diff --git a/README.md b/README.md index ef5bdc66ef03131318e1dde627e0224cca9137fd..0a309ebe2d828fc1934570b857d24fb289fcb832 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ **TensorFlow** is an open source software library for numerical computation using data flow graphs. The graph nodes represent mathematical operations, while the graph edges represent the multidimensional data arrays (tensors) that flow -between them. This flexible architecture lets you deploy computation to one +between them. This flexible architecture enables you to deploy computation to one or more CPUs or GPUs in a desktop, server, or mobile device without rewriting code. TensorFlow also includes TensorBoard, a data visualization toolkit. @@ -22,6 +22,10 @@ organization for the purposes of conducting machine learning and deep neural networks research. The system is general enough to be applicable in a wide variety of other domains, as well. +Keep up to date with release announcements and security updates by +subscribing to +[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce). + ## Installation *See [Installing TensorFlow](https://www.tensorflow.org/get_started/os_setup.html) for instructions on how to install our release binaries or how to build from source.* diff --git a/RELEASE.md b/RELEASE.md index 6f54dee58f75c29a16545ba25de12fe059baf1eb..c63d9f20c9a842ceed97afc25690073d082c42cb 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,63 @@ +# Release 1.7.0 + +## Major Features And Improvements +* Eager mode is moving out of contrib, try `tf.enable_eager_execution()`. +* Graph rewrites emulating fixed-point quantization compatible with TensorFlow Lite, supported by new `tf.contrib.quantize` package. +* Easily customize gradient computation with `tf.custom_gradient`. +* [TensorBoard Debugger Plugin](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/debugger/README.md), the graphical user interface (GUI) of TensorFlow Debugger (tfdbg), is now in alpha. +* Experimental support for reading a sqlite database as a `Dataset` with new `tf.contrib.data.SqlDataset`. +* Distributed Mutex / CriticalSection added to `tf.contrib.framework.CriticalSection`. +* Better text processing with `tf.regex_replace`. +* Easy, efficient sequence input with `tf.contrib.data.bucket_by_sequence_length` + +## Bug Fixes and Other Changes +* Accelerated Linear Algebra (XLA): + * Add `MaxPoolGradGrad` support for XLA + * CSE pass from Tensorflow is now disabled in XLA. +* `tf.data`: + * `tf.data.Dataset` + * Add support for building C++ Dataset op kernels as external libraries, using the `tf.load_op_library()` mechanism. + * `Dataset.list_files()` now shuffles its output by default. + * `Dataset.shuffle(..., seed=tf.constant(0, dtype=tf.int64))` now yields the same sequence of elements as `Dataset.shuffle(..., seed=0)`. + * Add `num_parallel_reads` argument to `tf.data.TFRecordDataset`. +* `tf.contrib`: + * `tf.contrib.bayesflow.halton_sequence` now supports randomization. + * Add support for scalars in `tf.contrib.all_reduce`. + * Add `effective_sample_size` to `tf.contrib.bayesflow.mcmc_diagnostics`. + * Add `potential_scale_reduction` to `tf.contrib.bayesflow.mcmc_diagnostics`. + * Add `BatchNormalization`, `Kumaraswamy` bijectors. + * Deprecate `tf.contrib.learn`. Please check contrib/learn/README.md for instructions on how to convert existing code. + * `tf.contrib.data` + * Remove deprecated `tf.contrib.data.Dataset`, `tf.contrib.data.Iterator`, `tf.contrib.data.FixedLengthRecordDataset`, `tf.contrib.data.TextLineDataset`, and `tf.contrib.data.TFRecordDataset` classes. + * Added `bucket_by_sequence_length`, `sliding_window_batch`, and `make_batched_features_dataset` + * Remove unmaintained `tf.contrib.ndlstm`. You can find it externally at https://github.com/tmbarchive/tfndlstm. + * Moved most of `tf.contrib.bayesflow` to its own repo: `tfp` +* Other: + * tf.py_func now reports the full stack trace if an exception occurs. + * Integrate `TPUClusterResolver` with GKE's integration for Cloud TPUs. + * Add a library for statistical testing of samplers. + * Add Helpers to stream data from the GCE VM to a Cloud TPU. + * Integrate ClusterResolvers with TPUEstimator. + * Unify metropolis_hastings interface with HMC kernel. + * Move LIBXSMM convolutions to a separate --define flag so that they are disabled by default. + * Fix `MomentumOptimizer` lambda. + * Reduce `tfp.layers` boilerplate via programmable docstrings. + * Add `auc_with_confidence_intervals`, a method for computing the AUC and confidence interval with linearithmic time complexity. + * `regression_head` now accepts customized link function, to satisfy the usage that user can define their own link function if the `array_ops.identity` does not meet the requirement. + * Fix `initialized_value` and `initial_value` behaviors for `ResourceVariables` created from `VariableDef` protos. + * Add TensorSpec to represent the specification of Tensors. + * Constant folding pass is now deterministic. + * Support `float16` `dtype` in `tf.linalg.*`. + * Add `tf.estimator.export.TensorServingInputReceiver` that allows `tf.estimator.Estimator.export_savedmodel` to pass raw tensors to model functions. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +4d55397500, Abe, Alistair Low, Andy Kernahan, Appledore, Ben, Ben Barsdell, Boris Pfahringer, Brad Wannow, Brett Koonce, Carl Thomé, cclauss, Chengzhi Chen, Chris Drake, Christopher Yeh, Clayne Robison, Codrut Grosu, Daniel Trebbien, Danny Goodman, David Goodwin, David Norman, Deron Eriksson, Donggeon Lim, Donny Viszneki, DosLin, DylanDmitri, Francisco Guerrero, Fred Reiss, gdh1995, Giuseppe, Glenn Weidner, gracehoney, Guozhong Zhuang, Haichen "Hc" Li, Harald Husum, harumitsu.nobuta, Henry Spivey, hsm207, Jekyll Song, Jerome, Jiongyan Zhang, jjsjann123, John Sungjin Park, Johnson145, JoshVarty, Julian Wolff, Jun Wang, June-One, Kamil Sindi, Kb Sriram, Kdavis-Mozilla, Kenji, lazypanda1, Liang-Chi Hsieh, Loo Rong Jie, Mahesh Bhosale, MandarJKulkarni, ManHyuk, Marcus Ong, Marshal Hayes, Martin Pool, matthieudelaro, mdfaijul, mholzel, Michael Zhou, Ming Li, Minmin Sun, Myungjoo Ham, MyungsungKwak, Naman Kamra, Peng Yu, Penghao Cen, Phil, Raghuraman-K, resec, Rohin Mohanadas, Sandeep N Gupta, Scott Tseng, seaotterman, Seo Sanghyeon, Sergei Lebedev, Ted Chang, terrytangyuan, Tim H, tkunic, Tod, vihanjain, Yan Facai (颜发才), Yin Li, Yong Tang, Yukun Chen, Yusuke Yamada + + + # Release 1.6.0 ## Breaking Changes diff --git a/tensorflow/SECURITY.md b/SECURITY.md similarity index 90% rename from tensorflow/SECURITY.md rename to SECURITY.md index fea24b273920885ba8a1ae96aafbf7710df46e1f..a5ce3a62ee202f6e7d83f0fedc2777d9c88ba9b5 100644 --- a/tensorflow/SECURITY.md +++ b/SECURITY.md @@ -6,7 +6,7 @@ report vulnerabilities in TensorFlow. ## TensorFlow models are programs -TensorFlow's runtime system interprets and executes programs. What machine +TensorFlow's runtime system interprets and executes programs. What machine learning practitioners term [**models**](https://developers.google.com/machine-learning/glossary/#model) are expressed as programs that TensorFlow executes. TensorFlow programs are encoded @@ -28,12 +28,12 @@ data you supply to TensorFlow to train a model, or to use a model to run inference on the data. **TensorFlow models are programs, and need to be treated as such from a security -perspective.** +perspective.** ## Running untrusted models As a general rule: **Always** execute untrusted models inside a sandbox (e.g., -[nsjail](https://github.com/google/nsjail)). +[nsjail](https://github.com/google/nsjail)). There are several ways in which a model could become untrusted. Obviously, if an untrusted party supplies TensorFlow kernels, arbitrary code may be executed. @@ -109,11 +109,11 @@ graphs known to the `ModelServer`. This means that an attacker may run graphs using untrusted inputs as described above, but they would not be able to execute arbitrary graphs. It is possible to safely expose a `ModelServer` directly to an untrusted network, **but only if the graphs it is configured to -use have been carefully audited to be safe**. +use have been carefully audited to be safe**. Similar to best practices for other servers, we recommend running any `ModelServer` with appropriate privileges (i.e., using a separate user with -reduced permisisons). In the spirit of defense in depth, we recommend +reduced permissions). In the spirit of defense in depth, we recommend authenticating requests to any TensorFlow server connected to an untrusted network, as well as sandboxing the server to minimize the adverse effects of any breach. @@ -129,11 +129,11 @@ with specially crafted inputs. ### What is a vulnerability? Given TensorFlow's flexibility, it is possible to specify computation graphs -which exhibit unexpected or unwanted behaviors. The fact that TensorFlow models +which exhibit unexpected or unwanted behavior. The fact that TensorFlow models can perform arbitrary computations means that they may read and write files, communicate via the network, produce deadlocks and infinite loops, or run out of memory. It is only when these behaviors are outside the specifications of the -operations involved that such behavior is a vulnerability. +operations involved that such behavior is a vulnerability. A `FileWriter` writing a file is not unexpected behavior and therefore is not a vulnerability in TensorFlow. A `MatMul` allowing arbitrary binary code execution @@ -170,6 +170,17 @@ Please use a descriptive subject line for your report email. After the initial reply to your report, the security team will endeavor to keep you informed of the progress being made towards a fix and announcement. +In addition, please include the following information along with your report: + +* Your name and affiliation (if any). +* A description the technical details of the vulnerabilities. It is very + important to let us know how we can reproduce your findings. +* An explanation who can exploit this vulnerability, and what they gain when + doing so -- write an attack scenario. This will help us evaluate your report + quickly, especially if the issue is complex. +* Whether this vulnerability public or known to third parties. If it is, please + provide details. + If you believe that an existing (public) issue is security-related, please send an email to `security@tensorflow.org`. The email should include the issue ID and a short description of why it should be handled according to this security @@ -233,7 +244,7 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc= ### Known vulnerabilities -| Type | Versions affected | Reported by | Additional Information | -|-------------------|:-----------------:|--------------------|-----------------------------| -| out of bounds read| <=1.4 | TenCent Blade Team | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | +| Type | Versions affected | Reported by | Additional Information | +|--------------------|:-----------------:|-----------------------|-----------------------------| +| Out Of Bounds Read | <=1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | diff --git a/WORKSPACE b/WORKSPACE index 1e38a9a8cd754886fc5232531816b875de0879a3..11c5cdb2070e79b16540a39f13cab28608962340 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -14,6 +14,12 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") closure_repositories() +# We must check the bazel version before trying to parse any other BUILD +# files, in case the parsing of those build files depends on the bazel +# version we require here. +load("//tensorflow:version_check.bzl", "check_bazel_version_at_least") +check_bazel_version_at_least("0.10.0") + load("//tensorflow:workspace.bzl", "tf_workspace") # Uncomment and update the paths in these entries to build the Android demo. diff --git a/configure.py b/configure.py index 7d2e30cd8af53b74c9274e28c4b95bdd01e9d41a..6744082d5d55c3a039b7a4efa7a539e77185cabd 100644 --- a/configure.py +++ b/configure.py @@ -40,7 +40,7 @@ _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' 'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION) -_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/x86_64-linux-gnu' +_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine() _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' @@ -250,7 +250,11 @@ def reset_tf_configure_bazelrc(workspace_path): if _TF_BAZELRC_FILENAME in l: continue f.write('%s\n' % l) - f.write('import %s\n' % _TF_BAZELRC) + if is_windows(): + tf_bazelrc_path = _TF_BAZELRC.replace("\\", "/") + else: + tf_bazelrc_path = _TF_BAZELRC + f.write('import %s\n' % tf_bazelrc_path) def cleanup_makefile(): @@ -444,7 +448,7 @@ def check_bazel_version(min_version): if which('bazel') is None: print('Cannot find bazel. Please install bazel.') sys.exit(0) - curr_version = run_shell(['bazel', '--batch', 'version']) + curr_version = run_shell(['bazel', '--batch', '--bazelrc=/dev/null', 'version']) for line in curr_version.split('\n'): if 'Build label: ' in line: @@ -498,7 +502,6 @@ def set_cc_opt_flags(environ_cp): write_to_bazelrc('build --copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') write_to_bazelrc('build --host_copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') - def set_tf_cuda_clang(environ_cp): """set TF_CUDA_CLANG action_env. @@ -520,7 +523,7 @@ def set_tf_cuda_clang(environ_cp): def set_tf_download_clang(environ_cp): """Set TF_DOWNLOAD_CLANG action_env.""" - question = 'Do you want to download a fresh release of clang? (Experimental)' + question = 'Do you wish to download a fresh release of clang? (Experimental)' yes_reply = 'Clang will be downloaded and used to compile tensorflow.' no_reply = 'Clang will not be downloaded.' set_action_env_var( @@ -1044,7 +1047,10 @@ def set_tf_tensorrt_install_path(environ_cp): for lib_file in possible_files: if is_compatible(lib_file, cuda_ver, cudnn_ver): - ver_str = nvinfer_pattern.search(lib_file).group(1) + matches = nvinfer_pattern.search(lib_file) + if len(matches.groups()) == 0: + continue + ver_str = matches.group(1) ver = convert_version_to_int(ver_str) if len(ver_str) else 0 if ver > highest_ver[0]: highest_ver = [ver, ver_str, lib_file] @@ -1373,7 +1379,7 @@ def main(): # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.5.4') + check_bazel_version('0.10.0') reset_tf_configure_bazelrc(args.workspace) cleanup_makefile() @@ -1390,6 +1396,9 @@ def main(): environ_cp['TF_NEED_OPENCL'] = '0' environ_cp['TF_CUDA_CLANG'] = '0' environ_cp['TF_NEED_TENSORRT'] = '0' + # TODO(ibiryukov): Investigate using clang as a cpu or cuda compiler on + # Windows. + environ_cp['TF_DOWNLOAD_CLANG'] = '0' if is_macos(): environ_cp['TF_NEED_JEMALLOC'] = '0' @@ -1404,7 +1413,7 @@ def main(): set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System', 'with_s3_support', True, 's3') set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform', - 'with_kafka_support', False, 'kafka') + 'with_kafka_support', True, 'kafka') set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', False, 'xla') set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support', @@ -1437,16 +1446,8 @@ def main(): set_tf_cuda_clang(environ_cp) if environ_cp.get('TF_CUDA_CLANG') == '1': - if not is_windows(): - # Ask if we want to download clang release while building. - set_tf_download_clang(environ_cp) - else: - # We use bazel's generated crosstool on Windows and there is no - # way to provide downloaded toolchain for that yet. - # TODO(ibiryukov): Investigate using clang as a cuda compiler on - # Windows. - environ_cp['TF_DOWNLOAD_CLANG'] = '0' - + # Ask whether we should download the clang toolchain. + set_tf_download_clang(environ_cp) if environ_cp.get('TF_DOWNLOAD_CLANG') != '1': # Set up which clang we should use as the cuda / host compiler. set_clang_cuda_compiler_path(environ_cp) @@ -1456,6 +1457,13 @@ def main(): if not is_windows(): set_gcc_host_compiler_path(environ_cp) set_other_cuda_vars(environ_cp) + else: + # CUDA not required. Ask whether we should download the clang toolchain and + # use it for the CPU build. + set_tf_download_clang(environ_cp) + if environ_cp.get('TF_DOWNLOAD_CLANG') == '1': + write_to_bazelrc('build --config=download_clang') + write_to_bazelrc('test --config=download_clang') set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False) if environ_cp.get('TF_NEED_MPI') == '1': diff --git a/tensorflow/BUILD b/tensorflow/BUILD index dc995d231d3e591771f801e28024a76610cdba26..31e64793de52a13530ebbf5ccc0e38cf570b16fd 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -240,6 +240,13 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "with_kafka_support_windows_override", + define_values = {"with_kafka_support": "true"}, + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + config_setting( name = "with_gcp_support_android_override", define_values = {"with_gcp_support": "true"}, @@ -415,6 +422,17 @@ py_library( deps = ["//tensorflow/python"], ) +py_library( + name = "experimental_tensorflow_py", + srcs = ["experimental_api.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow/tools/api/tests:__subpackages__"], + deps = [ + "//tensorflow/python", + "//tensorflow/tools/api/generator:python_api", + ], +) + filegroup( name = "all_opensource_files", data = [ @@ -441,6 +459,7 @@ filegroup( "//tensorflow/compiler/xla:all_files", "//tensorflow/compiler/xla/client:all_files", "//tensorflow/compiler/xla/client/lib:all_files", + "//tensorflow/compiler/xla/client/xla_client:all_files", "//tensorflow/compiler/xla/legacy_flags:all_files", "//tensorflow/compiler/xla/python:all_files", "//tensorflow/compiler/xla/service:all_files", @@ -455,6 +474,12 @@ filegroup( "//tensorflow/contrib:all_files", "//tensorflow/contrib/all_reduce:all_files", "//tensorflow/contrib/android:all_files", + "//tensorflow/contrib/autograph:all_files", + "//tensorflow/contrib/autograph/converters:all_files", + "//tensorflow/contrib/autograph/impl:all_files", + "//tensorflow/contrib/autograph/pyct:all_files", + "//tensorflow/contrib/autograph/pyct/static_analysis:all_files", + "//tensorflow/contrib/autograph/utils:all_files", "//tensorflow/contrib/batching:all_files", "//tensorflow/contrib/bayesflow:all_files", "//tensorflow/contrib/boosted_trees:all_files", @@ -548,12 +573,6 @@ filegroup( "//tensorflow/contrib/opt:all_files", "//tensorflow/contrib/periodic_resample:all_files", "//tensorflow/contrib/predictor:all_files", - "//tensorflow/contrib/py2tf:all_files", - "//tensorflow/contrib/py2tf/converters:all_files", - "//tensorflow/contrib/py2tf/impl:all_files", - "//tensorflow/contrib/py2tf/pyct:all_files", - "//tensorflow/contrib/py2tf/pyct/static_analysis:all_files", - "//tensorflow/contrib/py2tf/utils:all_files", "//tensorflow/contrib/quantize:all_files", "//tensorflow/contrib/receptive_field:all_files", "//tensorflow/contrib/reduce_slice_ops:all_files", @@ -598,6 +617,7 @@ filegroup( "//tensorflow/contrib/verbs:all_files", "//tensorflow/core:all_files", "//tensorflow/core/api_def:all_files", + "//tensorflow/core/common_runtime/eager:all_files", "//tensorflow/core/debug:all_files", "//tensorflow/core/distributed_runtime:all_files", "//tensorflow/core/distributed_runtime/rpc:all_files", @@ -660,6 +680,7 @@ filegroup( "//tensorflow/python/kernel_tests/distributions:all_files", "//tensorflow/python/kernel_tests/linalg:all_files", "//tensorflow/python/kernel_tests/random:all_files", + "//tensorflow/python/kernel_tests/testdata:all_files", "//tensorflow/python/ops/distributions:all_files", "//tensorflow/python/ops/linalg:all_files", "//tensorflow/python/ops/losses:all_files", @@ -773,7 +794,7 @@ tf_cc_shared_object( linkopts = select({ "//tensorflow:darwin": [ "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file - "//tensorflow/c:exported_symbols.lds", + "$(location //tensorflow/c:exported_symbols.lds)", "-Wl,-install_name,@rpath/libtensorflow.so", ], "//tensorflow:windows": [], @@ -782,11 +803,12 @@ tf_cc_shared_object( "-z defs", "-s", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "//tensorflow/c:version_script.lds", + "$(location //tensorflow/c:version_script.lds)", ], }), deps = [ "//tensorflow/c:c_api", + "//tensorflow/c:c_api_experimental", "//tensorflow/c:exported_symbols.lds", "//tensorflow/c:version_script.lds", "//tensorflow/c/eager:c_api", @@ -799,7 +821,7 @@ tf_cc_shared_object( linkopts = select({ "//tensorflow:darwin": [ "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file - "//tensorflow:tf_exported_symbols.lds", + "$(location //tensorflow:tf_exported_symbols.lds)", ], "//tensorflow:windows": [], "//tensorflow:windows_msvc": [], @@ -807,7 +829,7 @@ tf_cc_shared_object( "-z defs", "-s", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "//tensorflow:tf_version_script.lds", + "$(location //tensorflow:tf_version_script.lds)", ], }), deps = [ diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 5dfb743681255d8c03e91ea43fd441d94fdee59d..426f97b84472ba475b7b16ea49b64b4671ba6e74 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -17,7 +17,10 @@ load( filegroup( name = "headers", - srcs = ["c_api.h"], + srcs = [ + "c_api.h", + "c_api_experimental.h", + ], visibility = ["//tensorflow:__subpackages__"], ) @@ -113,6 +116,10 @@ tf_cuda_library( ":c_api", ":c_api_internal", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", + "//tensorflow/contrib/tpu:all_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], ) @@ -209,6 +216,27 @@ tf_cuda_cc_test( ], ) +tf_cc_test( + name = "c_api_experimental_test", + size = "small", + srcs = ["c_api_experimental_test.cc"], + data = ["testdata/tf_record"], + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + # We must ensure that the dependencies can be dynamically linked since + # the shared library must be able to use core:framework. + # linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":c_api_experimental", + ":c_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "c_api_function_test", size = "small", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 85f1d1639b4d09f2de77d326481a86ec246270d0..18eeb2816807ec9986999cfc2c9a4c0f032683c0 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -30,6 +30,7 @@ limitations under the License. #endif #include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/eval_const_tensor.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/log_memory.h" @@ -62,6 +63,7 @@ limitations under the License. // brain namespace because we are defining 'extern "C"' functions. using tensorflow::AllocationDescription; using tensorflow::DataType; +using tensorflow::ExtendSessionGraphHelper; using tensorflow::Graph; using tensorflow::GraphDef; using tensorflow::mutex_lock; @@ -73,6 +75,7 @@ using tensorflow::NodeBuilder; using tensorflow::NodeDef; using tensorflow::OpDef; using tensorflow::OpRegistry; +using tensorflow::OutputTensor; using tensorflow::PartialTensorShape; using tensorflow::RunMetadata; using tensorflow::RunOptions; @@ -638,17 +641,17 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, } void RecordMutation(TF_Graph* graph, const TF_Operation& op, - const char* mutation_type) - EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { + const char* mutation_type) { // If any session has already run this node_id, mark this session as // unrunnable. for (auto it : graph->sessions) { + mutex_lock session_lock(it.first->mu); if (it.first->last_num_graph_nodes > op.node.id()) { - it.second = FailedPrecondition( + it.second = strings::StrCat( "Operation '", op.node.DebugString(), "' was changed by ", mutation_type, - " after it was run by a session. Nodes can be mutated " - "only before they are executed by a session. Either don't modify " + " after it was run by a session. This mutation will have no effect, " + "and will trigger an error in the future. Either don't modify " "nodes after running them or create a new session."); } } @@ -708,6 +711,61 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, Status LoadLibrary(const char* library_filename, void** result, const void** buf, size_t* len); +// TODO(josh11b,mrry): Change Session to be able to use a Graph* +// directly, instead of requiring us to serialize to a GraphDef and +// call Session::Extend(). +bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { + if (session->graph != nullptr) { + // Take the graph lock before the session lock to avoid deadlock. This is + // safe since session->graph does not change. + session->graph->mu.lock(); + mutex_lock session_lock(session->mu); + const Graph& graph = session->graph->graph; + + const string& mutation_warning = session->graph->sessions[session]; + if (!mutation_warning.empty()) { + // TODO(b/74949947): turn this back into an error status + LOG(WARNING) << mutation_warning; + session->graph->sessions[session].clear(); + } + + const auto num_nodes = graph.num_node_ids(); + if (session->last_num_graph_nodes < num_nodes) { + status->status = tensorflow::ValidateNoCycles(session->graph->graph); + if (!status->status.ok()) { + session->graph->mu.unlock(); + return false; + } + + GraphDef graph_def; + *graph_def.mutable_versions() = graph.versions(); + // Fill graph_def with nodes with ids in the range + // [session->last_num_graph_nodes, num_nodes), that is the nodes + // added since the last TF_SessionRun() call. + for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) { + Node* const node = graph.FindNodeId(id); + if (node != nullptr && node->IsOp()) { + NodeDef* const node_def = graph_def.add_node(); + *node_def = node->def(); + } + } + *graph_def.mutable_library() = graph.flib_def().ToProto(); + session->graph->mu.unlock(); + status->status = session->session->Extend(graph_def); + if (!status->status.ok()) { + // Contract is we always delete input_values[i]. + return false; + } + // Note: session->session is not modified if Extend() fails, so + // we only set last_num_graph_nodes if it succeeds. + session->last_num_graph_nodes = num_nodes; + } else { + session->graph->mu.unlock(); + } + } + return true; +} + } // namespace tensorflow static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs, @@ -2408,11 +2466,7 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, // TF_Session functions ---------------------------------------------- TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g) - : session(s), graph(g), last_num_graph_nodes(0), device_mgr(nullptr) { - if (s->LocalDeviceManager(&device_mgr).ok()) { - devices = device_mgr->ListDevices(); - } -} + : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {} TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, TF_Status* status) { @@ -2422,7 +2476,7 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, TF_Session* new_session = new TF_Session(session, graph); if (graph != nullptr) { mutex_lock l(graph->mu); - graph->sessions[new_session] = Status::OK(); + graph->sessions[new_session] = ""; } return new_session; } else { @@ -2488,7 +2542,7 @@ TF_Session* TF_LoadSessionFromSavedModel( TF_Session* session = new TF_Session(bundle.session.release(), graph); - graph->sessions[session] = Status::OK(); + graph->sessions[session] = ""; session->last_num_graph_nodes = graph->graph.num_node_ids(); return session; #endif // __ANDROID__ @@ -2512,58 +2566,6 @@ void TF_DeleteSession(TF_Session* s, TF_Status* status) { delete s; } -// TODO(josh11b,mrry): Change Session to be able to use a Graph* -// directly, instead of requiring us to serialize to a GraphDef and -// call Session::Extend(). -static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { - if (session->graph != nullptr) { - mutex_lock session_lock(session->mu); - session->graph->mu.lock(); - const Graph& graph = session->graph->graph; - - status->status = session->graph->sessions[session]; - if (!status->status.ok()) { - session->graph->mu.unlock(); - return false; - } - - const auto num_nodes = graph.num_node_ids(); - if (session->last_num_graph_nodes < num_nodes) { - status->status = tensorflow::ValidateNoCycles(session->graph->graph); - if (!status->status.ok()) { - session->graph->mu.unlock(); - return false; - } - - GraphDef graph_def; - *graph_def.mutable_versions() = graph.versions(); - // Fill graph_def with nodes with ids in the range - // [session->last_num_graph_nodes, num_nodes), that is the nodes - // added since the last TF_SessionRun() call. - for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) { - Node* const node = graph.FindNodeId(id); - if (node != nullptr && node->IsOp()) { - NodeDef* const node_def = graph_def.add_node(); - *node_def = node->def(); - } - } - *graph_def.mutable_library() = graph.flib_def().ToProto(); - session->graph->mu.unlock(); - status->status = session->session->Extend(graph_def); - if (!status->status.ok()) { - // Contract is we always delete input_values[i]. - return false; - } - // Note: session->session is not modified if Extend() fails, so - // we only set last_num_graph_nodes if it succeeds. - session->last_num_graph_nodes = num_nodes; - } else { - session->graph->mu.unlock(); - } - } - return true; -} - void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options, const TF_Output* inputs, TF_Tensor* const* input_values, int ninputs, const TF_Output* outputs, @@ -2573,7 +2575,8 @@ void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options, // TODO(josh11b,mrry): Change Session to be able to use a Graph* // directly, instead of requiring us to serialize to a GraphDef and // call Session::Extend(). - if (!ExtendSessionGraphHelper(session, status)) { + if (session->extend_before_run && + !ExtendSessionGraphHelper(session, status)) { return; } @@ -2610,7 +2613,8 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, const char** handle, TF_Status* status) { *handle = nullptr; - if (!ExtendSessionGraphHelper(session, status)) { + if (session->extend_before_run && + !ExtendSessionGraphHelper(session, status)) { return; } @@ -2653,7 +2657,8 @@ void TF_SessionPRun(TF_Session* session, const char* handle, // TODO(josh11b,mrry): Change Session to be able to use a Graph* // directly, instead of requiring us to serialize to a GraphDef and // call Session::Extend(). - if (!ExtendSessionGraphHelper(session, status)) { + if (session->extend_before_run && + !ExtendSessionGraphHelper(session, status)) { return; } @@ -2682,6 +2687,24 @@ void TF_SessionPRun(TF_Session* session, const char* handle, output_values, target_names, nullptr, status); } +unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output, + TF_Tensor** result, TF_Status* status) { + *result = nullptr; + mutex_lock l(graph->mu); + OutputTensor tensor(&output.oper->node, output.index); + bool evaluated; + Tensor result_tensor; + status->status = EvaluateConstantTensor( + tensor, graph->refiner, *graph->graph.op_registry(), + graph->graph.versions().producer(), &evaluated, &result_tensor); + if (evaluated) { + DCHECK(status->status.ok()); + *result = TF_TensorFromTensor(result_tensor, status); + if (!status->status.ok()) evaluated = false; + } + return evaluated; +} + TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) { tensorflow::OpList op_list; if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) { diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index ad592ef70961ef427bfe9fd322a82bd64df7f9f1..b32f574628c4d1dc5c3bb3f1265a1b12adee28bc 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1275,13 +1275,22 @@ TF_CAPI_EXPORT extern void TF_FunctionGetAttrValueProto( // Deleting a function does not remove it from any graphs it was copied to. TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function* func); +// Attempts to evaluate `output`. This will only be possible if `output` doesn't +// depend on any graph inputs (this function is safe to call if this isn't the +// case though). +// +// If the evaluation is successful, this function returns true and `output`s +// value is returned in `result`. Otherwise returns false. An error status is +// returned if something is wrong with the graph or input. Note that this may +// return false even if no error status is set. +TF_CAPI_EXPORT extern unsigned char TF_TryEvaluateConstant(TF_Graph* graph, + TF_Output output, + TF_Tensor** result, + TF_Status* status); + // TODO(josh11b): Register OpDef, available to all operations added // to this graph. -// The following two may both benefit from a subgraph-definition API -// that re-uses most of the graph-definition API. -// TODO(andydavis): Add functions to a graph. - // -------------------------------------------------------------------------- // API for driving Graph execution. diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index be7f85a5bb06dce84579b109d506ded049042b50..bea93785717e2161fcec941485ac3c3f7f3e3ed5 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -17,8 +17,26 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/protobuf/config.pb.h" +using tensorflow::FunctionDef; +using tensorflow::Node; +using tensorflow::NodeBuilder; +using tensorflow::Status; + +namespace { +typedef std::unique_ptr + UniqueFuncPtr; +} + +// struct TF_Operation { tensorflow::Node node; }; +static TF_Operation* ToTF_Operation(Node* node) { + return static_cast(static_cast(node)); +} + void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { tensorflow::ConfigProto& config = options->options.config; auto* optimizer_options = @@ -37,3 +55,8340 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF); } } + +void TF_InitializeTPU(TF_Session* session, TF_Status* status) { + VLOG(1) << "Initializing TPU"; + TF_Operation* config_op = + TF_GraphOperationByName(session->graph, "ConfigureDistributedTPU"); + if (config_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find node ConfigureDistributedTPU in the TF graph."); + return; + } + + TF_Output config_node{config_op, 0}; + + TF_Tensor* dummy_output; + TF_SessionRun(session, /*run_options*/ nullptr, + // input related parameters + /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0, + // output related parameters + /*outputs*/ &config_node, /*output_values*/ &dummy_output, + /*noutputs*/ 1, + /*targets*/ nullptr, /*ntargets*/ 0, + /*run_metadata*/ nullptr, status); + if (status->status.ok()) { + TF_DeleteTensor(dummy_output); + } +} + +void TF_ShutdownTPU(TF_Session* session, TF_Status* status) { + { + tensorflow::mutex_lock c(session->graph->mu); + VLOG(1) << "Shutting down TPU, with input graph: " + << session->graph->graph.ToGraphDefDebug().DebugString(); + } + + TF_Operation* shutdown_op = + TF_GraphOperationByName(session->graph, "ShutdownDistributedTPU"); + if (shutdown_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find node ShutdownDistributedTPU in the TF graph."); + return; + } + + TF_SessionRun(session, /*run_options*/ nullptr, + // input related parameters + /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0, + // output related parameters + /*outputs*/ nullptr, /*output_values*/ nullptr, + /*noutputs*/ 0, + /*targets*/ &shutdown_op, /*ntargets*/ 1, + /*run_metadata*/ nullptr, status); +} + +const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) { + tensorflow::mutex_lock c(graph->mu); + const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString(); + *len = debug_str.size(); + char* ret = static_cast(malloc(*len + 1)); + memcpy(ret, debug_str.c_str(), *len + 1); + return ret; +} + +// On success, returns a set of TF_Function instances from `text_proto` of +// GraphDef type. These functions must be deleted by calling TF_DeleteFunction. +// +// If `mutate_proto_func` is non-NULL, run it over each FunctionDef proto, +// before creating a TF_Function out of the possibly mutated proto. +static std::vector CreateFunctionsFromTextProto( + const char* text_proto, + std::function* mutate_proto_func, TF_Status* status) { + tensorflow::GraphDef gdef; + if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &gdef)) { + status->status = tensorflow::errors::Internal( + "Invalid text proto for GraphDef: ", text_proto); + return {}; + } + const auto& fdef_lib = gdef.library(); + if (fdef_lib.gradient_size() > 0) { + status->status = tensorflow::errors::Internal( + "GradientDef is not supported in reading Dataset related functions: ", + text_proto); + return {}; + } + std::vector ret; + for (const FunctionDef& fdef : fdef_lib.function()) { + // Make a copy so that we can mutate it. + FunctionDef fdef_to_load = fdef; + if (mutate_proto_func) { + (*mutate_proto_func)(&fdef_to_load); + } + VLOG(1) << "Adding func to graph: " << fdef_to_load.DebugString(); + std::vector binary_proto_buf(fdef_to_load.ByteSizeLong()); + fdef_to_load.SerializeToArray(binary_proto_buf.data(), + binary_proto_buf.size()); + TF_Function* func = TF_FunctionImportFunctionDef( + binary_proto_buf.data(), binary_proto_buf.size(), status); + if (!status->status.ok()) return {}; + ret.push_back(UniqueFuncPtr(func, TF_DeleteFunction)); + } + return ret; +} + +// On success, returns a newly created TF_Function instance encoding a dataset +// node stack that returns a sequence of 3 floats, and sets `dataset_name` to +// the created dataset name. The returned function must be deleted by calling +// TF_DeleteFunction. +static UniqueFuncPtr CreateFakeDatasetFunction(std::string* dataset_name, + TF_Status* status) { + const char* func_def = R"PREFIX( +library { + function { + signature { + name: "_make_dataset_d8de2712" + output_arg { + name: "TensorSliceDataset" + type: DT_VARIANT + } + is_stateful: true + } + node_def { + name: "TensorSliceDataset/tensors/component_0" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 3 + } + } + tensor_content: "\000\000(B\000\000,B\000\0000B" + } + } + } + } + node_def { + name: "TensorSliceDataset" + op: "TensorSliceDataset" + input: "TensorSliceDataset/tensors/component_0:output:0" + attr { + key: "Toutput_types" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + } + ret { + key: "TensorSliceDataset" + value: "TensorSliceDataset:handle:0" + } + } +} +)PREFIX"; + + *dataset_name = "_make_dataset_d8de2712"; + auto functions = CreateFunctionsFromTextProto( + func_def, /*mutate_proto_func*/ nullptr, status); + DCHECK_EQ(functions.size(), 1); + return std::move(functions[0]); +} + +// On success, returns a set of TF_Function instances encoding a dataset +// node stack that reads a Imagenet TFRecordFile dataset from `file_path`, and +// sets `dataset_name` to the created dataset name. The returned functions must +// be deleted by calling TF_DeleteFunction. +static std::vector CreateImagenetDatasetFunctions( + const char* file_path, std::string* dataset_name, TF_Status* status) { + const char* func_def = R"PREFIX( +library { + function { + signature { + name: "tf_map_func_91295dea" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "FlatMapDataset" + type: DT_VARIANT + } + description: "A wrapper for Defun that facilitates shape inference." + is_stateful: true + } + node_def { + name: "flat_filenames/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } + } + node_def { + name: "flat_filenames" + op: "Reshape" + input: "arg0" + input: "flat_filenames/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "TensorSliceDataset" + op: "TensorSliceDataset" + input: "flat_filenames:output:0" + attr { + key: "Toutput_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "FlatMapDataset" + op: "FlatMapDataset" + input: "TensorSliceDataset:handle:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_0cc8c35b" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + } + ret { + key: "FlatMapDataset" + value: "FlatMapDataset:handle:0" + } + } + function { + signature { + name: "tf_map_func_0cc8c35b" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "TFRecordDataset" + type: DT_VARIANT + } + description: "A wrapper for Defun that facilitates shape inference." + is_stateful: true + } + node_def { + name: "compression_type" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "" + } + } + } + } + node_def { + name: "buffer_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 8388608 + } + } + } + } + node_def { + name: "TFRecordDataset" + op: "TFRecordDataset" + input: "arg0" + input: "compression_type:output:0" + input: "buffer_size:output:0" + } + ret { + key: "TFRecordDataset" + value: "TFRecordDataset:handle:0" + } + } + function { + signature { + name: "tf_map_func_74b6b15c" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "Reshape_1" + type: DT_FLOAT + } + output_arg { + name: "sub_1" + type: DT_INT32 + } + description: "A wrapper for Defun that facilitates shape inference." + is_stateful: true + } + node_def { + name: "ParseSingleExample/key_image/class/label" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -1 + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape" + op: "Reshape" + input: "ParseSingleExample/key_image/class/label:output:0" + input: "ParseSingleExample/Reshape/shape:output:0" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ParseSingleExample/key_image/class/text" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "" + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_1/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_1" + op: "Reshape" + input: "ParseSingleExample/key_image/class/text:output:0" + input: "ParseSingleExample/Reshape_1/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ParseSingleExample/key_image/encoded" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "" + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_2/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_2" + op: "Reshape" + input: "ParseSingleExample/key_image/encoded:output:0" + input: "ParseSingleExample/Reshape_2/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ParseSingleExample/key_image/format" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "jpeg" + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_3/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_3" + op: "Reshape" + input: "ParseSingleExample/key_image/format:output:0" + input: "ParseSingleExample/Reshape_3/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ParseSingleExample/ParseSingleExample" + op: "ParseSingleExample" + input: "arg0" + input: "ParseSingleExample/Reshape:output:0" + input: "ParseSingleExample/Reshape_1:output:0" + input: "ParseSingleExample/Reshape_2:output:0" + input: "ParseSingleExample/Reshape_3:output:0" + attr { + key: "Tdense" + value { + list { + type: DT_INT64 + type: DT_STRING + type: DT_STRING + type: DT_STRING + } + } + } + attr { + key: "dense_keys" + value { + list { + s: "image/class/label" + s: "image/class/text" + s: "image/encoded" + s: "image/format" + } + } + } + attr { + key: "dense_shapes" + value { + list { + shape { + } + shape { + } + shape { + } + shape { + } + } + } + } + attr { + key: "num_sparse" + value { + i: 5 + } + } + attr { + key: "sparse_keys" + value { + list { + s: "image/object/bbox/xmax" + s: "image/object/bbox/xmin" + s: "image/object/bbox/ymax" + s: "image/object/bbox/ymin" + s: "image/object/class/label" + } + } + } + attr { + key: "sparse_types" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + type: DT_FLOAT + type: DT_FLOAT + type: DT_INT64 + } + } + } + } + node_def { + name: "Reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "Reshape" + op: "Reshape" + input: "ParseSingleExample/ParseSingleExample:dense_values:2" + input: "Reshape/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/Substr/pos" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node_def { + name: "decode_image/Substr/len" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/Substr" + op: "Substr" + input: "Reshape:output:0" + input: "decode_image/Substr/pos:output:0" + input: "decode_image/Substr/len:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/is_jpeg/Substr/pos" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node_def { + name: "decode_image/is_jpeg/Substr/len" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/is_jpeg/Substr" + op: "Substr" + input: "Reshape:output:0" + input: "decode_image/is_jpeg/Substr/pos:output:0" + input: "decode_image/is_jpeg/Substr/len:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/is_jpeg/Equal/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "\377\330\377" + } + } + } + } + node_def { + name: "decode_image/is_jpeg/Equal" + op: "Equal" + input: "decode_image/is_jpeg/Substr:output:0" + input: "decode_image/is_jpeg/Equal/y:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + } + node_def { + name: "decode_image/cond_jpeg/Switch" + op: "Switch" + input: "decode_image/is_jpeg/Equal:z:0" + input: "decode_image/is_jpeg/Equal:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/switch_t" + op: "Identity" + input: "decode_image/cond_jpeg/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/switch_f" + op: "Identity" + input: "decode_image/cond_jpeg/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/pred_id" + op: "Identity" + input: "decode_image/is_jpeg/Equal:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/check_jpeg_channels/x" + op: "Const" + input: "^decode_image/cond_jpeg/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/check_jpeg_channels/y" + op: "Const" + input: "^decode_image/cond_jpeg/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 4 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/check_jpeg_channels" + op: "NotEqual" + input: "decode_image/cond_jpeg/check_jpeg_channels/x:output:0" + input: "decode_image/cond_jpeg/check_jpeg_channels/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/Assert/Const" + op: "Const" + input: "^decode_image/cond_jpeg/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 1, 3) when decoding JPEG images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/Assert/Assert/data_0" + op: "Const" + input: "^decode_image/cond_jpeg/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 1, 3) when decoding JPEG images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/Assert/Assert" + op: "Assert" + input: "decode_image/cond_jpeg/check_jpeg_channels:z:0" + input: "decode_image/cond_jpeg/Assert/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "decode_image/cond_jpeg/DecodeJpeg" + op: "DecodeJpeg" + input: "decode_image/cond_jpeg/DecodeJpeg/Switch:output_true:0" + input: "^decode_image/cond_jpeg/Assert/Assert" + attr { + key: "acceptable_fraction" + value { + f: 1.0 + } + } + attr { + key: "channels" + value { + i: 3 + } + } + attr { + key: "dct_method" + value { + s: "" + } + } + attr { + key: "fancy_upscaling" + value { + b: true + } + } + attr { + key: "ratio" + value { + i: 1 + } + } + attr { + key: "try_recover_truncated" + value { + b: false + } + } + } + node_def { + name: "decode_image/cond_jpeg/DecodeJpeg/Switch" + op: "Switch" + input: "Reshape:output:0" + input: "decode_image/cond_jpeg/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/is_png/y" + op: "Const" + input: "^decode_image/cond_jpeg/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "\211PN" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/is_png" + op: "Equal" + input: "decode_image/cond_jpeg/is_png/Switch:output_false:0" + input: "decode_image/cond_jpeg/is_png/y:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + } + node_def { + name: "decode_image/cond_jpeg/is_png/Switch" + op: "Switch" + input: "decode_image/Substr:output:0" + input: "decode_image/cond_jpeg/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@decode_image/Substr" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/is_png:z:0" + input: "decode_image/cond_jpeg/is_png:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/switch_t" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/switch_f" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/pred_id" + op: "Identity" + input: "decode_image/cond_jpeg/is_png:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/DecodePng" + op: "DecodePng" + input: "decode_image/cond_jpeg/cond_png/DecodePng/Switch_1:output_true:0" + attr { + key: "channels" + value { + i: 3 + } + } + attr { + key: "dtype" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/DecodePng/Switch" + op: "Switch" + input: "Reshape:output:0" + input: "decode_image/cond_jpeg/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/DecodePng/Switch_1" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/DecodePng/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/is_gif/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "GIF" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/is_gif" + op: "Equal" + input: "decode_image/cond_jpeg/cond_png/is_gif/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/is_gif/y:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/is_gif/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/is_png/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@decode_image/Substr" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/is_gif:z:0" + input: "decode_image/cond_jpeg/cond_png/is_gif:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/pred_id" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/is_gif:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/x" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels" + op: "NotEqual" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/x:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/x" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 4 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1" + op: "NotEqual" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/x:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/LogicalAnd" + op: "LogicalAnd" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels:z:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1:z:0" + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Const" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 3) when decoding GIF images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert/data_0" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 3) when decoding GIF images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert" + op: "Assert" + input: "decode_image/cond_jpeg/cond_png/cond_gif/LogicalAnd:z:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif" + op: "DecodeGif" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch_1:output_true:0" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert" + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/DecodePng/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch_1" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/pos" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/len" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr" + op: "Substr" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/pos:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/len:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "BM" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp" + op: "Equal" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp/y:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Const" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Unable to decode bytes as JPEG, PNG, GIF, or BMP" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert/data_0" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Unable to decode bytes as JPEG, PNG, GIF, or BMP" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert" + op: "Assert" + input: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp:z:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/x" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels" + op: "NotEqual" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/x:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Const" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 3) when decoding BMP images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert/data_0" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 3) when decoding BMP images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert" + op: "Assert" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels:z:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeBmp" + op: "DecodeBmp" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/Switch:output_false:0" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert" + attr { + key: "channels" + value { + i: 0 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Merge" + op: "Merge" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeBmp:image:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif:image:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/Merge" + op: "Merge" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Merge:output:0" + input: "decode_image/cond_jpeg/cond_png/DecodePng:image:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "decode_image/cond_jpeg/Merge" + op: "Merge" + input: "decode_image/cond_jpeg/cond_png/Merge:output:0" + input: "decode_image/cond_jpeg/DecodeJpeg:image:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "convert_image/Cast" + op: "Cast" + input: "decode_image/cond_jpeg/Merge:output:0" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "convert_image/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.00392156885937 + } + } + } + } + node_def { + name: "convert_image" + op: "Mul" + input: "convert_image/Cast:y:0" + input: "convert_image/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 4 + } + } + tensor_content: "\000\000\000\000\000\000\000\000\000\000\200?\000\000\200?" + } + } + } + } + node_def { + name: "distorted_bounding_box_crop/Shape" + op: "Shape" + input: "convert_image:z:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2/min_object_covered" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.10000000149 + } + } + } + } + node_def { + name: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2" + op: "SampleDistortedBoundingBoxV2" + input: "distorted_bounding_box_crop/Shape:output:0" + input: "Const:output:0" + input: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2/min_object_covered:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "area_range" + value { + list { + f: 0.0799999982119 + f: 1.0 + } + } + } + attr { + key: "aspect_ratio_range" + value { + list { + f: 0.75 + f: 1.33333337307 + } + } + } + attr { + key: "max_attempts" + value { + i: 1 + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + attr { + key: "use_image_if_no_bounding_boxes" + value { + b: true + } + } + } + node_def { + name: "distorted_bounding_box_crop/Slice" + op: "Slice" + input: "convert_image:z:0" + input: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2:begin:0" + input: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2:size:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Shape" + op: "Shape" + input: "convert_image:z:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Shape_1" + op: "Shape" + input: "distorted_bounding_box_crop/Slice:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Equal" + op: "Equal" + input: "Shape:output:0" + input: "Shape_1:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Cast" + op: "Cast" + input: "Equal:z:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_BOOL + } + } + } + node_def { + name: "Const_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "Sum" + op: "Sum" + input: "Cast:y:0" + input: "Const_1:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node_def { + name: "GreaterEqual/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "GreaterEqual" + op: "GreaterEqual" + input: "Sum:output:0" + input: "GreaterEqual/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/Switch" + op: "Switch" + input: "GreaterEqual:z:0" + input: "GreaterEqual:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/switch_t" + op: "Identity" + input: "cond/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/switch_f" + op: "Identity" + input: "cond/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/pred_id" + op: "Identity" + input: "GreaterEqual:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/Shape" + op: "Shape" + input: "cond/Shape/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/Shape/Switch" + op: "Switch" + input: "convert_image:z:0" + input: "cond/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@convert_image" + } + } + } + } + node_def { + name: "cond/Cast" + op: "Cast" + input: "cond/Shape:output:0" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/strided_slice/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice" + op: "StridedSlice" + input: "cond/Cast:y:0" + input: "cond/strided_slice/stack:output:0" + input: "cond/strided_slice/stack_1:output:0" + input: "cond/strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/strided_slice_1/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_1/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/strided_slice_1/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_1" + op: "StridedSlice" + input: "cond/Cast:y:0" + input: "cond/strided_slice_1/stack:output:0" + input: "cond/strided_slice_1/stack_1:output:0" + input: "cond/strided_slice_1/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/Greater" + op: "Greater" + input: "cond/strided_slice:output:0" + input: "cond/strided_slice_1:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/Switch" + op: "Switch" + input: "cond/Greater:z:0" + input: "cond/Greater:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/cond/switch_t" + op: "Identity" + input: "cond/cond/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/cond/switch_f" + op: "Identity" + input: "cond/cond/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/cond/pred_id" + op: "Identity" + input: "cond/Greater:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/cond/strided_slice/stack" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/cond/strided_slice/stack_1" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice/stack_2" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice" + op: "StridedSlice" + input: "cond/cond/strided_slice/Switch:output_true:0" + input: "cond/cond/strided_slice/stack:output:0" + input: "cond/cond/strided_slice/stack_1:output:0" + input: "cond/cond/strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/cond/strided_slice/Switch" + op: "Switch" + input: "cond/Cast:y:0" + input: "cond/cond/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@cond/Cast" + } + } + } + } + node_def { + name: "cond/cond/strided_slice_1/stack" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_1/stack_1" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_1/stack_2" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_1" + op: "StridedSlice" + input: "cond/cond/strided_slice/Switch:output_true:0" + input: "cond/cond/strided_slice_1/stack:output:0" + input: "cond/cond/strided_slice_1/stack_1:output:0" + input: "cond/cond/strided_slice_1/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/cond/truediv" + op: "RealDiv" + input: "cond/cond/strided_slice:output:0" + input: "cond/cond/strided_slice_1:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/mul/y" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 224.0 + } + } + } + } + node_def { + name: "cond/cond/mul" + op: "Mul" + input: "cond/cond/truediv:z:0" + input: "cond/cond/mul/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/Cast/x/1" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 224.0 + } + } + } + } + node_def { + name: "cond/cond/Cast/x" + op: "Pack" + input: "cond/cond/mul:z:0" + input: "cond/cond/Cast/x/1:output:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/cond/Cast" + op: "Cast" + input: "cond/cond/Cast/x:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/strided_slice_2/stack" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_2/stack_1" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_2/stack_2" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_2" + op: "StridedSlice" + input: "cond/cond/strided_slice_2/Switch:output_false:0" + input: "cond/cond/strided_slice_2/stack:output:0" + input: "cond/cond/strided_slice_2/stack_1:output:0" + input: "cond/cond/strided_slice_2/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/cond/strided_slice_2/Switch" + op: "Switch" + input: "cond/Cast:y:0" + input: "cond/cond/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@cond/Cast" + } + } + } + } + node_def { + name: "cond/cond/strided_slice_3/stack" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_3/stack_1" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_3/stack_2" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_3" + op: "StridedSlice" + input: "cond/cond/strided_slice_2/Switch:output_false:0" + input: "cond/cond/strided_slice_3/stack:output:0" + input: "cond/cond/strided_slice_3/stack_1:output:0" + input: "cond/cond/strided_slice_3/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/cond/truediv_1" + op: "RealDiv" + input: "cond/cond/strided_slice_2:output:0" + input: "cond/cond/strided_slice_3:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/mul_1/y" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 224.0 + } + } + } + } + node_def { + name: "cond/cond/mul_1" + op: "Mul" + input: "cond/cond/truediv_1:z:0" + input: "cond/cond/mul_1/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/Cast_1/x/0" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 224.0 + } + } + } + } + node_def { + name: "cond/cond/Cast_1/x" + op: "Pack" + input: "cond/cond/Cast_1/x/0:output:0" + input: "cond/cond/mul_1:z:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/cond/Cast_1" + op: "Cast" + input: "cond/cond/Cast_1/x:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/Merge" + op: "Merge" + input: "cond/cond/Cast_1:y:0" + input: "cond/cond/Cast:y:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/ResizeBicubic/images" + op: "Pack" + input: "cond/Shape/Switch:output_true:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/ResizeBicubic" + op: "ResizeBicubic" + input: "cond/ResizeBicubic/images:output:0" + input: "cond/cond/Merge:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "align_corners" + value { + b: false + } + } + } + node_def { + name: "cond/strided_slice_2/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice_2/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_2/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_2" + op: "StridedSlice" + input: "cond/ResizeBicubic:resized_images:0" + input: "cond/strided_slice_2/stack:output:0" + input: "cond/strided_slice_2/stack_1:output:0" + input: "cond/strided_slice_2/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/Shape_1" + op: "Shape" + input: "cond/strided_slice_2:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/strided_slice_3/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice_3/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_3/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_3" + op: "StridedSlice" + input: "cond/Shape_1:output:0" + input: "cond/strided_slice_3/stack:output:0" + input: "cond/strided_slice_3/stack_1:output:0" + input: "cond/strided_slice_3/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/Shape_2" + op: "Shape" + input: "cond/strided_slice_2:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/strided_slice_4/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_4/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/strided_slice_4/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_4" + op: "StridedSlice" + input: "cond/Shape_2:output:0" + input: "cond/strided_slice_4/stack:output:0" + input: "cond/strided_slice_4/stack_1:output:0" + input: "cond/strided_slice_4/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/sub/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/sub" + op: "Sub" + input: "cond/strided_slice_3:output:0" + input: "cond/sub/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/add/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/add" + op: "Add" + input: "cond/sub:z:0" + input: "cond/add/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/truediv/Cast" + op: "Cast" + input: "cond/add:z:0" + attr { + key: "DstT" + value { + type: DT_DOUBLE + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv/Cast_1" + op: "Cast" + input: "cond/truediv/y:output:0" + attr { + key: "DstT" + value { + type: DT_DOUBLE + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv" + op: "RealDiv" + input: "cond/truediv/Cast:y:0" + input: "cond/truediv/Cast_1:y:0" + attr { + key: "T" + value { + type: DT_DOUBLE + } + } + } + node_def { + name: "cond/sub_1/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/sub_1" + op: "Sub" + input: "cond/strided_slice_4:output:0" + input: "cond/sub_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/add_1/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/add_1" + op: "Add" + input: "cond/sub_1:z:0" + input: "cond/add_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv_1/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/truediv_1/Cast" + op: "Cast" + input: "cond/add_1:z:0" + attr { + key: "DstT" + value { + type: DT_DOUBLE + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv_1/Cast_1" + op: "Cast" + input: "cond/truediv_1/y:output:0" + attr { + key: "DstT" + value { + type: DT_DOUBLE + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv_1" + op: "RealDiv" + input: "cond/truediv_1/Cast:y:0" + input: "cond/truediv_1/Cast_1:y:0" + attr { + key: "T" + value { + type: DT_DOUBLE + } + } + } + node_def { + name: "cond/Shape_3" + op: "Shape" + input: "cond/strided_slice_2:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/Rank" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "cond/Equal/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "cond/Equal" + op: "Equal" + input: "cond/Rank:output:0" + input: "cond/Equal/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/Assert/Const" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Rank of image must be equal to 3." + } + } + } + } + node_def { + name: "cond/Assert/Assert/data_0" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Rank of image must be equal to 3." + } + } + } + } + node_def { + name: "cond/Assert/Assert" + op: "Assert" + input: "cond/Equal:z:0" + input: "cond/Assert/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "cond/strided_slice_5/stack" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/strided_slice_5/stack_1" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 3 + } + } + } + } + node_def { + name: "cond/strided_slice_5/stack_2" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_5" + op: "StridedSlice" + input: "cond/Shape_3:output:0" + input: "cond/strided_slice_5/stack:output:0" + input: "cond/strided_slice_5/stack_1:output:0" + input: "cond/strided_slice_5/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/stack/0" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/stack/1" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/stack" + op: "Pack" + input: "cond/stack/0:output:0" + input: "cond/stack/1:output:0" + input: "cond/strided_slice_5:output:0" + attr { + key: "N" + value { + i: 3 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/strided_slice_6/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice_6/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_6/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_6" + op: "StridedSlice" + input: "cond/Shape_3:output:0" + input: "cond/strided_slice_6/stack:output:0" + input: "cond/strided_slice_6/stack_1:output:0" + input: "cond/strided_slice_6/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/GreaterEqual/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/GreaterEqual" + op: "GreaterEqual" + input: "cond/strided_slice_6:output:0" + input: "cond/GreaterEqual/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/strided_slice_7/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_7/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/strided_slice_7/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_7" + op: "StridedSlice" + input: "cond/Shape_3:output:0" + input: "cond/strided_slice_7/stack:output:0" + input: "cond/strided_slice_7/stack_1:output:0" + input: "cond/strided_slice_7/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/GreaterEqual_1/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/GreaterEqual_1" + op: "GreaterEqual" + input: "cond/strided_slice_7:output:0" + input: "cond/GreaterEqual_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/LogicalAnd" + op: "LogicalAnd" + input: "cond/GreaterEqual:z:0" + input: "cond/GreaterEqual_1:z:0" + } + node_def { + name: "cond/Assert_1/Const" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Crop size greater than the image size." + } + } + } + } + node_def { + name: "cond/Assert_1/Assert/data_0" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Crop size greater than the image size." + } + } + } + } + node_def { + name: "cond/Assert_1/Assert" + op: "Assert" + input: "cond/LogicalAnd:z:0" + input: "cond/Assert_1/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "cond/stack_1/2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_DOUBLE + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_DOUBLE + tensor_shape { + } + double_val: 0.0 + } + } + } + } + node_def { + name: "cond/stack_1" + op: "Pack" + input: "cond/truediv:z:0" + input: "cond/truediv_1:z:0" + input: "cond/stack_1/2:output:0" + attr { + key: "N" + value { + i: 3 + } + } + attr { + key: "T" + value { + type: DT_DOUBLE + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/ToInt32" + op: "Cast" + input: "cond/stack_1:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_DOUBLE + } + } + } + node_def { + name: "cond/Slice" + op: "Slice" + input: "cond/strided_slice_2:output:0" + input: "cond/ToInt32:y:0" + input: "cond/stack:output:0" + input: "^cond/Assert_1/Assert" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/Reshape" + op: "Reshape" + input: "cond/Slice:output:0" + input: "cond/stack:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/ResizeBicubic_1/images" + op: "Pack" + input: "cond/ResizeBicubic_1/images/Switch:output_false:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/ResizeBicubic_1/images/Switch" + op: "Switch" + input: "distorted_bounding_box_crop/Slice:output:0" + input: "cond/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@distorted_bounding_box_crop/Slice" + } + } + } + } + node_def { + name: "cond/ResizeBicubic_1/size" + op: "Const" + input: "^cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\340\000\000\000\340\000\000\000" + } + } + } + } + node_def { + name: "cond/ResizeBicubic_1" + op: "ResizeBicubic" + input: "cond/ResizeBicubic_1/images:output:0" + input: "cond/ResizeBicubic_1/size:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "align_corners" + value { + b: false + } + } + } + node_def { + name: "cond/strided_slice_8/stack" + op: "Const" + input: "^cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice_8/stack_1" + op: "Const" + input: "^cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_8/stack_2" + op: "Const" + input: "^cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_8" + op: "StridedSlice" + input: "cond/ResizeBicubic_1:resized_images:0" + input: "cond/strided_slice_8/stack:output:0" + input: "cond/strided_slice_8/stack_1:output:0" + input: "cond/strided_slice_8/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/Merge" + op: "Merge" + input: "cond/strided_slice_8:output:0" + input: "cond/Reshape:output:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Const_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 3 + } + } + tensor_content: "\354Q\370>\325x\351>;\337\317>" + } + } + } + } + node_def { + name: "sub" + op: "Sub" + input: "cond/Merge:output:0" + input: "Const_2:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Const_3" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 3 + } + } + tensor_content: "\372~j>B`e>fff>" + } + } + } + } + node_def { + name: "truediv" + op: "RealDiv" + input: "sub:z:0" + input: "Const_3:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/control_dependency" + op: "Identity" + input: "truediv:z:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@truediv" + } + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/min" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/max" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/RandomUniform" + op: "RandomUniform" + input: "random_flip_left_right/random_uniform/shape:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/sub" + op: "Sub" + input: "random_flip_left_right/random_uniform/max:output:0" + input: "random_flip_left_right/random_uniform/min:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/mul" + op: "Mul" + input: "random_flip_left_right/random_uniform/RandomUniform:output:0" + input: "random_flip_left_right/random_uniform/sub:z:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/random_uniform" + op: "Add" + input: "random_flip_left_right/random_uniform/mul:z:0" + input: "random_flip_left_right/random_uniform/min:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/Less/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.5 + } + } + } + } + node_def { + name: "random_flip_left_right/Less" + op: "Less" + input: "random_flip_left_right/random_uniform:z:0" + input: "random_flip_left_right/Less/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/Switch" + op: "Switch" + input: "random_flip_left_right/Less:z:0" + input: "random_flip_left_right/Less:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "random_flip_left_right/switch_t" + op: "Identity" + input: "random_flip_left_right/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "random_flip_left_right/switch_f" + op: "Identity" + input: "random_flip_left_right/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "random_flip_left_right/pred_id" + op: "Identity" + input: "random_flip_left_right/Less:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "random_flip_left_right/ReverseV2/axis" + op: "Const" + input: "^random_flip_left_right/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "random_flip_left_right/ReverseV2" + op: "ReverseV2" + input: "random_flip_left_right/ReverseV2/Switch:output_true:0" + input: "random_flip_left_right/ReverseV2/axis:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + } + node_def { + name: "random_flip_left_right/ReverseV2/Switch" + op: "Switch" + input: "random_flip_left_right/control_dependency:output:0" + input: "random_flip_left_right/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@truediv" + } + } + } + } + node_def { + name: "random_flip_left_right/Switch_1" + op: "Switch" + input: "random_flip_left_right/control_dependency:output:0" + input: "random_flip_left_right/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@truediv" + } + } + } + } + node_def { + name: "random_flip_left_right/Merge" + op: "Merge" + input: "random_flip_left_right/Switch_1:output_false:0" + input: "random_flip_left_right/ReverseV2:output:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Reshape_1/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 3 + } + } + tensor_content: "\340\000\000\000\340\000\000\000\003\000\000\000" + } + } + } + } + node_def { + name: "Reshape_1" + op: "Reshape" + input: "random_flip_left_right/Merge:output:0" + input: "Reshape_1/shape:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Reshape_2/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "Reshape_2" + op: "Reshape" + input: "ParseSingleExample/ParseSingleExample:dense_values:0" + input: "Reshape_2/shape:output:0" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Cast_1" + op: "Cast" + input: "Reshape_2:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_INT64 + } + } + } + node_def { + name: "sub_1/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "sub_1" + op: "Sub" + input: "Cast_1:y:0" + input: "sub_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + ret { + key: "Reshape_1" + value: "Reshape_1:output:0" + } + ret { + key: "sub_1" + value: "sub_1:z:0" + } + } + function { + signature { + name: "tf_predicate_7089b845" + input_arg { + name: "arg0" + type: DT_FLOAT + } + input_arg { + name: "arg1" + type: DT_INT32 + } + input_arg { + name: "Equal/Placeholder" + type: DT_INT64 + } + output_arg { + name: "Equal" + type: DT_BOOL + } + description: "A wrapper for Defun that facilitates shape inference." + } + node_def { + name: "Shape" + op: "Shape" + input: "arg0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT64 + } + } + } + node_def { + name: "strided_slice/stack" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "strided_slice/stack_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "strided_slice/stack_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "strided_slice" + op: "StridedSlice" + input: "Shape:output:0" + input: "strided_slice/stack:output:0" + input: "strided_slice/stack_1:output:0" + input: "strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "Equal" + op: "Equal" + input: "strided_slice:output:0" + input: "Equal/Placeholder" + attr { + key: "T" + value { + type: DT_INT64 + } + } + } + ret { + key: "Equal" + value: "Equal:z:0" + } + } + function { + signature { + name: "_make_dataset_5fa5e1f4" + output_arg { + name: "PrefetchDataset_1" + type: DT_VARIANT + } + is_stateful: true + } + node_def { + name: "TensorSliceDataset/MatchingFiles/pattern" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "$(DATA_DIR)" + } + } + } + } + node_def { + name: "TensorSliceDataset/MatchingFiles" + op: "MatchingFiles" + input: "TensorSliceDataset/MatchingFiles/pattern:output:0" + } + node_def { + name: "TensorSliceDataset" + op: "TensorSliceDataset" + input: "TensorSliceDataset/MatchingFiles:filenames:0" + attr { + key: "Toutput_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "ShuffleDataset/MatchingFiles/pattern" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "$(DATA_DIR)" + } + } + } + } + node_def { + name: "ShuffleDataset/MatchingFiles" + op: "MatchingFiles" + input: "ShuffleDataset/MatchingFiles/pattern:output:0" + } + node_def { + name: "ShuffleDataset/Shape" + op: "Shape" + input: "ShuffleDataset/MatchingFiles:filenames:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "out_type" + value { + type: DT_INT64 + } + } + } + node_def { + name: "ShuffleDataset/strided_slice/stack" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset/strided_slice/stack_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "ShuffleDataset/strided_slice/stack_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "ShuffleDataset/strided_slice" + op: "StridedSlice" + input: "ShuffleDataset/Shape:output:0" + input: "ShuffleDataset/strided_slice/stack:output:0" + input: "ShuffleDataset/strided_slice/stack_1:output:0" + input: "ShuffleDataset/strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "ShuffleDataset/Maximum/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1 + } + } + } + } + node_def { + name: "ShuffleDataset/Maximum" + op: "Maximum" + input: "ShuffleDataset/strided_slice:output:0" + input: "ShuffleDataset/Maximum/y:output:0" + attr { + key: "T" + value { + type: DT_INT64 + } + } + } + node_def { + name: "ShuffleDataset/seed" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset/seed2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset" + op: "ShuffleDataset" + input: "TensorSliceDataset:handle:0" + input: "ShuffleDataset/Maximum:z:0" + input: "ShuffleDataset/seed:output:0" + input: "ShuffleDataset/seed2:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "reshuffle_each_iteration" + value { + b: true + } + } + } + node_def { + name: "ShuffleDataset_1/buffer_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1024 + } + } + } + } + node_def { + name: "ShuffleDataset_1/seed_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset_1/seed2_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset_1" + op: "ShuffleDataset" + input: "ShuffleDataset:handle:0" + input: "ShuffleDataset_1/buffer_size:output:0" + input: "ShuffleDataset_1/seed_1:output:0" + input: "ShuffleDataset_1/seed2_1:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "reshuffle_each_iteration" + value { + b: true + } + } + } + node_def { + name: "RepeatDataset/count" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -1 + } + } + } + } + node_def { + name: "RepeatDataset" + op: "RepeatDataset" + input: "ShuffleDataset_1:handle:0" + input: "RepeatDataset/count:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/cycle_length" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 8 + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/block_length" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1 + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/sloppy" + op: "Const" + attr { + key: "dtype" + value { + type: DT_BOOL + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_BOOL + tensor_shape { + } + bool_val: true + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/buffer_output_elements" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 2 + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/prefetch_input_elements" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 16 + } + } + } + } + node_def { + name: "ParallelInterleaveDataset" + op: "ParallelInterleaveDataset" + input: "RepeatDataset:handle:0" + input: "ParallelInterleaveDataset/cycle_length:output:0" + input: "ParallelInterleaveDataset/block_length:output:0" + input: "ParallelInterleaveDataset/sloppy:output:0" + input: "ParallelInterleaveDataset/buffer_output_elements:output:0" + input: "ParallelInterleaveDataset/prefetch_input_elements:output:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_91295dea" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + } + node_def { + name: "ShuffleDataset_2/buffer_size_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1024 + } + } + } + } + node_def { + name: "ShuffleDataset_2/seed_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset_2/seed2_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset_2" + op: "ShuffleDataset" + input: "ParallelInterleaveDataset:handle:0" + input: "ShuffleDataset_2/buffer_size_1:output:0" + input: "ShuffleDataset_2/seed_2:output:0" + input: "ShuffleDataset_2/seed2_2:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "reshuffle_each_iteration" + value { + b: true + } + } + } + node_def { + name: "ParallelMapDataset/num_parallel_calls" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 64 + } + } + } + } + node_def { + name: "ParallelMapDataset" + op: "ParallelMapDataset" + input: "ShuffleDataset_2:handle:0" + input: "ParallelMapDataset/num_parallel_calls:output:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_74b6b15c" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "PrefetchDataset/buffer_size_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 64 + } + } + } + } + node_def { + name: "PrefetchDataset" + op: "PrefetchDataset" + input: "ParallelMapDataset:handle:0" + input: "PrefetchDataset/buffer_size_2:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "BatchDataset/batch_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 64 + } + } + } + } + node_def { + name: "BatchDataset" + op: "BatchDataset" + input: "PrefetchDataset:handle:0" + input: "BatchDataset/batch_size:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "FilterDataset/batch_size_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 64 + } + } + } + } + node_def { + name: "FilterDataset" + op: "FilterDataset" + input: "BatchDataset:handle:0" + input: "FilterDataset/batch_size_1:output:0" + attr { + key: "Targuments" + value { + list { + type: DT_INT64 + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + attr { + key: "predicate" + value { + func { + name: "tf_predicate_7089b845" + } + } + } + } + node_def { + name: "PrefetchDataset_1/buffer_size_3" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 2 + } + } + } + } + node_def { + name: "PrefetchDataset_1" + op: "PrefetchDataset" + input: "FilterDataset:handle:0" + input: "PrefetchDataset_1/buffer_size_3:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 64 + } + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + dim { + size: 64 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + ret { + key: "PrefetchDataset_1" + value: "PrefetchDataset_1:handle:0" + } + } +} +)PREFIX"; + + *dataset_name = "_make_dataset_5fa5e1f4"; + std::function mutate_proto_func = + [dataset_name, file_path](FunctionDef* fdef) { + VLOG(1) << "Processsing function " << fdef->DebugString(); + if (std::string(fdef->signature().name()) != *dataset_name) return; + // Change the input file pattern to `file_path`. + bool found = false; + for (auto& node_def : *fdef->mutable_node_def()) { + if (node_def.name() != "TensorSliceDataset/MatchingFiles/pattern" && + node_def.name() != "ShuffleDataset/MatchingFiles/pattern") + continue; + DCHECK_EQ(node_def.op(), "Const"); + DCHECK_GT(node_def.attr().count("value"), 0); + found = true; + DCHECK_EQ(node_def.attr().at("value").tensor().string_val(0), + "$(DATA_DIR)"); + VLOG(1) << "Setting the value of node_def " + "TensorSliceDataset/MatchingFiles/pattern to " + << file_path; + auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor(); + tensor->clear_string_val(); + tensor->add_string_val(file_path); + } + VLOG(1) << "Rewrote function to " << fdef->DebugString(); + DCHECK(found); + }; + return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status); +} + +// On success, returns a set of TF_Function instances encoding a dataset +// node stack that reads an MNIST file dataset from `file_path`, and +// sets `dataset_name` to the created dataset name. The returned functions must +// be deleted by calling TF_DeleteFunction. +static std::vector CreateMNISTDatasetFunctions( + const char* file_path, int batch_size, std::string* dataset_name, + TF_Status* status) { + const char* func_def = R"PREFIX( +library { + function { + signature { + name: "tf_map_func_521bfd08" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "truediv" + type: DT_FLOAT + } + description: "A wrapper for Defun that facilitates shape inference." + } + node_def { + name: "DecodeRaw" + op: "DecodeRaw" + input: "arg0" + attr { + key: "little_endian" + value { + b: true + } + } + attr { + key: "out_type" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "Cast" + op: "Cast" + input: "DecodeRaw:output:0" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "Reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 784 + } + } + } + } + node_def { + name: "Reshape" + op: "Reshape" + input: "Cast:y:0" + input: "Reshape/shape:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "truediv/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 255.0 + } + } + } + } + node_def { + name: "truediv" + op: "RealDiv" + input: "Reshape:output:0" + input: "truediv/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + ret { + key: "truediv" + value: "truediv:z:0" + } + } + function { + signature { + name: "tf_map_func_9a08860d" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "ToInt32" + type: DT_INT32 + } + description: "A wrapper for Defun that facilitates shape inference." + } + node_def { + name: "DecodeRaw" + op: "DecodeRaw" + input: "arg0" + attr { + key: "little_endian" + value { + b: true + } + } + attr { + key: "out_type" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "Reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "Reshape" + op: "Reshape" + input: "DecodeRaw:output:0" + input: "Reshape/shape:output:0" + attr { + key: "T" + value { + type: DT_UINT8 + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ToInt32" + op: "Cast" + input: "Reshape:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_UINT8 + } + } + } + ret { + key: "ToInt32" + value: "ToInt32:y:0" + } + } + function { + signature { + name: "tf_predicate_7089b845" + input_arg { + name: "arg0" + type: DT_FLOAT + } + input_arg { + name: "arg1" + type: DT_INT32 + } + input_arg { + name: "Equal/Placeholder" + type: DT_INT64 + } + output_arg { + name: "Equal" + type: DT_BOOL + } + description: "A wrapper for Defun that facilitates shape inference." + } + node_def { + name: "Shape" + op: "Shape" + input: "arg0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT64 + } + } + } + node_def { + name: "strided_slice/stack" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "strided_slice/stack_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "strided_slice/stack_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "strided_slice" + op: "StridedSlice" + input: "Shape:output:0" + input: "strided_slice/stack:output:0" + input: "strided_slice/stack_1:output:0" + input: "strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "Equal" + op: "Equal" + input: "strided_slice:output:0" + input: "Equal/Placeholder" + attr { + key: "T" + value { + type: DT_INT64 + } + } + } + ret { + key: "Equal" + value: "Equal:z:0" + } + } + function { + signature { + name: "_make_dataset_2451e43a" + output_arg { + name: "FilterDataset" + type: DT_VARIANT + } + is_stateful: true + } + node_def { + name: "FixedLengthRecordDataset/filenames" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "$(DATA_DIR)/train-images-idx3-ubyte" + } + } + } + } + node_def { + name: "FixedLengthRecordDataset/header_bytes" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 16 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset/record_bytes" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 784 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset/footer_bytes" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset/buffer_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 262144 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset" + op: "FixedLengthRecordDataset" + input: "FixedLengthRecordDataset/filenames:output:0" + input: "FixedLengthRecordDataset/header_bytes:output:0" + input: "FixedLengthRecordDataset/record_bytes:output:0" + input: "FixedLengthRecordDataset/footer_bytes:output:0" + input: "FixedLengthRecordDataset/buffer_size:output:0" + } + node_def { + name: "MapDataset" + op: "MapDataset" + input: "FixedLengthRecordDataset:handle:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_521bfd08" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/filenames_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "$(DATA_DIR)/train-labels-idx1-ubyte" + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/header_bytes_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 8 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/record_bytes_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/footer_bytes_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/buffer_size_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 262144 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1" + op: "FixedLengthRecordDataset" + input: "FixedLengthRecordDataset_1/filenames_1:output:0" + input: "FixedLengthRecordDataset_1/header_bytes_1:output:0" + input: "FixedLengthRecordDataset_1/record_bytes_1:output:0" + input: "FixedLengthRecordDataset_1/footer_bytes_1:output:0" + input: "FixedLengthRecordDataset_1/buffer_size_1:output:0" + } + node_def { + name: "MapDataset_1" + op: "MapDataset" + input: "FixedLengthRecordDataset_1:handle:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_9a08860d" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT32 + } + } + } + } + node_def { + name: "ZipDataset" + op: "ZipDataset" + input: "MapDataset:handle:0" + input: "MapDataset_1:handle:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "CacheDataset/filename" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "" + } + } + } + } + node_def { + name: "CacheDataset" + op: "CacheDataset" + input: "ZipDataset:handle:0" + input: "CacheDataset/filename:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "RepeatDataset/count" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -1 + } + } + } + } + node_def { + name: "RepeatDataset" + op: "RepeatDataset" + input: "CacheDataset:handle:0" + input: "RepeatDataset/count:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "ShuffleDataset/buffer_size_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 50000 + } + } + } + } + node_def { + name: "ShuffleDataset/seed" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset/seed2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset" + op: "ShuffleDataset" + input: "RepeatDataset:handle:0" + input: "ShuffleDataset/buffer_size_2:output:0" + input: "ShuffleDataset/seed:output:0" + input: "ShuffleDataset/seed2:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + attr { + key: "reshuffle_each_iteration" + value { + b: true + } + } + } + node_def { + name: "BatchDataset/batch_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -123 + } + } + } + } + node_def { + name: "BatchDataset" + op: "BatchDataset" + input: "ShuffleDataset:handle:0" + input: "BatchDataset/batch_size:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "FilterDataset/batch_size_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -123 + } + } + } + } + node_def { + name: "FilterDataset" + op: "FilterDataset" + input: "BatchDataset:handle:0" + input: "FilterDataset/batch_size_1:output:0" + attr { + key: "Targuments" + value { + list { + type: DT_INT64 + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + attr { + key: "predicate" + value { + func { + name: "tf_predicate_7089b845" + } + } + } + } + ret { + key: "FilterDataset" + value: "FilterDataset:handle:0" + } + } +} +)PREFIX"; + + *dataset_name = "_make_dataset_2451e43a"; + std::function mutate_proto_func = + [dataset_name, file_path, batch_size](FunctionDef* fdef) { + VLOG(1) << "Processsing function " << fdef->DebugString(); + if (std::string(fdef->signature().name()) != *dataset_name) return; + // Change the input file pattern to `file_path`. + bool found_file_path = false, found_batch_size = false; + // `node_def` may be mutated. + for (auto& node_def : *fdef->mutable_node_def()) { + if (node_def.name() == "FixedLengthRecordDataset/filenames" || + node_def.name() == "FixedLengthRecordDataset_1/filenames_1") { + DCHECK_EQ(node_def.op(), "Const"); + DCHECK_GT(node_def.attr().count("value"), 0); + found_file_path = true; + // Replace $(DATA_DIR)/foo with /foo + // TODO(hongm): Use StringPiece manipulation for better efficiency. + const std::string cur_value = + node_def.attr().at("value").tensor().string_val(0); + const std::string pattern = "$(DATA_DIR)"; + DCHECK_EQ(cur_value.compare(0, pattern.length(), pattern), 0); + const std::string new_value = + file_path + cur_value.substr(pattern.length()); + VLOG(1) << "Setting the value of node_def " << node_def.name() + << " to " << new_value; + auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor(); + tensor->clear_string_val(); + tensor->add_string_val(new_value); + } else if (node_def.name() == "BatchDataset/batch_size" || + node_def.name() == "FilterDataset/batch_size_1") { + DCHECK_EQ(node_def.op(), "Const"); + DCHECK_GT(node_def.attr().count("value"), 0); + found_batch_size = true; + // Replace $(BATCH_SIZE) with `batch_size` + DCHECK_EQ(node_def.attr().at("value").tensor().int64_val(0), -123); + VLOG(1) << "Setting the batch size attr value of node_def " + << node_def.name() << " to " << batch_size; + auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor(); + tensor->clear_int64_val(); + tensor->add_int64_val(batch_size); + } + } + VLOG(1) << "Rewrote function to " << fdef->DebugString(); + DCHECK(found_file_path); + DCHECK(found_batch_size); + }; + return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status); +} + +// Adds the input functions to `graph`. On success, returns the created +// IteratorGetNext node. +static TF_Operation* AddDatasetFunctionAndIteratorNodesToGraph( + const std::vector& funcs, const std::string& dataset_name, + const std::vector& output_types, + const std::vector& output_shapes, + TF_Graph* graph, TF_Status* status) { + DCHECK(!dataset_name.empty()); + for (auto& func : funcs) { + TF_GraphCopyFunction(graph, func.get(), /*gradient*/ nullptr, status); + if (!status->status.ok()) { + return nullptr; + } + } + + tensorflow::mutex_lock c(graph->mu); + + tensorflow::NameAttrList func; + func.set_name(dataset_name); + // Run the iterator node on CPU. + Node* oneshot_iterator_node; + tensorflow::Status s = NodeBuilder("OneShotIterator", "OneShotIterator") + .Device("/device:CPU:0") + .Attr("container", "") + .Attr("dataset_factory", func) + .Attr("output_types", output_types) + .Attr("output_shapes", output_shapes) + .Attr("shared_name", "") + .Finalize(&graph->graph, &oneshot_iterator_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + // Run shape inference function for each newly added node, so that more + // subsequent nodes can be added to the graph via C API (TF_NewOperation()). + s = graph->refiner.AddNode(oneshot_iterator_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + + // Run the iterator node on CPU. + Node* getnext_node; + s = NodeBuilder("IteratorGetNext", "IteratorGetNext") + .Input(oneshot_iterator_node) + .Device("/device:CPU:0") + .Attr("output_types", output_types) + .Attr("output_shapes", output_shapes) + .Finalize(&graph->graph, &getnext_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + // Run shape inference function for each newly added node, so that more + // subsequent nodes can be added to the graph via C API (TF_NewOperation()). + s = graph->refiner.AddNode(getnext_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + + VLOG(1) << "Output graph: " << graph->graph.ToGraphDefDebug().DebugString(); + return ToTF_Operation(getnext_node); +} + +TF_Operation* TF_MakeFakeIteratorGetNextWithDatasets(TF_Graph* graph, + TF_Status* status) { + tensorflow::Status s; + + std::string dataset_name; + UniqueFuncPtr result_func = CreateFakeDatasetFunction(&dataset_name, status); + if (!status->status.ok()) { + return nullptr; + } + + std::vector funcs; + funcs.push_back(std::move(result_func)); + std::vector output_shape_list; + output_shape_list.push_back(tensorflow::TensorShapeProto()); + auto* getnext_node = AddDatasetFunctionAndIteratorNodesToGraph( + funcs, dataset_name, {tensorflow::DT_FLOAT}, output_shape_list, graph, + status); + if (!status->status.ok()) { + return nullptr; + } + + return getnext_node; +} + +TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets( + TF_Graph* graph, const char* file_path, int batch_size, + unsigned char is_mnist, TF_Status* status) { + tensorflow::Status s; + + std::string dataset_name; + const auto& funcs = + is_mnist + ? CreateMNISTDatasetFunctions(file_path, batch_size, &dataset_name, + status) + : CreateImagenetDatasetFunctions(file_path, &dataset_name, status); + if (!status->status.ok()) { + return nullptr; + } + + std::vector output_shape_list; + // batch_size X 224 X 224 X 3 + auto image_shape = tensorflow::TensorShapeProto(); + image_shape.add_dim()->set_size(batch_size); + if (is_mnist) { + image_shape.add_dim()->set_size(784); + } else { + image_shape.add_dim()->set_size(224); + image_shape.add_dim()->set_size(224); + image_shape.add_dim()->set_size(3); + } + output_shape_list.push_back(image_shape); + + // batch_size + auto label_shape = tensorflow::TensorShapeProto(); + label_shape.add_dim()->set_size(batch_size); + output_shape_list.push_back(label_shape); + auto* getnext_node = AddDatasetFunctionAndIteratorNodesToGraph( + funcs, dataset_name, {tensorflow::DT_FLOAT, tensorflow::DT_INT32}, + output_shape_list, graph, status); + if (!status->status.ok()) { + return nullptr; + } + + tensorflow::mutex_lock c(graph->mu); + VLOG(1) << "The extended graph: " + << graph->graph.ToGraphDefDebug().DebugString(); + + return getnext_node; +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 5a7b007e40aa199889b2d00b2bde5976c19e2966..ebcec8176b63f9a91c847ebe96fba3ff023fc599 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -25,6 +25,7 @@ limitations under the License. // Experimental C API for TensorFlow. // // The API here is subject to changes in the future. +// -------------------------------------------------------------------------- // Macro to control visibility of exported symbols in the shared library (.so, // .dylib, .dll). @@ -59,6 +60,53 @@ extern "C" { TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable); +// Initializes TPU system. Must be called exactly once before TF_SessionRun() is +// called on a TPU graph. +// +// The session graph must contain a node named ConfigureDistributedTPU. +// TODO(b/74774824): Improve the API on initializing TPU system. +TF_CAPI_EXPORT extern void TF_InitializeTPU(TF_Session* session, + TF_Status* status); + +// Shuts down TPU system. For any `session` where TF_InitializeTPU() has +// been successfully called, this call must be made exactly once before the +// session is closed. +// The session graph must contain a node named ShutdownDistributedTPU. +TF_CAPI_EXPORT extern void TF_ShutdownTPU(TF_Session* session, + TF_Status* status); + +// Returns the graph content in a human-readable format, with length set in +// `len`. The format is subject to change in the future. +// The returned string is heap-allocated, and caller should call free() on it. +TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph, + size_t* len); + +// Returns the graph content in a human-readable format, with length set in +// `len`. The format is subject to change in the future. +// The returned string is heap-allocated, and caller should call free() on it. +TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph, + size_t* len); + +// Creates a stack of data set + iterator nodes, currently hard-coded to return +// a sequence of 3 float values <42.0, 43.0, 44.0> over 3 calls. On success, +// returns the IteratorGetNext node, which caller can run or feed into an node. +// +// TODO(hongm): Extend the API to allow customization of the nodes created. +TF_CAPI_EXPORT extern TF_Operation* TF_MakeFakeIteratorGetNextWithDatasets( + TF_Graph* graph, TF_Status* status); + +// Similar to the above API, except that the returned iterator reads the +// file based dataset from `file_path`. +// If `is_mnist` is 0, the dataset corresponds to ImageNet. +// The iterators outputs 2 tensors: +// - A float tensor of shape `batch_size` X 784 when `is_mnist` is non-zero, or +// `batch_size` X 224 X 224 X 3 otherwise. +// - An int32 tensor of shape `batch_size` +// TODO(hongm): Extend the API to allow customization of the nodes created. +TF_CAPI_EXPORT extern TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets( + TF_Graph* graph, const char* file_path, int batch_size, + unsigned char is_mnist, TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..30fcfd401d9d634962d64aaa3bf348de91f2ecae --- /dev/null +++ b/tensorflow/c/c_api_experimental_test.cc @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/c_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +void TestFakeIteratorStack() { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + TF_Operation* get_next = TF_MakeFakeIteratorGetNextWithDatasets(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + CSession csession(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Run the graph. + const float base_value = 42.0; + for (int i = 0; i < 3; ++i) { + csession.SetOutputs({get_next}); + csession.Run(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Tensor* out = csession.output_tensor(0); + ASSERT_TRUE(out != nullptr); + ASSERT_EQ(TF_FLOAT, TF_TensorType(out)); + ASSERT_EQ(0, TF_NumDims(out)); // scalar + ASSERT_EQ(sizeof(float), TF_TensorByteSize(out)); + float* output_contents = static_cast(TF_TensorData(out)); + ASSERT_EQ(base_value + i, *output_contents); + } + + // This should error out since we've exhausted the iterator. + csession.Run(s); + ASSERT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)) << TF_Message(s); + + // Clean up + csession.CloseAndDelete(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + +TEST(CAPI_EXPERIMENTAL, FakeIteratorGetNext) { TestFakeIteratorStack(); } + +TEST(CAPI_EXPERIMENTAL, ImagenetIteratorGetNext) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + const string file_path = tensorflow::io::JoinPath( + tensorflow::testing::TensorFlowSrcRoot(), "c/testdata/tf_record"); + VLOG(1) << "data file path is " << file_path; + const int batch_size = 64; + TF_Operation* get_next = TF_MakeFileBasedIteratorGetNextWithDatasets( + graph, file_path.c_str(), batch_size, /*is_mnist*/ false, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + CSession csession(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Run the graph. + // The two output tensors should look like: + // Tensor("IteratorGetNext:0", shape=(batch_size, 224, 224, 3), dtype=float32) + // Tensor("IteratorGetNext:1", shape=(batch_size, ), dtype=int32) + for (int i = 0; i < 3; ++i) { + LOG(INFO) << "Running iter " << i; + csession.SetOutputs({{get_next, 0}, {get_next, 1}}); + csession.Run(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + { + TF_Tensor* image = csession.output_tensor(0); + ASSERT_TRUE(image != nullptr); + ASSERT_EQ(TF_FLOAT, TF_TensorType(image)); + // Confirm shape is 224 X 224 X 3 + ASSERT_EQ(4, TF_NumDims(image)); + ASSERT_EQ(batch_size, TF_Dim(image, 0)); + ASSERT_EQ(224, TF_Dim(image, 1)); + ASSERT_EQ(224, TF_Dim(image, 2)); + ASSERT_EQ(3, TF_Dim(image, 3)); + ASSERT_EQ(sizeof(float) * batch_size * 224 * 224 * 3, + TF_TensorByteSize(image)); + } + + { + TF_Tensor* label = csession.output_tensor(1); + ASSERT_TRUE(label != nullptr); + ASSERT_EQ(TF_INT32, TF_TensorType(label)); + ASSERT_EQ(1, TF_NumDims(label)); + ASSERT_EQ(batch_size, TF_Dim(label, 0)); + ASSERT_EQ(sizeof(int32) * batch_size, TF_TensorByteSize(label)); + } + } + + // Clean up + csession.CloseAndDelete(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 7ca50119eafe299b307f06c555aec1388e7e82e2..610274696f5940c063e68f2310cfd9cc1e0bd964 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 91667056e0eeb224b4b8a034766f11a123cd1a03..95652a11378d6276b5ba6540a07baa15aa77cc1c 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -84,19 +84,20 @@ struct TF_Graph { std::unordered_map name_map GUARDED_BY(mu); - // The keys of this map are all the active sessions using this graph. - // Each value is the current "runnability" status of the corresponding - // session. Under normal conditions all statuses are Status::OK(), but - // if some operation is mutated after it was run by a session (this - // is detected in RecordMutation function), that session is no longer - // safe to run. Its status will contain the error that will be returned - // to the user, should she try running this session. + // The keys of this map are all the active sessions using this graph. Each + // value records whether the graph has been mutated since the corresponding + // session has been run (this is detected in RecordMutation function). If the + // string is empty, no mutation has occurred. Otherwise the string is a + // description of the mutation suitable for returning to the user. // // Sessions are added to this map in TF_NewSession, and removed in // TF_DeleteSession. // TF_Graph may only / must be deleted when // sessions.size() == 0 && delete_requested == true - tensorflow::gtl::FlatMap sessions + // + // TODO(b/74949947): mutations currently trigger a warning instead of a bad + // status, this should be reverted when possible. + tensorflow::gtl::FlatMap sessions GUARDED_BY(mu); bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph @@ -124,15 +125,16 @@ struct TF_Session { TF_Session(tensorflow::Session* s, TF_Graph* g); tensorflow::Session* session; - TF_Graph* graph; + TF_Graph* const graph; - tensorflow::mutex mu; + tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu); int last_num_graph_nodes; - // NOTE(ashankar): Experimental fields to help keep the - // buffers of a TF_Tensor pinned in device memory. - const tensorflow::DeviceMgr* device_mgr; // Owned by session. - std::vector devices; // Owned by device_mgr. + // If true, TF_SessionRun and similar methods will call + // ExtendSessionGraphHelper before running the graph (this is the default + // public behavior). Can be set to false if the caller needs to call + // ExtendSessionGraphHelper manually. + std::atomic extend_before_run; }; struct TF_ImportGraphDefOptions { @@ -210,7 +212,11 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, TF_Status* status); void RecordMutation(TF_Graph* graph, const TF_Operation& op, - const char* mutation_type); + const char* mutation_type) + EXCLUSIVE_LOCKS_REQUIRED(graph->mu); + +bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) + LOCKS_EXCLUDED(session->graph->mu, session->mu); } // end namespace tensorflow diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index 3db2852ce6560ba493d60ef54a110161c112d110..f3b28c1708129d39e451d927a89c0d10e2193b63 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -34,6 +34,10 @@ static void DoubleDeallocator(void* data, size_t, void* arg) { delete[] static_cast(data); } +static void FloatDeallocator(void* data, size_t, void* arg) { + delete[] static_cast(data); +} + TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { int64_t num_values = 1; for (int i = 0; i < num_dims; ++i) { @@ -78,21 +82,34 @@ TF_Tensor* DoubleTensor(double v) { &DoubleDeallocator, nullptr); } +TF_Tensor* FloatTensor(float v) { + const int num_bytes = sizeof(float); + float* values = new float[1]; + values[0] = v; + return TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, + &FloatDeallocator, nullptr); +} + // All the *Helper methods are used as a workaround for the restrictions that // one cannot call ASSERT_* methods in non-void-returning functions (when // exceptions are disabled during compilation) void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name, + TF_DataType dtype, const std::vector& dims, TF_Operation** op) { TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); - TF_SetAttrType(desc, "dtype", TF_INT32); + TF_SetAttrType(desc, "dtype", dtype); + if (!dims.empty()) { + TF_SetAttrShape(desc, "shape", dims.data(), dims.size()); + } *op = TF_FinishOperation(desc, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_NE(*op, nullptr); } -TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) { +TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name, + TF_DataType dtype, const std::vector& dims) { TF_Operation* op; - PlaceholderHelper(graph, s, name, &op); + PlaceholderHelper(graph, s, name, dtype, dims, &op); return op; } @@ -126,6 +143,12 @@ TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s, return Const(tensor.get(), graph, s, name); } +TF_Operation* ScalarConst(float v, TF_Graph* graph, TF_Status* s, + const char* name) { + unique_tensor_ptr tensor(FloatTensor(v), TF_DeleteTensor); + return Const(tensor.get(), graph, s, name); +} + void AddOpHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name, TF_Operation** op, bool check) { diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index 2a70177c724c569844a5d8ad42b99bed20209946..cd19cf8d624d9b914b61132f93d918b046cdbd30 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -44,8 +44,12 @@ TF_Tensor* Int32Tensor(int32_t v); TF_Tensor* DoubleTensor(double v); +TF_Tensor* FloatTensor(float v); + TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, - const char* name = "feed"); + const char* name = "feed", + TF_DataType dtype = TF_INT32, + const std::vector& dims = {}); TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name = "const"); @@ -56,6 +60,9 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s, const char* name = "scalar"); +TF_Operation* ScalarConst(float v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar"); + TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name = "add"); diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index e55cb672e97e1403a3dd864c91c176426eb3f067..8df7b5662353e98eb82a13b9e65819a8f4d6261a 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -27,6 +27,12 @@ tf_cuda_library( ":runtime", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", + "//tensorflow/core:core_cpu", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:eager_executor", + "//tensorflow/core/common_runtime/eager:kernel_and_device", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/common_runtime/eager:copy_to_device_node", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -54,11 +60,17 @@ tf_cuda_library( ":runtime", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", + "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:framework_lite", + "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:eager_executor", + "//tensorflow/core/common_runtime/eager:kernel_and_device", + "//tensorflow/core/common_runtime/eager:tensor_handle", ], ) @@ -93,6 +105,7 @@ tf_cuda_library( "//conditions:default": [ "//tensorflow/c:c_api", "//tensorflow/core:core_cpu", + "//tensorflow/core/common_runtime/eager:kernel_and_device", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 98ef6f0d0ab094eae3e2e21624c3a4ba30d1c3d3..eaeb2fd07a3fdc2bfca97afc799bd65609955609 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -31,8 +31,11 @@ limitations under the License. #include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/eager/copy_to_device_node.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" @@ -40,6 +43,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/public/version.h" @@ -65,6 +69,7 @@ string DeviceName(const tensorflow::Device* d) { #ifdef TENSORFLOW_EAGER_USE_XLA std::atomic_int_fast64_t func_id_generator(0); #endif // TENSORFLOW_EAGER_USE_XLA + } // namespace extern "C" { @@ -76,182 +81,140 @@ void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto, TF_SetConfig(&options->session_options, proto, proto_len, status); } +void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options, + unsigned char async) { + options->async = async; +} void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) { options->policy = policy; } +TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, + unsigned char async, + TF_Status* status) { + status->status = ctx->context.SetAsyncForThread(async); +} + void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { - TF_Graph* graph = TF_NewGraph(); - TF_Session* session = TF_NewSession(graph, &opts->session_options, status); - if (status->status.ok()) { - if (session->device_mgr == nullptr || session->devices.empty()) { - status->status = tensorflow::errors::InvalidArgument( - "Provided TF_SessionOptions are not compatible with eager execution " - "(perhaps the TF_SessionOptions alluded to session execution in a " - "remote address space?)"); - } - } + std::vector devices; + status->status = tensorflow::DeviceFactory::AddDevices( + opts->session_options.options, "/job:localhost/replica:0/task:0", + &devices); if (!status->status.ok()) { - TF_DeleteGraph(graph); return nullptr; } - - return new TFE_Context(*opts, session); + std::unique_ptr device_mgr( + new tensorflow::DeviceMgr(devices)); + tensorflow::Rendezvous* r = + new tensorflow::IntraProcessRendezvous(device_mgr.get()); + return new TFE_Context(opts->session_options.options, opts->policy, + opts->async, std::move(device_mgr), r); } void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { - status->status = tensorflow::Status::OK(); - { - tensorflow::mutex_lock ml(ctx->cache_mu); - tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); - } - TF_Graph* graph = ctx->session->graph; - TF_DeleteSession(ctx->session, status); - TF_DeleteGraph(graph); - ctx->rendezvous->Unref(); delete ctx; } TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { - return TF_SessionListDevices(ctx->session, status); + TF_DeviceList* list = new TF_DeviceList; + ctx->context.device_mgr()->ListDeviceAttributes(&list->response); + return list; } -void TFE_ContextClearCaches(TFE_Context* ctx) { - tensorflow::mutex_lock ml(ctx->cache_mu); - tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); -} +void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); } void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { - tensorflow::mutex_lock ml(ctx->policy_map_mu); - ctx->thread_local_policies[std::this_thread::get_id()] = policy; + ctx->context.SetThreadLocalDevicePlacementPolicy( + static_cast(policy)); } +// Note: this function looks up a thread local policy. So it should be called in +// the appropriate client thread. In particular, in async mode, it may not be +// safe to call this function from the async EagerExecutor threads. extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( TFE_Context* ctx) { - tensorflow::mutex_lock ml(ctx->policy_map_mu); - auto policy_map_it = - ctx->thread_local_policies.find(std::this_thread::get_id()); - if (policy_map_it != ctx->thread_local_policies.end()) { - return policy_map_it->second; - } - return ctx->policy; + return static_cast( + ctx->context.GetDevicePlacementPolicy()); +} + +void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) { + status->status = ctx->context.AsyncWait(); +} + +void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) { + status->status = ctx->context.GetStatus(); +} + +void TFE_ContextAsyncClearError(TFE_Context* ctx) { + ctx->context.ClearAsyncError(); } TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { tensorflow::Tensor tensor; status->status = tensorflow::TF_TensorToTensor(t, &tensor); if (!status->status.ok()) return nullptr; - return new TFE_TensorHandle(tensor, nullptr); + return new TFE_TensorHandle(tensor, nullptr, nullptr); } -void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; } +void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { + DCHECK(h); + if (h->handle) { + h->handle->Unref(); + } + delete h; +} TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { - return static_cast(h->t.dtype()); + return static_cast(h->handle->dtype); } -int TFE_TensorHandleNumDims(TFE_TensorHandle* h) { return h->t.dims(); } +int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { + const tensorflow::Tensor* t = nullptr; + status->status = h->handle->Tensor(&t); + return t == nullptr ? 0 : t->dims(); +} -int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index) { - return h->t.dim_size(dim_index); +int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, + TF_Status* status) { + const tensorflow::Tensor* t = nullptr; + status->status = h->handle->Tensor(&t); + return t == nullptr ? 0 : t->dim_size(dim_index); } -const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h) { - // TODO(apassos) this will be potentially incorrect in the distributed case as - // our local device will have a name which depends on the ClusterSpec and - // hence will require the context to resolve. - return (h->d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" - : h->d->name().c_str(); +const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { + tensorflow::Device* d = nullptr; + status->status = h->handle->OpDevice(&d); + return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" + : d->name().c_str(); } TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { - if (!IsCPU(h->d)) { + // TODO(agarwal): move this implementation inside TFE_TensorHandle. + tensorflow::Device* d = nullptr; + tensorflow::Device* op_device = nullptr; + const tensorflow::Tensor* t = nullptr; + status->status = h->handle->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) return nullptr; + if (!IsCPU(d)) { TF_SetStatus(status, TF_UNIMPLEMENTED, tensorflow::strings::StrCat( "TFE_TensorHandle can be resolved iff it is on CPU (this " "handle is on ", - h->d->name(), + d->name(), "). Consider using TFE_TensorHandleCopyToDevice to get a " "copy of the tensor on CPU") .c_str()); return nullptr; } - return tensorflow::TF_TensorFromTensor(h->t, status); + return tensorflow::TF_TensorFromTensor(*t, status); } +} // extern "C" -TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, - TFE_Context* ctx, - const char* device_name, - TF_Status* status) { - tensorflow::Device* dstd = ctx->devices()[0]; - if (device_name != nullptr && strlen(device_name) > 0) { - status->status = ctx->session->device_mgr->LookupDevice(device_name, &dstd); - if (!status->status.ok()) return nullptr; - } - - tensorflow::Device* srcd = h->d == nullptr ? ctx->devices()[0] : h->d; - bool is_same_device = - (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd)); - const bool dst_cpu = IsCPU(dstd); - const bool src_cpu = IsCPU(srcd); - // both_on_cpu can be true and yet is_same_device is false, if one of src/dst - // has device type XLA_CPU, and the other CPU. - const bool both_on_cpu = src_cpu && dst_cpu; - if (is_same_device || both_on_cpu) { - return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd); - } - tensorflow::Tensor* src = &(h->t); - if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT && - !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) { - TF_SetStatus( - status, TF_INVALID_ARGUMENT, - tensorflow::strings::StrCat("Can't copy Tensor with type ", - tensorflow::DataTypeString(src->dtype()), - " to device ", DeviceName(dstd), ".") - .c_str()); - return nullptr; - } - tensorflow::AllocatorAttributes attr; - if (src->dtype() == tensorflow::DT_VARIANT) { - attr.set_on_host(true); - } - tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape()); - if (src->shape().num_elements() == 0) { - return new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd); - } - tensorflow::DeviceContext* src_device_context = nullptr; - if (!src_cpu) { - src_device_context = srcd->tensorflow_gpu_device_info()->default_context; - } - tensorflow::DeviceContext* dst_device_context = nullptr; - if (!dst_cpu) { - dst_device_context = dstd->tensorflow_gpu_device_info()->default_context; - } - // TODO(ashankar): The Sync() call below may be more aggressive than - // necessary. It is based on knowledge of implementation details - that - // GPU devices are implemented using 3 streams - one for host->device copies, - // one for device->host copies and one for sending operations to the GPU. - // With that setup, Sync()ing across all 3 streams should be sufficient - // but more than necessary (since it waits for operations that might have - // nothing to do with this tensor to complete). - status->status = srcd->Sync(); - tensorflow::Notification n; - tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context, - srcd, dstd, tensorflow::AllocatorAttributes(), - tensorflow::AllocatorAttributes(), src, &dst, - [status, &n](const tensorflow::Status& s) { - status->status = s; - n.Notify(); - }); - n.WaitForNotification(); - return (TF_GetCode(status) == TF_OK) - ? new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd) - : nullptr; -} +extern "C" { TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status) { @@ -260,8 +223,7 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, status->status = tensorflow::AttrTypeMapForOp(name, &types); if (status->status.ok()) return new TFE_Op(ctx, name, types); if (TF_GetCode(status) == TF_NOT_FOUND) { - tensorflow::mutex_lock l(ctx->functions_mu); - if (ctx->func_lib_def.Find(name) != nullptr) { + if (ctx->context.FindFunctionByName(name)) { status->status = tensorflow::Status::OK(); return new TFE_Op(ctx, name, nullptr); } @@ -274,16 +236,14 @@ void TFE_DeleteOp(TFE_Op* op) { delete op; } void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { tensorflow::Device* d = nullptr; if (device_name != nullptr && strlen(device_name) > 0) { - status->status = - op->ctx->session->device_mgr->LookupDevice(device_name, &d); - if (!status->status.ok()) return; + status->status = op->ctx->context.FindDeviceByName(device_name, &d); } op->device = d; } const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { tensorflow::Device* device = - (op->device == nullptr) ? op->ctx->devices()[0] : op->device; + (op->device == nullptr) ? op->ctx->context.HostCPU() : op->device; return device->name().c_str(); } @@ -296,17 +256,19 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { } void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { - // Questionable heuristic ... - // - // Motivation: After an 'op' is placed on GPU because some of its earlier - // inputs are on GPU, we want to keep the 'op' there, even if some later - // inputs of it are not on GPU. - if (IsCPU(op->device) && !IsCPU(h->d)) { - op->device = h->d; + if (op->device == nullptr) { + // Questionable heuristic ... + // - If a device was explicitly set on the op, always use that. + // - If not, place on the first non-host device seen. + tensorflow::Device* d = nullptr; + // TODO(agarwal): This call may block if h is not ready. Avoid this if + // possible. + status->status = h->handle->Device(&d); + if (!status->status.ok()) return; + if (!IsCPU(d)) op->device = d; } - if (!status->status.ok()) return; - op->inputs.push_back(h->t); - op->input_devices.push_back(h->d); + h->handle->Ref(); + op->inputs.push_back(h->handle); op->attrs.NumInputs(op->inputs.size()); } @@ -468,14 +430,45 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, tensorflow::gtl::ArraySlice( funcs.get(), num_values)); } +} // extern "C" namespace { +// TODO(apassos) move to TensorHandle +tensorflow::TensorHandle* TFE_TensorHandleCopyToDevice_Internal( + tensorflow::TensorHandle* h, TFE_Context* ctx, const char* device_name, + TF_Status* status) { + status->status = ctx->context.GetStatus(); + if (!status->status.ok()) { + return nullptr; + } + tensorflow::Device* dstd = ctx->context.HostCPU(); + if (device_name != nullptr && strlen(device_name) > 0) { + status->status = + ctx->context.device_mgr()->LookupDevice(device_name, &dstd); + if (!status->status.ok()) return nullptr; + } + if (ctx->context.Async()) { + // Note that `h` may not be currently ready. However execution order will + // make sure that `h` is ready before the copy is actually done. + tensorflow::CopyToDeviceNode* node = + new tensorflow::CopyToDeviceNode(h, dstd, &ctx->context); + tensorflow::TensorHandle* output = node->dst(); + // Note that calling Add makes `node` accessible by the EagerExecutor + // thread. So further accesses need to be thread-safe. + ctx->context.ExecutorAdd(node); + return output; + } else { + tensorflow::TensorHandle* output = nullptr; + status->status = h->CopyToDevice(&ctx->context, dstd, &output); + return output; + } +} + tensorflow::Status ValidateInputTypeAndPlacement( TFE_Context* ctx, tensorflow::Device* host_device, tensorflow::Device* op_device, TFE_Op* op, - const tensorflow::OpKernel* kernel, - std::vector* copied_tensors) { + const tensorflow::OpKernel* kernel) { const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types(); if (memtypes.size() != op->inputs.size()) { return tensorflow::errors::InvalidArgument( @@ -484,14 +477,17 @@ tensorflow::Status ValidateInputTypeAndPlacement( for (int i = 0; i < op->inputs.size(); ++i) { const tensorflow::Device* expected_device = memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device; + tensorflow::TensorHandle* handle = op->inputs[i]; + tensorflow::Device* handle_device = nullptr; + TF_RETURN_IF_ERROR(handle->Device(&handle_device)); const tensorflow::Device* actual_device = - op->input_devices[i] == nullptr ? host_device : op->input_devices[i]; + handle_device == nullptr ? host_device : handle_device; if (expected_device != actual_device) { switch (TFE_ContextGetDevicePlacementPolicy(ctx)) { case TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32: // TODO(xpan): See if we could bubble python related error up // to python level. - if (op->inputs[i].dtype() == tensorflow::DT_INT32) { + if (handle->dtype == tensorflow::DT_INT32) { // Note: enabling silent copies of int32 tensors to match behavior // of graph mode. break; @@ -522,35 +518,202 @@ tensorflow::Status ValidateInputTypeAndPlacement( } // We are only here if the policy is warn or silent copies, so we should // trigger a copy. - TFE_TensorHandle original{op->inputs[i], op->input_devices[i]}; TF_Status* s = TF_NewStatus(); - TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice( - &original, ctx, expected_device->name().c_str(), s); - if (!s->status.ok()) { - tensorflow::Status status = s->status; - delete s; + tensorflow::TensorHandle* copied_tensor = + TFE_TensorHandleCopyToDevice_Internal( + handle, ctx, expected_device->name().c_str(), s); + tensorflow::Status status = s->status; + TF_DeleteStatus(s); + if (!status.ok()) { + if (copied_tensor != nullptr) copied_tensor->Unref(); return tensorflow::errors::Internal( "Failed copying input tensor from ", actual_device->name(), " to ", expected_device->name(), " in order to run ", op->name, ": ", status.error_message()); } - op->inputs[i] = copied_tensor->t; - copied_tensors->push_back(copied_tensor); - op->input_devices[i] = copied_tensor->d; - delete s; + handle->Unref(); + handle = copied_tensor; + op->inputs[i] = copied_tensor; } - if (op->inputs[i].dtype() != kernel->input_type(i)) { + if (handle->dtype != kernel->input_type(i)) { return tensorflow::errors::InvalidArgument( "cannot compute ", op->name, " as input #", i, " was expected to be a ", tensorflow::DataTypeString(kernel->input_type(i)), - " tensor but is a ", - tensorflow::DataTypeString(op->inputs[i].dtype()), " tensor"); + " tensor but is a ", tensorflow::DataTypeString(handle->dtype), + " tensor"); + } + } + return tensorflow::Status::OK(); +} + +tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef, + TFE_Context* ctx, TF_Status* status) { + tensorflow::DeviceSet ds; + for (tensorflow::Device* d : *ctx->context.devices()) { + ds.AddDevice(d); + } + tensorflow::DeviceTypeVector final_devices; + status->status = tensorflow::SupportedDeviceTypesForNode( + ds.PrioritizedDeviceTypeList(), ndef, &final_devices); + if (!status->status.ok()) { + return nullptr; + } + if (final_devices.empty()) { + status->status = tensorflow::errors::Internal( + "Could not find valid device for node ", ndef.DebugString()); + return nullptr; + } + for (tensorflow::Device* d : *ctx->context.devices()) { + if (d->device_type() == final_devices[0].type_string()) { + return d; + } + } + status->status = tensorflow::errors::Unknown( + "Could not find a device for node ", ndef.DebugString()); + return nullptr; +} + +tensorflow::Status Execute( + TFE_Context* ctx, tensorflow::Device* device, + const tensorflow::gtl::InlinedVector& + op_inputs, + tensorflow::KernelAndDevice* kernel, tensorflow::NodeExecStats* maybe_stats, + tensorflow::TensorHandle** retvals, int num_retvals) { + if (!ctx->context.SoftPlacement() && device == nullptr) { + device = ctx->context.HostCPU(); + } + + if (device == nullptr) { + // TODO(apassos) debug how the assignment below might return a different + // device from the one requested above. + device = kernel->device(); + } + + std::vector outputs(1); + const tensorflow::MemoryTypeVector* output_memory_types = nullptr; + output_memory_types = &kernel->kernel()->output_memory_types(); + std::vector inputs(op_inputs.size()); + for (int i = 0; i < op_inputs.size(); ++i) { + const tensorflow::Tensor* input_tensor = nullptr; + TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor)); + inputs[i] = *input_tensor; + } + // WARNING: kernel->Run utilizes the FunctionLibraryRuntime + // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def. + // But knowledge of the implementation + // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by + // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here. + // This is quite subtle. Re-work things to make this better? (Would it make + // sense for FunctionLibraryRuntime to ensure thread-safe access to + // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats + // for ops which are a part of functions. + // TODO(agarwal): change Run to take vector of handles ? + TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats)); + if (maybe_stats != nullptr) { + maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() - + maybe_stats->all_start_micros()); + tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); + if (ctx->context.ShouldStoreMetadata()) { + auto* step_stats = ctx->context.RunMetadataProto()->mutable_step_stats(); + // Lazily initialize the RunMetadata with information about all devices if + // this is the first call. + while (step_stats->dev_stats_size() < ctx->context.devices()->size()) { + step_stats->add_dev_stats(); + } + // Find the current device's index. + int device_idx = 0; + for (int i = 0; i < ctx->context.devices()->size(); ++i) { + if (ctx->context.devices()->at(i) == device) { + device_idx = i; + break; + } + } + // Populate the device stats for this device. + auto* dev_stats = step_stats->mutable_dev_stats(device_idx); + dev_stats->set_device(device->name()); + *dev_stats->add_node_stats() = *maybe_stats; + } + } + DCHECK_EQ(num_retvals, outputs.size()); + tensorflow::Device* op_device = IsCPU(device) ? nullptr : device; + for (int i = 0; i < num_retvals; ++i) { + tensorflow::Device* d = op_device; + if (d != nullptr && output_memory_types != nullptr && + (*output_memory_types)[i] == tensorflow::HOST_MEMORY) { + d = nullptr; + } + if (retvals[i] == nullptr) { + retvals[i] = new tensorflow::TensorHandle(outputs[i], d, op_device); + } else { + retvals[i]->SetTensorAndDevice(outputs[i], d, op_device); } } return tensorflow::Status::OK(); } +// TODO(agarwal): move EagerExecutor and EagerNode related code to a separate +// file. +class ExecuteNode : public tensorflow::EagerNode { + public: + ExecuteNode(TFE_Op* op, tensorflow::KernelAndDevice* kernel, + tensorflow::NodeExecStats* maybe_stats, + const tensorflow::DataTypeVector& output_dtypes, + TFE_TensorHandle** retvals, int num_retvals) + : tensorflow::EagerNode(op->ctx->context.NextId()), + ctx_(op->ctx), + op_device_(op->device), + inputs_(op->inputs), + kernel_(kernel), + maybe_stats_(maybe_stats), + retvals_(num_retvals) { + for (auto handle : inputs_) { + handle->Ref(); + } + TFE_Context* ctx = op->ctx; + for (int i = 0; i < num_retvals; ++i) { + tensorflow::TensorHandle* h = + new tensorflow::TensorHandle(id, output_dtypes[i], &ctx->context); + h->Ref(); + retvals[i] = new TFE_TensorHandle(h); + retvals_[i] = h; + } + } + + ~ExecuteNode() override { + for (auto handle : inputs_) { + handle->Unref(); + } + for (auto handle : retvals_) { + handle->Unref(); + } + } + + tensorflow::Status Run() override { + const tensorflow::Status status = + Execute(ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(), + retvals_.begin(), retvals_.size()); + if (status.ok()) { + return status; + } else { + return tensorflow::Status( + status.code(), + tensorflow::strings::StrCat("Got error, \"", status.error_message(), + "\" while executing kernel ", + kernel_->kernel()->def().DebugString())); + } + } + + private: + TFE_Context* ctx_; + tensorflow::Device* op_device_; + tensorflow::gtl::InlinedVector inputs_; + tensorflow::KernelAndDevice* kernel_; + std::unique_ptr maybe_stats_; + tensorflow::gtl::InlinedVector retvals_; +}; + + #ifdef TENSORFLOW_EAGER_USE_XLA // Synthesizes and returns a wrapper function over `op`, which must be a // primitive op (e.g. matmul). @@ -578,8 +741,7 @@ const tensorflow::FunctionDef* OpToFunction( TFE_Context* ctx = op->ctx; const tensorflow::OpRegistrationData* op_data; { - tensorflow::tf_shared_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.LookUp(op->name, &op_data); + status->status = ctx->context.FindFunctionOpData(op->name, &op_data); if (!status->status.ok()) { return nullptr; } @@ -616,7 +778,7 @@ const tensorflow::FunctionDef* OpToFunction( (*op_input_to_func_input)[i] = const_index; func_input_arg = signature->mutable_input_arg(const_index++); const_input_types->push_back( - static_cast(op->inputs[i].dtype())); + static_cast(op->inputs[i]->dtype)); } else if (op_input_arg.type() == tensorflow::DT_RESOURCE) { VLOG(1) << "For resource input, mapping op input " << i << " to func input " << resource_index; @@ -628,11 +790,11 @@ const tensorflow::FunctionDef* OpToFunction( (*op_input_to_func_input)[i] = arg_index; func_input_arg = signature->mutable_input_arg(arg_index++); arg_input_types->push_back( - static_cast(op->inputs[i].dtype())); + static_cast(op->inputs[i]->dtype)); } func_input_arg->set_name(op_input_arg.name()); - func_input_arg->set_type(op->inputs[i].dtype()); + func_input_arg->set_type(op->inputs[i]->dtype); } VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString(); @@ -675,10 +837,9 @@ const tensorflow::FunctionDef* OpToFunction( } VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString(); - tensorflow::mutex_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.AddFunctionDef(fdef); + status->status = ctx->context.AddFunctionDef(fdef); if (!status->status.ok()) return nullptr; - const auto ret = ctx->func_lib_def.Find(signature->name()); + const auto ret = ctx->context.FindFunctionDef(signature->name()); DCHECK(ret != nullptr); return ret; } @@ -697,8 +858,7 @@ std::unique_ptr BuildXlaLaunch(TFE_Op* op, TF_Status* status) { const tensorflow::FunctionDef* fdef; { - tensorflow::tf_shared_lock l(op->ctx->functions_mu); - fdef = op->ctx->func_lib_def.Find(op->name); + fdef = op->ctx->context.FindFunctionDef(op->name); } std::vector const_input_types; std::vector arg_input_types; @@ -725,21 +885,16 @@ std::unique_ptr BuildXlaLaunch(TFE_Op* op, TF_Status* status) { // Since input param reordering may have occurred between `op` and `launch_op` // via `op_input_to_func_input`, adjust the actual inputs accordingly. launch_op->inputs = op->inputs; - launch_op->input_devices = op->input_devices; + for (tensorflow::TensorHandle* h : launch_op->inputs) { + h->Ref(); + } if (!op_input_to_func_input.empty()) { DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size()); - if (!op->input_devices.empty()) { - DCHECK_EQ(op->input_devices.size(), op_input_to_func_input.size()); - } for (int i = 0; i < op_input_to_func_input.size(); ++i) { VLOG(1) << "mapping op input " << i << " to func input " << op_input_to_func_input[i]; launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i]; - if (!op->input_devices.empty()) { - launch_op->input_devices[op_input_to_func_input[i]] = - op->input_devices[i]; - } } } launch_op->attrs.NumInputs(op->inputs.size()); @@ -772,15 +927,18 @@ std::unique_ptr BuildXlaLaunch(TFE_Op* op, TF_Status* status) { return launch_op; } #endif // TENSORFLOW_EAGER_USE_XLA + } // namespace +extern "C" { + void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { TFE_Context* ctx = op->ctx; - // TODO(ashankar): ASSUMPTION: ctx->devices()[0] is always CPU - tensorflow::Device* device = - (op->device == nullptr) ? ctx->devices()[0] : op->device; - + status->status = ctx->context.GetStatus(); + if (!status->status.ok()) { + return; + } #ifdef TENSORFLOW_EAGER_USE_XLA std::unique_ptr xla_launch_op; if (op->use_xla && op->name != "_XlaLaunch") { @@ -791,45 +949,99 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, op = xla_launch_op.get(); } #endif // TENSORFLOW_EAGER_USE_XLA - - std::vector outputs(1); - const tensorflow::MemoryTypeVector* output_memory_types = nullptr; - tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name()); - tensorflow::KernelAndDevice* kernel; - { - tensorflow::tf_shared_lock l(ctx->cache_mu); - kernel = tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key); + // Ensure all resource-touching ops run in the device the resource is, + // regardless of anything else that has been specified. This is identical to + // the graph mode behavior. + for (int i = 0; i < op->inputs.size(); ++i) { + tensorflow::Device* input_op_device = nullptr; + status->status = op->inputs[i]->OpDevice(&input_op_device); + if (!status->status.ok()) return; + if (op->inputs[i]->dtype == tensorflow::DT_RESOURCE && + input_op_device != op->device) { + tensorflow::Device* d = + input_op_device == nullptr ? ctx->context.HostCPU() : input_op_device; + VLOG(1) << "Changing device of operation " << op->name << " to " + << d->name() << " because input #" << i + << " is a resource in this device."; + op->device = d; + } + } + tensorflow::Device* device = op->device; + if (!ctx->context.SoftPlacement() && device == nullptr) { + device = ctx->context.HostCPU(); } + + tensorflow::Fprint128 cache_key = + op->attrs.CacheKey(device == nullptr ? "unspecified" : device->name()); + tensorflow::KernelAndDevice* kernel = ctx->context.GetCachedKernel(cache_key); if (kernel == nullptr) { const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); - kernel = new tensorflow::KernelAndDevice(ctx->rendezvous); + if (ctx->context.SoftPlacement() && device == nullptr) { + device = SelectDevice(ndef, ctx, status); + if (!status->status.ok()) { + return; + } + } + CHECK(device != nullptr); + if (ctx->context.LogDevicePlacement()) { + LOG(INFO) << "Executing op " << ndef.op() << " in device " + << device->name(); + } + kernel = new tensorflow::KernelAndDevice(ctx->context.GetRendezvous()); // Knowledge of the implementation of Init (and in-turn // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def // will be accessed, so grab on to the lock. - // See WARNING comment below - would be nice to rework to avoid this - // subtlety. - tensorflow::tf_shared_lock l(ctx->functions_mu); - status->status = - tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel); + // See WARNING comment in Execute (before kernel->Run) - would be nice to + // rework to avoid this subtlety. + tensorflow::tf_shared_lock l(*ctx->context.FunctionsMu()); + status->status = tensorflow::KernelAndDevice::Init( + ndef, ctx->context.func_lib(device), kernel); if (!status->status.ok()) { delete kernel; return; } - tensorflow::mutex_lock ml(ctx->cache_mu); - tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); - } - std::vector copied_tensors; - status->status = ValidateInputTypeAndPlacement( - ctx, ctx->devices()[0], device, op, kernel->kernel(), &copied_tensors); - output_memory_types = &kernel->kernel()->output_memory_types(); - if (!status->status.ok()) { - for (auto* t : copied_tensors) { - TFE_DeleteTensorHandle(t); + // Update output_dtypes inside `kernel`. + const tensorflow::OpDef* op_def = nullptr; + const tensorflow::FunctionDef* function_def = + ctx->context.FuncLibDef()->Find(ndef.op()); + if (function_def != nullptr) { + op_def = &(function_def->signature()); + } + if (op_def == nullptr) { + status->status = OpDefForOp(ndef.op().c_str(), &op_def); + if (!status->status.ok()) { + return; + } } + tensorflow::DataTypeVector input_dtypes; + status->status = InOutTypesForNode(ndef, *op_def, &input_dtypes, + kernel->mutable_output_dtypes()); + if (!status->status.ok()) { + return; + } + ctx->context.AddKernelToCache(cache_key, kernel); + } + const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes(); + const int output_dtypes_size = output_dtypes.size(); + if (output_dtypes_size > *num_retvals) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + tensorflow::strings::StrCat("Expecting ", output_dtypes.size(), + " outputs, but *num_retvals is ", + *num_retvals) + .c_str()); return; } + *num_retvals = output_dtypes_size; + if (device == nullptr) { + // TODO(apassos) debug how the assignment below might return a different + // device from the one requested above. + device = kernel->device(); + } + status->status = ValidateInputTypeAndPlacement(ctx, ctx->context.HostCPU(), + device, op, kernel->kernel()); + if (!status->status.ok()) return; std::unique_ptr maybe_stats; - if (ctx->should_store_metadata.load()) { + if (ctx->context.ShouldStoreMetadata()) { maybe_stats.reset(new tensorflow::NodeExecStats); maybe_stats->set_node_name(op->name); maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros()); @@ -837,54 +1049,38 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros()); // TODO(apassos) track referenced tensors } - // WARNING: kernel->Run utilizes the FunctionLibraryRuntime - // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def, - // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation - // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by - // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here. - // This is quite subtle. Re-work things to make this better? (Would it make - // sense for FunctionLibraryRuntime to ensure thread-safe access to - // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats - // for ops which are a part of functions. - status->status = kernel->Run(&op->inputs, &outputs, maybe_stats.get()); - for (auto* t : copied_tensors) { - TFE_DeleteTensorHandle(t); - } - if (!status->status.ok()) return; - if (maybe_stats != nullptr) { - maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() - - maybe_stats->all_start_micros()); - tensorflow::mutex_lock ml(ctx->metadata_mu); - if (ctx->should_store_metadata.load()) { - auto* step_stats = ctx->run_metadata.mutable_step_stats(); - // Lazily initialize the RunMetadata with information about all devices if - // this is the first call. - while (step_stats->dev_stats_size() < ctx->devices().size()) { - step_stats->add_dev_stats(); - } - // Find the current device's index. - int device_idx = 0; - for (int i = 0; i < ctx->devices().size(); ++i) { - if (ctx->devices()[i] == device) { - device_idx = i; - break; - } - } - // Populate the device stats for this device. - auto* dev_stats = step_stats->mutable_dev_stats(device_idx); - dev_stats->set_device(device->name()); - *dev_stats->add_node_stats() = *maybe_stats; + if (ctx->context.Async()) { + // Note that for async mode, execution order will make sure that all + // input handles are ready before executing them. + // TODO(agarwal): Consider executing "cheap" kernels inline for performance. + tensorflow::EagerNode* node = + new ExecuteNode(op, kernel, maybe_stats.release(), output_dtypes, + retvals, *num_retvals); + ctx->context.ExecutorAdd(node); + } else { + // Execute checks if retvals[i] is nullptr or not to figure if it needs to + // allocate it. + std::vector handle_retvals(*num_retvals, + nullptr); + status->status = + Execute(op->ctx, op->device, op->inputs, kernel, maybe_stats.get(), + handle_retvals.data(), *num_retvals); + for (int i = 0; i < *num_retvals; ++i) { + retvals[i] = new TFE_TensorHandle(handle_retvals[i]); } } - *num_retvals = std::min(*num_retvals, outputs.size()); - for (int i = 0; i < *num_retvals; ++i) { - tensorflow::Device* d = IsCPU(device) ? nullptr : device; - if (d != nullptr && output_memory_types != nullptr && - (*output_memory_types)[i] == tensorflow::HOST_MEMORY) { - d = nullptr; - } - retvals[i] = new TFE_TensorHandle(outputs[i], d); +} + +TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, + TFE_Context* ctx, + const char* device_name, + TF_Status* status) { + tensorflow::TensorHandle* handle = TFE_TensorHandleCopyToDevice_Internal( + h->handle, ctx, device_name, status); + if (status->status.ok()) { + return new TFE_TensorHandle(handle); } + return nullptr; } void TFE_ContextAddFunctionDef(TFE_Context* ctx, @@ -896,46 +1092,127 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); return; } - tensorflow::mutex_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.AddFunctionDef(function_def); + status->status = ctx->context.AddFunctionDef(function_def); } void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status) { - tensorflow::mutex_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.AddFunctionDef(function->fdef); + status->status = ctx->context.AddFunctionDef(function->fdef); +} + +void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { + ctx->context.SetShouldStoreMetadata(true); +} + +void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { + ctx->context.SetShouldStoreMetadata(false); } } // extern "C" TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) { - return new TFE_TensorHandle(t, nullptr); + return new TFE_TensorHandle(t, nullptr, nullptr); } const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( TFE_TensorHandle* h, TF_Status* status) { - if (h->d != nullptr) { + tensorflow::Device* d = nullptr; + tensorflow::Device* op_device = nullptr; + const tensorflow::Tensor* t = nullptr; + status->status = h->handle->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) return nullptr; + if (d != nullptr) { status->status = tensorflow::errors::FailedPrecondition( "TFE_TensorHandle is placed in device (not host) memory. Cannot return " "a tensorflow::Tensor"); return nullptr; } - return &h->t; + return t; } -void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - ctx->should_store_metadata.store(true); +void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, + TF_Status* status) { + TFE_ContextAsyncWait(ctx, status); + if (!status->status.ok()) return; + tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); + status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf); + ctx->context.RunMetadataProto()->Clear(); } -void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - tensorflow::mutex_lock ml(ctx->metadata_mu); - ctx->should_store_metadata.store(false); - ctx->run_metadata.Clear(); +namespace { +TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, + TF_Status* status) { + TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status); + for (const auto& attr : func.attr()) { + if (TF_GetCode(status) != TF_OK) return nullptr; + SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + } + return func_op; } +} // namespace -void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, - TF_Status* status) { - tensorflow::mutex_lock ml(ctx->metadata_mu); - status->status = MessageToBuffer(ctx->run_metadata, buf); - ctx->run_metadata.Clear(); +namespace tensorflow { +void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, + const tensorflow::AttrValue& default_value, + const char* attr_name, TF_Status* status) { + switch (default_value.value_case()) { + case tensorflow::AttrValue::kS: + TFE_OpSetAttrString(op, attr_name, default_value.s().data()); + break; + case tensorflow::AttrValue::kI: + TFE_OpSetAttrInt(op, attr_name, static_cast(default_value.i())); + break; + case tensorflow::AttrValue::kF: + TFE_OpSetAttrFloat(op, attr_name, default_value.f()); + break; + case tensorflow::AttrValue::kB: + TFE_OpSetAttrBool(op, attr_name, default_value.b()); + break; + case tensorflow::AttrValue::kType: + TFE_OpSetAttrType(op, attr_name, + static_cast(default_value.type())); + break; + case tensorflow::AttrValue::kShape: { + const auto& tensor_shape = default_value.shape(); + if (tensor_shape.unknown_rank()) { + TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status); + } else { + const auto num_dims = tensor_shape.dim_size(); + std::unique_ptr dims(new int64_t[num_dims]); + for (int i = 0; i < num_dims; ++i) { + dims[i] = tensor_shape.dim(i).size(); + } + TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status); + } + } break; + case tensorflow::AttrValue::kFunc: { + const auto func_op = GetFunc(ctx, default_value.func(), status); + if (TF_GetCode(status) != TF_OK) return; + // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList + // require TFE_Op* and just convert it internally a NameAttrValue, so + // consider adding an overload to the C API to make this case easier. + TFE_OpSetAttrFunction(op, attr_name, func_op); + } break; + case tensorflow::AttrValue::kList: + TF_FALLTHROUGH_INTENDED; + case tensorflow::AttrValue::kTensor: + TF_FALLTHROUGH_INTENDED; + case tensorflow::AttrValue::kPlaceholder: + TF_FALLTHROUGH_INTENDED; + case tensorflow::AttrValue::VALUE_NOT_SET: + TF_SetStatus( + status, TF_UNIMPLEMENTED, + tensorflow::strings::StrCat("Unable to get setfor default value: ", + default_value.DebugString()) + .data()); + } +} +} // namespace tensorflow + + +TFE_Op::~TFE_Op() { + for (tensorflow::TensorHandle* h : inputs) { + h->Unref(); + } } diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 7a321b54da343fd2b8912187bc620c1e7456db0c..a5029bf2115c7dac54d03b8bc6397bc63349c068 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -61,7 +61,8 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig( // Controls how to act when we try to run an operation on a given device but // some input tensors are not on that device. typedef enum TFE_ContextDevicePlacementPolicy { - // Running operations with input tensors on the wrong device will fail. + // Running operations with input tensors on the wrong device will fail. When + // soft placement is enabled acts like TFE_DEVICE_PLACEMENT_SILENT. TFE_DEVICE_PLACEMENT_EXPLICIT = 0, // Copy the tensor to the right device but log a warning. TFE_DEVICE_PLACEMENT_WARN = 1, @@ -69,10 +70,16 @@ typedef enum TFE_ContextDevicePlacementPolicy { // operation will be blocked till the copy completes. TFE_DEVICE_PLACEMENT_SILENT = 2, // Default placement policy which silently copies int32 tensors but not other - // dtypes. + // dtypes. When soft placement is enabled acts like + // TFE_DEVICE_PLACEMENT_SILENT. TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, } TFE_ContextDevicePlacementPolicy; +// Sets the default execution mode (sync/async). Note that this can be +// overridden per thread using TFE_ContextSetAsyncForThread. +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, + unsigned char async); + TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); @@ -108,6 +115,30 @@ TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalDevicePlacementPolicy( TF_CAPI_EXPORT extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(TFE_Context*); +// Overrides the execution mode (sync/async) for the current thread. +TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*, + unsigned char async, + TF_Status* status); + +// Causes the calling thread to block till all ops dispatched in async mode +// have been executed. Note that "execution" here refers to kernel execution / +// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee +// that lower level device queues (like GPU streams) have been flushed. +// +// This call may not block for execution of ops enqueued concurrently with this +// call. +TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context*, + TF_Status* status); + +// When an error happens, any pending operations are discarded and newly issued +// ops return an error. This call clears the error state and re-enables +// execution of newly issued ops. +// +// Note that outputs of discarded ops remain in a corrupt state and should not +// be used for future calls. +// TODO(agarwal): mark the affected handles and raise errors if they are used. +TF_CAPI_EXPORT extern void TFE_ContextAsyncClearError(TFE_Context*); + // A handle to a tensor on a device. // // Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape, @@ -117,13 +148,21 @@ typedef struct TFE_TensorHandle TFE_TensorHandle; TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status); +// Indicates that the caller will not be using `h` any more. TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h); -TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h); +// This function will block till the operation that produces `h` has completed. +TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h, + TF_Status* status); +// This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, - int dim_index); + int dim_index, + TF_Status* status); +// This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( - TFE_TensorHandle* h); + TFE_TensorHandle* h, TF_Status* status); + +// This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status); @@ -133,6 +172,9 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, // that shares the underlying buffer. Otherwise, it currently requires at least // one of the source or destination devices to be CPU (i.e., for the source or // destination tensor to be placed in host memory). +// If async execution is enabled, the copy may be enqueued and the call will +// return "non-ready" handle. Else, this function returns after the copy has +// been done. TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice( TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, TF_Status* status); @@ -153,6 +195,7 @@ typedef struct TFE_Op TFE_Op; TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status); + TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op); TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name, @@ -238,13 +281,21 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunctionList(TFE_Op* op, int num_values); // Execute the operation defined by 'op' and return handles to computed -// tensors in 'retvals'. +// tensors in `retvals`. +// +// 'retvals' must point to a pre-allocated array of TFE_TensorHandle* and +// '*num_retvals' should be set to the size of this array. It is an error if +// the size of 'retvals' is less than the number of outputs. This call sets +// *num_retvals to the number of outputs. // -// 'retvals' must point to a pre-allocated array of TFE_TensorHandle* -// and '*num_retvals' should be set to the size of this array. +// If async execution is enabled, the call may simply enqueue the execution +// and return "non-ready" handles in `retvals`. Note that any handles contained +// in 'op' should not be mutated till the kernel execution actually finishes. // -// On return, 'num_retvals' will be set to the actual number of outputs -// returned by the operation. +// For sync execution, if any of the inputs to `op` are not ready, this call +// will block till they become ready and then return when the kernel execution +// is done. +// TODO(agarwal): change num_retvals to int from int*. TF_CAPI_EXPORT extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status); @@ -270,6 +321,8 @@ TF_CAPI_EXPORT extern void TFE_ContextDisableRunMetadata(TFE_Context* ctx); // Populates the passed-in buffer with a serialized RunMetadata protocol buffer // containing any run metadata information accumulated so far and clears this // information. +// If async mode is enabled, this call blocks till all currently pending ops are +// done. TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 7b9f1db02ed9c53a280c7bd1284165cac4fb6353..e6d2ab75ffd2849d7fafb630eb452122ef36339b 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -19,7 +19,9 @@ limitations under the License. #include #include +#include #include +#include #include #include #include @@ -28,82 +30,56 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/runtime.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/public/version.h" + struct TFE_ContextOptions { TF_SessionOptions session_options; + // true if async execution is enabled. + bool async = false; TFE_ContextDevicePlacementPolicy policy{ TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32}; }; struct TFE_Context { - explicit TFE_Context(const TFE_ContextOptions& opts, TF_Session* s) - : policy(opts.policy), - session(s), - rendezvous(new tensorflow::IntraProcessRendezvous(s->device_mgr)), - pflr(new tensorflow::ProcessFunctionLibraryRuntime( - session->device_mgr, opts.session_options.options.env, - TF_GRAPH_DEF_VERSION, &func_lib_def, {})) {} - - const TFE_ContextDevicePlacementPolicy policy; - - // Note: we cannot use C++11 thread_local here as there is no concept of a - // thread-local-object-local variable in C++11. - tensorflow::mutex policy_map_mu; - std::unordered_map - thread_local_policies GUARDED_BY(policy_map_mu); - - // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph. - TF_Session* const session; - tensorflow::Rendezvous* const rendezvous; - - tensorflow::mutex functions_mu; - tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){ - tensorflow::OpRegistry::Global(), {}}; - - // One FunctionLibraryRuntime per device. - // func_libs[i] is the FunctionLibraryRuntime corresponding to - // session->devices[i]. - const std::unique_ptr pflr; - - tensorflow::mutex cache_mu; - std::unordered_map - kernel_cache GUARDED_BY(cache_mu); - - tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) const { - return pflr->GetFLR(d->name()); - } - - const std::vector& devices() { return session->devices; } - - // Whether we should compute RunMetadata. - std::atomic should_store_metadata{false}; - tensorflow::mutex metadata_mu; - tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu); + explicit TFE_Context(const tensorflow::SessionOptions& opts, + TFE_ContextDevicePlacementPolicy default_policy, + bool async, + std::unique_ptr device_mgr, + tensorflow::Rendezvous* rendezvous) + : context(opts, + static_cast( + default_policy), + async, std::move(device_mgr), rendezvous) {} + + tensorflow::EagerContext context; }; struct TFE_TensorHandle { - TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d) - : t(t), d(d) {} - - tensorflow::Tensor t; - // TODO(ashankar): d == nullptr iff local CPU - // This was expedient, but perhaps worth revisiting ('d' should always be a - // valid pointer?) - // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are - // provided with the appropriate TFE_Context. - // - // TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a - // TFE_TensorHandle does not outlive the TFE_Context from which it came? - tensorflow::Device* d; + TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d, + tensorflow::Device* op_device) + : handle(new tensorflow::TensorHandle(t, d, op_device)) {} + + TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype, + tensorflow::EagerContext* ctx) + : handle(new tensorflow::TensorHandle(node_id, dtype, ctx)) {} + + TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {} + + tensorflow::TensorHandle* handle; }; struct TFE_Op { @@ -112,16 +88,24 @@ struct TFE_Op { TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {} + ~TFE_Op(); + bool const is_function() const { return attr_types == nullptr; } TFE_Context* ctx; // Must outlive the TFE_Op. const tensorflow::string name; tensorflow::AttrBuilder attrs; const tensorflow::AttrTypeMap* attr_types; - std::vector inputs; - std::vector input_devices; + tensorflow::gtl::InlinedVector inputs; tensorflow::Device* device; bool use_xla = false; }; +namespace tensorflow { +// Set an AttrValue on the op. Doesn't handle the list types. +void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, + const tensorflow::AttrValue& default_value, + const char* attr_name, TF_Status* status); +} // namespace tensorflow + #endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 4a3ecbc0abb16296a84c0d2184dc3fc9f7f3ebb4..2268aba90d60b7b2f10e99f64fd7aa3ae719badb 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -29,6 +29,20 @@ using tensorflow::string; namespace { +TFE_TensorHandle* DoubleTestMatrixTensorHandle() { + int64_t dims[] = {2, 2}; + double data[] = {1.0, 2.0, 3.0, 4.0}; + TF_Tensor* t = TF_AllocateTensor( + TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + TFE_TensorHandle* TestMatrixTensorHandle() { int64_t dims[] = {2, 2}; float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; @@ -43,6 +57,20 @@ TFE_TensorHandle* TestMatrixTensorHandle() { return th; } +TFE_TensorHandle* TestMatrixTensorHandle3X2() { + int64_t dims[] = {3, 2}; + double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { TF_Status* status = TF_NewStatus(); @@ -139,10 +167,12 @@ void BM_InitOp(int iters) { } BENCHMARK(BM_InitOp); -void BM_Execute(int iters) { +void BM_Execute(int iters, int async) { tensorflow::testing::StopTiming(); + tensorflow::testing::SetLabel(async ? "ExecuteAsync" : "Execute"); TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -156,6 +186,9 @@ void BM_Execute(int iters) { TFE_Execute(matmul, &retvals[0], &num_retvals, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); } + if (async) { + TFE_ContextAsyncWait(ctx, status); + } tensorflow::testing::StopTiming(); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); @@ -163,7 +196,7 @@ void BM_Execute(int iters) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } -BENCHMARK(BM_Execute); +BENCHMARK(BM_Execute)->Arg(0)->Arg(1); TEST(CAPI, Context) { TF_Status* status = TF_NewStatus(); @@ -205,10 +238,11 @@ TEST(CAPI, TensorHandle) { TFE_DeleteTensorHandle(h); } -TEST(CAPI, TensorHandleCopyBetweenDevices) { +void TensorHandleCopyBetweenDevices(bool async) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status.get()); TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -274,10 +308,56 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) { EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } -TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { +TEST(CAPI, TensorHandleCopyBetweenDevices) { + TensorHandleCopyBetweenDevices(false); +} + +TEST(CAPI, TensorHandleCopyBetweenDevicesAsync) { + TensorHandleCopyBetweenDevices(true); +} + +void TensorHandleCopyBetweenDevicesError(bool async) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + const char* kErrorDevice = "NoSuchDevice:0"; + TFE_TensorHandle* hdevice = + TFE_TensorHandleCopyToDevice(hcpu, ctx, kErrorDevice, status.get()); + EXPECT_NE(TF_OK, TF_GetCode(status.get())); + const char* msg = "NoSuchDevice:0 unknown device"; + EXPECT_TRUE(strstr(TF_Message(status.get()), msg) != nullptr) + << TF_Message(status.get()); + TF_SetStatus(status.get(), TF_OK, ""); + const char* kCPUDevice = "CPU:0"; + TFE_TensorHandle* hcopy = + TFE_TensorHandleCopyToDevice(hcpu, ctx, kCPUDevice, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_ContextAsyncWait(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())); + TFE_DeleteTensorHandle(hcopy); + TFE_DeleteTensorHandle(hcpu); + if (hdevice != nullptr) TFE_DeleteTensorHandle(hdevice); + TFE_DeleteContext(ctx, status.get()); +} + +TEST(CAPI, TensorHandleCopyBetweenDevicesError) { + TensorHandleCopyBetweenDevicesError(false); +} + +TEST(CAPI, TensorHandleCopyBetweenDevicesErrorAsync) { + TensorHandleCopyBetweenDevicesError(true); +} + +void TensorHandleCopyBetweenTwoGPUDevices(bool async) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status.get()); TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -332,11 +412,20 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } -TEST(CAPI, TensorHandleSilentCopy) { +TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { + TensorHandleCopyBetweenTwoGPUDevices(false); +} + +TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) { + TensorHandleCopyBetweenTwoGPUDevices(true); +} + +void TensorHandleSilentCopy(bool async) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status.get()); TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -366,14 +455,20 @@ TEST(CAPI, TensorHandleSilentCopy) { TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); + TFE_ContextAsyncWait(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_DeleteContext(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } -TEST(CAPI, TensorHandleSilentCopyLocal) { +TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false); } +TEST(CAPI, TensorHandleSilentCopyAsync) { TensorHandleSilentCopy(true); } + +void TensorHandleSilentCopyLocal(bool async) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_EXPLICIT); TFE_Context* ctx = TFE_NewContext(opts, status.get()); @@ -407,11 +502,17 @@ TEST(CAPI, TensorHandleSilentCopyLocal) { TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); + TFE_ContextAsyncWait(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_DeleteContext(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } +TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); } +TEST(CAPI, TensorHandleSilentCopyLocalAsync) { + TensorHandleSilentCopyLocal(true); +} -TEST(CAPI, SetAndGetOpDevices) { +void SetAndGetOpDevices(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); @@ -442,27 +543,28 @@ TEST(CAPI, SetAndGetOpDevices) { TF_DeleteStatus(status); } -TEST(CAPI, Execute_MatMul_CPU) { +void Execute_MatMul_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[2] = {nullptr, nullptr}; + int num_retvals = 2; TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(1, num_retvals); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - ASSERT_EQ(1, num_retvals); TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); @@ -474,7 +576,101 @@ TEST(CAPI, Execute_MatMul_CPU) { EXPECT_EQ(22, product[3]); TF_DeleteStatus(status); } +TEST(CAPI, Execute_MatMul_CPU) { Execute_MatMul_CPU(false); } +TEST(CAPI, Execute_MatMul_CPUAsync) { Execute_MatMul_CPU(true); } + +void Execute_MatMul_CPU_Runtime_Error(bool async) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m1 = TestMatrixTensorHandle(); + TFE_TensorHandle* m2 = TestMatrixTensorHandle3X2(); + TFE_Op* matmul = MatMulOp(ctx, m1, m2); + TFE_Op* matmul2 = MatMulOp(ctx, m1, m1); + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + TFE_DeleteOp(matmul); + if (!async) { + EXPECT_NE(TF_OK, TF_GetCode(status)); + } else { + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + EXPECT_NE(TF_OK, TF_GetCode(status)); + EXPECT_EQ(nullptr, t); + const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]"; + EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr) + << TF_Message(status); + // Since error is not cleared, the following copy with correct device will + // still fail. + TF_SetStatus(status, TF_OK, ""); + TFE_DeleteTensorHandle(retvals[0]); + retvals[0] = nullptr; + TFE_Execute(matmul2, &retvals[0], &num_retvals, status); + EXPECT_NE(TF_OK, TF_GetCode(status)); + TFE_ContextAsyncClearError(ctx); + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + } + // Following works in async mode since TFE_ContextAsyncClearError was called. + TF_SetStatus(status, TF_OK, ""); + if (retvals[0] != nullptr) { + TFE_DeleteTensorHandle(retvals[0]); + } + retvals[0] = nullptr; + TFE_Execute(matmul2, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteTensor(t); + TFE_DeleteOp(matmul2); + TFE_DeleteTensorHandle(m1); + TFE_DeleteTensorHandle(m2); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx, status); + TF_DeleteStatus(status); +} +TEST(CAPI, Execute_MatMul_CPU_Runtime_Error) { + Execute_MatMul_CPU_Runtime_Error(false); +} +TEST(CAPI, Execute_MatMul_CPU_Runtime_ErrorAsync) { + Execute_MatMul_CPU_Runtime_Error(true); +} + +void Execute_MatMul_CPU_Type_Error(bool async) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m1 = TestMatrixTensorHandle(); + TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m1, m2); + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_NE(TF_OK, TF_GetCode(status)); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m1); + TFE_DeleteTensorHandle(m2); + if (retvals[0] != nullptr) { + TFE_DeleteTensorHandle(retvals[0]); + } + TFE_DeleteContext(ctx, status); + TF_DeleteStatus(status); +} +TEST(CAPI, Execute_MatMul_CPU_Type_Error) { + Execute_MatMul_CPU_Type_Error(false); +} +TEST(CAPI, Execute_MatMul_CPU_Type_ErrorAsync) { + Execute_MatMul_CPU_Type_Error(true); +} TEST(CAPI, Execute_Min_CPU) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -485,8 +681,8 @@ TEST(CAPI, Execute_Min_CPU) { TFE_TensorHandle* input = TestMatrixTensorHandle(); TFE_TensorHandle* axis = TestAxisTensorHandle(); TFE_Op* minOp = MinOp(ctx, input, axis); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; TFE_Execute(minOp, &retvals[0], &num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(minOp); @@ -509,9 +705,10 @@ TEST(CAPI, Execute_Min_CPU) { } #ifdef TENSORFLOW_EAGER_USE_XLA -TEST(CAPI, Execute_MatMul_XLA_CPU) { +void Execute_MatMul_XLA_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -521,15 +718,14 @@ TEST(CAPI, Execute_MatMul_XLA_CPU) { TFE_OpSetXLACompilation(matmul, true); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; TFE_Execute(matmul, &retvals[0], &num_retvals, status); // Running a primitive TF operator via XLA is not yet supported. ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); EXPECT_EQ(1, num_retvals); @@ -545,13 +741,16 @@ TEST(CAPI, Execute_MatMul_XLA_CPU) { EXPECT_EQ(10, product[1]); EXPECT_EQ(15, product[2]); EXPECT_EQ(22, product[3]); - + TFE_DeleteContext(ctx, status); TF_DeleteStatus(status); } +TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); } +TEST(CAPI, Execute_MatMul_XLA_CPUAsync) { Execute_MatMul_XLA_CPU(true); } -TEST(CAPI, Execute_Min_XLA_CPU) { +void Execute_Min_XLA_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -562,14 +761,13 @@ TEST(CAPI, Execute_Min_XLA_CPU) { TFE_OpSetXLACompilation(minOp, true); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; TFE_Execute(minOp, &retvals[0], &num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(minOp); TFE_DeleteTensorHandle(input); TFE_DeleteTensorHandle(axis); - TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(1, num_retvals); @@ -582,13 +780,17 @@ TEST(CAPI, Execute_Min_XLA_CPU) { TF_DeleteTensor(t); EXPECT_EQ(1, output[0]); EXPECT_EQ(3, output[1]); + TFE_DeleteContext(ctx, status); TF_DeleteStatus(status); } +TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); } +TEST(CAPI, Execute_Min_XLA_CPUAsync) { Execute_Min_XLA_CPU(true); } #endif // TENSORFLOW_EAGER_USE_XLA -TEST(CAPI, ExecuteWithTracing) { +void ExecuteWithTracing(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); TFE_ContextEnableRunMetadata(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -596,8 +798,8 @@ TEST(CAPI, ExecuteWithTracing) { TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; TFE_Execute(matmul, &retvals[0], &num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(matmul); @@ -609,12 +811,12 @@ TEST(CAPI, ExecuteWithTracing) { EXPECT_TRUE( rm.ParseFromString({reinterpret_cast(b->data), b->length})); TF_DeleteBuffer(b); - TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(1, num_retvals); TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); @@ -626,6 +828,8 @@ TEST(CAPI, ExecuteWithTracing) { EXPECT_EQ(22, product[3]); TF_DeleteStatus(status); } +TEST(CAPI, ExecuteWithTracing) { ExecuteWithTracing(false); } +TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithTracing(true); } TEST(CAPI, Function_ident_CPU) { // First create a simple identity function. @@ -657,32 +861,37 @@ TEST(CAPI, Function_ident_CPU) { ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteFunction(fn); - TF_Tensor* t = - TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); - *reinterpret_cast(TF_TensorData(t)) = 42; - TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteTensor(t); + for (bool async : {false, true, false}) { + TFE_ContextSetAsyncForThread(ctx, static_cast(async), + status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK); + TF_Tensor* t = + TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); + *reinterpret_cast(TF_TensorData(t)) = 42; + TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteTensor(t); - TFE_Op* op = TFE_NewOp(ctx, "ident", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_OpAddInput(op, h, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_Op* op = TFE_NewOp(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_OpAddInput(op, h, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - std::vector result; - result.push_back(nullptr); - int num_retvals = 1; - TFE_Execute(op, result.data(), &num_retvals, status); - TFE_DeleteOp(op); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - ASSERT_EQ(num_retvals, 1); + std::vector result; + result.push_back(nullptr); + int num_retvals = 1; + TFE_Execute(op, result.data(), &num_retvals, status); + TFE_DeleteOp(op); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + ASSERT_EQ(num_retvals, 1); - TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); - TFE_DeleteTensorHandle(h); - TF_DeleteTensor(r); - TFE_DeleteTensorHandle(result[0]); + TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); + TFE_DeleteTensorHandle(h); + TF_DeleteTensor(r); + TFE_DeleteTensorHandle(result[0]); + } TFE_DeleteContext(ctx, status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); @@ -719,35 +928,40 @@ TEST(CAPI, Function_ident_XLA_CPU) { ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteFunction(fn); - TF_Tensor* t = - TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); - *reinterpret_cast(TF_TensorData(t)) = 42; - TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteTensor(t); + for (bool async : {false, true, false}) { + TFE_ContextSetAsyncForThread(ctx, static_cast(async), + status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK); + TF_Tensor* t = + TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); + *reinterpret_cast(TF_TensorData(t)) = 42; + TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteTensor(t); - TFE_Op* op = TFE_NewOp(ctx, "ident", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_OpAddInput(op, h, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_Op* op = TFE_NewOp(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_OpAddInput(op, h, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - // Now run it via XLA. - TFE_OpSetXLACompilation(op, true); + // Now run it via XLA. + TFE_OpSetXLACompilation(op, true); - std::vector result; - result.push_back(nullptr); - int num_retvals = 1; - TFE_Execute(op, result.data(), &num_retvals, status); - TFE_DeleteOp(op); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - ASSERT_EQ(num_retvals, 1); + std::vector result; + result.push_back(nullptr); + int num_retvals = 1; + TFE_Execute(op, result.data(), &num_retvals, status); + TFE_DeleteOp(op); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + ASSERT_EQ(num_retvals, 1); - TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); - TFE_DeleteTensorHandle(h); - TF_DeleteTensor(r); - TFE_DeleteTensorHandle(result[0]); + TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); + TFE_DeleteTensorHandle(h); + TF_DeleteTensor(r); + TFE_DeleteTensorHandle(result[0]); + } TFE_DeleteContext(ctx, status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); @@ -788,9 +1002,10 @@ string MatMulFunction() { return def.SerializeAsString(); } -TEST(CAPI, FunctionDefAndExecute) { +void FunctionDefAndExecute(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -827,11 +1042,16 @@ TEST(CAPI, FunctionDefAndExecute) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } +TEST(CAPI, FunctionDefAndExecute) { FunctionDefAndExecute(false); } +TEST(CAPI, FunctionDefAndExecuteAsync) { FunctionDefAndExecute(true); } -void BM_ExecuteFunction(int iters) { +void BM_ExecuteFunction(int iters, int async) { tensorflow::testing::StopTiming(); + tensorflow::testing::SetLabel(async ? "ExecuteFunctionAsync" + : "ExecuteFunction"); TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -853,6 +1073,9 @@ void BM_ExecuteFunction(int iters) { TFE_Execute(matmul, &retval[0], &num_retvals, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); } + if (async) { + TFE_ContextAsyncWait(ctx, status); + } tensorflow::testing::StopTiming(); TFE_DeleteTensorHandle(m); TFE_DeleteTensorHandle(retval[0]); @@ -860,7 +1083,7 @@ void BM_ExecuteFunction(int iters) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } -BENCHMARK(BM_ExecuteFunction); +BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1); TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, TF_Status* status) { @@ -932,7 +1155,8 @@ TEST(CAPI, Variables) { ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(1, num_retvals); EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle)); - EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle)); + EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle, status)); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float value = 0.0f; TF_Tensor* t = TFE_TensorHandleResolve(value_handle, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -974,7 +1198,8 @@ void BM_ReadVariable(int iters) { CHECK_EQ(1, num_retvals); CHECK(h); CHECK_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); - CHECK_EQ(0, TFE_TensorHandleNumDims(h)); + CHECK_EQ(0, TFE_TensorHandleNumDims(h, status)); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); h = nullptr; } tensorflow::testing::StopTiming(); diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc index f77a937f1ffc2d146224cb3191a5ca127daefc22..abe2793ce894ad07c252575c5d55d98342916eac 100644 --- a/tensorflow/c/eager/runtime.cc +++ b/tensorflow/c/eager/runtime.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/eager/runtime.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -41,17 +42,26 @@ const uint32 kIsList = 1U << 31; } // namespace +Status OpDefForOp(const char* op_name, const OpDef** op_def) { + const OpRegistrationData* op_reg_data = nullptr; + Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data); + if (s.ok()) { + *op_def = &op_reg_data->op_def; + } + return s; +} + Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) { mutex_lock l(g_op_name_to_attr_type_map_lock); *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name); if (*out != nullptr) return Status::OK(); - const OpRegistrationData* op_reg_data = nullptr; - Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data); + const OpDef* op_def = nullptr; + Status s = OpDefForOp(op_name, &op_def); if (!s.ok()) return s; std::unique_ptr m(new AttrTypeMap); // TODO(agarwal): Avoid having to create this "registry" at runtime, // perhaps can be done at op registration time? - for (const auto& attr : op_reg_data->op_def.attr()) { + for (const auto& attr : op_def->attr()) { string type = attr.type(); const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0); if (is_list) { @@ -86,22 +96,6 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) { return Status::OK(); } -Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, - TF_AttrType* out, unsigned char* is_list) { - auto* t = gtl::FindOrNull(m, attr_name); - if (t == nullptr) { - return errors::InvalidArgument("Attribute '", attr_name, - "' does not exist for this operation"); - } - *out = static_cast(*t & ~kIsList); - if (*t & kIsList) { - *is_list = 1; - } else { - *is_list = 0; - } - return Status::OK(); -} - #define DEFINE_SET_ATTR(value_type, value_field) \ template <> \ AttrBuilder& AttrBuilder::Set(StringPiece attr_name, value_type&& value) { \ @@ -159,6 +153,22 @@ const NodeDef& AttrBuilder::BuildNodeDef() { return *node_def_; } +Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, + TF_AttrType* out, unsigned char* is_list) { + auto* t = gtl::FindOrNull(m, attr_name); + if (t == nullptr) { + return errors::InvalidArgument("Attribute '", attr_name, + "' does not exist for this operation"); + } + *out = static_cast(*t & ~kIsList); + if (*t & kIsList) { + *is_list = 1; + } else { + *is_list = 0; + } + return Status::OK(); +} + namespace { inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a, const tensorflow::Fprint128& b) { @@ -236,93 +246,4 @@ void AttrBuilder::MayBeInitializeNodeDef() { } } -// static -Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef, - KernelAndDevice* out) { - OpKernel* k = nullptr; - Status s = CreateOpKernel(device->device_type().c_str(), device, - device->GetAllocator(AllocatorAttributes()), - nullptr, ndef, TF_GRAPH_DEF_VERSION, &k); - out->device_ = device; - out->kernel_.reset(k); - out->flib_ = nullptr; - return s; -} - -// static -Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, - KernelAndDevice* out) { - OpKernel* k = nullptr; - Status s = flib->CreateKernel(ndef, &k); - out->device_ = flib->device(); - out->kernel_.reset(k); - out->flib_ = flib; - return s; -} - -Status KernelAndDevice::Run(std::vector* input_tensors, - std::vector* output_tensors, - NodeExecStats* stats) { - gtl::InlinedVector inputs; - for (Tensor& t : *input_tensors) { - inputs.push_back(TensorValue(&t)); - } - - std::vector out_attrs(kernel_->num_outputs()); - for (size_t i = 0; i < out_attrs.size(); ++i) { - out_attrs[i].set_on_host(kernel_->output_memory_types()[i] == - tensorflow::HOST_MEMORY); - } - - OpKernelContext::Params params; - params.device = device_; - params.frame_iter = FrameAndIter(0, 0); - params.inputs = &inputs; - params.op_kernel = kernel_.get(); - params.resource_manager = device_->resource_manager(); - params.output_attr_array = gtl::vector_as_array(&out_attrs); - params.function_library = flib_; - params.slice_reader_cache = &slice_reader_cache_; - params.rendezvous = rendez_; - if (stats != nullptr) { - params.track_allocations = true; - } - // TODO(apassos): use a thread pool. - std::function)> runner = - [](std::function f) { f(); }; - params.runner = &runner; - - OpKernelContext context(¶ms); - device_->Compute(kernel_.get(), &context); - if (!context.status().ok()) return context.status(); - - output_tensors->clear(); - for (int i = 0; i < context.num_outputs(); ++i) { - output_tensors->push_back(Tensor(*context.mutable_output(i))); - } - if (stats != nullptr) { - for (const auto& allocator_pair : context.wrapped_allocators()) { - AllocatorMemoryUsed* memory = stats->add_memory(); - memory->set_allocator_name(allocator_pair.first->Name()); - auto sizes = allocator_pair.second->GetSizes(); - memory->set_total_bytes(std::get<0>(sizes)); - memory->set_peak_bytes(std::get<1>(sizes)); - memory->set_live_bytes(std::get<2>(sizes)); - - AllocatorStats allocator_stats; - allocator_pair.first->GetStats(&allocator_stats); - memory->set_allocator_bytes_in_use(allocator_stats.bytes_in_use); - allocator_pair.second->GetRecordsAndUnRef(); - } - auto* ms = stats->mutable_memory_stats(); - ms->set_temp_memory_size(context.temp_memory_allocated()); - for (const auto& alloc_id : context.persistent_alloc_ids()) { - ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id); - } - - ms->set_persistent_memory_size(context.persistent_memory_allocated()); - } - return Status::OK(); -} - } // namespace tensorflow diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h index 4d20b5244a46fcde2eed0a429dced2a77b86aedd..929b1b8296faf61c11c68af06ffc4ca3770ae929 100644 --- a/tensorflow/c/eager/runtime.h +++ b/tensorflow/c/eager/runtime.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -39,9 +40,16 @@ namespace tensorflow { // represent the TF_AttrType type of the values in the list. typedef std::unordered_map AttrTypeMap; +// Look up OpDef for `op_name`. +Status OpDefForOp(const char* op_name, const OpDef** op_def); + // Returns the AttrTypeMap for the TensorFlow operation named op_name. Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out); +// Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'. +Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, + TF_AttrType* out, unsigned char* is_list); + // Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'. Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, TF_AttrType* out, unsigned char* is_list); @@ -146,47 +154,6 @@ template <> AttrBuilder& AttrBuilder::Set(StringPiece attr_name, tensorflow::DataType&& value); -// KernelAndDevice encapsulates an instantiated kernel and the device it is on. -// -// Also see: -// https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h -// and -// https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h -class KernelAndDevice { - public: - // Populates 'out' with a kernel appropriate for 'ndef'. - // - // The provided FunctionLibraryRuntime MUST outlive all calls to - // Run() on the returned KernelAndDevice. - // - // TODO(ashankar): Figure out thread-safety concerns around - // FunctionLibraryRuntime (in particular, how the underlying - // FunctionLibraryDefinition might be mutated by another thread as new - // functions are registered with it). Conservatively, thread-safe usage of - // the FunctionLibraryRuntime is pushed on to the caller (see locking in - // c_api.cc). - static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, - KernelAndDevice* out); - // TODO(ashankar): Remove this - static Status InitOp(Device* device, const NodeDef& ndef, - KernelAndDevice* out); - - KernelAndDevice(tensorflow::Rendezvous* rendez) - : device_(nullptr), flib_(nullptr), rendez_(rendez) {} - - // TODO(ashankar): Handle list-valued inputs. - Status Run(std::vector* inputs, std::vector* outputs, - NodeExecStats* stats); - - const OpKernel* kernel() const { return kernel_.get(); } - - private: - std::unique_ptr kernel_; - Device* device_; - FunctionLibraryRuntime* flib_; - checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; - Rendezvous* rendez_; -}; } // namespace tensorflow diff --git a/tensorflow/c/eager/runtime_test.cc b/tensorflow/c/eager/runtime_test.cc index 643153058ce3d6f0c88dd23a0dec4c6eff060319..27ebeb0508844ee1ee89e0733b66f6ed129b7757 100644 --- a/tensorflow/c/eager/runtime_test.cc +++ b/tensorflow/c/eager/runtime_test.cc @@ -33,27 +33,6 @@ limitations under the License. namespace tensorflow { namespace { -class TestEnv { - public: - TestEnv() : flib_def_(OpRegistry::Global(), {}) { - Device* device = - DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"); - device_mgr_.reset(new DeviceMgr({device})); - flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(), - device, TF_GRAPH_DEF_VERSION, - &flib_def_, {}, nullptr); - } - - FunctionLibraryRuntime* function_library_runtime() const { - return flib_runtime_.get(); - } - - private: - FunctionLibraryDefinition flib_def_; - std::unique_ptr device_mgr_; - std::unique_ptr flib_runtime_; -}; - TEST(AttrTypeMap, Lookup) { const AttrTypeMap* m = nullptr; Status s = AttrTypeMapForOp("ThisOpCannotPossiblyExist", &m); @@ -79,113 +58,5 @@ TEST(AttrTypeMap, Lookup) { EXPECT_NE(is_list, 0); } -TEST(KernelAndDevice, Run) { - Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor()); - std::vector inputs; - inputs.push_back(t); - inputs.push_back(t); - NodeDef ndef(AttrBuilder("MatMul") - .Set("T", DT_FLOAT) - .Set("transpose_a", false) - .Set("transpose_b", false) - .NumInputs(inputs.size()) - .BuildNodeDef()); - TestEnv env; - KernelAndDevice kernel(nullptr); - Status s = - KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel); - ASSERT_TRUE(s.ok()) << s; - std::vector outputs; - s = kernel.Run(&inputs, &outputs, nullptr); - ASSERT_TRUE(s.ok()) << s; - ASSERT_EQ(1, outputs.size()); - const Tensor& out = outputs[0]; - EXPECT_EQ(7, out.matrix()(0, 0)); - EXPECT_EQ(10, out.matrix()(0, 1)); - EXPECT_EQ(15, out.matrix()(1, 0)); - EXPECT_EQ(22, out.matrix()(1, 1)); -} - -void BM_CreateGraph(int iters) { - for (int i = 0; i < iters; ++i) { - Scope root = Scope::NewRootScope(); - auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}}); - auto M = ops::MatMul(root, C, C); - TF_CHECK_OK(root.status()); - } -} -BENCHMARK(BM_CreateGraph); - -void BM_RunGraph(int iters) { - tensorflow::testing::StopTiming(); - Scope root = Scope::NewRootScope(); - auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}}); - auto M = ops::MatMul(root, C, C); - SessionOptions opts; - opts.config.set_inter_op_parallelism_threads(1); - opts.config.set_intra_op_parallelism_threads(1); - ClientSession sess(root, opts); - std::vector outputs; - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { - outputs.clear(); - TF_CHECK_OK(sess.Run({M}, &outputs)); - } -} -BENCHMARK(BM_RunGraph); - -void BM_CreateAndDestroySession(int iters) { - tensorflow::testing::StopTiming(); - Scope root = Scope::NewRootScope(); - auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}}); - auto M = ops::MatMul(root, C, C); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { - ClientSession sess(root); - } -} -BENCHMARK(BM_CreateAndDestroySession); - -void BM_KernelAndDeviceInit(int iters) { - tensorflow::testing::StopTiming(); - NodeDef ndef(AttrBuilder("MatMul") - .Set("T", DT_FLOAT) - .Set("transpose_a", false) - .Set("transpose_b", false) - .NumInputs(2) - .BuildNodeDef()); - TestEnv env; - KernelAndDevice k(nullptr); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { - TF_CHECK_OK( - KernelAndDevice::Init(ndef, env.function_library_runtime(), &k)); - } -} -BENCHMARK(BM_KernelAndDeviceInit); - -void BM_KernelAndDeviceRun(int iters) { - tensorflow::testing::StopTiming(); - Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor()); - std::vector inputs; - inputs.push_back(t); - inputs.push_back(t); - std::vector outputs; - NodeDef ndef(AttrBuilder("MatMul") - .Set("T", DT_FLOAT) - .Set("transpose_a", false) - .Set("transpose_b", false) - .NumInputs(inputs.size()) - .BuildNodeDef()); - TestEnv env; - KernelAndDevice kernel(nullptr); - TF_CHECK_OK( - KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel)); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { - TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr)); - } -} -BENCHMARK(BM_KernelAndDeviceRun); } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index bdb0815d6b68444ec1c89b835d563db20ce4d8a1..c7bd3bdafd787e5c72625b190ea8bf8b8264d22d 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -152,6 +152,8 @@ class GradientTape { gtl::ArraySlice output_gradients, std::vector* result); + bool IsPersistent() const { return persistent_; } + private: TensorTape tensor_tape_; OpTape op_tape_; diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index f553142d15f476ad2c1af68016a4254ed211b9b2..cd604538f1fa142c6fe6a76624c048baddaa52fb 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -104,4 +104,9 @@ void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) { graph->refiner.set_require_shape_inference_fns(require); } +void ExtendSession(TF_Session* session, TF_Status* status) { + ExtendSessionGraphHelper(session, status); + session->extend_before_run = false; +} + } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index 542d70f42c2a5df8309a722b32d850dd249e496f..13b680b3a24afa2d285ea18207578aff4350f6d5 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -41,6 +41,16 @@ void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op); // error. The default is true. void SetRequireShapeInferenceFns(TF_Graph* graph, bool require); +// Extends `session` with any new operations added to its associated graph. +// Usually this happens automatically in TF_SessionRun. After this is called, +// TF_SessionRun will no longer extend the session on every call. +// +// We expose this here to allow fine-grained synchronization in multi-threaded +// workloads, which is required since the Python implementation depends on the +// above mutation methods. This allows us to prevent modifications to nodes in +// the graph after the session has been made aware of them. +void ExtendSession(TF_Session* session, TF_Status* status); + } // namespace tensorflow #endif // TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/c/testdata/tf_record b/tensorflow/c/testdata/tf_record new file mode 100644 index 0000000000000000000000000000000000000000..6e16076bfb79ad8151952e96567565e8820b0f5b Binary files /dev/null and b/tensorflow/c/testdata/tf_record differ diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index a40ad1ffc3b262840e6ca0043139b1b61e04510d..d73121c7b701ec06c03836d1a765f4b35d88fe92 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb_text.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" @@ -697,7 +698,8 @@ string OpInfo::GetOpAttrStruct() const { attr_comment = MakeComment(attr_comment, " "); strings::StrAppend(&setters, attr_comment); - strings::StrAppend(&setters, " Attrs ", attr_func_def, " x) {\n"); + strings::StrAppend(&setters, " TF_MUST_USE_RESULT Attrs ", attr_func_def, + " x) {\n"); strings::StrAppend(&setters, " Attrs ret = *this;\n"); strings::StrAppend(&setters, " ret.", api_def_attr.rename_to(), "_ = x;\n"); diff --git a/tensorflow/cc/framework/while_gradients.cc b/tensorflow/cc/framework/while_gradients.cc index 0734075fc6144d7c9f4fdb48c5e097faa58b8355..81870a0efa309ae6dbd5cc05a5dbe8c3e2d437c8 100644 --- a/tensorflow/cc/framework/while_gradients.cc +++ b/tensorflow/cc/framework/while_gradients.cc @@ -72,9 +72,9 @@ Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope, }; // Body function that adds one to input. - BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope, - const std::vector& inputs, - std::vector* outputs) { + BodyGraphBuilderFn body_fn = [](const Scope& scope, + const std::vector& inputs, + std::vector* outputs) { DCHECK_EQ(inputs.size(), 1); outputs->emplace_back(ops::Add(scope, inputs[0], 1)); return scope.status(); diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 13a3bba5e6d5ca19ff3f0eca76665ba7d3ab628d..0cb3132e94e381f672d69aefe4a199d2b590830c 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -48,8 +48,8 @@ Status SoftmaxGrad(const Scope& scope, const Operation& op, REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad); Status LogSoftmaxGrad(const Scope& scope, const Operation& op, - const std::vector& grad_inputs, - std::vector* grad_outputs) { + const std::vector& grad_inputs, + std::vector* grad_outputs) { auto softmax = Exp(scope, op.output(0)); auto sum = Sum(scope, grad_inputs[0], {1}, Sum::KeepDims(true)); auto mul = Mul(scope, sum, softmax); @@ -107,11 +107,10 @@ Status BiasAddGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { string data_format; - BiasAddGrad::Attrs input_attrs; TF_RETURN_IF_ERROR( GetNodeAttr(op.output(0).node()->attrs(), "data_format", &data_format)); - input_attrs.DataFormat(data_format); - auto dx_1 = BiasAddGrad(scope, grad_inputs[0], input_attrs); + auto dx_1 = + BiasAddGrad(scope, grad_inputs[0], BiasAddGrad::DataFormat(data_format)); grad_outputs->push_back(Identity(scope, grad_inputs[0])); grad_outputs->push_back(dx_1); return scope.status(); @@ -130,19 +129,16 @@ Status Conv2DGrad(const Scope& scope, const Operation& op, TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "use_cudnn_on_gpu", &use_cudnn_on_gpu)); - Conv2DBackpropInput::Attrs input_attrs; - input_attrs.DataFormat(data_format); - input_attrs.UseCudnnOnGpu(use_cudnn_on_gpu); - auto dx_1 = Conv2DBackpropInput(scope, Shape(scope, op.input(0)), - op.input(1), grad_inputs[0], - strides, padding, input_attrs); + auto dx_1 = Conv2DBackpropInput(scope, Shape(scope, op.input(0)), op.input(1), + grad_inputs[0], strides, padding, + Conv2DBackpropInput::DataFormat(data_format) + .UseCudnnOnGpu(use_cudnn_on_gpu)); grad_outputs->push_back(dx_1); - Conv2DBackpropFilter::Attrs filter_attrs; - filter_attrs.DataFormat(data_format); - filter_attrs.UseCudnnOnGpu(use_cudnn_on_gpu); - auto dx_2 = Conv2DBackpropFilter(scope, op.input(0), - Shape(scope, op.input(1)), grad_inputs[0], - strides, padding, filter_attrs); + auto dx_2 = + Conv2DBackpropFilter(scope, op.input(0), Shape(scope, op.input(1)), + grad_inputs[0], strides, padding, + Conv2DBackpropFilter::DataFormat(data_format) + .UseCudnnOnGpu(use_cudnn_on_gpu)); grad_outputs->push_back(dx_2); return scope.status(); } @@ -160,13 +156,9 @@ Status MaxPoolGradHelper(const Scope& scope, const Operation& op, TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); - internal::MaxPoolGrad::Attrs grad_attrs; - grad_attrs.DataFormat(data_format); - auto dx = internal::MaxPoolGrad(scope, op.input(0), - op.output(0), - grad_inputs[0], - ksize, strides, - padding, grad_attrs); + auto dx = internal::MaxPoolGrad( + scope, op.input(0), op.output(0), grad_inputs[0], ksize, strides, padding, + internal::MaxPoolGrad::DataFormat(data_format)); grad_outputs->push_back(dx); return scope.status(); } @@ -180,15 +172,9 @@ Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op, auto attrs = op.output(0).node()->attrs(); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); - MaxPoolGradV2::Attrs grad_attrs; - grad_attrs.DataFormat(data_format); - auto dx = MaxPoolGradV2(scope, op.input(0), - op.output(0), - grad_inputs[0], - op.input(1), - op.input(2), - padding, - grad_attrs); + auto dx = MaxPoolGradV2(scope, op.input(0), op.output(0), grad_inputs[0], + op.input(1), op.input(2), padding, + MaxPoolGradV2::DataFormat(data_format)); grad_outputs->push_back(dx); grad_outputs->push_back(NoGradient()); grad_outputs->push_back(NoGradient()); @@ -196,13 +182,74 @@ Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("MaxPoolV2", MaxPoolGradV2Helper); +Status MaxPool3DGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + std::vector ksize; + std::vector strides; + string padding; + string data_format; + auto attrs = op.output(0).node()->attrs(); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); + MaxPool3DGrad::Attrs grad_attrs; + auto dx = MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0], + ksize, strides, padding, + grad_attrs.DataFormat(data_format)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("MaxPool3D", MaxPool3DGradHelper); + +Status AvgPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + std::vector ksize; + std::vector strides; + string padding; + string data_format; + auto attrs = op.output(0).node()->attrs(); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); + internal::AvgPoolGrad::Attrs grad_attrs; + auto dx = + internal::AvgPoolGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], + ksize, strides, padding, + grad_attrs.DataFormat(data_format)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("AvgPool", AvgPoolGradHelper); + +Status AvgPool3DGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + std::vector ksize; + std::vector strides; + string padding; + string data_format; + auto attrs = op.output(0).node()->attrs(); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); + AvgPool3DGrad::Attrs grad_attrs; + auto dx = AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], + ksize, strides, padding, + grad_attrs.DataFormat(data_format)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("AvgPool3D", AvgPool3DGradHelper); + Status LRNGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, - std::vector* grad_outputs){ - internal::LRNGrad::Attrs grad_attrs; - - auto dx = internal::LRNGrad(scope, grad_inputs[0], op.input(0), op.output(0), - grad_attrs); + std::vector* grad_outputs) { + auto dx = internal::LRNGrad(scope, grad_inputs[0], op.input(0), op.output(0)); grad_outputs->push_back(dx); return scope.status(); } diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index 0cfe5f6e3c49f7c4a3cafbf48ff4e54a0ffd0d47..c4eba7ecb017fe4628140d75a63bc7f0f09deb7f 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -31,8 +31,11 @@ using ops::Elu; using ops::L2Loss; using ops::LogSoftmax; using ops::LRN; +using ops::AvgPool; +using ops::AvgPool3D; using ops::MaxPool; using ops::MaxPoolV2; +using ops::MaxPool3D; using ops::Placeholder; using ops::Relu; using ops::Relu6; @@ -70,9 +73,9 @@ class NNGradTest : public ::testing::Test { // Sets tensor with random values, ensuring that the max value is largest by // a reasonable amount. - // This is an issue for MaxPool and MaxPoolV2, in which perturbations by the - // numeric gradient computation in the gradient checker can change the max - // value if values are too close together. + // This is an issue for MaxPool, MaxPoolV2 and MaxPool3D, in which + // perturbations by the numeric gradient computation in the gradient checker + // can change the max value if values are too close together. template void SetRandomValuesWithBumpedMax(Tensor* tensor) { auto tensor_flat = tensor->flat(); @@ -203,6 +206,41 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) { RunTest(x, x_init_value, y, y_shape); } +TEST_F(NNGradTest, MaxPool3DGradHelper) { + TensorShape x_shape({1, 3, 3, 3, 1}); + TensorShape y_shape({1, 1, 1, 1, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Setup window and strides so that we only do one MaxPool3D. + const std::vector ksize{1, 3, 3, 3, 1}; + const std::vector strides{1, 3, 3, 3, 1}; + auto y = MaxPool3D(scope_, x, ksize, strides, "VALID"); + Tensor x_init_value = Tensor(DT_FLOAT, x_shape); + SetRandomValuesWithBumpedMax(&x_init_value); + RunTest(x, x_init_value, y, y_shape); +} + +TEST_F(NNGradTest, AvgPoolGradHelper) { + TensorShape x_shape({1, 2, 2, 1}); + TensorShape y_shape({1, 1, 1, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Setup window and strides so that we only do one AvgPool. + const std::vector ksize{1, 2, 2, 1}; + const std::vector strides{1, 2, 2, 1}; + auto y = AvgPool(scope_, x, ksize, strides, "SAME"); + RunTest(x, x_shape, y, y_shape); +} + +TEST_F(NNGradTest, AvgPool3DGradHelper) { + TensorShape x_shape({1, 3, 3, 3, 1}); + TensorShape y_shape({1, 1, 1, 1, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Setup window and strides so that we only do one AvgPool3D. + const std::vector ksize{1, 3, 3, 3, 1}; + const std::vector strides{1, 3, 3, 3, 1}; + auto y = AvgPool3D(scope_, x, ksize, strides, "SAME"); + RunTest(x, x_shape, y, y_shape); +} + TEST_F(NNGradTest, LRN){ TensorShape x_shape({1, 1, 2, 1}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); diff --git a/tensorflow/cc/profiler/profiler.h b/tensorflow/cc/profiler/profiler.h index 6077c45c5854fd5812ccb7c91522f93ed4e54883..64edbb5766c3604fbe0f15c2299843718381aa3f 100644 --- a/tensorflow/cc/profiler/profiler.h +++ b/tensorflow/cc/profiler/profiler.h @@ -61,18 +61,18 @@ class Profiler { /// Adds tracing information `run_meta` to profiler. A `run_meta` is /// generated by a TensorFlow session run call. `step` is the key /// to the `run_meta`. When calling ProfileXXX methods, caller can specify - /// `step` in `options` to seletively profile the corresponding `run_meta`. + /// `step` in `options` to selectively profile the corresponding `run_meta`. /// Multiple different `run_meta` can be keyed by the same `step` in order /// to group them together. void AddStep(int64 step, const RunMetadata& run_meta); /// Profiles the model by organizing nodes in graph structure. - /// Each node is an op and the nodes are contected by the op inputs/outputs. + /// Each node is an op and the nodes are connected by the op inputs/outputs. GraphNodeProto ProfileGraph(const Options& options); /// Profiles the model by organizing nodes in name scope structure. /// Each node is an op, and nodes are organized by the ops' name - /// scope, similar to a filesystem tree. + /// scope, similar to a file system tree. /// E.g. /foo is the root of operation /foo/matmul_1 and foo/conv_2. GraphNodeProto ProfileNameScope(const Options& options); diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD index 97f66e79b8ad9f383b22f56e9385fc6d2080e1f8..f413a5cc52e9eb4bc393b8186f5b591681fa2e5e 100644 --- a/tensorflow/cc/tools/BUILD +++ b/tensorflow/cc/tools/BUILD @@ -32,6 +32,7 @@ tf_cc_test( deps = [ ":freeze_saved_model", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:resource_variable_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework_internal", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc index ddf372cdef21e1b3892c9a03714478d5a5785517..4ddddcb5863c9ffb1e5367db750b0d2ffd29cd5e 100644 --- a/tensorflow/cc/tools/freeze_saved_model.cc +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -75,16 +75,13 @@ void GetNodeNameToNodeDefMap( // variable nodes to convert. void GetReachableNodesAndVariables( GraphDef* graph_def, const std::unordered_set& outputs, + const std::unordered_map& name_to_node_map, std::unordered_set* reachable_node_names, std::unordered_set* variable_node_names) { // TODO(suharshs): Add support for ResourceVariables. static const std::unordered_set* kVariableTypes = - new std::unordered_set({"Variable", "VariableV2"}); - // name_to_node_map is needed to get the inputs from the NodeDef corresponding - // the a string node name. These inputs are used when doing our backwards - // traversal. - std::unordered_map name_to_node_map; - GetNodeNameToNodeDefMap(graph_def, &name_to_node_map); + new std::unordered_set({"Variable", "VariableV2", "VarHandleOp"}); + std::queue nodes_to_visit; for (const string& tensor_name : outputs) { // We need to strip off the tensor part to get the node name. @@ -99,7 +96,7 @@ void GetReachableNodesAndVariables( continue; } reachable_node_names->insert(node_name); - NodeDef* node = name_to_node_map[node_name]; + NodeDef* node = name_to_node_map.at(node_name); if (kVariableTypes->find(node->op()) != kVariableTypes->end()) { variable_node_names->insert(node->name()); } @@ -111,7 +108,9 @@ void GetReachableNodesAndVariables( // Gets a map from variable name to variable value. Status GetVariableNameToTensorMap( - Session* session, std::unordered_set variable_names_set, + Session* session, + const std::unordered_map& name_to_node_map, + std::unordered_set variable_names_set, std::unordered_map* variable_name_to_value_map) { if (variable_names_set.empty()) { return Status::OK(); @@ -120,8 +119,14 @@ Status GetVariableNameToTensorMap( std::vector tensor_names; for (const string& node_name : variable_names_set) { variable_names.push_back(node_name); - // We need to run tensors, so append ":0". - tensor_names.push_back(node_name + ":0"); + NodeDef* node_def = name_to_node_map.at(node_name); + if (node_def->op() == "VarHandleOp") { + // If this is a resource variable, we have to run the corresponding + // ReadVariableOp. + tensor_names.push_back(node_name + "/Read/ReadVariableOp:0"); + } else { + tensor_names.push_back(node_name + ":0"); + } } std::vector outputs; TF_RETURN_IF_ERROR( @@ -143,6 +148,15 @@ void ConvertVariableToConstant(const NodeDef& variable_node, (*const_node->mutable_attr())["value"].mutable_tensor()); } +// Converts a ReadVariableOp NodeDef to an Identity NodeDef. +void ConvertReadVariableOpToIdentity(const NodeDef& node, + NodeDef* identity_node) { + identity_node->set_name(node.name()); + identity_node->set_op("Identity"); + (*identity_node->mutable_attr())["T"] = node.attr().at("dtype"); + identity_node->add_input(node.input(0)); +} + // Freezes the subgraph of all nodes needed by `outputs`. Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, const std::unordered_set& outputs, @@ -155,14 +169,19 @@ Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, if (graph_def.node_size() == 0) { return Status::OK(); } + // name_to_node_map is needed to get the inputs from the NodeDef corresponding + // the a string node name. These inputs are used when doing our backwards + // traversal. + std::unordered_map name_to_node_map; + GetNodeNameToNodeDefMap(&graph_def, &name_to_node_map); std::unordered_set reachable_node_names; std::unordered_set variable_node_names; - GetReachableNodesAndVariables(&graph_def, outputs, &reachable_node_names, - &variable_node_names); + GetReachableNodesAndVariables(&graph_def, outputs, name_to_node_map, + &reachable_node_names, &variable_node_names); std::unordered_map variable_to_value_map; - TF_RETURN_IF_ERROR( - GetVariableNameToTensorMap(saved_model_bundle.session.get(), - variable_node_names, &variable_to_value_map)); + TF_RETURN_IF_ERROR(GetVariableNameToTensorMap( + saved_model_bundle.session.get(), name_to_node_map, variable_node_names, + &variable_to_value_map)); // We copy the nodes in the same order they were in the original graph_def. for (const NodeDef& node : graph_def.node()) { if (reachable_node_names.find(node.name()) == reachable_node_names.end()) { @@ -171,6 +190,12 @@ Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, if (variable_node_names.find(node.name()) != variable_node_names.end()) { ConvertVariableToConstant(node, variable_to_value_map[node.name()], frozen_graph_def->add_node()); + } else if (node.op() == "ReadVariableOp" && + variable_node_names.find(node.input(0)) != + variable_node_names.end()) { + // If the node is a ReadVariableOp, its input VarHandleOp will be + // converted to a Constant, so we will need to convert it to an Identity. + ConvertReadVariableOpToIdentity(node, frozen_graph_def->add_node()); } else { // If the node isn't a variable, just copy the node as-is. *frozen_graph_def->add_node() = node; diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc index 52a81a50284aec36bba4e56a0232c886cb0cb6cf..cd35fd3b95deec669218cfa4f25fea2c3ac9e56e 100644 --- a/tensorflow/cc/tools/freeze_saved_model_test.cc +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/cc/tools/freeze_saved_model.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph.pb.h" @@ -113,6 +114,160 @@ class FreezeTest : public ::testing::Test { test::ExpectTensorEqual(unfrozen_outputs[0], frozen_outputs[0]); } + + void TestFreezeGraphWithoutDependentVariables(bool use_resource) { + // Test freezing a graph with variables that are not needed by the outputs + // in the SignatureDef. The resulting graph shouldn't be frozen, but + // non-dependent nodes should be pruned. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), a, b); + if (use_resource) { + Output var = + ops::VarHandleOp(scope.WithOpName("var"), DataType::DT_FLOAT, {}); + Output read_var = ops::ReadVariableOp( + scope.WithOpName("var/Read/ReadVariableOp"), var, DataType::DT_FLOAT); + auto assign = ops::AssignVariableOp(scope.WithOpName("assign"), var, a); + } else { + Output var = + ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); + Output assign = ops::Assign(scope.WithOpName("assign"), var, a); + } + + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + // "c" isnt dependent on the variable, so nothing should be frozen. + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( + graph_def, {"c:0"}, "assign", &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, + &inputs, &outputs)); + + GraphDef expected_graph_def; + Scope expected_scope = Scope::NewRootScope(); + Output expected_a = ops::Const(expected_scope.WithOpName("a"), 10.0f, {}); + Output expected_b = ops::Const(expected_scope.WithOpName("b"), 10.0f, {}); + Output expected_c = + ops::Mul(expected_scope.WithOpName("c"), expected_a, expected_b); + TF_ASSERT_OK(expected_scope.ToGraphDef(&expected_graph_def)); + + GraphDefEqual(frozen_graph_def, expected_graph_def); + + RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), + frozen_graph_def, "c:0"); + } + + void TestFreezeGraphWithDependentVariables(bool use_resource) { + // Test freezing a graph with variables that are needed by outputs in the + // SignatureDef. The variables should be frozen. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output read_var; + if (use_resource) { + Output var = + ops::VarHandleOp(scope.WithOpName("var"), DataType::DT_FLOAT, {}); + read_var = ops::ReadVariableOp( + scope.WithOpName("var/Read/ReadVariableOp"), var, DataType::DT_FLOAT); + auto assign = ops::AssignVariableOp(scope.WithOpName("assign"), var, a); + } else { + Output read_var = + ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); + Output assign = ops::Assign(scope.WithOpName("assign"), read_var, a); + } + Output c = ops::Mul(scope.WithOpName("c"), a, read_var); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + // "c" isnt dependent on the variable, so nothing should be frozen. + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( + graph_def, {"c:0"}, "assign", &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, + &inputs, &outputs)); + + // If using normal variables there should be 3 nodes in the resulting + // graph_def. If using resource variables there should be 4 nodes in the + // resulting graph_def. + // In both cases, none should be variables. + size_t expected_nodes = use_resource ? 4 : 3; + EXPECT_EQ(frozen_graph_def.node_size(), expected_nodes); + for (const NodeDef& node : frozen_graph_def.node()) { + EXPECT_NE(node.op(), "Variable") << node.name(); + EXPECT_NE(node.op(), "VariableV2") << node.name(); + EXPECT_NE(node.op(), "VarHandleOp") << node.name(); + EXPECT_NE(node.op(), "ReadVariableOp") << node.name(); + } + + RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), + frozen_graph_def, "c:0"); + } + + void TestFreezeGraphWithAndWithoutDependentVariables(bool use_resource) { + // Test freezing a graph with some variables that are needed and not needed + // by + // the outputs in the SignatureDef. The resulting graph should only freeze + // dependent variables. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output read_var; + + if (use_resource) { + Output var = + ops::VarHandleOp(scope.WithOpName("var"), DataType::DT_FLOAT, {}); + read_var = ops::ReadVariableOp( + scope.WithOpName("var/Read/ReadVariableOp"), var, DataType::DT_FLOAT); + auto assign = ops::AssignVariableOp(scope.WithOpName("assign"), var, a); + Output var_1 = + ops::VarHandleOp(scope.WithOpName("var_1"), DataType::DT_FLOAT, {}); + Output read_var_1 = + ops::ReadVariableOp(scope.WithOpName("var_1/Read/ReadVariableOp"), + var, DataType::DT_FLOAT); + auto assign_1 = + ops::AssignVariableOp(scope.WithOpName("assign_1"), var_1, a); + } else { + read_var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); + Output assign = ops::Assign(scope.WithOpName("assign"), read_var, a); + Output var_1 = + ops::Variable(scope.WithOpName("var_1"), {}, DataType::DT_FLOAT); + Output assign_1 = ops::Assign(scope.WithOpName("assign_1"), var_1, a); + } + + Output c = ops::Mul(scope.WithOpName("c"), a, read_var); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + // "c" isnt dependent on the variable, so nothing should be frozen. + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( + graph_def, {"c:0"}, "assign", &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, + &inputs, &outputs)); + + // There should be 3 nodes in the resulting graph_def, and none should be + // variables. + size_t expected_nodes = use_resource ? 4 : 3; + EXPECT_EQ(frozen_graph_def.node_size(), expected_nodes); + for (const NodeDef& node : frozen_graph_def.node()) { + EXPECT_NE(node.op(), "Variable") << node.name(); + EXPECT_NE(node.op(), "VariableV2") << node.name(); + EXPECT_NE(node.op(), "VarHandleOp") << node.name(); + EXPECT_NE(node.op(), "ReadVariableOp") << node.name(); + } + + RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), + frozen_graph_def, "c:0"); + } }; TEST_F(FreezeTest, InputsAndOutputsSingleSignatureDef) { @@ -196,111 +351,28 @@ TEST_F(FreezeTest, GraphDefWithNoVariables) { GraphDefEqual(frozen_graph_def, graph_def); } -TEST_F(FreezeTest, GraphDefWithVariablesNotNeededByOutputs) { - // Test freezing a graph with variables that are not needed by the outputs in - // the SignatureDef. The resulting graph shouldn't be frozen, but - // non-dependent nodes should be pruned. - SavedModelBundle saved_model_bundle; - GraphDef graph_def; - Scope scope = Scope::NewRootScope(); - Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); - Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); - Output c = ops::Mul(scope.WithOpName("c"), a, b); - Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); - Output assign = ops::Assign(scope.WithOpName("assign"), var, a); - TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); - // "c" isnt dependent on the variable, so nothing should be frozen. - TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( - graph_def, {"c:0"}, assign.name(), &saved_model_bundle)); - - GraphDef frozen_graph_def; - std::unordered_set inputs; - std::unordered_set outputs; - TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, - &outputs)); - - GraphDef expected_graph_def; - Scope expected_scope = Scope::NewRootScope(); - Output expected_a = ops::Const(expected_scope.WithOpName("a"), 10.0f, {}); - Output expected_b = ops::Const(expected_scope.WithOpName("b"), 10.0f, {}); - Output expected_c = - ops::Mul(expected_scope.WithOpName("c"), expected_a, expected_b); - TF_ASSERT_OK(expected_scope.ToGraphDef(&expected_graph_def)); - - GraphDefEqual(frozen_graph_def, expected_graph_def); - - RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), - frozen_graph_def, "c:0"); +TEST_F(FreezeTest, GraphDefWithoutDependentVariables) { + TestFreezeGraphWithoutDependentVariables(false); } -TEST_F(FreezeTest, GraphDefWithVariablesNeededByOutputs) { - // Test freezing a graph with variables that are needed by outputs in the - // SignatureDef. The variables should be frozen. - SavedModelBundle saved_model_bundle; - GraphDef graph_def; - Scope scope = Scope::NewRootScope(); - Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); - Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); - Output c = ops::Mul(scope.WithOpName("c"), a, var); - Output assign = ops::Assign(scope.WithOpName("assign"), var, a); - TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); - // "c" isnt dependent on the variable, so nothing should be frozen. - TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( - graph_def, {"c:0"}, assign.name(), &saved_model_bundle)); - - GraphDef frozen_graph_def; - std::unordered_set inputs; - std::unordered_set outputs; - TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, - &outputs)); - - // There should be 3 nodes in the resulting graph_def, and none should be - // variables. - EXPECT_EQ(frozen_graph_def.node_size(), 3); - for (const NodeDef& node : frozen_graph_def.node()) { - EXPECT_NE(node.op(), "Variable") << node.name(); - EXPECT_NE(node.op(), "VariableV2") << node.name(); - } - - RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), - frozen_graph_def, "c:0"); +TEST_F(FreezeTest, GraphDefWithoutDependentResourceVariables) { + TestFreezeGraphWithoutDependentVariables(true); } -TEST_F(FreezeTest, GraphDefWithVariablesNeededAndNotNeededByOutputs) { - // Test freezing a graph with some variables that are needed and not needed by - // the outputs in the SignatureDef. The resulting graph should only freeze - // dependent variables. - SavedModelBundle saved_model_bundle; - GraphDef graph_def; - Scope scope = Scope::NewRootScope(); - Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); - Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); - Output c = ops::Mul(scope.WithOpName("c"), a, var); - Output assign = ops::Assign(scope.WithOpName("assign"), var, a); - Output var_1 = - ops::Variable(scope.WithOpName("var_1"), {}, DataType::DT_FLOAT); - Output assign_1 = ops::Assign(scope.WithOpName("assign_1"), var, a); - TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); - // "c" isnt dependent on the variable, so nothing should be frozen. - TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( - graph_def, {"c:0"}, assign.name(), &saved_model_bundle)); +TEST_F(FreezeTest, GraphDefWithDependentVariables) { + TestFreezeGraphWithDependentVariables(false); +} - GraphDef frozen_graph_def; - std::unordered_set inputs; - std::unordered_set outputs; - TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, - &outputs)); +TEST_F(FreezeTest, GraphDefWithDependentResourceVariables) { + TestFreezeGraphWithDependentVariables(true); +} - // There should be 3 nodes in the resulting graph_def, and none should be - // variables. - EXPECT_EQ(frozen_graph_def.node_size(), 3); - for (const NodeDef& node : frozen_graph_def.node()) { - EXPECT_NE(node.op(), "Variable") << node.name(); - EXPECT_NE(node.op(), "VariableV2") << node.name(); - } +TEST_F(FreezeTest, GraphDefWithAndWithoutDependentVariables) { + TestFreezeGraphWithAndWithoutDependentVariables(false); +} - RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), - frozen_graph_def, "c:0"); +TEST_F(FreezeTest, GraphDefWithAndWithoutDependentResourceVariables) { + TestFreezeGraphWithAndWithoutDependentVariables(true); } } // namespace diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 0900e87ebabd378e6237b77ca0ef01677c07c244..ffa2d088295375bbbcd2cdd9365982907f2bf480 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -72,6 +72,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", ], ) diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index c87f2b75dfa18ad5c3eda4bd6fcbcb3083ef73fd..7c833878818022c86fd3171ec9cef9fcd3217a24 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 9dff1be09fede6f65f82c2f36d94be07e781949f..3a877c5337ff76193a7f27fb9681e5a9ca500961 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -132,7 +132,7 @@ def tf_library(name, graph, config, header_file = name + ".h" metadata_object_file = name + "_tfcompile_metadata.o" function_object_file = name + "_tfcompile_function.o" - ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_") + ep = ("__" + native.package_name() + "__" + name).replace("/", "_") if type(tfcompile_flags) == type(""): flags = tfcompile_flags else: diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index a711319607f4ff2b83aa0ebe50e215b3d0e2258e..8e505da6221b23b0130548405f12a61dcda100d7 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -29,7 +29,10 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") # Target that bundles up the XLA CPU and GPU JIT devices. cc_library( name = "jit", - visibility = [":friends"], + visibility = [ + ":friends", + "//learning/tfx:__subpackages__", + ], deps = [ ":xla_cpu_device", ":xla_cpu_jit", @@ -73,6 +76,7 @@ cc_library( ":jit_compilation_passes", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/legacy_flags:xla_device_flags", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep @@ -102,22 +106,44 @@ cc_library( cc_library( name = "xla_interpreter_device", srcs = ["xla_interpreter_device.cc"], + visibility = [":friends"], deps = [ + ":jit_compilation_passes", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_launch_op", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_tensor_info", + srcs = ["xla_tensor_info.cc"], + hdrs = ["xla_tensor_info.h"], + deps = [ + ":common", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", ], - alwayslink = True, ) cc_library( name = "xla_device", srcs = [ + "xla_compile_on_demand_op.cc", "xla_device.cc", "xla_device_context.cc", "xla_device_ops.cc", ], hdrs = [ + "xla_compile_on_demand_op.h", "xla_device.h", "xla_device_context.h", "xla_device_ops.h", @@ -127,6 +153,8 @@ cc_library( deps = [ ":common", ":jit_compilation_passes", + ":xla_launch_util", + ":xla_tensor_info", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:dump_graph", @@ -166,6 +194,29 @@ cc_library( visibility = [":friends"], ) +cc_library( + name = "xla_launch_util", + srcs = ["xla_launch_util.cc"], + hdrs = ["xla_launch_util.h"], + deps = [ + ":common", + ":xla_compilation_cache", + ":xla_tensor_info", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels:variable_ops", + ], +) + cc_library( name = "xla_compilation_cache", srcs = ["xla_compilation_cache.cc"], @@ -200,6 +251,7 @@ cc_library( name = "graph_to_functiondef", srcs = ["graph_to_functiondef.cc"], hdrs = ["graph_to_functiondef.h"], + visibility = [":friends"], deps = [ "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -296,6 +348,7 @@ tf_cc_test( deps = [ ":common", ":compilation_passes", + ":graph_to_functiondef", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 9c372a012789fc25ca0a711349c09ca62edc6754..7fc43fb26318335909d52d5bbd83ebf61f42a703 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -53,6 +53,8 @@ namespace tensorflow { const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel"; const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs"; const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; +const char* const kXlaHostTransferSequencerAttr = + "_xla_host_transfer_sequencer"; namespace { @@ -143,7 +145,7 @@ struct NodeSlot { // everything to use it. static const char* const kArgOp = "_Arg"; static const char* const kRetValOp = "_Retval"; -static const char* const kHostComputeOp = "_XlaHostCompute"; +static const char* const kHostComputeOp = "XlaHostCompute"; static const char* const kSendFromHostOp = "_XlaSendFromHost"; static const char* const kRecvAtHostOp = "_XlaRecvAtHost"; @@ -328,12 +330,14 @@ class Encapsulator { Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out); // If there is a sequencer node, adds a control edge from the sequencer to - // all the downstream nodes of call_node_outputs. - void ConnectSequencerToOutputs(Graph* graph_out); + // the call node. + void ConnectSequencerToCallNode(Graph* graph_out); Status AddShapeInferenceInfo( + const string& subgraph_name, const string& outside_compilation_subgraph_name, - const std::vector& shapes, GraphDef* inference_graph); + const std::vector& shapes, Graph* inference_graph, + FunctionLibraryDefinition* library); Status ReplaceFunctionDef(FunctionLibraryDefinition* library); @@ -381,12 +385,24 @@ class Encapsulator { Node* send_from_host = nullptr; }; + // Creates an outside_compilation subgraph for outside_compilation_id if + // none exists yet. Returns the (possible newly created) subgraph for + // outside_compilation_id. + OutsideCompilationSubgraph* LookupOrCreateOutsideCompilationSubgraph( + const string& outside_compilation_id); + // Builds a ParallelCheck op that compares the output of the original // subgraph with the encapsulated subgraph. Status BuildParallelCheckOp( const std::unordered_map& node_images, Graph* graph_out); + // Builds a placeholder node used to provide the key input to a RecvAtHost + // or SendFromHost node. This placeholder node will be removed by a later + // pass. + Status AddHostComputeKeyPlaceholder(OutsideCompilationSubgraph* oc_subgraph, + Graph* graph_out); + // Builds a _RecvAtHost node producing all the inputs of an // outside_compilation subgraph and stores it in oc_subgraph.recv_at_host. Status AddRecvAtHostNode(const string& subgraph_name, @@ -413,6 +429,14 @@ class Encapsulator { // NodeDef for the function call node. NodeDef call_node_def_; + // Name that is used for the call node. This may not be + // call_node_def_.name() if the client supplies a rewrite lambda. + string function_def_name_; + + // Placeholder node simulating the host compute key in the output graph. + // Not owned. + Node* host_compute_key_placeholder_ = nullptr; + // Function call node(s) in the output graph. Not owned. // If parallel_checking is enabled, 'call_node_inputs' is the function call // node to which inputs should be fed, and 'call_node_outputs' is the @@ -551,7 +575,7 @@ class Encapsulator { const std::unordered_set& recv_at_host_nodes, Node* send_node, FunctionLibraryDefinition* library, std::vector* static_shape_out, - std::unique_ptr* graphdef_out); + std::unique_ptr* graph_out); // Makes a copy of graph containing only nodes that are ancestors of at least // one node in send_from_host_nodes and store it in pruned_graph. On exit @@ -712,39 +736,44 @@ Status Encapsulator::Subgraph::RecordResult( return Status::OK(); } -void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl( - const string& outside_compilation_id, const Edge* edge) { +Encapsulator::Subgraph::OutsideCompilationSubgraph* +Encapsulator::Subgraph::LookupOrCreateOutsideCompilationSubgraph( + const string& outside_compilation_id) { auto iter = outside_compilation_subgraphs_ .emplace(outside_compilation_id, OutsideCompilationSubgraph()) .first; - OutsideCompilationSubgraph& outside_subgraph = iter->second; + OutsideCompilationSubgraph* outside_subgraph = &iter->second; + return outside_subgraph; +} + +void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl( + const string& outside_compilation_id, const Edge* edge) { + OutsideCompilationSubgraph* outside_subgraph = + LookupOrCreateOutsideCompilationSubgraph(outside_compilation_id); if (edge->IsControlEdge()) { - outside_subgraph.control_inputs.insert(edge->src()); + outside_subgraph->control_inputs.insert(edge->src()); } else { - int input_index = outside_subgraph.inputs.size(); - outside_subgraph.inputs.emplace(NodeSlot(edge->src(), edge->src_output()), - input_index); + int input_index = outside_subgraph->inputs.size(); + outside_subgraph->inputs.emplace(NodeSlot(edge->src(), edge->src_output()), + input_index); } } void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl( const string& outside_compilation_id, const Edge* edge) { - auto subgraph_iter = - outside_compilation_subgraphs_ - .emplace(outside_compilation_id, OutsideCompilationSubgraph()) - .first; - OutsideCompilationSubgraph& outside_subgraph = subgraph_iter->second; + OutsideCompilationSubgraph* outside_subgraph = + LookupOrCreateOutsideCompilationSubgraph(outside_compilation_id); if (edge->IsControlEdge()) { - outside_subgraph.control_outputs.insert(edge->dst()); + outside_subgraph->control_outputs.insert(edge->dst()); } else { DataType dtype = edge->dst()->input_type(edge->dst_input()); auto output_iter = - outside_subgraph.outputs_by_src + outside_subgraph->outputs_by_src .emplace(NodeSlot(edge->src(), edge->src_output(), dtype), - outside_subgraph.outputs_by_src.size()) + outside_subgraph->outputs_by_src.size()) .first; int output_index = output_iter->second; - outside_subgraph.outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] = + outside_subgraph->outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] = output_index; } } @@ -842,25 +871,21 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, NodeDef seq_def; NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"), "NoOp"); + builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name); + builder.Device(device_); Status s = builder.Finalize(&seq_def); if (!s.ok()) return s; sequencer_ = graph_out->AddNode(seq_def, &s); if (!s.ok()) return s; - sequencer_->set_assigned_device_name(device_); } return Status::OK(); } -void Encapsulator::Subgraph::ConnectSequencerToOutputs(Graph* graph_out) { +void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) { if (sequencer_ != nullptr) { - std::unordered_set output_dependencies; - for (Node* node : call_node_outputs_->out_nodes()) { - output_dependencies.insert(node); - } - for (Node* node : output_dependencies) { - graph_out->AddControlEdge(sequencer_, node); - } + VLOG(2) << "ConnectSequencerToCallNode"; + graph_out->AddControlEdge(sequencer_, call_node_inputs_); } } @@ -906,6 +931,8 @@ Status Encapsulator::Subgraph::BuildFunctionDef( name = call_node_def_.op(); } + function_def_name_ = name; + FunctionDef fdef; TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); @@ -924,8 +951,10 @@ Status Encapsulator::Subgraph::BuildFunctionDef( } Status Encapsulator::Subgraph::AddShapeInferenceInfo( + const string& subgraph_name, const string& outside_compilation_subgraph_name, - const std::vector& shapes, GraphDef* inference_graph) { + const std::vector& shapes, Graph* inference_graph, + FunctionLibraryDefinition* library) { OutsideCompilationSubgraph& oc_subgraph = outside_compilation_subgraphs_.at(outside_compilation_subgraph_name); @@ -947,21 +976,22 @@ Status Encapsulator::Subgraph::AddShapeInferenceInfo( host_compute->AddAttr("shape_inference_graph", ""); host_compute->AddAttr("shapes", shapes); } else { - string serialized_graph; - if (!inference_graph->SerializeToString(&serialized_graph)) { - return errors::Internal( - "Failed to serialize graph for outside compilation subgraph ", - oc_subgraph.host_compute_name); - } - host_compute->AddAttr("shape_inference_graph", serialized_graph); + string inference_graph_name = + strings::StrCat("_outside_compilation_shape_inference_", subgraph_name, + "_", outside_compilation_subgraph_name); + FunctionDef fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef)); + host_compute->AddAttr("shape_inference_graph", inference_graph_name); host_compute->AddAttr("shapes", std::vector()); + TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); } return Status::OK(); } Status Encapsulator::Subgraph::ReplaceFunctionDef( FunctionLibraryDefinition* library) { - const string& name = call_node_def_.name(); + const string& name = function_def_name_; FunctionDef fdef; TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); @@ -1060,9 +1090,36 @@ Status Encapsulator::Subgraph::AddFunctionCallNode( return Status::OK(); } +Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder( + OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { + TensorShapeProto shape_proto; + TensorShape shape({2}); + shape.AsProto(&shape_proto); + GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); + NodeDef key_def; + NodeDefBuilder builder( + strings::StrCat(call_node_def_.name(), "_key_placeholder"), + "Placeholder"); + builder.Attr("dtype", DT_STRING); + builder.Attr("shape", shape_proto); + builder.Attr("_host_compute_call_node", call_node_def_.name()); + Status s = builder.Finalize(&key_def); + if (!s.ok()) return s; + + host_compute_key_placeholder_ = graph_out->AddNode(key_def, &s); + if (!s.ok()) return s; + host_compute_key_placeholder_->set_assigned_device_name(device_); + + return Status::OK(); +} + Status Encapsulator::Subgraph::AddRecvAtHostNode( const string& subgraph_name, const string& oc_subgraph_name, OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { + if (host_compute_key_placeholder_ == nullptr) { + TF_RETURN_IF_ERROR(AddHostComputeKeyPlaceholder(oc_subgraph, graph_out)); + } + std::vector dtypes(oc_subgraph->inputs.size(), DT_INVALID); for (const auto& input : oc_subgraph->inputs) { @@ -1078,15 +1135,22 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_recv"), kRecvAtHostOp); + // TODO(misard) When we add replication the device placement will have to be + // redone. + builder.Device(device_); builder.Attr("Toutputs", dtypes); + // TODO(misard) For now we only support TPU device 0. + builder.Attr("device_ordinal", 0); builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, "_", oc_subgraph_name)); + builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING); Status s = builder.Finalize(&recv_def); if (!s.ok()) return s; oc_subgraph->recv_at_host = graph_out->AddNode(recv_def, &s); if (!s.ok()) return s; - oc_subgraph->recv_at_host->set_assigned_device_name(device_); + graph_out->AddEdge(host_compute_key_placeholder_, 0, + oc_subgraph->recv_at_host, 0); // Add a control dependency forcing the RecvAtHost to run before the subgraph // completes. This has no effect on execution order but prevents the @@ -1101,6 +1165,10 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( const std::unordered_map& node_images, const string& subgraph_name, const string& oc_subgraph_name, OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { + if (host_compute_key_placeholder_ == nullptr) { + TF_RETURN_IF_ERROR(AddHostComputeKeyPlaceholder(oc_subgraph, graph_out)); + } + std::vector dtypes(oc_subgraph->outputs_by_src.size(), DT_INVALID); std::vector inputs( oc_subgraph->outputs_by_src.size()); @@ -1120,16 +1188,23 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_send"), kSendFromHostOp); + // TODO(misard) When we add replication the device placement will have to be + // redone. + builder.Device(device_); builder.Attr("Tinputs", dtypes); builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, "_", oc_subgraph_name)); + // TODO(misard) For now we only support TPU device 0. + builder.Attr("device_ordinal", 0); builder.Input(inputs); + builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING); Status s = builder.Finalize(&send_def); if (!s.ok()) return s; oc_subgraph->send_from_host = graph_out->AddNode(send_def, &s); if (!s.ok()) return s; - oc_subgraph->send_from_host->set_assigned_device_name(device_); + graph_out->AddEdge(host_compute_key_placeholder_, 0, + oc_subgraph->send_from_host, inputs.size()); // Add a control dependency forcing the SendFromHost to run before the // subgraph completes. This has no effect on execution order but prevents the @@ -1611,7 +1686,7 @@ Status Encapsulator::AddEdgesToOutputGraph( for (auto& subgraph_entry : subgraphs_) { Subgraph& subgraph = subgraph_entry.second; - subgraph.ConnectSequencerToOutputs(graph_out); + subgraph.ConnectSequencerToCallNode(graph_out); } return Status::OK(); @@ -1690,7 +1765,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( const std::unordered_set& recv_at_host_nodes, Node* send_node, FunctionLibraryDefinition* library, std::vector* static_shape_out, - std::unique_ptr* graphdef_out) { + std::unique_ptr* graph_out) { // Maps from nodes in graph_in to nodes in graph_out. // // When an edge has fully defined shape the source node in graph_in is @@ -1707,9 +1782,11 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( std::unordered_map dummy_node_images; std::unordered_map copied_node_images; - std::unique_ptr graph_out(new Graph(graph_in.op_registry())); - graph_out->set_versions(graph_in.versions()); - static_shape_out->resize(send_node->num_inputs()); + graph_out->reset(new Graph(graph_in.op_registry())); + (*graph_out)->set_versions(graph_in.versions()); + // The final input to the send node is the dynamic key, which we don't include + // in the static shapes. + static_shape_out->resize(send_node->num_inputs() - 1); // We don't use the standard ReverseDFS because we want to cut off traversal // whenever we find an output with fully defined shape. @@ -1728,7 +1805,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( if (w.leave) { TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph( n, send_node, dummy_node_images, library, &copied_node_images, - graph_out.get())); + graph_out->get())); } else { if (visited[n->id()]) continue; visited[n->id()] = true; @@ -1750,14 +1827,23 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( // continue. TensorShapeProto proto; context->ShapeHandleToProto(shape, &proto); - dummy_node_images[src_node] = AddDummyShapedNode( - src_node->output_type(src_port), proto, graph_out.get()); - if (n == send_node) { + if (dummy_node_images.find(src_node) == dummy_node_images.end()) { + dummy_node_images[src_node] = AddDummyShapedNode( + src_node->output_type(src_port), proto, graph_out->get()); + } + // The final input to the send node is the dynamic key, which we + // don't include in the static shapes. + if (n == send_node && + in_edge->dst_input() < static_shape_out->size()) { (*static_shape_out)[in_edge->dst_input()] = proto; } } else { + has_parent_with_unknown_shape = true; if (!visited[src_node->id()]) { - has_parent_with_unknown_shape = true; + if (VLOG_IS_ON(2)) { + TensorShapeProto proto; + context->ShapeHandleToProto(shape, &proto); + } stack.push_back({src_node, false}); } } @@ -1768,7 +1854,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( // The shapes of all the inputs to send_node are statically known. We // won't have to do any inference at compile time so return now: the // shapes were stored in static_shape_out above. - graphdef_out->reset(); + graph_out->reset(); return Status::OK(); } else { // Any shape that is being processed is either the original send node @@ -1791,9 +1877,6 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( } } - graphdef_out->reset(new GraphDef()); - graph_out->ToGraphDef(graphdef_out->get()); - return Status::OK(); } @@ -1910,14 +1993,20 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends( TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends( *graph_out, &pruned_graph, &shape_refiner, &node_images, library)); + if (VLOG_IS_ON(1)) { + dump_graph::DumpGraphToFile("pruned_graph_for_shape_inference", + *pruned_graph, library); + } + for (auto& subgraph_entry : subgraphs_) { + const string& subgraph_name = subgraph_entry.first; Subgraph& subgraph = subgraph_entry.second; // Find all the recv_at_host nodes in this subgraph. std::vector outside_compilation_names; subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names); std::unordered_set recv_at_host_names; - for (const auto& name : outside_compilation_names) { - Node* recv_node = subgraph.GetRecvAtHostNode(name); + for (const auto& oc_name : outside_compilation_names) { + Node* recv_node = subgraph.GetRecvAtHostNode(oc_name); if (recv_node != nullptr) { recv_at_host_names.insert(recv_node->name()); } @@ -1926,26 +2015,30 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends( // without knowing the shape of the recv_at_host nodes, and store the // result, along with enough information to complete the job at compile time // once the recv_at_host shapes are known. - for (const auto& name : outside_compilation_names) { - Node* send_node = subgraph.GetSendFromHostNode(name); + for (const auto& oc_name : outside_compilation_names) { + Node* send_node = subgraph.GetSendFromHostNode(oc_name); std::vector static_shape; - std::unique_ptr graphdef; + std::unique_ptr graph; if (send_node != nullptr) { TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend( *pruned_graph, shape_refiner, recv_at_host_names, - node_images[send_node], library, &static_shape, &graphdef)); - if (graphdef == nullptr) { + node_images[send_node], library, &static_shape, &graph)); + if (graph == nullptr) { VLOG(2) << "Send node " << send_node->name() << " shapes"; for (int i = 0; i < static_shape.size(); ++i) { VLOG(2) << static_shape[i].DebugString(); } } else { - VLOG(2) << "Send node " << send_node->name() << " graph\n" - << graphdef->DebugString(); + if (VLOG_IS_ON(2)) { + GraphDef graphdef; + graph->ToGraphDef(&graphdef); + VLOG(2) << "Send node " << send_node->name() << " graph\n" + << graphdef.DebugString(); + } } } - TF_RETURN_IF_ERROR( - subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get())); + TF_RETURN_IF_ERROR(subgraph.AddShapeInferenceInfo( + subgraph_name, oc_name, static_shape, graph.get(), library)); } if (!outside_compilation_names.empty()) { TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library)); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index aed9cae0f1799c4524da8ee309344849798755d5..94481a1fde986b705764f6f0c6de14fb28002496 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -13,12 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/graph_to_functiondef.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" @@ -29,6 +31,27 @@ limitations under the License. namespace tensorflow { namespace { +const char* const kXlaHostTransferSequencerAttr = + "_xla_host_transfer_sequencer"; + +Status AddGraphDefToFunctionLibrary(const GraphDefBuilder& graphdef_builder, + const string& name_suffix, + FunctionDefLibrary* library) { + GraphDef graphdef; + TF_RETURN_IF_ERROR(graphdef_builder.ToGraphDef(&graphdef)); + std::unique_ptr graph = + std::unique_ptr(new Graph(OpRegistry::Global())); + GraphConstructorOptions opts; + opts.allow_internal_ops = true; + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graphdef, graph.get())); + FunctionDef* fdef = library->add_function(); + TF_RETURN_IF_ERROR(GraphToFunctionDef( + *graph, + strings::StrCat("_outside_compilation_shape_inference_", name_suffix), + fdef)); + return Status::OK(); +} + template bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, const ::tensorflow::protobuf::Map& b, @@ -112,23 +135,7 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, a.attr(), b.attr(), [](const string& s) { return s; }, [](const AttrValue& v) { return v.DebugString(); }, [](const string& key, const AttrValue& av, const AttrValue& bv) { - if (key == "shape_inference_graph") { - // Default serialization of GraphDef is unstable because maps don't - // serialize deterministically. Rather than go through the hoops to - // turn on deterministic serialization of this attr just for this - // test, add logic here to compare determinstically. - GraphDef ga; - if (!ga.ParseFromString(av.s())) { - return false; - } - GraphDef gb; - if (!gb.ParseFromString(bv.s())) { - return false; - } - return EqualGraphDef(ga, gb, nullptr); - } else { - return av.DebugString() == bv.DebugString(); - } + return av.DebugString() == bv.DebugString(); }, strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()), diff); @@ -246,26 +253,32 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, << diff << "\nActual: " << actual.DebugString(); \ } while (false) -// TODO(misard): remove these fake registrations once there are real Ops to be -// compiled. -REGISTER_OP("_XlaHostCompute") +// These dummy Op registrations are here because the real Op registrations live +// in contrib and there can't be a dependence from this test to contrib. +REGISTER_OP("XlaHostCompute") .Input("inputs: Tinputs") .Output("outputs: Toutputs") .Attr("Tinputs: list(type) >= 0") .Attr("Toutputs: list(type) >= 0") .Attr("key: string") + .Attr("shape_inference_graph: string = ''") + .Attr("shapes: list(shape) >= 0") .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("_XlaSendFromHost") - .Input("input: Tinputs") + .Input("inputs: Tinputs") + .Input("dynamic_key: string") .Attr("Tinputs: list(type) >= 0") .Attr("key: string") + .Attr("device_ordinal: int") .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("_XlaRecvAtHost") - .Output("output: Toutputs") + .Input("dynamic_key: string") + .Output("outputs: Toutputs") .Attr("Toutputs: list(type) >= 0") .Attr("key: string") + .Attr("device_ordinal: int") .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("InputTest") @@ -315,8 +328,13 @@ REGISTER_OP("AddNLikeTest") .SetIsCommutative() .SetIsAggregate(); -Node* NoOp(const GraphDefBuilder::Options& opts) { - return ops::SourceOp("NoOp", opts); +Node* Sequencer(const GraphDefBuilder::Options& opts, + const string& call_node_name) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp("NoOp"), "NoOp", + opts.op_registry()); + return opts.WithAttr(kXlaHostTransferSequencerAttr, call_node_name) + .FinalizeBuilder(&node_builder); } Node* Input(const GraphDefBuilder::Options& opts) { @@ -327,43 +345,71 @@ Node* InputShaped(const GraphDefBuilder::Options& opts) { return ops::SourceOp("InputTestShaped", opts); } -Node* KnownShape(const gtl::ArraySlice& shape, - const GraphDefBuilder::Options& opts) { +Node* KnownShapeBase(DataType dtype, const gtl::ArraySlice& shape, + const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const", opts.op_registry()); TensorProto value; - value.set_dtype(DT_FLOAT); + value.set_dtype(dtype); for (int dim : shape) { value.mutable_tensor_shape()->add_dim()->set_size(dim); } return opts.WithAttr("value", value) - .WithAttr("dtype", DT_FLOAT) + .WithAttr("dtype", dtype) .FinalizeBuilder(&node_builder); } -Node* RecvAtHost(const string& key, const gtl::ArraySlice& dtypes, +Node* KnownShape(const gtl::ArraySlice& shape, + const GraphDefBuilder::Options& opts) { + return KnownShapeBase(DT_FLOAT, shape, opts); +} + +Node* KeyPlaceholderShape(const GraphDefBuilder::Options& opts) { + return KnownShapeBase(DT_STRING, {2}, opts); +} + +Node* KeyPlaceholder(const string& call_node, + const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp("Placeholder"), "Placeholder", + opts.op_registry()); + TensorShapeProto shape; + shape.add_dim()->set_size(2); + return opts.WithAttr("shape", shape) + .WithAttr("dtype", DT_STRING) + .WithAttr("_host_compute_call_node", call_node) + .FinalizeBuilder(&node_builder); +} + +Node* RecvAtHost(ops::NodeOut key_input, const string& key, + const gtl::ArraySlice& dtypes, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"), "_XlaRecvAtHost", opts.op_registry()); + node_builder.Input(std::move(key_input)); return opts.WithAttr("Toutputs", dtypes) .WithAttr("key", key) + .WithAttr("device_ordinal", 0) .FinalizeBuilder(&node_builder); } -Node* SendFromHost(const string& key, const std::vector& inputs, +Node* SendFromHost(ops::NodeOut key_input, const string& key, + const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"), "_XlaSendFromHost", opts.op_registry()); node_builder.Input(inputs); + node_builder.Input(std::move(key_input)); std::vector dtypes; for (const auto& node : inputs) { dtypes.push_back(node.dt); } - return opts.WithAttr("key", key) - .WithAttr("Tinputs", dtypes) + return opts.WithAttr("Tinputs", dtypes) + .WithAttr("key", key) + .WithAttr("device_ordinal", 0) .FinalizeBuilder(&node_builder); } @@ -806,19 +852,20 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; - string shape_string_expected; { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); + Node* key_constant = + KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); Node* recv = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", + {DT_FLOAT, DT_FLOAT}, shape.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), shape.opts().WithName("E")); - SendFromHost("host_compute_channel_F1_O1", {e}, - shape.opts().WithName("outside_compilation_F1_O1_send")); - GraphDef shape_graph; - TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); - EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", + {e}, shape.opts().WithName("outside_compilation_F1_O1_send")); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } *library_expected.add_function() = test::function::XTimesTwo(); @@ -833,12 +880,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"C:o:0", "c:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", shape_string_expected}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, {"shapes", gtl::ArraySlice({})}}, {"c"}}, }, @@ -851,24 +899,30 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - NodeBuilder node_builder("F1", "F1", lib_def.get()); - node_builder.Input(a).Input(b); - Node* call = b2.opts().FinalizeBuilder(&node_builder); - + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); Node* recv = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", + {DT_FLOAT, DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), b2.opts().WithName("E").WithControlInputs({recv, b})); - Node* send = SendFromHost("host_compute_channel_F1_O1", {e}, + Node* send = SendFromHost(ops::NodeOut(key_constant, 0), + "host_compute_channel_F1_O1", {e}, b2.opts() .WithName("outside_compilation_F1_O1_send") .WithControlInput(e)); - Node* s = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send})); + Node* s = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), + "F1"); + + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(a).Input(b); + Node* call = + b2.opts().WithControlInputs({s}).FinalizeBuilder(&node_builder); - Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e})); + Binary(a, call, b2.opts().WithName("G").WithControlInputs({e})); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -918,38 +972,41 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; - string shape_string_expected_1; { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = + KeyPlaceholderShape(shape1.opts().WithName("KnownShape/_0")); Node* recv = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", + {DT_FLOAT, DT_FLOAT}, shape1.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), shape1.opts().WithName("E")); - SendFromHost("host_compute_channel_F1_O1", {e}, - shape1.opts().WithName("outside_compilation_F1_O1_send")); - GraphDef shape1_graph; - TF_EXPECT_OK(shape1.ToGraphDef(&shape1_graph)); - EXPECT_TRUE(shape1_graph.SerializeToString(&shape_string_expected_1)); + SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", + {e}, shape1.opts().WithName("outside_compilation_F1_O1_send")); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } - string shape_string_expected_2; { GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); + Node* key_constant = + KeyPlaceholderShape(shape2.opts().WithName("KnownShape/_0")); Node* recv1 = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", + {DT_FLOAT, DT_FLOAT}, shape2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), shape2.opts().WithName("E")); Node* recv2 = - RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT}, + RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O2", + {DT_FLOAT, DT_FLOAT}, shape2.opts().WithName("outside_compilation_F1_O2_recv")); Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H")); - SendFromHost("host_compute_channel_F1_O2", {h}, - shape2.opts().WithName("outside_compilation_F1_O2_send")); - GraphDef shape2_graph; - TF_EXPECT_OK(shape2.ToGraphDef(&shape2_graph)); - EXPECT_TRUE(shape2_graph.SerializeToString(&shape_string_expected_2)); + SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O2", + {h}, shape2.opts().WithName("outside_compilation_F1_O2_send")); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected)); } *library_expected.add_function() = FunctionDefHelper::Create( @@ -966,21 +1023,23 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O2_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"D:o:0", "F:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", shape_string_expected_2}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O2"}, {"shapes", gtl::ArraySlice({})}}, {"F"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"C:o:0", "D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", shape_string_expected_1}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, {"shapes", gtl::ArraySlice({})}}, {"D"}}, }, @@ -993,35 +1052,41 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - NodeBuilder node_builder("F1", "F1", lib_def.get()); - node_builder.Input(a).Input(b); - Node* call = b2.opts().FinalizeBuilder(&node_builder); - + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); Node* recv1 = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", + {DT_FLOAT, DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts().WithName("E").WithControlInputs({recv1, b})); - Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e}, + Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), + "host_compute_channel_F1_O1", {e}, b2.opts() .WithName("outside_compilation_F1_O1_send") .WithControlInput(e)); Node* recv2 = - RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT}, + RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O2", + {DT_FLOAT, DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O2_recv")); Node* g = Binary(e, ops::NodeOut(recv2, 1), b2.opts().WithName("G").WithControlInputs({recv2, e})); Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H")); - Node* send2 = - SendFromHost("host_compute_channel_F1_O2", {h}, - b2.opts().WithName("outside_compilation_F1_O2_send")); + Node* send2 = SendFromHost( + ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O2", {h}, + b2.opts().WithName("outside_compilation_F1_O2_send")); - Node* s = NoOp(b2.opts() - .WithName("F1_sequencer") - .WithControlInputs({recv1, send1, recv2, send2})); + Node* s = Sequencer(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, send1, recv2, send2}), + "F1"); - Binary(g, call, b2.opts().WithName("J").WithControlInput(s)); + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(a).Input(b); + Node* call = b2.opts().WithControlInput(s).FinalizeBuilder(&node_builder); + + Binary(g, call, b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1070,19 +1135,20 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; - string shape_string_expected; { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); + Node* key_constant = + KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); Node* recv = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", + {DT_FLOAT, DT_FLOAT}, shape.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), shape.opts().WithName("E")); - SendFromHost("host_compute_channel_F1_O1", {e}, - shape.opts().WithName("outside_compilation_F1_O1_send")); - GraphDef shape_graph; - TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); - EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", + {e}, shape.opts().WithName("outside_compilation_F1_O1_send")); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } TensorShapeProto shape_proto_expected; @@ -1100,12 +1166,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"C:o:0", "D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", shape_string_expected}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, {"shapes", gtl::ArraySlice({})}}, {"D"}}, }, @@ -1120,7 +1187,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { "BinaryTest", {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"G:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, @@ -1138,39 +1205,47 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { Node* a = InputShaped(b2.opts().WithName("A")); Node* b = InputShaped(b2.opts().WithName("B")); + Node* key_constant1 = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); Node* recv1 = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + RecvAtHost(ops::NodeOut(key_constant1, 0), "host_compute_channel_F1_O1", + {DT_FLOAT, DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts().WithName("E").WithControlInputs({recv1, b})); - Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e}, + Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), + "host_compute_channel_F1_O1", {e}, b2.opts() .WithName("outside_compilation_F1_O1_send") .WithControlInput(e)); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); - - Node* recv2 = - RecvAtHost("host_compute_channel_F2_O1", {DT_FLOAT}, - b2.opts().WithName("outside_compilation_F2_O1_recv")); - Node* h = Binary(ops::NodeOut(call1, 1), recv2, - b2.opts().WithName("H").WithControlInput(s1)); - Node* send2 = - SendFromHost("host_compute_channel_F2_O1", {h}, - b2.opts().WithName("outside_compilation_F2_O1_send")); - + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); + + Node* key_constant2 = + KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder")); + Node* recv2 = RecvAtHost( + ops::NodeOut(key_constant2, 0), "host_compute_channel_F2_O1", + {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_recv")); + Node* h = Binary(ops::NodeOut(call1, 1), recv2, b2.opts().WithName("H")); + Node* send2 = SendFromHost( + ops::NodeOut(key_constant2, 0), "host_compute_channel_F2_O1", {h}, + b2.opts().WithName("outside_compilation_F2_O1_send")); + + Node* s2 = Sequencer( + b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}), + "F2"); NodeBuilder node_builder2("F2", "F2", lib_def.get()); node_builder2.Input(e).Input(call1); Node* call2 = b2.opts() - .WithControlInputs({s1, e, call1}) + .WithControlInputs({s2, e, call1}) .FinalizeBuilder(&node_builder2); - Node* s2 = NoOp( - b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2})); - Binary(call2, ops::NodeOut(call2, 1), - b2.opts().WithName("J").WithControlInput(s2)); + Binary(call2, ops::NodeOut(call2, 1), b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1218,7 +1293,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { "BinaryTest", {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {}, {{"Tinputs", gtl::ArraySlice({})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, @@ -1237,15 +1312,19 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { Node* b = Input(b2.opts().WithName("B")); Node* e = Unary(a, b2.opts().WithName("E")); - Node* send1 = - SendFromHost("host_compute_channel_F1_O1", {e}, - b2.opts().WithName("outside_compilation_F1_O1_send")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* send1 = SendFromHost( + ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {e}, + b2.opts().WithName("outside_compilation_F1_O1_send")); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInput(send1), "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(send1)); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Unary(call1, b2.opts().WithName("G").WithControlInput(s1)); + Unary(call1, b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1294,7 +1373,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { "BinaryTest", {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {}, {{"Tinputs", gtl::ArraySlice({})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, @@ -1313,20 +1392,24 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { Node* a = InputShaped(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); Node* recv1 = - RecvAtHost("host_compute_channel_F1_O1", {}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); + RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", + {}, b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1)); - Node* send1 = - SendFromHost("host_compute_channel_F1_O1", {e}, - b2.opts().WithName("outside_compilation_F1_O1_send")); + Node* send1 = SendFromHost( + ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {e}, + b2.opts().WithName("outside_compilation_F1_O1_send")); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Unary(call1, b2.opts().WithName("G").WithControlInput(s1)); + Unary(call1, b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1368,7 +1451,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"D:o:0"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({})}, @@ -1385,16 +1468,20 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* recv1 = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost( + ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Unary(recv1, b2.opts().WithName("E")); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInput(recv1), "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(recv1)); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1)); + Binary(e, call1, b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1441,7 +1528,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({})}, @@ -1458,21 +1545,26 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* recv1 = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost( + ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Unary(recv1, b2.opts().WithName("E")); - Node* send1 = SendFromHost("host_compute_channel_F1_O1", {}, + Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), + "host_compute_channel_F1_O1", {}, b2.opts() .WithName("outside_compilation_F1_O1_send") .WithControlInput(e)); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1)); + Binary(e, call1, b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1569,19 +1661,19 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; - string shape_string_expected; { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); - Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_0")); - Node* recv = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, - shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* key_constant = + KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); + Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_1")); + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {DT_FLOAT}, + shape.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E")); - SendFromHost("host_compute_channel_F1_O1", {e}, - shape.opts().WithName("outside_compilation_F1_O1_send")); - GraphDef shape_graph; - TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); - EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", + {e}, shape.opts().WithName("outside_compilation_F1_O1_send")); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } *library_expected.add_function() = test::function::XTimesTwo(); @@ -1595,12 +1687,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"c:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", shape_string_expected}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, {"shapes", gtl::ArraySlice({})}}, {"c"}}, }, @@ -1614,26 +1707,30 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { Node* b = Input(b2.opts().WithName("B")); Node* c = Unary(a, b2.opts().WithName("C")); - NodeBuilder node_builder("F1", "F1", lib_def.get()); - node_builder.Input(b).Input(c); - Node* call = - b2.opts().WithControlInputs({c}).FinalizeBuilder(&node_builder); - - Node* recv = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = BinaryUnknownShape( c, ops::NodeOut(recv, 0), b2.opts().WithName("E").WithControlInputs({recv, b})); - Node* send = SendFromHost("host_compute_channel_F1_O1", {e}, + Node* send = SendFromHost(ops::NodeOut(key_constant, 0), + "host_compute_channel_F1_O1", {e}, b2.opts() .WithName("outside_compilation_F1_O1_send") .WithControlInput(e)); - Node* s = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send})); + Node* s = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), + "F1"); + + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(b).Input(c); + Node* call = + b2.opts().WithControlInputs({s, c}).FinalizeBuilder(&node_builder); - Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e})); + Binary(a, call, b2.opts().WithName("G").WithControlInputs({e})); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 9bea5663319c8a25249fdc265cee0191556a7c04..616a7f8f1541d3debff97a90bd390c76c665d196 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -14,6 +14,7 @@ cc_library( "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:xla_compilation_cache", "//tensorflow/compiler/jit:xla_device", + "//tensorflow/compiler/jit:xla_launch_util", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 6353149e4afdf739fe44dd5c76502ef5d98b8477..8a8e8bb8df1a8d0a40af054e6713616745224cc8 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -40,111 +41,6 @@ namespace gpu = perftools::gputools; namespace tensorflow { -// Adapter class that wraps a Tensorflow allocator as an XLA allocator. -// Assumes that the Tensorflow allocator permits asynchronous deallocation: -// see comment on `AllowsAsynchronousDeallocation()`. -class XlaAllocator : public xla::DeviceMemoryAllocator { - public: - XlaAllocator(const gpu::Platform* platform, OpKernelContext* op_context); - ~XlaAllocator() override; - xla::StatusOr Allocate(int device_ordinal, uint64 size, - bool retry_on_failure) override; - Status Deallocate(int device_ordinal, gpu::DeviceMemoryBase* mem) override; - - // Register an Tensor (input or resource variable) with the allocator. If - // the operation returns an alias to one of its inputs, then the allocator - // needs to be able to handle it. - Status RegisterArgument(const Tensor* t); - - // Makes 'tensor' a wrapper around the data buffer at 'ptr'. The buffer is - // interpreted as having data type 'dtype' and shape 'shape'. - Status MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, DataType dtype, - const TensorShape& shape, Tensor* tensor) const; - - // The Tensorflow BFC allocator used on GPU allows host-side deallocation - // before GPU execution takes place. Tensorflow uses the ordering of the main - // compute stream to enforce a happens-before relationship between a memory - // allocation and code that reuses the same memory. If Tensorflow adds - // support for multiple GPU streams or allocators with different ordering - // requirements, this code may need to change. - // (This attribute has no effect on CPU.) - bool AllowsAsynchronousDeallocation() const override { return true; } - - private: - OpKernelContext* const op_context_; - - // Map from pointer address to the owning Tensor; used by - // MakeTensorFromBuffer. Also used to automatically release Tensors when the - // allocator is freed. - std::unordered_map tensors_; -}; - -XlaAllocator::XlaAllocator(const gpu::Platform* platform, - OpKernelContext* op_context) - : xla::DeviceMemoryAllocator(platform), op_context_(op_context) {} - -XlaAllocator::~XlaAllocator() = default; - -xla::StatusOr XlaAllocator::Allocate( - int device_ordinal, uint64 size, bool retry_on_failure) { - AllocatorAttributes allocator_attrs; - allocator_attrs.set_on_host(false); - - AllocationAttributes allocation_attrs; - allocation_attrs.no_retry_on_failure = !retry_on_failure; - - Tensor t; - Status status = op_context_->allocate_temp( - DT_UINT8, TensorShape({static_cast(size)}), &t, allocator_attrs, - allocation_attrs); - if (!status.ok()) { - VLOG(2) << "Allocation failed " << size; - return status; - } - void* data = - reinterpret_cast(const_cast(t.tensor_data().data())); - tensors_[data] = t; - return gpu::DeviceMemoryBase(data, size); -} - -Status XlaAllocator::RegisterArgument(const Tensor* t) { - void* data = - reinterpret_cast(const_cast(t->tensor_data().data())); - tensors_[data] = *t; - return Status::OK(); -} - -Status XlaAllocator::Deallocate(int device_ordinal, - gpu::DeviceMemoryBase* mem) { - if (mem->opaque() != nullptr) { - if (tensors_.erase(mem->opaque()) == 0) { - return tensorflow::errors::InvalidArgument("Unknown tensor address"); - } - } - return Status::OK(); -} - -Status XlaAllocator::MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, - DataType dtype, - const TensorShape& shape, - Tensor* out_tensor) const { - void* ptr = const_cast(buffer.opaque()); - auto it = tensors_.find(ptr); - if (it == tensors_.end()) { - return errors::InvalidArgument("Unknown tensor address"); - } - const Tensor& tensor = it->second; - - int64 output_size = DataTypeSize(dtype) * shape.num_elements(); - if (tensor.TotalBytes() == output_size) { - out_tensor->UnsafeCopyFromInternal(tensor, dtype, shape); - } else { - Tensor slice = tensor.Slice(0, output_size); - out_tensor->UnsafeCopyFromInternal(slice, dtype, shape); - } - return Status::OK(); -} - XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) : OpKernel(ctx), device_type_(ctx->device_type()) { const NameAttrList* func; @@ -196,23 +92,6 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, return Status::OK(); } -std::vector SnapshotResourceVariables(OpKernelContext* ctx, - int num_variables) { - std::vector snapshot(num_variables); - int first_variable = ctx->num_inputs() - num_variables; - for (int i = 0; i < num_variables; ++i) { - Var* variable = nullptr; - ResourceHandle handle = HandleFromInput(ctx, first_variable + i); - if (LookupResource(ctx, handle, &variable).ok()) { - tf_shared_lock lock(*variable->mu()); - snapshot[i].name = handle.name(); - snapshot[i].present = true; - snapshot[i].value = *variable->tensor(); - } - } - return snapshot; -} - void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { VLOG(1) << "XlaLocalLaunchOp::Compute " << Canonicalize(function_.name(), AttrSlice(&function_.attr())); @@ -235,16 +114,22 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { // this is more obviously correct.) core::ScopedUnref cache_ref(cache); + const XlaDevice::Metadata* metadata; + Status s = XlaDevice::GetMetadata(ctx, &metadata); + + XlaTensorInfoManager* tensor_info_manager = nullptr; + if (s.ok()) { + tensor_info_manager = &metadata->tensor_info_manager(); + } + // Get the platform_id_ for XLA_* devices. if (platform_id_ == nullptr) { - const XlaDevice::Metadata* metadata; - Status s = XlaDevice::GetMetadata(ctx, &metadata); if (s.ok()) { platform_id_ = metadata->platform()->id(); } } - std::vector variables = + std::map variables = SnapshotResourceVariables(ctx, num_resource_args_); xla::LocalClient* client = static_cast(cache->client()); @@ -263,49 +148,19 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; - OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_, + std::map constant_args; + for (int i = 0; i < num_constant_args_; ++i) { + constant_args.insert({i, ctx->input(i)}); + } + OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args, variables, ctx, &kernel, &executable, /*compile_options=*/nullptr)); VLOG(1) << "Executing XLA Computation..."; - std::unique_ptr output; - // Build xla::ShapedBuffers that point directly to the Tensor buffers. - std::vector> arg_buffers; - arg_buffers.reserve(kernel->xla_input_shapes.size() + 1); - arg_buffers.resize(kernel->xla_input_shapes.size()); - std::vector arg_ptrs(arg_buffers.size()); - - const int first_variable_arg = ctx->num_inputs() - num_resource_args_; - // Pass remaining parameters. - const Tensor* t; - for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { - int arg_num = kernel->input_mapping[i]; - const xla::Shape& shape = kernel->xla_input_shapes[i]; - if (arg_num >= first_variable_arg) { - t = &(variables[arg_num - first_variable_arg].value); - } else { - t = &(ctx->input(arg_num)); - } - - gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase( - const_cast(t->tensor_data().data()), t->tensor_data().size()); - - const xla::Shape on_device_shape = - client->backend().transfer_manager()->HostShapeToDeviceShape(shape); - CHECK(xla::ShapeUtil::Equal(shape, on_device_shape)) - << "On-device shape " - << xla::ShapeUtil::HumanStringWithLayout(on_device_shape) - << " not the same as on-host shape " - << xla::ShapeUtil::HumanStringWithLayout(shape); - arg_buffers[i] = xla::MakeUnique( - /*on_host_shape=*/shape, /*on_device_shape=*/shape, client->platform(), - client->default_device_ordinal()); - arg_buffers[i]->set_buffer(dmem, /*index=*/{}); - arg_ptrs[i] = arg_buffers[i].get(); - - OP_REQUIRES_OK(ctx, xla_allocator.RegisterArgument(t)); - } + XlaComputationLaunchContext launch_context( + num_resource_args_, client, &xla_allocator, tensor_info_manager); + launch_context.PopulateInputs(ctx, kernel, variables); // Execute the computation. VLOG(2) << "Executing computation."; @@ -315,93 +170,13 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); Env* env = Env::Default(); auto start_time = env->NowMicros(); - auto run_result = executable->Run(arg_ptrs, run_options); + auto run_result = executable->Run(launch_context.arguments(), run_options); OP_REQUIRES(ctx, run_result.ok(), run_result.status()); - output = run_result.ConsumeValueOrDie()->release(); auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; - // Computation output should always be a tuple. - if (VLOG_IS_ON(2)) { - VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString(); - } - CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); - - // Copy XLA results to the OpOutputList. - int output_num = 0; - for (int i = 0; i < ctx->num_outputs(); ++i) { - if (kernel->outputs[i].is_constant) { - // Output is a constant. - const Tensor& const_tensor = kernel->outputs[i].constant_value; - const size_t total_bytes = const_tensor.TotalBytes(); - if (stream && total_bytes > 0) { - // Copy host -> device. (Empty tensors don't have backing buffers.) - VLOG(1) << "Constant output tensor on device"; - Tensor* output_tensor; - TF_CHECK_OK( - ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); - - const void* src_ptr = DMAHelper::base(&const_tensor); - void* dst_ptr = DMAHelper::base(output_tensor); - gpu::DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes); - stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes); - } else { - // No copy required. - ctx->set_output(i, const_tensor); - } - } else { - const TensorShape& shape = kernel->outputs[i].shape; - VLOG(2) << "Retval " << i << " shape " << shape.DebugString(); - - gpu::DeviceMemoryBase buffer = output->buffer({output_num}); - Tensor output_tensor; - // Looks up the owning Tensor by buffer address. - OP_REQUIRES_OK(ctx, xla_allocator.MakeTensorFromBuffer( - buffer, ctx->expected_output_dtype(i), shape, - &output_tensor)); - ctx->set_output(i, output_tensor); - ++output_num; - } - - if (VLOG_IS_ON(3)) { - VLOG(3) << ctx->mutable_output(i)->DebugString(); - } - } - - // Apply variable updates, if any. - VLOG(2) << "Applying variable updates"; - for (int i = 0; i < kernel->resource_updates.size(); ++i) { - const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; - OP_REQUIRES(ctx, - write.input_index >= 0 && write.input_index < ctx->num_inputs(), - errors::Internal("Invalid input index for variable write.")); - - gpu::DeviceMemoryBase buffer = output->buffer({output_num}); - - Var* variable = nullptr; - // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, not - // a Tensor. - OP_REQUIRES_OK(ctx, LookupOrCreateResource( - ctx, HandleFromInput(ctx, write.input_index), - &variable, [this, ctx, &write](Var** ptr) { - *ptr = new Var(write.type); - return Status::OK(); - })); - - core::ScopedUnref s(variable); - - mutex_lock ml(*variable->mu()); - OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type, - errors::Internal("Mismatched type in variable write")); - - // Looks up the owning Tensor by buffer address. - OP_REQUIRES_OK( - ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write.shape, - variable->tensor())); - ++output_num; - } - + launch_context.PopulateOutputs(ctx, kernel, run_result.ConsumeValueOrDie()); VLOG(1) << "Done"; } diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h index 47fd912b12abbbe876e933ab57f6f586fd299909..c6cc0986af0300c51283d432c671e92a1e4d8145 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.h +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h @@ -26,14 +26,6 @@ limitations under the License. namespace tensorflow { -// Takes a snapshot of the values of resource variable arguments, which are -// the last `num_variables` arguments. We snapshot tensors that back -// resource variables since concurrent updates may modify the shape, and it is -// important that the shapes used for compilation match the true shapes of the -// buffers. -std::vector SnapshotResourceVariables(OpKernelContext* ctx, - int num_variables); - // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph // which will be compiled and executed using XLA. The XlaLocalLaunchOp is // responsible for handling interactions with the TensorFlow executor. diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD index 4491dd6ac8f2b84f341162eb469cc8194f817c9a..9cd66fc13c9e0658fdf105d5d9d92f0320ddd179 100644 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ b/tensorflow/compiler/jit/legacy_flags/BUILD @@ -52,6 +52,18 @@ cc_library( ], ) +cc_library( + name = "xla_device_flags", + srcs = ["xla_device_flags.cc"], + hdrs = ["xla_device_flags.h"], + deps = + [ + "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc index 4bc209b7ecf499d82e7567f7eff12b17cefa9863..7277a1d1f8ad5fa045645ead839ab9efa01e89c7 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc @@ -40,6 +40,8 @@ static void AllocateFlags() { flags->tf_xla_max_cluster_size = std::numeric_limits::max(); flags->tf_xla_clustering_debug = false; flags->tf_xla_cpu_global_jit = false; + flags->tf_xla_clustering_fuel = std::numeric_limits::max(); + flags->tf_xla_fusion_only = false; flag_list = new std::vector( {Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit, "Control compilation of operators into XLA computations on CPU and " @@ -55,7 +57,13 @@ static void AllocateFlags() { Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug, "Dump graphs during XLA compilation."), Flag("tf_xla_cpu_global_jit", &flags->tf_xla_cpu_global_jit, - "Enables global JIT compilation for CPU via SessionOptions.")}); + "Enables global JIT compilation for CPU via SessionOptions."), + Flag("tf_xla_clustering_fuel", &flags->tf_xla_clustering_fuel, + "Places an artificial limit on the number of ops marked as " + "eligible for clustering."), + Flag("tf_xla_fusion_only", &flags->tf_xla_fusion_only, + "enable fusion of element-wise operations only using XLA when " + "global_jit_level is ON*.")}); xla::legacy_flags::ParseFlagsFromEnv(*flag_list); } diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h index e1ccd7ddb8706ca445b6811ca1fec369af7cd5d5..2affda6ab4e0fbad32a246744fa5b38aeb629c1b 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h @@ -48,6 +48,13 @@ typedef struct { bool tf_xla_clustering_debug; // Dump graphs during XLA compilation. bool tf_xla_cpu_global_jit; // Enables global JIT compilation for CPU // via SessionOptions. + int64 tf_xla_clustering_fuel; // "Compiler fuel" for clustering. Only this + // many ops will be marked as eligible for + // clustering. + bool tf_xla_fusion_only; // This flag is effective only when global_jit_level + // is set to ON* and overrides its behavior. If + // true, enable fusion of element-wise operations + // only using XLA. } MarkForCompilationPassFlags; // Return a pointer to the MarkForCompilationPassFlags struct; diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc new file mode 100644 index 0000000000000000000000000000000000000000..1bb2fce2dbad5bffce2e33b665b7222090d0855a --- /dev/null +++ b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for the XLA bridge's xla_device module. + +#include +#include + +#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static XlaDeviceFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new XlaDeviceFlags; + flags->tf_xla_compile_on_demand = false; + flag_list = new std::vector({ + Flag("tf_xla_compile_on_demand", &flags->tf_xla_compile_on_demand, + "Switch a device into 'on-demand' mode, where instead of " + "autoclustering ops are compiled one by one just-in-time."), + }); + xla::legacy_flags::ParseFlagsFromEnv(*flag_list); +} + +// Return a pointer to the XlaDeviceFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +XlaDeviceFlags* GetXlaDeviceFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h new file mode 100644 index 0000000000000000000000000000000000000000..27b22121ac1e089bd5d5a494e1e3fb60b05bc76d --- /dev/null +++ b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ +#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ + +// Legacy flags for the XLA bridge's xla_device module. + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace legacy_flags { + +// The values of flags associated with the XLA bridge's +// xla_device module. +typedef struct { + // Switch the CPU device into "on-demand" mode, where instead of + // autoclustering ops are compiled one by one just-in-time. + // Enabling this mode by a legacy flag is a temporary mechanism. When this + // feature is battle-tested, we will switch this to be a session option. + bool tf_xla_compile_on_demand; +} XlaDeviceFlags; + +// Return a pointer to the XlaDeviceFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +XlaDeviceFlags* GetXlaDeviceFlags(); + +} // namespace legacy_flags +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index a0211acbbe9eec77d30c7d14293650de8826f41c..f651768a67278628e40445291d7fb271bb1ae611 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -174,10 +174,164 @@ bool HasResourceInputOrOutput(const Node& node) { } struct NodeCompare { - bool operator()(const Node* a, const Node* b) { return a->id() < b->id(); } + bool operator()(const Node* a, const Node* b) const { + return a->id() < b->id(); + } }; using OrderedNodeSet = std::set; +// Returns true if the op can be decomposed into XLA ops for which +// there are fusable elemental implementations. +// +// TODO(hpucha): Consider a black list instead of a white list as +// implemented below. +bool IsXlaFusable(const NodeDef& node) { + static const std::unordered_set* elementwise_ops = + new std::unordered_set( + {// tf2xla/kernels/aggregate_ops.cc + "AddN", + // tf2xla/kernels/batchtospace_op.cc + "BatchToSpace", "BatchToSpaceND", + // tf2xla/kernels/bcast_ops.cc + "BroadcastArgs", "BroadcastGradientArgs", + // tf2xla/kernels/bias_ops.cc + "BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/, + // tf2xla/kernels/binary_ops.cc + "Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv", + "FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift", + "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv", + "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference", + "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater", + "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", + "SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual", + // tf2xla/kernels/cast_op.cc + "Cast", + // tf2xla/kernels/categorical_op.cc + "Multinomial" /* (Rng ops are disabled on GPU backend currently)*/, + // tf2xla/kernels/concat_op.cc + "Concat", "ConcatV2", "ConcatOffset", + // tf2xla/kernels/const_op.cc + "Const", + // tf2xla/kernels/cross_op.cc + "Cross", + // tf2xla/kernels/depthtospace_op.cc + "DepthToSpace", + // tf2xla/kernels/diag_op.cc + "Diag", "DiagPart", "MatrixDiag", "MatrixDiagPart", + // tf2xla/kernels/dynamic_stitch_op.cc + "DynamicStitch", "ParallelDynamicStitch", + // tf2xla/kernels/elu_op.cc + "Elu", "EluGrad", "Selu", "SeluGrad", + // tf2xla/kernels/fake_quantize_ops.cc + "FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxArgsGradient", + "FakeQuantWithMinMaxVars", + "FakeQuantWithMinMaxVarsGradient" /*(Reduce)*/, + // tf2xla/kernels/fill_op.cc + "Fill", + // tf2xla/kernels/gather_op.cc + "Gather", "GatherV2", "GatherNd", + // tf2xla/kernels/identity_op.cc + "Identity", "IdentityN", "PreventGradient", "StopGradient", + "Snapshot", + // tf2xla/kernels/image_ops.cc + "RGBToHSV", "HSVToRGB", "AdjustContrastv2" /*(Reduce)*/, + "AdjustSaturation", "AdjustHue", + // tf2xla/kernels/index_ops.cc + "ArgMax", "ArgMin", + // tf2xla/kernels/l2loss_op.cc + "L2Loss" /*(Reduce)*/, + // tf2xla/kernels/lrn_ops.cc (ReduceWindow) + "LRN", "LRNGrad", + // tf2xla/kernels/matrix_band_part_op.cc + "MatrixBandPart", + // tf2xla/kernels/matrix_set_diag_op.cc + "MatrixSetDiag", + // tf2xla/kernels/mirror_pad_op.cc + "MirrorPad", + // tf2xla/kernels/no_op.cc + "NoOp", "ControlTrigger", + // tf2xla/kernels/one_hot_op.cc + "OneHot", + // tf2xla/kernels/pack_op.cc + "Pack", + // tf2xla/kernels/pad_op.cc + "Pad", "PadV2", + // tf2xla/kernels/pooling_ops.cc + "MaxPool", "MaxPoolV2", "MaxPool3D", "AvgPool", + "AvgPool3D", /*(all the pooling ops use ReduceWindow)*/ + "MaxPoolGrad", "MaxPoolGradV2", "MaxPool3DGrad", "AvgPoolGrad", + "AvgPool3DGrad", + // tf2xla/kernels/quantize_and_dequantize_op.cc (Reduce) + "QuantizeAndDequantizeV2", + // tf2xla/kernels/random_ops.cc (Rng ops are disabled on GPU backend + // currently) + "RandomUniform", "RandomUniformInt", "RandomStandardNormal", + "TruncatedNormal", + // tf2xla/kernels/reduction_ops.cc (Reduce) + "Sum", "Prod", "Min", "Max", "Mean", "All", "Any", + // tf2xla/kernels/relu_op.cc + "Relu", "Relu6", "ReluGrad", "Relu6Grad", + // tf2xla/kernels/reshape_op.cc + "Reshape", + // tf2xla/kernels/reverse_op.cc + "Reverse", "ReverseV2", + // tf2xla/kernels/reverse_sequence_op.cc + "ReverseSequence", + // tf2xla/kernels/scan_ops.cc (ReduceWindow) + "Cumsum", "Cumprod", + // tf2xla/kernels/scatter_nd_op.cc (Reduce) + "ScatterNd", + // tf2xla/kernels/segment_reduction_ops.cc (Reduce) + "UnsortedSegmentSum", + // tf2xla/kernels/select_op.cc + "Select", + // tf2xla/kernels/sequence_ops.cc + "Range", "LinSpace", + // tf2xla/kernels/shape_op.cc + "Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze", + "ZerosLike", "OnesLike", + // tf2xla/kernels/slice_op.cc + "Slice", + // tf2xla/kernels/softmax_op.cc (Reduce) + "Softmax", "LogSoftmax", "SoftmaxCrossEntropyWithLogits", + "SparseSoftmaxCrossEntropyWithLogits", + // tf2xla/kernels/spacetobatch_op.cc + "SpaceToBatchND", "SpaceToBatch", + // tf2xla/kernels/spacetodepth_op.cc + "SpaceToDepth", + // tf2xla/kernels/split_op.cc + "Split", "SplitV", + // tf2xla/kernels/stack_ops.cc + "StackV2", "StackPushV2", "StackPopV2", "StackCloseV2", + // tf2xla/kernels/stateless_random_ops.cc (Rng ops are disabled on + // GPU + // backend currently) + "StatelessRandomUniform", + "StatelessRandomNormal" + // tf2xla/kernels/strided_slice_op.cc + "StridedSlice", + "StridedSliceGrad", "ResourceStridedSliceAssign", + // tf2xla/kernels/tile_ops.cc + "Tile", + // tf2xla/kernels/training_ops.cc + "ResourceApplyGradientDescent", "ResourceApplyMomentum", + "ResourceApplyAdagrad", "ResourceApplyAdam", "ResourceApplyRMSProp", + "ResourceApplyFtrl", "ResourceApplyFtrlV2", + // tf2xla/kernels/transpose_op.cc + "Transpose", "InvertPermutation", + // tf2xla/kernels/unary_ops.cc + "ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin", + "Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", + "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", + "Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round", + "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt", + "Square", "Tan", "Tanh", "Real", "Imag", + // tf2xla/kernels/unpack_op.cc + "Unpack"}); + + return elementwise_ops->count(node.op()) > 0; +} + Status FindCompilationCandidates( const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, const std::function& is_compilable_fn, @@ -189,7 +343,27 @@ Status FindCompilationCandidates( FunctionLibraryRuntime* lib_runtime = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + int64& fuel = + legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel; + + // Iterate over nodes in sorted order so that compiler fuel is deterministic. + // We can't simply pass op_nodes().begin() and op_nodes().end to the + // std::vector constructor because they're not proper iterators, with + // iterator_traits defined and so on. + std::vector sorted_nodes; for (Node* node : graph.op_nodes()) { + sorted_nodes.push_back(node); + } + std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeCompare()); + + for (Node* node : sorted_nodes) { + VLOG(2) << "Fuel: " << fuel; + if (fuel <= 0) { + VLOG(2) + << "Hit fuel limit; not marking any remaining ops as clusterable."; + break; + } + VLOG(2) << "FindCompilationCandidates(): Processing " << node->DebugString(); @@ -234,7 +408,9 @@ Status FindCompilationCandidates( continue; } candidates->insert(node); + --fuel; } + VLOG(2) << "candidates->size() = " << candidates->size(); return Status::OK(); } @@ -314,10 +490,13 @@ Status MarkForCompilationPass::Run( static_cast(flags->tf_xla_auto_jit); } bool cpu_global_jit = flags->tf_xla_cpu_global_jit; + bool fusion_only = flags->tf_xla_fusion_only; + VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; + VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; const FunctionLibraryDefinition* fld = options.flib_def; - auto is_compilable = [global_jit_level, cpu_global_jit, fld]( + auto is_compilable = [global_jit_level, cpu_global_jit, fusion_only, fld]( const Node* node, const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), @@ -340,6 +519,11 @@ Status MarkForCompilationPass::Run( status = fld->GetAttr(*node, kXlaCompileAttr, &compile); if (status.ok()) return compile; + // Check for fusable ops only if requested. + if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) { + return false; + } + // Otherwise use the value of global_jit_level. // Ignore enable_jit_by_default if global jit compilation for CPU // is explicitly requested via tf_xla_cpu_global_jit flag diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 6d854a920eb0b4c01b09024ceaef5035e847d392..6430975335f5eef5b53c80213e6090ffd6166a91 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -92,38 +92,30 @@ uint64 XlaCompilationCache::Signature::Hash::operator()( } Status XlaCompilationCache::BuildSignature( - const NameAttrList& function, int num_constant_args, - const std::vector& variable_args, OpKernelContext* ctx, + const NameAttrList& function, const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, Signature* signature) { signature->name = Canonicalize(function.name(), AttrSlice(&function.attr())); - signature->arg_values.resize(num_constant_args); - - signature->arg_types.reserve(ctx->num_inputs() - num_constant_args); - - // Inputs are in the order: constants, non-constants, resource variables. - int input_num = 0; - // Use the values of compile time constants in the signature-> - while (input_num < num_constant_args) { - signature->arg_values[input_num] = ctx->input(input_num); - ++input_num; - } - // Add the types and shapes of the remaining arguments. - while (input_num < ctx->num_inputs() - variable_args.size()) { - signature->arg_types.emplace_back(ctx->input_dtype(input_num), - ctx->input(input_num).shape()); - ++input_num; - } - // For variable signatures, use the type and shape of the variable's - // current value. - for (const OptionalTensor& variable : variable_args) { - TF_RET_CHECK(input_num < ctx->num_inputs()); - if (variable.present) { - signature->arg_types.emplace_back(variable.value.dtype(), - variable.value.shape()); + signature->arg_values.reserve(constant_args.size()); + + signature->arg_types.reserve(ctx->num_inputs() - constant_args.size()); + + for (int i = 0; i < ctx->num_inputs(); ++i) { + if (constant_args.count(i) > 0) { + // Use the values of compile time constants in the signature. + signature->arg_values.push_back(constant_args.at(i)); + } else if (variable_args.count(i) > 0) { + const OptionalTensor& variable = variable_args.at(i); + if (variable.present) { + signature->arg_types.emplace_back(variable.value.dtype(), + variable.value.shape()); + } else { + signature->arg_types.emplace_back(DT_INVALID, TensorShape()); + } } else { - signature->arg_types.emplace_back(DT_INVALID, TensorShape()); + signature->arg_types.emplace_back(ctx->input_dtype(i), + ctx->input(i).shape()); } - ++input_num; } return Status::OK(); } @@ -131,74 +123,58 @@ Status XlaCompilationCache::BuildSignature( namespace { // Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch -// op. The first `num_constant_args` arguments must be host-memory Tensors. -Status BuildArguments(int num_constant_args, - const std::vector& variable_args, +// op. +Status BuildArguments(const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, std::vector* args) { args->resize(ctx->num_inputs()); - int input_num = 0; - - // Handles compile-time constants. - TF_RET_CHECK(num_constant_args <= ctx->num_inputs()); - while (input_num < num_constant_args) { - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); - XlaCompiler::Argument& arg = (*args)[input_num]; - arg.kind = XlaCompiler::Argument::kConstant; - arg.type = input.dtype(); - arg.shape = input.shape(); - arg.constant_value = input; - ++input_num; - } - - // Handles the non-constant arguments. - int num_variable_args = variable_args.size(); - int num_nonconst_args = - ctx->num_inputs() - num_variable_args - num_constant_args; - TF_RET_CHECK(num_nonconst_args >= 0); - while (input_num < num_constant_args + num_nonconst_args) { - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); + for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) { XlaCompiler::Argument& arg = (*args)[input_num]; - if (input.NumElements() > 0) { - arg.kind = XlaCompiler::Argument::kParameter; - } else { + if (constant_args.count(input_num) > 0) { + // Handles compile-time constants. + const Tensor& input = constant_args.at(input_num); + TF_RET_CHECK(input.dtype() != DT_RESOURCE); arg.kind = XlaCompiler::Argument::kConstant; + arg.type = input.dtype(); + arg.shape = input.shape(); arg.constant_value = input; - } - arg.type = input.dtype(); - arg.shape = input.shape(); - ++input_num; - } - - // Handles resource variables. - TF_RET_CHECK(input_num + num_variable_args == ctx->num_inputs()); - for (int variable_id = 0; variable_id < num_variable_args; ++variable_id) { - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() == DT_RESOURCE); - - XlaCompiler::Argument& arg = (*args)[input_num]; - - arg.name = variable_args[variable_id].name; - arg.kind = XlaCompiler::Argument::kResource; - arg.resource_kind = XlaResource::kVariable; - if (variable_args[variable_id].present) { - const Tensor& value = variable_args[variable_id].value; - arg.type = value.dtype(); - arg.shape = value.shape(); - arg.initialized = true; + } else if (variable_args.count(input_num) == 0) { + // Handles the non-constant arguments. + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() != DT_RESOURCE); + if (input.NumElements() > 0) { + arg.kind = XlaCompiler::Argument::kParameter; + } else { + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = input; + } + arg.type = input.dtype(); + arg.shape = input.shape(); } else { - // The values of uninitialized variables are not passed as inputs, since - // they are meaningless. However, it is legal to assign to a resource - // variable for the first time inside the XLA computation, so we do permit - // uninitialized variables. - arg.initialized = false; - arg.type = DT_INVALID; - arg.shape = TensorShape(); + // Handles resource variables. + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() == DT_RESOURCE); + const OptionalTensor& variable = variable_args.at(input_num); + arg.name = variable.name; + arg.kind = XlaCompiler::Argument::kResource; + arg.resource_kind = XlaResource::kVariable; + if (variable.present) { + const Tensor& value = variable.value; + arg.type = value.dtype(); + arg.shape = value.shape(); + arg.initialized = true; + } else { + // The values of uninitialized variables are not passed as inputs, since + // they are meaningless. However, it is legal to assign to a resource + // variable for the first time inside the XLA computation, so we do + // permit uninitialized variables. + arg.initialized = false; + arg.type = DT_INVALID; + arg.shape = TensorShape(); + } } - ++input_num; } return Status::OK(); @@ -233,16 +209,43 @@ Status XlaCompilationCache::BuildExecutable( Status XlaCompilationCache::Compile( const XlaCompiler::Options& options, const NameAttrList& function, - int num_constant_args, const std::vector& variable_args, - OpKernelContext* ctx, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, const XlaCompiler::CompileOptions* compile_options) { + return CompileImpl(options, function, constant_args, variable_args, ctx, + compilation_result, executable, compile_options, false); +} + +Status XlaCompilationCache::CompileSingleOp( + const XlaCompiler::Options& options, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options) { + const NodeDef& def = ctx->op_kernel().def(); + NameAttrList name; + name.set_name(def.op()); + *name.mutable_attr() = def.attr(); + return CompileImpl(options, name, constant_args, variable_args, ctx, + compilation_result, executable, compile_options, true); +} + +Status XlaCompilationCache::CompileImpl( + const XlaCompiler::Options& options, const NameAttrList& function, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options, + bool compile_single_op) { VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { VLOG(2) << "num_inputs=" << ctx->num_inputs() - << " num_constant_args=" << num_constant_args + << " num_constant_args=" << constant_args.size() << " num_variable_args=" << variable_args.size(); for (int i = 0; i < ctx->num_inputs(); i++) { TensorShape shape = ctx->input(i).shape(); @@ -250,10 +253,12 @@ Status XlaCompilationCache::Compile( << " present=" << ctx->has_input(i) << " shape=" << shape.DebugString(); } - for (const OptionalTensor& variable : variable_args) { + for (auto& iterator : variable_args) { + const OptionalTensor& variable = iterator.second; VLOG(2) << "variable present=" << variable.present << " type=" << DataTypeString(variable.value.dtype()) - << " shape=" << variable.value.shape().DebugString(); + << " shape=" << variable.value.shape().DebugString() + << " TF arg= " << iterator.first; } VLOG(2) << "num_outputs = " << ctx->num_outputs(); for (int i = 0; i < ctx->num_outputs(); i++) { @@ -261,11 +266,12 @@ Status XlaCompilationCache::Compile( } } - TF_RET_CHECK(num_constant_args + variable_args.size() <= ctx->num_inputs()); + TF_RET_CHECK(constant_args.size() + variable_args.size() <= + ctx->num_inputs()); Signature signature; - TF_RETURN_IF_ERROR(BuildSignature(function, num_constant_args, variable_args, - ctx, &signature)); + TF_RETURN_IF_ERROR( + BuildSignature(function, constant_args, variable_args, ctx, &signature)); VLOG(2) << "Signature: " << SignatureDebugString(signature); // The outer lock protects the existence of the cache entry. It does not @@ -292,13 +298,20 @@ Status XlaCompilationCache::Compile( // a long time.) std::vector args; TF_RETURN_IF_ERROR( - BuildArguments(num_constant_args, variable_args, ctx, &args)); + BuildArguments(constant_args, variable_args, ctx, &args)); XlaCompiler compiler(options); entry->compiled = true; - entry->compilation_status = compiler.CompileFunction( - compile_options ? *compile_options : XlaCompiler::CompileOptions(), - function, args, &entry->compilation_result); + + if (compile_single_op) { + entry->compilation_status = compiler.CompileSingleOp( + compile_options ? *compile_options : XlaCompiler::CompileOptions(), + signature.name, ctx, args, &entry->compilation_result); + } else { + entry->compilation_status = compiler.CompileFunction( + compile_options ? *compile_options : XlaCompiler::CompileOptions(), + function, args, &entry->compilation_result); + } } *compilation_result = &entry->compilation_result; if (entry->compilation_status.ok() && executable) { diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 0858020716fcf4763e42dc0699ad22cfda756942..5c0c79b880c474969464f23b4485734c404cef07 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -52,8 +52,8 @@ class XlaCompilationCache : public ResourceBase { // Compiles a function into a XlaCompiler::CompilationResult that can be used // to execute an XLA Computation. Compilation results are cached. // `function` is the name of a Tensorflow function to compile. - // `num_constant_args` is the number of compile-time constant arguments to - // `function`. `variable_args` is a snapshot of the current values of the + // `constant_args` is a maps of tensorflow argument number to constant value. + // `variable_args` is a snapshot of the current values of the // resource variable arguments to `function`; uninitialized variables are // represented by an absent OptionalTensor. // The result of compilation is written to `*compilation_result`, which must @@ -62,19 +62,40 @@ class XlaCompilationCache : public ResourceBase { // executable pointer may be null if the computation has no non-constant // outputs. Status Compile(const XlaCompiler::Options& options, - const NameAttrList& function, int num_constant_args, - const std::vector& variable_args, + const NameAttrList& function, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, const XlaCompiler::CompileOptions* compile_options); + // As above, but calls XlaCompiler::CompileSingleOp instead of + // XlaCompiler::CompileFunction. + Status CompileSingleOp( + const XlaCompiler::Options& options, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options); + xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } string DebugString() override; private: + // Common implementation of Compile and CompileSingleOp. + Status CompileImpl(const XlaCompiler::Options& options, + const NameAttrList& function, + const std::map& constant_args, + const std::map& variable_args, + OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options, + bool compile_single_op); // Takes `result` which has been compiled from a Tensorflow subgraph to a // XLA computation already, and generates an XLA LocalExecutable `executable`. Status BuildExecutable(const XlaCompiler::Options& options, @@ -104,8 +125,9 @@ class XlaCompilationCache : public ResourceBase { static string SignatureDebugString(const Signature& sig); // Builds the signature for a compilation. - Status BuildSignature(const NameAttrList& function, int num_constant_args, - const std::vector& variable_args, + Status BuildSignature(const NameAttrList& function, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, Signature* signature); // The value associated with a cache entry. diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..915b9ce84ab8268ef4e652351bc981aa5bf7b10c --- /dev/null +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -0,0 +1,178 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Defines the XlaCompileOnDemandOp. + +#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { + +namespace { +std::map GetVariables(OpKernelContext* ctx) { + std::map variables; + for (int64 i = 0; i < ctx->num_inputs(); ++i) { + if (ctx->input(i).dtype() == DT_RESOURCE) { + Var* variable = nullptr; + ResourceHandle handle = HandleFromInput(ctx, i); + OptionalTensor& optional = variables[i]; + optional.name = handle.name(); + if (LookupResource(ctx, handle, &variable).ok()) { + tf_shared_lock lock(*variable->mu()); + optional.present = true; + optional.value = *variable->tensor(); + } + } + } + return variables; +} +} // namespace + +Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, + const XlaDevice::Metadata& metadata, + const XlaCompiler::CompilationResult* result, + xla::LocalExecutable* executable) { + std::map variables = GetVariables(ctx); + int64 num_resource_args = variables.size(); + + xla::LocalClient* client = metadata.client(); + XlaTensorInfoManager* tensor_info_manager = &metadata.tensor_info_manager(); + + // Builds an XLA allocator for the device. + XlaAllocator xla_allocator(client->platform(), ctx); + XlaComputationLaunchContext launch_context( + num_resource_args, client, &xla_allocator, tensor_info_manager); + + launch_context.PopulateInputs(ctx, result, variables); + + perftools::gputools::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + TF_RET_CHECK(stream); + + VLOG(2) << "Executing computation."; + xla::ExecutableRunOptions run_options; + run_options.set_stream(stream); + run_options.set_allocator(&xla_allocator); + run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + + auto run_result = executable->Run(launch_context.arguments(), run_options); + TF_RETURN_IF_ERROR(run_result.status()); + + launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie()); + return Status::OK(); +} + +bool XlaCompileOnDemandOp::MustArgumentBeConstant(const OpKernel* op_kernel, + int64 argument_idx) { + // TODO(jmolloy): This could be expensive, so memoize. + auto* constant_inputs = tensorflow::XlaOpRegistry::CompileTimeConstantInputs( + op_kernel->def().op()); + CHECK(constant_inputs); + std::set constant_input_indices; + for (const auto& name : *constant_inputs) { + int start, stop; + TF_CHECK_OK(op_kernel->InputRange(name, &start, &stop)); + for (int i = start; i < stop; ++i) { + constant_input_indices.insert(i); + } + } + return constant_input_indices.count(argument_idx) > 0; +} + +bool XlaCompileOnDemandOp::ShouldArgumentBeConstant(const OpKernel* op_kernel, + int64 argument_idx) { + // Right now we only create kConstant arguments when absolutely required, but + // there may be benefit in eagerly constant-folding a larger subset of + // arguments in the future. + return MustArgumentBeConstant(op_kernel, argument_idx); +} + +Status XlaCompileOnDemandOp::Compile( + OpKernelContext* ctx, const XlaDevice::Metadata& metadata, + const XlaCompiler::CompilationResult** result, + xla::LocalExecutable** executable) { + XlaTensorInfoManager* tensor_info_manager = &metadata.tensor_info_manager(); + + std::map constant_arguments; + for (int64 i = 0; i < ctx->num_inputs(); ++i) { + const Tensor& device_tensor = ctx->input(i); + if (const XlaTensorInfo* tensor_info = + tensor_info_manager->GetTensorInfo(device_tensor)) { + if (tensor_info->has_host_tensor() && + ShouldArgumentBeConstant(&ctx->op_kernel(), i)) { + constant_arguments[i] = tensor_info->host_tensor(); + } + } + if (constant_arguments.count(i) == 0 && + MustArgumentBeConstant(&ctx->op_kernel(), i)) { + // Slow path; the argument is not available as a host constant so we must + // fetch it synchronously. + Tensor host_tensor; + TF_RETURN_IF_ERROR(ctx->allocate_temp( + device_tensor.dtype(), device_tensor.shape(), &host_tensor)); + Notification n; + ctx->op_device_context()->CopyDeviceTensorToCPU( + &device_tensor, "ConstantArgument", + reinterpret_cast(ctx->device()), &host_tensor, + [&](Status status) { n.Notify(); }); + n.WaitForNotification(); + constant_arguments[i] = host_tensor; + } + } + + // We store information about the JIT-compiled XLA computation + // in the ResourceMgr. + ResourceMgr* rm = ctx->resource_manager(); + CHECK(rm); + + XlaCompilationCache* cache; + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + rm->default_container(), "xla_cache", &cache, + [&](XlaCompilationCache** cache) { + *cache = new XlaCompilationCache(metadata.client(), + metadata.jit_device_type()); + return Status::OK(); + })); + // Hold the reference to the JIT during evaluation. (We could probably + // free it sooner because the ResourceMgr will retain a reference, but + // this is more obviously correct.) + core::ScopedUnref cache_ref(cache); + + XlaCompiler::Options options; + DeviceType device_type = metadata.jit_device_type(); + options.device_type = &device_type; + options.client = metadata.client(); + options.flib_def = + new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{}); + + std::map variable_args = GetVariables(ctx); + return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, + result, executable, + /*compile_options=*/nullptr); +} + +void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { + const XlaCompiler::CompilationResult* result; + xla::LocalExecutable* executable; + const XlaDevice::Metadata* metadata; + OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata)); + OP_REQUIRES_OK(ctx, Compile(ctx, *metadata, &result, &executable)); + OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable)); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h new file mode 100644 index 0000000000000000000000000000000000000000..23c6f3903f841a6c39104983c6f7f409757a7319 --- /dev/null +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The XlaCompileOnDemandOp is an OpKernel that, when its Compute method is +// called, will generate an xla::Computation and run it asynchronously. + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ + +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// An OpKernel that compiles an op to an XLA computation and runs it. Unlike +// _XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a +// vanilla TensorFlow op as long as the bridge supports it. +// +// Importantly _XlaLaunch assumes all input and output tensors are on the host, +// whereas XlacompileOnDemandOp works with tensors in device memory. +class XlaCompileOnDemandOp : public OpKernel { + public: + explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override; + + private: + XlaCompiler::Argument CreateCompilerArgument(OpKernelContext* ctx, int64 i); + bool ShouldArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx); + bool MustArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx); + Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata, + const XlaCompiler::CompilationResult** result, + xla::LocalExecutable** executable); + Status Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata, + const XlaCompiler::CompilationResult* result, + xla::LocalExecutable* executable); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index e238252751e677eb947f6df03e3b2f2e948ffe19..d2dfdeea68129b536477aa75f66c9d267f5a9434 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -17,6 +17,8 @@ limitations under the License. // operators using XLA via the XLA "Host" (CPU) backend. #include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" +#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -34,14 +36,24 @@ class XlaCpuDeviceFactory : public DeviceFactory { Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector* devices) { + legacy_flags::XlaDeviceFlags* flags = legacy_flags::GetXlaDeviceFlags(); + bool compile_on_demand = flags->tf_xla_compile_on_demand; + + XlaOpRegistry::DeviceRegistration registration; + registration.compilation_device_name = DEVICE_CPU_XLA_JIT; + registration.requires_compilation = !compile_on_demand; + registration.enable_jit_by_default = false; + registration.compile_resource_ops = true; + static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(DEVICE_XLA_CPU, DEVICE_CPU_XLA_JIT); (void)registrations; std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create( - "Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options, name_prefix, - /*register_device_for_compilation=*/true, &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, + DEVICE_CPU_XLA_JIT, options, name_prefix, + registration, + /*transfer_as_literal=*/false, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index d4d8fe1c1d575b4e35d624621cc709e3a16569d5..82048f5d78957dfeaf9656d332374ba86a5e920b 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" @@ -108,21 +109,15 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( /* static */ Status XlaDevice::Create( const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, const SessionOptions& options, - const string& name_prefix, bool register_device_for_compilation, - std::unique_ptr* device) { + const string& name_prefix, + const XlaOpRegistry::DeviceRegistration& registration, + bool transfer_as_literal, std::unique_ptr* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" << device_ordinal; - if (register_device_for_compilation) { - // These are no-ops if they have already been done previously for - // this device_name/compilation_device_name pair. - XlaOpRegistry::DeviceRegistration registration; - registration.compilation_device_name = jit_device_name; - registration.requires_compilation = true; - registration.enable_jit_by_default = false; - registration.compile_resource_ops = true; - XlaOpRegistry::RegisterCompilationDevice(device_name, registration); - } + // These are no-ops if they have already been done previously for + // this device_name/compilation_device_name pair. + XlaOpRegistry::RegisterCompilationDevice(device_name, registration); auto platform = se::MultiPlatformManager::PlatformWithName(platform_name); if (!platform.ok()) { @@ -137,15 +132,17 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( device->reset(new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name), - platform.ValueOrDie())); + platform.ValueOrDie(), transfer_as_literal)); return Status::OK(); } -XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform, - const DeviceType& device_type) +XlaDevice::Metadata::Metadata( + int device_ordinal, se::Platform* platform, const DeviceType& device_type, + std::unique_ptr* tensor_info_manager) : device_ordinal_(device_ordinal), device_type_(device_type), - platform_(platform) {} + platform_(platform), + tensor_info_manager_(*tensor_info_manager) {} int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; } @@ -160,6 +157,10 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return device_type_; } +XlaTensorInfoManager& XlaDevice::Metadata::tensor_info_manager() const { + return *tensor_info_manager_; +} + /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, const Metadata** metadata) { XlaDevice* xla_device = @@ -177,13 +178,19 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { XlaDevice::XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, - const DeviceType& jit_device_name, se::Platform* platform) + const DeviceType& jit_device_name, se::Platform* platform, + bool transfer_as_literal) : LocalDevice(options, attrs), - xla_metadata_(device_ordinal, platform, jit_device_name), + xla_metadata_( + device_ordinal, platform, jit_device_name, + // Pass tensor_info_manager_ by reference as it is initialized lazily. + &tensor_info_manager_), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(nullptr), - platform_(platform) {} + platform_(platform), + tensor_info_manager_(nullptr), + transfer_as_literal_(transfer_as_literal) {} XlaDevice::~XlaDevice() {} @@ -208,6 +215,7 @@ Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) { xla::Backend* backend = client()->mutable_backend(); xla_allocator_ = XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( backend, device_ordinal_); + tensor_info_manager_.reset(new XlaTensorInfoManager(xla_allocator_)); } return xla_allocator_; } @@ -225,7 +233,11 @@ Status XlaDevice::FillContextMap(const Graph* graph, VLOG(1) << "XlaDevice::FillContextMap"; device_context_map->resize(graph->num_node_ids()); TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - auto ctx = new XlaDeviceContext(stream); + // Call GetAllocator for the side-effect of ensuring the allocator and + // XlaTensorInfoManager is created. + (void)GetAllocator({}); + auto ctx = new XlaDeviceContext(stream, tensor_info_manager_.get(), + transfer_as_literal_); for (Node* n : graph->nodes()) { VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); ctx->Ref(); @@ -273,7 +285,8 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); Notification n; TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - XlaTransferManager manager(stream); + XlaTransferManager manager(stream, tensor_info_manager_.get(), + transfer_as_literal_); manager.CopyCPUTensorToDevice(&parsed, this, ©, [&n, &status](const Status& s) { status = s; @@ -288,19 +301,23 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, const char* jit_device) { + // Any op assigned to the device that isn't rewritten by the graph rewriter + // gets executed by a n XlaCompileOnDemandOp, which compiles it and executes + // it just-in-time. + kernel_factory::OpKernelRegistrar::Factory factory = + [](OpKernelConstruction* context) -> OpKernel* { + return new XlaCompileOnDemandOp(context); + }; XlaOpRegistry::RegisterCompilationKernels(); XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations; - auto dummy_factory = [](OpKernelConstruction* context) -> OpKernel* { - return new XlaDeviceDummyOp(context); - }; for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels( jit_device, /*include_compilation_only_kernels=*/false)) { KernelDef* def = new KernelDef(*jit_def); def->set_device_type(device); registrations->op_kernel_registrars.emplace_back( - new kernel_factory::OpKernelRegistrar(def, "XlaDeviceDummyOp", - dummy_factory)); + new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp", + factory)); } return registrations; } diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index d2ec38293c429f04f088bf3726ba97eb4e4b0dba..9cd9167e523961c0ddd99fbc9ca9bdc20b9be7b5 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -26,6 +26,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ +#include "tensorflow/compiler/jit/xla_tensor_info.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" @@ -48,7 +50,8 @@ class XlaDevice : public LocalDevice { class Metadata { public: Metadata(int device_ordinal, perftools::gputools::Platform* platform, - const DeviceType& device_type); + const DeviceType& device_type, + std::unique_ptr* tensor_info_manager); // The index of the device on this host. int device_ordinal() const; @@ -56,11 +59,13 @@ class XlaDevice : public LocalDevice { perftools::gputools::Platform* platform() const; xla::LocalClient* client() const; const DeviceType& jit_device_type() const; + XlaTensorInfoManager& tensor_info_manager() const; private: const int device_ordinal_; const DeviceType device_type_; perftools::gputools::Platform* platform_; // Not owned. + std::unique_ptr& tensor_info_manager_; TF_DISALLOW_COPY_AND_ASSIGN(Metadata); }; @@ -71,15 +76,20 @@ class XlaDevice : public LocalDevice { // Factory function. 'platform_name' is the name of the XLA platform. // 'device_name' is the name of the Tensorflow device to create. // 'jit_device_name' is the name of the corresponding JIT device. + // 'transfer_as_literal' is true if device<->host transfers must be done using + // XLA's TransferLiteral{To,From}Device interface. If false, we can use + // ThenMemcpy instead. static Status Create(const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, const SessionOptions& options, const string& name_prefix, - bool register_device_for_compilation, + const XlaOpRegistry::DeviceRegistration& registration, + bool transfer_as_literal, std::unique_ptr* device); XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, - ::perftools::gputools::Platform* platform); + ::perftools::gputools::Platform* platform, + bool transfer_as_literal); ~XlaDevice() override; Allocator* GetAllocator(AllocatorAttributes attr) override; @@ -104,7 +114,7 @@ class XlaDevice : public LocalDevice { // Which hardware device in the client's platform this XlaDevice controls. const int device_ordinal_; // The name of the device that is used to compile Ops for this XlaDevice. - const DeviceType& jit_device_name_; + DeviceType jit_device_name_; // Memory allocator associated with this device. Allocator* xla_allocator_; // Not owned. ::perftools::gputools::Platform* platform_; // Not owned. @@ -113,9 +123,19 @@ class XlaDevice : public LocalDevice { // copying back and forth between CPU and the device, and // computations enqueued by XLA. xla::Backend::StreamPtr stream_; + // Manages sideband data about tensors, in particular the on-device shape tree + // if the tensor requires multiple device buffers to represent (for example, + // tuple shapes). + // This is a unique_ptr because XlaTensorInfoManager is non-copy-constructible + // and we need to initialize this lazily (as we also lazily initialize the + // underlying allocator). + std::unique_ptr tensor_info_manager_; + // Must we use XLA's transfer manager for correct host<->device transfers? if + // false, we can use ThenMemcpy() instead. + bool transfer_as_literal_; }; -// Builds dummy OpKernel registrations on 'device' for the JIT operators +// Builds OpKernel registrations on 'device' for the JIT operators // registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations // object that encapsulates the kernel registrations. struct XlaDeviceOpRegistrations { diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index c936222f32056e92efced82d5adb3a96c8041a17..88f7c15f0b74a8c99935647f75352e7dec4689fc 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_context.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" @@ -52,7 +53,12 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } -XlaTransferManager::XlaTransferManager(se::Stream* stream) : stream_(stream) {} +XlaTransferManager::XlaTransferManager( + se::Stream* stream, XlaTensorInfoManager* tensor_info_manager, + bool transfer_as_literal) + : stream_(stream), + tensor_info_manager_(tensor_info_manager), + transfer_as_literal_(transfer_as_literal) {} void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, @@ -72,15 +78,25 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, se::DeviceMemoryBase dev_dst_ptr(dst_ptr, total_bytes); Status status; - stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); - // TODO(hpucha): Make this asynchronous. - Status block_status = stream_->BlockHostUntilDone(); - if (!block_status.ok()) { - status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_, - block_status.error_message().c_str()); + if (transfer_as_literal_) { + status = xla::Unimplemented( + "XlaTransferManager::CopyCPUTensorToDevice not implemented for " + "literals"); + } else { + stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); + // TODO(hpucha): Make this asynchronous. + Status block_status = stream_->BlockHostUntilDone(); + if (!block_status.ok()) { + status = xla::InternalError( + "Failed to complete data transfer on stream %p: %s", stream_, + block_status.error_message().c_str()); + } } + XlaTensorInfo* tensor_info = + tensor_info_manager_->GetOrCreateTensorInfo(*device_tensor); + tensor_info->set_host_tensor(*cpu_tensor); + done(status); return; } @@ -108,13 +124,19 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, void* dst_ptr = DMAHelper::base(cpu_tensor); Status status; - stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); - // TODO(hpucha): Make this asynchronous. - Status block_status = stream_->BlockHostUntilDone(); - if (!block_status.ok()) { - status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_, - block_status.error_message().c_str()); + if (transfer_as_literal_) { + status = xla::Unimplemented( + "XlaTransferManager::CopyDeviceTensorToCPU not implemented for " + "literals"); + } else { + stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); + // TODO(hpucha): Make this asynchronous. + Status block_status = stream_->BlockHostUntilDone(); + if (!block_status.ok()) { + status = xla::InternalError( + "Failed to complete data transfer on stream %p: %s", stream_, + block_status.error_message().c_str()); + } } done(status); @@ -125,7 +147,10 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, done(Status::OK()); } -XlaDeviceContext::XlaDeviceContext(se::Stream* stream) : manager_(stream) {} +XlaDeviceContext::XlaDeviceContext(se::Stream* stream, + XlaTensorInfoManager* tensor_info_manager, + bool transfer_as_literal) + : manager_(stream, tensor_info_manager, transfer_as_literal) {} void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index c4edcd474e48f791af9340c3cd6e4d031407bb68..df02f4eac482f385f8864476d11c5430971f00c8 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/jit/xla_tensor_info.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/framework/allocator.h" @@ -49,7 +50,9 @@ class XlaDeviceAllocator : public Allocator { // Helper class for managing data transfers between host and XLA devices. class XlaTransferManager { public: - explicit XlaTransferManager(perftools::gputools::Stream* stream); + explicit XlaTransferManager(perftools::gputools::Stream* stream, + XlaTensorInfoManager* tensor_info_manager, + bool transfer_as_literal); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; @@ -62,6 +65,10 @@ class XlaTransferManager { // Stream obtained from a Device, used to transfer tensors between // CPU and device. perftools::gputools::Stream* stream_; + // The tensor info manager, for access to sideband information about tensors. + XlaTensorInfoManager* tensor_info_manager_; + // True if we must use XLA's TransferManager for correct device transfers. + bool transfer_as_literal_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -69,7 +76,9 @@ class XlaTransferManager { // wraps the methods in XlaTransferManager. class XlaDeviceContext : public DeviceContext { public: - explicit XlaDeviceContext(perftools::gputools::Stream* stream); + explicit XlaDeviceContext(perftools::gputools::Stream* stream, + XlaTensorInfoManager* tensor_info_manager, + bool transfer_as_literal); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 2326070358d67c0cf30ef17fab5c93862cd8932c..5a1db817745f56d6bcc26ff6fc441b7c902ee2b5 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -34,14 +34,21 @@ class XlaGpuDeviceFactory : public DeviceFactory { Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector* devices) { + XlaOpRegistry::DeviceRegistration registration; + registration.compilation_device_name = DEVICE_GPU_XLA_JIT; + registration.requires_compilation = true; + registration.enable_jit_by_default = false; + registration.compile_resource_ops = true; + static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT); (void)registrations; std::unique_ptr device; - Status status = XlaDevice::Create( - "CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix, - /*register_device_for_compilation=*/true, &device); + Status status = + XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, + name_prefix, registration, + /*transfer_as_literal=*/false, &device); if (!status.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << status; diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 2614deefd8823dcb8f38e9e22ae4e78145d0d96a..9e098c46f422b436c722bb909dc58930ab7c0ef6 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -25,8 +25,8 @@ namespace tensorflow { const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER"; const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT"; -constexpr std::array kExecAllTypes = { - {DT_INT32, DT_FLOAT, DT_BOOL, DT_DOUBLE, DT_INT64}}; +constexpr std::array kExecAllTypes = { + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; class XlaInterpreterDeviceFactory : public DeviceFactory { public: @@ -41,10 +41,17 @@ Status XlaInterpreterDeviceFactory::CreateDevices( DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT); (void)registrations; + XlaOpRegistry::DeviceRegistration registration; + registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; + registration.requires_compilation = true; + registration.enable_jit_by_default = false; + registration.compile_resource_ops = true; + std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create( - "Interpreter", DEVICE_XLA_INTERPRETER, 0, DEVICE_INTERPRETER_XLA_JIT, - options, name_prefix, /*register_device_for_compilation=*/true, &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create("Interpreter", DEVICE_XLA_INTERPRETER, 0, + DEVICE_INTERPRETER_XLA_JIT, options, + name_prefix, registration, + /*transfer_as_literal=*/false, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..bb7316c60c61f8755b6cdd575676fab343f26d11 --- /dev/null +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -0,0 +1,280 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_launch_util.h" + +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/util/stream_executor_util.h" + +namespace gpu = perftools::gputools; + +namespace tensorflow { + +std::map SnapshotResourceVariables(OpKernelContext* ctx, + int num_variables) { + std::map snapshot; + int first_variable = ctx->num_inputs() - num_variables; + for (int i = 0; i < num_variables; ++i) { + Var* variable = nullptr; + ResourceHandle handle = HandleFromInput(ctx, first_variable + i); + OptionalTensor& tensor = snapshot[first_variable + i]; + if (LookupResource(ctx, handle, &variable).ok()) { + tf_shared_lock lock(*variable->mu()); + tensor.name = handle.name(); + tensor.present = true; + tensor.value = *variable->tensor(); + } + } + return snapshot; +} + +XlaAllocator::XlaAllocator(const gpu::Platform* platform, + OpKernelContext* op_context) + : xla::DeviceMemoryAllocator(platform), op_context_(op_context) {} + +XlaAllocator::~XlaAllocator() { CHECK(allocated_.empty()); } + +xla::StatusOr XlaAllocator::Allocate( + int device_ordinal, uint64 size, bool retry_on_failure) { + void* data = op_context_->device()->GetAllocator({})->AllocateRaw( + Allocator::kAllocatorAlignment, size); + allocated_.insert(data); + return gpu::DeviceMemoryBase(data, size); +} + +void XlaAllocator::Release(void* ptr) { allocated_.erase(ptr); } + +Status XlaAllocator::Deallocate(int device_ordinal, + gpu::DeviceMemoryBase* mem) { + if (allocated_.count(mem->opaque())) { + op_context_->device()->GetAllocator({})->DeallocateRaw(mem->opaque()); + allocated_.erase(mem->opaque()); + } + return Status::OK(); +} + +namespace { +// Return the 'index''th subtree of the given ShapedBuffer as a ShapedBuffer. +xla::ShapedBuffer ExtractSubShapedBuffer(const xla::ShapedBuffer& shaped_buffer, + int index) { + xla::Shape on_host_shape = xla::ShapeUtil::GetTupleElementShape( + shaped_buffer.on_host_shape(), index); + xla::Shape on_device_shape = xla::ShapeUtil::GetTupleElementShape( + shaped_buffer.on_device_shape(), index); + + xla::ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape, + shaped_buffer.platform(), + shaped_buffer.device_ordinal()); + + auto& shape_tree = shaped_buffer.buffers(); + auto& sub_shape_tree = sub_shaped_buffer.buffers(); + sub_shape_tree.CopySubtreeFrom(shape_tree, + /*source_base_index=*/{index}, + /*target_base_index=*/{}); + return sub_shaped_buffer; +} +} // namespace + +XlaComputationLaunchContext::XlaComputationLaunchContext( + int64 num_resource_args, xla::LocalClient* client, + XlaAllocator* xla_allocator, XlaTensorInfoManager* tensor_info_manager) + : num_resource_args_(num_resource_args), + client_(client), + xla_allocator_(xla_allocator), + tensor_info_manager_(tensor_info_manager) {} + +void XlaComputationLaunchContext::PopulateInputs( + OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, + const std::map& variables) { + // Build xla::ShapedBuffers that point directly to the Tensor buffers. + arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1); + arg_buffers_.resize(kernel->xla_input_shapes.size()); + arg_ptrs_ = std::vector(arg_buffers_.size()); + + // Pass remaining parameters. + const Tensor* t; + for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { + int arg_num = kernel->input_mapping[i]; + const xla::Shape& shape = kernel->xla_input_shapes[i]; + if (variables.count(arg_num)) { + t = &(variables.at(arg_num).value); + CHECK(t); + } else { + t = &(ctx->input(arg_num)); + } + + const xla::Shape on_device_shape = + client_->backend().transfer_manager()->HostShapeToDeviceShape(shape); + if (xla::ShapeUtil::IsTuple(on_device_shape)) { + CHECK(tensor_info_manager_); + const XlaTensorInfo* tensor_info = + tensor_info_manager_->GetTensorInfo(*t); + CHECK(tensor_info && tensor_info->has_shaped_buffer()); + arg_ptrs_[i] = + const_cast(&tensor_info->shaped_buffer()); + } else { + CHECK(xla::ShapeUtil::Equal(shape, on_device_shape)) + << "On-device shape " + << xla::ShapeUtil::HumanStringWithLayout(on_device_shape) + << " not the same as on-host shape " + << xla::ShapeUtil::HumanStringWithLayout(shape); + gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase( + const_cast(t->tensor_data().data()), t->tensor_data().size()); + arg_buffers_[i] = xla::MakeUnique( + /*on_host_shape=*/shape, /*on_device_shape=*/shape, + client_->platform(), client_->default_device_ordinal()); + arg_buffers_[i]->set_buffer(dmem, /*index=*/{}); + arg_ptrs_[i] = arg_buffers_[i].get(); + } + } +} + +void XlaComputationLaunchContext::PopulateOutputs( + OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, + std::unique_ptr output) { + gpu::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + + // Computation output should always be a tuple. + if (VLOG_IS_ON(2)) { + VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString(); + } + CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); + + // Copy XLA results to the OpOutputList. + int output_num = 0; + for (int i = 0; i < ctx->num_outputs(); ++i) { + AllocatorAttributes alloc_attrs = ctx->output_alloc_attr(i); + Allocator* allocator = ctx->device()->GetAllocator({}); + if (tensor_info_manager_ && !alloc_attrs.on_host()) { + allocator = tensor_info_manager_; + } + if (kernel->outputs[i].is_constant) { + // Output is a constant. + const Tensor& const_tensor = kernel->outputs[i].constant_value; + Tensor* output_tensor; + const size_t total_bytes = const_tensor.TotalBytes(); + if (stream && total_bytes > 0) { + // Copy host -> device. (Empty tensors don't have backing buffers.) + VLOG(1) << "Constant output tensor on device"; + + TF_CHECK_OK( + ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); + + const void* src_ptr = DMAHelper::base(&const_tensor); + void* dst_ptr = DMAHelper::base(output_tensor); + gpu::DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes); + // Memcpying asynchronously is safe for the GPU, but the CPU uses a + // shared allocator so hold a reference to the copied-to buffer until + // complete. + TensorReference ref(*output_tensor); + stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes); + stream->ThenDoHostCallback([ref] { ref.Unref(); }); + } else { + // No copy required. + ctx->set_output(i, const_tensor); + output_tensor = ctx->mutable_output(i); + } + if (tensor_info_manager_) { + XlaTensorInfo* tensor_info = + tensor_info_manager_->GetOrCreateTensorInfo(*output_tensor); + tensor_info->set_host_tensor(const_tensor); + } + } else { + const TensorShape& shape = kernel->outputs[i].shape; + VLOG(2) << "Retval " << i << " shape " << shape.DebugString(); + + gpu::DeviceMemoryBase buffer = output->buffer({output_num}); + Tensor output_tensor = XlaTensorBuffer::MakeTensor( + ctx->expected_output_dtype(i), shape, buffer, allocator); + xla_allocator_->Release(buffer.opaque()); + + xla::Shape output_shape = xla::ShapeUtil::GetTupleElementShape( + output->on_device_shape(), output_num); + if (xla::ShapeUtil::IsTuple(output_shape)) { + CHECK(tensor_info_manager_); + XlaTensorInfo* tensor_info = + tensor_info_manager_->GetOrCreateTensorInfo(output_tensor); + tensor_info->set_shaped_buffer( + ExtractSubShapedBuffer(*output, output_num)); + } + ctx->set_output(i, output_tensor); + ++output_num; + } + + if (VLOG_IS_ON(3)) { + VLOG(3) << ctx->mutable_output(i)->DebugString(); + } + } + + // Apply variable updates, if any. + VLOG(2) << "Applying variable updates"; + for (int i = 0; i < kernel->resource_updates.size(); ++i) { + Allocator* allocator = ctx->device()->GetAllocator({}); + if (tensor_info_manager_) { + allocator = tensor_info_manager_; + } + const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; + OP_REQUIRES(ctx, + write.input_index >= 0 && write.input_index < ctx->num_inputs(), + errors::Internal("Invalid input index for variable write.")); + + gpu::DeviceMemoryBase buffer = output->buffer({output_num}); + + Var* variable = nullptr; + // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, + // not a Tensor. + OP_REQUIRES_OK(ctx, LookupOrCreateResource( + ctx, HandleFromInput(ctx, write.input_index), + &variable, [this, ctx, &write](Var** ptr) { + *ptr = new Var(write.type); + return Status::OK(); + })); + + core::ScopedUnref s(variable); + + mutex_lock ml(*variable->mu()); + OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type, + errors::Internal("Mismatched type in variable write")); + *variable->tensor() = + XlaTensorBuffer::MakeTensor(write.type, write.shape, buffer, allocator); + xla_allocator_->Release(buffer.opaque()); + + xla::Shape output_shape = xla::ShapeUtil::GetTupleElementShape( + output->on_device_shape(), output_num); + if (xla::ShapeUtil::IsTuple(output_shape)) { + CHECK(tensor_info_manager_); + XlaTensorInfo* tensor_info = + tensor_info_manager_->GetOrCreateTensorInfo(*variable->tensor()); + tensor_info->set_shaped_buffer( + ExtractSubShapedBuffer(*output, output_num)); + } + ++output_num; + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h new file mode 100644 index 0000000000000000000000000000000000000000..8694f6ce58b72ca188bf831528db30daf93b905d --- /dev/null +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -0,0 +1,149 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains utilities for launching compiled XLA kernels for a KernelContext. + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ + +#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/xla_tensor_info.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/variable_ops.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +class XlaAllocator; + +// Takes a snapshot of the values of resource variable arguments, which are +// the last `num_variables` arguments. We snapshot tensors that back +// resource variables since concurrent updates may modify the shape, and it is +// important that the shapes used for compilation match the true shapes of the +// buffers. +// +// Returns a map of TensorFlow argument index to resource variable. +std::map SnapshotResourceVariables(OpKernelContext* ctx, + int num_variables); + +// Adapter class that wraps a Tensorflow allocator as an XLA allocator. +// Assumes that the Tensorflow allocator permits asynchronous deallocation: +// see comment on `AllowsAsynchronousDeallocation()`. +class XlaAllocator : public xla::DeviceMemoryAllocator { + public: + XlaAllocator(const perftools::gputools::Platform* platform, + OpKernelContext* op_context); + ~XlaAllocator() override; + xla::StatusOr Allocate( + int device_ordinal, uint64 size, bool retry_on_failure) override; + Status Deallocate(int device_ordinal, + perftools::gputools::DeviceMemoryBase* mem) override; + + // Un-track 'ptr' - do not delete it on destruction. + void Release(void* ptr); + + // The Tensorflow BFC allocator used on GPU allows host-side deallocation + // before GPU execution takes place. Tensorflow uses the ordering of the main + // compute stream to enforce a happens-before relationship between a memory + // allocation and code that reuses the same memory. If Tensorflow adds + // support for multiple GPU streams or allocators with different ordering + // requirements, this code may need to change. + // (This attribute has no effect on CPU.) + bool AllowsAsynchronousDeallocation() const override { return true; } + + private: + OpKernelContext* const op_context_; + std::unordered_set allocated_; +}; + +// Helper class to perform the marshalling of TensorFlow inputs and outputs to +// ShapedBuffers suitable for passing to an XLA computation. +class XlaComputationLaunchContext { + public: + XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client, + XlaAllocator* xla_allocator, + XlaTensorInfoManager* tensor_info_manager); + + // Add all inputs within `ctx` as XLA arguments (returned by arguments()). + // `variables` is a map from TensorFlow argument number to resource variable. + void PopulateInputs(OpKernelContext* ctx, + const XlaCompiler::CompilationResult* kernel, + const std::map& variables); + + // Given the XLA output in `output`, populate all outputs of `ctx`. + void PopulateOutputs(OpKernelContext* ctx, + const XlaCompiler::CompilationResult* kernel, + std::unique_ptr output); + + // Return the argument list. Only valid after PopulateInputs() has been + // called. + const std::vector& arguments() const { return arg_ptrs_; } + + private: + int64 num_resource_args_; + xla::LocalClient* client_; + XlaAllocator* xla_allocator_; + XlaTensorInfoManager* tensor_info_manager_; + std::vector> arg_buffers_; + std::vector arg_ptrs_; +}; + +// A simple TensorBuffer implementation that allows us to create Tensors that +// take ownership of pre-allocated memory. +class XlaTensorBuffer : public TensorBuffer { + public: + XlaTensorBuffer(const void* ptr, size_t expected_size, size_t actual_size, + Allocator* allocator) + : expected_size_(expected_size), + actual_size_(actual_size), + allocator_(allocator) { + data_ = const_cast(ptr); + } + + ~XlaTensorBuffer() override { allocator_->DeallocateRaw(data_); } + + void* data() const override { return data_; } + size_t size() const override { return expected_size_; } + + TensorBuffer* root_buffer() override { return this; } + + void FillAllocationDescription(AllocationDescription* proto) const override { + proto->set_allocated_bytes(actual_size_); + } + + static Tensor MakeTensor(DataType dtype, const TensorShape& shape, + perftools::gputools::DeviceMemoryBase buffer, + Allocator* allocator) { + size_t expected_size = shape.num_elements() * DataTypeSize(dtype); + auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size, + buffer.size(), allocator); + Tensor t(dtype, shape, tensor_buffer); + tensor_buffer->Unref(); + return t; + } + + private: + void* data_; + size_t expected_size_; + size_t actual_size_; + Allocator* allocator_; +}; + +} // namespace tensorflow + +#endif diff --git a/tensorflow/compiler/jit/xla_tensor_info.cc b/tensorflow/compiler/jit/xla_tensor_info.cc new file mode 100644 index 0000000000000000000000000000000000000000..0ce18c27cbe1d46eb61f8000506396fedc509e9c --- /dev/null +++ b/tensorflow/compiler/jit/xla_tensor_info.cc @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_tensor_info.h" + +namespace tensorflow { + +const XlaTensorInfo* XlaTensorInfoManager::GetTensorInfo( + const void* device_ptr) const { + mutex_lock lock(lock_); + auto iterator = tensor_infos_.find(device_ptr); + return (iterator == tensor_infos_.end()) ? nullptr + : tensor_infos_.at(device_ptr).get(); +} + +XlaTensorInfo* XlaTensorInfoManager::GetOrCreateTensorInfo( + const void* device_ptr) { + mutex_lock lock(lock_); + auto iterator = tensor_infos_.find(device_ptr); + if (iterator != tensor_infos_.end()) { + return iterator->second.get(); + } + auto iterator_and_inserted = + tensor_infos_.emplace(device_ptr, MakeUnique()); + CHECK(iterator_and_inserted.second); + return iterator_and_inserted.first->second.get(); +} + +const XlaTensorInfo* XlaTensorInfoManager::GetTensorInfo(const Tensor& tensor) { + return GetTensorInfo(tensor.tensor_data().data()); +} + +XlaTensorInfo* XlaTensorInfoManager::GetOrCreateTensorInfo( + const Tensor& tensor) { + return GetOrCreateTensorInfo(tensor.tensor_data().data()); +} + +void XlaTensorInfoManager::DeallocateRaw(void* ptr) { + wrapped()->DeallocateRaw(ptr); + mutex_lock lock(lock_); + tensor_infos_.erase(ptr); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_tensor_info.h b/tensorflow/compiler/jit/xla_tensor_info.h new file mode 100644 index 0000000000000000000000000000000000000000..fbd6ad770fbf9b80829ca80f1a85704e3288a680 --- /dev/null +++ b/tensorflow/compiler/jit/xla_tensor_info.h @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_INFO_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_INFO_H_ + +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// Information about a tensor. The XlaTensorInfoManager can maintain one of +// these per device Tensor. +class XlaTensorInfo { + public: + XlaTensorInfo() {} + + // Some Tensors can have complex on-device shapes, including tuple shapes. To + // manage the memory for these tensors a ShapedBuffer may be required. + + // Return true if this TensorInfo contains a ShapedBuffer. + bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; } + // Return the contained ShapedBuffer. + // REQUIRES: has_shaped_buffer() + const xla::ShapedBuffer& shaped_buffer() const { return *shaped_buffer_; } + // Mutates the TensorInfo to set the ShapedBuffer. + void set_shaped_buffer(xla::ShapedBuffer shaped_buffer) { + shaped_buffer_.reset(new xla::ShapedBuffer(std::move(shaped_buffer))); + } + + // Some tensors on the device may have known values on the host. We use these + // in on-demand mode to avoid re-copying values from the device if we know the + // host value already. + + // Return true if this TensorInfo contains a host tensor. + bool has_host_tensor() const { return host_tensor_ != nullptr; } + // Return the contained host tensor. + // REQUIRES: has_host_tensor() + const Tensor& host_tensor() const { return *host_tensor_; } + // Sets the contained host tensor. + void set_host_tensor(const Tensor& tensor) { + host_tensor_.reset(new Tensor(tensor)); + } + + private: + // The optional contained ShapedBuffer. + std::unique_ptr shaped_buffer_; + // An optional host tensor value. + std::unique_ptr host_tensor_; +}; + +// Manages XlaTensorInfo objects. This class is also an Allocator, so that +// XlaTensorInfo objects can be deleted when their Tensor is deallocated. +class XlaTensorInfoManager : public AllocatorWrapper { + public: + // Creates a new XlaTensorInfoManager, delegating all DeallocateRaw calls to + // allocator. + XlaTensorInfoManager(Allocator* allocator) : AllocatorWrapper(allocator) {} + + // Returns the XlaTensorInfo for the given device memory pointer or nullptr if + // none exists. + const XlaTensorInfo* GetTensorInfo(const void* device_ptr) const; + // Returns the XlaTensorInfo for the device memory pointer extracted from + // tensor or nullptr if none exists. + const XlaTensorInfo* GetTensorInfo(const Tensor& tensor); + + // Returns the XlaTensorInfo for the given device memory pointer, creating one + // if necessary. + XlaTensorInfo* GetOrCreateTensorInfo(const Tensor& tensor); + // Returns the XlaTensorInfo for the device memory pointer extracted from + // tensor, creating one if necessary. + XlaTensorInfo* GetOrCreateTensorInfo(const void* device_ptr); + + // Allocator interface + void DeallocateRaw(void* ptr) override; + + private: + mutable mutex lock_; + // The managed tensor infos. The mapped value is a unique_ptr so that returned + // references are stable over rehashes. + std::unordered_map> tensor_infos_ + GUARDED_BY(lock_); +}; +} // namespace tensorflow + +#endif diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 782bf82d4149968d5e5fbfb93bbd4ff1dcd75494..1c5a8f8e695cb2922f118f231082ebb53cb2bc9b 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -86,7 +86,10 @@ tf_xla_py_test( # ArgMax needs CustomCall on CPU, which is not available in normal # (not precompiled) TensorFlow. The flag below excludes the CPU # backend. - disabled_backends = "cpu", + disabled_backends = [ + "cpu", + "cpu_ondemand", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -98,7 +101,7 @@ tf_xla_py_test( tf_xla_py_test( name = "binary_ops_test", - size = "small", + size = "medium", srcs = ["binary_ops_test.py"], shard_count = 5, tags = [ @@ -315,6 +318,8 @@ tf_xla_py_test( name = "function_test", size = "small", srcs = ["function_test.py"], + # Functions are not implemented in the on-demand compilation model yet. + disabled_backends = "cpu_ondemand", deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -550,6 +555,8 @@ tf_xla_py_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], + # Stack ops are not implemented in the on-demand compilation model yet. + disabled_backends = "cpu_ondemand", deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -576,6 +583,8 @@ tf_xla_py_test( name = "tensor_array_ops_test", size = "small", srcs = ["tensor_array_ops_test.py"], + # TensorArray ops are not implemented in the on-demand compilation model yet. + disabled_backends = "cpu_ondemand", deps = [ ":xla_test", "//tensorflow/python:array_ops", diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 30a6d3a74d64f90ad33062df6d1e16e3a575bd63..ba7b9bacd2b794c74409d517a9c05bfbb14a845f 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -71,7 +71,7 @@ class BinaryOpsTest(XLATestCase): expected=np.array([[[[False, True], [True, False]]]], dtype=dtype)) self._testBinary( - gen_math_ops._real_div, + gen_math_ops.real_div, np.array([3, 3, -1.5, -8, 44], dtype=dtype), np.array([2, -2, 7, -4, 0], dtype=dtype), expected=np.array( @@ -108,57 +108,57 @@ class BinaryOpsTest(XLATestCase): [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype)) self._testBinary( - gen_math_ops._reciprocal_grad, + gen_math_ops.reciprocal_grad, np.array([4, -3, -2, 1], dtype=dtype), np.array([5, -6, 7, -8], dtype=dtype), expected=np.array([-80, 54, -28, 8], dtype=dtype)) self._testBinary( - gen_math_ops._sigmoid_grad, + gen_math_ops.sigmoid_grad, np.array([4, 3, 2, 1], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array([-60, -36, -14, 0], dtype=dtype)) self._testBinary( - gen_math_ops._rsqrt_grad, + gen_math_ops.rsqrt_grad, np.array([4, 3, 2, 1], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array([-160, -81, -28, -4], dtype=dtype)) self._testBinary( - gen_math_ops._sqrt_grad, + gen_math_ops.sqrt_grad, np.array([4, 3, 2, 1], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array([0.625, 1, 1.75, 4], dtype=dtype)) self._testBinary( - gen_nn_ops._softplus_grad, + gen_nn_ops.softplus_grad, np.array([4, 3, 2, 1], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array( [3.97322869, 2.99258232, 1.99817801, 0.99966466], dtype=dtype)) self._testBinary( - gen_nn_ops._softsign_grad, + gen_nn_ops.softsign_grad, np.array([4, 3, 2, 1], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array( [0.11111111, 0.06122449, 0.03125, 0.01234568], dtype=dtype)) self._testBinary( - gen_math_ops._tanh_grad, + gen_math_ops.tanh_grad, np.array([4, 3, 2, 1], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array([-75, -48, -21, 0], dtype=dtype)) self._testBinary( - gen_nn_ops._elu_grad, + gen_nn_ops.elu_grad, np.array([1, 2, 3, 4, 5, 6], dtype=dtype), np.array([-.6, -.4, -.2, 0, .2, .4], dtype=dtype), expected=np.array([0.4, 1.2, 2.4, 4, 5, 6], dtype=dtype)) self._testBinary( - gen_nn_ops._selu_grad, + gen_nn_ops.selu_grad, np.array([1, 2, 3, 4, 5, 6], dtype=dtype), np.array([-.6, -.4, -.2, .2, .4, .6], dtype=dtype), expected=np.array( @@ -166,20 +166,20 @@ class BinaryOpsTest(XLATestCase): 4.202803949422, 5.2535049367774, 6.30420592413], dtype=dtype)) self._testBinary( - gen_nn_ops._relu_grad, + gen_nn_ops.relu_grad, np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), np.array([0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype), expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10], dtype=dtype)) self._testBinary( - gen_nn_ops._relu6_grad, + gen_nn_ops.relu6_grad, np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtype), np.array( [0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9, 6.1, 10.0], dtype=dtype), expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 0, 0], dtype=dtype)) self._testBinary( - gen_nn_ops._softmax_cross_entropy_with_logits, + gen_nn_ops.softmax_cross_entropy_with_logits, np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype), np.array([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]], dtype=dtype), expected=[ @@ -191,7 +191,7 @@ class BinaryOpsTest(XLATestCase): equality_test=self.ListsAreClose) self._testBinary( - gen_nn_ops._sparse_softmax_cross_entropy_with_logits, + gen_nn_ops.sparse_softmax_cross_entropy_with_logits, np.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]], dtype=dtype), np.array([2, 1, 7], dtype=np.int32), @@ -207,7 +207,7 @@ class BinaryOpsTest(XLATestCase): def testIntOps(self): for dtype in self.int_types: self._testBinary( - gen_math_ops._truncate_div, + gen_math_ops.truncate_div, np.array([3, 3, -1, -9, -8], dtype=dtype), np.array([2, -2, 7, 2, -4], dtype=dtype), expected=np.array([1, -1, 0, -4, 2], dtype=dtype)) @@ -232,11 +232,16 @@ class BinaryOpsTest(XLATestCase): expected=np.right_shift(lhs, rhs)) if dtype in [np.int8, np.int16, np.int32, np.int64]: - lhs = np.array([-1, -5, -3, -14], dtype=dtype) - rhs = np.array([5, 0, 1, 11], dtype=dtype) - self._testBinary( - bitwise_ops.right_shift, lhs, rhs, - expected=np.right_shift(lhs, rhs)) + lhs = np.array([-1, -5, -3, -14, -2], dtype=dtype) + rhs = np.array([5, 0, 1, 11, 36], dtype=dtype) + # HLO has saturating shift behavior. + bits = np.ceil( + np.log(np.iinfo(dtype).max - np.iinfo(dtype).min) / np.log(2)) + expected = [ + np.right_shift(l, r) if r < bits else np.sign(l) + for l, r in zip(lhs, rhs) + ] + self._testBinary(bitwise_ops.right_shift, lhs, rhs, expected=expected) def testNumericOps(self): for dtype in self.numeric_types: @@ -255,12 +260,18 @@ class BinaryOpsTest(XLATestCase): np.array([[1], [2]], dtype=dtype), dtype(7), expected=np.array([[8], [9]], dtype=dtype)) + self._testBinary( + math_ops.add, + np.array([0xffffffff, 0xfffffffff, 1, 1], dtype=np.int64), + np.array([1, 1, 0xffffffff, 0xfffffffff], dtype=np.int64), + expected=np.array( + [1 << 32, 1 << 36, 1 << 32, 1 << 36], dtype=np.int64)) self._testBinary( math_ops.subtract, - np.array([1, 2], dtype=dtype), - np.array([10, 20], dtype=dtype), - expected=np.array([-9, -18], dtype=dtype)) + np.array([1, 2, 100], dtype=dtype), + np.array([10, 20, -1], dtype=dtype), + expected=np.array([-9, -18, 101], dtype=dtype)) self._testBinary( math_ops.subtract, dtype(5), @@ -369,7 +380,7 @@ class BinaryOpsTest(XLATestCase): expected=np.array([[[[False, True], [True, False]]]], dtype=dtype)) self._testBinary( - gen_math_ops._real_div, + gen_math_ops.real_div, np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j], dtype=dtype), np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j], dtype=dtype), expected=np.array( @@ -378,7 +389,7 @@ class BinaryOpsTest(XLATestCase): # Test inf/nan scenarios. self._testBinary( - gen_math_ops._real_div, + gen_math_ops.real_div, np.array([4 + 3j, 4, 3j, -4, -4j, 2 - 3j], dtype=dtype), np.array([0, 0, 0, 0, 0, 0], dtype=dtype), expected=np.array( @@ -418,19 +429,19 @@ class BinaryOpsTest(XLATestCase): lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype) rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype) self._testBinary( - gen_math_ops._reciprocal_grad, lhs, rhs, expected=-rhs * lhs * lhs) + gen_math_ops.reciprocal_grad, lhs, rhs, expected=-rhs * lhs * lhs) self._testBinary( - gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs)) + gen_math_ops.sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs)) self._testBinary( - gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) + gen_math_ops.rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) self._testBinary( - gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs)) + gen_math_ops.sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs)) self._testBinary( - gen_math_ops._tanh_grad, lhs, rhs, expected=rhs * (1 - lhs * lhs)) + gen_math_ops.tanh_grad, lhs, rhs, expected=rhs * (1 - lhs * lhs)) def testComplexMath(self): for dtype in self.complex_types: @@ -538,7 +549,7 @@ class BinaryOpsTest(XLATestCase): if dtype not in self.complex_types: # floordiv unsupported for complex. self._testBinary( - gen_math_ops._floor_div, + gen_math_ops.floor_div, np.array([3, 3, -1, -9, -8], dtype=dtype), np.array([2, -2, 7, 2, -4], dtype=dtype), expected=np.array([1, -2, -1, -5, 2], dtype=dtype)) @@ -554,12 +565,12 @@ class BinaryOpsTest(XLATestCase): def _testRemainder(self, dtype): """Test cases for remainder operators.""" self._testBinary( - gen_math_ops._floor_mod, + gen_math_ops.floor_mod, np.array([3, 3, -1, -8], dtype=dtype), np.array([2, -2, 7, -4], dtype=dtype), expected=np.array([1, -1, 6, 0], dtype=dtype)) self._testBinary( - gen_math_ops._truncate_mod, + gen_math_ops.truncate_mod, np.array([3, 3, -1, -8], dtype=dtype), np.array([2, -2, 7, -4], dtype=dtype), expected=np.array([1, 1, -1, 0], dtype=dtype)) @@ -668,6 +679,11 @@ class BinaryOpsTest(XLATestCase): np.array([[10], [7], [2]], dtype=np.float32), np.float32(7), expected=np.array([[False], [False], [True]], dtype=np.bool)) + self._testBinary( + less_op, + np.array([[10], [7], [2], [-1]], dtype=np.int64), + np.int64(7), + expected=np.array([[False], [False], [True], [True]], dtype=np.bool)) for less_equal_op in [math_ops.less_equal, (lambda x, y: x <= y)]: self._testBinary( @@ -686,6 +702,80 @@ class BinaryOpsTest(XLATestCase): np.float32(7), expected=np.array([[False], [True], [True]], dtype=np.bool)) + def testS64Comparisons(self): + for op in [(lambda x, y: x < y), (lambda x, y: x <= y), + (lambda x, y: x >= y), (lambda x, y: x > y)]: + lhs = np.array( + [ + np.int64(0x000000007FFFFFFF), + np.int64(0x000000007FFFFFFF), + np.int64(0x0000000080000000), + np.int64(0x0000000080000000), + np.int64(0x0000000080000001), + np.int64(0x00000000FFFF0000), + np.int64(0x00000000FFFF0000), + np.int64(0x00000000FFFFFFFE), + np.int64(0x00000000FFFFFFFF), + np.int64(0x00000000FFFFFFFF), + np.int64(0x0000000100000000), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(-0x7FFFFFFF00000002), + np.int64(-0x7FFFFFFF00000002), + np.int64(-0x7FFFFFFF00000001), + np.int64(-0x7FFFFFFF00000001), + np.int64(-0x7FFFFFFF00000001), + np.int64(-0x7FFFFFFF00000001), + np.int64(0x7ffffffefff00010), + np.int64(0x7ffffffefff00010), + np.int64(-1), + np.int64(-1) + ], + dtype=np.int64) + rhs = np.array( + [ + np.int64(0x000000007FFFFFFE), + np.int64(0x000000007FFFFFFF), + np.int64(0x000000007FFFFFFF), + np.int64(0x0000000080000000), + np.int64(0x0000000080000001), + np.int64(0x00000000FFFF0000), + np.int64(0x00000000FFFF0001), + np.int64(0x00000000FFFFFFFF), + np.int64(0x00000000FFFFFFFE), + np.int64(0x00000000FFFFFFFF), + np.int64(0x00000000FFFFFFFF), + np.int64(0x0000000100000001), + np.int64(0x0000000100000002), + np.int64(0x0000000100000003), + np.int64(0x0000000200000001), + np.int64(0x0000000200000002), + np.int64(0x0000000200000003), + np.int64(0x0000000300000001), + np.int64(0x0000000300000002), + np.int64(0x0000000300000003), + np.int64(0x00000000FFFFFFFF), + np.int64(-0x7FFFFFFF00000001), + np.int64(0x00000000FFFFFFFE), + np.int64(0x00000000FFFFFFFF), + np.int64(-0x7FFFFFFF00000002), + np.int64(-0x7FFFFFFF00000001), + np.int64(0x00000000FFFFFFFF), + np.int64(-0x7FFFFFFF00000001), + np.int64(-2), + np.int64(-1) + ], + dtype=np.int64) + expected = np.array([op(l, r) for l, r in zip(lhs, rhs)], dtype=np.bool) + self._testBinary(op, lhs, rhs, expected=expected) + def testBroadcasting(self): """Tests broadcasting behavior of an operator.""" @@ -1045,6 +1135,20 @@ class BinaryOpsTest(XLATestCase): ], equality_test=self.ListsAreClose) + def splitvOp(x, y): # pylint: disable=invalid-name + return array_ops.split(value=y, num_or_size_splits=[2, 3], axis=x) + for axis in [1, -1]: + self._testBinary( + splitvOp, + np.int32(axis), + np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], + dtype=dtype), + expected=[ + np.array([[0, 1], [5, 6]], dtype=dtype), + np.array([[2, 3, 4], [7, 8, 9]], dtype=dtype), + ], + equality_test=self.ListsAreClose) + def testTile(self): for dtype in self.numeric_types: self._testBinary( diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 81734082d9aab86f8bc763681265ef64ef32bd31..f10973e19f1945515b776cf86349445ed7334629 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -301,7 +301,7 @@ class ConcatOffsetTest(XLATestCase): s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32) s2 = constant_op.constant([2, 20, 5], dtypes.int32) - off = gen_array_ops._concat_offset(cdim, [s0, s1, s2]) + off = gen_array_ops.concat_offset(cdim, [s0, s1, s2]) ans = sess.run(off) self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]]) diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 538fa8e8e570b83ed681ecc0501285520cabdecb..3bc41b7cfd72bec7572097f8c53eef314a4369f6 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -426,7 +426,7 @@ class ResizeBilinearTest(XLATestCase): with self.test_session() as sess, self.test_scope(): dtype = dtype or np.float32 grads = array_ops.placeholder(np.float32) - resized = gen_image_ops._resize_bilinear_grad( + resized = gen_image_ops.resize_bilinear_grad( grads, np.zeros([1, input_shape[0], input_shape[1], 1], dtype=dtype), align_corners=True) diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 2d8236e2cbdfafb35626cd582ee39b1f917aec7f..f9d87c2d1cfe5c1a7487e124c971a54ffcfede15 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np from tensorflow.contrib.compiler import jit @@ -436,5 +437,55 @@ class XlaCompilationTest(test.TestCase): self.assertTrue(InLabels(labels, "_XlaLaunch")) +class ElementWiseFusionTest(test.TestCase): + + # Runs a simple test with the input jit_level and fusion_only flag. + def simpleTest(self, arg0, arg1, global_jit_level): + config = config_pb2.ConfigProto() + config.graph_options.optimizer_options.global_jit_level = global_jit_level + + with session_lib.Session(config=config) as sess: + a1 = array_ops.placeholder(dtypes.float32, [2, 2], name="a1") + a2 = array_ops.placeholder(dtypes.float32, [2, 2], name="a2") + # Two element-wise ops. We need at least two ops since single + # element clusters are not passed to XLA in fusion_only mode. + a3 = a1 * a2 + a4 = a3 + a1 + # A matmul to break XLA clustering. + a5 = math_ops.matmul(a4, a1) + # Two more element-wise ops. + a6 = a5 - a4 + a7 = a6 + a2 + + run_metadata = config_pb2.RunMetadata() + output = sess.run( + a7, { + a1: arg0, + a2: arg1 + }, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = RunMetadataLabels(run_metadata) + count = sum("_XlaLaunch(" in x for x in labels) + + return output, count + + def testElementWiseClustering(self): + arg0 = np.random.rand(2, 2).astype(np.float32) + arg1 = np.random.rand(2, 2).astype(np.float32) + os.environ["TF_XLA_FLAGS"] = "--tf_xla_fusion_only=true" + tf_op, tf_count = self.simpleTest(arg0, arg1, + config_pb2.OptimizerOptions.OFF) + self.assertEqual(0, tf_count) + + tfef_op, tfef_count = self.simpleTest(arg0, arg1, + config_pb2.OptimizerOptions.ON_1) + self.assertEqual(2, tfef_count) + + self.assertAllClose(tf_op, tfef_op, rtol=1e-1) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py index 5d8d89224d4a778d84803811710bb095872e86b2..69bd8f7230d4394c45764d02a88fb0ec097c5756 100644 --- a/tensorflow/compiler/tests/lrn_ops_test.py +++ b/tensorflow/compiler/tests/lrn_ops_test.py @@ -115,11 +115,11 @@ class LRNTest(XLATestCase): out_image = constant_op.constant(out_image_vals, shape=shape) out_grads = constant_op.constant(out_grads_vals, shape=shape) with ops.device(CPU_DEVICE): - expected = gen_nn_ops._lrn_grad(out_grads, in_image, out_image, - depth_radius, bias, alpha, beta) + expected = gen_nn_ops.lrn_grad(out_grads, in_image, out_image, + depth_radius, bias, alpha, beta) with self.test_scope(): - actual = gen_nn_ops._lrn_grad(out_grads, in_image, out_image, - depth_radius, bias, alpha, beta) + actual = gen_nn_ops.lrn_grad(out_grads, in_image, out_image, + depth_radius, bias, alpha, beta) expected_val = expected.eval() actual_val = actual.eval() self.assertAllClose(actual_val, expected_val, rtol=1e-3) diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py index eb48fe555a0b182ea7983cbd8c3b217d56350408..4eed903963a34a253ea5c409782d9a89a97a4fdf 100644 --- a/tensorflow/compiler/tests/pooling_ops_3d_test.py +++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py @@ -33,7 +33,7 @@ from tensorflow.python.platform import test # MaxPoolGrad. def _AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding): del outputs # Unused by average-pooling gradients. - return gen_nn_ops._avg_pool3d_grad( + return gen_nn_ops.avg_pool3d_grad( inputs.get_shape().as_list(), output_gradients, ksize=ksize, @@ -263,7 +263,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradValidPadding1_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[1, 3, 3, 3, 1], ksize=[1, 1, 1], strides=[1, 1, 1], @@ -272,7 +272,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradValidPadding2_1_6_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[2, 3, 3, 6, 3], ksize=[2, 2, 2], strides=[1, 1, 1], @@ -281,7 +281,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradValidPadding2_1_7_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[2, 3, 5, 7, 3], ksize=[2, 2, 2], strides=[1, 1, 1], @@ -290,7 +290,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradValidPadding2_2_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[2, 2, 2, 2, 3], ksize=[2, 2, 2], strides=[2, 2, 2], @@ -299,7 +299,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradSamePadding1_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[2, 3, 2, 4, 1], ksize=[1, 1, 1], strides=[1, 1, 1], @@ -308,7 +308,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradSamePadding2_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[2, 3, 2, 4, 1], ksize=[2, 2, 2], strides=[1, 1, 1], @@ -317,7 +317,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradSamePadding2_2_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[2, 5, 2, 4, 3], ksize=[2, 2, 2], strides=[2, 2, 2], @@ -326,7 +326,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradSamePadding3_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[1, 3, 3, 7, 1], ksize=[3, 3, 3], strides=[1, 1, 1], diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index 7c19a99c4eb4be3ca34b3ce949216e557b0a681d..fe270af3d636c0824621f36360ce9e7d14d8fc91 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -292,8 +292,15 @@ class PoolGradTest(XLATestCase): CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" - def _VerifyOneTest(self, pool_func, pool_grad_func, input_sizes, ksize, - strides, padding, data_format): + def _VerifyOneTest(self, + pool_func, + pool_grad_func, + input_sizes, + ksize, + strides, + padding, + data_format, + pool_grad_grad_func=None): """Verifies the output values of the pooling gradient function. Args: @@ -304,9 +311,19 @@ class PoolGradTest(XLATestCase): strides: The stride dimensions padding: Padding type. data_format: The data format we use to run the pooling operation. + pool_grad_grad_func: Second-order gradient function, if available. """ total_size = np.prod(input_sizes) - x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes) + # TODO(b/73062247): MaxPoolGradGrad can confuse gradients when x is equally + # maximal at 16 bits. Switch to np.random.randn when resolved. + x = np.arange(1, total_size + 1, dtype=np.float32) + x *= (np.random.randint(2, size=total_size) * 2 - 1) # Flip signs randomly + # Verify some specifically interesting values... + x[np.random.choice(total_size)] = np.inf + x[np.random.choice(total_size)] = -np.inf + # TODO(b/74222344): Fix nan handling for max pool grad. + # x[np.random.choice(total_size)] = np.nan + x = x.reshape(input_sizes) with self.test_session() as sess: # Use the forward pool function to compute some corresponding outputs # (needed for the CPU device, and we need the shape in both cases). @@ -323,6 +340,8 @@ class PoolGradTest(XLATestCase): output_gradient_vals = np.arange( 1, output_vals.size + 1, dtype=np.float32) output_gradient_vals = output_gradient_vals.reshape(output_vals.shape) + output_grad_grad_vals = np.arange(1, x.size + 1, dtype=np.float32) + output_grad_grad_vals = output_grad_grad_vals.reshape(x.shape) # Use the Tensorflow CPU pooling gradient to compute the expected input # gradients. @@ -342,18 +361,36 @@ class PoolGradTest(XLATestCase): {inputs: x, output_gradients: output_gradient_vals}) + output_grad_gradients = array_ops.placeholder( + dtypes.float32, shape=expected_input_gradient_vals.shape) + if pool_grad_grad_func is not None: + expected_grad_gradients = pool_grad_grad_func( + inputs, + outputs, + output_grad_gradients, + ksize=ksize, + strides=strides, + padding=padding, + data_format="NHWC") + expected_grad_gradients_vals = sess.run(expected_grad_gradients, { + inputs: x, + output_grad_gradients: output_grad_grad_vals + }) + # Run the gradient op on the XLA device with self.test_scope(): outputs = array_ops.placeholder(dtypes.float32, shape=output_vals.shape) xla_inputs = inputs xla_outputs = outputs xla_output_gradients = output_gradients + xla_output_grad_gradients = output_grad_gradients xla_ksize = ksize xla_strides = strides if data_format == "NCHW": xla_inputs = NHWCToNCHW(inputs) xla_outputs = NHWCToNCHW(outputs) xla_output_gradients = NHWCToNCHW(output_gradients) + xla_output_grad_gradients = NHWCToNCHW(output_grad_gradients) xla_ksize = NHWCToNCHW(ksize) xla_strides = NHWCToNCHW(strides) actual_input_gradients = pool_grad_func( @@ -366,22 +403,54 @@ class PoolGradTest(XLATestCase): data_format=data_format) if data_format == "NCHW": actual_input_gradients = NCHWToNHWC(actual_input_gradients) - actual = sess.run(actual_input_gradients, { + if pool_grad_grad_func is not None: + actual_grad_gradients = pool_grad_grad_func( + xla_inputs, + xla_outputs, + xla_output_grad_gradients, + ksize=xla_ksize, + strides=xla_strides, + padding=padding, + data_format=data_format) + if data_format == "NCHW": + actual_grad_gradients = NCHWToNHWC(actual_grad_gradients) + actual_input_gradients_vals = sess.run(actual_input_gradients, { inputs: x, outputs: output_vals, output_gradients: output_gradient_vals }) - # Compare the Tensorflow and XLA results. self.assertAllClose( - expected_input_gradient_vals.flatten(), - actual.flatten(), + expected_input_gradient_vals, + actual_input_gradients_vals, rtol=1e-4, atol=1e-6) - self.assertShapeEqual(actual, inputs) - - def _VerifyValues(self, pool_func, pool_grad_func, input_sizes, ksize, - strides, padding): + self.assertShapeEqual(actual_input_gradients_vals, inputs) + + if pool_grad_grad_func is not None: + actual_grad_gradients_vals = sess.run( + actual_grad_gradients, { + inputs: x, + outputs: output_vals, + output_grad_gradients: output_grad_grad_vals + }) + + # Compare the Tensorflow and XLA results. + self.assertAllClose( + expected_grad_gradients_vals, + actual_grad_gradients_vals, + rtol=1e-4, + atol=1e-6) + self.assertShapeEqual(actual_grad_gradients_vals, outputs) + + def _VerifyValues(self, + pool_func, + pool_grad_func, + input_sizes, + ksize, + strides, + padding, + pool_grad_grad_func=None): """Verifies the output values of the pooling function. Args: @@ -391,12 +460,20 @@ class PoolGradTest(XLATestCase): ksize: The kernel size dimensions strides: The stride dimensions padding: Padding type. + pool_grad_grad_func: Second-order gradient function, if available. """ for data_format in GetTestConfigs(): - self._VerifyOneTest(pool_func, pool_grad_func, input_sizes, ksize, - strides, padding, data_format) - - def _TestPooling(self, forward_op, backward_op): + self._VerifyOneTest( + pool_func, + pool_grad_func, + input_sizes, + ksize, + strides, + padding, + data_format, + pool_grad_grad_func=pool_grad_grad_func) + + def _TestPooling(self, forward_op, backward_op, pool_grad_grad_func=None): # VALID padding self._VerifyValues( forward_op, @@ -404,7 +481,8 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 3, 3, 3], ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], - padding="VALID") + padding="VALID", + pool_grad_grad_func=pool_grad_grad_func) # SAME padding self._VerifyValues( @@ -413,7 +491,8 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 2, 3, 3], ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=pool_grad_grad_func) # SAME padding, non square window self._VerifyValues( @@ -422,7 +501,8 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 2, 2, 1], ksize=[1, 1, 2, 1], strides=[1, 1, 1, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=pool_grad_grad_func) # VALID padding, uneven stride self._VerifyValues( @@ -431,14 +511,16 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 4, 4, 1], ksize=[1, 2, 2, 1], strides=[1, 1, 2, 1], - padding="VALID") + padding="VALID", + pool_grad_grad_func=pool_grad_grad_func) self._VerifyValues( forward_op, backward_op, input_sizes=[1, 4, 4, 1], ksize=[1, 2, 2, 1], strides=[1, 2, 1, 1], - padding="VALID") + padding="VALID", + pool_grad_grad_func=pool_grad_grad_func) # SAME padding, size 4 input self._VerifyValues( @@ -447,7 +529,8 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 4, 4, 4], ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=pool_grad_grad_func) # SAME padding, size 8 input self._VerifyValues( @@ -456,10 +539,14 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 8, 8, 8], ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=pool_grad_grad_func) def testMaxPool(self): - self._TestPooling(nn_ops.max_pool, gen_nn_ops._max_pool_grad) + self._TestPooling( + nn_ops.max_pool, + gen_nn_ops.max_pool_grad, + pool_grad_grad_func=gen_nn_ops.max_pool_grad_grad) def testAvgPool(self): # Wrapper around AvgPoolGrad that ignores extra arguments needed by @@ -467,7 +554,7 @@ class PoolGradTest(XLATestCase): def AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding, data_format): del outputs # Unused by average-pooling gradients. - return gen_nn_ops._avg_pool_grad( + return gen_nn_ops.avg_pool_grad( inputs.get_shape().as_list(), output_gradients, ksize=ksize, @@ -483,7 +570,7 @@ class PoolGradTest(XLATestCase): def testMaxPoolKernelSmallerThanStrideValid(self): self._VerifyValues( nn_ops.max_pool, - gen_nn_ops._max_pool_grad, + gen_nn_ops.max_pool_grad, input_sizes=[1, 7, 7, 1], ksize=[1, 2, 2, 1], strides=[1, 3, 3, 1], @@ -492,7 +579,7 @@ class PoolGradTest(XLATestCase): def testMaxPoolKernelSmallerThanStrideSame(self): self._VerifyValues( nn_ops.max_pool, - gen_nn_ops._max_pool_grad, + gen_nn_ops.max_pool_grad, input_sizes=[1, 3, 3, 1], ksize=[1, 1, 1, 1], strides=[1, 2, 2, 1], @@ -500,7 +587,7 @@ class PoolGradTest(XLATestCase): self._VerifyValues( nn_ops.max_pool, - gen_nn_ops._max_pool_grad, + gen_nn_ops.max_pool_grad, input_sizes=[1, 4, 4, 1], ksize=[1, 1, 1, 1], strides=[1, 2, 2, 1], diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index e72dd4eea9f127e1df96ab166103c4c16372adb6..e53efc3091d8935e745122af29abd7b8063b1d01 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -83,8 +83,8 @@ string LocalDeviceToFullDeviceName(const string& device) { return strings::StrCat("/job:localhost/replica:0/task:0/device:", device); } -constexpr std::array kAllXlaTypes = { - {DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64}}; +constexpr std::array kAllXlaTypes = { + {DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64, DT_INT64}}; // An OpTestBuilder is a graph builder class that takes as input an operator to // test, its inputs and attributes, and builds a graph that executes the diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index 965fdf684b973498d0b3c3cde17711cca7279705..2c084b04fa2f67ad0d86508109522d7bead206eb 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase @@ -30,8 +31,13 @@ from tensorflow.python.platform import googletest class ReduceOpsTest(XLATestCase): - def _testReduction(self, tf_reduce_fn, np_reduce_fn, dtype, test_inputs, - rtol=1e-4, atol=1e-4): + def _testReduction(self, + tf_reduce_fn, + np_reduce_fn, + dtype, + test_inputs, + rtol=1e-4, + atol=1e-4): """Tests that the output of 'tf_reduce_fn' matches numpy's output.""" for test_input in test_inputs: @@ -41,16 +47,16 @@ class ReduceOpsTest(XLATestCase): index = array_ops.placeholder(dtypes.int32) out = tf_reduce_fn(a, index) result = sess.run(out, {a: test_input, index: [0]}) - self.assertAllClose(result, np_reduce_fn(test_input, axis=0), - rtol=rtol, atol=atol) + self.assertAllClose( + result, np_reduce_fn(test_input, axis=0), rtol=rtol, atol=atol) result = sess.run(out, {a: test_input, index: [1]}) - self.assertAllClose(result, np_reduce_fn(test_input, axis=1), - rtol=rtol, atol=atol) + self.assertAllClose( + result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol) result = sess.run(out, {a: test_input, index: [-1]}) - self.assertAllClose(result, np_reduce_fn(test_input, axis=1), - rtol=rtol, atol=atol) + self.assertAllClose( + result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, 'Invalid reduction dim'): @@ -60,7 +66,7 @@ class ReduceOpsTest(XLATestCase): errors_impl.InvalidArgumentError, 'Invalid reduction dim'): sess.run(out, {a: test_input, index: [2]}) - FLOAT_DATA = [ + REAL_DATA = [ np.zeros(shape=(2, 0)), np.zeros(shape=(0, 30)), np.arange(1, 7).reshape(2, 3), @@ -74,7 +80,7 @@ class ReduceOpsTest(XLATestCase): np.arange(-14, -2, dtype=np.float32).view(np.complex64).reshape(2, 3), np.arange(-4, 8, dtype=np.float32).view(np.complex64).reshape(2, 3), ] - NONEMPTY_FLOAT_DATA = [x for x in FLOAT_DATA if np.size(x) > 0] + NONEMPTY_REAL_DATA = [x for x in REAL_DATA if np.size(x) > 0] NONEMPTY_COMPLEX_DATA = [x for x in COMPLEX_DATA if np.size(x) > 0] BOOL_DATA = [ np.array([], dtype=np.bool).reshape(2, 0), @@ -83,8 +89,7 @@ class ReduceOpsTest(XLATestCase): ] def testReduceSumF32(self): - self._testReduction(math_ops.reduce_sum, np.sum, np.float32, - self.FLOAT_DATA) + self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA) def testReduceSumC64(self): self._testReduction(math_ops.reduce_sum, np.sum, np.complex64, @@ -92,7 +97,7 @@ class ReduceOpsTest(XLATestCase): def testReduceProdF32(self): self._testReduction(math_ops.reduce_prod, np.prod, np.float32, - self.FLOAT_DATA) + self.REAL_DATA) def testReduceProdC64(self): self._testReduction(math_ops.reduce_prod, np.prod, np.complex64, @@ -100,31 +105,44 @@ class ReduceOpsTest(XLATestCase): def testReduceMin(self): - def reference_min(inp, axis): + def reference_min(dtype, inp, axis): """Wrapper around np.amin that returns +infinity for an empty input.""" if inp.shape[axis] == 0: - return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('inf')) + if np.issubdtype(dtype, np.floating): + return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('inf')) + return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], + np.iinfo(dtype).max) return np.amin(inp, axis) - self._testReduction(math_ops.reduce_min, reference_min, np.float32, - self.FLOAT_DATA) + for dtype in set(self.all_types).intersection( + [np.float32, np.int32, np.int64]): + self._testReduction(math_ops.reduce_min, + functools.partial(reference_min, dtype), dtype, + self.REAL_DATA) def testReduceMax(self): - def reference_max(inp, axis): + def reference_max(dtype, inp, axis): """Wrapper around np.amax that returns -infinity for an empty input.""" if inp.shape[axis] == 0: - return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('-inf')) + if np.issubdtype(dtype, np.floating): + return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], + float('-inf')) + return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], + np.iinfo(dtype).min) return np.amax(inp, axis) - self._testReduction(math_ops.reduce_max, reference_max, np.float32, - self.FLOAT_DATA) + for dtype in set(self.all_types).intersection( + [np.float32, np.int32, np.int64]): + self._testReduction(math_ops.reduce_max, + functools.partial(reference_max, dtype), dtype, + self.REAL_DATA) def testReduceMeanF32(self): # TODO(phawkins): mean on XLA currently returns 0 instead of NaN when # reducing across zero inputs. self._testReduction(math_ops.reduce_mean, np.mean, np.float32, - self.NONEMPTY_FLOAT_DATA) + self.NONEMPTY_REAL_DATA) def testReduceMeanC64(self): self._testReduction(math_ops.reduce_mean, np.mean, np.complex64, diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py index c013f4b50a4cf95be8028248c52b10b1c3be2bd3..92518aadc4bf5c601cfb4192c093799784b6aa72 100644 --- a/tensorflow/compiler/tests/spacetobatch_op_test.py +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -75,11 +75,11 @@ class SpaceToBatchTest(XLATestCase): for dtype in self.float_types: # outputs = space_to_batch(inputs) placeholder = array_ops.placeholder(dtype) - x_tf = gen_array_ops._space_to_batch( + x_tf = gen_array_ops.space_to_batch( placeholder, paddings, block_size=block_size) self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs) # inputs = batch_to_space(outputs) - x_tf = gen_array_ops._batch_to_space( + x_tf = gen_array_ops.batch_to_space( placeholder, paddings, block_size=block_size) self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs) diff --git a/tensorflow/compiler/tests/stack_ops_test.py b/tensorflow/compiler/tests/stack_ops_test.py index 2b9c2279737ccee531d488d27ccdb0cafa1dc8fc..94342f9567ca71274609e63b0482d55637c98d51 100644 --- a/tensorflow/compiler/tests/stack_ops_test.py +++ b/tensorflow/compiler/tests/stack_ops_test.py @@ -34,33 +34,33 @@ class StackOpTest(XLATestCase): with self.test_session(), self.test_scope(): size = array_ops.placeholder(dtypes.int32) v = array_ops.placeholder(dtypes.float32) - h = gen_data_flow_ops._stack_v2(size, dtypes.float32, stack_name="foo") - c = gen_data_flow_ops._stack_push_v2(h, v) + h = gen_data_flow_ops.stack_v2(size, dtypes.float32, stack_name="foo") + c = gen_data_flow_ops.stack_push_v2(h, v) with ops.control_dependencies([c]): - c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32) + c1 = gen_data_flow_ops.stack_pop_v2(h, dtypes.float32) self.assertAllClose([[4.0, 5.0]], c1.eval({size: 5, v: [[4.0, 5.0]]})) def testStackPushPopSwap(self): with self.test_session(), self.test_scope(): a = np.arange(2000) x = array_ops.placeholder(dtypes.float32) - h = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") - c = gen_data_flow_ops._stack_push_v2(h, x, swap_memory=True) + h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") + c = gen_data_flow_ops.stack_push_v2(h, x, swap_memory=True) with ops.control_dependencies([c]): - c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32) + c1 = gen_data_flow_ops.stack_pop_v2(h, dtypes.float32) self.assertAllClose(a, c1.eval({x: a})) def testMultiStack(self): with self.test_session(), self.test_scope(): v = array_ops.placeholder(dtypes.float32) - h1 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") - c1 = gen_data_flow_ops._stack_push_v2(h1, v) + h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") + c1 = gen_data_flow_ops.stack_push_v2(h1, v) with ops.control_dependencies([c1]): - c1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32) - h2 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="bar") - c2 = gen_data_flow_ops._stack_push_v2(h2, 5.0) + c1 = gen_data_flow_ops.stack_pop_v2(h1, dtypes.float32) + h2 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="bar") + c2 = gen_data_flow_ops.stack_push_v2(h2, 5.0) with ops.control_dependencies([c2]): - c2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32) + c2 = gen_data_flow_ops.stack_pop_v2(h2, dtypes.float32) r = c1 + c2 self.assertAllClose(9.0, r.eval({v: 4.0})) @@ -69,15 +69,15 @@ class StackOpTest(XLATestCase): with self.test_session() as sess, self.test_scope(): v1 = array_ops.placeholder(dtypes.float32) v2 = array_ops.placeholder(dtypes.float32) - h1 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") - h2 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") + h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") + h2 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") - c1 = gen_data_flow_ops._stack_push_v2(h1, v1) + c1 = gen_data_flow_ops.stack_push_v2(h1, v1) with ops.control_dependencies([c1]): - c2 = gen_data_flow_ops._stack_push_v2(h2, v2) + c2 = gen_data_flow_ops.stack_push_v2(h2, v2) with ops.control_dependencies([c2]): - pop1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32) - pop2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32) + pop1 = gen_data_flow_ops.stack_pop_v2(h1, dtypes.float32) + pop2 = gen_data_flow_ops.stack_pop_v2(h2, dtypes.float32) out1, out2 = sess.run([pop1, pop2], {v1: 4.0, v2: 5.0}) self.assertAllClose(out1, 4.0) @@ -86,17 +86,17 @@ class StackOpTest(XLATestCase): def testCloseStack(self): with self.test_session() as sess, self.test_scope(): size = array_ops.placeholder(dtypes.int32) - h = gen_data_flow_ops._stack_v2(size, dtypes.float32, stack_name="foo") - c1 = gen_data_flow_ops._stack_close_v2(h) + h = gen_data_flow_ops.stack_v2(size, dtypes.float32, stack_name="foo") + c1 = gen_data_flow_ops.stack_close_v2(h) sess.run(c1, {size: 5}) def testPushCloseStack(self): with self.test_session() as sess, self.test_scope(): v = array_ops.placeholder(dtypes.float32) - h = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") - c = gen_data_flow_ops._stack_push_v2(h, v) + h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") + c = gen_data_flow_ops.stack_push_v2(h, v) with ops.control_dependencies([c]): - c1 = gen_data_flow_ops._stack_close_v2(h) + c1 = gen_data_flow_ops.stack_close_v2(h) sess.run(c1, {v: [[4.0, 5.0]]}) diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index a62925a1818da00cb0a9e82e1281db20fb38b208..7624d6e4b2e2ece6a61155743fc8b866f6903f32 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -338,7 +338,7 @@ class TensorArrayTest(xla_test.XLATestCase): w0 = ta.write(0, [[4.0, 5.0]]) # Test reading wrong datatype. - r0_bad = gen_data_flow_ops._tensor_array_read_v3( + r0_bad = gen_data_flow_ops.tensor_array_read_v3( handle=w0.handle, index=0, dtype=dtype2, flow_in=w0.flow) with self.assertRaisesOpError("TensorArray dtype is "): r0_bad.eval() diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 7e1f5c76ed65946363cc3c113ab1a9862f87b289..e924fe1e61454aefda622a5a46a0e483d26db5c1 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import contextlib +import os import random import re @@ -44,6 +45,8 @@ flags.DEFINE_string('test_device', None, flags.DEFINE_string('types', None, 'Types to test. Comma-separated list.') flags.DEFINE_string('disabled_manifest', None, 'Path to a file with a list of tests that should not run.') +flags.DEFINE_string('tf_xla_flags', None, + 'Value to set the TF_XLA_FLAGS environment variable to') class XLATestCase(test.TestCase): @@ -71,14 +74,14 @@ class XLATestCase(test.TestCase): self._all_types = set( [dtype.as_numpy_dtype for dtype in self._all_tf_types]) - self.int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types]) + self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types]) self._float_types = set( [dtype.as_numpy_dtype for dtype in self._float_tf_types]) self.complex_types = set([ dtype.as_numpy_dtype for dtype in self.complex_tf_types ]) - self._numeric_types = set( - self.int_types | self._float_types | self.complex_types) + self._numeric_types = set(self._int_types | self._float_types + | self.complex_types) # Parse the manifest file, if any, into a regex identifying tests to # disable @@ -97,6 +100,8 @@ class XLATestCase(test.TestCase): disabled_tests = [] disabled_method_types = [] for l in manifest_file.read().splitlines(): + if not l: + continue entry = comments_re.sub('', l).strip().split(' ') if len(entry) == 1: disabled_tests.append(entry[0]) @@ -113,6 +118,9 @@ class XLATestCase(test.TestCase): for name in types]) manifest_file.close() + if FLAGS.tf_xla_flags is not None: + os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags + @property def all_tf_types(self): name = '{}.{}'.format(type(self).__name__, self._testMethodName) @@ -130,6 +138,11 @@ class XLATestCase(test.TestCase): name = '{}.{}'.format(type(self).__name__, self._testMethodName) return self._float_tf_types - self._method_types_filter.get(name, set()) + @property + def int_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + return self._int_types - self._method_types_filter.get(name, set()) + @property def numeric_tf_types(self): name = '{}.{}'.format(type(self).__name__, self._testMethodName) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index fb82c2601c432cee425a46a3b6dc2c55febeda87..eb20ca501c80b01c76198e1ad54173f1c601714d 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -58,6 +58,15 @@ xla_proto_library( ], ) +xla_proto_library( + name = "host_compute_metadata_proto", + srcs = ["host_compute_metadata.proto"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:protos_all_cc", + ], +) + cc_library( name = "tf2xla", srcs = ["tf2xla.cc"], @@ -149,6 +158,7 @@ cc_library( ":common", ":dump_graph", ":functionalize_control_flow", + ":host_compute_metadata_proto", ":sharding_util", ":tf2xla_util", "//tensorflow/compiler/tf2xla/lib:util", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 82923722c54d235716b9138d95a75a441df924ca..de1008803d69fefa415c7bdbe6c27a62e625b417 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -37,7 +37,7 @@ Status BackwardsConstAnalysis(const Graph& g, }; Status status; - std::unordered_set must_be_const; + std::unordered_set must_be_const; auto visit = [&status, &metadata_ops, &must_be_const, compile_time_const_args](Node* node) { if (!status.ok()) return; @@ -55,8 +55,10 @@ Status BackwardsConstAnalysis(const Graph& g, compile_time_const_args->at(index) = true; return; } - for (Node* pred : node->in_nodes()) { - must_be_const.insert(pred); + for (const Edge* pred : node->in_edges()) { + if (!pred->IsControlEdge()) { + must_be_const.insert(pred->src()); + } } return; } diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc index 9d125f8d499863cfaa0e26b5b633ca02914d1b7d..992b12c06db5efc0ae54284d0ea77017c1c79aca 100644 --- a/tensorflow/compiler/tf2xla/const_analysis_test.cc +++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc @@ -79,5 +79,24 @@ TEST(ConstAnalysisTest, TopologicalOrder) { } } +TEST(ConstAnalysisTest, DontFollowControlDependencies) { + Scope root = Scope::NewRootScope(); + + Output arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0); + Output arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1); + Output c1 = + ops::Const(root.WithOpName("c1").WithControlDependencies(arg0), 1, {1}); + Output add = ops::Add(root, arg1, c1); + Output reshape = ops::Reshape(root, arg1, add); + + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph)); + + std::vector const_args(2, false); + TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args)); + + EXPECT_EQ(const_args, std::vector({false, true})); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md index 91351421bcacd26c41b5c9f98ea833730e4aef30..20179b67991d3d23d678cf1df2642e029ea037fd 100644 --- a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md +++ b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md @@ -3,6 +3,7 @@ Operator | Type Constraint ------------------------------------- | --------------- `Abs` | `T={double,float,int32,int64}` +`Acos` | `T={complex64,double,float,int32,int64}` `Acosh` | `T={complex64,double,float}` `Add` | `T={complex64,double,float,int32,int64}` `AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` @@ -15,10 +16,12 @@ Operator | Type Constraint `ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}` `ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={float}` `ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Asin` | `T={complex64,double,float,int32,int64}` `Asinh` | `T={complex64,double,float}` `AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` `AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` `AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Atan` | `T={complex64,double,float,int32,int64}` `Atan2` | `T={double,float}` `Atanh` | `T={complex64,double,float}` `AvgPool` | `T={double,float}` @@ -75,6 +78,10 @@ Operator | Type Constraint `FFT` | `FFT2D` | `FFT3D` | +`FakeQuantWithMinMaxArgs` | +`FakeQuantWithMinMaxArgsGradient` | +`FakeQuantWithMinMaxVars` | +`FakeQuantWithMinMaxVarsGradient` | `Fill` | `index_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Floor` | `T={double,float}` `FloorDiv` | `T={complex64,double,float,int32,int64}` @@ -84,6 +91,7 @@ Operator | Type Constraint `FusedBatchNormGradV2` | `U={float}`
`T={float}` `FusedBatchNormV2` | `U={float}`
`T={float}` `Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`GatherNd` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `Greater` | `T={double,float,int32,int64,uint32,uint64}` `GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` @@ -117,14 +125,18 @@ Operator | Type Constraint `LogicalNot` | `LogicalOr` | `MatMul` | `T={complex64,double,float}` +`MatrixBandPart` | `Tindex={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixSetDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixTriangularSolve` | `T={complex64,double,float}` `Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `MaxPool` | `T={double,float,int32,int64}` `MaxPool3D` | `T={float}` `MaxPool3DGrad` | `TInput={float}`
`T={float}` `MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolGradGrad` | `T={float}` +`MaxPoolGradGradV2` | `T={float}` `MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}` `MaxPoolV2` | `T={double,float,int32,int64}` `Maximum` | `T={double,float,int32,int64}` @@ -186,6 +198,7 @@ Operator | Type Constraint `Round` | `T={complex64,double,float,int32,int64}` `Rsqrt` | `T={complex64,double,float}` `RsqrtGrad` | `T={complex64,double,float}` +`ScatterNd` | `Tindices={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Selu` | `T={double,float}` `SeluGrad` | `T={double,float}` @@ -198,6 +211,7 @@ Operator | Type Constraint `Sinh` | `T={complex64,double,float}` `Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Snapshot` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Softmax` | `T={double,float}` `SoftmaxCrossEntropyWithLogits` | `T={double,float}` `Softplus` | `T={double,float,int32,int64,uint32,uint64}` diff --git a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md index b9bdb829d773825005a8921f48d28b6892d8f0cd..55f0538dba7c1941dfea88e0631cd299e51f76d0 100644 --- a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md +++ b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md @@ -3,6 +3,7 @@ Operator | Type Constraint ------------------------------------- | --------------- `Abs` | `T={double,float,int32,int64}` +`Acos` | `T={complex64,double,float,int32,int64}` `Acosh` | `T={complex64,double,float}` `Add` | `T={complex64,double,float,int32,int64}` `AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` @@ -15,10 +16,12 @@ Operator | Type Constraint `ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}` `ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Asin` | `T={complex64,double,float,int32,int64}` `Asinh` | `T={complex64,double,float}` `AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` `AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` `AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Atan` | `T={complex64,double,float,int32,int64}` `Atan2` | `T={double,float}` `Atanh` | `T={complex64,double,float}` `AvgPool` | `T={double,float}` @@ -75,6 +78,10 @@ Operator | Type Constraint `FFT` | `FFT2D` | `FFT3D` | +`FakeQuantWithMinMaxArgs` | +`FakeQuantWithMinMaxArgsGradient` | +`FakeQuantWithMinMaxVars` | +`FakeQuantWithMinMaxVarsGradient` | `Fill` | `index_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Floor` | `T={double,float}` `FloorDiv` | `T={complex64,double,float,int32,int64}` @@ -84,6 +91,7 @@ Operator | Type Constraint `FusedBatchNormGradV2` | `U={float}`
`T={float}` `FusedBatchNormV2` | `U={float}`
`T={float}` `Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`GatherNd` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `Greater` | `T={double,float,int32,int64,uint32,uint64}` `GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` @@ -117,14 +125,18 @@ Operator | Type Constraint `LogicalNot` | `LogicalOr` | `MatMul` | `T={complex64,double,float}` +`MatrixBandPart` | `Tindex={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixSetDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixTriangularSolve` | `T={complex64,double,float}` `Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `MaxPool` | `T={double,float,int32,int64}` `MaxPool3D` | `T={float}` `MaxPool3DGrad` | `TInput={float}`
`T={float}` `MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolGradGrad` | `T={float}` +`MaxPoolGradGradV2` | `T={float}` `MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}` `MaxPoolV2` | `T={double,float,int32,int64}` `Maximum` | `T={double,float,int32,int64}` @@ -183,6 +195,7 @@ Operator | Type Constraint `Round` | `T={complex64,double,float,int32,int64}` `Rsqrt` | `T={complex64,double,float}` `RsqrtGrad` | `T={complex64,double,float}` +`ScatterNd` | `Tindices={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Selu` | `T={double,float}` `SeluGrad` | `T={double,float}` @@ -195,6 +208,7 @@ Operator | Type Constraint `Sinh` | `T={complex64,double,float}` `Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Snapshot` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Softmax` | `T={double,float}` `SoftmaxCrossEntropyWithLogits` | `T={double,float}` `Softplus` | `T={double,float,int32,int64,uint32,uint64}` diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 058a1f2621c64a735bd9d9c9d0ae007f93aa4dea..b20c1ffc7d8956f3f5530ee63e9b711a26439be5 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -130,7 +130,7 @@ Status GraphCompiler::Compile() { // Set up inputs from outputs of previous nodes. for (auto* e : n->in_edges()) { if (e->IsControlEdge()) continue; - Node* src = e->src(); + const Node* src = e->src(); TF_RET_CHECK(src->id() < output_registry.size()); const NodeOutputs& src_outputs = output_registry[src->id()]; diff --git a/tensorflow/compiler/tf2xla/host_compute_metadata.proto b/tensorflow/compiler/tf2xla/host_compute_metadata.proto new file mode 100644 index 0000000000000000000000000000000000000000..43ab371a217e6c4521a160715104c96e3c8782c6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/host_compute_metadata.proto @@ -0,0 +1,38 @@ +syntax = "proto3"; + +package tensorflow.tf2xla; +option cc_enable_arenas = true; +option java_outer_classname = "Tf2XlaProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.tf2xla"; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +// TensorMetadata indicates the type and shape of a Tensor that is +// part of a host compute transfer. +message TensorMetadata { + DataType type = 1; + TensorShapeProto shape = 2; +} + +// HostTransferMetadata describes a transfer either from host to device +// or device to host. It has a key that is unique to the computation, +// and metadata about the list of tensors being transferred. +message HostTransferMetadata { + // The key used to identify this transfer. + string key = 1; + + // For each Tensor being transferred, its type and shape. + repeated TensorMetadata metadata = 2; +} + +// HostComputeMetadata describes all the sends and recvs +// from all host compute transfer ops in a computation. +message HostComputeMetadata { + // Metadata about each device_to_host transfer + repeated HostTransferMetadata device_to_host = 1; + + // Metadata about each host_to_device transfer + repeated HostTransferMetadata host_to_device = 2; +} diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index d2fa933cf9c085f92b2f442827a94d72938e4bb2..0bbfe86de389ff6063b1f9604003f35b41d28e3b 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -93,6 +93,7 @@ tf_kernel_library( "shape_util.h", ], deps = [ + ":if_op", ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -154,6 +155,22 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "if_op", + srcs = ["if_op.cc"], + hdrs = ["if_op.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/ops:functional_ops", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + # Kernels that only work on CPU, because they use XLA custom calls. # Only link this when using the CPU backend for XLA. tf_kernel_library( diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index a249b1869f547f8e5aa725f9f5cf391b10429928..931175be1111ed5f70afbdf351ee53c59c1367de 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -118,30 +118,24 @@ class FusedBatchNormGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); - - auto grad_backprop = ctx->Input(0); - auto activations = ctx->Input(1); - auto scale = ctx->Input(2); - auto mean = ctx->Input(3); - auto var = ctx->Input(4); - - TensorShape input_shape = ctx->InputShape(0); - int feature_index = - GetTensorFeatureDimIndex(input_shape.dims(), data_format_); - + xla::ComputationBuilder* const b = ctx->builder(); DataType input_dtype = ctx->input_type(0); DataType scale_dtype = ctx->input_type(2); - xla::PrimitiveType input_type; - OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_dtype, &input_type)); - xla::PrimitiveType scale_type; - OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(scale_dtype, &scale_type)); // TODO(b/69928690): support mixed precision in the XLA batch normalization // operators. For now, cast everything to the statistics type (which // may be more precise than the input type). - grad_backprop = b->ConvertElementType(grad_backprop, scale_type); - activations = b->ConvertElementType(activations, scale_type); + auto grad_backprop = + XlaHelpers::ConvertElementType(b, ctx->Input(0), scale_dtype); + auto activations = + XlaHelpers::ConvertElementType(b, ctx->Input(1), scale_dtype); + auto scale = ctx->Input(2); + auto mean = ctx->Input(3); + auto var = ctx->Input(4); + + const int input_dims = ctx->InputShape(0).dims(); + const int feature_index = + GetTensorFeatureDimIndex(input_dims, data_format_); xla::ComputationDataHandle x_backprop; xla::ComputationDataHandle scale_backprop; @@ -156,7 +150,7 @@ class FusedBatchNormGradOp : public XlaOpKernel { offset_backprop = b->GetTupleElement(output, 2); } else { // Reduce over all dimensions except the feature dim. - std::vector reduction_dims(input_shape.dims() - 1); + std::vector reduction_dims(input_dims - 1); std::iota(reduction_dims.begin(), reduction_dims.begin() + feature_index, 0); std::iota(reduction_dims.begin() + feature_index, reduction_dims.end(), @@ -165,9 +159,14 @@ class FusedBatchNormGradOp : public XlaOpKernel { // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + // epsilon)) // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon)) - offset_backprop = - b->Reduce(grad_backprop, XlaHelpers::Zero(b, scale_dtype), - *ctx->GetOrCreateAdd(scale_dtype), reduction_dims); + const DataType accumulation_type = + XlaHelpers::SumAccumulationType(scale_dtype); + auto converted = + XlaHelpers::ConvertElementType(b, grad_backprop, accumulation_type); + auto reduce = + b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); + offset_backprop = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); // scratch1 = rsqrt(pop_var + epsilon) auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); @@ -175,17 +174,21 @@ class FusedBatchNormGradOp : public XlaOpKernel { b->Pow(b->Add(var, b->ConstantR0(epsilon_)), neg_half); // scratch2 = sum(y_backprop * (x - mean)) - auto scratch2 = b->Reduce( - b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index})), - XlaHelpers::Zero(b, scale_dtype), *ctx->GetOrCreateAdd(scale_dtype), - reduction_dims); + auto mul = + b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index})); + converted = XlaHelpers::ConvertElementType(b, mul, accumulation_type); + reduce = + b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); + auto scratch2 = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); x_backprop = b->Mul(grad_backprop, b->Mul(scratch1, scale), {feature_index}); scale_backprop = b->Mul(scratch1, scratch2); } - ctx->SetOutput(0, b->ConvertElementType(x_backprop, input_type)); + ctx->SetOutput(0, + XlaHelpers::ConvertElementType(b, x_backprop, input_dtype)); ctx->SetOutput(1, scale_backprop); ctx->SetOutput(2, offset_backprop); ctx->SetConstantOutput(3, Tensor(scale_dtype, {})); diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index cbade79e85eed10ecb5ead7151ee778c86a0de37..569950c2dfaeb61028049a263a962dfa54a62e09 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -184,9 +184,7 @@ class BatchToSpaceOp : public XlaOpKernel { private: int block_size_; }; -REGISTER_XLA_OP(Name("BatchToSpace") - .CompileTimeConstInput("crops") - .CompileTimeConstInput("block_shape"), +REGISTER_XLA_OP(Name("BatchToSpace").CompileTimeConstInput("crops"), BatchToSpaceOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index c667b4e3e326b776faba49387760abbd582fcc68..ed33b8ed2e823f313a9a7fe220390bc617288405 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -103,10 +103,15 @@ class BiasAddGradOp : public XlaOpKernel { std::iota(reduce_dims.begin(), reduce_dims.begin() + feature_dim, 0); std::iota(reduce_dims.begin() + feature_dim, reduce_dims.end(), feature_dim + 1); - xla::ComputationDataHandle result = ctx->builder()->Reduce( - ctx->Input(0), XlaHelpers::Zero(ctx->builder(), input_type(0)), - *ctx->GetOrCreateAdd(input_type(0)), reduce_dims); - ctx->SetOutput(0, result); + xla::ComputationBuilder* const b = ctx->builder(); + const DataType accumulation_type = + XlaHelpers::SumAccumulationType(input_type(0)); + auto converted = + XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); + auto reduce = + b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduce_dims); + ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, reduce, input_type(0))); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 81cea6d376d02c956a5257c5475fe5c10b83deb9..c0ee0c9c2ea849a692bee70bba36d32335eed9b5 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -58,7 +58,7 @@ xla::ComputationDataHandle CreateExpandedZero( // Create a mask for depthwise convolution that will make a normal convolution // produce the same results as a depthwise convolution. For a [2, 2, 3, 2] -// depthwise filter this returns a [2, 2, 3, 6] tesnsor +// depthwise filter this returns a [2, 2, 3, 6] tensor // 1 1 0 0 0 0 1 1 0 0 0 0 // 0 0 1 1 0 0 0 0 1 1 0 0 // 0 0 0 0 1 1 0 0 0 0 1 1 @@ -166,6 +166,10 @@ xla::ComputationDataHandle ContractFilterForDepthwiseBackprop( CreateExpandedFilterMask(filter_shape, builder), filter_backprop, CreateExpandedZero(filter_shape, dtype, builder)); return builder->Reshape( + // This reduce does not need inputs to be converted with + // XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with + // ExpandedZero guarantees that only one element is non zero, so there + // cannot be accumulated precision error. builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), *ctx->GetOrCreateAdd(dtype), {expanded_filter_shape.dims() - 2}), diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 453a32c494b42e9922bc35fc526f3306530054fd..99470d70e709ddb5593c5eaae061bb897befc168 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -247,6 +247,8 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { const TensorShape gradient_shape = ctx->InputShape(0); xla::ComputationDataHandle input = ctx->Input(1); const DataType data_type = ctx->input_type(1); + const DataType accumulation_type = + XlaHelpers::SumAccumulationType(data_type); xla::ComputationDataHandle input_min = ctx->Input(2); xla::ComputationDataHandle input_max = ctx->Input(3); @@ -265,15 +267,23 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { ctx->SetOutput(0, output0); xla::ComputationDataHandle below_min = b->Lt(input, nudged_input_min); + xla::ComputationDataHandle select1 = b->Select(below_min, gradient, zeroes); + xla::ComputationDataHandle reduce1 = b->ReduceAll( + XlaHelpers::ConvertElementType(b, select1, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type)); xla::ComputationDataHandle output1 = - b->ReduceAll(b->Select(below_min, gradient, zeroes), zero, - *ctx->GetOrCreateAdd(data_type)); + XlaHelpers::ConvertElementType(b, reduce1, data_type); ctx->SetOutput(1, output1); xla::ComputationDataHandle above_max = b->Gt(input, nudged_input_max); + xla::ComputationDataHandle select2 = b->Select(above_max, gradient, zeroes); + xla::ComputationDataHandle reduce2 = b->ReduceAll( + XlaHelpers::ConvertElementType(b, select2, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type)); xla::ComputationDataHandle output2 = - b->ReduceAll(b->Select(above_max, gradient, zeroes), zero, - *ctx->GetOrCreateAdd(data_type)); + XlaHelpers::ConvertElementType(b, reduce2, data_type); ctx->SetOutput(2, output2); } diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..eefbe55c815d80a608bdf62d454a69d722adb158 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -0,0 +1,226 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/if_op.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { + +XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + const NameAttrList* name_attr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &name_attr)); + then_branch_ = *name_attr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &name_attr)); + else_branch_ = *name_attr; + + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_)); +} + +// TODO(b/35949885): There is duplication here with the handling of the +// while_op. Refactor the common code out/rework. +void XlaIfOp::Compile(XlaOpKernelContext* ctx) { + xla::ComputationBuilder* b = ctx->builder(); + + OP_REQUIRES(ctx, cond_type_ == DT_BOOL, + errors::InvalidArgument( + "Condition argument must be a boolean for XLA compilation")); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(0)), + errors::InvalidArgument( + "Condition argument must be a scalar for XLA compilation")); + + VLOG(1) << "Building If: " << input_types_.size() << " inputs"; + + std::vector inputs(input_types_.size()); + std::vector arguments(input_types_.size()); + for (int i = 0; i < input_types_.size(); ++i) { + XlaCompiler::Argument& arg = arguments[i]; + DataType type = ctx->input_type(i + 1); + if (type == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource)); + + arg.initialized = resource->initialized(); + arg.kind = XlaCompiler::Argument::kResource; + arg.resource_kind = resource->kind(); + OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); + + arg.type = resource->type(); + arg.shape = resource->shape(); + OP_REQUIRES(ctx, arg.initialized, + errors::Unimplemented("Uninitialized arguments: ", arg.name)); + arg.tensor_array_size = resource->tensor_array_size(); + for (const auto& gradient : resource->tensor_array_gradients()) { + arg.tensor_array_gradients.insert(gradient.first); + } + arg.name = resource->name(); + VLOG(2) << "Resource " << resource->name() + << " type: " << DataTypeString(arg.type) + << " shape: " << arg.shape.DebugString() + << " initialized: " << arg.initialized; + } else { + arg.kind = XlaCompiler::Argument::kParameter; + arg.type = input_types_[i]; + arg.shape = ctx->InputShape(i + 1); + inputs[i] = ctx->Input(i + 1); + VLOG(2) << "Arg type: " << DataTypeString(arg.type) + << " shape: " << arg.shape.DebugString(); + } + } + + // Compile both branches of the conditional. + XlaCompiler::CompileOptions options; + options.use_tuple_arg = true; + options.resolve_compile_time_constants = false; + options.return_updated_values_for_all_resources = true; + options.is_entry_computation = false; + XlaCompiler* compiler = ctx->compiler(); + + XlaCompiler::CompilationResult then_result; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_, + arguments, &then_result)); + XlaCompiler::CompilationResult else_result; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_, + arguments, &else_result)); + + for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) { + for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, + ctx->GetResourceInput(update.input_index + 1, &resource)); + XlaCompiler::Argument& arg = arguments[update.input_index]; + + // Add any TensorArray gradients touched by the then/else computation to + // the enclosing graph. + for (const string& grad_source : update.tensor_array_gradients_accessed) { + VLOG(5) << "TensorArray " << resource->name() << " accessed gradient " + << grad_source; + XlaResource* gradient; + OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient( + grad_source, b, &gradient)); + } + // Add all of the TensorArray gradients to the argument. For simplicity, + // we always pass all known gradients. + for (const auto& gradient : resource->tensor_array_gradients()) { + arg.tensor_array_gradients.insert(gradient.first); + } + } + } + + // Check that both branches have identical input shapes. + OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1, + errors::FailedPrecondition("Expected one input shape")); + xla::Shape then_input_shape = then_result.xla_input_shapes[0]; + OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(then_input_shape), + errors::FailedPrecondition("Expected tuple shape")); + OP_REQUIRES(ctx, else_result.xla_input_shapes.size() == 1, + errors::FailedPrecondition("Expected one input shape")); + xla::Shape else_input_shape = else_result.xla_input_shapes[0]; + OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(else_input_shape), + errors::FailedPrecondition("Expected tuple shape")); + OP_REQUIRES(ctx, + xla::ShapeUtil::Compatible(then_input_shape, else_input_shape), + errors::InvalidArgument( + "Input shapes of then and else branches do not match: ", + xla::ShapeUtil::HumanString(then_input_shape), " vs. ", + xla::ShapeUtil::HumanString(else_input_shape))); + + // Check that both branches have identical output shapes. + OP_REQUIRES( + ctx, + xla::ShapeUtil::Compatible(then_result.xla_output_shape, + else_result.xla_output_shape), + errors::InvalidArgument( + "Output shapes of then and else branches do not match: ", + xla::ShapeUtil::HumanString(then_result.xla_output_shape), " vs. ", + xla::ShapeUtil::HumanString(else_result.xla_output_shape))); + + VLOG(2) << "Input shape: " << xla::ShapeUtil::HumanString(then_input_shape); + VLOG(2) << "Output shape: " + << xla::ShapeUtil::HumanString(then_result.xla_output_shape); + + // We set return_updated_values_for_all_resources=true and we pass the same + // arguments to both computations, so the resource update count must match. + OP_REQUIRES(ctx, + then_result.resource_updates.size() == + else_result.resource_updates.size(), + errors::FailedPrecondition( + "Different number of resources in then and else branch")); + for (int i = 0; i < then_result.resource_updates.size(); ++i) { + const auto& lhs = then_result.resource_updates[i]; + const auto& rhs = else_result.resource_updates[i]; + bool equal = lhs.input_index == rhs.input_index && lhs.shape == rhs.shape && + lhs.tensor_array_gradients_accessed == + rhs.tensor_array_gradients_accessed; + OP_REQUIRES( + ctx, equal, + errors::FailedPrecondition( + "Mismatch in resource of then and else branch for resource ", i)); + } + + xla::ComputationDataHandle outputs = + b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation, + b->Tuple(inputs), *else_result.computation); + // Sets non-variable outputs. + for (int i = 0; i < output_types_.size(); ++i) { + if (ctx->input_type(i) != DT_RESOURCE) { + xla::ComputationDataHandle output_handle = b->GetTupleElement(outputs, i); + if (VLOG_IS_ON(2)) { + LOG(INFO) << "Setting output " << i; + auto shape_or = b->GetShape(output_handle); + if (shape_or.ok()) { + LOG(INFO) << "Shape for output " << i << ": " + << xla::ShapeUtil::HumanString(*shape_or.ValueOrDie()); + } else { + LOG(INFO) << "Shape unknown for output " << i; + } + } + ctx->SetOutput(i, output_handle); + } + } + + // Updates the values of any resource variables modified by the conditional + // bodies. + for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) { + for (int i = 0; i < result->resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& update = result->resource_updates[i]; + XlaResource* resource; + OP_REQUIRES_OK(ctx, + ctx->GetResourceInput(update.input_index + 1, &resource)); + if (update.modified) { + int pos = result->outputs.size() + i; + OP_REQUIRES_OK(ctx, + resource->SetFromPack( + arguments[update.input_index].tensor_array_gradients, + b->GetTupleElement(outputs, pos), b)); + } + VLOG(2) << "If variable: pos: " << update.input_index + << " name: " << resource->name() + << " modified: " << update.modified + << " type: " << DataTypeString(update.type) + << " shape: " << update.shape.DebugString(); + } + } + VLOG(1) << "Done building If"; +} + +REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f9bc98a198a72dcc0594e61971713bf890ce30b6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/if_op.h @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/attr_value.pb.h" + +namespace tensorflow { + +// This TensorFlow op provides a functional conditional primitive. +// +// The outputs of the then/else branches must agree on the number, types, and +// shapes of the Tensors carried around the two bodies. +// +// Computations in then/else bodies may read from and write to resource +// variables. +// Resource variables may be passed as arguments to the then/else function's +// bodies. The XlaCompiler converts resource variable arguments +// into parameters to the XLA computation and moves them to the end of the +// parameter list, and by using the `return_updated_values_for_all_variables` +// we ensure that all variables that appear in the input also appear at the +// end of the then/else bodies output. This ensures the then/else bodies output +// signatures match. +// +// It is the user's responsibility to ensure that each non-variable _Arg matches +// the corresponding _Retval. +class XlaIfOp : public XlaOpKernel { + public: + explicit XlaIfOp(OpKernelConstruction* ctx); + + void Compile(XlaOpKernelContext* ctx) override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaIfOp); + + NameAttrList then_branch_; + NameAttrList else_branch_; + DataType cond_type_; + DataTypeVector input_types_; + DataTypeVector output_types_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index f22f384256a8ddd8c05de4a1322aba741dc4d7fd..5eeda79a935e8194a596d322b52add27846d378c 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -180,9 +180,13 @@ class AdjustContrastOpV2 : public XlaOpKernel { DataType type = context->input_type(0); - auto output = b->Reduce(input, /*init_value=*/XlaHelpers::Zero(b, type), - /*computation=*/*context->GetOrCreateAdd(type), + const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); + auto converted = + XlaHelpers::ConvertElementType(b, input, accumulation_type); + auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *context->GetOrCreateAdd(accumulation_type), {height_dim, width_dim}); + auto output = XlaHelpers::ConvertElementType(b, reduce, type); output = b->Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); std::vector broadcast_dims(input_shape.dims() - 2); diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index d096415087e47a73503a06526ab133ac34803c5d..c177f08d9c4687bb13b98a4328bb3960519799c4 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -29,21 +29,22 @@ class L2LossOp : public XlaOpKernel { explicit L2LossOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); + std::vector dims(ctx->InputShape(0).dims()); + std::iota(dims.begin(), dims.end(), 0); DataType dtype = ctx->input_type(0); - xla::ComputationBuilder* b = ctx->builder(); - - auto zero = XlaHelpers::Zero(b, dtype); - auto two = XlaHelpers::IntegerLiteral(b, dtype, 2); - const xla::Computation& add = *ctx->GetOrCreateAdd(dtype); - - std::vector dims(input_shape.dims()); - std::iota(dims.begin(), dims.end(), 0); + xla::ComputationBuilder* const b = ctx->builder(); // output = sum(t ** 2) / 2 - auto x = ctx->Input(0); - ctx->SetOutput(0, b->Div(b->Reduce(b->Mul(x, x), zero, add, dims), two)); + const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); + auto t = + XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); + auto square = b->Mul(t, t); + auto reduce = b->Reduce(square, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), dims); + auto deconverted = XlaHelpers::ConvertElementType(b, reduce, dtype); + auto two = XlaHelpers::IntegerLiteral(b, dtype, 2); + ctx->SetOutput(0, b->Div(deconverted, two)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index 759d1a1a2d996d4f5deb1774be7014bb6de30f40..1cfee3070f384af0a7441a9c860c530dd1b42187 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -47,12 +47,17 @@ class LRNOp : public XlaOpKernel { // We use a window of depth_radius_ * 2 + 1, to account for the current // element and a depth_radius_ on either side. - auto squared = builder->Mul(input, input); - auto sqr_sum = builder->ReduceWindow( - squared, XlaHelpers::Zero(builder, input_type(0)), - *ctx->GetOrCreateAdd(input_type(0)), + auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); + auto converted = + XlaHelpers::ConvertElementType(builder, input, accumulation_type); + auto squared = builder->Mul(converted, converted); + auto reduce = builder->ReduceWindow( + squared, XlaHelpers::Zero(builder, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); + auto sqr_sum = + XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); auto scale = builder->Pow( builder->Add(builder->ConstantR0(bias_), @@ -130,12 +135,17 @@ class LRNGradOp : public XlaOpKernel { // dyi *= out_grads[j] // grads[k] += dyi - auto squared = builder->Mul(in_image, in_image); - auto sqr_sum = builder->ReduceWindow( - squared, XlaHelpers::Zero(builder, input_type(0)), - *ctx->GetOrCreateAdd(input_type(0)), + auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); + auto converted = + XlaHelpers::ConvertElementType(builder, in_image, accumulation_type); + auto squared = builder->Mul(converted, converted); + auto reduce = builder->ReduceWindow( + squared, XlaHelpers::Zero(builder, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); + auto sqr_sum = + XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); auto norm = builder->Add(builder->ConstantR0(bias_), @@ -146,11 +156,15 @@ class LRNGradOp : public XlaOpKernel { builder->Div(out_image, norm)), in_grads); - auto dy_reduced = builder->ReduceWindow( - dy, XlaHelpers::Zero(builder, input_type(0)), - *ctx->GetOrCreateAdd(input_type(0)), + auto converted_dy = + XlaHelpers::ConvertElementType(builder, dy, accumulation_type); + auto dy_reduce = builder->ReduceWindow( + converted_dy, XlaHelpers::Zero(builder, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); + auto dy_reduced = + XlaHelpers::ConvertElementType(builder, dy_reduce, input_type(0)); xla::ComputationDataHandle gradients = builder->Add( builder->Mul(in_image, dy_reduced), diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index d4fb5dd4e06c7c70591262c0d63a91c383a2a6e0..5f635dd1bc6122cfcac8163baafd95b13f157715 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -35,8 +35,11 @@ namespace { // Superclass of pooling ops. class PoolingOp : public XlaOpKernel { public: - PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims) - : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { + PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims, + const DataType reduction_type) + : XlaOpKernel(ctx), + num_spatial_dims_(num_spatial_dims), + reduction_type_(reduction_type) { if (ctx->num_inputs() == 1) { std::vector ksize_int; std::vector stride_int; @@ -63,12 +66,10 @@ class PoolingOp : public XlaOpKernel { int num_dims() const { return num_spatial_dims_ + 2; } // Method that builds an initial value to use in reductions. - virtual xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, - DataType data_type) = 0; + virtual xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) = 0; // The reduction operation to apply to each window. - virtual const xla::Computation* Reduction(XlaOpKernelContext* ctx, - DataType dtype) = 0; + virtual const xla::Computation* Reduction(XlaOpKernelContext* ctx) = 0; // A post-processing operation to apply on the outputs of the ReduceWindow. virtual xla::ComputationDataHandle PostProcessOutput( @@ -76,9 +77,6 @@ class PoolingOp : public XlaOpKernel { DataType dtype, const TensorShape& input_shape) = 0; void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle input = ctx->Input(0); - const TensorShape input_shape = ctx->InputShape(0); - std::vector ksize = ksize_; std::vector stride = stride_; if (ctx->num_inputs() != 1) { @@ -106,16 +104,20 @@ class PoolingOp : public XlaOpKernel { stride.clear(); OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride)); } + const TensorShape input_shape = ctx->InputShape(0); OP_REQUIRES(ctx, input_shape.dims() == num_dims(), errors::InvalidArgument("Input to ", type_string(), " operator must have ", num_dims(), " dimensions")); - const DataType type = input_type(0); - xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow( - input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize, - stride, padding_); - ctx->SetOutput(0, PostProcessOutput(ctx, pooled, type, input_shape)); + xla::ComputationBuilder* const b = ctx->builder(); + auto input = + XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_); + auto reduce = ctx->builder()->ReduceWindow( + input, InitValue(b), *Reduction(ctx), ksize, stride, padding_); + auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); + ctx->SetOutput(0, + PostProcessOutput(ctx, pooled, input_type(0), input_shape)); } protected: @@ -124,21 +126,21 @@ class PoolingOp : public XlaOpKernel { std::vector stride_; xla::Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; + DataType reduction_type_; }; class MaxPoolOp : public PoolingOp { public: MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) - : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims) {} + : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, + /*reduction_type=*/ctx->input_type(0)) {} - xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, - DataType data_type) override { - return XlaHelpers::MinValue(b, data_type); + xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) override { + return XlaHelpers::MinValue(b, reduction_type_); } - const xla::Computation* Reduction(XlaOpKernelContext* ctx, - DataType dtype) override { - return ctx->GetOrCreateMax(dtype); + const xla::Computation* Reduction(XlaOpKernelContext* ctx) override { + return ctx->GetOrCreateMax(reduction_type_); } xla::ComputationDataHandle PostProcessOutput( @@ -209,15 +211,17 @@ static xla::ComputationDataHandle AvgPoolDivideByCount( } // Build a matrix of all 1s, with the same width/height as the input. + const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); auto ones = ctx->builder()->Broadcast( - XlaHelpers::One(ctx->builder(), dtype), input_dim_sizes); + XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes); // Perform a ReduceWindow with the same window size, strides, and padding // to count the number of contributions to each result element. - auto counts = ctx->builder()->ReduceWindow( - ones, XlaHelpers::Zero(ctx->builder(), dtype), - *ctx->GetOrCreateAdd(dtype), window_ksize, window_stride, + auto reduce = ctx->builder()->ReduceWindow( + ones, XlaHelpers::Zero(ctx->builder(), accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride, xla::Padding::kSame); + auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype); return ctx->builder()->Div(output, counts, window_dims); } @@ -226,16 +230,16 @@ static xla::ComputationDataHandle AvgPoolDivideByCount( class AvgPoolOp : public PoolingOp { public: AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) - : PoolingOp(ctx, num_spatial_dims) {} + : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, + /*reduction_type=*/ + XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, - DataType data_type) override { - return XlaHelpers::Zero(b, data_type); + xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) override { + return XlaHelpers::Zero(b, reduction_type_); } - const xla::Computation* Reduction(XlaOpKernelContext* ctx, - DataType dtype) override { - return ctx->GetOrCreateAdd(dtype); + const xla::Computation* Reduction(XlaOpKernelContext* ctx) override { + return ctx->GetOrCreateAdd(reduction_type_); } xla::ComputationDataHandle PostProcessOutput( @@ -455,14 +459,12 @@ class AvgPoolGradOp : public XlaOpKernel { gradients_shape, filter_shape, out_backprop_shape, stride_, padding_, data_format_, &dims)); + // The input gradients are computed by a convolution of the output gradients + // and the filter, with some appropriate padding. See the comment at the top + // of conv_grad_ops.h for details. + xla::ComputationBuilder* const b = ctx->builder(); auto out_backprop = ctx->Input(1); - - // The input gradients are computed by a convolution of the output - // gradients - // and the filter, with some appropriate padding. See the comment at - // the top of conv_grad_ops.h for details. - DataType dtype = input_type(1); - + auto dtype = input_type(1); xla::Padding xla_padding = (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; @@ -483,17 +485,18 @@ class AvgPoolGradOp : public XlaOpKernel { padding->set_interior_padding(dims.spatial_dims[i].stride - 1); } - auto zero = XlaHelpers::Zero(ctx->builder(), dtype); - auto padded_gradients = - ctx->builder()->Pad(out_backprop_div, zero, padding_config); + auto zero = XlaHelpers::Zero(b, dtype); + auto padded_gradients = b->Pad(out_backprop_div, zero, padding_config); // in_backprop = padded_gradients ones std::vector ones(num_dims(), 1LL); - xla::ComputationDataHandle in_backprop = ctx->builder()->ReduceWindow( - padded_gradients, zero, *ctx->GetOrCreateAdd(dtype), ksize_, + auto accumulation_type = XlaHelpers::SumAccumulationType(dtype); + auto in_backprop = b->ReduceWindow( + XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), ksize_, /* window_strides=*/ones, xla::Padding::kValid); - - ctx->SetOutput(0, in_backprop); + ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, in_backprop, dtype)); } protected: @@ -525,5 +528,172 @@ class AvgPool3DGradOp : public AvgPoolGradOp { REGISTER_XLA_OP(Name("AvgPool3DGrad").CompileTimeConstInput("orig_input_shape"), AvgPool3DGradOp); +class MaxPoolGradGradOp : public XlaOpKernel { + public: + MaxPoolGradGradOp(OpKernelConstruction* ctx, int num_spatial_dims) + : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { + if (ctx->num_inputs() == 3) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); + } + OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); + } + + int num_dims() const { return num_spatial_dims_ + 2; } + + void Compile(XlaOpKernelContext* ctx) override { + if (ctx->num_inputs() != 3) { + OP_REQUIRES( + ctx, ctx->num_inputs() == 5, + errors::InvalidArgument("Must supply ksize and stride arguments.")); + const TensorShape ksize_shape = ctx->InputShape(3); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), + errors::InvalidArgument("ksize must be a vector, not shape ", + ksize_shape.DebugString())); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_)); + + const TensorShape stride_shape = ctx->InputShape(4); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), + errors::InvalidArgument("stride must be a vector, not shape ", + stride_shape.DebugString())); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_)); + } + + OP_REQUIRES(ctx, ksize_.size() == num_dims(), + errors::InvalidArgument("Sliding window ksize field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES(ctx, stride_.size() == num_dims(), + errors::InvalidArgument("Sliding window strides field must " + "specify ", + num_dims(), " dimensions")); + + const TensorShape tensor_in_shape = ctx->InputShape(0); + const TensorShape tensor_out_shape = ctx->InputShape(1); + const TensorShape out_backprop_shape = ctx->InputShape(2); + + // For maxpooling, tensor_in should have num_dims() dimensions. + OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(), + errors::InvalidArgument("tensor_in must be ", num_dims(), + "-dimensional")); + OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(), + errors::InvalidArgument("tensor_out must be ", num_dims(), + "-dimensional")); + // For maxpooling, out_backprop should have num_dims() dimensions. + OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(), + errors::InvalidArgument("out_backprop must be ", num_dims(), + "-dimensional")); + + // What we want to compute: + // Given y = MaxPool(x), and xs_grad = MaxPoolGrad(x, y, ys_grad) + // MaxPoolGradGrad computes {ys_grad}_grad given x, y, and {xs_grad}_grad. + // + // In the regular TF op, this amounts to selecting for each window the + // incoming backprop value from xs_grad_grad that corresponds to the maximal + // value in the corresponding window of x. + // + // TODO(b/73062247): What we really want is a ReduceWindow with different + // arrays for index selection vs return value selection--a select-to-gather. + // + // Here, we implement a bitwise hack: we use the hi 16 bits of input for + // separate max pooling alongside each of the hi and lo 16 bits of + // out_backprop packed into 16 lo bits, which we then glue back together at + // the end to get a full 32 bits of gradient. + // + // This could select the wrong backprop value for two x values that are + // equally maximal up to the first 16 bits, in which case we are taking the + // latter. + // + // Note that in principle we could use 32 separate maxpools to recover each + // of 32 bits of the gradient while preserving 31 bits of input for the max + // pooling criteria; here, we just truncate to the first 16 bits of input. + + auto input = ctx->Input(0); + auto out_backprop = ctx->Input(2); + + auto b = ctx->builder(); + + auto sixteen = b->ConstantR0(16); + // in (f32) -> round to bf16 -> f32 for correct bitwidth -> 16-high-bit u32 + auto in_hi = b->BitcastConvertType( + b->ConvertElementType(b->ConvertElementType(input, xla::BF16), + xla::F32), + xla::U32); + auto bp_int = b->BitcastConvertType(out_backprop, xla::U32); + auto bp_hi = b->ShiftRightLogical(bp_int, sixteen); + auto bp_lo = b->ShiftRightLogical(b->ShiftLeft(bp_int, sixteen), sixteen); + auto in_hi_bp_hi = b->Add(in_hi, bp_hi); // Want an unsigned add. + auto in_hi_bp_lo = b->Add(in_hi, bp_lo); // Want an unsigned add. + + auto init_value = XlaHelpers::MinValue(b, DT_FLOAT); + // We will reduce by taking the maximal value up to 16 bits (ignoring the lo + // 16 bits of packed-in hi/lo backprop value). + auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits"); + { + // F32 parameters to satisfy lowering type restriction for reduce opcode. + const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {}); + auto lhs = rb->Parameter(0, scalar, "lhs"); + auto rhs = rb->Parameter(1, scalar, "rhs"); + auto sixteen = rb->ConstantR0(16); + auto lhs_criteria = rb->ShiftLeft( + rb->ShiftRightLogical(rb->BitcastConvertType(lhs, xla::S32), sixteen), + sixteen); + auto rhs_criteria = rb->ShiftLeft( + rb->ShiftRightLogical(rb->BitcastConvertType(rhs, xla::S32), sixteen), + sixteen); + // Must use a F32 comparison, because S32 would not work for negatives. + rb->Select(rb->Ge(rb->BitcastConvertType(lhs_criteria, xla::F32), + rb->BitcastConvertType(rhs_criteria, xla::F32)), + lhs, rhs); + } + auto reduce = rb->BuildAndNoteError(); + xla::Padding xla_padding = + (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; + auto pooled_hi = + b->ReduceWindow(b->BitcastConvertType(in_hi_bp_hi, xla::F32), + init_value, reduce, ksize_, stride_, xla_padding); + auto pooled_lo = + b->ReduceWindow(b->BitcastConvertType(in_hi_bp_lo, xla::F32), + init_value, reduce, ksize_, stride_, xla_padding); + auto grads_hi = + b->ShiftLeft(b->BitcastConvertType(pooled_hi, xla::U32), sixteen); + auto grads_lo = b->ShiftRightLogical( + b->ShiftLeft(b->BitcastConvertType(pooled_lo, xla::U32), sixteen), + sixteen); + auto grads = b->Add(grads_hi, grads_lo); // Want an unsigned add. + + xla::PrimitiveType element_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type)); + ctx->SetOutput(0, b->BitcastConvertType(grads, element_type)); + } + + protected: + const int num_spatial_dims_; + std::vector ksize_; + std::vector stride_; + Padding padding_; + TensorFormat data_format_ = FORMAT_NHWC; +}; + +class MaxPool2DGradGradOp : public MaxPoolGradGradOp { + public: + explicit MaxPool2DGradGradOp(OpKernelConstruction* ctx) + : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/2) { + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + } +}; +REGISTER_XLA_OP(Name("MaxPoolGradGrad").TypeConstraint("T", DT_FLOAT), + MaxPool2DGradGradOp); +REGISTER_XLA_OP(Name("MaxPoolGradGradV2") + .TypeConstraint("T", DT_FLOAT) + .CompileTimeConstInput("ksize") + .CompileTimeConstInput("strides"), + MaxPool2DGradGradOp); + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 03b13b2924f4b81c1017804c91d5ffb81c44ea0b..812d258cd1677e18ef49952044126c76a2f55b19 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -27,7 +27,13 @@ namespace { class SumOp : public XlaReductionOp { public: - explicit SumOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit SumOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, + XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} + xla::ComputationDataHandle InitialValue( + xla::ComputationBuilder* builder) override { + return XlaHelpers::Zero(builder, reduction_type_); + } void BuildReducer(xla::ComputationBuilder* builder, const xla::ComputationDataHandle& scalar_lhs, const xla::ComputationDataHandle& scalar_rhs) override { @@ -39,11 +45,13 @@ REGISTER_XLA_OP(Name("Sum").CompileTimeConstInput("reduction_indices"), SumOp); class ProdOp : public XlaReductionOp { public: - explicit ProdOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit ProdOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, + XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} xla::ComputationDataHandle InitialValue( xla::ComputationBuilder* builder) override { - return XlaHelpers::One(builder, input_type(0)); + return XlaHelpers::One(builder, reduction_type_); } void BuildReducer(xla::ComputationBuilder* builder, @@ -58,13 +66,12 @@ REGISTER_XLA_OP(Name("Prod").CompileTimeConstInput("reduction_indices"), class MinOp : public XlaReductionOp { public: - explicit MinOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit MinOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::ComputationDataHandle InitialValue( xla::ComputationBuilder* builder) override { - xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); - return builder->ConstantLiteral(xla::Literal::MaxValue(type)); + return XlaHelpers::MaxValue(builder, reduction_type_); } void BuildReducer(xla::ComputationBuilder* builder, @@ -78,13 +85,12 @@ REGISTER_XLA_OP(Name("Min").CompileTimeConstInput("reduction_indices"), MinOp); class MaxOp : public XlaReductionOp { public: - explicit MaxOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit MaxOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::ComputationDataHandle InitialValue( xla::ComputationBuilder* builder) override { - xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); - return builder->ConstantLiteral(xla::Literal::MinValue(type)); + return XlaHelpers::MinValue(builder, reduction_type_); } void BuildReducer(xla::ComputationBuilder* builder, @@ -98,8 +104,14 @@ REGISTER_XLA_OP(Name("Max").CompileTimeConstInput("reduction_indices"), MaxOp); class MeanOp : public XlaReductionOp { public: - explicit MeanOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit MeanOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, + XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} + xla::ComputationDataHandle InitialValue( + xla::ComputationBuilder* builder) override { + return XlaHelpers::Zero(builder, reduction_type_); + } void BuildReducer(xla::ComputationBuilder* builder, const xla::ComputationDataHandle& scalar_lhs, const xla::ComputationDataHandle& scalar_rhs) override { @@ -121,7 +133,8 @@ REGISTER_XLA_OP(Name("Mean").CompileTimeConstInput("reduction_indices"), class AllOp : public XlaReductionOp { public: - explicit AllOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit AllOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::ComputationDataHandle InitialValue( xla::ComputationBuilder* builder) override { @@ -139,7 +152,8 @@ REGISTER_XLA_OP(Name("All").CompileTimeConstInput("reduction_indices"), AllOp); class AnyOp : public XlaReductionOp { public: - explicit AnyOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit AnyOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::ComputationDataHandle InitialValue( xla::ComputationBuilder* builder) override { diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index 9aca6d8fedf92f176b3b7b40c5961d4a2e557a8a..f3181f0dadc2d3f45abb145e009e2663c10490f0 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -33,12 +33,12 @@ namespace tensorflow { // xla::ComputationBuilder. class XlaReductionOp : public XlaOpKernel { public: - explicit XlaReductionOp(OpKernelConstruction* ctx); + XlaReductionOp(OpKernelConstruction* ctx, DataType reduction_type); ~XlaReductionOp() override {} - // Return the base case for the reduction. Defaults to zero. + // Return the base case for the reduction. virtual xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder); + xla::ComputationBuilder* builder) = 0; // Implement the (scalar,scalar)->scalar lambda that should be // applied to each pair of elements to be reduced. The desired @@ -63,6 +63,9 @@ class XlaReductionOp : public XlaOpKernel { private: // True if the number of dimensions should be maintained. bool keep_dims_; + + protected: + DataType reduction_type_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 4b5d09eb9fd4110cdc4221099ff55767e9132540..64fe765ae9a945c58ea60bc157b1520c83b0d8e7 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -24,19 +24,15 @@ limitations under the License. namespace tensorflow { -XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { +XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx, + DataType reduction_type) + : XlaOpKernel(ctx), reduction_type_(reduction_type) { const DataType dt = BaseType(input_type(0)); OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt})); OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_)); } -// Return the base case for the reduction. Defaults to zero. -xla::ComputationDataHandle XlaReductionOp::InitialValue( - xla::ComputationBuilder* builder) { - return XlaHelpers::Zero(builder, input_type(0)); -} - // Unless BuildFinalizer is overridden the reduction has no // finalizer. xla::ComputationDataHandle XlaReductionOp::BuildFinalizer( @@ -100,36 +96,26 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { string desc = ctx->op_kernel().name(); - // Call virtual method to get the initial value. - const xla::ComputationDataHandle initial = InitialValue(ctx->builder()); + xla::ComputationBuilder* const b = ctx->builder(); // Construct the builder for the reduction lambda. - xla::ComputationBuilder r(ctx->builder()->client(), - strings::StrCat(desc, "-reduction")); + xla::ComputationBuilder r(b->client(), strings::StrCat(desc, "-reduction")); xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); - // Make two scalar parameters of the desired type for the lambda. - xla::ComputationDataHandle rx = - r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x"); - xla::ComputationDataHandle ry = - r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y"); - - auto data = ctx->Input(0); + TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type)); + auto data = b->ConvertElementType(ctx->Input(0), type); + // Call virtual method to get the initial value. + auto initial = b->ConvertElementType(InitialValue(b), type); + // Make two scalar parameters of the desired type for the lambda. + auto rx = r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x"); + auto ry = r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y"); // Call virtual method to build the reduction lambda. BuildReducer(&r, rx, ry); xla::Computation reduction_computation = r.Build().ConsumeValueOrDie(); - xla::ComputationDataHandle reduce = - ctx->builder()->Reduce(data, initial, reduction_computation, xla_axes); - xla::ComputationDataHandle finalized = - BuildFinalizer(ctx->builder(), reduce, num_elements_reduced); - - xla::ComputationDataHandle result; - if (keep_dims_) { - result = ctx->builder()->Reshape(finalized, final_shape); - } else { - result = finalized; - } + auto reduce = b->Reduce(data, initial, reduction_computation, xla_axes); + auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); + auto finalized = BuildFinalizer(b, deconverted, num_elements_reduced); + auto result = keep_dims_ ? b->Reshape(finalized, final_shape) : finalized; ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index ee4a94164c4a43828eb4feedbfa9d1a9e231ef8f..4cfa28a0ce3d7d1f24196ef6ef2775f840b2bcf1 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -66,7 +66,7 @@ class ScanOp : public XlaOpKernel { -input_shape.dims(), ", ", input_shape.dims(), "), but got ", axis)); - DataType dtype = ctx->input_type(0); + DataType dtype = XlaHelpers::SumAccumulationType(ctx->input_type(0)); if (input_shape.num_elements() == 0) { // Exit early if there is nothing to compute. @@ -91,7 +91,6 @@ class ScanOp : public XlaOpKernel { std::swap(padding[axis].first, padding[axis].second); } - xla::ComputationDataHandle input = ctx->Input(0); xla::ComputationDataHandle init; const xla::Computation* reducer; if (sum_) { @@ -102,7 +101,10 @@ class ScanOp : public XlaOpKernel { reducer = ctx->GetOrCreateMul(dtype); } auto output = builder->ReduceWindowWithGeneralPadding( - ctx->Input(0), init, *reducer, window_dims, window_strides, padding); + XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init, + *reducer, window_dims, window_strides, padding); + output = + XlaHelpers::ConvertElementType(builder, output, ctx->input_type(0)); // In exclusive mode, we have computed an extra element containing the sum // of all the input elements. Slice off this extra "last" element. diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 80d6df6c48b0141734dcee1c2a3c413926931feb..498342a98881df0c6ff50007eacc1d5ef6196b57 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -83,7 +83,9 @@ class UnsortedSegmentSum : public XlaOpKernel { DataType dtype_; }; -REGISTER_XLA_OP(Name("UnsortedSegmentSum"), UnsortedSegmentSum); +REGISTER_XLA_OP( + Name("UnsortedSegmentSum").CompileTimeConstInput("num_segments"), + UnsortedSegmentSum); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 750a4c2dec8154f97f307978b3d8884271292279..aa47cb799f1f3d01f6fcb01ff9f2e410f7f0ac5a 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -42,9 +42,8 @@ class SoftmaxOp : public XlaOpKernel { const DataType type = input_type(0); auto logits = ctx->Input(0); - xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationBuilder* const b = ctx->builder(); const xla::Computation& max_func = *ctx->GetOrCreateMax(type); - const xla::Computation& add_func = *ctx->GetOrCreateAdd(type); // Find the max in each batch, resulting in a tensor of shape [batch] auto logits_max = @@ -52,21 +51,20 @@ class SoftmaxOp : public XlaOpKernel { // Subtract the max in batch b from every element in batch b. Broadcasts // along the batch dimension. auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim}); - xla::ComputationDataHandle softmax; - if (log_) { - // softmax = shifted_logits - log(sum(exp(shifted_logits))) - auto log_sum_exp = - b->Log(b->Reduce(b->Exp(shifted_logits), XlaHelpers::Zero(b, type), - add_func, {kClassDim})); - softmax = b->Sub(shifted_logits, log_sum_exp, {kBatchDim}); - } else { - // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) - auto exp_shifted = b->Exp(shifted_logits); - auto sum_exp = b->Reduce(exp_shifted, XlaHelpers::Zero(b, type), add_func, - {kClassDim}); - softmax = b->Div(exp_shifted, sum_exp, {kBatchDim}); - } - + auto exp_shifted = b->Exp(shifted_logits); + const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); + auto converted = + XlaHelpers::ConvertElementType(b, exp_shifted, accumulation_type); + auto reduce = + b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto sum = XlaHelpers::ConvertElementType(b, reduce, type); + auto softmax = + log_ + // softmax = shifted_logits - log(sum(exp(shifted_logits))) + ? b->Sub(shifted_logits, b->Log(sum), {kBatchDim}) + // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) + : b->Div(exp_shifted, sum, {kBatchDim}); ctx->SetOutput(0, softmax); } @@ -82,7 +80,6 @@ CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type, const xla::ComputationDataHandle& logits, const xla::ComputationDataHandle& labels) { const xla::Computation& max_func = *ctx->GetOrCreateMax(type); - const xla::Computation& add_func = *ctx->GetOrCreateAdd(type); const int kBatchDim = 0; const int kClassDim = 1; @@ -100,8 +97,12 @@ CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type, auto exp_shifted_logits = b->Exp(shifted_logits); // sum_{class} (exp(logits - max_logits)) - auto sum_exp = b->Reduce(exp_shifted_logits, XlaHelpers::Zero(b, type), - add_func, {kClassDim}); + const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); + auto converted = + XlaHelpers::ConvertElementType(b, exp_shifted_logits, accumulation_type); + auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto sum_exp = XlaHelpers::ConvertElementType(b, reduce, type); // log(sum(exp(logits - max_logits))) auto log_sum_exp = b->Log(sum_exp); @@ -110,9 +111,13 @@ CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type, // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) // along classes // (The subtraction broadcasts along the batch dimension.) - xla::ComputationDataHandle loss = b->Reduce( - b->Mul(b->Neg(labels), b->Sub(shifted_logits, log_sum_exp, {kBatchDim})), - XlaHelpers::Zero(b, type), add_func, {kClassDim}); + auto sub = b->Sub(shifted_logits, log_sum_exp, {kBatchDim}); + auto mul = b->Mul(b->Neg(labels), sub); + auto sum = + b->Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto loss = XlaHelpers::ConvertElementType(b, sum, type); // backprop: prob - labels, where // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 79c435c90a1f57250be90c2c2523bf3d7d231461..43c15e753805352875034dfd2c70a2a1ed9a4114 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -111,27 +111,24 @@ class SplitVOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const int32 num_split = num_outputs(); + const TensorShape input_shape = ctx->InputShape(0); const TensorShape index_shape = ctx->InputShape(2); - xla::Literal literal_index; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &literal_index)); - int32 split_dim; - OP_REQUIRES(ctx, index_shape.dims() == 0, - errors::InvalidArgument("split_dim input to Split Op must be a " - "scalar")); - split_dim = literal_index.Get({}); + int64 split_dim_orig; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &split_dim_orig)); + int64 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims() + : split_dim_orig; + OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(), + errors::InvalidArgument("-input rank(-", input_shape.dims(), + ") <= split_dim < input rank (", + input_shape.dims(), "), but got ", + split_dim_orig)); xla::ComputationDataHandle input = ctx->Input(0); - const TensorShape input_shape = ctx->InputShape(0); OP_REQUIRES(ctx, input_shape.dims() > 0, errors::InvalidArgument("Can't split a 0 dimensional input")); - OP_REQUIRES( - ctx, 0 <= split_dim && split_dim < input_shape.dims(), - errors::InvalidArgument("0 <= split_dim < number of input dimensions (", - input_shape.dims(), "), but got ", split_dim)); - OP_REQUIRES( ctx, num_split > 0, errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index b10880de77e6b9811008076cd4a959c284e558d1..5bb773d97fc5ce90dabceeefd5c29d916597f5ff 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -239,6 +239,7 @@ class StatelessRandomUniformOp : public XlaOpKernel { // TODO(phawkins): generalize to non-float, non-int32 seed types. REGISTER_XLA_OP(Name("StatelessRandomUniform") + .CompileTimeConstInput("shape") .TypeConstraint("dtype", DT_FLOAT) .TypeConstraint("Tseed", DT_INT32), StatelessRandomUniformOp); @@ -272,6 +273,7 @@ class StatelessRandomNormalOp : public XlaOpKernel { // TODO(phawkins): generalize to non-float, non-int32 seed types. REGISTER_XLA_OP(Name("StatelessRandomNormal") + .CompileTimeConstInput("shape") .TypeConstraint("dtype", DT_FLOAT) .TypeConstraint("Tseed", DT_INT32), StatelessRandomNormalOp); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 5ec05c4121e059ad2b1307376766a41916fe61ae..86263d847ae02d50e70dafb0129b2664c522f2a3 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -600,6 +600,48 @@ Status XlaCompiler::BuildArguments( return Status::OK(); } +Status XlaCompiler::CompileSingleOp( + const XlaCompiler::CompileOptions& options, string const& name, + OpKernelContext* ctx, const std::vector& args, + CompilationResult* result) { + // TODO(b/74182462): We implement this by creating a new dummy Graph including + // _Arg nodes, and let CompileGraph walk it. This could be optimized. + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + Status status; + // First create the actual node we care about computing. + Node* main_node = graph->AddNode(ctx->op_kernel().def(), &status); + TF_RETURN_IF_ERROR(status); + + // Create dummy _Arg nodes. Link these to `node` and also via a control + // dependency edge to the _SOURCE node. + for (int64 i = 0; i < ctx->num_inputs(); ++i) { + Node* node; + string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_arg"); + Status status = NodeBuilder(name, "_Arg") + .ControlInput(graph->source_node()) + .Attr("T", ctx->input_dtype(i)) + .Attr("index", i) + .Finalize(graph.get(), &node); + TF_RETURN_IF_ERROR(status); + graph->AddEdge(node, 0, main_node, i); + } + + // Similarly with return values, create dummy _Retval nodes fed by `node`. + for (int64 i = 0; i < ctx->num_outputs(); ++i) { + Node* node; + string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_retval"); + Status status = NodeBuilder(name, "_Retval") + .Input(main_node, i) + .Attr("T", ctx->expected_output_dtype(i)) + .Attr("index", i) + .Finalize(graph.get(), &node); + TF_RETURN_IF_ERROR(status); + } + + return CompileGraph(options, name, std::move(graph), args, result); +} + Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, @@ -674,6 +716,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, VLOG(2) << "XLA output shape: " << xla::ShapeUtil::HumanString(result->xla_output_shape); + // Copy the host transfer metadata to the result. + for (const auto& send : host_compute_sends_) { + *result->host_compute_metadata.add_device_to_host() = send.second; + } + for (const auto& recv : host_compute_recvs_) { + *result->host_compute_metadata.add_host_to_device() = recv.second; + } + // Tensorflow expects a major-to-minor order of results. xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); @@ -708,4 +758,59 @@ Status XlaCompiler::GetChannelHandle(const string& key, return Status::OK(); } +namespace { + +void SetTransfer(const string& key, gtl::ArraySlice types, + gtl::ArraySlice shapes, + tf2xla::HostTransferMetadata* transfer) { + transfer->set_key(key); + CHECK(types.size() == shapes.size()); + for (int i = 0; i < types.size(); ++i) { + tf2xla::TensorMetadata* metadata = transfer->add_metadata(); + metadata->set_type(types[i]); + shapes[i].AsProto(metadata->mutable_shape()); + } +} + +} // namespace + +Status XlaCompiler::SetDeviceToHostMetadata( + const string& key, gtl::ArraySlice types, + gtl::ArraySlice shapes) { + if (host_compute_sends_.find(key) != host_compute_sends_.end()) { + return errors::InvalidArgument( + "Duplicate calls to SetDeviceToHostMetadata with key ", key); + } + tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key]; + SetTransfer(key, types, shapes, &transfer); + return Status::OK(); +} + +Status XlaCompiler::GetDeviceToHostShapes( + const string& key, std::vector* shapes) const { + const auto iter = host_compute_sends_.find(key); + if (iter == host_compute_sends_.end()) { + return errors::InvalidArgument( + "No host compute send shapes registered for key ", key); + } + shapes->clear(); + for (int i = 0; i < iter->second.metadata_size(); ++i) { + TensorShape shape(iter->second.metadata(i).shape()); + shapes->push_back(shape); + } + return Status::OK(); +} + +Status XlaCompiler::SetHostToDeviceMetadata( + const string& key, gtl::ArraySlice types, + gtl::ArraySlice shapes) { + if (host_compute_recvs_.find(key) != host_compute_sends_.end()) { + return errors::InvalidArgument( + "Duplicate calls to SetHostToDeviceMetadata with key ", key); + } + tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key]; + SetTransfer(key, types, shapes, &transfer); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index c4449bc4be06daff856eff70c6d89be6ddbcf0ee..a6747bbe72e161b2ece55697825cce0e71145a5c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ +#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device.h" @@ -216,6 +217,10 @@ class XlaCompiler { // containing both constant and non-constant results. std::vector outputs; + // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their + // matching RecvAtHost/SendFromHost Ops in the outer graph. + tf2xla::HostComputeMetadata host_compute_metadata; + // Resources whose values were updated by the computation, ordered // by return value position. Resource updates follow the non-constant // results in the outputs of XLA computation. @@ -284,6 +289,14 @@ class XlaCompiler { const std::vector& args, CompilationResult* result); + // Compiles a single Op, given by an OpKernelContext, into an + // xla::Computation. Similar to CompileFunction but takes a single Op as + // input. + Status CompileSingleOp(const CompileOptions& options, string const& name, + OpKernelContext* ctx, + const std::vector& args, + CompilationResult* result); + // Returns the shape of the XLA parameter for an argument 'arg'. // See the class comment for more details about the argument passing // convention. @@ -296,6 +309,22 @@ class XlaCompiler { // same XlaCompiler. Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); + // Sets the shapes and types for the device to host transfer associated with + // 'key'. + Status SetDeviceToHostMetadata(const string& key, + gtl::ArraySlice types, + gtl::ArraySlice shapes); + + // Gets the shapes the device to host transfer associated with 'key'. + Status GetDeviceToHostShapes(const string& key, + std::vector* shapes) const; + + // Sets the shapes and types for the host to device transfer associated with + // 'key'. + Status SetHostToDeviceMetadata(const string& key, + gtl::ArraySlice types, + gtl::ArraySlice shapes); + const Options& options() const { return options_; } xla::Client* client() const { return options_.client; } FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } @@ -359,6 +388,9 @@ class XlaCompiler { std::unordered_map channels_; + std::unordered_map host_compute_sends_; + std::unordered_map host_compute_recvs_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); }; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index f048662953e20b2a612271e2daeef6e370c4822a..3b0b2f06ebae4af918cbe6fb8a384004c1858998 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -273,4 +274,20 @@ Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth, return Status::OK(); } +DataType XlaHelpers::SumAccumulationType(const DataType& dtype) { + if (dtype == DT_BFLOAT16) { + return DT_FLOAT; + } + return dtype; +} + +xla::ComputationDataHandle XlaHelpers::ConvertElementType( + xla::ComputationBuilder* const builder, + const xla::ComputationDataHandle& operand, + const DataType new_element_type) { + xla::PrimitiveType convert_to; + TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to)); + return builder->ConvertElementType(operand, convert_to); +} + } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 2a027db4c839c917f3a7acd27184792d157356bf..68ab93b64a5fa87ad99e0f44d84f6473fc8bbebd 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -107,6 +107,18 @@ class XlaHelpers { const xla::ComputationDataHandle& on_value, const xla::ComputationDataHandle& off_value, xla::ComputationDataHandle* one_hot); + + // Certain DataTypes should use increased precision DataTypes when performing + // reductions. This function remaps a given DataType to a higher precision + // DataType if needed. + static DataType SumAccumulationType(const DataType& dtype); + + // A helper for creating a ConvertElementType xla op given a DataType rather + // than the xla::PrimitiveType. + static xla::ComputationDataHandle ConvertElementType( + xla::ComputationBuilder* const builder, + const xla::ComputationDataHandle& operand, + const DataType new_element_type); }; } // end namespace tensorflow diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index c7cb69215fb051b7f87c3be3b0b419b9c1b8998c..cd13db4d300bb5bba21a734173b6afb9223539d8 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -52,6 +52,7 @@ xla_proto_library( visibility = ["//visibility:public"], deps = [ ":xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:session_proto", ], ) diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 46ee4e64c9ae7ca111d9d04bedcb74ff02a42386..ea75ad32d5df7bbadd37e89de6144b264ab6d5d1 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -121,10 +122,31 @@ class Array { CHECK(idx == num_elements()); } - // Creates a 2D array of Eigen::half from the given nested initializer list of - // float values. + // Creates a 1D array of a floating-point type (half, bfloat16, float, + // or double) from an initializer list of float values. template ::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && + std::is_same::value>::type> + Array(std::initializer_list values) + : Array(ToInt64Vector({values.size()})) { + int64 idx = 0; + for (const auto& it1 : values) { + values_[idx] = static_cast(it1); + ++idx; + } + CHECK(idx == num_elements()); + } + + // Creates a 2D array of a floating-point type (half, bfloat16, float, + // or double) from an initializer list of float values. + template ::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && std::is_same::value>::type> Array(std::initializer_list> values) : Array(ToInt64Vector({values.size(), values.begin()->size()})) { @@ -155,10 +177,13 @@ class Array { CHECK(idx == num_elements()); } - // Creates a 3D array of Eigen::half from the given nested initializer list of - // float values. + // Creates a 3D array of a floating-point type (half, bfloat16, float, + // or double) from an initializer list of float values. template ::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && std::is_same::value>::type> Array(std::initializer_list>> values) @@ -196,10 +221,13 @@ class Array { CHECK(idx == num_elements()); } - // Creates a 4D array of Eigen::half from the given nested initializer list of - // float values. + // Creates a 4D array of a floating-point type (half, bfloat16, float, + // or double) from an initializer list of float values. template ::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && std::is_same::value>::type> Array(std::initializer_list< std::initializer_list>>> diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index d30e78ecde45cfcfcfdaac6c13c9d87ab5630c57..a17e81f44832f272fd93dce9f854042b4a84fde4 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -53,10 +53,13 @@ class Array2D : public Array { Array2D(std::initializer_list> values) : Array(values) {} - // Creates an array of Eigen::half from the given nested initializer list of - // float values. + // Creates an array of a floating-point type (half, bfloat16, float, + // or double) from the given nested initializer list of float values. template ::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && std::is_same::value>::type> Array2D(std::initializer_list> values) : Array(values) {} @@ -100,14 +103,16 @@ std::unique_ptr> MakeLinspaceArray2D(double from, double to, int64 n1, int64 n2) { auto array = MakeUnique>(n1, n2); int64 count = n1 * n2; - NativeT step = (count > 1) ? (to - from) / (count - 1) : 0.0f; + NativeT step = + static_cast((count > 1) ? (to - from) / (count - 1) : 0); auto set = [&array, n1, n2](int64 index, NativeT value) { (*array)(index / n2, index % n2) = value; }; for (int64 i = 0; i < count - 1; ++i) { - set(i, static_cast(from + i * step)); + set(i, (static_cast(from) + + static_cast(i) * static_cast(step))); } - set(count - 1, to); + set(count - 1, static_cast(to)); return array; } } // namespace xla diff --git a/tensorflow/compiler/xla/array3d.h b/tensorflow/compiler/xla/array3d.h index e5eb235d45d160d486d1499db665ed14a8509043..0e9a0722ae43e1dc6ecddde9cbc3daf1db058840 100644 --- a/tensorflow/compiler/xla/array3d.h +++ b/tensorflow/compiler/xla/array3d.h @@ -57,10 +57,13 @@ class Array3D : public Array { values) : Array(values) {} - // Creates an array of Eigen::half from the given nested initializer list of - // float values. + // Creates an array of a floating-point type (half, bfloat16, float, + // or double) from the given nested initializer list of float values. template ::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && std::is_same::value>::type> Array3D( std::initializer_list>> diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h index cff70e54bad0116bdd08674b626b3bf99dc89e1f..a75fffc605aa0df3e1e2eeb6d3129718cbbba0e4 100644 --- a/tensorflow/compiler/xla/array4d.h +++ b/tensorflow/compiler/xla/array4d.h @@ -82,10 +82,13 @@ class Array4D : public Array { values) : Array(values) {} - // Creates an array of Eigen::half from the given nested initializer list of - // float values. + // Creates an array of a floating-point type (half, bfloat16, float, + // or double) from the given nested initializer list of float values. template ::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && std::is_same::value>::type> Array4D(std::initializer_list>>> diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 02356699a25e47be50eb15872df4c9c302fc289b..5094e5ce6786bb56da408ea6ec83f786be422b38 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -74,6 +74,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index d15ccb0c28522c647617153aaa8e738d029dfaba..5ce3c45528cfa36315977f7feac920ffd2272894 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -177,6 +177,22 @@ StatusOr> Client::ExecuteAndTransfer( return Transfer(*data, shape_with_output_layout); } +StatusOr> Client::ExecuteAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const ExecutionOptions* execution_options, + ExecutionProfile* execution_profile) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr data, + Execute(computation, arguments, execution_options, execution_profile)); + + const Shape* shape_with_output_layout = nullptr; + if (execution_options && execution_options->has_shape_with_output_layout()) { + shape_with_output_layout = &execution_options->shape_with_output_layout(); + } + return Transfer(*data, shape_with_output_layout); +} + StatusOr Client::LoadSnapshot(const SessionModule& module) { LoadComputationSnapshotRequest request; *request.mutable_module() = module; @@ -231,6 +247,41 @@ StatusOr> Client::Execute( return MakeUnique(stub_, response.output()); } +StatusOr> Client::Execute( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const ExecutionOptions* execution_options, + ExecutionProfile* execution_profile) { + ExecuteGraphRequest request; + *request.mutable_computation() = computation.proto(); + + if (execution_options == nullptr) { + *request.mutable_execution_options() = CreateDefaultExecutionOptions(); + } else { + *request.mutable_execution_options() = *execution_options; + } + for (GlobalData* argument : arguments) { + CHECK(argument != nullptr) << "Argument pointers must not be null."; + *request.add_arguments() = argument->handle(); + } + + ExecuteResponse response; + VLOG(1) << "making execute request: " << request.ShortDebugString(); + Status s = stub_->ExecuteGraph(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + + if (execution_profile != nullptr) { + *execution_profile = response.profile(); + // TODO(b/74197823): Get execution stats for the graph and VLOG(1) them. + } + + return MakeUnique(stub_, response.output()); +} + StatusOr>> Client::ExecuteParallel( tensorflow::gtl::ArraySlice computations) { ExecuteParallelRequest request; diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index c28380b689c7a0e16bf0bcbf15003f4aa15e42a7..ec87646ebf3bfffc70aa1a8597fb2053a7fbe059 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service_interface.h" @@ -57,6 +58,21 @@ class Client { const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); + // Executes the computation with the given arguments and returns the global + // data that was produced from the execution. + // * If execution_options is not nullptr, these options are passed to the + // service to affect how it compiles our computation. (The pointer does not + // need to live beyond this call.) + // * If execution_profile is not nullptr then the pointed-to ExecutionProfile + // will be filled with profile data from the execution. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> Execute( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const ExecutionOptions* execution_options = nullptr, + ExecutionProfile* execution_profile = nullptr); + // A struct to represent a computation instance to be executed. // * If execution_options.device_handles is not empty, the computation is // executed on the devices associated with the handles by partitioning the @@ -137,6 +153,17 @@ class Client { const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); + // Executes the computation with the given arguments and transfers the result + // to the client as a literal. Parameters are defined the same as for + // Execute() and Transfer(). + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> ExecuteAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const ExecutionOptions* execution_options = nullptr, + ExecutionProfile* execution_profile = nullptr); + // Unregister the memory for the given GlobalData on the device. Status Unregister(const GlobalData& data); diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index c7e2c4367b89ca2112022fa40449ae3ebe28463e..59662c95ac15e7c23790c5b5ff5d75a694613aeb 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -39,16 +39,15 @@ CompileOnlyClient::CompileAheadOfTime( return compiler_service_->CompileAheadOfTime(service_instances, options); } -int64 CompileOnlyClient::PointerSizeForTriple( - tensorflow::StringPiece target_triple) { - llvm::Triple triple(llvm::Triple::normalize( - llvm::StringRef(target_triple.data(), target_triple.size()))); - if (triple.isArch64Bit()) { +int64 CompileOnlyClient::PointerSizeForTriple(tensorflow::StringPiece triple) { + llvm::Triple llvm_triple( + llvm::Triple::normalize(llvm::StringRef(triple.data(), triple.size()))); + if (llvm_triple.isArch64Bit()) { return 8; - } else if (triple.isArch32Bit()) { + } else if (llvm_triple.isArch32Bit()) { return 4; } else { - CHECK(triple.isArch16Bit()); + CHECK(llvm_triple.isArch16Bit()); return 2; } } diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 2a6e02649d15bc9fd47a893c41f9c8a62ac076c6..39d02f0863f78d4094f2cc4805f534713fb7e929 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -408,7 +408,7 @@ ComputationDataHandle ComputationBuilder::Reshape( ComputationDataHandle ComputationBuilder::Collapse( const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dims_to_collapse) { + tensorflow::gtl::ArraySlice dimensions) { if (!first_error_.ok()) { return ComputationDataHandle(); } @@ -416,8 +416,8 @@ ComputationDataHandle ComputationBuilder::Collapse( // Don't support out-of-order collapse here. // Checks that the collapsed dimensions are in order and consecutive. for (tensorflow::gtl::ArraySlice::size_type i = 1; - i < dims_to_collapse.size(); ++i) { - if (dims_to_collapse[i] - 1 != dims_to_collapse[i - 1]) { + i < dimensions.size(); ++i) { + if (dimensions[i] - 1 != dimensions[i - 1]) { NoteError(InvalidArgument( "Collapsed dimensions are not in order and consecutive.")); return ComputationDataHandle(); @@ -434,9 +434,9 @@ ComputationDataHandle ComputationBuilder::Collapse( VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape); VLOG(3) << "dims to collapse: " - << tensorflow::str_util::Join(dims_to_collapse, ","); + << tensorflow::str_util::Join(dimensions, ","); - if (dims_to_collapse.size() <= 1) { + if (dimensions.size() <= 1) { // Not collapsing anything, trivially we can return the operand versus // enqueueing a trivial reshape. return operand; @@ -444,7 +444,7 @@ ComputationDataHandle ComputationBuilder::Collapse( std::vector new_sizes; for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) { - if (i <= dims_to_collapse.front() || i > dims_to_collapse.back()) { + if (i <= dimensions.front() || i > dimensions.back()) { new_sizes.push_back(original_shape->dimensions(i)); } else { new_sizes.back() *= original_shape->dimensions(i); @@ -753,13 +753,13 @@ ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape, } void ComputationBuilder::Outfeed(const ComputationDataHandle& operand, - const Shape& shape, + const Shape& shape_with_layout, const string& outfeed_config) { OpRequest op_request; OutfeedRequest* request = op_request.mutable_outfeed_request(); request->set_outfeed_config(outfeed_config); *request->mutable_operand() = operand; - *request->mutable_shape() = shape; + *request->mutable_shape() = shape_with_layout; RunOpAndNoteError(&op_request); } @@ -868,6 +868,14 @@ ComputationDataHandle ComputationBuilder::Or( return BinaryOp(BINOP_OR, lhs, rhs, broadcast_dimensions); } +// TODO(b/65209188): Create a dedicated lowering for Xor +ComputationDataHandle ComputationBuilder::Xor( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return Or(And(Not(lhs), rhs, broadcast_dimensions), + And(lhs, Not(rhs), broadcast_dimensions)); +} + ComputationDataHandle ComputationBuilder::Not( const ComputationDataHandle& operand) { return UnaryOp(UNOP_NOT, operand); @@ -1382,15 +1390,16 @@ ComputationDataHandle ComputationBuilder::BatchNormInference( ComputationDataHandle ComputationBuilder::BatchNormGrad( const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& mean, const ComputationDataHandle& var, + const ComputationDataHandle& batch_mean, + const ComputationDataHandle& batch_var, const ComputationDataHandle& grad_output, float epsilon, int64 feature_index) { OpRequest op_request; BatchNormGradRequest* request = op_request.mutable_batch_norm_grad_request(); *request->mutable_operand() = operand; *request->mutable_scale() = scale; - *request->mutable_mean() = mean; - *request->mutable_variance() = var; + *request->mutable_mean() = batch_mean; + *request->mutable_variance() = batch_var; *request->mutable_grad_output() = grad_output; request->set_epsilon(epsilon); request->set_feature_index(feature_index); diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 377b6716399ea87b12bd0bd8a9486d4476e3cbf0..2141ebc2065a1a80d2fe820a7b6fe15434c89e28 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -512,6 +512,10 @@ class ComputationBuilder { const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + ComputationDataHandle Xor( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + ComputationDataHandle Not(const ComputationDataHandle& operand); ComputationDataHandle ShiftLeft( @@ -872,7 +876,7 @@ class ComputationBuilder { Window* window); // Internal helper method that does the building for an arbitrary unary op. - ComputationDataHandle UnaryOp(UnaryOperation binop, + ComputationDataHandle UnaryOp(UnaryOperation unop, const ComputationDataHandle& operand); // Internal helper method that does the building for an arbitrary binary op. diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 804e34f5e75ce2d153ac7627b94a543fda88e810..6e3c5cb484b8f1ef053fa287a4d462aeb886e530 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -76,4 +76,35 @@ ExecutableBuildOptions::generate_hlo_graph() const { return generate_hlo_graph_; } +ExecutableBuildOptions& ExecutableBuildOptions::set_dump_optimized_hlo_proto_to( + tensorflow::StringPiece dirpath) { + dump_optimized_hlo_proto_to_ = dirpath.ToString(); + return *this; +} + +const tensorflow::gtl::optional& +ExecutableBuildOptions::dump_optimized_hlo_proto_to() const { + return dump_optimized_hlo_proto_to_; +} + +ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to( + tensorflow::StringPiece dirpath) { + dump_per_pass_hlo_proto_to_ = dirpath.ToString(); + return *this; +} + +const tensorflow::gtl::optional& +ExecutableBuildOptions::dump_per_pass_hlo_proto_to() const { + return dump_per_pass_hlo_proto_to_; +} + +ExecutableBuildOptions& ExecutableBuildOptions::set_hlo_profile(bool enabled) { + hlo_profile_ = enabled; + return *this; +} + +tensorflow::gtl::optional ExecutableBuildOptions::hlo_profile() const { + return hlo_profile_; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 3a52dbac9adb155ad9a7d91a8102707f70fe2fbf..11f10983606fe02b1edb11a260edde8e5f9a726f 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/optional.h" namespace xla { @@ -57,15 +58,36 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_generate_hlo_graph(string regex); const tensorflow::gtl::optional& generate_hlo_graph() const; + // If set, specifies a dirpath to dump the end-of-optimization-pipeline HLO + // protobuf to (as in DebugOptions). + ExecutableBuildOptions& set_dump_optimized_hlo_proto_to( + tensorflow::StringPiece dirpath); + const tensorflow::gtl::optional& dump_optimized_hlo_proto_to() const; + + // If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs + // to (as in DebugOptions). + ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to( + tensorflow::StringPiece dirpath); + const tensorflow::gtl::optional& dump_per_pass_hlo_proto_to() const; + + // If true, specifies that we should record an HLO profile during execution + // and log it after execution (as in DebugOptions). If nullopt the default is + // used. + ExecutableBuildOptions& set_hlo_profile(bool enabled); + tensorflow::gtl::optional hlo_profile() const; + // Returns a string representation of the build options, suitable for // debugging. string ToString() const; private: + tensorflow::gtl::optional hlo_profile_; int device_ordinal_ = -1; Shape result_layout_; bool result_layout_set_ = false; tensorflow::gtl::optional generate_hlo_graph_; + tensorflow::gtl::optional dump_optimized_hlo_proto_to_; + tensorflow::gtl::optional dump_per_pass_hlo_proto_to_; DeviceMemoryAllocator* device_allocator_ = nullptr; }; diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 91396f055fe4a3ecbd436139be9470e2a35e1c63..30594243dcf51d2b5312b9dcb2bea7d0cd78524d 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -265,6 +265,24 @@ StatusOr> LocalClient::Compile( updated_options)); } +StatusOr> LocalClient::Compile( + const XlaComputation& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const ExecutableBuildOptions& options) { + ExecutableBuildOptions updated_options = options; + if (options.device_ordinal() == -1) { + updated_options.set_device_ordinal(default_device_ordinal()); + VLOG(3) << "Set device ordinal to default value of: " + << updated_options.device_ordinal(); + } + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + local_service_->CompileExecutable( + computation, argument_layouts, updated_options)); + return WrapUnique(new LocalExecutable(std::move(executable), + local_service_->mutable_backend(), + updated_options)); +} + StatusOr> LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, DeviceMemoryAllocator* allocator) { diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index b52a30f5a0b92e0094e6b0de3241c10a5a909cad..98ee7c62c94be7c618cedd3dc12ecbfc812ee180 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -69,7 +69,7 @@ class LocalExecutable { // of the computation. tensorflow::Status ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, - const ExecutableRunOptions& options, const Backend& backend); + const ExecutableRunOptions& run_options, const Backend& backend); // Records the computation in a SessionModule proto with the arguments used to // invoke it, and the result. Enabled by flag: --tla_dump_executions_to. @@ -123,6 +123,15 @@ class LocalClient : public Client { const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& options); + // Build and return a LocalExecutable object. The executable is compiled using + // the given XlaComputation, argument layouts and options. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> Compile( + const XlaComputation& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const ExecutableBuildOptions& options); + // Copy the literal data to the device with the given ordinal and return as a // ScopedShapedBuffer. If non-null the given memory allocator is used for // device memory allocation. If null, the default memory allocator for the diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..cc5f551c9c1a7b59426f3490e5e671f341543f34 --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/BUILD @@ -0,0 +1,91 @@ +# Description: +# The new XLA client libraries. +# +# This is NOT YET ready to use. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = [":friends"]) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "xla_computation", + srcs = ["xla_computation.cc"], + hdrs = ["xla_computation.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/core:lib", + ], +) + +# TODO(b/74197823): Replace computation_builder with xla_builder. +cc_library( + name = "xla_builder", + srcs = ["xla_builder.cc"], + hdrs = ["xla_builder.h"], + deps = [ + ":xla_computation", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:shape_inference", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "xla_builder_test", + srcs = ["xla_builder_test.cc"], + deps = [ + ":xla_builder", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/core:test", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc new file mode 100644 index 0000000000000000000000000000000000000000..fcaf393b6b1db6e8335eb84cf00a19c543df1087 --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -0,0 +1,964 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/mutex.h" + +namespace xla { + +using tensorflow::strings::StrCat; + +namespace { + +int64 GetUniqueId() { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static int64 built_counter = 0; + tensorflow::mutex_lock loc(mu); + const int64 id = built_counter++; + return id; +} + +// Returns true if an instruction with the given opcode can be the root of the +// computation. +bool CanBeRoot(HloOpcode opcode) { + switch (opcode) { + case HloOpcode::kSend: + case HloOpcode::kOutfeed: + case HloOpcode::kTrace: + return false; + default: + return true; + } +} + +} // namespace + +StatusOr XlaBuilder::GetShape(const XlaOp& op) const { + TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op)); + return instr->shape(); +} + +StatusOr XlaOp::GetShape() const { + TF_RET_CHECK(builder_ != nullptr); + return builder_->GetShape(*this); +} + +XlaBuilder::XlaBuilder(const string& computation_name) + : name_(computation_name) {} + +XlaBuilder::~XlaBuilder() {} + +void XlaBuilder::NoteError(const Status& error) { + CHECK(!error.ok()); + if (die_immediately_on_error_) { + LOG(FATAL) << "error building computation: " << error; + } + + if (first_error_.ok()) { + first_error_ = error; + first_error_backtrace_.CreateCurrent(/*skip_count=*/1); + } +} + +StatusOr XlaBuilder::GetProgramShape(int64* root_id) { + TF_RET_CHECK(root_id != nullptr); + ProgramShape program_shape; + + // Not all instructions can be roots. Walk backwards from the last added + // instruction until a valid root is found. + int64 index = instructions_.size() - 1; + for (; index >= 0; index--) { + TF_ASSIGN_OR_RETURN(HloOpcode opcode, + StringToHloOpcode(instructions_[index].opcode())); + if (CanBeRoot(opcode)) { + break; + } + } + if (index < 0) { + return FailedPrecondition("no root instruction was found"); + } + *root_id = instructions_[index].id(); + *program_shape.mutable_result() = instructions_[index].shape(); + + // Check that the parameter numbers are continuous from 0, and add parameter + // shapes and names to the program shape. + const int64 param_count = parameter_numbers_.size(); + for (int64 i = 0; i < param_count; i++) { + program_shape.add_parameters(); + program_shape.add_parameter_names(); + } + for (const HloInstructionProto& instr : instructions_) { + // Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So + // to verify continuity, we just need to verify that every parameter is in + // the right range. + if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) { + const int64 index = instr.parameter_number(); + TF_RET_CHECK(index >= 0 && index < param_count) + << "invalid parameter number: " << index; + *program_shape.mutable_parameters(index) = instr.shape(); + *program_shape.mutable_parameter_names(index) = instr.name(); + } + } + return program_shape; +} + +StatusOr XlaBuilder::GetProgramShape() { + int64 root_id; + return GetProgramShape(&root_id); +} + +StatusOr XlaBuilder::Build() { + if (!first_error_.ok()) { + string backtrace; + first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); + return AppendStatus(first_error_, backtrace); + } + + HloComputationProto entry; + entry.set_name(name_); + + { + int64 root_id; + ProgramShape program_shape; + TF_ASSIGN_OR_RETURN(program_shape, GetProgramShape(&root_id)); + entry.mutable_program_shape()->Swap(&program_shape); + entry.set_root_id(root_id); + } + + for (auto& instruction : instructions_) { + entry.add_instructions()->Swap(&instruction); + } + + const int64 id = GetUniqueId(); + entry.set_id(id); + XlaComputation computation(id); + HloModuleProto* module = computation.mutable_proto(); + module->set_name(entry.name()); + module->set_id(entry.id()); + module->set_entry_computation_name(entry.name()); + module->set_entry_computation_id(entry.id()); + *module->mutable_program_shape() = entry.program_shape(); + for (auto& e : embedded_) { + module->add_computations()->Swap(&e.second); + } + module->add_computations()->Swap(&entry); + + // Clear data held by this builder. + this->instructions_.clear(); + this->embedded_.clear(); + this->parameter_numbers_.clear(); + + return std::move(computation); +} + +StatusOr XlaBuilder::InDimBroadcast( + const Shape& shape, const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + HloInstructionProto instr; + *instr.mutable_shape() = shape; + for (int64 dim : broadcast_dimensions) { + instr.add_dimensions(dim); + } + return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand}); +} + +StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, + const XlaOp& operand) { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + + CHECK(ShapeUtil::IsScalar(operand_shape) || + ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)); + Shape broadcast_shape = + ShapeUtil::ChangeElementType(output_shape, operand_shape.element_type()); + + // Do explicit broadcast for scalar. + if (ShapeUtil::IsScalar(operand_shape)) { + return InDimBroadcast(broadcast_shape, operand, {}); + } + + // Do explicit broadcast for degenerate broadcast. + std::vector broadcast_dimensions; + std::vector reshaped_dimensions; + for (int i = 0; i < ShapeUtil::Rank(operand_shape); i++) { + if (operand_shape.dimensions(i) == output_shape.dimensions(i)) { + broadcast_dimensions.push_back(i); + reshaped_dimensions.push_back(operand_shape.dimensions(i)); + } else { + TF_RET_CHECK(operand_shape.dimensions(i) == 1) + << "An explicit broadcast sequence requires the broadcasted " + "dimensions to be trivial; operand shape: " + << operand_shape << "; output_shape: " << output_shape; + } + } + // Eliminate the size one dimensions. + TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, + Reshape(ShapeUtil::MakeShape(operand_shape.element_type(), + reshaped_dimensions), + operand)); + // Broadcast 'reshape' up to the larger size. + return InDimBroadcast(broadcast_shape, reshaped_operand, + broadcast_dimensions); +} + +XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferUnaryOpShape(unop, operand_shape)); + return AddInstruction(std::move(instr), unop, {operand}); + }()); +} + +XlaOp XlaBuilder::BinaryOp( + HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferBinaryOpShape( + binop, lhs_shape, rhs_shape, broadcast_dimensions)); + + const int64 lhs_rank = ShapeUtil::Rank(lhs_shape); + const int64 rhs_rank = ShapeUtil::Rank(rhs_shape); + + XlaOp updated_lhs = lhs; + XlaOp updated_rhs = rhs; + + if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) { + const bool should_broadcast_lhs = lhs_rank < rhs_rank; + XlaOp from = should_broadcast_lhs ? lhs : rhs; + const Shape& from_shape = should_broadcast_lhs ? lhs_shape : rhs_shape; + + std::vector to_size; + for (int64 size : instr.shape().dimensions()) { + to_size.push_back(size); + } + for (int64 from_dim = 0; from_dim < ShapeUtil::Rank(from_shape); + from_dim++) { + int64 to_dim = broadcast_dimensions[from_dim]; + to_size[to_dim] = from_shape.dimensions(from_dim); + } + + const Shape& broadcasted_shape = + ShapeUtil::MakeShape(from_shape.element_type(), to_size); + TF_ASSIGN_OR_RETURN( + XlaOp broadcasted_operand, + InDimBroadcast(broadcasted_shape, from, broadcast_dimensions)); + + updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs; + updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs; + } + + TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, updated_lhs.GetShape()); + if (!ShapeUtil::SameDimensions(instr.shape(), updated_lhs_shape)) { + TF_ASSIGN_OR_RETURN(updated_lhs, + AddBroadcastSequence(instr.shape(), updated_lhs)); + } + TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, updated_rhs.GetShape()); + if (!ShapeUtil::SameDimensions(instr.shape(), updated_rhs_shape)) { + TF_ASSIGN_OR_RETURN(updated_rhs, + AddBroadcastSequence(instr.shape(), updated_rhs)); + } + + return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs}); + }()); +} + +XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, + const XlaOp& ehs) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, ehs.GetShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferTernaryOpShape( + triop, lhs_shape, rhs_shape, ehs_shape)); + XlaOp updated_lhs = lhs; + XlaOp updated_rhs = rhs; + XlaOp updated_ehs = ehs; + if (!ShapeUtil::IsTuple(instr.shape())) { + if (!ShapeUtil::IsTuple(lhs_shape) && + !ShapeUtil::SameDimensions(instr.shape(), lhs_shape)) { + // lhs is being implicitly broadcasted. Change to explicit. + TF_ASSIGN_OR_RETURN(updated_lhs, + AddBroadcastSequence(instr.shape(), lhs)); + } + if (!ShapeUtil::IsTuple(rhs_shape) && + !ShapeUtil::SameDimensions(instr.shape(), rhs_shape)) { + // rhs is being implicitly broadcasted. Change to explicit. + TF_ASSIGN_OR_RETURN(updated_rhs, + AddBroadcastSequence(instr.shape(), rhs)); + } + if (!ShapeUtil::IsTuple(ehs_shape) && + !ShapeUtil::SameDimensions(instr.shape(), ehs_shape)) { + // ehs is being implicitly broadcasted. Change to explicit. + TF_ASSIGN_OR_RETURN(updated_ehs, + AddBroadcastSequence(instr.shape(), ehs)); + } + } + return AddInstruction(std::move(instr), triop, + {updated_lhs, updated_rhs, updated_ehs}); + }()); +} + +XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = literal.shape(); + *instr.mutable_literal() = literal.ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kConstant); + }()); +} + +XlaOp XlaBuilder::Call(const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + std::vector operand_shape_ptrs; + std::vector operand_shapes; + for (const auto& operand : operands) { + TF_ASSIGN_OR_RETURN(const Shape& shape, operand.GetShape()); + operand_shapes.push_back(shape); + } + c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferCallShape( + operand_shape_ptrs, + /*to_apply=*/computation.GetProgramShape())); + + // Add called computation. + instr.add_called_computation_ids( + computation.proto().entry_computation_id()); + for (const HloComputationProto& e : computation.proto().computations()) { + embedded_.insert({e.id(), e}); + } + + return AddInstruction(std::move(instr), HloOpcode::kCall, operands); + }()); +} + +XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, + const string& name) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + if (parameter_numbers_.find(parameter_number) != parameter_numbers_.end()) { + return InvalidArgument("parameter %lld already registered", + parameter_number); + } + parameter_numbers_.insert(parameter_number); + instr.set_parameter_number(parameter_number); + instr.set_name(name); + *instr.mutable_shape() = shape; + return AddInstruction(std::move(instr), HloOpcode::kParameter); + }()); +} + +XlaOp XlaBuilder::Broadcast( + const XlaOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN( + const Shape& shape, + ShapeInference::InferBroadcastShape(operand_shape, broadcast_sizes)); + + // The client-level broadcast op just appends dimensions on the left (adds + // lowest numbered dimensions). The HLO broadcast instruction is more + // flexible and can add new dimensions anywhere. The instruction's + // dimensions field maps operand dimensions to dimensions in the broadcast + // output, so to append dimensions on the left the instruction's dimensions + // should just be the n highest dimension numbers of the output shape where + // n is the number of input dimensions. + const int64 operand_rank = ShapeUtil::Rank(operand_shape); + std::vector dimensions(operand_rank); + for (int i = 0; i < operand_rank; ++i) { + dimensions[i] = i + ShapeUtil::Rank(shape) - operand_rank; + } + return InDimBroadcast(shape, operand, dimensions); + }()); +} + +StatusOr XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) { + HloInstructionProto instr; + *instr.mutable_shape() = shape; + return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand}); +} + +XlaOp XlaBuilder::Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice operands, + int64 dimension) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& shape, + ShapeInference::InferReshapeShape( + operand_shape, dimensions, new_sizes)); + XlaOp transposed = IsIdentityPermutation(dimensions) + ? operand + : Transpose(operand, dimensions); + return Reshape(shape, transposed); + }()); +} + +XlaOp XlaBuilder::Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto shape, operand.GetShape()); + std::vector dimensions(shape.dimensions_size()); + std::iota(dimensions.begin(), dimensions.end(), 0); + return Reshape(operand, dimensions, new_sizes); + }()); +} + +XlaOp XlaBuilder::Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions) { + return UnimplementedOp(); +} + +void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { + UnimplementedOp(); +} + +XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, + const XlaOp& on_false) { + return TernaryOp(HloOpcode::kSelect, pred, on_true, on_false); +} + +XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice elements) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kEq, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kNe, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kGe, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kGt, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kLe, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kLt, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + Padding padding) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ConvGeneral( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, + const tensorflow::gtl::ArraySlice fft_length) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { + return UnimplementedOp(); +} + +void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config) { + UnimplementedOp(); +} + +XlaOp XlaBuilder::CustomCall(const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice operands, + const string& channel_name, + int64 cost_estimate_ns, const Shape& shape) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Complex( + const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kComplex, real, imag, broadcast_dimensions); +} + +XlaOp XlaBuilder::Conj(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kSubtract, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kDivide, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kRemainder, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kMaximum, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kMinimum, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kAnd, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions); +} + +// TODO(b/65209188): Create a dedicated lowering for Xor. +XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return Or(And(Not(lhs), rhs, broadcast_dimensions), + And(lhs, Not(rhs), broadcast_dimensions)); +} + +XlaOp XlaBuilder::Not(const XlaOp& operand) { + return UnaryOp(HloOpcode::kNot, operand); +} + +XlaOp XlaBuilder::ShiftLeft( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs, + broadcast_dimensions); +} + +XlaOp XlaBuilder::ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs, + broadcast_dimensions); +} + +XlaOp XlaBuilder::Abs(const XlaOp& operand) { + return UnaryOp(HloOpcode::kAbs, operand); +} + +XlaOp XlaBuilder::Atan2( + const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kAtan2, y, x, broadcast_dimensions); +} + +XlaOp XlaBuilder::Exp(const XlaOp& operand) { + return UnaryOp(HloOpcode::kExp, operand); +} + +XlaOp XlaBuilder::Floor(const XlaOp& operand) { + return UnaryOp(HloOpcode::kFloor, operand); +} + +XlaOp XlaBuilder::Ceil(const XlaOp& operand) { + return UnaryOp(HloOpcode::kCeil, operand); +} + +XlaOp XlaBuilder::Round(const XlaOp& operand) { + return UnaryOp(HloOpcode::kRoundNearestAfz, operand); +} + +XlaOp XlaBuilder::Log(const XlaOp& operand) { + return UnaryOp(HloOpcode::kLog, operand); +} + +XlaOp XlaBuilder::Sign(const XlaOp& operand) { + return UnaryOp(HloOpcode::kSign, operand); +} + +XlaOp XlaBuilder::Cos(const XlaOp& operand) { + return UnaryOp(HloOpcode::kCos, operand); +} + +XlaOp XlaBuilder::Sin(const XlaOp& operand) { + return UnaryOp(HloOpcode::kSin, operand); +} + +XlaOp XlaBuilder::Tanh(const XlaOp& operand) { + return UnaryOp(HloOpcode::kTanh, operand); +} + +XlaOp XlaBuilder::Real(const XlaOp& operand) { + return UnaryOp(HloOpcode::kReal, operand); +} + +XlaOp XlaBuilder::Imag(const XlaOp& operand) { + return UnaryOp(HloOpcode::kImag, operand); +} + +XlaOp XlaBuilder::IsFinite(const XlaOp& operand) { + return UnaryOp(HloOpcode::kIsFinite, operand); +} + +XlaOp XlaBuilder::Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferTransposeShape(operand_shape, permutation)); + for (int64 dim : permutation) { + instr.add_dimensions(dim); + } + return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand}); + }()); +} + +XlaOp XlaBuilder::Rev(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Sort(const XlaOp& operand) { + return UnaryOp(HloOpcode::kSort, operand); +} + +XlaOp XlaBuilder::SqrtF32(const XlaOp& operand) { + return BinaryOp(HloOpcode::kPower, operand, ConstantR0(0.5), + /*broadcast_dimensions=*/{}); +} + +XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, + PrimitiveType new_element_type) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, + PrimitiveType new_element_type) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::SquareF32(const XlaOp& operand) { + return BinaryOp(HloOpcode::kPower, operand, ConstantR0(2.0), + /*broadcast_dimensions=*/{}); +} + +XlaOp XlaBuilder::ReciprocalF32(const XlaOp& operand) { + return BinaryOp(HloOpcode::kPower, operand, ConstantR0(-1.0), + /*broadcast_dimensions=*/{}); +} + +XlaOp XlaBuilder::Neg(const XlaOp& operand) { + return UnaryOp(HloOpcode::kNegate, operand); +} + +XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand, + const XlaOp& max) { + return TernaryOp(HloOpcode::kClamp, min, operand, max); +} + +XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::RngNormal(const XlaOp& mu, const XlaOp& sigma, + const Shape& shape) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::RngUniform(const XlaOp& a, const XlaOp& b, + const Shape& shape) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::While(const XlaComputation& condition, + const XlaComputation& body, const XlaOp& init) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Reduce( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ReduceWindow( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::SelectAndScatter( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits) { + return UnimplementedOp(); +} + +void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { + UnimplementedOp(); +} + +XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { + return UnimplementedOp(); +} + +StatusOr XlaBuilder::AddInstruction( + HloInstructionProto&& instr, HloOpcode opcode, + tensorflow::gtl::ArraySlice operands) { + const int64 handle = instructions_.size(); + instr.set_id(handle); + instr.set_opcode(HloOpcodeString(opcode)); + if (instr.name().empty()) { + instr.set_name(StrCat(instr.opcode(), ".", handle)); + } else { + // Append the handle to make sure the name is unique. + instr.set_name(StrCat(instr.name(), ".", handle)); + } + for (const auto& operand : operands) { + TF_RET_CHECK(operand.builder_ != nullptr); + TF_RET_CHECK(operand.builder_ == this) + << "Do not add XlaOp from builder " << operand.builder_->name() + << " to builder " << this->name(); + instr.add_operand_ids(operand.handle()); + } + + *instr.mutable_metadata() = metadata_; + if (sharding_) { + *instr.mutable_sharding() = *sharding_; + } + + instructions_.push_back(instr); + + XlaOp op(handle, this); + return op; +} + +StatusOr XlaBuilder::LookUpInstruction( + const XlaOp& op) const { + TF_RET_CHECK(op.builder_ == this); + if (op.handle() >= instructions_.size() || op.handle() < 0) { + return InvalidArgument("no XlaOp value %lld", op.handle()); + } + return &instructions_[op.handle()]; +} + +XlaOp XlaBuilder::UnimplementedOp() { + NoteError(Unimplemented("Op not yet implemented")); + return {}; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..c5c35159e06e1cc2d9f75a5b41f025773c3d685d --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -0,0 +1,896 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// TODO(b/74197823): Replace computation_builder.h with this file. +// +// This is NOT YET ready to use. + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stacktrace.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class XlaBuilder; + +// This represents an instruction that has been enqueued using the XlaBuilder. +// This is used to pass to subsequent computations that depends upon the +// instruction as an operand. +// +// TODO(b/74197823): Replace xla::ComputationDataHandle with this one. +class XlaOp { + public: + XlaOp() : handle_(0), builder_(nullptr) {} + + StatusOr GetShape() const; + + private: + XlaOp(int64 handle, XlaBuilder* builder) + : handle_(handle), builder_(builder) {} + + int64 handle() const { return handle_; } + friend class XlaBuilder; + + int64 handle_; + XlaBuilder* builder_; // Not owned. +}; + +// A convenient interface for building up computations. +// +// Thread-compatible. +// +// TODO(b/74197823): Replace xla::ComputationBuilder with this one. +class XlaBuilder { + public: + // computation_name: name to use for the built computation. + XlaBuilder(const string& computation_name); + + XlaBuilder(const XlaBuilder&) = delete; + XlaBuilder& operator=(const XlaBuilder&) = delete; + + ~XlaBuilder(); + + // Returns the computation name. + const string& name() const { return name_; } + + // Sets OpMetadata that will be added to all instructions until cleared. + // + // OpMetadata is often applied to a series of XLA HLO instructions. As a + // result, OpMetadata is set on the Computation Builder. All subsequent + // instructions generated via this Computation Builder will have the same + // OpMetadata attached until a call to ClearOpMetadata. + void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } + + // Clears the HloMetadata state. + void ClearOpMetadata() { metadata_.Clear(); } + + // Sets an OpSharding that will be attached to all instructions until cleared. + void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } + + // Clears the sharding. Ops will be sharded according to the default placement + // policy. + void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; } + + // Returns the OpSharding that will be attached to all instructions. + const tensorflow::gtl::optional& sharding() const { + return sharding_; + } + + // Sets the builder to a mode where it will die immediately when an error is + // encountered, rather than producing it in a deferred fashion when Build() is + // called (which is the default). + void set_die_immediately_on_error(bool enabled) { + die_immediately_on_error_ = enabled; + } + + // Enqueues a "retrieve parameter value" instruction for a parameter that was + // passed to the computation. + XlaOp Parameter(int64 parameter_number, const Shape& shape, + const string& name); + + // Enqueues a constant with the value of the given literal onto the + // computation. + XlaOp ConstantLiteral(const Literal& literal); + + // Enqueues a constant onto the computation. Methods are templated on the + // native host type (NativeT) which corresponds to a specific XLA + // PrimitiveType as given in the following table: + // + // Native Type PrimitiveType + // ----------------------------- + // bool PRED + // int32 S32 + // int64 S64 + // uint32 U32 + // uint64 U64 + // float F32 + // double F64 + // + // Note: not all primitive types defined in xla_data.proto have a + // corresponding native type yet. + template + XlaOp ConstantR0(NativeT value); + template + XlaOp ConstantR1(tensorflow::gtl::ArraySlice values); + XlaOp ConstantR1(const tensorflow::core::Bitmap& values); + template + XlaOp ConstantR2( + std::initializer_list> values); + template + XlaOp ConstantFromArrayWithLayout(const Array& values, + const Layout& layout); + template + XlaOp ConstantFromArray(const Array& values); + template + XlaOp ConstantR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout); + template + XlaOp ConstantR2FromArray2D(const Array2D& values); + template + XlaOp ConstantR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout); + template + XlaOp ConstantR3FromArray3D(const Array3D& values); + template + XlaOp ConstantR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout); + template + XlaOp ConstantR4FromArray4D(const Array4D& values); + + // Enqueues a rank one constant (vector) onto the computation. The vector has + // size 'length' and every element has the value 'value'. + template + XlaOp ConstantR1(int64 length, NativeT value); + + // Adds dimensions to an array by duplicating the data in the array. + // + // The new dimensions are inserted on the left, i.e. if + // broadcast_sizes has values {a0, ..., aN} and the operand shape + // has dimensions {b0, ..., bM} then the shape of the output has + // dimensions {a0, ..., aN, b0, ..., bM}. + // + // The new dimensions index into copies of the operand, i.e. + // + // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] + XlaOp Broadcast(const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); + + // Enqueues a pad operation onto the computation that pads the given value on + // the edges as well as between the elements of the input. padding_config + // specifies the padding amount for each dimension. + XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config); + + // Enqueues an operation onto the computation that flattens the operand based + // on the dimension order (major/slowest-varying to minor/fastest-varying) + // given, followed by reshaping it into the shape with the given dimension + // sizes (also major to minor). Conceptually, this is a limited form of + // "shape casting". + XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); + + // Enqueues an operation onto the computation that collapses the operand, from + // first to last dimension (C order), then reshapes it to the given dimension + // sizes. Conceptually, this is a limited form of "shape casting". + XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes); + + // Wrapper for Reshape. + // Enqueues an operation to collapse the provided dimensions; e.g. an + // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to + // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must + // be a consecutive, in-order subsequence of the operand dimensions. + // + // Note that collapsing a single dimension does nothing: + // + // {256} collapsing {0} => {256} + // {1} collapsing {0} => {1} + // + // Collapsing multiple dimensions produces a single result dimension: + // + // {256, 2} collapsing {0,1} => {512} + // {256, 2, 3} collapsing {0,1} => {512, 3} + // + // This could potentially cause data to be moved -- it provides a more + // structured form of reshaping than an arbitrary Reshape operation. + XlaOp Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + + // Enqueues a slice operation onto the computation that slices the operand + // from the start indices to the limit indices; e.g. + // + // x + // [ 0 1 2 3 ] + // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] + // [ 8 9 a b ] + // + // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D + // range notation. + // The strides parameter determines the stride over the slice + XlaOp Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + + // Enqueues a slice operation in a given dimension, taking all other + // dimensions as they are; e.g. if dimno is 1 from start_index 2 to + // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand + // for: + // + // array[:, 2:4:1, :] + XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, + int64 stride, int64 dimno); + + // Enqueues a slice operation onto the computation that slices the 'operand' + // from dynamic start indices which are passed in 'start_indices'. + // The size of the slice in each dimension is passed in 'slice_sizes', + // which specify the end point of exclusive slice intervals in each + // dimension [start, start + size). + // The shape of 'start_indices' must be rank == 1, with dimension size + // equal to the rank of the 'operand'. + // Slice index calculations are computed modulo input dimension sizes to + // prevent dynamic start indices from generating out-of-bound array accesses. + XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + + // Enqueues a dynamic update slice operation onto the computation, which + // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. + // The shape of 'update' determines the shape of the slice of 'operand' + // which is updated. + // The indices specified in 'start_indices' specify the offset of the slice + // of 'operand' which is updated. + // + // update = {10, 11} // calculated at runtime. + // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] + // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] + // [7 8 9] [7 8 9 ] + // + // The shape of 'start_indices' must be rank == 1, with dimension size + // equal to the rank of the 'operand'. + // Slice index calculations are computed modulo update dimension sizes to + // prevent dynamic start indices from generating out-of-bound array accesses. + XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices); + + // Enqueues a concatenate instruction onto the computation. 'operands' must + // have >= 1 entry. + XlaOp ConcatInDim(tensorflow::gtl::ArraySlice operands, + int64 dimension); + + // Enqueue a tracing operation onto the computation; the computation will emit + // a logging message with the operand. + void Trace(const string& tag, const XlaOp& operand); + + // Enqueues a conditional-move-like select operation onto the computation; + // predicated on pred, selects between on_true and on_false. + XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); + + // Enqueues a tuple-creation instruction onto the computation. + XlaOp Tuple(tensorflow::gtl::ArraySlice elements); + + // Enqueues a tuple-element-get instruction onto the computation. + XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); + + // Enqueues an equal-to comparison instruction onto the computation. + XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a not-equal comparison instruction onto the computation. + XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a greater-or-equal comparison instruction onto the computation. + XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a greater-than comparison instruction onto the computation. + XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a less-than comparison instruction onto the computation. + XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a less-or-equal comparison instruction onto the computation. + XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a dot instruction onto the computation. + XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + + // Enqueues a general dot instruction onto the computation. + XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers); + + // Enqueues a convolution instruction onto the computation, which uses the + // default convolution dimension numbers. + XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + Padding padding); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration in the format returned by MakePadding(). + XlaOp ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided dimension numbers configuration. + XlaOp ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration as well as the dimension numbers. + XlaOp ConvGeneral( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration, dilation factors and dimension numbers. + XlaOp ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues an FFT instruction onto the computation, of the given type and + // with the given FFT length. + XlaOp Fft(const XlaOp& operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + + // Enqueues an infeed instruction onto the computation, which writes data of + // the given shape to the infeed buffer of the device. + XlaOp Infeed(const Shape& shape, const string& config = ""); + + // Enqueues an outfeed instruction onto the computation. This instruction + // generates outgoing data transfers for the given data. + // + // shape_with_layout communicates the laid out shape that we want to outfeed + // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error + // will occur. + void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config); + + // Enqueues a call instruction onto the computation. + XlaOp Call(const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands); + + // Enqueues a custom call instruction onto the computation. + // During code generation, a call instruction is emitted which targets a + // symbol with the name |call_target_name|. The |operands| are passed to the + // call instruction. |shape| is the resultant shape. + XlaOp CustomCall(const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape); + + // Enqueues a pseudo-op to represent host-side computation data-dependencies. + // During code generation, host send and receive operations will be generated + // to transfer |operands| to the host and a single result of |shape| back to + // the device. Host send/recv operations are emitted using |channel_name|. + // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO + // instruction scheduling. + XlaOp HostCompute(tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, + const Shape& shape); + + // The following methods enqueue element-wise binary arithmetic operations + // onto the computation. The shapes of the operands have to match unless one + // of the operands is a scalar, or an explicit broadcast dimension is given + // (see g3doc for more details). + + // Enqueues a complex compose instruction onto the computation. + XlaOp Complex(const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a complex conjugate instruction onto the computation. + XlaOp Conj(const XlaOp& operand); + + // Enqueues an add instruction onto the computation. + XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a subtract instruction onto the computation. + XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a multiply instruction onto the computation. + XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a divide instruction onto the computation. + XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a remainder instruction onto the computation. + XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a max instruction onto the computation. + XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a min instruction onto the computation. + XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Element-wise logical operators + XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + XlaOp Not(const XlaOp& operand); + + XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + XlaOp ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + XlaOp ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Reduces an array among the provided dimensions, given "computation" as a + // reduction operator. + XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); + + // Convenience wrapper around the above that reduces all the dimensions in the + // operand shape. + XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation); + + // Enqueues a windowed reduce instruction onto the computation. + XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding); + + // As ReduceWindow(), but the padding is given in the format + // returned by MakePadding(). + XlaOp ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + + // Returns the sum of the operand value across all replicas. All replicas + // supply one input to the sum and all replicas receive the resulting sum. + XlaOp CrossReplicaSum(const XlaOp& operand); + + // Enqueues an operation that scatters the `source` array to the selected + // indices of each window. + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, + const XlaComputation& scatter); + + // As SelectAndScatter(), but the padding is given in the format + // returned by MakePadding(). + XlaOp SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + + // Enqueues an abs instruction onto the computation. + XlaOp Abs(const XlaOp& operand); + + // Enqueues a atan2 instruction onto the computation. + XlaOp Atan2(const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues an exp instruction onto the computation. + XlaOp Exp(const XlaOp& operand); + + // Enqueues a floor instruction onto the computation. + XlaOp Floor(const XlaOp& operand); + + // Enqueues a ceil instruction onto the computation. + XlaOp Ceil(const XlaOp& operand); + + // Enqueues a round instruction onto the computation, rounding to nearest even + // with half-way cases rounding away from zero. + XlaOp Round(const XlaOp& operand); + + // Enqueues an log instruction (natural logarithm) onto the computation. + XlaOp Log(const XlaOp& operand); + + // Enqueues a sign instruction onto the computation. + XlaOp Sign(const XlaOp& operand); + + // Enqueues a cosine instruction onto the computation. + XlaOp Cos(const XlaOp& operand); + + // Enqueues a sine instruction onto the computation. + XlaOp Sin(const XlaOp& operand); + + // Enqueues a tanh instruction onto the computation. + XlaOp Tanh(const XlaOp& operand); + + // Enqueues a real-part instruction onto the computation. + XlaOp Real(const XlaOp& operand); + + // Enqueues an imaginary-part instruction onto the computation. + XlaOp Imag(const XlaOp& operand); + + // Enqueues a float32 sqrt instruction onto the computation. + // (float32 is specified as there is an implicit float32 0.5f constant + // exponent). + XlaOp SqrtF32(const XlaOp& operand); + + // Enqueues a float32 square instruction onto the computation. + // (float32 is specified as there is an implicit float32 2.0f constant + // exponent). + XlaOp SquareF32(const XlaOp& operand); + + // Enqueues a lhs^rhs computation onto the computation. + XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues an operator that tests if the operand's values are finite, i.e., + // not Inf or NaN. Defined only for floating-point types. Returns an array of + // booleans with the same shape where entries are true iff the corresponding + // entry was NaN. + XlaOp IsFinite(const XlaOp& operand); + + // Enqueues a convert instruction onto the computation that changes the + // element type of the operand array to primitive_type. + XlaOp ConvertElementType(const XlaOp& operand, + PrimitiveType new_element_type); + + // Enqueues a no-op instruction onto the computation that changes + // the element type of the operand array to primitive_type. The + // bit-widths of the source and destination element types must be + // identical. + XlaOp BitcastConvertType(const XlaOp& operand, + PrimitiveType new_element_type); + + // Enqueues a float32 reciprocal instruction onto the computation. + // (float32 is specified as there is an implicit float32 -1.0f constant + // exponent). + // + // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the + // shape of the operand. + XlaOp ReciprocalF32(const XlaOp& operand); + + // Enqueues a negate instruction onto the computation. + XlaOp Neg(const XlaOp& operand); + + // Enqueues a transpose instruction onto the computation. + XlaOp Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation); + + // Enqueues a reverse instruction onto the computation. The order of the + // elements in the given dimensions is reversed (i.e., the element at index i + // is moved to index dimension_size - 1 - i). + XlaOp Rev(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + + // Enqueues a sort (as increasing order) instruction onto the computation. + XlaOp Sort(const XlaOp& operand); + + // Enqueues a clamp instruction onto the computation. + XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); + + // Enqueues a map instruction onto the computation. + XlaOp Map(tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands = {}); + + // Enqueues a N(mu, sigma) random number generation instruction onto the + // computation. + XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); + + // Enqueues a U(a, b) random number generation instruction onto the + // computation. Returns values in the semi-open interval [a, b). + XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); + + // Enqueues a while node onto the computation. + XlaOp While(const XlaComputation& condition, const XlaComputation& body, + const XlaOp& init); + + // Enqueues a conditional node onto the computation. + XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation); + + // Enqueues a ReducePrecision node onto the computation. + XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits); + + // Enqueues a Gather node onto the computation. + XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds); + + // Enqueues a Send node onto the computation, to send the given operand to + // a Recv instruction that shares the same channel handle. + void Send(const XlaOp& operand, const ChannelHandle& handle); + + // Enqueues a Recv node onto the computation. The data comes from a Send + // instruction that shares the same channel handle and its shape must + // be the same as the given shape. + XlaOp Recv(const Shape& shape, const ChannelHandle& handle); + + // Returns true if 'operand' is a compile-time constant. A compile-time + // constant does not depend on parameters with index greater than or equal to + // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`. + // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a + // compile-time constant without evaluating the computation. + StatusOr IsConstant(const XlaOp& operand, int64 num_parameters = 0); + + // Normalizes operand across spatial and batch dimensions for each feature. + // + // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` + // is the normalized result and batch_mean and batch_var are the mean and + // variance, respectively, across batch for the operand. + XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index); + + // Normalizes operand across spatial and batch dimensions for each feature. + // + // `BatchNormInference` is equivalent to calling `BatchNormTraining` without + // computing `mean` and `variance` for each batch inside the operation. It + // uses the input `mean` and `variance` instead as estimated values. The + // purpose of this op is to reduce latency in inference, hence the name + // `BatchNormInference`. + // + // The output has the same shape as `operand`, and contains the normalized + // values for each batch. + XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index); + + // Calculates the gradients of a batch norm op. + // + // The inputs `batch_mean` and `batch_var` represent the mean and variance + // across the batch. + // + // Returns a tuple of three elements: + // - grad_operand: Gradient with respect to input `operand` + // - grad_offset: Gradient with respect to input `offset` + // - grad_scale: Gradient with respect to input `scale` + XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index); + + // Builds the computation with the requested operations, or returns a non-ok + // status. + StatusOr Build(); + + // Returns the first error that was encountered while building the + // computation. When an error is encountered, by default we return a vacuous + // XlaOp and inform the user of the error that occurred while + // building the computation when they make a final call to Build(). + // + // See also set_die_immediately_on_error(). + Status first_error() const { return first_error_; } + + // Returns the shape of the given op. + StatusOr GetShape(const XlaOp& op) const; + + // Returns the (inferred) result for the current computation's shape. + StatusOr GetProgramShape(); + + private: + StatusOr AddInstruction( + HloInstructionProto&& instr, HloOpcode opcode, + tensorflow::gtl::ArraySlice operands = {}); + + // Notes that the error occurred by: + // * storing it internally and capturing a backtrace if it's the first error + // (this deferred value will be produced on the call to Build()) + // * dying if die_immediately_on_error_ is true + void NoteError(const Status& error); + + XlaOp NoteErrorOrReturn(StatusOr&& op) { + if (!op.ok()) { + NoteError(op.status()); + return XlaOp(); + } + return op.ConsumeValueOrDie(); + } + + // Helper method that creates an empty op and notes error. + XlaOp UnimplementedOp(); + + StatusOr LookUpInstruction(const XlaOp& op) const; + + // Internal helper method that does the building for an arbitrary unary op. + XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand); + + // Internal helper method that does the building for an arbitrary binary op. + // broadcast_dimensions specifies which dimensions to use for broadcasting + // when the operation is between tensors of different ranks. + XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + + // Internal helper method that does the building for an arbitrary ternary op. + XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, + const XlaOp& ehs); + + StatusOr InDimBroadcast( + const Shape& shape, const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_dimensions); + + // Internal helper method that creates a sequence of instructions that + // performs an explicit broadcast of the operand to the target shape. + StatusOr AddBroadcastSequence(const Shape& output_shape, + const XlaOp& operand); + + // Internal helper method for creating a Reshape op with the already inferred + // shape. + StatusOr Reshape(const Shape& shape, const XlaOp& operand); + + // Returns the (inferred) result for the program shape for the current + // computation and fills the root_id in the pointer. + StatusOr GetProgramShape(int64* root_id); + + string name_; // Name to use for the built computation. + + // The first error encountered while building the computation. + // This is OK until the first error is encountered. + Status first_error_; + + // The saved stack trace from the point at which the first error occurred. + tensorflow::SavedStackTrace first_error_backtrace_; + + // The instructions of this computation. + std::vector instructions_; + + // The embedded computations used by this computation. Each computation was + // the entry computation of some XlaComputation, the key is the unique id of + // that XlaComputation. + std::map embedded_; + + // The unique parameter numbers. + tensorflow::gtl::FlatSet parameter_numbers_; + + // The metadata to attach to each op. This is structured as a "modal"-like + // operation, in order to simplify client code (and not sprinkle this metadata + // throughout the TensorFlow op kernel implementations). + OpMetadata metadata_; + + // Sharding for this operator. This is structured as a "model"-like operation, + // in order to simplify client code, similar to metadata_. + tensorflow::gtl::optional sharding_; + + // Mode bit that indicates whether to die when a first error is encountered. + bool die_immediately_on_error_ = false; +}; + +template +XlaOp XlaBuilder::ConstantR0(NativeT value) { + return ConstantLiteral(*Literal::CreateR0(value)); +} + +template +XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice values) { + return ConstantLiteral(*Literal::CreateR1(values)); +} + +template +XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { + Literal literal(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {length})); + literal.PopulateWithValue(value); + return ConstantLiteral(literal); +} + +inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { + return ConstantLiteral(*Literal::CreateR1(values)); +} + +template +XlaOp XlaBuilder::ConstantR2( + std::initializer_list> values) { + return ConstantLiteral(*Literal::CreateR2(values)); +} + +template +XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array& values, + const Layout& layout) { + return ConstantLiteral( + *Literal::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp XlaBuilder::ConstantFromArray(const Array& values) { + return ConstantLiteral(*Literal::CreateFromArray(values)); +} + +template +XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { + return ConstantLiteral( + *Literal::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D& values) { + return ConstantLiteral(*Literal::CreateR2FromArray2D(values)); +} + +template +XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( + const Array3D& values, const Layout& layout) { + return ConstantLiteral( + *Literal::CreateR3FromArray3DWithLayout(values, layout)); +} + +template +XlaOp XlaBuilder::ConstantR3FromArray3D(const Array3D& values) { + return ConstantFromArray(values); +} + +template +XlaOp XlaBuilder::ConstantR4FromArray4DWithLayout( + const Array4D& values, const Layout& layout) { + return ConstantFromArrayWithLayout(values, layout); +} + +template +XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { + return ConstantFromArray(values); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..85d4227ba4d8d04b1d2ba8b1d24922b13bd9cae5 --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc @@ -0,0 +1,235 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +namespace { + +namespace op = xla::testing::opcode_matchers; + +using ::testing::HasSubstr; + +// TODO(b/74197823): Move the tests to service/. +class XlaBuilderTest : public ::testing::Test { + protected: + StatusOr> BuildHloModule(XlaBuilder* b) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build()); + const HloModuleProto& proto = computation.proto(); + TF_ASSIGN_OR_RETURN(const auto& config, + HloModule::CreateModuleConfigFromProto(proto)); + return HloModule::CreateFromProto(proto, config); + } + + // Returns the name of the test currently being run. + string TestName() const { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } +}; + +TEST_F(XlaBuilderTest, OnePlusTwo) { + XlaBuilder b(TestName()); + b.Add(b.ConstantR0(1.0), b.ConstantR0(2.0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); +} + +TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); + b.Add(x, b.ConstantR0(1.0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(), op::Broadcast(op::Constant()))); +} + +TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { + XlaBuilder b(TestName()); + const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6}); + const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4}); + auto x = b.Parameter(0, x_shape, "x"); + auto y = b.Parameter(1, y_shape, "y"); + auto add = b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); + + TF_ASSERT_OK_AND_ASSIGN(auto add_shape, add.GetShape()); + EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape)); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(0), op::Broadcast(op::Parameter(1)))); +} + +TEST_F(XlaBuilderTest, XPlusX) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x"); + b.Add(x, x); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(0))); +} + +TEST_F(XlaBuilderTest, ShapeInferenceError) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(U32, {2, 4}), "y"); + b.Add(x, y); + auto statusor = BuildHloModule(&b); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("shape inference")); +} + +TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) { + XlaBuilder b_call("add"); + b_call.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); + + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); + auto y = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "y"); + b.Add(x, y); + auto statusor = BuildHloModule(&b); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("parameter 0 already registered")); +} + +TEST_F(XlaBuilderTest, Call) { + XlaBuilder b_call("the_only_to_apply"); + auto p0 = b_call.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); + auto p1 = b_call.Parameter(1, ShapeUtil::MakeShape(F32, {}), "p1"); + b_call.Add(p0, p1); + TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build()); + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + auto one = b.ConstantR0(1); + auto two = b.ConstantR0(2); + b.Add(b.Call(call, {x, y}), b.Call(call, {one, two})); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Call(op::Parameter(), op::Parameter()), + op::Call(op::Constant(), op::Constant()))); +} + +TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y"); + b.Add(x, y); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + + // Expected: + // + // x: f32[1,2,3] y: f32[1,2,1] + // | | + // | reshape: f32[1,2] + // | | + // | broadcast: f32[1,2,3] + // \ / + // add + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(0), + op::Broadcast(op::Reshape(op::Parameter(1))))); +} + +TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y"); + b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + + // The binary operation has in-dim broadcast and degenerate broadcast, should + // first do the in-dim broadcast then convert the degnerate broadcast into a + // reshape and a broadcast. + // + // Expected: + // + // x: f32[2,3] y: f32[2,1,4] + // | | + // broadcast: f32[2,3,4] reshape: f32[2,4] + // | | + // | broadcast: f32[2,3,4] + // \ / + // add + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Broadcast(op::Parameter(0)), + op::Broadcast(op::Reshape(op::Parameter(1))))); +} + +TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { + XlaBuilder b1("b1"); + auto p0 = b1.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); + XlaBuilder builder("main"); + builder.Add(p0, p0); + auto statusor = builder.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Do not add XlaOp from builder b1 to builder main")); +} + +TEST_F(XlaBuilderTest, ReshapeDefaultOrder) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); + b.Reshape(x, /*new_sizes=*/{6, 35}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reshape(op::Parameter())); +} + +TEST_F(XlaBuilderTest, ReshapeHasTranspose) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); + b.Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reshape(op::Transpose(op::Parameter()))); +} + +TEST_F(XlaBuilderTest, Transpose) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + b.Transpose(x, /*permutation=*/{1, 0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Transpose(op::Parameter())); +} + +// TODO(b/65209188): Create a dedicated lowering for Xor. +TEST_F(XlaBuilderTest, Xor) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(PRED, {}), "y"); + b.Xor(x, y); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + LOG(ERROR) << module->ToString(); + EXPECT_THAT(root, + op::Or(op::And(op::Not(op::Parameter(0)), op::Parameter(1)), + op::And(op::Parameter(0), op::Not(op::Parameter(1))))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_client/xla_computation.cc new file mode 100644 index 0000000000000000000000000000000000000000..3681792eeea081f87ee055e79ba841b4917a428d --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.cc @@ -0,0 +1,26 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" + +#include + +namespace xla { + +const ProgramShape& XlaComputation::GetProgramShape() const { + return proto_.program_shape(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.h b/tensorflow/compiler/xla/client/xla_client/xla_computation.h new file mode 100644 index 0000000000000000000000000000000000000000..5b89747fdd4f91e82c7ebc7aa10c5a914100a0c8 --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// The computation graph that the user builds up with the XlaBuilder. +// +// TODO(b/74197823): Replace xla::Computation with this one. +class XlaComputation { + public: + XlaComputation(const XlaComputation&) = delete; + XlaComputation& operator=(const XlaComputation&) = delete; + + XlaComputation(XlaComputation&& from) = default; + + XlaComputation& operator=(XlaComputation&& from) = default; + + // Returns the "program shape" (parameter and return shapes) for this + // computation. + const ProgramShape& GetProgramShape() const; + const HloModuleProto& proto() const { return proto_; } + + private: + XlaComputation(const int64 unique_id) : unique_id_(unique_id) {} + HloModuleProto* mutable_proto() { return &proto_; } + friend class XlaBuilder; + + int64 unique_id_; + HloModuleProto proto_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc index a3b4286f4c12bf39a44c63dd6e7d303a46a418c3..7b6ae311c1099dccb8dceb2f49743c1b185cd5ab 100644 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc +++ b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/subprocess.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 823da43b5ab2e9c8e80181efc993735877a2c363..13675b7d0074592043b7e12de0aad948a3e9848f 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -223,7 +223,7 @@ Status Literal::CopySliceFromInternal( Literal::StrideConfig stride_config(src_literal.shape(), shape(), copy_size); - auto copy_proc = [&](const std::vector& indexes) { + auto copy_proc = [&](tensorflow::gtl::ArraySlice indexes) { // Map from multi-dimensional index, to source index. std::transform(indexes.begin(), indexes.end(), src_base.begin(), src_indexes.begin(), std::plus()); @@ -248,6 +248,28 @@ Status Literal::CopySliceFromInternal( return Status::OK(); } +Status Literal::CopyElementFrom(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index) { + DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); + const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex( + src_literal.shape(), src_index); + const int64 dest_linear_index = + IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index); + const int64 primitive_size = + ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); + + char* dest_address = + static_cast(untyped_data()) + dest_linear_index * primitive_size; + const char* source_address = + static_cast(src_literal.untyped_data()) + + src_linear_index * primitive_size; + if (dest_address != source_address) { + memcpy(dest_address, source_address, primitive_size); + } + return Status::OK(); +} + std::vector Literal::DecomposeTuple() { CHECK(ShapeUtil::IsTuple(shape())); std::vector elements; @@ -343,7 +365,7 @@ Status Literal::Piece::CopyFrom(const Literal::Piece& src) { #undef COPY_ELEMENTS default: return Unimplemented( - "Unhandled primitive type %s", + "Copying a Literal object with element type %s is not implemented.", PrimitiveType_Name(subshape().element_type()).c_str()); } } @@ -491,7 +513,10 @@ Status Literal::CopySliceFrom(const Literal& src_literal, default: break; } - return Unimplemented("Unhandled primitive type %d", shape().element_type()); + return Unimplemented( + "Copying a slice from a Literal object with element type %d is not " + "implemented.", + shape().element_type()); } /* static */ Literal Literal::Zero(PrimitiveType primitive_type) { @@ -808,9 +833,10 @@ std::unique_ptr Literal::Slice( DimensionVector result_dimensions; for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { CHECK_GE(start_indices[dnum], 0); - CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)); + CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)) + << "dnum = " << dnum; int64 dimension = limit_indices[dnum] - start_indices[dnum]; - CHECK_GT(dimension, 0); + CHECK_GE(dimension, 0) << "dnum = " << dnum; result_dimensions.push_back(dimension); } const auto result_shape = @@ -903,7 +929,7 @@ string Literal::GetAsString(tensorflow::gtl::ArraySlice multi_index, case U64: return StrCat(Get(multi_index, shape_index)); case F16: - return StrCat(Get(multi_index, shape_index)); + return StrCat(static_cast(Get(multi_index, shape_index))); case F32: return StrCat(Get(multi_index, shape_index)); case BF16: @@ -953,7 +979,8 @@ string Literal::GetSparseElementAsString(int64 sparse_element_number, return StrCat( GetSparseElement(sparse_element_number, shape_index)); case F16: - return StrCat(GetSparseElement(sparse_element_number, shape_index)); + return StrCat(static_cast( + GetSparseElement(sparse_element_number, shape_index))); case F32: return StrCat( GetSparseElement(sparse_element_number, shape_index)); @@ -997,6 +1024,36 @@ StatusOr Literal::GetIntegralAsS64( } } +Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, + int64 value) { + CHECK(LayoutUtil::IsDenseArray(shape())); + switch (shape().element_type()) { + case PRED: + Set(multi_index, value); + break; + case U8: + Set(multi_index, value); + break; + case S32: + Set(multi_index, value); + break; + case S64: + Set(multi_index, value); + break; + case U32: + Set(multi_index, value); + break; + case U64: + Set(multi_index, value); + break; + default: + return FailedPrecondition( + "Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type()).c_str()); + } + return Status::OK(); +} + tensorflow::gtl::ArraySlice Literal::GetSparseIndex( int64 sparse_element_number, const ShapeIndex& shape_index) const { const Piece& p = piece(shape_index); @@ -1328,8 +1385,9 @@ void Literal::EachCellAsString( } namespace { -template -std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { +template +std::unique_ptr ConvertBetweenNativeTypesWithConverter( + const Literal& src_literal, const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( src_literal.shape(), @@ -1339,11 +1397,18 @@ std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = static_cast(src_data[i]); + dest_data[i] = converter(src_data[i]); } return result_literal; } +template +std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { + auto converter = [](NativeSrcT src) { return static_cast(src); }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + template std::unique_ptr ConvertToC64(const Literal& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); @@ -1394,8 +1459,8 @@ StatusOr> ConvertIfDestTypeMatches( return ConvertToC64(src_literal); // Other types are not yet supported. default: - return InvalidArgument( - "Unimplemented: Convert from type %s to type %s", + return Unimplemented( + "Converting from type %s to type %s is not implemented.", PrimitiveType_Name(src_literal.shape().element_type()).c_str(), PrimitiveType_Name(primitive_dest_type).c_str()); } @@ -1406,6 +1471,9 @@ StatusOr> ConvertIfDestTypeMatches( StatusOr> Literal::Convert( PrimitiveType primitive_dest_type) const { TF_RET_CHECK(ShapeUtil::IsArray(shape())); + if (shape().element_type() == primitive_dest_type) { + return CloneToUnique(); + } switch (shape().element_type()) { #define CONVERT_IF_DEST_TYPE_MATCHES(type) \ case (type): \ @@ -1424,10 +1492,37 @@ StatusOr> Literal::Convert( #undef CONVERT_IF_DEST_TYPE_MATCHES // Other types are not yet supported. default: - return InvalidArgument("Unimplemented: Convert from type %s to type %s", - PrimitiveType_Name(shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); + return Unimplemented( + "Converting from type %s to type %s is not implemented.", + PrimitiveType_Name(shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); + } +} + +StatusOr> Literal::ConvertToShape( + const Shape& dest_shape, bool round_f32_to_bf16) const { + if (!ShapeUtil::IsTuple(dest_shape)) { + if (round_f32_to_bf16 && shape().element_type() == F32 && + dest_shape.element_type() == BF16) { + auto converter = [](float src) { + return tensorflow::bfloat16::round_to_bfloat16(src); + }; + return ConvertBetweenNativeTypesWithConverter(*this, + converter); + } + return Convert(dest_shape.element_type()); } + std::vector elements; + for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { + auto element = LiteralView::Create(*this, {i}); + TF_ASSIGN_OR_RETURN( + auto new_element, + element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); + elements.push_back(std::move(*new_element)); + } + auto converted = MakeUnique(); + *converted = Literal::MoveIntoTuple(&elements); + return std::move(converted); } template diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index d5ae3fd72322fe243f0156dfbe236b6d62ab8c9d..a96a76fbb4e1a46e225d33b715f073c05fe6275a 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -262,6 +262,11 @@ class Literal { tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size); + // Copies one element from src_literal[src_index] to (*this)[dest_index]. + Status CopyElementFrom(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index); + // Returns a vector containing the tuple elements of this Literal as separate // Literals. This Literal must be tuple-shaped and can be a nested tuple. The // elements are moved into the new Literals; no data is copied. Upon return @@ -333,6 +338,17 @@ class Literal { StatusOr> Convert( PrimitiveType primitive_dest_type) const; + // Converts this literal to the given shape. Returns an error is the + // conversion is not possible. + // + // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding + // instead of truncation; otherwise, truncation is used. + // + // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes + // the default behavior. + StatusOr> ConvertToShape( + const Shape& dest_shape, bool round_f32_to_bf16 = false) const; + // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); @@ -469,6 +485,11 @@ class Literal { StatusOr GetIntegralAsS64( tensorflow::gtl::ArraySlice multi_index) const; + // As Set(), but truncates `value` to the literal element type before storing. + // This literal must be an array. + Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, + int64 value); + // Returns an identity matrix (rank 2) with the given row and column count. template static std::unique_ptr MakeIdentityR2(int64 size); @@ -1269,7 +1290,7 @@ Status Literal::Populate(const FnType& generator) { int64 minor_dimension_size = ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); - auto init_function = [&](const std::vector& indexes) { + auto init_function = [&](tensorflow::gtl::ArraySlice indexes) { const int64 index = IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index ee2f4fe87440428c7364fe2924003c5124f4eaa2..7627762074b6132655c58690a7fffbaf2717e279 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -30,6 +30,7 @@ limitations under the License. namespace xla { namespace { +using tensorflow::gtl::ArraySlice; using ::testing::ElementsAre; using ::testing::HasSubstr; @@ -214,11 +215,11 @@ TEST_F(LiteralUtilTest, CreateSparse) { std::vector expected_values = {8, 9, 7, 10}; EXPECT_EQ(literal->sparse_indices()->data(), - tensorflow::gtl::ArraySlice( - expected_indices.data(), expected_indices.num_elements())); - EXPECT_EQ(tensorflow::gtl::ArraySlice(literal->data().data(), - expected_values.size()), - tensorflow::gtl::ArraySlice(expected_values)); + ArraySlice(expected_indices.data(), + expected_indices.num_elements())); + EXPECT_EQ( + ArraySlice(literal->data().data(), expected_values.size()), + ArraySlice(expected_values)); } TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { @@ -290,7 +291,7 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { // clang-format on std::vector> seen; literal->EachCellAsString( - [&seen](tensorflow::gtl::ArraySlice indices, const string& value) { + [&seen](ArraySlice indices, const string& value) { seen.emplace_back(indices[0], indices[1], value); }); @@ -622,11 +623,10 @@ TEST_F(LiteralUtilTest, TransposeR4) { // clang-format on auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1}); - reshape->EachCell( - [&](tensorflow::gtl::ArraySlice indices, float value) { - EXPECT_EQ(value, original->Get( - {indices[2], indices[3], indices[0], indices[1]})); - }); + reshape->EachCell([&](ArraySlice indices, float value) { + EXPECT_EQ(value, original->Get( + {indices[2], indices[3], indices[0], indices[1]})); + }); } TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { @@ -863,7 +863,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { const int64 zero_base[] = {0, 0, 0, 0}; const int64 step[] = {1, 1, 1, 1}; uint32 seqnr = 0; - auto init_proc = [&](const std::vector& indexes) { + auto init_proc = [&](ArraySlice indexes) { source->Set(indexes, ++seqnr); return true; }; @@ -879,7 +879,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); bool matched = true; - auto check_proc = [&](const std::vector& indexes) { + auto check_proc = [&](ArraySlice indexes) { std::copy(indexes.begin(), indexes.end(), source_indexes.begin()); std::transform(source_indexes.begin(), source_indexes.end(), src_base, source_indexes.begin(), std::plus()); @@ -1067,7 +1067,7 @@ TEST_F(LiteralUtilTest, Populate) { primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); auto literal = Literal::CreateFromShape(shape); - auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> uint32 { + auto generator = [&](ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), @@ -1079,7 +1079,7 @@ TEST_F(LiteralUtilTest, Populate) { std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; - auto check_function = [&](const std::vector& indexes) { + auto check_function = [&](ArraySlice indexes) { auto value = literal->Get(indexes); matched = matched && (value == generator(indexes)); return matched; @@ -1232,15 +1232,15 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { EXPECT_EQ(*conv, *c64); EXPECT_EQ(s32->Convert(TUPLE).status().code(), - tensorflow::error::INVALID_ARGUMENT); + tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(s32->Convert(S16).status().code(), - tensorflow::error::INVALID_ARGUMENT); + tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(s32->Convert(U16).status().code(), - tensorflow::error::INVALID_ARGUMENT); + tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(c64->Convert(F32).status().code(), - tensorflow::error::INVALID_ARGUMENT); + tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(c64->Convert(S32).status().code(), - tensorflow::error::INVALID_ARGUMENT); + tensorflow::error::UNIMPLEMENTED); } TEST_F(LiteralUtilTest, CopyFromProto_Bool) { @@ -1702,7 +1702,7 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { ASSERT_EQ(Literal::CreateSparse(dimensions, indices, {half{1.0}, half{2.0}, half{3.0}}) ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(half{2.0})); + tensorflow::strings::StrCat(static_cast(half{2.0}))); ASSERT_EQ( Literal::CreateSparse( dimensions, indices, diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index b5354131c94930b75ea66036ddb61ecd3993414f..8f231d1a12d92ecd93908771019c1440da6855e3 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -141,6 +141,33 @@ bool GetIntAttr(PyObject* o, const char* field, int64* result) { return true; } +// Returns "ok"; true if there is no error, false if there was an error. +bool HandleStringAttribute(PyObject* o, + const char* attr_name, + std::function f) { + if (!PyObject_HasAttrString(o, attr_name)) { + return true; // It's ok for the object to not have the attribute. + } + PyObject* attr = PyObject_GetAttrString(o, attr_name); + if (attr == nullptr) { + return false; // An error occurred getting the attribute. + } + if (attr == Py_None) { + Py_DECREF(attr); + return true; // The attribute is None, which we consider ok. + } + if (!PyString_Check(attr)) { + string message = tensorflow::strings::Printf("%s must be a string or none; got %s", + attr_name, numpy::PyObjectCppRepr(attr).c_str()); + PyErr_SetString(PyExc_TypeError, message.c_str()); + Py_DECREF(attr); + return false; // Type error, not ok. + } + f(PyString_AsString(attr)); + Py_DECREF(attr); + return true; // Handled string attribute, ok! +} + } } %} @@ -216,6 +243,7 @@ tensorflow::ImportNumpy(); PyExc_RuntimeError, $1.ToString().c_str()); return NULL; } + Py_INCREF(Py_None); $result = Py_None; } @@ -819,16 +847,32 @@ tensorflow::ImportNumpy(); if ($input == Py_None) { $1 = NULL; } else { - PyObject* o = PyObject_GetAttrString($input, "generate_hlo_graph"); - if (!o) { + if (!HandleStringAttribute($input, "generate_hlo_graph", [&](string s) { + build_options.set_generate_hlo_graph(std::move(s)); + })) { + return nullptr; + } + if (!HandleStringAttribute($input, "dump_optimized_hlo_proto_to", [&](string s) { + build_options.set_dump_optimized_hlo_proto_to(std::move(s)); + })) { + return nullptr; + } + if (!HandleStringAttribute($input, "dump_per_pass_hlo_proto_to", [&](string s) { + build_options.set_dump_per_pass_hlo_proto_to(std::move(s)); + })) { + return nullptr; + } + + PyObject* o = PyObject_GetAttrString($input, "hlo_profile"); + if (o == NULL) { return NULL; } if (o != Py_None) { - if (!PyString_Check(o)) { - PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.generate_hlo_graph must be a string or None."); + if (!PyBool_Check(o)) { + PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.hlo_profile must be a bool or None."); return NULL; } - build_options.set_generate_hlo_graph(PyString_AsString(o)); + build_options.set_hlo_profile(o == Py_True); } Py_DECREF(o); diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 3d87480728aab1d4ebbc71c6c7504d37cae5edaf..eec48479c929ab0823fef342fc284bfdc4b1f339 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -170,8 +170,7 @@ static string PyObjectCppStr(PyObject* o) { return ExtractStringAndDecref(s); } -// Safely returns a repr of the given Python object o as a C++ string. -static string PyObjectCppRepr(PyObject* o) { +string PyObjectCppRepr(PyObject* o) { PyObject* r = PyObject_Repr(o); return ExtractStringAndDecref(r); } diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index adfcc3b8588dce01718bb19dea936bace483be4d..9656cb1c31c39dbe54293700c2765d0723255657 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -107,6 +107,9 @@ void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) { std::copy(source.begin(), source.end(), dest); } +// Safely returns a repr of the given Python object o as a C++ string. +string PyObjectCppRepr(PyObject* o); + // Workarounds for Python 2 and 3 interop PyObject* LongToPyIntOrPyLong(long x); // NOLINT diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 3b8ec851d5aa032ebbf4f6cfc7e12f5a03539cbd..e548d420f4614d3b3fff6034f9a174d553ebea66 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -30,9 +30,9 @@ from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import pywrap_xla as c_api -# Most functions are snake_case for consistency with other modules, -# whereas method names of ComputationBuilder and LocalComputation are -# CamelCase for consistency with XLA. +# Most functions are snake_case for consistency with other modules, whereas +# method names of ComputationBuilder and LocalComputation are CamelCase for +# consistency with XLA. # pylint: disable=invalid-name @@ -123,24 +123,34 @@ _BINARY_OPS = [ 'Pow', ] + XLA_ELEMENT_TYPE_TO_DTYPE = { - xla_data_pb2.F32: np.dtype(np.float32), - xla_data_pb2.F64: np.dtype(np.float64), - xla_data_pb2.S32: np.dtype(np.int32), - xla_data_pb2.S64: np.dtype(np.int64), - xla_data_pb2.U32: np.dtype(np.uint32), - xla_data_pb2.U64: np.dtype(np.uint64), - xla_data_pb2.PRED: np.dtype(np.bool), + xla_data_pb2.PRED: np.dtype('bool'), + xla_data_pb2.S8: np.dtype('int8'), + xla_data_pb2.S16: np.dtype('int16'), + xla_data_pb2.S32: np.dtype('int32'), + xla_data_pb2.S64: np.dtype('int64'), + xla_data_pb2.U8: np.dtype('uint8'), + xla_data_pb2.U16: np.dtype('uint16'), + xla_data_pb2.U32: np.dtype('uint32'), + xla_data_pb2.U64: np.dtype('uint64'), + xla_data_pb2.F16: np.dtype('float16'), + xla_data_pb2.F32: np.dtype('float32'), + xla_data_pb2.F64: np.dtype('float64'), + xla_data_pb2.C64: np.dtype('complex64'), xla_data_pb2.TUPLE: np.dtype(np.object), } # Note the conversion on the key. Numpy has a known issue wherein dtype hashing # doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, # when keying by dtype in this dict, we use the string form of dtypes. -DTYPE_TO_XLA_ELEMENT_TYPE = { - str(v): k - for k, v in XLA_ELEMENT_TYPE_TO_DTYPE.items() -} +DTYPE_TO_XLA_ELEMENT_TYPE = {str(dt): et + for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items()} + + +def dtype_to_etype(dtype): + """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" + return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] class LocalBuffer(object): @@ -310,6 +320,9 @@ class CompileOptions(object): def __init__(self): self.generate_hlo_graph = None + self.dump_optimized_hlo_proto_to = None + self.dump_per_pass_hlo_proto_to = None + self.hlo_profile = False def transfer_to_infeed(value, replica_number=None): diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index a9acdae380af5b7f9efb3d08302fc717108f5e40..ad3a28e11939d6259ebd75d544a950ba7abd741f 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -30,29 +30,23 @@ limitations under the License. namespace xla { -/* static */ std::unique_ptr> ReferenceUtil::TransposeArray2D( - const Array2D& operand) { - auto result = MakeUnique>(operand.width(), operand.height()); - for (int64 w = 0; w < operand.width(); ++w) { - for (int64 h = 0; h < operand.height(); ++h) { - (*result)(w, h) = operand(h, w); - } - } - - return result; -} - -/* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( - const Array2D& lhs, const Array2D& rhs) { +namespace { + +template +std::unique_ptr> MatmulArray2DImpl( + const Array2D& lhs, const Array2D& rhs, + const std::function& impl_fn) { CHECK_EQ(lhs.width(), rhs.height()); int m = lhs.height(); int n = rhs.width(); int k = lhs.width(); - auto result = MakeUnique>(m, n); + auto result = MakeUnique>(m, n); // Because Eigen is a header-oriented library, make sure that the Eigen code // is the same as the code used by the CPU backend (otherwise the linker will // randomly pick *some* definition). - __xla_cpu_runtime_EigenSingleThreadedMatMulF32( + impl_fn( /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, k, /*transpose_lhs=*/0, @@ -60,22 +54,24 @@ namespace xla { return result; } +} // namespace + +/* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16); +} + +/* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32); +} + /* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { - CHECK_EQ(lhs.width(), rhs.height()); - int m = lhs.height(); - int n = rhs.width(); - int k = lhs.width(); - auto result = MakeUnique>(m, n); - // Because Eigen is a header-oriented library, make sure that the Eigen code - // is the same as the code used by the CPU backend (otherwise the linker will - // randomly pick *some* definition). - __xla_cpu_runtime_EigenSingleThreadedMatMulF64( - /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, - k, - /*transpose_lhs=*/0, - /*transpose_rhs=*/0); - return result; + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64); } /* static */ std::unique_ptr> ReferenceUtil::Array2DF32ToF64( @@ -188,18 +184,6 @@ ReferenceUtil::SeparableConvArray4D(const Array4D& input, return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride); } -/* static */ std::unique_ptr> -ReferenceUtil::ReduceWindow1DGeneric( - const tensorflow::gtl::ArraySlice& operand, float init, - const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { - std::vector dim_lengths{static_cast(operand.size())}; - return ReduceWindow1DGeneric( - operand, init, reduce_func, window, stride, - xla::MakePadding(dim_lengths, window, stride, padding)); -} - /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow1DGeneric( const tensorflow::gtl::ArraySlice& operand, float init, @@ -239,23 +223,28 @@ ReferenceUtil::ReduceWindow1DAdd( const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; - return ReduceWindow1DGeneric(operand, init, add_reduce, window, stride, - padding); + std::vector dim_lengths{static_cast(operand.size())}; + return ReduceWindow1DGeneric( + operand, init, add_reduce, window, stride, + xla::MakePadding(dim_lengths, window, stride, padding)); } -/* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( +/* static */ std::unique_ptr> +ReferenceUtil::ReduceWindow2DGeneric( const Array2D& operand, float init, + const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding) { std::vector dim_lengths{operand.height(), operand.width()}; - auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); std::vector window_counts(window.size(), 0); std::vector pad_low(window.size(), 0); for (int64 i = 0; i < window.size(); ++i) { + int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; window_counts[i] = - WindowCount(dim_lengths[i], window[i], stride[i], padding); - pad_low[i] = padding_both[i].first; + window_util::StridedBound(padded_width, window[i], stride[i]); + pad_low[i] = padding[i].first; } auto result = MakeUnique>(window_counts[0], window_counts[1]); @@ -271,7 +260,7 @@ ReferenceUtil::ReduceWindow1DAdd( if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && i0_base + i0_win < operand.n1() && i1_base + i1_win < operand.n2()) { - val += operand(i0_base + i0_win, i1_base + i1_win); + val = reduce_func(val, operand(i0_base + i0_win, i1_base + i1_win)); } } } @@ -281,6 +270,17 @@ ReferenceUtil::ReduceWindow1DAdd( return result; } +/* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( + const Array2D& operand, float init, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; + std::vector dim_lengths{operand.height(), operand.width()}; + return ReduceWindow2DGeneric( + operand, init, add_reduce, window, stride, + xla::MakePadding(dim_lengths, window, stride, padding)); +} + /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow3DAdd( const Array3D& operand, float init, const tensorflow::gtl::ArraySlice& window, @@ -472,7 +472,7 @@ ReferenceUtil::SelectAndScatter4DGePlus( i3_base + i3_win < operand.n4()) { float tmp = operand(i0_base + i0_win, i1_base + i1_win, i2_base + i2_win, i3_base + i3_win); - if (tmp >= val) { + if (tmp > val) { val = tmp; scatter_0 = i0_base + i0_win; scatter_1 = i1_base + i1_win; diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 3ec96f2f38b8f91e1549419b60481327fa9bbd5f..28d6a8c3fe85fa4179bf2f41c82ad4eb93a045fe 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -39,10 +39,22 @@ namespace xla { class ReferenceUtil { public: // Returns the result of a transpose operation on the input matrix. - static std::unique_ptr> TransposeArray2D( - const Array2D& operand); + template + static std::unique_ptr> TransposeArray2D( + const Array2D& operand) { + auto result = MakeUnique>(operand.width(), operand.height()); + for (int64 w = 0; w < operand.width(); ++w) { + for (int64 h = 0; h < operand.height(); ++h) { + (*result)(w, h) = operand(h, w); + } + } + + return result; + } // Returns the result of a matrix multiply `lhs x rhs`. + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); static std::unique_ptr> MatmulArray2D( const Array2D& lhs, const Array2D& rhs); static std::unique_ptr> MatmulArray2D( @@ -187,9 +199,10 @@ class ReferenceUtil { const tensorflow::gtl::ArraySlice& operand, float init, const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); - static std::unique_ptr> ReduceWindow1DGeneric( - const tensorflow::gtl::ArraySlice& operand, float init, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding); + static std::unique_ptr> ReduceWindow2DGeneric( + const Array2D& operand, float init, const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, @@ -215,6 +228,7 @@ class ReferenceUtil { // Performs select and scatter with Greater Than or equal as the select, plus // as the scatter, and Same Padding. + // TODO(b/74533103) Switch tests to evaluator and remove this implementation. static std::unique_ptr> SelectAndScatter4DGePlus( const Array4D& operand, const Array4D& source, float init, const tensorflow::gtl::ArraySlice& window, diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 37ca1b893abdd914ca21ddaa970b821f157810f9..da16976d06ad516644113e8e727ce6b24b6bb26a 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -106,6 +106,7 @@ tf_cc_test( ":bfloat16_normalization", ":bfloat16_support", ":hlo", + ":hlo_verifier", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -126,7 +127,10 @@ cc_library( ":bfloat16_support", ":hlo", ":hlo_dataflow_analysis", + ":hlo_dce", ":hlo_pass", + ":tuple_simplifier", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", @@ -146,6 +150,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep ], ) @@ -618,6 +623,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], @@ -984,6 +990,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -1060,6 +1067,38 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_module_group_metadata", + srcs = ["hlo_module_group_metadata.cc"], + hdrs = ["hlo_module_group_metadata.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_module_group_util", + srcs = ["hlo_module_group_util.cc"], + hdrs = ["hlo_module_group_util.h"], + deps = [ + ":hlo", + ":hlo_module_group_metadata", + ":hlo_reachability", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + cc_library( name = "hlo_scheduling", srcs = ["hlo_scheduling.cc"], @@ -1091,6 +1130,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1128,6 +1168,19 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_creation_utils", + srcs = ["hlo_creation_utils.cc"], + hdrs = ["hlo_creation_utils.h"], + deps = [ + ":hlo", + ":shape_inference", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + ], +) + cc_library( name = "batchnorm_expander", srcs = ["batchnorm_expander.cc"], @@ -1136,7 +1189,6 @@ cc_library( ":hlo", ":hlo_pass", ":hlo_query", - ":shape_inference", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1148,6 +1200,20 @@ cc_library( ], ) +cc_library( + name = "gather_expander", + srcs = ["gather_expander.cc"], + hdrs = ["gather_expander.h"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + ":while_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + ], +) + tf_cc_test( name = "batchnorm_expander_test", size = "small", @@ -1175,9 +1241,9 @@ cc_library( hdrs = ["algebraic_simplifier.h"], deps = [ ":hlo", + ":hlo_creation_utils", ":hlo_pass", ":hlo_query", - ":shape_inference", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1211,6 +1277,53 @@ tf_cc_test( ], ) +tf_cc_test( + name = "gather_expander_test", + srcs = ["gather_expander_test.cc"], + deps = [ + ":gather_expander", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:test_macros_header", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + +cc_library( + name = "conditional_simplifier", + srcs = ["conditional_simplifier.cc"], + hdrs = ["conditional_simplifier.h"], + deps = [ + ":call_inliner", + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "conditional_simplifier_test", + srcs = ["conditional_simplifier_test.cc"], + deps = [ + ":conditional_simplifier", + ":hlo", + ":hlo_matchers", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "while_loop_simplifier", srcs = ["while_loop_simplifier.cc"], @@ -1233,6 +1346,7 @@ tf_cc_test( ":while_loop_simplifier", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/core:lib", "//tensorflow/core:test", ], ) @@ -2347,6 +2461,24 @@ cc_library( ":hlo", ":hlo_proto", "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:util", + ], +) + +tf_cc_test( + name = "hlo_proto_util_test", + srcs = ["hlo_proto_util_test.cc"], + deps = [ + ":hlo", + ":hlo_proto", + ":hlo_proto_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", ], ) @@ -2445,7 +2577,9 @@ cc_library( deps = [ ":call_inliner", ":hlo", + ":hlo_creation_utils", ":tuple_util", + "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 4391462c1cbb3b4333d2a129efbd53e9faf02db7..f9fabd8a35bcee2253b30fc5ad9e5fee545f06eb 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -26,10 +26,10 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -122,6 +122,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleBitcastConvert(HloInstruction* bitcast) override; + Status HandleBroadcast(HloInstruction* broadcast) override; Status HandleConcatenate(HloInstruction* concatenate) override; @@ -300,7 +302,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable dot strength reduction on platforms where it causes a slowdown. bool enable_dot_strength_reduction_; - // Disable convolution simplication on platforms where it causes a slowdown. + // Disable convolution simplification on platforms where it causes a slowdown. bool enable_conv_simplification_; }; @@ -381,13 +383,9 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { !lhs->operand(0)->IsConstant() && lhs->operand(1)->IsConstant()) { auto* c1 = lhs->mutable_operand(1); auto* c2 = rhs; - TF_ASSIGN_OR_RETURN( - Shape sum_of_constants_shape, - ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, c1, c2)); - auto* sum_of_constants = - computation_->AddInstruction(HloInstruction::CreateBinary( - sum_of_constants_shape, HloOpcode::kAdd, c1, c2)); + TF_ASSIGN_OR_RETURN(auto* sum_of_constants, + MakeBinaryHlo(HloOpcode::kAdd, c1, c2)); return ReplaceWithNewInstruction( add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, lhs->mutable_operand(0), @@ -411,6 +409,13 @@ Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleBitcastConvert( + HloInstruction* bitcast) { + // Eliminate bitcast converts between same shape. + ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0)); + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { // If a copy feeds a copy, make it a single copy. if (copy->operand(0)->opcode() == HloOpcode::kCopy) { @@ -631,32 +636,23 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) if (lhs->opcode() == HloOpcode::kDivide && rhs->opcode() == HloOpcode::kDivide) { - TF_ASSIGN_OR_RETURN( - const Shape a_times_d_shape, - ShapeInference::InferBinaryOpShape(HloOpcode::kMultiply, - lhs->operand(0), rhs->operand(1))); - auto a_times_d = computation_->AddInstruction(HloInstruction::CreateBinary( - a_times_d_shape, HloOpcode::kMultiply, lhs->mutable_operand(0), - rhs->mutable_operand(1))); - TF_ASSIGN_OR_RETURN( - const Shape b_times_c_shape, - ShapeInference::InferBinaryOpShape(HloOpcode::kMultiply, - lhs->operand(1), rhs->operand(0))); - auto b_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( - b_times_c_shape, HloOpcode::kMultiply, lhs->mutable_operand(1), - rhs->mutable_operand(0))); - return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kDivide, a_times_d, b_times_c)); + TF_ASSIGN_OR_RETURN(auto a_times_d, MakeBinaryHlo(HloOpcode::kMultiply, + lhs->mutable_operand(0), + rhs->mutable_operand(1))); + TF_ASSIGN_OR_RETURN(auto b_times_c, MakeBinaryHlo(HloOpcode::kMultiply, + lhs->mutable_operand(1), + rhs->mutable_operand(0))); + TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide, + a_times_d, b_times_c)); + + return ReplaceInstruction(divide, new_divide); } // (A / B) / C => A / (B * C) if (lhs->opcode() == HloOpcode::kDivide) { - TF_ASSIGN_OR_RETURN(const Shape b_times_c_shape, - ShapeInference::InferBinaryOpShape( - HloOpcode::kMultiply, lhs->operand(1), rhs)); - auto b_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( - b_times_c_shape, HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); + TF_ASSIGN_OR_RETURN( + auto b_times_c, + MakeBinaryHlo(HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); return ReplaceWithNewInstruction( divide, HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, @@ -665,11 +661,8 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { // A / (B / C) => (A*C) / B if (rhs->opcode() == HloOpcode::kDivide) { - TF_ASSIGN_OR_RETURN(const Shape a_times_c_shape, - ShapeInference::InferBinaryOpShape( - HloOpcode::kMultiply, lhs, rhs->operand(1))); - auto a_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( - a_times_c_shape, HloOpcode::kMultiply, lhs, rhs->mutable_operand(1))); + TF_ASSIGN_OR_RETURN(auto a_times_c, MakeBinaryHlo(HloOpcode::kMultiply, lhs, + rhs->mutable_operand(1))); return ReplaceWithNewInstruction( divide, HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, @@ -1128,10 +1121,10 @@ bool OutputIsSubsetOfOperandElements(HloInstruction* instruction, Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { auto operand = broadcast->mutable_operand(0); + auto dims = broadcast->dimensions(); // A degenerate broadcast of a reshape that does not change the number of // elements can be replaced by a reshape. - if (std::is_sorted(broadcast->dimensions().begin(), - broadcast->dimensions().end()) && + if (std::is_sorted(dims.begin(), dims.end()) && ShapeUtil::ElementsIn(broadcast->shape()) == ShapeUtil::ElementsIn(operand->shape())) { VLOG(10) << "transform broadcast(X) -> reshape(X) where " @@ -1149,8 +1142,8 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { VLOG(10) << "transform broadcast(X) -> transpose(X) where " "n(broadcast(X)) == n(X)"; return ReplaceWithNewInstruction( - broadcast, HloInstruction::CreateTranspose(broadcast->shape(), operand, - broadcast->dimensions())); + broadcast, + HloInstruction::CreateTranspose(broadcast->shape(), operand, dims)); } // A broadcast of a reshape which merely inserts 1-sized dimensions can @@ -1164,7 +1157,6 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { if (merely_inserts_or_deletes_1_sized_dimensions && deleted_indices.empty()) { std::reverse(inserted_indices.begin(), inserted_indices.end()); - auto dims = broadcast->dimensions(); for (auto inserted_index : inserted_indices) { dims.erase(dims.begin() + inserted_index); } @@ -1208,6 +1200,19 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { return user->ReplaceAllUsesWith(new_broadcast); } } + return Status::OK(); + } + + // Merge two consecutive broadcasts into a single one. + if (operand->opcode() == HloOpcode::kBroadcast) { + std::vector new_dimensions; + for (auto dim : operand->dimensions()) { + new_dimensions.push_back(dims[dim]); + } + return ReplaceWithNewInstruction( + broadcast, + HloInstruction::CreateBroadcast( + broadcast->shape(), operand->mutable_operand(0), new_dimensions)); } return Status::OK(); } @@ -1302,17 +1307,14 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { padding_dimension->set_edge_padding_high(0); } } - TF_ASSIGN_OR_RETURN(Shape nonzero_pad_shape, - ShapeInference::InferPadShape(pad->operand(0)->shape(), - pad->operand(1)->shape(), - nonzero_padding)); + + TF_ASSIGN_OR_RETURN(HloInstruction * nonzero_pad, + MakePadHlo(pad->mutable_operand(0), + pad->mutable_operand(1), nonzero_padding)); // Copy the layout from the original pad instructions. The new pad and the // slice instruction should all have the same layout. - TF_RETURN_IF_ERROR( - LayoutUtil::CopyLayoutBetweenShapes(pad->shape(), &nonzero_pad_shape)); - HloInstruction* nonzero_pad = computation_->AddInstruction( - HloInstruction::CreatePad(nonzero_pad_shape, pad->mutable_operand(0), - pad->mutable_operand(1), nonzero_padding)); + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + pad->shape(), nonzero_pad->mutable_shape())); // Second, construct the slice instruction to perform the negative padding. std::vector start_indices; @@ -1325,7 +1327,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { if (padding_dimension.edge_padding_low() < 0) { start = -1 * padding_dimension.edge_padding_low(); } - int64 end = nonzero_pad_shape.dimensions(i); + int64 end = nonzero_pad->shape().dimensions(i); if (padding_dimension.edge_padding_high() < 0) { end += padding_dimension.edge_padding_high(); } @@ -1334,16 +1336,14 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { strides.push_back(1); } - // Verify that the slice shape matches the pad shape. TF_ASSIGN_OR_RETURN( - Shape inferred_slice_shape, - ShapeInference::InferSliceShape(nonzero_pad_shape, start_indices, - end_indices, strides)); - TF_RET_CHECK(ShapeUtil::Compatible(inferred_slice_shape, pad->shape())); + HloInstruction * slice, + MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides)); + + // Verify that the slice shape matches the pad shape. + TF_RET_CHECK(ShapeUtil::Compatible(slice->shape(), pad->shape())); - std::unique_ptr slice = HloInstruction::CreateSlice( - pad->shape(), nonzero_pad, start_indices, end_indices, strides); - return ReplaceWithNewInstruction(pad, std::move(slice)); + return ReplaceInstruction(pad, slice); } return Status::OK(); @@ -1616,6 +1616,14 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( if (IsAll(start_indices, 0) && SameShape(dynamic_update_slice, update)) { return ReplaceInstruction(dynamic_update_slice, update); } + + // If any dimension of update is 0, elide the DynamicUpdateSlice. This + // optimization becomes invalid should we later prefer to warn about out of + // bound indices. + if (ShapeUtil::HasZeroElements(update->shape())) { + return ReplaceInstruction(dynamic_update_slice, + dynamic_update_slice->mutable_operand(0)); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 43315f5cdc7afbe79039420320f4a0d0535e11f1..c48196e861a559a5abfa360841ec70b39356fa2b 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -23,7 +23,7 @@ limitations under the License. namespace xla { -// A pass which performs AlgebraicSimplications. +// A pass which performs algebraic simplifications. class AlgebraicSimplifier : public HloPassInterface { public: // Given shapes 'from_shape' and 'to_shape', determines if it is valid to @@ -57,10 +57,10 @@ class AlgebraicSimplifier : public HloPassInterface { bool is_layout_sensitive_; ValidBitcastCallback valid_bitcast_callback_; - // Enable dot simplication on platforms where it is profitable. + // Enable dot simplification on platforms where it is profitable. bool enable_dot_strength_reduction_; - // Enable convolution simplication on platforms where it is profitable. + // Enable convolution simplification on platforms where it is profitable. bool enable_conv_simplification_; }; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 667ae01993ebf0feeab89e0b5afaf7c7c8c99ab9..3b80a827bf0b5f1041c7351be0943bf1ad8c8afe 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -35,6 +35,8 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/str_util.h" +using ::testing::ElementsAre; + namespace xla { namespace { @@ -2462,6 +2464,55 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { op::DynamicSlice(op::Parameter(), op::Parameter())); } +// Test that two consecutive broadcasts can be merged to one. +TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { + HloComputation::Builder builder(TestName()); + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* input_array = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({3, 4}))); + HloInstruction* inner_bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, input_array, {1})); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2}); + builder.AddInstruction( + HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root->dimensions(), ElementsAre(2)); +} + +// Test that two consecutive broadcasts can be merged to one. +TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { + HloComputation::Builder builder(TestName()); + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3}); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + // The initial dimensions go to places 0 and 2 in the 3-dim array, + // and to places 1 and 3 in the 4-dim array, + HloInstruction* inner_bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r3f32, param0, {0, 2})); + Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3}); + builder.AddInstruction( + HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Broadcast(op::Parameter(0))); + EXPECT_THAT(root->dimensions(), ElementsAre(1, 3)); +} + struct PadReduceWindowEffectiveBroadcastCase { std::vector input_spatials; std::vector symmetric_pad_spatials; @@ -2800,6 +2851,29 @@ DotOfConcatTestSpec kDotOfConcatTestSpecs[] = { {/*m=*/1, /*k=*/16, /*n=*/1}, // }; +// Test that DynamicUpdateSlice update param with any dimension equal to zero +// gets removed. +TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { + HloComputation::Builder builder(TestName()); + const Shape dslice_shape = ShapeUtil::MakeShape(F32, {10}); + HloInstruction* const operand = builder.AddInstruction( + HloInstruction::CreateParameter(0, dslice_shape, "operand")); + const Shape update_shape = ShapeUtil::MakeShape(F32, {0}); + HloInstruction* const update = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "update")); + HloInstruction* const start_indices = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({0}))); + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + dslice_shape, operand, update, start_indices)); + const HloComputation* const computation = + module().AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), operand); +} + INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation, DotOfConcatSimplificationTest, ::testing::ValuesIn(kDotOfConcatTestSpecs)); diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 4e80679c11dfdf7fdf8077a9f354139a4cab6803..4f819a743c48f30df8dde00ece72a0b4e1748802 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -34,40 +34,54 @@ StatusOr AllocationTracker::Register( std::unique_ptr shaped_buffer, const string& tag) { tensorflow::mutex_lock lock(mutex_); VLOG(2) << "Register"; - return RegisterInternal(std::move(shaped_buffer), tag); + std::vector> replicated_buffers; + replicated_buffers.emplace_back(std::move(shaped_buffer)); + return RegisterInternal(std::move(replicated_buffers), tag); +} + +StatusOr AllocationTracker::RegisterReplicatedBuffers( + std::vector> replicated_buffers, + const string& tag) { + tensorflow::mutex_lock lock(mutex_); + VLOG(2) << "RegisterReplicatedBuffers"; + return RegisterInternal(std::move(replicated_buffers), tag); } StatusOr AllocationTracker::RegisterInternal( - std::unique_ptr shaped_buffer, const string& tag) { + std::vector> replicated_buffers, + const string& tag) { VLOG(2) << "RegisterInternal(" - << "tag: \"" << tag << "\" " - << "shaped_buffer: " << *shaped_buffer; - if (shaped_buffer->platform() != backend_->platform()) { - return InvalidArgument( - "AllocationTracker for platform %s cannot register buffer from " - "platform %s", - backend_->platform()->Name().c_str(), - shaped_buffer->platform()->Name().c_str()); + << "tag: \"" << tag << "\" with " << replicated_buffers.size() + << " shaped_buffers."; + for (const auto& shaped_buffer : replicated_buffers) { + VLOG(2) << "shaped_buffer:" << *shaped_buffer; + if (shaped_buffer->platform() != backend_->platform()) { + return InvalidArgument( + "AllocationTracker for platform %s cannot register buffer from " + "platform %s", + backend_->platform()->Name().c_str(), + shaped_buffer->platform()->Name().c_str()); + } } int64 handle = next_handle_++; - std::vector shape_indices; - ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), - [this, &shape_indices](const Shape& /*subshape*/, - const ShapeIndex& index) { - shape_indices.push_back(index); - }); - for (const ShapeIndex& index : shape_indices) { - AddAllocationOrIncrementRefCount(shaped_buffer->buffer(index), - shaped_buffer->device_ordinal()); + for (auto& shaped_buffer : replicated_buffers) { + std::vector shape_indices; + ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), + [this, &shape_indices](const Shape& /*subshape*/, + const ShapeIndex& index) { + shape_indices.push_back(index); + }); + for (const ShapeIndex& index : shape_indices) { + AddAllocationOrIncrementRefCount(shaped_buffer->buffer(index), + shaped_buffer->device_ordinal()); + } + handle_to_shaped_buffers_[handle].emplace_back(std::move(shaped_buffer)); } + GlobalDataHandle result; result.set_handle(handle); - - handle_to_shaped_buffer_[handle] = std::move(shaped_buffer); - VLOG(2) << "handle: " << handle; - return result; } @@ -75,23 +89,35 @@ tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { tensorflow::mutex_lock lock(mutex_); VLOG(2) << "Unregister(" << "handle: " << data.handle() << ")"; - TF_ASSIGN_OR_RETURN(ShapedBuffer * shaped_buffer, ResolveInternal(data)); - std::vector shape_indices; - ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), - [this, &shape_indices](const Shape& /*subshape*/, - const ShapeIndex& index) { - shape_indices.push_back(index); - }); - for (const ShapeIndex& index : shape_indices) { - TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index), - shaped_buffer->device_ordinal())); + TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, + ResolveInternal(data)); + for (const auto& shaped_buffer : replicated_buffers) { + std::vector shape_indices; + ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), + [this, &shape_indices](const Shape& /*subshape*/, + const ShapeIndex& index) { + shape_indices.push_back(index); + }); + for (const ShapeIndex& index : shape_indices) { + TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index), + shaped_buffer->device_ordinal())); + } } + return Reset(data); +} - // Keep a nullptr as a tombstone for unregistered handles. This enables better - // error messages. That is, "handle has been deallocated" versus "handle does - // not exist". - handle_to_shaped_buffer_.at(data.handle()).reset(); - +Status AllocationTracker::Reset(const GlobalDataHandle& data) { + // Keep a nullptr as a tombstone for unregistered handles. This enables + // better error messages. That is, "handle has been deallocated" versus + // "handle does not exist". + auto it = handle_to_shaped_buffers_.find(data.handle()); + if (it == handle_to_shaped_buffers_.end()) { + return NotFound("no allocation record for global data handle: %lld", + data.handle()); + } + for (auto& shaped_buffer : it->second) { + shaped_buffer.reset(); + } return tensorflow::Status::OK(); } @@ -99,7 +125,11 @@ StatusOr> AllocationTracker::DeconstructTuple( const GlobalDataHandle& data) { tensorflow::mutex_lock lock(mutex_); - TF_ASSIGN_OR_RETURN(ShapedBuffer * shaped_buffer, ResolveInternal(data)); + TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, + ResolveInternal(data)); + // We only need to care about replica id 0 here, since the GlobalDataHandle is + // the same for all buffers across replicas. + const ShapedBuffer* shaped_buffer = replicated_buffers[0]; if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) { return InvalidArgument("global data handle %lld is not a tuple", data.handle()); @@ -109,7 +139,7 @@ StatusOr> AllocationTracker::DeconstructTuple( TF_RET_CHECK(ShapeUtil::IsTuple(shaped_buffer->on_device_shape())); if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) { - return Unimplemented("deconstructing nested tuples not yet supported"); + return Unimplemented("Deconstructing nested tuples is not implemented."); } std::vector element_handles; @@ -122,37 +152,55 @@ StatusOr> AllocationTracker::DeconstructTuple( shaped_buffer->platform(), shaped_buffer->device_ordinal()); element_buffer->set_buffer(shaped_buffer->buffer(/*index=*/{i}), /*index=*/{}); + std::vector> replicated_buffers; + replicated_buffers.emplace_back(std::move(element_buffer)); TF_ASSIGN_OR_RETURN( GlobalDataHandle element_handle, - RegisterInternal(std::move(element_buffer), "deconstructed tuple")); + RegisterInternal(std::move(replicated_buffers), "deconstructed tuple")); element_handles.push_back(element_handle); } return std::move(element_handles); } -StatusOr AllocationTracker::Resolve( +StatusOr> AllocationTracker::Resolve( const GlobalDataHandle& data) { tensorflow::mutex_lock lock(mutex_); return AllocationTracker::ResolveInternal(data); } -StatusOr AllocationTracker::ResolveInternal( +StatusOr AllocationTracker::ResolveForReplica( + const GlobalDataHandle& data, int replica_id) { + tensorflow::mutex_lock lock(mutex_); + TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, + ResolveInternal(data)); + if (replica_id >= replicated_buffers.size()) { + return InvalidArgument( + "Requesting buffer for replica %d, but found buffers only for %lu " + "replicas.", + replica_id, replicated_buffers.size()); + } + return replicated_buffers[replica_id]; +} + +StatusOr> AllocationTracker::ResolveInternal( const GlobalDataHandle& data) { VLOG(2) << "resolve:" << data.handle(); - auto it = handle_to_shaped_buffer_.find(data.handle()); - if (it == handle_to_shaped_buffer_.end()) { + auto it = handle_to_shaped_buffers_.find(data.handle()); + if (it == handle_to_shaped_buffers_.end()) { return NotFound("no allocation record for global data handle: %lld", data.handle()); } - ShapedBuffer* shaped_buffer = it->second.get(); - - if (shaped_buffer == nullptr) { - return InvalidArgument("global data handle %lld was previously deallocated", - data.handle()); + std::vector replicated_buffers; + for (const auto& shaped_buffer : it->second) { + if (shaped_buffer == nullptr) { + return InvalidArgument( + "global data handle %lld was previously deallocated", data.handle()); + } + replicated_buffers.push_back(shaped_buffer.get()); } - return shaped_buffer; + return replicated_buffers; } void AllocationTracker::AddAllocationOrIncrementRefCount( diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index 807af8694972083d097604a67ee46d2f73d9545a..038aee8541b297d6f91fe2b3bce7455fd9a7084e 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -43,10 +43,17 @@ class AllocationTracker { AllocationTracker(Backend* backend) : backend_(backend), next_handle_(1) {} // Registers a shaped buffer of device memory, and returns a corresponding - // handle that can be used for talking to XLA clients. + // handle that can be used for talking to XLA clients. The given shaped buffer + // will be treated as the buffer corresponding to the only replica. StatusOr Register( std::unique_ptr shaped_buffer, const string& tag); + // Registers a vector of shaped buffers of device memory, one per replica, and + // returns a corresponding handle that can be used for talking to XLA clients. + StatusOr RegisterReplicatedBuffers( + std::vector> replicated_buffers, + const string& tag); + // Unregister the allocation for the given data handle. Status Unregister(const GlobalDataHandle& data); @@ -54,9 +61,17 @@ class AllocationTracker { StatusOr> DeconstructTuple( const GlobalDataHandle& Data); - // Resolve a handle from an XLA client to a shaped buffer, or provide an error - // status to say whether it was not found (or found, but found deallocated). - StatusOr Resolve(const GlobalDataHandle& data); + // Resolve a handle from an XLA client to a vector of shaped buffers, one per + // replica, or provide an error status to say whether any of those buffers + // were not found (or found, but found deallocated). + StatusOr> Resolve( + const GlobalDataHandle& data); + + // Resolves a handle from an XLA client and replica id to a shaped buffer, or + // provide an error status to say whether it was not found (or found, but + // found deallocated). + StatusOr ResolveForReplica(const GlobalDataHandle& data, + int replica_id); private: // Data structure encapsulating single memory allocation on the device. @@ -74,13 +89,17 @@ class AllocationTracker { // Internal helper which resolves the given GlobalDataHandle to a // ShapedBuffer. - StatusOr ResolveInternal(const GlobalDataHandle& data) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); + StatusOr> ResolveInternal( + const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_); - // Internal helper which registers a shaped buffer. + // Internal helper which registers a vector of shaped buffers, one per + // replica. StatusOr RegisterInternal( - std::unique_ptr shaped_buffer, const string& tag) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); + std::vector> replicated_buffers, + const string& tag) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Resets the shaped buffers corresponding to the given handle. + Status Reset(const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Adds the given device address to the allocation tracker, or if it already // exists, then increment it's reference count. @@ -111,9 +130,10 @@ class AllocationTracker { tensorflow::gtl::FlatMap opaque_to_allocation_map_ GUARDED_BY(mutex_); - // A map from data handle to ShapedBuffer. - tensorflow::gtl::FlatMap> - handle_to_shaped_buffer_ GUARDED_BY(mutex_); + // A map from data handle to a vector of shaped buffers that represent the + // buffers for different replicas. + tensorflow::gtl::FlatMap>> + handle_to_shaped_buffers_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker); }; diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 84c9db32932becd9b701929b392efa4998d03067..38086bd7e121847be6b6b69415cfe87814e7fc24 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index cde990e176ddb57a8e93ecc3c60260b2dbae32a8..08d0152e3cfcfcb7ae1e85f72c2f7dc856f5e8b3 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -34,6 +34,9 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo) override; + // Special handling for cross-replica-sum which can have a tuple output. + Status HandleCrossReplicaSum(HloInstruction* crs) override; + static bool Run(HloComputation* computation, const BFloat16Support* bfloat16_support) { BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support); @@ -84,6 +87,25 @@ Status BFloat16ConversionFoldingVisitor::FoldOperandConversion( return Status::OK(); } +namespace { + +// Returns whether hlo has users and all users are conversions from F32 to BF16. +bool AllUsersAreF32ToBF16Converts(const HloInstruction* hlo) { + if (hlo->user_count() == 0 || hlo->shape().element_type() != F32) { + return false; + } + for (const auto user : hlo->users()) { + if (user->opcode() == HloOpcode::kConvert && + user->shape().element_type() == BF16) { + continue; + } + return false; + } + return true; +} + +} // namespace + Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( HloInstruction* hlo) { std::vector bf16_to_f32_operands; @@ -104,22 +126,9 @@ Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( } } - bool fold_output_conversion = hlo->user_count() > 0 && - hlo->shape().element_type() == F32 && - bfloat16_support_->SupportsBF16Output(*hlo) && - hlo != computation_->root_instruction(); - if (fold_output_conversion) { - for (auto user : hlo->users()) { - if (user->opcode() == HloOpcode::kConvert && - user->shape().element_type() == BF16) { - continue; - } - // We should not change the output type if any user is not a conversion - // from F32 to BF16. - fold_output_conversion = false; - break; - } - } + const bool fold_output_conversion = + AllUsersAreF32ToBF16Converts(hlo) && + bfloat16_support_->SupportsBF16Output(*hlo); if (!bfloat16_support_->SupportsMixedPrecisions(*hlo)) { if (has_other_f32_operands || @@ -147,6 +156,10 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { hlo->opcode() == HloOpcode::kGetTupleElement || // hlo->opcode() == HloOpcode::kInfeed || // hlo->opcode() == HloOpcode::kOutfeed || // + hlo->opcode() == HloOpcode::kSend || // + hlo->opcode() == HloOpcode::kSendDone || // + hlo->opcode() == HloOpcode::kRecv || // + hlo->opcode() == HloOpcode::kRecvDone || // hlo->opcode() == HloOpcode::kConstant || // hlo->opcode() == HloOpcode::kParameter || // hlo->opcode() == HloOpcode::kFusion || // @@ -167,6 +180,52 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { return TryFoldBF16Conversions(hlo); } +Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum( + HloInstruction* crs) { + if (!ShapeUtil::IsTuple(crs->shape()) || + !bfloat16_support_->SupportsMixedPrecisions(*crs)) { + return DefaultAction(crs); + } + + // First use DefaultAction() to handle the operands. It can't handle + // tuple-shaped output. + TF_RETURN_IF_ERROR(DefaultAction(crs)); + + // Then do per-tuple-element handling on the output. + std::vector> per_tuple_element_gtes( + crs->operand_count()); + for (auto user : crs->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + return Status::OK(); + } + per_tuple_element_gtes[user->tuple_index()].push_back(user); + } + + for (int64 i = 0; i < crs->operand_count(); ++i) { + // Fold conversions only when all the get-tuple-elements' users are + // conversions from F32 to BF16. + auto all_gte_users_are_bf16_convert = [&per_tuple_element_gtes, i]() { + for (auto gte : per_tuple_element_gtes[i]) { + if (!AllUsersAreF32ToBF16Converts(gte)) { + return false; + } + } + return true; + }; + if (!all_gte_users_are_bf16_convert()) { + continue; + } + + ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i}) + ->set_element_type(BF16); + for (auto gte : per_tuple_element_gtes[i]) { + TF_RETURN_IF_ERROR(FoldOutputConversions(gte)); + } + } + + return Status::OK(); +} + StatusOr BFloat16ConversionFolding::Run(HloModule* module) { XLA_VLOG_LINES( 2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString()); diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index cb37759439debf41a305ec7dccaa548e1bf234cd..28e71c2054f59ba4d5d096bf7d898161877bb42f 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -37,7 +37,8 @@ class TestBFloat16Support : public BFloat16Support { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kSubtract || hlo.opcode() == HloOpcode::kTuple || - hlo.opcode() == HloOpcode::kGetTupleElement) { + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kCrossReplicaSum) { return true; } return false; @@ -47,7 +48,8 @@ class TestBFloat16Support : public BFloat16Support { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kSubtract || hlo.opcode() == HloOpcode::kTuple || - hlo.opcode() == HloOpcode::kGetTupleElement) { + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kCrossReplicaSum) { return true; } return false; @@ -55,7 +57,8 @@ class TestBFloat16Support : public BFloat16Support { bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple || - hlo.opcode() == HloOpcode::kGetTupleElement) { + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kCrossReplicaSum) { return true; } return false; @@ -206,4 +209,46 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { EXPECT_EQ(tuple->operand(1), convert0); } +TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_shape, "a")); + HloInstruction* convert_a = + builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, a)); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32_shape, "b")); + + HloInstruction* crs = + builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( + ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b})); + HloInstruction* gte_a = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); + HloInstruction* gte_b = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, crs, 1)); + HloInstruction* convert_gte_b = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte_b)); + HloInstruction* tuple = builder.AddInstruction( + HloInstruction::CreateTuple({gte_a, convert_gte_b})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), tuple); + EXPECT_EQ(tuple->operand(0), gte_a); + EXPECT_EQ(tuple->operand(1), gte_b); + EXPECT_EQ(gte_a->shape().element_type(), F32); + EXPECT_EQ(gte_b->shape().element_type(), BF16); + EXPECT_EQ(crs->operand(0), a); + EXPECT_EQ(crs->operand(1), b); + EXPECT_EQ(a->shape().element_type(), BF16); + EXPECT_EQ(b->shape().element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {0}).element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), BF16); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index b032c040e8aff49f9e0fc1ff9a1c1e79ea4bb77f..14c54ddd135af024327f63418b410da1ed3c4fd4 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -152,44 +152,64 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( std::vector operand_types(crs->operand_count()); std::vector output_types(crs->operand_count()); - bool has_f32 = false; - bool has_bf16 = false; - bool has_bf16_output = false; + int64 f32_count = 0; + int64 bf16_count = 0; + bool has_unsupported_bf16_operand = false; + bool has_unsupported_bf16_output = false; for (int64 i = 0; i < crs->operand_count(); ++i) { operand_types[i] = crs->operand(i)->shape().element_type(); output_types[i] = ShapeUtil::GetSubshape(crs->shape(), {i}).element_type(); - if (operand_types[i] == F32 || output_types[i] == F32) { - has_f32 = true; + if (operand_types[i] == F32) { + f32_count += 1; } else if (operand_types[i] == BF16) { - has_bf16 = true; + bf16_count += 1; + if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) { + has_unsupported_bf16_operand = true; + } } - if (output_types[i] == BF16) { - has_bf16 = true; - has_bf16_output = true; + if (output_types[i] == F32) { + f32_count += 1; + } else if (output_types[i] == BF16) { + bf16_count += 1; + if (!bfloat16_support_->SupportsBF16Output(*crs)) { + has_unsupported_bf16_output = true; + } } } - for (int64 i = 0; i < crs->operand_count(); ++i) { + if (bf16_count == 0) { + return Status::OK(); + } + + auto should_convert_operand = [&](int64 i) { if (operand_types[i] != BF16) { - continue; + return false; } - if (bfloat16_support_->SupportsBF16Operand(*crs, i) && - (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) { - continue; + if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) { + return true; } - TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_)); - has_f32 = true; - } + if (bfloat16_support_->SupportsMixedPrecisions(*crs)) { + return false; + } + return has_unsupported_bf16_operand || has_unsupported_bf16_output || + f32_count > 0; + }; - if (!has_bf16_output) { - return Status::OK(); + for (int64 i = 0; i < crs->operand_count(); ++i) { + if (should_convert_operand(i)) { + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_)); + f32_count += 1; + bf16_count -= 1; + } } - if (bfloat16_support_->SupportsBF16Output(*crs) && - (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) { + if (!has_unsupported_bf16_output && + (bfloat16_support_->SupportsMixedPrecisions(*crs) || f32_count == 0 || + bf16_count == 0)) { return Status::OK(); } + std::vector materialized_users = crs->users(); std::vector output_elements(crs->operand_count()); auto original_shape = crs->shape(); for (int64 i = 0; i < crs->operand_count(); ++i) { @@ -209,7 +229,6 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( auto tuple = computation_->AddInstruction( HloInstruction::CreateTuple(output_elements)); - std::vector materialized_users = crs->users(); // Use the crs' shape temporarily, in order to pass checks in // ReplaceUseWith. *tuple->mutable_shape() = crs->shape(); @@ -221,41 +240,37 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( } Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) { - std::vector bf16_operands; - std::vector f32_operands; - bool has_f32 = false; - bool has_bf16 = false; + int f32_count = 0; + int bf16_count = 1; for (int64 i = 0; i < hlo->operand_count(); ++i) { if (hlo->operand(i)->shape().element_type() == F32) { - f32_operands.push_back(i); - has_f32 = true; + f32_count += 1; } else if (hlo->operand(i)->shape().element_type() == BF16) { - bf16_operands.push_back(i); - has_bf16 = true; + bf16_count += 1; } } if (hlo->shape().element_type() == F32) { - has_f32 = true; + f32_count += 1; } else if (hlo->shape().element_type() == BF16) { - has_bf16 = true; + bf16_count += 1; } std::vector bf16_called_comps; for (auto* comp : hlo->called_computations()) { bool comp_has_bf16 = false; if (comp->root_instruction()->shape().element_type() == F32) { - has_f32 = true; + f32_count += 1; } else if (comp->root_instruction()->shape().element_type() == BF16) { - has_bf16 = true; + bf16_count += 1; comp_has_bf16 = true; } for (auto* param : comp->parameter_instructions()) { if (param->shape().element_type() == F32) { - has_f32 = true; + f32_count += 1; } else if (param->shape().element_type() == BF16) { - has_bf16 = true; + bf16_count += 1; comp_has_bf16 = true; } } @@ -264,54 +279,69 @@ Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) { } } - if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) && has_bf16 && - has_f32) { - // Resolve unsupported mixed precision. - // - // See if we can change everything to BF16. - if (hlo->called_computations().empty() && - hlo->shape().element_type() == BF16) { - bool can_use_bf16 = true; - for (int i : f32_operands) { - if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo, - i) && - bfloat16_support_->SupportsBF16Operand(*hlo, i)) { - continue; - } - can_use_bf16 = false; - break; - } - if (can_use_bf16) { - for (int i : f32_operands) { - TF_RETURN_IF_ERROR( - InsertConvertBeforeOperand(hlo, i, BF16, computation_)); - } - return Status::OK(); - } - } - if (hlo->shape().element_type() == BF16) { - TF_RETURN_IF_ERROR( - ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); - } - for (int i : bf16_operands) { - TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); - } - return ConvertCalledComputations(hlo, bf16_called_comps); - } - - for (int i : bf16_operands) { - if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) { + // Resolve unsupported BF16 operands. + for (int i = 0; i < hlo->operand_count(); ++i) { + if (hlo->operand(i)->shape().element_type() == BF16 && + !bfloat16_support_->SupportsBF16Operand(*hlo, i)) { TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); + bf16_count -= 1; + f32_count += 1; } } + // Resolve unsupported BF16 output. if (hlo->shape().element_type() == BF16 && !bfloat16_support_->SupportsBF16Output(*hlo)) { TF_RETURN_IF_ERROR( ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); + bf16_count -= 1; + f32_count += 1; } - return Status::OK(); + // Resolve unsupported mixed precision after resolving unsupported BF16 + // operands and output, because the numbers of BF16 operands/output and F32 + // operands/output may have changed. + if (bfloat16_support_->SupportsMixedPrecisions(*hlo) || bf16_count == 0 || + f32_count == 0) { + return Status::OK(); + } + // See if we can change everything to BF16. + if (hlo->called_computations().empty() && + hlo->shape().element_type() == BF16) { + bool can_use_bf16 = true; + for (int i = 0; i < hlo->operand_count(); ++i) { + if (hlo->operand(i)->shape().element_type() == BF16) { + continue; + } + if ((bfloat16_support_->EffectiveOperandPrecisionIsBF16(*hlo, i) || + bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo, + i)) && + bfloat16_support_->SupportsBF16Operand(*hlo, i)) { + continue; + } + can_use_bf16 = false; + break; + } + if (can_use_bf16) { + for (int i = 0; i < hlo->operand_count(); ++i) { + if (hlo->operand(i)->shape().element_type() == F32) { + TF_RETURN_IF_ERROR( + InsertConvertBeforeOperand(hlo, i, BF16, computation_)); + } + } + return Status::OK(); + } + } + if (hlo->shape().element_type() == BF16) { + TF_RETURN_IF_ERROR( + ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); + } + for (int i = 0; i < hlo->operand_count(); ++i) { + if (hlo->operand(i)->shape().element_type() == BF16) { + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); + } + } + return ConvertCalledComputations(hlo, bf16_called_comps); } Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 66c3085842c4afe7ffc4d5891883e4cce9389d45..1afaefd9df9c5771fb9e134ae9050f3abb00ea4a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -41,13 +42,17 @@ class TestBFloat16Support : public BFloat16Support { hlo.opcode() == HloOpcode::kGetTupleElement) { return true; } + if (hlo.opcode() == HloOpcode::kDot) { + // Test that only the first operand of kDot supports BF16. + return operand_index == 0; + } return false; } bool SupportsBF16Output(const HloInstruction& hlo) const override { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kSubtract || - hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kDot || hlo.opcode() == HloOpcode::kTuple || hlo.opcode() == HloOpcode::kGetTupleElement) { return true; } @@ -70,6 +75,10 @@ class BFloat16NormalizationTest : public HloTestBase { BFloat16Normalization normalization(&bfloat16_support_); StatusOr result = normalization.Run(module); EXPECT_IS_OK(result.status()); + + HloVerifier verifier(/*allow_mixed_precision=*/true); + EXPECT_IS_OK(verifier.Run(module).status()); + return result.ValueOrDie(); } }; @@ -166,7 +175,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { Shape f32_input_shape = ShapeUtil::MakeShape(F32, {2, 4}); Shape f32_output_shape = ShapeUtil::MakeShape(F32, {4}); - Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {}); auto reduce_comp_builder = HloComputation::Builder("reduce_comp"); auto reduce_comp_param0 = reduce_comp_builder.AddInstruction( @@ -245,4 +254,34 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32); } +// Tests that the normalization should not cause unsupported mixed precision due +// to resolving unsupported BF16 operand. +TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { + auto builder = HloComputation::Builder(TestName()); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {4, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); + EXPECT_EQ(dot->shape().element_type(), F32); + EXPECT_EQ(dot->operand(0)->shape().element_type(), F32); + EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConvert); + EXPECT_EQ(dot->operand(1)->shape().element_type(), F32); + EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConvert); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 9246cb25d28b53008af092197a132db91236d61c..c26d2feef584faeff013a602409cdd58c2d44a5a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -15,10 +15,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_propagation.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -66,33 +69,53 @@ void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision( for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false); } + computations_visited_in_mutation_pass_.insert( + fusion->fused_instructions_computation()); } -void BFloat16Propagation::AdjustFusionParameters(HloInstruction* fusion) { - CHECK_EQ(fusion->fused_parameters().size(), fusion->operand_count()); - for (int64 i = 0; i < fusion->operand_count(); ++i) { - auto parameter = fusion->fused_parameter(i); - ShapeUtil::ForEachMutableSubshape( - parameter->mutable_shape(), - [&](Shape* subshape, const ShapeIndex& index) { - if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) { - return; - } - PrimitiveType operand_type = - ShapeUtil::GetSubshape(fusion->operand(i)->shape(), index) - .element_type(); - if (subshape->element_type() == operand_type) { - return; - } - CHECK(operand_type == F32 || operand_type == BF16); - subshape->set_element_type(operand_type); +void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision( + HloInstruction* while_hlo) { + CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile); + + // We are depending on the while node itself having already been analyzed for + // whether it can output BF16 and this has been adjusted in the output shape, + // and now we're looking to update the body and condition computations to + // match the new output shape, as well as recursively process the whole while + // node even if the output shape was not modified. + HloComputation* body = while_hlo->while_body(); + auto body_root = body->root_instruction(); + HloComputation* condition = while_hlo->while_condition(); + + ShapeUtil::ForEachMutableSubshape( + body_root->mutable_shape(), + [this, while_hlo, body_root](Shape* subshape, const ShapeIndex& index) { + if (subshape->element_type() != F32) { + return; + } + if (ShapeUtil::GetSubshape(while_hlo->shape(), index).element_type() == + BF16) { + subshape->set_element_type(BF16); changed_ = true; - VLOG(2) << "Fused parameter " << parameter->ToString() + VLOG(2) << "While body root " << body_root->ToString() << " at shape index " << index - << " adjusted to match operand in fusion " - << fusion->ToString(); - }); + << " changed to BF16 precision for while " + << while_hlo->ToString(); + } + }); + + auto body_insts = body->MakeInstructionPostOrder(); + for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend(); + ++inst_it) { + DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false); } + computations_visited_in_mutation_pass_.insert(body); + + auto condition_insts = condition->MakeInstructionPostOrder(); + for (auto inst_it = condition_insts.rbegin(); + inst_it != condition_insts.rend(); ++inst_it) { + DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false); + } + computations_visited_in_mutation_pass_.insert(condition); } bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, @@ -106,14 +129,45 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, continue; } for (const HloUse& use : value->uses()) { + if (!ContainsKey(instructions_visited_in_mutation_pass_, + use.instruction)) { + // We don't know yet whether use.instruction will consume BF16 since it + // hasn't been visited. Although we visit instructions in reverse + // topological order, this is still possible because there may be + // unvisited instruction that alias the same buffer. In this case, we + // aggressively skip this use, and if this causes inconsistency (e.g., + // one use is in BF16 but another use is in F32), it will be resolved at + // the end of the BFloat16Propagation pass. + continue; + } + // Any visited user that can accept BF16 has already been updated if + // necessary, e.g., the output has been changed to BF16 if it propagates + // precision, or a called computation's parameters have been changed to + // BF16 for fusions or whiles. if (use.instruction->opcode() == HloOpcode::kFusion) { - auto fused_parameter = + const auto* fused_parameter = use.instruction->fused_parameter(use.operand_number); if (ShapeUtil::GetSubshape(fused_parameter->shape(), use.operand_index) .element_type() != BF16) { return false; } continue; + } else if (use.instruction->opcode() == HloOpcode::kWhile) { + const auto* cond_parameter = + use.instruction->while_condition()->parameter_instruction( + use.operand_number); + if (ShapeUtil::GetSubshape(cond_parameter->shape(), use.operand_index) + .element_type() != BF16) { + return false; + } + const auto* body_parameter = + use.instruction->while_body()->parameter_instruction( + use.operand_number); + if (ShapeUtil::GetSubshape(body_parameter->shape(), use.operand_index) + .element_type() != BF16) { + return false; + } + continue; } if (bfloat16_support_->EffectiveOperandPrecisionIsBF16( *use.instruction, use.operand_number)) { @@ -147,24 +201,40 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, void BFloat16Propagation::DetermineAndMutateInstructionPrecision( HloInstruction* hlo, bool skip_parameters) { - // We handle any fusion computation after the instruction is handled, because - // we need to know a fusion's output shape before propagating inside its fused - // computation. - auto cleaner = tensorflow::gtl::MakeCleanup([this, hlo] { - if (hlo->opcode() == HloOpcode::kFusion) { - DetermineAndMutateFusionComputationPrecision(hlo); - } - }); + // We handle any fusion computation or while body/condition after the + // instruction is handled, because we need to know the output shape of a + // fusion or while before propagating inside its computations. + bool postpone_processing_called_computations = false; + auto cleaner = tensorflow::gtl::MakeCleanup( + [this, hlo, &postpone_processing_called_computations] { + if (!postpone_processing_called_computations) { + if (hlo->opcode() == HloOpcode::kFusion) { + DetermineAndMutateFusionComputationPrecision(hlo); + } else if (hlo->opcode() == HloOpcode::kWhile) { + DetermineAndMutateWhileComputationsPrecision(hlo); + } + } + instructions_visited_in_mutation_pass_.insert(hlo); + }); + + if (hlo->opcode() == HloOpcode::kWhile && + (caller_counts_[hlo->while_condition()] > 1 || + caller_counts_[hlo->while_body()] > 1)) { + postpone_processing_called_computations = true; + return; + } // Do not change precision for instructions related to entry and exit of a // computation, and control flow, because this pass might break the interfaces // or assumptions for them. if (hlo->opcode() == HloOpcode::kInfeed || // hlo->opcode() == HloOpcode::kOutfeed || // - hlo->opcode() == HloOpcode::kConstant || // + hlo->opcode() == HloOpcode::kSend || // + hlo->opcode() == HloOpcode::kSendDone || // + hlo->opcode() == HloOpcode::kRecv || // + hlo->opcode() == HloOpcode::kRecvDone || // hlo->opcode() == HloOpcode::kCustomCall || // hlo->opcode() == HloOpcode::kCall || // - hlo->opcode() == HloOpcode::kWhile || // hlo->opcode() == HloOpcode::kConditional || // (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) { return; @@ -229,6 +299,357 @@ bool BFloat16Propagation::InstructionIsCandidateForBF16Output( return true; } +void BFloat16Propagation::AdjustCalledComputationParameters( + HloInstruction* hlo) { + auto adjust_computation = + [this, hlo](HloComputation* computation, + tensorflow::gtl::ArraySlice operands) { + // Adjust parameters. + CHECK_EQ(operands.size(), computation->num_parameters()); + for (int64 i = 0; i < operands.size(); ++i) { + auto parameter = computation->parameter_instruction(i); + ShapeUtil::ForEachMutableSubshape( + parameter->mutable_shape(), + [this, i, hlo, &operands, parameter](Shape* subshape, + const ShapeIndex& index) { + if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) { + return; + } + PrimitiveType operand_type = + ShapeUtil::GetSubshape(operands[i]->shape(), index) + .element_type(); + if (subshape->element_type() == operand_type) { + return; + } + CHECK(operand_type == F32 || operand_type == BF16); + subshape->set_element_type(operand_type); + changed_ = true; + VLOG(2) << "Called computation parameter " + << parameter->ToString() << " at shape index " << index + << " adjusted to match operand in HLO " + << hlo->ToString(); + }); + } + }; + + switch (hlo->opcode()) { + case HloOpcode::kFusion: + adjust_computation(hlo->fused_instructions_computation(), + hlo->operands()); + break; + case HloOpcode::kWhile: + adjust_computation(hlo->while_condition(), hlo->operands()); + adjust_computation(hlo->while_body(), hlo->operands()); + break; + default: + break; + } +} + +void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) { + auto adjust_computation = [this, hlo](HloComputation* computation, + const Shape& output_shape) { + // Adjust root. + HloInstruction* root = computation->root_instruction(); + ShapeUtil::ForEachMutableSubshape( + root->mutable_shape(), [this, hlo, root, &output_shape]( + Shape* subshape, const ShapeIndex& index) { + if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) { + return; + } + const PrimitiveType output_type = + ShapeUtil::GetSubshape(output_shape, index).element_type(); + if (subshape->element_type() == output_type) { + return; + } + CHECK(output_type == F32 || output_type == BF16); + subshape->set_element_type(output_type); + // It's possible that output_type is F32, but the root instruction's + // type is BF16; e.g., a fusion node's output was changed to BF16 + // initially but then adjusted back to F32, and the fusion computation + // is now being adjusted after the fusion node. + if (output_type == F32) { + for (const auto* value : + dataflow_->GetValueSet(root, index).values()) { + // We rely on the fact that this adjustment works in reverse + // topological order so that called computation will be + // processed later. Adding the value to + // values_that_must_be_kept_as_f32_ will ensure the + // correctness of the adjustment for HLOs that will be + // processed later. + values_that_must_be_kept_as_f32_.insert(value); + } + } + changed_ = true; + VLOG(2) << "Called computation root " << root->ToString() + << " at shape index " << index + << " adjusted to match output shape of " << hlo->ToString(); + }); + }; + + switch (hlo->opcode()) { + case HloOpcode::kFusion: + adjust_computation(hlo->fused_instructions_computation(), hlo->shape()); + break; + case HloOpcode::kWhile: + adjust_computation(hlo->while_condition(), hlo->shape()); + adjust_computation(hlo->while_body(), hlo->shape()); + break; + default: + break; + } +} + +bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( + HloComputation* computation, + tensorflow::gtl::FlatSet* visited_computations) { + bool parameter_changed = false; + auto insts = computation->MakeInstructionPostOrder(); + // Do the adjustment on each instruction in the computation in reverse + // topological order. + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + auto hlo = *inst_it; + auto adjust_hlo_output = [this, hlo, ¶meter_changed]( + Shape* subshape, const ShapeIndex& index) { + if (subshape->element_type() != F32 && subshape->element_type() != BF16) { + return; + } + PrimitiveType type = BF16; + for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { + if (value->shape().element_type() == BF16) { + continue; + } + CHECK_EQ(value->shape().element_type(), F32); + type = F32; + break; + } + // It's possible that a user has been changed from BF16 to F32 + // during this final adjustment pass, so we need to check + // AllUsersConsumeBF16() again. + if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) { + type = F32; + } + if (type == F32) { + for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { + // We rely on the fact that this adjustment works in reverse + // topological order. Adding the value to + // values_that_must_be_kept_as_f32_ will ensure the correctness + // of the adjustment for HLOs that will be processed later. + values_that_must_be_kept_as_f32_.insert(value); + } + } + if (type != subshape->element_type()) { + subshape->set_element_type(type); + VLOG(2) << "HloInstruction output at shape index " << index + << " adjusted to " << *subshape << ": " << hlo->ToString(); + if (hlo->opcode() == HloOpcode::kParameter) { + parameter_changed = true; + } + } + }; + ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), adjust_hlo_output); + AdjustCalledComputationRoot(hlo); + if (hlo->opcode() == HloOpcode::kWhile) { + // We need to run on the while body and condition repeatedly until a fixed + // point is reached, i.e., the parameters do not change any more. We may + // need more than one iteration because the while input and output alias + // each other, so changing one input parameter requires changing the + // corresponding output element and thus may transitively require changing + // another input parameter. A fixed point will be reached because the + // parameters can only be changed from BF16 to F32, not the other way + // around. + tensorflow::gtl::FlatSet visited_in_while; + while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(), + &visited_in_while) || + ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(), + &visited_in_while)) { + visited_in_while.clear(); + ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), + adjust_hlo_output); + AdjustCalledComputationRoot(hlo); + } + visited_computations->insert(visited_in_while.begin(), + visited_in_while.end()); + } + } + // Now adjust parameters of called computations. + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + AdjustCalledComputationParameters(*inst_it); + } + return parameter_changed; +} + +Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( + HloModule* module) { + std::list computations_topological_order = + module->MakeComputationPostOrder(); + tensorflow::gtl::FlatSet resolved; + for (auto comp_it = computations_topological_order.rbegin(); + comp_it != computations_topological_order.rend(); ++comp_it) { + if (ContainsKey(resolved, *comp_it)) { + continue; + } + ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved); + } + + // We could have changed a fusion computation's root shape to have a different + // precision than the fusion node's output, if the fusion root does not + // define a buffer (e.g., a tuple). Now we add conversions after such fusion + // roots to make them match the fusion output. If the fusion output is a + // (possibly nested) tuple, we first create get-tuple-elements, then convert + // the unmatching leaf nodes, and finally create a new tuple as the fusion + // computation's root. If tuples and get-tuple-elements are created, we will + // run tuple simplifier and dead code elimination at the end (dead code is not + // allowed in fusion computation). E.g., + // + // (1) (2) (3) + // a b a b a b + // |\ | |\ | |\ | + // \ add -> |add -> | add + // \ | \ | convert | + // tuple tuple \ | + // / \ tuple + // gte gte + // | | + // convert | + // \ / + // tuple + // (1) a is F32 but tuple is BF16 + // (2) after adding conversion + // (3) after tuple simplifier and DCE. + bool needs_tuple_simplifier = false; + for (auto computation : computations_topological_order) { + auto insts = computation->MakeInstructionPostOrder(); + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + auto hlo = *inst_it; + if (hlo->opcode() != HloOpcode::kFusion) { + continue; + } + auto fusion_computation = hlo->fused_instructions_computation(); + auto fusion_root = fusion_computation->root_instruction(); + if (ShapeUtil::Compatible(fusion_root->shape(), hlo->shape())) { + continue; + } + ShapeTree converted_outputs(hlo->shape()); + // Iterate through nodes in the shape tree in pre-order and initialize + // each non-root node with a corresponding get-tuple-element. For a leaf + // node, if its shape does not match the fusion output, create a + // conversion node to overwrite the node value. + for (auto it = converted_outputs.begin(); it != converted_outputs.end(); + ++it) { + ShapeIndex output_index = it->first; + HloInstruction*& output = it->second; + const Shape subshape = + ShapeUtil::GetSubshape(hlo->shape(), output_index); + if (output_index.empty()) { + output = fusion_root; + } else { + ShapeIndex parent_index = output_index; + parent_index.pop_back(); + output = fusion_computation->AddInstruction( + HloInstruction::CreateGetTupleElement( + subshape, converted_outputs.element(parent_index), + output_index.back())); + } + if (ShapeUtil::IsTuple(subshape)) { + continue; + } + if (!ShapeUtil::Compatible( + subshape, + ShapeUtil::GetSubshape(fusion_root->shape(), output_index))) { + output = fusion_computation->AddInstruction( + HloInstruction::CreateConvert(subshape, output)); + } + } + // Iterate through nodes in the shape tree in reverse pre-order and create + // a tuple instruction for each non-leaf node where the elements are the + // values of its child nodes. + for (auto it = converted_outputs.rbegin(); it != converted_outputs.rend(); + ++it) { + ShapeIndex output_index = it->first; + HloInstruction*& output = it->second; + const Shape& subshape = + ShapeUtil::GetSubshape(hlo->shape(), output_index); + if (!ShapeUtil::IsTuple(subshape)) { + continue; + } + std::vector elements( + ShapeUtil::TupleElementCount(subshape)); + ShapeIndex child_index = output_index; + for (int64 i = 0; i < elements.size(); ++i) { + child_index.push_back(i); + elements[i] = converted_outputs.element(child_index); + child_index.pop_back(); + } + output = fusion_computation->AddInstruction( + HloInstruction::CreateTuple(elements)); + } + fusion_computation->set_root_instruction(converted_outputs.element({})); + needs_tuple_simplifier |= ShapeUtil::IsTuple(hlo->shape()); + } + } + + // We may have converted some constants from F32 to BF16, so adjust the + // constant literals in such cases. We do this here instead of when the + // constant node's is changed because 1) the HloInstruction interface does not + // allow resetting the literal so we have to create a new kConstant + // instruction to replace the old one, which invalidates dataflow analysis, + // and 2) it's possible that a kConstant's output gets changed to BF16 at the + // beginning but later on adjusted back to F32, so converting literals here + // can avoid repeated conversions. + // + // TODO(b/73833576): Consider resetting literal in HloInstruction. + bool needs_dce = needs_tuple_simplifier; + for (auto computation : computations_topological_order) { + for (auto hlo : computation->MakeInstructionPostOrder()) { + if (hlo->opcode() != HloOpcode::kConstant) { + continue; + } + if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) { + TF_ASSIGN_OR_RETURN( + auto converted_literal, + hlo->literal().ConvertToShape(hlo->shape(), + /*round_f32_to_bf16=*/true)); + auto new_constant = computation->AddInstruction( + HloInstruction::CreateConstant(std::move(converted_literal))); + TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant)); + needs_dce = true; + } + } + } + + if (needs_tuple_simplifier) { + TupleSimplifier tuple_simplifier; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + } + if (needs_dce) { + HloDCE dce; + TF_RETURN_IF_ERROR(dce.Run(module).status()); + } + return Status::OK(); +} + +Status BFloat16Propagation::RemoveNoopConversions(HloModule* module) { + for (auto computation : module->computations()) { + for (auto hlo : computation->MakeInstructionPostOrder()) { + if (hlo->opcode() != HloOpcode::kConvert) { + continue; + } + auto source = hlo->mutable_operand(0); + if (!ShapeUtil::Equal(source->shape(), hlo->shape())) { + continue; + } + const bool is_root = hlo == computation->root_instruction(); + TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(source)); + if (is_root) { + computation->set_root_instruction(source); + } + TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(hlo)); + } + } + return Status::OK(); +} + // The algorithm first does a forward pass (parameters to root) to determine a // set of instructions to consider using bfloat16, then does a backward pass to // determine the precisions of those instructions according to the need of @@ -278,56 +699,11 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { // It's possible that an instruction does not define a buffer, but the // defining instruction's shape has changed. So we need to adjust the output // shapes of instructions according to the HLO values they refer to. - for (auto comp_it = computations_topological_order.rbegin(); - comp_it != computations_topological_order.rend(); ++comp_it) { - auto insts = (*comp_it)->MakeInstructionPostOrder(); - // Do the adjustment on each instruction in the computation in reverse - // topological order. - for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { - auto hlo = *inst_it; - auto adjust_buffer = [this, hlo](Shape* subshape, - const ShapeIndex& index) { - if (subshape->element_type() != F32 && - subshape->element_type() != BF16) { - return; - } - PrimitiveType type = BF16; - for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { - if (value->shape().element_type() == BF16) { - continue; - } - CHECK_EQ(value->shape().element_type(), F32); - type = F32; - break; - } - // It's possible that a user has been changed from BF16 to F32 - // during this final adjustment pass, so we need to check - // AllUsersConsumeBF16() again. - if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) { - type = F32; - } - if (type == F32) { - for (const auto* value : - dataflow_->GetValueSet(hlo, index).values()) { - // We rely on the fact that this adjustment works in reverse - // topological order. Adding the value to - // values_that_must_be_kept_as_f32_ will ensure the correctness - // of the adjustment for HLOs that will be processed later. - values_that_must_be_kept_as_f32_.insert(value); - } - } - subshape->set_element_type(type); - }; - ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), adjust_buffer); - } - // Now adjust parameters of fusions inside this computation. - for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { - auto hlo = *inst_it; - if (hlo->opcode() == HloOpcode::kFusion) { - AdjustFusionParameters(hlo); - } - } - } + TF_RETURN_IF_ERROR(ResolveInconsistencyOfAliasingBuffers(module)); + + // This pass could have turned an F32 -> BF16 conversion to a no-op (BF16 -> + // BF16), so we remove them now. + TF_RETURN_IF_ERROR(RemoveNoopConversions(module)); return true; } diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index aa81dde3b0e6964576a056cd0f579a8ec9540d64..1744e9db90aeff269daa91eb68a1d61bb0fc3035 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -38,7 +38,8 @@ namespace xla { // be bitwise identical to that without this pass; this is possible if the // backend already reduces precision to BF16 on some HLO instructions. // -// This pass will not modify the signature of any non-fusion computation. +// This pass will not modify the signature of a computation, unless it is a +// fusion computation or its only caller is a while. // // !!! WARNING !!! This pass can introduce mixed precision in individual HLOs, // which has two issues: @@ -92,11 +93,53 @@ class BFloat16Propagation : public HloPassInterface { bool skip_parameters); // Special handling in the mutation pass for fusion computations. + // + // Precondition: hlo->opcode() == kFusion void DetermineAndMutateFusionComputationPrecision(HloInstruction* fusion); - // Makes the fusion parameters match the precision of the actual parameters - // passed to the fusion node. - void AdjustFusionParameters(HloInstruction* fusion); + // Special handling in the mutation pass for while computations. + // + // Precondition: hlo->opcode() == kWhile + void DetermineAndMutateWhileComputationsPrecision(HloInstruction* while_hlo); + + // The set of HloInstructions that have been visited in the mutation pass. + tensorflow::gtl::FlatSet + instructions_visited_in_mutation_pass_; + + // The set of HloComputations that have been visited in the mutation pass. + tensorflow::gtl::FlatSet + computations_visited_in_mutation_pass_; + + // *************************** + // Functions called by the final inconsistency resolving pass. + + // Adjusts the output shapes of HloInstructions such that if two + // HloInstructions have aliasing buffers in their outputs, they must have the + // same precision. + Status ResolveInconsistencyOfAliasingBuffers(HloModule* module); + + // Resolves inconsistency of aliasing buffers for the given computation, and + // recursively runs on a while instruction's condition and body until a fixed + // point is reached. + bool ResolveInconsistencyOfAliasingBuffersHelper( + HloComputation* computation, + tensorflow::gtl::FlatSet* visited_computations); + + // Makes the parameters of called computations match how they are called by + // the given HLO. + void AdjustCalledComputationParameters(HloInstruction* hlo); + + // Makes the root instructions of called computations match how they are used + // by the given HLO. + void AdjustCalledComputationRoot(HloInstruction* hlo); + + // *************************** + // Removes no-op conversions (same source and target shapes) that can be + // produced this pass. + Status RemoveNoopConversions(HloModule* module); + + // *************************** + // Functions called and state used by two or more passes. // Returns whether all uses of the given HloInstruction can consume BF16 // input. @@ -106,8 +149,10 @@ class BFloat16Propagation : public HloPassInterface { // The set of F32 HLO values that must be kept in F32. tensorflow::gtl::FlatSet values_that_must_be_kept_as_f32_; - // *************************** - // State used by both passes. + // Mapping from each HloComputation to the number of callers to it in the + // module. Populated at the beginning of this pass. + tensorflow::gtl::FlatMap caller_counts_; + const BFloat16Support* bfloat16_support_; std::unique_ptr dataflow_; diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 4c86c6b26e362c62804b627036e6ba7c078b402a..88f83014164ff726a11e45e762b9c082cf12720d 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -68,7 +69,7 @@ class BFloat16PropagationTest : public HloTestBase { // Returns whether the given HloInstruction's output element type is BF16 or // the only use of it is converting to BF16. - bool OutputsBF16(HloInstruction* inst) { + bool OutputsBF16(const HloInstruction* inst) { if (inst->shape().element_type() == BF16) { return true; } @@ -121,6 +122,41 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { EXPECT_FALSE(OutputsBF16(c)); } +// Tests that if a constant is converted to BF16 then its literal must also be +// converted. +TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + Array2D array_a(4, 4); + array_a.FillUnique(1.0f); + Array2D array_b(4, 4); + array_b.FillUnique(10.0f); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateFromArray(array_a))); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateFromArray(array_b))); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_TRUE(OutputsBF16(dot->operand(0))); + EXPECT_TRUE(OutputsBF16(dot->operand(1))); + EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); + EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); + LiteralTestUtil::ExpectEqual( + dot->operand(0)->literal(), + *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_a))); + LiteralTestUtil::ExpectEqual( + dot->operand(1)->literal(), + *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_b))); +} + // Tests that BF16 can be propagated through nested tuples. TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { auto builder = HloComputation::Builder(TestName()); @@ -287,6 +323,64 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { EXPECT_TRUE(OutputsBF16(b_f1)); } +// Tests that if 1) the root instruction of a fusion is a tuple, 2) the fusion +// outputs are only used by a dot, and 3) one element of the tuple is used by +// an add in the fusion computation, then the propagation pass should create a +// convert in the fusion computation to keep the add's operand in F32 but change +// the fusion output to BF16. E.g., the following fusion computation +// (F32, F32) fusion_computation(F32 a, F32 b) +// = tuple(F32 a, F32 add(F32 a, F32 b)) +// will be changed to +// (BF16, BF16) fusion_computation(F32 a, F32 b) +// = tuple(BF16 convert(a), BF16 add(F32 a, F32 b)) +TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + + auto builder_f = HloComputation::Builder("fusion0"); + HloInstruction* a_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* add_f = builder_f.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f)); + HloInstruction* tuple_f = + builder_f.AddInstruction(HloInstruction::CreateTuple({a_f, add_f})); + auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); + auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( + tuple_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, + comp_f)); + + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, fusion, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, fusion, 1)); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, gte0, gte1)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_TRUE(OutputsBF16(gte0)); + EXPECT_TRUE(OutputsBF16(gte1)); + EXPECT_FALSE(OutputsBF16(a_f)); + EXPECT_FALSE(OutputsBF16(b_f)); + EXPECT_TRUE(OutputsBF16(add_f)); + auto new_fusion_root = comp_f->root_instruction(); + EXPECT_EQ(new_fusion_root->opcode(), HloOpcode::kTuple); + EXPECT_EQ(new_fusion_root->operand(1), add_f); + EXPECT_EQ(new_fusion_root->operand(0)->opcode(), HloOpcode::kConvert); + EXPECT_TRUE(OutputsBF16(new_fusion_root->operand(0))); +} + // A select over tuples does not define the leaf buffers, so the types in // on_true and on_false must match, so that as long as one of them is F32, the // other must be F32 as well. @@ -332,4 +426,235 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) { EXPECT_TRUE(OutputsBF16(xpose)); } +// Tests that BF16 is propagated properly through while computations. +TEST_F(BFloat16PropagationTest, PropagateThroughWhile) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param1")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + + auto builder_cond = HloComputation::Builder("cond"); + auto cond_param = builder_cond.AddInstruction( + HloInstruction::CreateParameter(0, tuple->shape(), "cond_param")); + auto cond_lhs = builder_cond.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, cond_param, 0)); + auto cond_rhs = builder_cond.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, cond_param, 1)); + // This add should prevent RHS from using BF16 + auto cond_add_rhs = builder_cond.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs)); + auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, cond_lhs, cond_add_rhs)); + builder_cond.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1})))); + auto cond = module->AddEmbeddedComputation(builder_cond.Build()); + + auto builder_body = HloComputation::Builder("body"); + auto body_param = builder_body.AddInstruction( + HloInstruction::CreateParameter(0, tuple->shape(), "body_param")); + auto body_lhs = builder_body.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, body_param, 0)); + auto body_rhs = builder_body.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, body_param, 1)); + auto body_dot = builder_body.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); + builder_body.AddInstruction( + HloInstruction::CreateTuple({body_dot, body_rhs})); + auto body = module->AddEmbeddedComputation(builder_body.Build()); + + auto while_hlo = builder.AddInstruction( + HloInstruction::CreateWhile(tuple->shape(), cond, body, tuple)); + + auto lhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while_hlo, 0)); + auto rhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while_hlo, 1)); + auto dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs)); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_TRUE(OutputsBF16(lhs)); + EXPECT_FALSE(OutputsBF16(rhs)); + EXPECT_TRUE(OutputsBF16(body_dot)); + EXPECT_TRUE(OutputsBF16(body_lhs)); + EXPECT_FALSE(OutputsBF16(body_rhs)); + EXPECT_TRUE(OutputsBF16(cond_lhs)); + EXPECT_FALSE(OutputsBF16(cond_rhs)); + EXPECT_TRUE(OutputsBF16(add0)); + EXPECT_FALSE(OutputsBF16(add1)); +} + +// Tests that BF16 is not propagated through multiple whiles that invoke the +// same computation as long as one while prevents the propagation. +TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param1")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + HloInstruction* add2 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + HloInstruction* add3 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + HloInstruction* tuple0 = + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + HloInstruction* tuple1 = + builder.AddInstruction(HloInstruction::CreateTuple({add2, add3})); + + // Condition computation for the first while. + auto builder_cond0 = HloComputation::Builder("cond0"); + auto cond0_param = builder_cond0.AddInstruction( + HloInstruction::CreateParameter(0, tuple0->shape(), "cond0_param")); + auto cond0_lhs = builder_cond0.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, cond0_param, 0)); + auto cond0_rhs = builder_cond0.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, cond0_param, 1)); + // This add should prevent RHS from using BF16 + auto cond0_add_rhs = + builder_cond0.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs)); + auto cond0_dot = builder_cond0.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, cond0_lhs, cond0_add_rhs)); + builder_cond0.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond0.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond0_dot, {0, 0}, {1, 1}, {1, 1})), + builder_cond0.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond0_dot, {1, 1}, {2, 2}, {1, 1})))); + auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build()); + + // Condition computation for the second while. + auto builder_cond1 = HloComputation::Builder("cond1"); + auto cond1_param = builder_cond1.AddInstruction( + HloInstruction::CreateParameter(0, tuple1->shape(), "cond1_param")); + auto cond1_lhs = builder_cond1.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, cond1_param, 0)); + auto cond1_rhs = builder_cond1.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, cond1_param, 1)); + // This add should prevent LHS from using BF16 + auto cond1_add_lhs = + builder_cond1.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs)); + auto cond1_dot = builder_cond1.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, cond1_add_lhs, cond1_rhs)); + builder_cond1.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond1.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond1_dot, {0, 0}, {1, 1}, {1, 1})), + builder_cond1.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond1_dot, {1, 1}, {2, 2}, {1, 1})))); + auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build()); + + // Body computation shared by both whiles. + auto builder_body = HloComputation::Builder("body"); + auto body_param = builder_body.AddInstruction( + HloInstruction::CreateParameter(0, tuple0->shape(), "body_param")); + auto body_lhs = builder_body.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, body_param, 0)); + auto body_rhs = builder_body.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, body_param, 1)); + auto body_dot = builder_body.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); + builder_body.AddInstruction( + HloInstruction::CreateTuple({body_dot, body_rhs})); + auto body = module->AddEmbeddedComputation(builder_body.Build()); + + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(tuple0->shape(), cond0, body, tuple0)); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1)); + + auto lhs = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while0, 0)), + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while0, 1)))); + auto rhs = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while1, 0)), + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while1, 1)))); + auto dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs)); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_FALSE(OutputsBF16(body_dot)); + EXPECT_FALSE(OutputsBF16(body_rhs)); + EXPECT_FALSE(OutputsBF16(body_lhs)); + EXPECT_FALSE(OutputsBF16(cond0_lhs)); + EXPECT_FALSE(OutputsBF16(cond0_rhs)); + EXPECT_FALSE(OutputsBF16(cond1_lhs)); + EXPECT_FALSE(OutputsBF16(cond1_rhs)); + EXPECT_TRUE(OutputsBF16(cond0_add_rhs)); + EXPECT_TRUE(OutputsBF16(cond1_add_lhs)); + EXPECT_EQ(computation->root_instruction(), dot); +} + +// Tests that if this pass turns an F32 -> BF16 conversion into a no-op (BF16 -> +// BF16 conversion), then it will remove that conversion. +TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {4, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {4, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "param")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, param, param)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, param, param)); + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, tuple, 1)); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte0)); + HloInstruction* convert1 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte1)); + HloInstruction* add2 = builder.AddInstruction(HloInstruction::CreateBinary( + bf16_shape, HloOpcode::kAdd, convert0, convert1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), add2); + EXPECT_EQ(add2->operand(0), gte0); + EXPECT_EQ(add2->operand(1), gte1); + EXPECT_EQ(gte0->shape().element_type(), BF16); + EXPECT_EQ(gte1->shape().element_type(), BF16); + EXPECT_EQ(add0->shape().element_type(), BF16); + EXPECT_EQ(add1->shape().element_type(), BF16); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index 3fd9e24601f27633c8063e4574c7c4f91f30dcff..07b4b14b5ec1bdbc01345091105df69368b0b2fb 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -79,6 +79,7 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( case HloOpcode::kBroadcast: case HloOpcode::kClamp: case HloOpcode::kConcatenate: + case HloOpcode::kConvert: case HloOpcode::kCopy: case HloOpcode::kGetTupleElement: case HloOpcode::kMaximum: diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index b1e693da9d5af4babe619b8796007f2da318f6a8..dbe45e932cdeed00e959355d5b3199d2e858148f 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -48,6 +48,183 @@ using ::tensorflow::strings::HumanReadableNumBytes; using ::tensorflow::strings::Printf; using ::tensorflow::strings::StrAppend; +namespace { + +template +string ColocatedBufferSetsToString(const T& container, const char* title) { + string result; + StrAppend(&result, title, "\n"); + for (const auto& it : container) { + StrAppend(&result, "\t", it->ToString(), "\n"); + } + return result; +} + +// Walk the call graph of the HLO module and place each computation into either +// thread_local_computations or global_computations depending upon whether the +// computation requires thread-local allocations or global allocations. The +// elements in thread_local_computations and global_computations are in post +// order (if computation A has an instruction which calls computation B, then A +// will appear after B in the vector). +Status GatherComputationsByAllocationType( + const HloModule* module, + std::vector* thread_local_computations, + std::vector* global_computations) { + // Create a worklist of computations paired with whether the allocation must + // be thread-local. + std::deque> worklist; + worklist.push_back(std::make_pair(module->entry_computation(), + /*is_thread_local*/ false)); + + // Sets for quickly checking membership. Computations are returned in vectors + // for stable iteration. + FlatSet thread_local_set; + FlatSet global_set; + + while (!worklist.empty()) { + auto worklist_front = worklist.front(); + worklist.pop_front(); + const HloComputation* computation = worklist_front.first; + bool is_thread_local = worklist_front.second; + bool in_thread_local_set = thread_local_set.count(computation) > 0; + bool in_global_set = global_set.count(computation) > 0; + + // If the computation has already been added to the respective set, then + // nothing to do. + if ((is_thread_local && in_thread_local_set) || + (!is_thread_local && in_global_set)) { + continue; + } + + // If the computation has already been added to the other set this is an + // error condition because the global call to the computation (eg, + // while/call) may return a reference to one of the thread-local buffers to + // the calling computation which will become a dangling reference when the + // thread-local is deallocated with the call return. + if ((is_thread_local && in_global_set) || + (!is_thread_local && in_thread_local_set)) { + return InvalidArgument( + "computation %s has conflicting allocation requirements (global " + "and thread-local)", + computation->name().c_str()); + } + + if (is_thread_local) { + thread_local_set.insert(computation); + } else { + global_set.insert(computation); + } + + for (auto* instruction : computation->instructions()) { + for (HloComputation* subcomputation : + instruction->called_computations()) { + switch (instruction->opcode()) { + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kWhile: + // Call and while must be called from a computation with global + // allocations as they may return references to buffers inside the + // called computation which cannot be thread-local. + if (is_thread_local) { + return InvalidArgument( + "computation %s cannot contain call/while op because it " + "requires thread-local buffer allocations", + computation->name().c_str()); + } + worklist.push_back(std::make_pair(subcomputation, + false)); // Not thread local. + break; + case HloOpcode::kMap: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kFusion: + // Map/reduce etc computations are always thread-local. + worklist.push_back(std::make_pair(subcomputation, + true)); // Thread local. + break; + default: + return InternalError( + "Unexpected calling opcode: %s", + HloOpcodeString(instruction->opcode()).c_str()); + } + } + } + } + + // Add the computations to the vectors in post order. + for (auto* computation : module->MakeComputationPostOrder()) { + if (thread_local_set.count(computation) > 0) { + thread_local_computations->push_back(computation); + } else if (global_set.count(computation) > 0) { + global_computations->push_back(computation); + } + // If the computation is not reachable from the entry computation, then it + // will not appear in either thread_local_set or global_set. We don't bother + // assigning buffers for these. + } + return Status::OK(); +} + +// Checks that points-to set of 'instruction' is unambiguous and distinct +// (ensured by CopyInsertion), then adds the buffer from the points-to set at +// 'index' to 'colocated_set'. +const LogicalBuffer* AddBufferToColocatedSet( + const HloInstruction* instruction, const ShapeIndex& index, + const TuplePointsToAnalysis& points_to_analysis, + std::vector* colocated_set) { + // CopyInsertion ensures root points-to set is unambiguous and distinct. + const auto& points_to = points_to_analysis.GetPointsToSet(instruction); + DCHECK(!points_to.IsAmbiguous()); + colocated_set->push_back(points_to.element(index)[0]); + return colocated_set->back(); +} + +// Given the interference map of a graph (the list of interfering node indices +// for each node), perform graph coloring such that interfering nodes are +// assigned to different colors. Returns the assigned color of the nodes, where +// the colors are represented as integer values [0, color_count). +std::vector ColorInterferenceGraph( + const std::vector>& interference_map) { + const int64 node_count = interference_map.size(); + + // Sort the nodes such that we assign nodes with more interference first. This + // relies on the common heuristic of assigning the most constrained node + // first, but it would be good to investigate other ordering heuristics too. + std::vector nodes(node_count); + std::iota(nodes.begin(), nodes.end(), 0); + std::sort(nodes.begin(), nodes.end(), + [&interference_map](const int64 i, const int64 j) { + return interference_map[i].size() > interference_map[j].size(); + }); + + const int64 kColorUnassigned = -1; + std::vector assigned_colors(node_count, kColorUnassigned); + for (int64 node : nodes) { + // Mark the colors that are already assigned to the neighbors. + std::vector available_colors(node_count, true); + for (int64 neighbor : interference_map[node]) { + int64 color = assigned_colors[neighbor]; + if (color != kColorUnassigned) { + available_colors[color] = false; + } + } + + // Find the color that is not yet assigned to the neighbors. + int64 color = kColorUnassigned; + for (color = 0; color < available_colors.size(); ++color) { + if (available_colors[color]) { + break; + } + } + CHECK_NE(color, kColorUnassigned); + assigned_colors[node] = color; + } + return assigned_colors; +} + +} // namespace + size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { uint64 h = std::hash()(s.index()); h = tensorflow::Hash64Combine(h, std::hash()(s.offset())); @@ -115,6 +292,112 @@ BufferAllocationProto BufferAllocation::ToProto() const { return proto; } +std::pair> +BufferAllocation::ComputePeakMemoryLogicalBuffers() const { + if (HeapTraces().empty()) { + // Just return the largest LogicalBuffer in the allocation. + const LogicalBuffer* largest_buffer = nullptr; + int64 largest_size = 0; + for (const auto& pair : assigned_buffers()) { + const LogicalBuffer* buffer = pair.first; + int64 size = pair.second.size; + if (largest_buffer == nullptr) { + largest_buffer = buffer; + largest_size = size; + continue; + } + // Tie-break with LogicalBuffer::Id so the return value is stable relative + // to changing addresses. + if (size > largest_size || + ((size == largest_size) && (largest_buffer->id() > buffer->id()))) { + largest_buffer = buffer; + largest_size = size; + } + } + CHECK(largest_buffer != nullptr) + << "No logical buffers in allocation: " << ToString(); + return {largest_size, {largest_buffer}}; + } + + // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical + // buffers in this allocation. + tensorflow::gtl::FlatMap + id_to_buffer; + tensorflow::gtl::FlatMap buffer_sizes; + for (const auto& pair : assigned_buffers()) { + const LogicalBuffer* buffer = pair.first; + const OffsetSize& offset_size = pair.second; + id_to_buffer[buffer->id()] = buffer; + buffer_sizes[buffer] = offset_size.size; + } + + // Returns how much the given event increases the total size of live + // buffers. Can be negative. + auto memory_delta = [this, &id_to_buffer, &buffer_sizes]( + const HeapSimulatorTrace::Event& event) -> int64 { + const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); + const int64 buffer_size = buffer_sizes.at(buffer); + if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { + return buffer_size; + } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { + // Sharing a buffer does not change the live set size for the purposes of + // the heap simulator. Even though the shared-with buffer may be smaller, + // the entire allocation remains live. + return 0; + } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { + return -1 * buffer_size; + } + LOG(FATAL) << "Unknown event kind: " << event.kind(); + }; + + int64 total_max_live_size = 0; + std::vector live_buffers_vector; + for (const HeapSimulatorTrace& heap_trace : HeapTraces()) { + // First compute the size of the maximal live set. + int64 max_live_size = 0; + int64 live_size = 0; + for (const auto& event : heap_trace.events()) { + live_size += memory_delta(event); + if (max_live_size < live_size) { + max_live_size = live_size; + } + } + + // Next gather the set of logical buffers live at the earliest point of + // maximal live set size. + tensorflow::gtl::FlatSet live_buffers; + live_size = 0; + for (const auto& event : heap_trace.events()) { + const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); + if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { + InsertOrDie(&live_buffers, buffer); + } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { + // Nothing to do. + } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { + CHECK(ContainsKey(live_buffers, buffer)); + live_buffers.erase(buffer); + } + + live_size += memory_delta(event); + if (live_size == max_live_size) { + break; + } + } + CHECK_EQ(live_size, max_live_size); + total_max_live_size += max_live_size; + + live_buffers_vector.insert(live_buffers_vector.end(), live_buffers.begin(), + live_buffers.end()); + } + + // Stabily sort the live buffers. + std::sort(live_buffers_vector.begin(), live_buffers_vector.end(), + [](const LogicalBuffer* a, const LogicalBuffer* b) { + return a->id() < b->id(); + }); + return {total_max_live_size, live_buffers_vector}; +} + string BufferAllocation::ToString() const { string output; Appendf(&output, "allocation %lld: %p, size %lld", index_, this, size()); @@ -348,6 +631,7 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation, // Combines allocations of temporary buffers of the same color into one big // BufferAllocation. void BufferAssignment::CombineTempAllocations() { + VLOG(1) << "CombineTempAllocations()"; FlatMap combined_allocation_map; @@ -369,11 +653,16 @@ void BufferAssignment::CombineTempAllocations() { if (combined_it == combined_allocation_map.end()) { // We have found the first temp allocation of this color. Collect // the other temp allocations of the same color into it. + VLOG(1) << "Combined temp allocation for color " << color + << " is: " << temp_allocation; combined_allocation_map.emplace(color, temp_allocation); continue; } auto* combined_allocation = &combined_it->second; + VLOG(1) << "Combined allocation absorbing temp allocation: " + << temp_allocation; + // Each temp allocation is placed end-to-end, accounting for alignment. // The offset of each buffer in the combined allocation is computed from // the base offset of the allocation. @@ -387,6 +676,10 @@ void BufferAssignment::CombineTempAllocations() { const int64 size = buffer_offset_size.second.size; combined_allocation->AddAssignment(*buffer, base + offset, size); } + if (!temp_allocation.HeapTraces().empty()) { + CHECK_EQ(temp_allocation.HeapTraces().size(), 1); + combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front()); + } } // Replace all existing temporary allocations with the new combined // allocations. @@ -516,123 +809,13 @@ BufferAssignmentProto BufferAssignment::ToProto() const { for (const BufferAllocation& allocation : Allocations()) { BufferAllocationProto proto_allocation = allocation.ToProto(); proto.add_buffer_allocations()->Swap(&proto_allocation); - } - for (const HeapSimulatorTrace& trace : heap_simulator_traces_) { - *proto.add_heap_simulator_traces() = trace; - } - return proto; -} - -namespace { - -// Walk the call graph of the HLO module and place each computation into either -// thread_local_computations or global_computations depending upon whether the -// computation requires thread-local allocations or global allocations. The -// elements in thread_local_computations and global_computations are in post -// order (if computation A has an instruction which calls computation B, then A -// will appear after B in the vector). -Status GatherComputationsByAllocationType( - const HloModule* module, - std::vector* thread_local_computations, - std::vector* global_computations) { - // Create a worklist of computations paired with whether the allocation must - // be thread-local. - std::deque> worklist; - worklist.push_back(std::make_pair(module->entry_computation(), - /*is_thread_local*/ false)); - - // Sets for quickly checking membership. Computations are returned in vectors - // for stable iteration. - FlatSet thread_local_set; - FlatSet global_set; - - while (!worklist.empty()) { - auto worklist_front = worklist.front(); - worklist.pop_front(); - const HloComputation* computation = worklist_front.first; - bool is_thread_local = worklist_front.second; - bool in_thread_local_set = thread_local_set.count(computation) > 0; - bool in_global_set = global_set.count(computation) > 0; - - // If the computation has already been added to the respective set, then - // nothing to do. - if ((is_thread_local && in_thread_local_set) || - (!is_thread_local && in_global_set)) { - continue; + for (const HeapSimulatorTrace& heap_trace : allocation.HeapTraces()) { + *proto.add_heap_simulator_traces() = heap_trace; } - - // If the computation has already been added to the other set this is an - // error condition because the global call to the computation (eg, - // while/call) may return a reference to one of the thread-local buffers to - // the calling computation which will become a dangling reference when the - // thread-local is deallocated with the call return. - if ((is_thread_local && in_global_set) || - (!is_thread_local && in_thread_local_set)) { - return InvalidArgument( - "computation %s has conflicting allocation requirements (global " - "and thread-local)", - computation->name().c_str()); - } - - if (is_thread_local) { - thread_local_set.insert(computation); - } else { - global_set.insert(computation); - } - - for (auto* instruction : computation->instructions()) { - for (HloComputation* subcomputation : - instruction->called_computations()) { - switch (instruction->opcode()) { - case HloOpcode::kCall: - case HloOpcode::kConditional: - case HloOpcode::kWhile: - // Call and while must be called from a computation with global - // allocations as they may return references to buffers inside the - // called computation which cannot be thread-local. - if (is_thread_local) { - return InvalidArgument( - "computation %s cannot contain call/while op because it " - "requires thread-local buffer allocations", - computation->name().c_str()); - } - worklist.push_back(std::make_pair(subcomputation, - false)); // Not thread local. - break; - case HloOpcode::kMap: - case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: - case HloOpcode::kSelectAndScatter: - case HloOpcode::kFusion: - // Map/reduce etc computations are always thread-local. - worklist.push_back(std::make_pair(subcomputation, - true)); // Thread local. - break; - default: - return InternalError( - "Unexpected calling opcode: %s", - HloOpcodeString(instruction->opcode()).c_str()); - } - } - } - } - - // Add the computations to the vectors in post order. - for (auto* computation : module->MakeComputationPostOrder()) { - if (thread_local_set.count(computation) > 0) { - thread_local_computations->push_back(computation); - } else if (global_set.count(computation) > 0) { - global_computations->push_back(computation); - } - // If the computation is not reachable from the entry computation, then it - // will not appear in either thread_local_set or global_set. We don't bother - // assigning buffers for these. } - return Status::OK(); + return proto; } -} // namespace - /* static */ StatusOr> BufferAssigner::Run( const HloModule* module, std::unique_ptr hlo_ordering, @@ -1064,7 +1247,8 @@ void BufferAssigner::AssignBuffersFromHeapSimulator( assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size); } - assignment->heap_simulator_traces_.push_back(result.debug_trace); + VLOG(1) << "Ran heap simulation for allocation: " << allocation->ToString(); + allocation->AddHeapTrace(result.debug_trace); } // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining @@ -1085,7 +1269,8 @@ void BufferAssigner::AddSetToColocatedBufferSets( if (colocated_set.empty()) { return; } - + VLOG(5) << ColocatedBufferSetsToString(colocated_set, + "Adding colocated buffer set"); // Find existing sets that overlap with at least one buffer from the // colocated_set. The resulting 'overlap_set_indices' will have at most // colocated_buffer_sets->size() entries, and will be in increasing order. @@ -1093,6 +1278,10 @@ void BufferAssigner::AddSetToColocatedBufferSets( for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) { for (const LogicalBuffer* buffer : colocated_set) { if ((*colocated_buffer_sets)[index].count(buffer) > 0) { + VLOG(5) << "Found overlap with existing set on buffer " + << buffer->ToString() << "\n" + << ColocatedBufferSetsToString((*colocated_buffer_sets)[index], + "Overlapping set"); overlap_set_indices.push_back(index); break; } @@ -1104,6 +1293,7 @@ void BufferAssigner::AddSetToColocatedBufferSets( colocated_buffer_sets->emplace_back(); colocated_buffer_sets->back().insert(colocated_set.begin(), colocated_set.end()); + VLOG(5) << "No overlap found, new group created"; return; } @@ -1115,6 +1305,8 @@ void BufferAssigner::AddSetToColocatedBufferSets( first->insert(overlap_set.begin(), overlap_set.end()); } first->insert(colocated_set.begin(), colocated_set.end()); + VLOG(5) << ColocatedBufferSetsToString( + *first, "Result of the colocated buffer set merging"); // Remove overlap sets that we just merged. The offset accounts for the fact // that as elements are erased, the indices need to be adjusted. Keep in mind @@ -1125,67 +1317,6 @@ void BufferAssigner::AddSetToColocatedBufferSets( } } -namespace { - -// Checks that points-to set of 'instruction' is unambiguous and distinct -// (ensured by CopyInsertion), then adds the buffer from the points-to set at -// 'index' to 'colocated_set'. -const LogicalBuffer* AddBufferToColocatedSet( - const HloInstruction* instruction, const ShapeIndex& index, - const TuplePointsToAnalysis& points_to_analysis, - std::vector* colocated_set) { - // CopyInsertion ensures root points-to set is unambiguous and distinct. - const auto& points_to = points_to_analysis.GetPointsToSet(instruction); - DCHECK(!points_to.IsAmbiguous()); - colocated_set->push_back(points_to.element(index)[0]); - return colocated_set->back(); -} - -// Given the interference map of a graph (the list of interfering node indices -// for each node), perform graph coloring such that interfering nodes are -// assigned to different colors. Returns the assigned color of the nodes, where -// the colors are represented as integer values [0, color_count). -std::vector ColorInterferenceGraph( - const std::vector>& interference_map) { - const int64 node_count = interference_map.size(); - - // Sort the nodes such that we assign nodes with more interference first. This - // relies on the common heuristic of assigning the most constrained node - // first, but it would be good to investigate other ordering heuristics too. - std::vector nodes(node_count); - std::iota(nodes.begin(), nodes.end(), 0); - std::sort(nodes.begin(), nodes.end(), - [&interference_map](const int64 i, const int64 j) { - return interference_map[i].size() > interference_map[j].size(); - }); - - const int64 kColorUnassigned = -1; - std::vector assigned_colors(node_count, kColorUnassigned); - for (int64 node : nodes) { - // Mark the colors that are already assigned to the neighbors. - std::vector available_colors(node_count, true); - for (int64 neighbor : interference_map[node]) { - int64 color = assigned_colors[neighbor]; - if (color != kColorUnassigned) { - available_colors[color] = false; - } - } - - // Find the color that is not yet assigned to the neighbors. - int64 color = kColorUnassigned; - for (color = 0; color < available_colors.size(); ++color) { - if (available_colors[color]) { - break; - } - } - CHECK_NE(color, kColorUnassigned); - assigned_colors[node] = color; - } - return assigned_colors; -} - -} // namespace - std::vector BufferAssigner::MergeColocatedBufferSets( const std::vector& colocated_buffer_sets, @@ -1208,26 +1339,35 @@ BufferAssigner::MergeColocatedBufferSets( auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness, &buffer_size, &is_entry_parameter](int64 i, int64 j) { - for (auto& buffer_a : colocated_buffer_sets[i]) { - for (auto& buffer_b : colocated_buffer_sets[j]) { - // Do not merge if the set includes live outs or entry parameters. - if ((buffer_liveness.MaybeLiveOut(*buffer_a) && - is_entry_parameter(*buffer_b)) || - (buffer_liveness.MaybeLiveOut(*buffer_b) && - is_entry_parameter(*buffer_a))) { + // Do not merge if one of the sets includes live outs or entry parameters. + for (int64 key : {i, j}) { + for (auto& buffer : colocated_buffer_sets[key]) { + if (buffer_liveness.MaybeLiveOut(*buffer) || + is_entry_parameter(*buffer)) { return true; } - // Do not merge if the buffers interfere with each other. + } + } + + // Colocated sets satisfy the invariant that all buffers within a set have + // the same size. That means we need to check whether the size is the same + // between the two sets, but also that it's enough to look at just one + // buffer within each set. + if (buffer_size(**colocated_buffer_sets[i].begin()) != + buffer_size(**colocated_buffer_sets[j].begin())) { + return true; + } + + // Do not merge if some pair of buffers interferes with each other. + for (auto& buffer_a : colocated_buffer_sets[i]) { + for (auto& buffer_b : colocated_buffer_sets[j]) { if (buffer_a->id() != buffer_b->id() && buffer_liveness.MayInterfere(*buffer_a, *buffer_b)) { return true; } - // Do not merge if the buffer sizes are different. - if (buffer_size(*buffer_a) != buffer_size(*buffer_b)) { - return true; - } } } + return false; }; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 6b7fd0014d103ef0617afcc5cb3f663554a01aa4..3086d0e2ca0026547134285b8ceb357390fc7ece 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -192,6 +192,37 @@ class BufferAllocation { !is_thread_local(); } + // Add a heap trace which was used to assign slices to logical buffers in this + // allocation. A single BufferAllocation may include multiple heap traces + // in the case of the temporary block where there is a heap trace per + // computation. + void AddHeapTrace(const HeapSimulatorTrace& heap_trace) { + heap_traces_.push_back(heap_trace); + } + + // Return the set of heap traces used to assign slices to logical buffers in + // this allocation. + const std::vector HeapTraces() const { + return heap_traces_; + } + + // Compute and return the LogicalBuffers which are live at the point of peak + // memory usage for the given allocation. The point of peak memory usage is + // the point at which the total size of all live logical buffers is + // maximal. If peak memory is reached at multiple points, the set of logical + // buffers live at the earliest maximal point is returned. The vector is + // stabily asserted by LogicalBuffer::Index. + // + // The return value is a pair of total size of the logical buffers at peak, + // and the buffers themselves. + std::pair> + ComputePeakMemoryLogicalBuffers() const; + + // Get the number of bytes lost to fragmentation. This is equal to the + // difference between the size of the allocation and the size of the maximal + // live set. + int64 fragmentation_bytes() const { return fragmentation_bytes_; } + bool operator==(const BufferAllocation& other) const { return index_ == other.index_; } @@ -257,6 +288,9 @@ class BufferAllocation { // Mapping from the set of buffers assigned to this allocation to their // logical offsets and sizes. tensorflow::gtl::FlatMap assigned_buffers_; + + int64 fragmentation_bytes_ = 0; + std::vector heap_traces_; }; // Add stream operators for nicer output of CHECK/RET_CHECK failures. @@ -441,7 +475,6 @@ class BufferAssignment { LogicalBuffer::AlignmentFunction color_alignment_; Stats stats_; - std::vector heap_simulator_traces_; TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment); }; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index cd73654b8f666c4b96c000235cc3ad2cd0a46c17..513a8785bbd52b0a3bfa3642bbfc62b1035ffb17 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -37,14 +37,16 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" namespace xla { - namespace { +using ::testing::UnorderedElementsAre; + // DFS visitor that collects the instructions referenced by a computation // without descending into nested computations, i.e., only from the operands. class InstructionListVisitor : public DfsHloVisitorWithDefault { @@ -101,6 +103,22 @@ class BufferAssignmentTest : public HloTestBase { .ConsumeValueOrDie(); } + std::unique_ptr RunBufferAssignmentWithInstructionSequence( + HloModule* module, + tensorflow::gtl::ArraySlice instruction_sequence, + int64 alignment = 1) { + SequentialHloOrdering::HloModuleSequence module_sequence; + module_sequence[module->entry_computation()] = + std::vector(instruction_sequence.begin(), + instruction_sequence.end()); + return BufferAssigner::Run( + module, + xla::MakeUnique(module, module_sequence), + backend().compiler()->BufferSizeBytesFunction(), + [alignment](LogicalBuffer::Color) { return alignment; }) + .ConsumeValueOrDie(); + } + // Builds an x+1.0 computation to use in a Map. std::unique_ptr BuildMapComputationPlus1(const string& name) { auto builder = HloComputation::Builder(name); @@ -1370,7 +1388,7 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { auto element_slices = assignment->GetAllSlices(select, /*index=*/{0}); EXPECT_EQ(2, element_slices.size()); EXPECT_THAT(element_slices, - ::testing::UnorderedElementsAre( + UnorderedElementsAre( assignment->GetUniqueSlice(tuple_param0, /*index=*/{0}) .ConsumeValueOrDie(), assignment->GetUniqueSlice(tuple_param1, /*index=*/{0}) @@ -1473,6 +1491,98 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { } } +TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { + // paramscalar ------- (mul) -- (add) -- (sub) + // / / / + // param0[100] -------/ / / + // / / + // param1[100] --------------/--------/ + auto builder = HloComputation::Builder(TestName()); + auto paramscalar = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec100_, "")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32vec100_, "")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); + builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kSubtract, add, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + + // Trivially, the set of peak memory logical buffer(s) of an allocation with a + // single logical buffer should be exactly the logical buffer in that + // allocation. + const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); + int64 peak_size; + std::vector peak_buffers; + + std::tie(peak_size, peak_buffers) = + mul_buffer.ComputePeakMemoryLogicalBuffers(); + EXPECT_EQ(peak_size, ShapeUtil::ByteSizeOf(f32vec100_)); + ASSERT_EQ(peak_buffers.size(), 1); + EXPECT_EQ(peak_buffers[0]->instruction(), mul); +} + +TEST_F(BufferAssignmentTest, PeakBuffers) { + // Compute the peak liveness buffers of the following sequence: + // + // %param = ... + // %log = log(%param) + // %rev = reverse(%log) + // %neg = neg(%param) + // %concat = concat(%rev, %neg) + // ROOT %root = slice(concat) + // + // In the temporary block, the set of live buffers at peak memory use should + // be {%rev, %neg, %concat}. This occurs right at the concat itself. + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "")); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param)); + auto rev = builder.AddInstruction( + HloInstruction::CreateReverse(f32vec100_, log, {0})); + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param)); + const Shape concat_shape = ShapeUtil::MakeShape(F32, {200}); + auto concat = builder.AddInstruction( + HloInstruction::CreateConcatenate(concat_shape, {rev, neg}, 0)); + // Make the root tiny so no interior nodes can share its buffer. + auto root = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1})); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignmentWithInstructionSequence( + module.get(), {param, log, rev, neg, concat, root}); + + // The temporary buffer should hold the 4 interior instructions. + const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat); + EXPECT_FALSE(buffer.IsInputOrOutput()); + EXPECT_TRUE(buffer.IsPreallocatedTempBuffer()); + ASSERT_EQ(buffer.assigned_buffers().size(), 4); + + int64 peak_size; + std::vector peak_buffers; + std::tie(peak_size, peak_buffers) = buffer.ComputePeakMemoryLogicalBuffers(); + + // The peak live set should be concat and its inputs. + EXPECT_EQ(peak_size, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F32, {400}))); + ASSERT_EQ(peak_buffers.size(), 3); + std::vector peak_instructions; + for (const LogicalBuffer* logical_buffer : peak_buffers) { + peak_instructions.push_back(logical_buffer->instruction()); + } + EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat)); +} + class WhileBufferAssignmentTest : public HloTestBase { protected: std::unique_ptr BuildWhileConditionComputation( @@ -1587,6 +1697,81 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie()); } +// Tests that two colocated buffer sets are not merged if an entry parameter +// buffer belongs to either of the colocation sets (b/73267882). +// +// %param --> %while.0 --> %mul --> %while.1 --> %broadcast +// +// %while.0 body just forwards the init value, so the loop carried variable +// remains the constant, whereas %while.1 changes the loop carried variable. +TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithEntryParameter) { + const Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + + const char* module_str = R"( +HloModule test_module + +%cond.v0 { + %param = s32[] parameter(0) + ROOT %constant = pred[] constant(true) +} + +%cond.v1 { + %param.0 = s32[] parameter(0) + ROOT %constant.0 = pred[] constant(true) +} + +%body.v0 { + ROOT %param.1 = s32[] parameter(0) +} + +%body.v1 { + %param.2 = s32[] parameter(0) + ROOT add = s32[] add(%param.2, %param.2) +} + +ENTRY %test_module { + %param.3 = s32[] parameter(0) + %while.0 = s32[] while(%param.3), condition=%cond.v0, body=%body.v0 + %mul = s32[] multiply(%while.0, %while.0) + %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1 + ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(module_str)); + + // Run CopyInsertion and check if the graph constructed above doesn't need + // any copies inserted for BufferAssignment to run. + int64 instruction_count = module->instruction_count(); + CopyInsertion copy_insertion; + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + ASSERT_EQ(instruction_count, module->instruction_count()); + + // Get the instructions in the module. + const HloInstruction* bcast = module->entry_computation()->root_instruction(); + const HloInstruction* param = + module->entry_computation()->parameter_instruction(0); + ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); + const HloInstruction* while1 = bcast->operand(0); + ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); + const HloInstruction* while0 = while1->operand(0)->operand(0); + ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); + + // Run buffer assignment. + auto assignment = RunBufferAssignment(module.get()); + TF_ASSERT_OK_AND_ASSIGN(auto slice_param, + assignment->GetUniqueSlice(param, {})); + TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, + assignment->GetUniqueSlice(while0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto slice_while1, + assignment->GetUniqueSlice(while1, {})); + + // The parameter slice is part of the while0's colocation set (init value), + // but not merged into the while1's colocation set. + EXPECT_EQ(slice_param, slice_while0); + EXPECT_NE(slice_param, slice_while1); +} + // Tests that the colocated buffers for while instructions are properly assigned // during buffer assignment such that the result tuple elements are not assigned // to the same buffer. diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index dab73596e1639eed62151197048ee8d29570b20a..c83da9eddc8f8b156dd9acfc99b393bf844575da 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -72,8 +72,7 @@ CompileOnlyService::CompileAheadOfTime( VersionedComputationHandle versioned_handle = user_computation->GetVersionedHandle(); - // TODO(b/63773457): Track DebugOptions in AotCompilationOptions. - DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); + const DebugOptions& debug_options = options.debug_options(); // Dump computation proto state if flag is set. const string& directory_path = debug_options.xla_dump_computations_to(); @@ -101,7 +100,7 @@ CompileOnlyService::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig(*program_shape, instance.argument_layouts, - &execution_options, *user_computation)); + &execution_options, user_computation)); TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, computation_tracker_.BuildHloModule( diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index e2e9d2a0c048fec6c6ffbeef1223ae0e6aef50d1..0392d4af48a040c4a648f7bf9bf21a62ce03a990 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -86,4 +86,7 @@ Compiler::GetPlatformCompilers() { return compilers->at(platform->id()).get(); } +AotCompilationOptions::AotCompilationOptions() + : debug_options_(legacy_flags::GetDebugOptionsFromFlags()) {} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 74fd24edf88d44b2dfdc87556b0af43987e69e08..b4b53ae2ed425a48de5bcb6ba5c37b5d37e1f371 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -79,11 +79,15 @@ class AotCompilationOptions { device_allocator_ = device_allocator; } + const DebugOptions& debug_options() const { return debug_options_; } + DebugOptions* mutable_debug_options() { return &debug_options_; } + protected: - AotCompilationOptions() = default; + AotCompilationOptions(); private: DeviceMemoryAllocator* device_allocator_ = nullptr; + DebugOptions debug_options_; }; // Abstract compiler interface that is subclassed for compilation on a @@ -123,7 +127,7 @@ class Compiler { // Compiles the HLO module for execution on a device given by the executor, // and returns an executable object or an error status. No HLO passes are // applied to module. Generally a module should be passed through RunHloPasses - // prior to calling this method because the some HLO passes are required for + // prior to calling this method because some HLO passes are required for // correctness. Takes ownership of the HLO module and is free to transform it. // // The compiler may optionally specialize to the individual device diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc new file mode 100644 index 0000000000000000000000000000000000000000..f35de080853f7ec986565cb2df1050946ac3f244 --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -0,0 +1,106 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/conditional_simplifier.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { + +// Tries to replace a conditional with a call operation of the corresponding +// computation. If the given conditional has a constant predicate, tries to +// replace it with a call to its true/false computation as appropirate and then +// inline that computation. +// +// Returns true if it made a change to the graph. +static StatusOr TryRemoveConditional(HloInstruction* conditional) { + CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); + // Do not remove conditionals that contain side-effecting instructions or + // have control predecessors/successors in either true/false computation. + if (!conditional->parent()->IsRemovable(conditional) || + conditional->HasSideEffect()) { + VLOG(2) << "Not attempting to remove conditional as it is not removable or " + "has side effect: " + << conditional->ToShortString(); + return false; + } + + if (conditional->operand(0)->opcode() != HloOpcode::kConstant) { + VLOG(2) << "Not attempting to remove conditional as its predicate is not a " + "compile-time constant: " + << conditional->ToShortString(); + return false; + } + + auto computation = conditional->parent(); + HloInstruction* call_op; + if (conditional->operand(0)->literal().Get({})) { + call_op = computation->AddInstruction(HloInstruction::CreateCall( + conditional->shape(), {conditional->mutable_operand(1)}, + conditional->true_computation())); + } else { + call_op = computation->AddInstruction(HloInstruction::CreateCall( + conditional->shape(), {conditional->mutable_operand(2)}, + conditional->false_computation())); + } + + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op)); + TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status()); + + return true; +} + +StatusOr ConditionalSimplifier::Run(HloModule* module) { + XLA_VLOG_LINES( + 3, "ConditionalSimplifier::Run(), before:\n" + module->ToString()); + bool changed = false; + + // Gather all the conditional ops in our module. We do this ahead of time so + // we don't have to worry about mutating the lists of computations or + // instructions as we iterate. + std::vector conditional_ops; + for (auto* comp : module->computations()) { + for (auto* instr : comp->instructions()) { + if (instr->opcode() == HloOpcode::kConditional) { + conditional_ops.push_back(instr); + } + } + } + + for (HloInstruction* conditional_op : conditional_ops) { + TF_ASSIGN_OR_RETURN(bool result, TryRemoveConditional(conditional_op)); + changed |= result; + } + + XLA_VLOG_LINES(3, + "ConditionalSimplifier::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h new file mode 100644 index 0000000000000000000000000000000000000000..063261e26d06e21a297e8e3c405898a17221b7ca --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_simplifier.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace xla { + +// HLO pass that removes kConditional with a constant predicate, replacing them +// with their true or false computation as appropriate. +class ConditionalSimplifier : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { + return "simplify-conditional"; + } + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..868348547d9f5cbdc7576c7fc0697d72c3a3e557 --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -0,0 +1,153 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/conditional_simplifier.h" + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class ConditionalSimplifierTest : public HloVerifiedTestBase { + public: + // Makes a computation that contains a conditional with constant predicate. + HloComputation* MakeConditional(HloModule* module); +}; + +HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) { + HloComputation::Builder builder(TestName()); + + // true_computation returns param+1. + HloComputation* true_computation; + { + HloComputation::Builder true_computation_builder(TestName() + + ".true_computation"); + auto param = + true_computation_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(S32, {}), "param")); + auto one = true_computation_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))); + + true_computation_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, one)); + + true_computation = + module->AddEmbeddedComputation(true_computation_builder.Build()); + } + + // false_computation returns param+42. + HloComputation* false_computation; + { + HloComputation::Builder false_computation_builder(TestName() + + ".false_computation"); + auto param = false_computation_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), + "param")); + auto forty_two = false_computation_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + false_computation_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, forty_two)); + false_computation = + module->AddEmbeddedComputation(false_computation_builder.Build()); + } + + auto false_instrn = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + auto false_param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(S32, {}), "false_param")); + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))); + + builder.AddInstruction(HloInstruction::CreateConditional( + ShapeUtil::MakeShape(S32, {}), false_instrn, one, true_computation, + false_param, false_computation)); + + return module->AddEntryComputation(builder.Build()); +} + +TEST_F(ConditionalSimplifierTest, ConditionalGetsInlined) { + HloComputation* computation = MakeConditional(&module()); + ASSERT_TRUE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Parameter(), op::Constant())); +} + +TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) { + HloComputation* computation = MakeConditional(&module()); + + auto* true_op = computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + TF_ASSERT_OK( + true_op->AddControlDependencyTo(computation->root_instruction())); + + EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) { + HloComputation* computation = MakeConditional(&module()); + auto* conditional = computation->root_instruction(); + ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); + + auto* true_computation = conditional->true_computation(); + auto* send = true_computation->AddInstruction(HloInstruction::CreateSend( + true_computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))), + /*channel_id=*/0)); + true_computation->AddInstruction(HloInstruction::CreateSendDone(send)); + EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) { + HloComputation* computation = MakeConditional(&module()); + auto* conditional = computation->root_instruction(); + ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); + + auto* true_computation = conditional->true_computation(); + auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv( + ShapeUtil::MakeShape(F32, {1}), /*channel_id=*/0)); + true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv)); + EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { + HloComputation* computation = MakeConditional(&module()); + auto* conditional = computation->root_instruction(); + ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); + auto* false_computation = conditional->false_computation(); + false_computation->AddInstruction( + HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); + EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index cc195879a6bb490a9b49ad962aa9326cb51d9b0a..40519ecc799c8f0343294ad88009820dbd8535e9 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -58,6 +58,46 @@ bool ValueIsReadOnly(const HloValue& value) { return IsConstantValue(value) || IsEntryParameterValue(value); } +// Data structure describing the action which should be taken on parts of a +// computation buffers, with respect to the adding of special case copies. +struct SpecialCaseCopyPolicy { + // Insert a copy if the same buffer is found at multiple indices within the + // output tuple. + bool copy_root_replicated_buffers = false; + // If true, insert a copy if a buffer coming from a constant or a parameter + // is found wihtin the output tuple. + bool copy_parameters_and_constants = false; +}; + +SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node, + HloModule* module, + HloComputation* computation) { + SpecialCaseCopyPolicy policy; + if (computation == module->entry_computation()) { + policy.copy_parameters_and_constants = true; + policy.copy_root_replicated_buffers = true; + } + for (const CallSite& site : node.caller_callsites()) { + // The AddCopiesForConditional() already adds copies, but the copy remover + // removes them, so we re-add them by returning the policy here. But really + // the copy remover should not be removing them. + if (site.instruction()->opcode() == HloOpcode::kConditional) { + policy.copy_parameters_and_constants = true; + policy.copy_root_replicated_buffers = true; + } + } + return policy; +} + +bool ShouldCopyRootValue(const HloValue& value, + const SpecialCaseCopyPolicy& policy) { + if (policy.copy_parameters_and_constants) { + return IsConstantValue(value) || + value.defining_instruction()->opcode() == HloOpcode::kParameter; + } + return false; +} + // Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in // 'indices_to_copy'. Add control edges from the respective kCopy instructions // in deep copy of 'from' to the respective kCopy instruction in the deep copy @@ -282,6 +322,29 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, return Status::OK(); } +// We add copies for all the indices of the true and false computaiton roots, +// in order to resolve interference. We later rely on the CopyRemover to drop +// the unnecessary ones. +Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, + HloInstruction* conditional) { + VLOG(2) << "Adding copies for kConditional instruction " + << conditional->name(); + TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional); + + for (HloComputation* computation : + {conditional->true_computation(), conditional->false_computation()}) { + HloInstruction* root = computation->root_instruction(); + std::vector users = root->users(); + TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, + computation->DeepCopyInstruction(root)); + for (HloInstruction* user : users) { + TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy)); + } + computation->set_root_instruction(deep_copy); + } + return Status::OK(); +} + // Removes any control dependencies to or from the given instruction. Status StripControlDependenciesFrom(HloInstruction* instruction) { while (!instruction->control_successors().empty()) { @@ -309,6 +372,9 @@ Status AddCopiesToResolveInterference(HloModule* module) { for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kWhile) { TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction)); + } else if (instruction->opcode() == HloOpcode::kConditional) { + TF_RETURN_IF_ERROR( + AddCopiesForConditional(*alias_analysis, instruction)); } } } @@ -557,6 +623,7 @@ class CopyRemover { auto is_live_range_before = [this](const ValueNode& a, const ValueNode& b) { + VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value; if (LiveRangeBefore(a, b)) { VLOG(2) << " Live range of " << a.value->ToShortString() << " is before " << b.value->ToShortString(); @@ -571,7 +638,7 @@ class CopyRemover { VLOG(3) << copy->name() << " copies value " << src->value->ToShortString(); VLOG(3) << "Source buffer values: " << ValueListToString(src); - VLOG(3) << "Dest buffer values: " << ValueListToString(src); + VLOG(3) << "Dest buffer values: " << ValueListToString(dest); // A kCopy instruction copies an HLO value from a source buffer and // defines an HLO value in a destination buffer. Most generally, the @@ -747,16 +814,16 @@ class CopyRemover { // updated as copies are removed. bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) { if (a.uses.empty()) { - VLOG(2) << "Empty uses"; + VLOG(2) << "Empty uses for " << *a.value; return ordering_.IsDefinedBefore(*a.value, *b.value); } for (const HloUse* use : a.uses) { - VLOG(2) << "use: " << *use; - VLOG(2) << "is before:" << *b.value; + VLOG(2) << "Checking use " << *use << " against " << *b.value; if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { - VLOG(2) << "Not before"; + VLOG(2) << "Use " << *use << " is NOT before " << *b.value; return false; } + VLOG(2) << "Use " << *use << " is before " << *b.value; } return true; } @@ -892,7 +959,6 @@ Status RemoveUnnecessaryCopies( CopyRemover copy_remover(*alias_analysis, ordering, module); XLA_VLOG_LINES(3, copy_remover.ToString()); - tensorflow::gtl::FlatSet existing_copies; for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kCopy && @@ -901,7 +967,6 @@ Status RemoveUnnecessaryCopies( } } } - return Status::OK(); } @@ -921,7 +986,7 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { // Identify which shape indices of which instructions need to be copied. Store // these results in 'instructions_to_copy'. - std::unordered_map> instructions_to_copy; + HloInstructionMap> instructions_to_copy; auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction, const ShapeIndex& index) { auto it = instructions_to_copy.find(instruction); @@ -957,7 +1022,8 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { } TF_RET_CHECK(node.context() == CallContext::kSequential); - const bool is_entry = computation == module->entry_computation(); + SpecialCaseCopyPolicy policy = + GetSpecialCaseCopyPolicy(node, module, computation); HloInstruction* root = computation->root_instruction(); // Mark nondistinct/ambiguous indices. @@ -970,27 +1036,26 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { for (const HloBuffer* buffer : buffers_at_index) { buffer_seen_before |= !seen.insert(buffer).second; } - if (buffers_at_index.size() > 1 || (buffer_seen_before && is_entry)) { - VLOG(2) << "Index " << index << " of root of computation " + if (buffers_at_index.size() > 1 || + (buffer_seen_before && policy.copy_root_replicated_buffers)) { + VLOG(2) << "Index " << index << " of computation " << computation->name() << " (" << root->name() << ") has ambiguous or non-distinct buffer. Copying."; add_index_to_copy(root, index); } }); - // For entry instructions, mark any parameter or constant values. - if (is_entry) { - for (const auto& pair : - alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) { - const ShapeIndex& index = pair.first; - const HloValueSet& value_set = pair.second; - for (const HloValue* value : value_set.values()) { - if (ValueIsReadOnly(*value)) { - VLOG(2) << "Root of entry computation (" << root->name() - << ") has constant or entry parameter value at index " - << index << ". Copying."; - add_index_to_copy(root, index); - } + for (const auto& pair : + alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) { + const ShapeIndex& index = pair.first; + const HloValueSet& value_set = pair.second; + for (const HloValue* value : value_set.values()) { + if (ShouldCopyRootValue(*value, policy)) { + VLOG(2) << "Root of (" << root->name() << ") of computation(" + << computation->name() + << ") has constant or parameter value at index " << index + << ". Copying."; + add_index_to_copy(root, index); } } } @@ -1012,7 +1077,6 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { instruction->parent()->set_root_instruction(deep_copy); } } - return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 32be0b0c968f2d24f460fc8377c458f2da282112..0faa9e9c41063c5f7576ef5cbd873e8a84a73c28 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -105,9 +105,11 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", + "//tensorflow/compiler/xla/service:gather_expander", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", @@ -514,7 +516,6 @@ cc_library( cc_library( name = "runtime_matvec", - srcs = ["runtime_matvec.cc"], hdrs = ["runtime_matvec.h"], copts = runtime_copts(), deps = [ @@ -669,6 +670,22 @@ cc_library( ], ) +tf_cc_test( + name = "ir_emission_utils_test", + srcs = ["ir_emission_utils_test.cc"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + cc_library( name = "cpu_layout_assignment", srcs = ["cpu_layout_assignment.cc"], @@ -771,6 +788,31 @@ cc_library( ], ) +tf_cc_test( + name = "parallel_task_assignment_test", + srcs = ["parallel_task_assignment_test.cc"], + deps = [ + ":cpu_executable", + ":parallel_task_assignment", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "cpu_options", srcs = ["cpu_options.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 387806e24aad0d5f28cb104507ef6cc136ffd779..e43777c5e5e8afcf08e1e334c8847f6b94d0d047 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -47,6 +47,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" @@ -66,6 +67,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" @@ -260,6 +262,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true, /*use_fusion=*/false); + pipeline.AddPass(); pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, @@ -275,6 +278,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pass.AddPass(); pass.AddPass(); pass.AddPass(); + pass.AddPass(); } pipeline.AddPass( [](const HloInstruction& dot, @@ -314,7 +318,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // Note this is not run for AOT because it would bring in thread pool // and thread synchronization dependencies which would likely increase // binary size (and most AOT applications are single-threaded). - // TODO(29630486) Support multi-threaded AOT. + // TODO(b/29630486) Support multi-threaded AOT. pipeline.AddPass(max_parallelism, ShapeSizeBytesFunction()); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 482e04052d5a914eab0e5bff2c7a83f3b698052f..0fc5a746bbbc7685ff5d4647111a750e7d7b1c19 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -30,7 +30,6 @@ bool CanBeLoopFused(const HloInstruction& hlo) { // These are the only ones we fuse since we rely on effective elemental IR // generation. return hlo.IsElementwise() || // - hlo.opcode() == HloOpcode::kBitcast || hlo.opcode() == HloOpcode::kBroadcast || hlo.opcode() == HloOpcode::kConcatenate || hlo.opcode() == HloOpcode::kDynamicSlice || diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 595c3f55b321f47e2312b93e0c238c7637495d77..6ed1cd31b18f6360bdd7fd41bd5be2e657b310a5 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -77,7 +77,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) { EXPECT_THAT(computation->root_instruction(), op::Fusion()); } -TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) { +TEST_F(InstructionFusionTest, DotOperationNoFusion_Bitcast) { HloComputation::Builder builder(TestName()); HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0")); @@ -94,8 +94,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); - EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Fusion()); + EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); } TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { @@ -244,35 +243,33 @@ class OpcodeFusionTest : public InstructionFusionTest { } }; -TEST_F(OpcodeFusionTest, Exponential_Bitcast_Negate) { +TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) { HloComputation::Builder builder(TestName()); Shape param_shape = ShapeUtil::MakeShape(F32, {1, 4}); Shape result_shape = ShapeUtil::MakeShape(F32, {4}); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "param")); - // InstructionFusion::ShouldFuse() precludes fusing a bitcast whose operand - // is a parameter, so create an operand between the parameter and bitcast. HloInstruction* exp1 = builder.AddInstruction( HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0)); - HloInstruction* bitcast2 = builder.AddInstruction( - HloInstruction::CreateUnary(result_shape, HloOpcode::kBitcast, exp1)); + HloInstruction* reshape2 = + builder.AddInstruction(HloInstruction::CreateReshape(result_shape, exp1)); builder.AddInstruction( - HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, bitcast2)); + HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape2)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( - module.get(), {HloOpcode::kNegate, HloOpcode::kBitcast, HloOpcode::kExp, + module.get(), {HloOpcode::kNegate, HloOpcode::kReshape, HloOpcode::kExp, HloOpcode::kParameter}); } -TEST_F(OpcodeFusionTest, Broadcast_Bitcast_DynamicSlice_Tanh) { +TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { HloComputation::Builder builder(TestName()); Shape param_shape = ShapeUtil::MakeShape(F32, {8}); Shape starts_shape = ShapeUtil::MakeShape(F32, {2}); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {1, 8, 8}); - Shape bitcast_shape = ShapeUtil::MakeShape(F32, {8, 8}); + Shape reshape_shape = ShapeUtil::MakeShape(F32, {8, 8}); Shape dynamic_slice_shape = ShapeUtil::MakeShape(F32, {4, 4}); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "param")); @@ -280,11 +277,11 @@ TEST_F(OpcodeFusionTest, Broadcast_Bitcast_DynamicSlice_Tanh) { HloInstruction::CreateParameter(1, starts_shape, "starts")); HloInstruction* broadcast2 = builder.AddInstruction( HloInstruction::CreateBroadcast(broadcast_shape, param0, {1})); - HloInstruction* bitcast3 = builder.AddInstruction(HloInstruction::CreateUnary( - bitcast_shape, HloOpcode::kBitcast, broadcast2)); + HloInstruction* reshape3 = builder.AddInstruction( + HloInstruction::CreateReshape(reshape_shape, broadcast2)); HloInstruction* dynamic_slice4 = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - dynamic_slice_shape, bitcast3, param1, {4, 4})); + dynamic_slice_shape, reshape3, param1, {4, 4})); builder.AddInstruction(HloInstruction::CreateUnary( dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4)); @@ -293,7 +290,7 @@ TEST_F(OpcodeFusionTest, Broadcast_Bitcast_DynamicSlice_Tanh) { RunFusionAndCheckOpcodesWereFused( module.get(), - {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kBitcast, + {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kReshape, HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter}); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 40ace963270e8cead47cc731cc326351178dff7d..9a3bd68c80c6e8bcdb231c63ba025d1f73619eb7 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -31,6 +31,8 @@ XfeedManager* GetXfeedManager() { return manager; } +extern const char* const kEigenMatMulF16SymbolName = + "__xla_cpu_runtime_EigenMatMulF16"; extern const char* const kEigenMatMulF32SymbolName = "__xla_cpu_runtime_EigenMatMulF32"; extern const char* const kEigenMatMulF64SymbolName = @@ -40,6 +42,8 @@ extern const char* const kEigenConvF16SymbolName = extern const char* const kEigenConvF32SymbolName = "__xla_cpu_runtime_EigenConvF32"; extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft"; +extern const char* const kEigenSingleThreadedMatMulF16SymbolName = + "__xla_cpu_runtime_EigenSingleThreadedMatMulF16"; extern const char* const kEigenSingleThreadedMatMulF32SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF32"; extern const char* const kEigenSingleThreadedMatMulF64SymbolName = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 2141dfe1cedd6f9674acc348152574b4fd30895b..e61d6ea28b633398863357541e056ee887582f9c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -41,11 +41,13 @@ namespace runtime { // the actual symbol. // 2. When using ahead-of-time compilation, the linker can resolve the name // because it is a symbol in the cpu_runtime library. +extern const char* const kEigenMatMulF16SymbolName; extern const char* const kEigenMatMulF32SymbolName; extern const char* const kEigenMatMulF64SymbolName; extern const char* const kEigenConvF16SymbolName; extern const char* const kEigenConvF32SymbolName; extern const char* const kEigenFftSymbolName; +extern const char* const kEigenSingleThreadedMatMulF16SymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; extern const char* const kEigenSingleThreadedConvF16SymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index cfe7c9c3af0be109ac8a86753e880e2bcbceba41..8b1e20d79e90fcc32e985ffb855a1a10cdd2f2b9 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -715,6 +715,11 @@ tensorflow::Status DotOpEmitter::Emit() { // which performs the sum-of-products (the reduction loop) before storing // the result in the output buffer. + // This routine assumes that the dot operation is not in a parallelized + // enclosing computation. + CHECK( + dot_.parent()->root_instruction()->outer_dimension_partitions().empty()); + const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); @@ -919,6 +924,12 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { llvm::Type* float_type; const char* fn_name; switch (type) { + case F16: + fn_name = multi_threaded_eigen + ? runtime::kEigenMatMulF16SymbolName + : runtime::kEigenSingleThreadedMatMulF16SymbolName; + float_type = ir_builder_->getHalfTy(); + break; case F32: fn_name = multi_threaded_eigen ? runtime::kEigenMatMulF32SymbolName @@ -1051,7 +1062,8 @@ static bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, // The inputs and the output must // 1) be matrices with no padding, and // 2) have an allowed element type. - return output_shape.element_type() == F32 && + PrimitiveType output_primitive_type = output_shape.element_type(); + return (output_primitive_type == F32 || output_primitive_type == F16) && IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && IsRank2WithNoPadding(output_shape); } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index 788217aab6172b4e548452b3f6ffd4197c163ce4..f209a69e3cd0f8d336d61bafd1e22be8bc88ca3f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -34,14 +34,16 @@ bool PotentiallyImplementedAsEigenConvolution( // // To be sufficient, certain layout constraints need to be satisfied as well. const Shape& input_shape = convolution.operand(0)->shape(); - const Shape& kernel_shape = convolution.operand(0)->shape(); + const Shape& kernel_shape = convolution.operand(1)->shape(); if (ShapeUtil::HasZeroElements(input_shape) || ShapeUtil::HasZeroElements(kernel_shape)) { return false; } + // Make sure input and kernel has the same data type. + CHECK( + ShapeUtil::SameElementTypeIgnoringFpPrecision(input_shape, kernel_shape)); // TODO(b/65408531): Explore using Eigen dot for complex64 type. - if (ShapeUtil::ElementIsComplex(input_shape) || - ShapeUtil::ElementIsComplex(kernel_shape)) { + if (ShapeUtil::ElementIsComplex(input_shape)) { return false; } if (window_util::HasWindowReversal(convolution.window())) { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..215f48c4cc1a1a6b13d98dff76e0d1f0f773f5c1 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" + +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { + +TEST(IrEmitterTest, ConvWithZeroSizedKernelNotImplementedAsEigen) { + const char* const hlo_string = R"( +HloModule ModuleWithConv + +ENTRY Conv { + input = f32[32,50,28,28]{3,2,1,0} parameter(0) + kernel = f32[0,32,5,5]{3,2,1,0} parameter(1) + ROOT convolution = f32[64,50,24,24]{3,2,1,0} convolution(input, kernel), + window={size=5x5}, + dim_labels=b01f_01io->b01f +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + HloComputation* entry_computation = module->entry_computation(); + + HloInstruction* conv_instr = entry_computation->root_instruction(); + EXPECT_FALSE(cpu::PotentiallyImplementedAsEigenConvolution(*conv_instr)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 4dffaee87f6b33933b58c8c58478eec918569197..3405277d449f2d9e558f2d3f83277163655af592 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -438,12 +438,14 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, if (kind == XfeedKind::kInfeed) { // Copy to the program buffer address from the acquired buffer. - ir_builder_.CreateMemCpy(program_buffer_address, acquired_pointer, - length_32, 1); + ir_builder_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, + acquired_pointer, + /*SrcAlign=*/1, length_32); } else { // Outfeed -- copy from the in-program address to the acquired buffer. - ir_builder_.CreateMemCpy(acquired_pointer, program_buffer_address, - length_32, 1); + ir_builder_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, + program_buffer_address, + /*SrcAlign=*/1, length_32); } ir_builder_.CreateCall(release_func, @@ -2074,7 +2076,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*root, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F32})); + /*supported_types=*/{F16, F32})); llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); @@ -2441,7 +2443,8 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); } else { auto* memcpy_instruction = ir_builder_.CreateMemCpy( - target, source, element_count * primitive_type_size, element_alignment); + target, /*DstAlign=*/element_alignment, source, + /*SrcAlign=*/element_alignment, element_count * primitive_type_size); // The memcpy does the load and the store internally. The aliasing related // metadata has to reflect that. @@ -2905,7 +2908,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source, llvm::Value* destination_value = GetEmittedValueFor(&destination); int64 source_size = ByteSizeOf(source.shape()); // TODO(b/63762267): Be more aggressive about specifying alignment. - ir_builder_.CreateMemCpy(destination_value, source_value, source_size, 1); + ir_builder_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value, + /*SrcAlign=*/1, source_size); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index deb21bf4ef5895cfdbec5c2449b6ce7b306a7008..fb28280fade307ac1f193e7dca481bd2afa855fc 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -71,7 +71,7 @@ class DefaultCostModel : public ParallelCostModel { if (flops_to_bytes_ratio <= 1.0) { // Limit max parallelism for I/O bound instructions by assuming a // sub-linear scaling function (fit based on empirical benchmark results). - // TODO(29630486) Develop system bandwidth model. + // TODO(b/29630486) Develop system bandwidth model. max_parallelism = std::ceil(std::sqrt(tensorflow::port::NumSchedulableCPUs())); // Use shape size instruction cost and L2 cache size min per-thread cost. @@ -81,7 +81,7 @@ class DefaultCostModel : public ParallelCostModel { // Use max parallelism for compute bound instructions. max_parallelism = max_parallelism_; // Calculate the instruction cost in cycles. - // TODO(29630486) Improve on this linear cost model. + // TODO(b/29630486) Improve on this linear cost model. // Consider making 'min_cost_per_thread' be a function of the target // bandwidth limit for instructions with low arithmetic complexity. instruction_cost = @@ -128,24 +128,25 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( // one of the following properties: // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall). // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). + // *) Operations that are not thread safe (like infeed and rng). // *) Tuple-shaped. // TODO(b/27458679) Parallelize instructions which are skipped here. - if (instruction->opcode() == HloOpcode::kParameter || - instruction->opcode() == HloOpcode::kConstant || - instruction->opcode() == HloOpcode::kCall || - instruction->opcode() == HloOpcode::kCustomCall || - instruction->opcode() == HloOpcode::kSelectAndScatter || - instruction->opcode() == HloOpcode::kGetTupleElement || - instruction->opcode() == HloOpcode::kBitcast || - instruction->opcode() == HloOpcode::kFft || - (instruction->opcode() == HloOpcode::kConvolution && + auto opcode = instruction->opcode(); + if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant || + opcode == HloOpcode::kCall || opcode == HloOpcode::kCustomCall || + opcode == HloOpcode::kDot || opcode == HloOpcode::kSelectAndScatter || + opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast || + opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed || + opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng || + (opcode == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction)) || PotentiallyImplementedAsEigenDot(*instruction) || - (instruction->opcode() == HloOpcode::kFusion && + (opcode == HloOpcode::kFusion && instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || ShapeUtil::IsTuple(instruction->shape())) { return 1; } + // Consult 'cost_model_' to compute target parallel task count. return cost_model_->GetParallelTaskCount(instruction); } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..13eb75a57213b1a68a5732a4f6061efdf97fa4f4 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -0,0 +1,118 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace xla { +namespace { + +class ParallelTaskAssignmentTest : public HloVerifiedTestBase { + protected: + const HloCostAnalysis::ShapeSizeFunction shape_size_func_ = + cpu::CpuExecutable::ShapeSizeBytes; + + // Use any value larger than 2 since we only test whether a module is + // parallelized or not + const int max_parallelism_ = 10; +}; + +TEST_F(ParallelTaskAssignmentTest, DotOperationNotParallelized) { + const string hlo_string = R"( + HloModule TestTaskParallel_Dot + ENTRY Dot { + dot_lhs = f32[196614,2]{1,0} parameter(0) + dot_rhs = f32[2,1]{1,0} parameter(1) + ROOT dot = f32[196614,1]{1,0} dot(dot_lhs, dot_rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( + max_parallelism_, shape_size_func_) + .Run(&module())); + EXPECT_FALSE(changed); +} + +TEST_F(ParallelTaskAssignmentTest, + FusedComputationWithDotOperationNotParallelized) { + const string hlo_string = R"( + HloModule TestTaskParallel_DotNestedInFusedComp + fused_computation.0 { + parameter.0 = f32[196614,2]{1,0} parameter(0) + parameter.0.1 = f32[2,1]{1,0} parameter(1) + parameter.0.2 = f32[196614,1]{1,0} parameter(2) + dot.0 = f32[196614,1]{1,0} dot(parameter.0, parameter.0.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT add.0 = f32[196614,1]{1,0} add(dot.0, parameter.0.2) + + } + ENTRY DotNestedInFusedComp { + parameter = f32[196614,2]{1,0} parameter(0) + parameter.1 = f32[2,1]{1,0} parameter(1) + parameter.2 = f32[196614,1]{1,0} parameter(2) + ROOT fusion = f32[196614,1]{1,0} fusion(parameter, parameter.1, + parameter.2), kind=kOutput, calls=fused_computation.0 + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( + max_parallelism_, shape_size_func_) + .Run(&module())); + EXPECT_FALSE(changed); +} + +TEST_F(ParallelTaskAssignmentTest, RngOperationNotParallelized) { + const string hlo_string = R"( + HloModule TestTaskParallel_rng + ENTRY Rng { + src0 = f32[] parameter(0) + src1 = f32[] parameter(1) + ROOT rng0 = f32[1234567,2]{1,0} rng(f32[] src0, f32[] src1), + distribution=rng_uniform + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( + max_parallelism_, shape_size_func_) + .Run(&module())); + EXPECT_FALSE(changed); +} + +TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { + const string hlo_string = R"( + HloModule TestTaskParallel_infeed_outfeed + ENTRY InfeedOutfeed { + infeed0 = u32[12345678,2]{1,0} infeed() + ROOT outfeed0 = u32[12345678,2]{1,0} outfeed(infeed0) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( + max_parallelism_, shape_size_func_) + .Run(&module())); + EXPECT_FALSE(changed); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h index 39e20ed45639040110b99ddb52eb6f6dab26dfaa..7337c907f5c83d608641b7382e75902e6f6c05d4 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_H_ +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" extern "C" { diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index bff57d33ae23fbba8c664cbd18df77e4c35eb592..39b13183ff093611a42b3931d45f64eadb420622 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -63,30 +63,41 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, C.device(*run_options->intra_op_thread_pool()) = A.contract(B, dims); } +template +void MatMulImpl(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { + if (m == 1 || n == 1) { + // Despite being single threaded, this version of matrix * vector is faster. + xla::EigenMatVec(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); + } else { + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); + } +} + } // namespace +void __xla_cpu_runtime_EigenMatMulF16(const void* run_options_ptr, + Eigen::half* out, Eigen::half* lhs, + Eigen::half* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); +} + void __xla_cpu_runtime_EigenMatMulF32(const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - if (m == 1 || n == 1) { - // Despite being single threaded, this version of matrix * vector is faster. - xla::EigenMatVecF32(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); - } else { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); - } + MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); } void __xla_cpu_runtime_EigenMatMulF64(const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - if (m == 1 || n == 1) { - // Despite being single threaded, this version of matrix * vector is faster. - xla::EigenMatVecF64(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); - } else { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); - } + MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.h b/tensorflow/compiler/xla/service/cpu/runtime_matmul.h index fdb644651dd5d0fa0345580f52ed0fb051672285..d96fe3d58bd5ffbad347e3ede3534d1d47be697a 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_H_ +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" extern "C" { @@ -25,6 +26,12 @@ extern "C" { // order. 'out' is a pointer to a buffer sufficiently large to hold the result // of the operation. Following standard nomenclature: lhs is m x k, // rhs is k x n, and out is m x n. +extern void __xla_cpu_runtime_EigenMatMulF16( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, tensorflow::int64 m, + tensorflow::int64 n, tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); + extern void __xla_cpu_runtime_EigenMatMulF32( const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matvec.cc b/tensorflow/compiler/xla/service/cpu/runtime_matvec.cc deleted file mode 100644 index 435820cdd36e2a906d9dfbe2555f4c0df623c729..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/runtime_matvec.cc +++ /dev/null @@ -1,110 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "third_party/eigen3/Eigen/Core" -#include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h" - -using tensorflow::int32; -using tensorflow::int64; - -namespace { - -// Does mat * x or mat^T * x. -template -void MatVec(T* out_buf, T* mat_buf, T* x_buf, int64 rows, int64 cols, - int32 transpose) { - // Use an Eigen Matrix instead of a Tensor, as the GEMV from Matrix seems to - // be faster (b/30223679). See also: the matmul op kernel in TensorFlow, - // which implements the same optimization. - using Matrix = Eigen::Matrix; - using MatrixMap = Eigen::Map; - - using Vector = Eigen::Matrix; - using VectorMap = Eigen::Map; - - auto x = VectorMap(x_buf, cols); - auto out = VectorMap(out_buf, rows); - - int64 mat_rows = rows; - int64 mat_cols = cols; - - if (transpose) { - std::swap(mat_rows, mat_cols); - } - - auto mat = MatrixMap(mat_buf, mat_rows, mat_cols); - - if (transpose) { - out = mat.transpose() * x; - } else { - out = mat * x; - } -} - -// Converts matmul-style args to matvec. -template -void DispatchMatVec(T* out, T* lhs, T* rhs, int64 m, int64 n, int64 k, - int32 transpose_lhs, int32 transpose_rhs) { - // If the input is in the form x * A, where x is the vector, then bring A back - // over to the left hand side. We make use of the identity - // - // (x * A)^T = A^T * x^T - // - // We do not need to take the transpose of x or of the result since taking - // the transpose of a vector does not change the memory layout. - const int64 cols = k; - - T* mat; - T* vec; - int64 rows; - bool transpose_mat; - - bool is_mat_vec = (n == 1); - - if (is_mat_vec) { - mat = lhs; - vec = rhs; - rows = m; - transpose_mat = transpose_lhs; - } else { - mat = rhs; - vec = lhs; - rows = n; - transpose_mat = !transpose_rhs; - } - - MatVec(out, mat, vec, rows, cols, transpose_mat); -} - -} // namespace - -namespace xla { - -void EigenMatVecF32(float* out, float* lhs, float* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, int32 transpose_rhs) { - assert((m == 1 || n == 1) && "not a matrix-vector multiply"); - DispatchMatVec(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); -} - -void EigenMatVecF64(double* out, double* lhs, double* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, int32 transpose_rhs) { - assert((m == 1 || n == 1) && "not a matrix-vector multiply"); - DispatchMatVec(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matvec.h b/tensorflow/compiler/xla/service/cpu/runtime_matvec.h index 1bd8dfb377acc1f7cfbe9a92773f87f0ef25de3a..70eb98c54169824e220d9287753c0849362eade6 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matvec.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_matvec.h @@ -16,10 +16,86 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_ +#include "third_party/eigen3/Eigen/Core" + #include "tensorflow/core/platform/types.h" namespace xla { +namespace detail { + +using tensorflow::int32; +using tensorflow::int64; + +// Does mat * x or mat^T * x. +template +void MatVec(T* out_buf, T* mat_buf, T* x_buf, int64 rows, int64 cols, + int32 transpose) { + // Use an Eigen Matrix instead of a Tensor, as the GEMV from Matrix seems to + // be faster (b/30223679). See also: the matmul op kernel in TensorFlow, + // which implements the same optimization. + using Matrix = Eigen::Matrix; + using MatrixMap = Eigen::Map; + + using Vector = Eigen::Matrix; + using VectorMap = Eigen::Map; + + auto x = VectorMap(x_buf, cols); + auto out = VectorMap(out_buf, rows); + + int64 mat_rows = rows; + int64 mat_cols = cols; + + if (transpose) { + std::swap(mat_rows, mat_cols); + } + + auto mat = MatrixMap(mat_buf, mat_rows, mat_cols); + + if (transpose) { + out = mat.transpose() * x; + } else { + out = mat * x; + } +} + +// Converts matmul-style args to matvec. +template +void DispatchMatVec(T* out, T* lhs, T* rhs, int64 m, int64 n, int64 k, + int32 transpose_lhs, int32 transpose_rhs) { + // If the input is in the form x * A, where x is the vector, then bring A back + // over to the left hand side. We make use of the identity + // + // (x * A)^T = A^T * x^T + // + // We do not need to take the transpose of x or of the result since taking + // the transpose of a vector does not change the memory layout. + const int64 cols = k; + + T* mat; + T* vec; + int64 rows; + bool transpose_mat; + + bool is_mat_vec = (n == 1); + + if (is_mat_vec) { + mat = lhs; + vec = rhs; + rows = m; + transpose_mat = transpose_lhs; + } else { + mat = rhs; + vec = lhs; + rows = n; + transpose_mat = !transpose_rhs; + } + + MatVec(out, mat, vec, rows, cols, transpose_mat); +} + +} // namespace detail + // Performs a matrix-vector multiplication using Eigen. 'lhs' and 'rhs' are // pointers to buffers containing input matrices in column-major order. 'out' is // a pointer to a buffer sufficiently large to hold the result of the @@ -30,15 +106,15 @@ namespace xla { // // TODO(b/64684907): Compare runtime performance of these functions with dot // simplification. -void EigenMatVecF32(float* out, float* lhs, float* rhs, tensorflow::int64 m, - tensorflow::int64 n, tensorflow::int64 k, - tensorflow::int32 transpose_lhs, - tensorflow::int32 transpose_rhs); - -void EigenMatVecF64(double* out, double* lhs, double* rhs, tensorflow::int64 m, - tensorflow::int64 n, tensorflow::int64 k, - tensorflow::int32 transpose_lhs, - tensorflow::int32 transpose_rhs); +template +void EigenMatVec(T* out, T* lhs, T* rhs, tensorflow::int64 m, + tensorflow::int64 n, tensorflow::int64 k, + tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { + assert((m == 1 || n == 1) && "not a matrix-vector multiply"); + detail::DispatchMatVec(out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h index f216bd0152aa93b8753d881938c63a9cabea899b..44b201725b2c724f48c1a3f0373c41e76211e0c2 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_CONV2D_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_CONV2D_H_ +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" extern "C" { diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index ee8eb081556d60fcf6537b1036a4a5825c4c7bf6..17303e2f0d34e531a3a56aa147608b949e0f43ae 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -57,26 +57,38 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, C = A.contract(B, dims); } +template +void SingleThreadedMatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, + int64 m, int64 n, int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + if (m == 1 || n == 1) { + xla::EigenMatVec(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); + } else { + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); + } +} + } // namespace +void __xla_cpu_runtime_EigenSingleThreadedMatMulF16( + const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, + Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); +} + void __xla_cpu_runtime_EigenSingleThreadedMatMulF32( const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - if (m == 1 || n == 1) { - xla::EigenMatVecF32(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); - } else { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); - } + SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } void __xla_cpu_runtime_EigenSingleThreadedMatMulF64( const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - if (m == 1 || n == 1) { - xla::EigenMatVecF64(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); - } else { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); - } + SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h index 029eb9514287d8c69cde2cfb06e0d56e78d6f165..82a1fcce594fa5b04f4fe459870991863c32a91a 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_ +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" extern "C" { @@ -25,6 +26,12 @@ extern "C" { // 'out' is a pointer to a buffer sufficiently large to hold the result of the // operation. Following standard nomenclature: lhs is m x k, rhs is k x n, and // out is m x n. +extern void __xla_cpu_runtime_EigenSingleThreadedMatMulF16( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, tensorflow::int64 m, + tensorflow::int64 n, tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); + extern void __xla_cpu_runtime_EigenSingleThreadedMatMulF32( const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.cc b/tensorflow/compiler/xla/service/cpu/shape_partition.cc index 61b408b8c24dded134218110d4e219c31f1685a8..42fe955f1917e0268dc739e44fbd0a7afb39185c 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.cc @@ -20,12 +20,13 @@ namespace cpu { std::vector ShapePartitionAssigner::Run(int64 target_partition_count) { // Gather outer-most dims where dim_size >= 'target_partition_count'. - // Note: always leave inner-dim static for vectorization/optimizations. + // This may include the inner-dim as LLVM can vectorize loops with dynamic + // bounds. std::vector outer_dims; int64 outer_dim_size = 1; // TODO(b/27458679) Consider reserving enough minor dimensions (based on // target vector register width) to enable vector instructions. - for (int i = shape_.layout().minor_to_major_size() - 1; i >= 1; --i) { + for (int i = shape_.layout().minor_to_major_size() - 1; i >= 0; --i) { const int64 dimension = shape_.layout().minor_to_major(i); outer_dims.push_back(dimension); outer_dim_size *= shape_.dimensions(dimension); diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc index ee0c53fa6d7c41481a53350e57e5844dea2644c1..ae80a6f4977f85cfd9f872734fd0a69432a1f382 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -30,105 +30,65 @@ class ShapePartitionAssignerTest : public HloTestBase { protected: typedef std::vector Vec; - void RunR2Test(const Shape& shape, const int64 expected_max_partition_count) { + void RunR2Test(const Shape& shape, int64 max_target_partition_count, + const std::vector* expected_partitions) { ShapePartitionAssigner assigner(shape); - // Check all partitions of outer dimension. - for (int64 i = 1; i <= expected_max_partition_count; ++i) { - EXPECT_TRUE(ContainersEqual(Vec({i}), - assigner.Run(/*target_partition_count=*/i))); + // Iterate through 1..max_target_partition_count. + for (int64 i = 1; i <= max_target_partition_count; ++i) { + std::vector actual_partitions = + assigner.Run(/*target_partition_count=*/i); + EXPECT_THAT(actual_partitions, expected_partitions[i - 1]); } - // Check target_partition_count > outer dimension size. - EXPECT_TRUE(ContainersEqual( - Vec({expected_max_partition_count}), - assigner.Run( - /*target_partition_count=*/expected_max_partition_count + 1))); } }; TEST_F(ShapePartitionAssignerTest, Shape13WithLayout10) { - RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {1, 3}, {1, 0}), 1); + std::vector expected_partitions[] = {{1} /* 1 */, {1, 2} /* 2 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {1, 3}, {1, 0}), 2, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape31WithLayout01) { - RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {3, 1}, {0, 1}), 1); + std::vector expected_partitions[] = { + {1} /* 1 */, {1, 2} /* 2 */ + }; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {3, 1}, {0, 1}), 2, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape53WithLayout10) { - RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {1, 0}), 5); + std::vector expected_partitions[] = {{1} /* 1 */, {2} /* 2 */, + {3} /* 3 */, {4} /* 4 */, + {5} /* 5 */, {3, 2} /* 6 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {1, 0}), 6, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape53WithLayout01) { - RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {0, 1}), 3); + std::vector expected_partitions[] = { + {1} /* 1 */, {2} /* 2 */, {3} /* 3 */, {2, 2} /* 4 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {0, 1}), 4, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape532WithLayout210) { - Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 1, 0}); - ShapePartitionAssigner assigner(shape); - - for (int64 i = 1; i <= 5; ++i) { - EXPECT_TRUE(ContainersEqual(Vec({i}), assigner.Run( - /*target_partition_count=*/i))); - } - - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/6))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/7))); - EXPECT_TRUE( - ContainersEqual(Vec({4, 2}), assigner.Run(/*target_partition_count=*/8))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 3}), assigner.Run(/*target_partition_count=*/9))); - EXPECT_TRUE(ContainersEqual(Vec({3, 3}), - assigner.Run(/*target_partition_count=*/10))); - EXPECT_TRUE(ContainersEqual(Vec({3, 3}), - assigner.Run(/*target_partition_count=*/11))); - EXPECT_TRUE(ContainersEqual(Vec({4, 3}), - assigner.Run(/*target_partition_count=*/12))); - EXPECT_TRUE(ContainersEqual(Vec({4, 3}), - assigner.Run(/*target_partition_count=*/13))); - EXPECT_TRUE(ContainersEqual(Vec({4, 3}), - assigner.Run(/*target_partition_count=*/14))); - EXPECT_TRUE(ContainersEqual(Vec({5, 3}), - assigner.Run(/*target_partition_count=*/15))); - EXPECT_TRUE(ContainersEqual(Vec({5, 3}), - assigner.Run(/*target_partition_count=*/16))); + std::vector expected_partitions[] = { + {1} /* 1 */, {2} /* 2 */, {3} /* 3 */, {4} /* 4 */, + {5} /* 5 */, {3, 2} /* 6 */, {3, 2} /* 7 */, {4, 2} /* 8 */, + {3, 3} /* 9 */, {3, 3} /* 10 */, {3, 3} /* 11 */, {4, 3} /* 12 */, + {4, 3} /* 13 */, {4, 3} /* 14 */, {5, 3} /* 15 */, {4, 2, 2} /* 16 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 1, 0}), 16, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { - Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 0, 1}); - ShapePartitionAssigner assigner(shape); - - for (int64 i = 1; i <= 3; ++i) { - EXPECT_TRUE(ContainersEqual(Vec({i}), assigner.Run( - /*target_partition_count=*/i))); - } - - EXPECT_TRUE( - ContainersEqual(Vec({2, 2}), assigner.Run(/*target_partition_count=*/4))); - EXPECT_TRUE( - ContainersEqual(Vec({2, 2}), assigner.Run(/*target_partition_count=*/5))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/6))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/7))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/8))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 3}), assigner.Run(/*target_partition_count=*/9))); - EXPECT_TRUE(ContainersEqual(Vec({3, 3}), - assigner.Run(/*target_partition_count=*/10))); - EXPECT_TRUE(ContainersEqual(Vec({3, 3}), - assigner.Run(/*target_partition_count=*/11))); - EXPECT_TRUE(ContainersEqual(Vec({3, 4}), - assigner.Run(/*target_partition_count=*/12))); - EXPECT_TRUE(ContainersEqual(Vec({3, 4}), - assigner.Run(/*target_partition_count=*/13))); - EXPECT_TRUE(ContainersEqual(Vec({3, 4}), - assigner.Run(/*target_partition_count=*/14))); - EXPECT_TRUE(ContainersEqual(Vec({3, 5}), - assigner.Run(/*target_partition_count=*/15))); - EXPECT_TRUE(ContainersEqual(Vec({3, 5}), - assigner.Run(/*target_partition_count=*/16))); + std::vector expected_partitions[] = { + {1} /* 1 */, {2} /* 2 */, {3} /* 3 */, {2, 2} /* 4 */, + {2, 2} /* 5 */, {3, 2} /* 6 */, {3, 2} /* 7 */, {3, 2} /* 8 */, + {3, 3} /* 9 */, {3, 3} /* 10 */, {3, 3} /* 11 */, {3, 4} /* 12 */, + {3, 4} /* 13 */, {3, 4} /* 14 */, {3, 5} /* 15 */, {3, 2, 2} /* 16 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 0, 1}), 16, + expected_partitions); } class ShapePartitionIteratorTest : public HloTestBase { diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index e8a375d63791cd9a94f77af4ef5e74d2cb7e4361..80c24eaccfc2a83f8f3f311d60860715668d0c08 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -181,10 +181,12 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenFft); + REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index c732974995f70d9ba1b46e18aa4cc2c6ab467182..b6a0903b0eeaa04d8bc1488378c148b2016c5d48 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1522,15 +1522,12 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kBroadcast: return [this, hlo, &operand_to_generator]( const IrArray::Index& target_index) -> StatusOr { + const HloInstruction* operand = hlo->operand(0); // The `dimensions` member of the broadcast instruction maps from // input dimensions to output dimensions. - const HloInstruction* operand = hlo->operand(0); - int64 rank = ShapeUtil::Rank(operand->shape()); - IrArray::Index source_index(rank); - for (int64 i = 0; i < rank; ++i) { - source_index[i] = target_index[hlo->dimensions(i)]; - } - return operand_to_generator.at(operand)(source_index); + return operand_to_generator.at( + operand)(target_index.SourceIndexOfBroadcast( + hlo->shape(), operand->shape(), hlo->dimensions(), ir_builder_)); }; case HloOpcode::kSlice: return [this, hlo, &operand_to_generator]( @@ -1722,6 +1719,14 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( SetToFirstInsertPoint(if_data.after_block, ir_builder_); return ir_builder_->CreateLoad(ret_value_addr); }; + case HloOpcode::kBitcast: + CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()), + ShapeUtil::ElementsIn(hlo->operand(0)->shape())); + return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + const HloInstruction* operand = hlo->operand(0); + return operand_to_generator.at(operand)(index.SourceIndexOfBitcast( + hlo->shape(), operand->shape(), ir_builder_)); + }; case HloOpcode::kReshape: CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()), ShapeUtil::ElementsIn(hlo->operand(0)->shape())); diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 90481c7a88f90edea5399ee44aee2d2c77fc115f..be92b1629a2d8dae57b315751bd4f7f9ccddf171 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc new file mode 100644 index 0000000000000000000000000000000000000000..221ff7900f398166c193c495848a2afcfd4edc81 --- /dev/null +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -0,0 +1,392 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gather_expander.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +using tensorflow::gtl::ArraySlice; + +static StatusOr TransposeIndexVectorDimToLast( + HloInstruction* gather_indices, int64 index_vector_dim) { + const Shape& gather_indices_shape = gather_indices->shape(); + if (index_vector_dim == (gather_indices_shape.dimensions_size() - 1)) { + return gather_indices; + } + std::vector permutation; + permutation.reserve(gather_indices_shape.dimensions_size()); + for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + if (i != index_vector_dim) { + permutation.push_back(i); + } + } + permutation.push_back(index_vector_dim); + return MakeTransposeHlo(gather_indices, permutation); +} + +// If the gather_indices holds scalar indices (i.e. gather_indices has rank N +// and index_vector_dim is N) then reshape it to have a trailing degenerate +// dimension. This makes the code for slicing out the index vector more +// uniform. +static StatusOr DeScalarizeGatherIndices( + HloInstruction* gather_indices, int64 index_vector_dim) { + const Shape& gather_indices_shape = gather_indices->shape(); + if (index_vector_dim != gather_indices_shape.dimensions_size()) { + return gather_indices; + } + + DCHECK_EQ(index_vector_dim, gather_indices_shape.dimensions_size()); + + std::vector result_shape_dims; + c_copy(gather_indices_shape.dimensions(), + std::back_inserter(result_shape_dims)); + result_shape_dims.push_back(1); + + return MakeReshapeHlo(result_shape_dims, gather_indices); +} + +// Canonicalizes the gather_indices tensors so that we only have deal with some +// specific cases in the while loop that does the heavy lifting. +// +// See the "High Level Algorithm" section for a broader picture. +static StatusOr CanonicalizeGatherIndices( + HloInstruction* gather_indices, int64 index_vector_dim) { + // If gather_indices holds scalar indices, normalize it to hold index vectors + // of size 1. + TF_ASSIGN_OR_RETURN( + HloInstruction * descalarized_gather_indices, + DeScalarizeGatherIndices(gather_indices, index_vector_dim)); + + // Transpose the non-index-vector dimensions to the front. + TF_ASSIGN_OR_RETURN(HloInstruction * transposed_gather_indices, + TransposeIndexVectorDimToLast(descalarized_gather_indices, + index_vector_dim)); + + // If there is only one index (i.e. gather_indices has rank 1 and this gather + // is really just a dynamic slice) add a leading degenerate dimension for + // uniformity. Otherwise create a "collapsed" leading dimension that subsumes + // all of the non-index-vector dimensions. + const Shape& shape = transposed_gather_indices->shape(); + if (shape.dimensions_size() == 1) { + return ExpandFirstDimIntoNDims(transposed_gather_indices, + {1, shape.dimensions(0)}); + } else { + return CollapseFirstNDims(transposed_gather_indices, + shape.dimensions_size() - 1); + } +} + +// Expands out or contracts away the gather dimensions in the accumulator +// produced by the while loop. +static StatusOr AdjustGatherDimsInAccumulator( + const Shape& gather_indices_shape, HloInstruction* accumulator, + int64 index_vector_dim) { + std::vector output_gather_dim_bounds; + output_gather_dim_bounds.reserve(gather_indices_shape.dimensions_size()); + for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + if (i != index_vector_dim) { + output_gather_dim_bounds.push_back(gather_indices_shape.dimensions(i)); + } + } + + if (output_gather_dim_bounds.empty()) { + // If output_gather_dim_bounds is empty we must be lowering a (effectively) + // dynamic-slice. In that case, there is a leading degenerate gather + // dimension that we added to make this special case play well with the + // general while loop which we need to remove now. + CHECK_EQ(accumulator->shape().dimensions(0), 1); + ArraySlice reshaped_dim_sizes = + AsInt64Slice(accumulator->shape().dimensions()); + reshaped_dim_sizes.remove_prefix(1); + return MakeReshapeHlo(reshaped_dim_sizes, accumulator); + } + + return ExpandFirstDimIntoNDims(accumulator, output_gather_dim_bounds); +} + +// Expand an index vector from the gather_indices tensor into a vector that can +// be used to dynamic-slice out of the gather operand. +static StatusOr ExpandIndexVectorIntoOperandSpace( + HloInstruction* index_vector, const GatherDimensionNumbers& dim_numbers, + int64 operand_rank) { + HloComputation* computation = index_vector->parent(); + const Shape& index_shape = index_vector->shape(); + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + Literal::CreateFromDimensions(index_shape.element_type(), {1}))); + + // We extract out individual components from the smaller index and concatenate + // them (interspersing zeros as needed) into the larger index. + std::vector expanded_index_components; + + for (int i = 0; i < operand_rank; i++) { + int64 index_vector_dim_index = + FindIndex(dim_numbers.gather_dims_to_operand_dims(), i); + if (index_vector_dim_index != + dim_numbers.gather_dims_to_operand_dims_size()) { + TF_ASSIGN_OR_RETURN( + HloInstruction * component_to_concat, + MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, + /*limit_indices=*/{index_vector_dim_index + 1}, + /*strides=*/{1})); + expanded_index_components.push_back(component_to_concat); + } else { + expanded_index_components.push_back(zero); + } + } + + return MakeConcatHlo(expanded_index_components, /*dimension=*/0); +} + +// This generates the body of the while that implements the main data movement +// behavior of gather using dynamic-slice and dynamic-update-slice. +static StatusOr> GatherLoopBody( + const HloInstruction& gather, HloInstruction* induction_var, + const std::vector& incoming_loop_state) { + CHECK_EQ(incoming_loop_state.size(), 3); + HloInstruction* const operand = incoming_loop_state[0]; + HloInstruction* const gather_indices = incoming_loop_state[1]; + HloInstruction* const output_accumulator = incoming_loop_state[2]; + + int64 index_vector_size = gather_indices->shape().dimensions(1); + + TF_ASSIGN_OR_RETURN( + HloInstruction * induction_var_as_vector, + MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, + /*result_shape_bounds=*/{1})); + + TF_ASSIGN_OR_RETURN( + HloInstruction * index_into_gather_indices, + PadVectorWithZeros(induction_var_as_vector, + /*zeros_to_prepend=*/0, /*zeros_to_append=*/1)); + + TF_ASSIGN_OR_RETURN( + HloInstruction * index_vector_2d, + MakeDynamicSliceHlo(gather_indices, index_into_gather_indices, + {1, index_vector_size})); + + TF_ASSIGN_OR_RETURN(HloInstruction * index_vector, + ElideDegenerateDims(index_vector_2d, {0})); + + TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice_start, + ExpandIndexVectorIntoOperandSpace( + index_vector, gather.gather_dimension_numbers(), + operand->shape().dimensions_size())); + + TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice, + MakeDynamicSliceHlo(operand, gathered_slice_start, + gather.gather_window_bounds())); + + TF_ASSIGN_OR_RETURN( + HloInstruction * gathered_slice_for_update, + ExpandFirstDimIntoNDims(gathered_slice, + {1, gathered_slice->shape().dimensions(0)})); + + TF_ASSIGN_OR_RETURN( + HloInstruction * index_vector_into_accumulator, + PadVectorWithZeros( + induction_var_as_vector, /*zeros_to_prepend=*/0, + /*zeros_to_append=*/gathered_slice->shape().dimensions_size())); + + TF_ASSIGN_OR_RETURN( + HloInstruction * updated_accumulator, + MakeDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update, + index_vector_into_accumulator)); + + // New loop state -- only the accumulator has changed. The + // WhileUtil::MakeCountedLoop functions takes care of the induction variable + // and the while loop exit condition. + return StatusOr>{ + {operand, gather_indices, updated_accumulator}}; +} + +static StatusOr CreateGatherLoopAccumulatorInitValue( + HloComputation* computation, PrimitiveType element_type, + ArraySlice window_bounds, int64 gather_loop_trip_count) { + std::vector accumulator_state_shape_dims; + accumulator_state_shape_dims.reserve(1 + window_bounds.size()); + accumulator_state_shape_dims.push_back(gather_loop_trip_count); + c_copy(window_bounds, std::back_inserter(accumulator_state_shape_dims)); + return BroadcastZeros(computation, element_type, + accumulator_state_shape_dims); +} + +static StatusOr ElideWindowDimsFromAccumulator( + HloInstruction* accumulator, const GatherDimensionNumbers& dim_numbers) { + std::vector dims_to_elide; + dims_to_elide.reserve(dim_numbers.elided_window_dims_size()); + for (int64 elided_window_dim : dim_numbers.elided_window_dims()) { + dims_to_elide.push_back(elided_window_dim + 1); + } + + return ElideDegenerateDims(accumulator, dims_to_elide); +} + +// `accumulator` is almost the tensor the gather operation would have produced, +// except that it has the dimensions in the wrong order -- the gather dimensions +// are the major dimensions and the window dimensions are the minor dimensions. +// Fix this up with a transpose. +static StatusOr PermuteGatherAndWindowDims( + HloInstruction* accumulator, ArraySlice output_window_dims, + int64 output_rank) { + std::vector permutation; + permutation.reserve(output_rank); + + int64 gather_idx_counter = 0; + int64 window_idx_counter = output_rank - output_window_dims.size(); + for (int64 i = 0; i < output_rank; i++) { + bool is_window_dim = c_binary_search(output_window_dims, i); + if (is_window_dim) { + permutation.push_back(window_idx_counter++); + } else { + permutation.push_back(gather_idx_counter++); + } + } + + return MakeTransposeHlo(accumulator, permutation); +} + +// High Level Algorithm +// +// We follow the following steps in sequence: +// +// 1. We canonicalize the gather_indices tensor such that it has rank +// 2 (i.e. is a matrix) where each row is an index vector into the +// operand. +// 2. We iterate over the set of indices in the canonicalized +// gather_indices tensor using a while loop, accumulating slices +// of the operand tensor into an accumulator using +// DynamicUpdateSlice. +// 3. The accumulator result from the while loop from (2) is then +// reshaped to split out all the individual gather dimensions and +// then transposed to give the final result. +// +// As an example, if we started with the following operation: +// +// HloModule TensorFlowGatherMultipleBatchDims +// +// ENTRY main { +// operand = s32[3,3] parameter(0) +// indices = s32[2,2] parameter(1) +// ROOT gather = s32[2,3,2] gather(operand, indices), +// output_window_dims={1}, +// elided_window_dims={1}, +// gather_dims_to_operand_dims={1}, +// index_vector_dim=2, +// window_bounds={3, 1} +// } +// +// We'd first reshape indices to s32[4,1], where each row is an index +// into operand. We'd then run a loop to slice out 4 tensors of shape +// [3,1] out of operand into an accumulator of shape [4,3,1]. We then +// reshape this result to [2,2,3] and finally transpose it to [2,3,2]. + +StatusOr GatherExpander::ExpandGather( + HloInstruction* gather_instr) { + CHECK(!ShapeUtil::HasZeroElements(gather_instr->shape())); + + HloComputation* computation = gather_instr->parent(); + HloInstruction* operand = gather_instr->mutable_operand(0); + HloInstruction* gather_indices = gather_instr->mutable_operand(1); + const Shape& gather_indices_shape = gather_indices->shape(); + const Shape& output_shape = gather_instr->shape(); + int64 output_rank = output_shape.dimensions_size(); + + const GatherDimensionNumbers& dim_numbers = + gather_instr->gather_dimension_numbers(); + + int64 gather_loop_trip_count = 1; + for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + gather_loop_trip_count *= gather_indices_shape.dimensions(i); + } + } + + if (!IsInt32(gather_loop_trip_count)) { + return Unimplemented( + "Gather operations with more than 2147483647 gather indices are not " + "supported. This error occurred for %s.", + gather_instr->ToString().c_str()); + } + + TF_ASSIGN_OR_RETURN(HloInstruction * canonical_gather_indices, + CanonicalizeGatherIndices( + gather_indices, dim_numbers.index_vector_dim())); + + CHECK_EQ(gather_loop_trip_count, + canonical_gather_indices->shape().dimensions(0)); + + TF_ASSIGN_OR_RETURN( + HloInstruction * accumulator_init, + CreateGatherLoopAccumulatorInitValue( + computation, output_shape.element_type(), + gather_instr->gather_window_bounds(), gather_loop_trip_count)); + + StatusOr> gather_loop_result_or_error = + WhileUtil::MakeCountedLoop( + computation, gather_loop_trip_count, + {operand, canonical_gather_indices, accumulator_init}, + [&](HloInstruction* indvar, + const std::vector& loop_state) { + return GatherLoopBody(*gather_instr, indvar, loop_state); + }); + + TF_ASSIGN_OR_RETURN(std::vector gather_loop_result, + gather_loop_result_or_error); + + HloInstruction* accumulator_result = gather_loop_result.back(); + TF_ASSIGN_OR_RETURN( + HloInstruction * accumulator_with_window_dims_elided, + ElideWindowDimsFromAccumulator(accumulator_result, dim_numbers)); + + TF_ASSIGN_OR_RETURN( + HloInstruction * accumulator_with_output_gather_dims_decanonicalized, + AdjustGatherDimsInAccumulator(gather_indices->shape(), + accumulator_with_window_dims_elided, + dim_numbers.index_vector_dim())); + + return PermuteGatherAndWindowDims( + accumulator_with_output_gather_dims_decanonicalized, + AsInt64Slice(dim_numbers.output_window_dims()), output_rank); +} + +StatusOr GatherExpander::Run(HloModule* module) { + auto is_nontrivial_gather = [](HloInstruction* inst) { + return inst->opcode() == HloOpcode::kGather && + // Avoid expanding gather ops that produce zero sized tensors, + // instead punt these to ZeroSizedHloElimination. + !ShapeUtil::HasZeroElements(inst->shape()); + }; + + std::vector gather_instrs; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + c_copy_if(computation->instructions(), std::back_inserter(gather_instrs), + is_nontrivial_gather); + } + + for (HloInstruction* inst : gather_instrs) { + TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandGather(inst)); + TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root)); + } + + return !gather_instrs.empty(); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h new file mode 100644 index 0000000000000000000000000000000000000000..c1fc8574da99fff223c7dbb570b4533f76905b9a --- /dev/null +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GATHER_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GATHER_EXPANDER_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// This pass rewrites gather operations into (roughly) while loops of dynamic +// slices. This lets backends that don't support gather directly to +// nevertheless have a minimum level of support. +class GatherExpander : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { return "gather_expander"; } + StatusOr Run(HloModule* module) override; + + private: + StatusOr ExpandGather(HloInstruction* gather_instr); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GATHER_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ba41ee8428cbe7132103df24d552565a8dc2f9f6 --- /dev/null +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gather_expander.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { +TEST(GatherExpanderTest, ErrorStatusOnTooManyIndices) { + const string hlo_text = R"( +HloModule TensorFlowGatherMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2147483647,5] parameter(1) + ROOT gather = s32[2147483647,3,5] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_text)); + + Status status = GatherExpander{}.Run(module.get()).status(); + EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); + + ASSERT_THAT( + status.error_message(), + ::testing::HasSubstr("Gather operations with more than 2147483647 gather " + "indices are not supported.")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 78dc0ad4fcd167c93f19d0c2b18ea72d666897ef..a99e2b7794a399047fb5a77a140bd333214e3f23 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -38,14 +38,7 @@ namespace xla { GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id, size_t pointer_size) - : platform_id_(platform_id), pointer_size_(pointer_size) { - // We currently only support kHostPlatformId for CPU, kCudaPlatformId for - // GPU and kInterpreterPlatformId for Interpreter. Before supporting other - // platforms, we need to test this transfer manager on them. - CHECK(platform_id_ == se::host::kHostPlatformId || - platform_id_ == se::interpreter::kInterpreterPlatformId || - platform_id_ == se::cuda::kCudaPlatformId); -} + : platform_id_(platform_id), pointer_size_(pointer_size) {} se::Platform::Id GenericTransferManager::PlatformId() const { return platform_id_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 9da4fb97fa27a238fead74985cb481a9be1f4a65..93b2f2a4748932e50ce40e8a2f573af922dea8d1 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -241,6 +241,7 @@ cc_library( "gpu_executable.cc", "infeed_thunk.cc", "kernel_thunk.cc", + "memset_thunk.cc", "sequential_thunk.cc", "thunk_schedule.cc", "tuple_thunk.cc", @@ -257,6 +258,7 @@ cc_library( "gpu_executable.h", "infeed_thunk.h", "kernel_thunk.h", + "memset_thunk.h", "sequential_thunk.h", "thunk.h", "thunk_schedule.h", @@ -273,6 +275,7 @@ cc_library( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -293,6 +296,7 @@ cc_library( "//tensorflow/core/platform/default/build_config:cudnn_plugin", "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep + "//tensorflow/stream_executor", ], ) @@ -397,6 +401,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -437,8 +442,10 @@ tf_cc_test( ":fusion_merger", ":instruction_fusion", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -452,6 +459,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:shape_inference", ], @@ -510,9 +518,11 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", + "//tensorflow/compiler/xla/service:gather_expander", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index c137fbc97e29e24edb3603c611a5c8f093bc62a6..3cd30b754c3242f00c704de1afab2282ed827b41 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -45,6 +45,7 @@ void MaybeResolveTupleElements(HloInstruction* instruction, // Returns the bytes read by fusion parameter 'param', by returning the byte // size of 'param' shape (or the cumulative byte sizes of all leaf tuple // elements if 'param' is tuple-shaped). +// // In the special case where all users of 'param' (or all users of a leaf // tuple element if 'param' is tuple-shaped) are Slice instructions, the size // of each slice instruction is accumulated instead, to give a more accurate @@ -63,11 +64,10 @@ double CalculateBytesReadByFusionParameter(HloInstruction* param) { // Slice for a more accurate estimate of bytes read. double bytes = 0.0; for (auto& instruction : instructions) { - if (std::all_of(instruction->users().begin(), instruction->users().end(), - [](const HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kSlice || - instruction->opcode() == HloOpcode::kDynamicSlice; - })) { + if (c_all_of(instruction->users(), [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSlice || + instruction->opcode() == HloOpcode::kDynamicSlice; + })) { // All users are slice: accumulate bytes of all user slice instructions. for (auto& user : instruction->users()) { bytes += ShapeUtil::ByteSizeOf(user->shape()); @@ -199,6 +199,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { ++total_visited_; // Skip 'fusion' instruction if there are no users into which we can merge. if (fusion->users().empty()) { + VLOG(3) << "Not merging " << fusion->name() << ": Has no users."; ++num_fail_no_users_; return Status::OK(); } @@ -208,24 +209,27 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // Input fusion instructions need to be rooted at a particular HLO (e.g. // kReduce), so they shouldn't be further fused either. if (fusion->fusion_kind() != HloInstruction::FusionKind::kLoop) { + VLOG(3) << "Not merging " << fusion->name() << ": Is not loop fusion."; ++num_fail_not_loop_fusion_; return Status::OK(); } // Skip multiple output fusion. It's not yet supported. if (fusion->IsMultiOutputFusion()) { + VLOG(3) << "Not merging " << fusion->name() << ": Is multi-output fusion."; ++num_fail_not_loop_fusion_; return Status::OK(); } // Skip 'fusion' instruction if we cannot merge into all of its users. // Merging into all users enables the removal of 'fusion' from the // computation. - if (!std::all_of(fusion->users().begin(), fusion->users().end(), - [](const HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kFusion && - instruction->fusion_kind() == - HloInstruction::FusionKind::kLoop; - })) { + if (!c_all_of(fusion->users(), [](const HloInstruction* user) { + return user->opcode() == HloOpcode::kFusion && + (user->fusion_kind() == HloInstruction::FusionKind::kLoop || + user->fusion_kind() == HloInstruction::FusionKind::kInput); + })) { + VLOG(3) << "Not merging " << fusion->name() + << ": Some of its users are not loop/input fusion kernels."; ++num_fail_merge_all_users_; return Status::OK(); } @@ -233,18 +237,17 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // Skip 'fusion' instruction if any of its fused instructions are expensive. // This is done to avoid the duplication of expensive instructions, which // would occur if 'fusion' were merged into multiple users. + // // If 'fusion' has just one user, then an earlier fusion pass chose not to // fuse this producer/comsumer pair (likely because of expensive instruction // re-use by the consumer), and so we honor that choice here as well. - if (!std::all_of(fusion->fused_instructions().begin(), - fusion->fused_instructions().end(), - [](const HloInstruction* instruction) { - if (instruction->opcode() != HloOpcode::kParameter && - GpuInstructionFusion::IsExpensive(*instruction)) { - return false; - } - return true; - })) { + if (c_any_of(fusion->fused_instructions(), + [](const HloInstruction* instruction) { + return instruction->opcode() != HloOpcode::kParameter && + GpuInstructionFusion::IsExpensive(*instruction); + })) { + VLOG(3) << "Not merging " << fusion->name() + << ": Contains one or more expensive instructions."; ++num_fail_expensive_fused_instruction_; return Status::OK(); } @@ -253,6 +256,8 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // exceeds the threshold value. if (CalculateFlopsToBytesRatio(fusion) > FusionMerger::GetThresholdFlopsToBytesRatio()) { + VLOG(3) << "Not merging " << fusion->name() + << ": flops-to-bytes ratio is not favorable."; ++num_fail_flops_to_byte_ratio_; return Status::OK(); } @@ -265,6 +270,9 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { const double merged_to_current_bytes_ratio = merged_bytes_transferred / std::max(1.0, current_bytes_transferred); if (merged_to_current_bytes_ratio > 1.10) { + VLOG(3) << "Not merging " << fusion->name() + << ": merged-to-current-bytes-ratio of " + << merged_to_current_bytes_ratio << " is not favorable."; ++num_fail_net_bytes_transferred_ratio_; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index deef5966b80d1b7f16e9982eed9ac5c7131e9d73..2217776c7d5a5f92c520d56222988f80401be9e4 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -16,257 +16,21 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace gpu { namespace { -class FusionMergerTest : public HloTestBase { - protected: - FusionMergerTest() : module_(CreateNewModule()) {} - - // Builds the following computation: - // - // Param - // / | \ - // / | \ - // OnesVec GTE(0) GTE(1) GTE(2) - // \ / \ / - // Add Add OnesVec - // \ / \ / - // \ Add Mul OnesVec - // \ | | / - // \ Mul Add - // \ | / - // \ | / - // Tuple - // - HloComputation* BuildComputation0() { - auto builder = HloComputation::Builder(TestName() + ".Computation0"); - // Create param instruction to access computation state. - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape3_, "param")); - - // Create GetTupleElement instructions for each tuple element. - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, param, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, param, 1)); - auto gte2 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, param, 2)); - - // Create const vector of ones to be used in element-wise computations. - auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); - - // Create simple fusable computation for tuple element 0 (wont get merged). - auto out0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, one_vec, gte0)); - - // Create fusable computation which is dependent on second and third tuple - // elements (will initially be fused on its own). - auto add1 = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte1, gte2)); - - // Create two sub-computations, both of which are users of 'add1'. - - // First sub-computation: out1 = Mul(Add(add1, one_vec), one_vec) - auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, add1, one_vec)); - auto out1 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, add2, one_vec)); - - // Second sub-computation: out2 = Add(Mul(add1, one_vec), one_vec) - auto mul0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, add1, one_vec)); - auto out2 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, mul0, one_vec)); - - // Create output Tuple. - builder.AddInstruction(HloInstruction::CreateTuple({out0, out1, out2})); - return module_->AddEntryComputation(builder.Build()); - } - - // Builds the following computation: - // - // Param - // / \ - // GTE(0) GTE(1) - // | | \ / - // | | Mul - // \ \ | - // \ Mul - // \ | - // OnesVec Mul OnesVec - // \ / \ / - // OnesVec Add Mul OnesVec - // \ | | / - // Mul Add - // \ / - // \ / - // Tuple - // - HloComputation* BuildComputation1() { - auto builder = HloComputation::Builder(TestName() + ".Computation1"); - Shape tuple_shape2_ = ShapeUtil::MakeTupleShape({data_shape_, data_shape_}); - // Create param instruction to access computation state. - auto state = builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape2_, "state")); - - // Create shared sub-computation (will initially be fused on its own). - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, state, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, state, 2)); - // Calculate the flops we need to generate for this shared computation - // to exceed the threshold flops_to_bytes_ratio. - // Note that bytes transferred is multiplied by 3 because there are two - // operands and one output of size 'data_shape_'. - const int64 flops_needed = FusionMerger::GetThresholdFlopsToBytesRatio() * - ShapeUtil::ByteSizeOf(data_shape_) * 3; - const int64 vec_elements = ShapeUtil::ElementsIn(data_shape_); - const int64 iters = (flops_needed + vec_elements - 1) / vec_elements; - - auto mul0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, gte0, gte1)); - for (int i = 0; i < iters; ++i) { - mul0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, gte0, mul0)); - } - - // Create two sub-computations, both of which are users of 'mul0'. - auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); - - // First sub-computation: out0 = Mul(Add(mul0, one_vec), one_vec) - auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, mul0, one_vec)); - auto out0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, add0, one_vec)); - - // Second sub-computation: out1 = Add(Mul(mul0, one_vec), one_vec) - auto mul1 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, mul0, one_vec)); - auto out1 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, mul1, one_vec)); - - // Create output Tuple. - builder.AddInstruction(HloInstruction::CreateTuple({out0, out1})); - return module_->AddEntryComputation(builder.Build()); - } - - // Builds the following computation: - // - // Param - // / | | \ - // / | | \ - // / | | \ - // GTE(0) GTE(1) GTE(2) GTE(3) - // \ / / / - // Add / / - // \ / / - // Add / - // \ / - // \ / - // OnesVec Add OnesVec - // \ / \ / - // OnesVec Add Mul OnesVec - // \ | | / - // Mul Add - // \ / - // \ / - // Tuple - // - HloComputation* BuildComputation2(bool add_extra_input) { - auto builder = HloComputation::Builder(TestName() + ".Computation2"); - Shape state_shape = add_extra_input ? tuple_shape4_ : tuple_shape3_; - // Create param instruction to access computation state. - auto state = builder.AddInstruction( - HloInstruction::CreateParameter(0, state_shape, "state")); - - // Create GetTupleElement instructions for each tuple element. - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, state, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, state, 1)); - auto gte2 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, state, 2)); - - // Create shared fusable computation that reduces its operands. - auto reduce0 = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte0, gte1)); - auto reduce_out = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, reduce0, gte2)); - if (add_extra_input) { - auto gte3 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, state, 3)); - reduce_out = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, reduce_out, gte3)); - } +namespace op = xla::testing::opcode_matchers; - // Create two fusable sub-computations which are dependent on shared - // computation 'reduce_out'. - auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); - - // First sub-computation: out0 = Mul(Add(reduce_out, one_vec), one_vec) - auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, reduce_out, one_vec)); - auto out0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, add2, one_vec)); - - // Second sub-computation: out1 = Add(Mul(reduce_out, one_vec), one_vec) - auto mul0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, reduce_out, one_vec)); - auto out1 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, mul0, one_vec)); - - // Create output Tuple. - builder.AddInstruction(HloInstruction::CreateTuple({out0, out1})); - return module_->AddEntryComputation(builder.Build()); - } - - Shape data_shape_ = ShapeUtil::MakeShape(F32, {4}); - Shape tuple_shape2_ = ShapeUtil::MakeTupleShape({data_shape_, data_shape_}); - Shape tuple_shape3_ = - ShapeUtil::MakeTupleShape({data_shape_, data_shape_, data_shape_}); - Shape tuple_shape4_ = ShapeUtil::MakeTupleShape( - {data_shape_, data_shape_, data_shape_, data_shape_}); - - std::unique_ptr module_; -}; +class FusionMergerTest : public HloTestBase {}; // Tests that we can merge a fusion instruction that is below threshold. // -// Original computation: -// -// Param -// / | \ -// / | \ -// OnesVec GTE(0) GTE(1) GTE(2) -// \ / \ / -// Add Add OnesVec -// \ / \ / -// \ Add Mul OnesVec -// \ | | / -// \ Mul Add -// \ | / -// \ | / -// Tuple -// -// Computation after fusion passes: -// -// Param -// / \ -// Fusion3 Fusion2 -// | / \ -// \ Fusion0 Fusion1 -// \ | / -// \ | / -// Tuple -// // Computation after fusion merger pass (Fusion2 is merged into Fusion0 and // Fusion1): // Param @@ -276,19 +40,50 @@ class FusionMergerTest : public HloTestBase { // Tuple // TEST_F(FusionMergerTest, MergeSharedFusionInstruction) { - auto computation = BuildComputation0(); - // Run standard fusion passes. - EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .ValueOrDie()); - EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .ValueOrDie()); - // Run fusion merger pass, which should merge the shared fusion instruction - // into its two users. - EXPECT_TRUE(FusionMerger().Run(module_.get()).ValueOrDie()); - - auto* root = computation->root_instruction(); + auto module = tools::Parse(R"( +HloModule MergeSharedFusionInstruction + +comp.3 { + constant.param_0 = f32[4]{0} parameter(0) + param.param_1.2 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(1) + get-tuple-element.6 = f32[4]{0} get-tuple-element(param.param_1.2), index=0 + ROOT add.7 = f32[4]{0} add(constant.param_0, get-tuple-element.6) +} + +comp.2 { + param.param_1.1 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0) + get-tuple-element.4 = f32[4]{0} get-tuple-element(param.param_1.1), index=1 + get-tuple-element.5 = f32[4]{0} get-tuple-element(param.param_1.1), index=2 + ROOT add.6 = f32[4]{0} add(get-tuple-element.4, get-tuple-element.5) +} + +comp.1 { + add.1.param_1.1 = f32[4]{0} parameter(1) + constant.param_1.3 = f32[4]{0} parameter(0) + add.5 = f32[4]{0} add(add.1.param_1.1, constant.param_1.3) + ROOT multiply.3 = f32[4]{0} multiply(add.5, constant.param_1.3) +} + +comp { + add.1.param_1 = f32[4]{0} parameter(1) + constant.param_1.1 = f32[4]{0} parameter(0) + multiply.2 = f32[4]{0} multiply(add.1.param_1, constant.param_1.1) + ROOT add.4 = f32[4]{0} add(multiply.2, constant.param_1.1) +} + +ENTRY MergeSharedFusionInstruction.Computation0 { + constant = f32[4]{0} constant({1, 1, 1, 1}) + param = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0) + fusion.3 = f32[4]{0} fusion(constant, param), kind=kLoop, calls=comp.3 + fusion.4 = f32[4]{0} fusion(param), kind=kLoop, calls=comp.2 + fusion.5 = f32[4]{0} fusion(constant, fusion.4), kind=kLoop, calls=comp.1 + fusion.6 = f32[4]{0} fusion(constant, fusion.4), kind=kLoop, calls=comp + ROOT tuple = (f32[4]{0}, f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.5, fusion.6) +})") + .ValueOrDie(); + EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie()); + + auto* root = module->entry_computation()->root_instruction(); EXPECT_EQ(HloOpcode::kTuple, root->opcode()); // Check operand 0 (not merged). Should have 4 instructions. auto* operand0 = root->operand(0); @@ -307,156 +102,188 @@ TEST_F(FusionMergerTest, MergeSharedFusionInstruction) { // Tests that we do not merge a fusion instruction that above flops to bytes // threshold. // -// Original computation: -// -// Param -// / \ -// GTE(0) GTE(1) -// | | \ / -// | | Mul -// \ \ | -// \ Mul -// \ | -// OnesVec Mul OnesVec -// \ / \ / -// OnesVec Add Mul OnesVec -// \ | | / -// Mul Add -// \ / -// \ / -// Tuple -// -// Computation after fusion passes and fusion merger pass (Fusion2 is not -// merged because it exceeds the threshold flops to bytes ratio). -// -// Param -// | -// Fusion2 -// / \ -// Fusion0 Fusion1 -// \ / -// Tuple -// +// Fusion2 is not merged because it exceeds the threshold flops-to-bytes ratio. TEST_F(FusionMergerTest, FlopsToBytesRatioThresholdExceeded) { - BuildComputation1(); - // Run standard fusion passes. - EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .ValueOrDie()); - EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .ValueOrDie()); + auto module = tools::Parse(R"( +HloModule FlopsToBytesRatioThresholdExceeded + +comp.2 { + state.param_1.1 = (f32[4]{0}, f32[4]{0}) parameter(0) + get-tuple-element.3 = f32[4]{0} get-tuple-element(state.param_1.1), index=0 + get-tuple-element.4 = f32[4]{0} get-tuple-element(state.param_1.1), index=2 + multiply.29 = f32[4]{0} multiply(get-tuple-element.3, get-tuple-element.4) + multiply.30 = f32[4]{0} multiply(get-tuple-element.3, multiply.29) + multiply.31 = f32[4]{0} multiply(get-tuple-element.3, multiply.30) + multiply.32 = f32[4]{0} multiply(get-tuple-element.3, multiply.31) + multiply.33 = f32[4]{0} multiply(get-tuple-element.3, multiply.32) + multiply.34 = f32[4]{0} multiply(get-tuple-element.3, multiply.33) + multiply.35 = f32[4]{0} multiply(get-tuple-element.3, multiply.34) + multiply.36 = f32[4]{0} multiply(get-tuple-element.3, multiply.35) + multiply.37 = f32[4]{0} multiply(get-tuple-element.3, multiply.36) + multiply.38 = f32[4]{0} multiply(get-tuple-element.3, multiply.37) + multiply.39 = f32[4]{0} multiply(get-tuple-element.3, multiply.38) + multiply.40 = f32[4]{0} multiply(get-tuple-element.3, multiply.39) + ROOT multiply.41 = f32[4]{0} multiply(get-tuple-element.3, multiply.40) +} + +comp.1 { + multiply.12.param_1.1 = f32[4]{0} parameter(1) + constant.param_1.3 = f32[4]{0} parameter(0) + add.3 = f32[4]{0} add(multiply.12.param_1.1, constant.param_1.3) + ROOT multiply.16 = f32[4]{0} multiply(add.3, constant.param_1.3) +} + +comp { + multiply.12.param_1 = f32[4]{0} parameter(1) + constant.param_1.1 = f32[4]{0} parameter(0) + multiply.15 = f32[4]{0} multiply(multiply.12.param_1, constant.param_1.1) + ROOT add.2 = f32[4]{0} add(multiply.15, constant.param_1.1) +} + +ENTRY FlopsToBytesRatioThresholdExceeded.Computation1 { + constant = f32[4]{0} constant({1, 1, 1, 1}) + state = (f32[4]{0}, f32[4]{0}) parameter(0) + fusion.2 = f32[4]{0} fusion(state), kind=kLoop, calls=comp.2 + fusion.3 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp.1 + fusion.4 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp + ROOT tuple = (f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.4) +})") + .ValueOrDie(); // Run fusion merger pass, which should detect that the flops/bytes of the // shared fusion instruction exceeds the threshold ratio, and therefore // cannot be merged with other fusion instructions. - EXPECT_FALSE(FusionMerger().Run(module_.get()).ValueOrDie()); + EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie()); } // Tests that threshold for bytes transferred if merged is exceeded. // -// Original computation: -// -// Param -// / | | \ -// / | | \ -// / | | \ -// GTE(0) GTE(1) GTE(2) GTE(3) -// \ / / / -// Add / / -// \ / / -// Add / -// \ / -// \ / -// OnesVec Add OnesVec -// \ / \ / -// OnesVec Add Mul OnesVec -// \ | | / -// Mul Add -// \ / -// \ / -// Tuple -// -// Computation after fusion passes and fusion merger pass. Fusion2 is not -// merged because it exceeds the threshold bytes transferred. This is because -// the bytes read by Fusion2 (when replicated if the instruction is merged -// into Fusion0 and Fusion1) would exceed the bytes transferred threshold. -// -// Param -// | -// Fusion2 -// / \ -// Fusion0 Fusion1 -// \ / -// Tuple -// +// Fusion2 is not merged because it exceeds the threshold bytes transferred. +// This is because the bytes read by Fusion2 (when replicated if the instruction +// is merged into Fusion0 and Fusion1) would exceed the bytes transferred +// threshold. TEST_F(FusionMergerTest, BytesTransferredThresholdExeceeded) { - BuildComputation2(/*add_extra_input=*/true); - // Run standard fusion passes. - EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .ValueOrDie()); - EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .ValueOrDie()); + auto module = tools::Parse(R"( +HloModule BytesTransferredThresholdExeceeded + +comp.2 { + state.param_1.1 = (f32[4]{0}, f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0) + get-tuple-element.7 = f32[4]{0} get-tuple-element(state.param_1.1), index=0 + get-tuple-element.8 = f32[4]{0} get-tuple-element(state.param_1.1), index=1 + add.9 = f32[4]{0} add(get-tuple-element.7, get-tuple-element.8) + get-tuple-element.9 = f32[4]{0} get-tuple-element(state.param_1.1), index=2 + add.10 = f32[4]{0} add(add.9, get-tuple-element.9) + get-tuple-element.10 = f32[4]{0} get-tuple-element(state.param_1.1), index=3 + ROOT add.11 = f32[4]{0} add(add.10, get-tuple-element.10) +} + +comp.1 { + add.2.param_1.1 = f32[4]{0} parameter(1) + constant.param_1.3 = f32[4]{0} parameter(0) + add.6 = f32[4]{0} add(add.2.param_1.1, constant.param_1.3) + ROOT multiply.3 = f32[4]{0} multiply(add.6, constant.param_1.3) +} + +comp { + add.2.param_1 = f32[4]{0} parameter(1) + constant.param_1.1 = f32[4]{0} parameter(0) + multiply.2 = f32[4]{0} multiply(add.2.param_1, constant.param_1.1) + ROOT add.5 = f32[4]{0} add(multiply.2, constant.param_1.1) +} + +ENTRY BytesTransferredThresholdExeceeded.Computation2 { + constant = f32[4]{0} constant({1, 1, 1, 1}) + state = (f32[4]{0}, f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0) + fusion.2 = f32[4]{0} fusion(state), kind=kLoop, calls=comp.2 + fusion.3 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp.1 + fusion.4 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp + ROOT tuple = (f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.4) +})") + .ValueOrDie(); // Run fusion merger pass, which should detect that the net bytes transferred // (if merged) would increase. - EXPECT_FALSE(FusionMerger().Run(module_.get()).ValueOrDie()); + EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie()); } // Tests that threshold for bytes transferred if merged is not exceeded. // -// Original computation: -// -// Param -// / | \ -// / | \ -// / | \ -// GTE(0) GTE(1) GTE(2) -// \ / / -// Add / -// \ / -// OnesVec Add OnesVec -// \ / \ / -// OnesVec Add Mul OnesVec -// \ / \ / -// Mul Add -// \ / -// \ / -// Tuple -// -// Computation after fusion passes: -// -// Param -// | -// Fusion2 -// / \ -// Fusion0 Fusion1 -// \ / -// Tuple -// -// Computation after fusion merger pass (Fusion2 is merged into Fusion0 and -// Fusion1, because bytes read from Param by Fusion2 is reduced for this test -// which makes the merge operation into its operand below the bytes -// transferred threshold. -// -// Param -// / \ -// Fusion0 Fusion1 -// \ / -// Tuple -// +// Fusion2 is merged into Fusion0 and Fusion1, because bytes read from Param by +// Fusion2 is reduced for this test which makes the merge operation into its +// operand below the bytes transferred threshold. TEST_F(FusionMergerTest, BytesTransferredThresholdNotExeceeded) { - BuildComputation2(/*add_extra_input=*/false); - // Run standard fusion passes. - EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .ValueOrDie()); - EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .ValueOrDie()); + auto module = tools::Parse(R"( +HloModule BytesTransferredThresholdNotExeceeded + +comp.2 { + state.param_1.1 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0) + get-tuple-element.5 = f32[4]{0} get-tuple-element(state.param_1.1), index=0 + get-tuple-element.6 = f32[4]{0} get-tuple-element(state.param_1.1), index=1 + add.7 = f32[4]{0} add(get-tuple-element.5, get-tuple-element.6) + get-tuple-element.7 = f32[4]{0} get-tuple-element(state.param_1.1), index=2 + ROOT add.8 = f32[4]{0} add(add.7, get-tuple-element.7) +} + +comp.1 { + add.1.param_1.1 = f32[4]{0} parameter(1) + constant.param_1.3 = f32[4]{0} parameter(0) + add.5 = f32[4]{0} add(add.1.param_1.1, constant.param_1.3) + ROOT multiply.3 = f32[4]{0} multiply(add.5, constant.param_1.3) +} + +comp { + add.1.param_1 = f32[4]{0} parameter(1) + constant.param_1.1 = f32[4]{0} parameter(0) + multiply.2 = f32[4]{0} multiply(add.1.param_1, constant.param_1.1) + ROOT add.4 = f32[4]{0} add(multiply.2, constant.param_1.1) +} + +ENTRY BytesTransferredThresholdNotExeceeded.Computation2 { + constant = f32[4]{0} constant({1, 1, 1, 1}) + state = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0) + fusion.2 = f32[4]{0} fusion(state), kind=kLoop, calls=comp.2 + fusion.3 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp.1 + fusion.4 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp + ROOT tuple = (f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.4) +})") + .ValueOrDie(); // Run fusion merger pass, which should detect that the net bytes transferred // (if merged) would not increase. - EXPECT_TRUE(FusionMerger().Run(module_.get()).ValueOrDie()); + EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie()); +} + +// Check that we're willing to merge f1_computation into f2_computation, even +// though f2 is an input fusion node. +TEST_F(FusionMergerTest, WillMergeIntoInputFusion) { + auto module = tools::Parse(R"( + HloModule m + + f1_computation { + f1_p0 = f32[10]{0} parameter(0) + ROOT f1_root = f32[10]{0} add(f1_p0, f1_p0) + } + + add_computation { + add_lhs = f32[] parameter(0) + add_rhs = f32[] parameter(1) + ROOT add_root = f32[] add(add_lhs, add_rhs) + } + + f2_computation { + f2_p0 = f32[10]{0} parameter(0) + f2_mul = f32[10]{0} multiply(f2_p0, f2_p0) + f2_zero = f32[] constant(0) + ROOT f2_root = f32[] reduce(f2_mul, f2_zero), dimensions={0}, + to_apply=add_computation + } + + ENTRY entry { + p0 = f32[10]{0} parameter(0) + f1 = f32[10]{0} fusion(p0), kind=kLoop, calls=f1_computation + ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation + })") + .ValueOrDie(); + EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Fusion(op::Parameter())); } } // namespace diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index ba482793e7632f0f423cc9da0dd9620bdf29c642..38668ff455a44c7ef99b57b750f1a3b18a90bd2c 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -49,7 +49,7 @@ struct MatrixDescriptor { // rhs_matrix, and stores the result to output_matrix. template bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, se::Stream* stream) { + MatrixDescriptor output_matrix, double alpha, se::Stream* stream) { DCHECK(!output_matrix.transpose); se::DeviceMemory lhs_data(lhs_matrix.data); @@ -65,7 +65,7 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, return stream ->ThenBlasGemm( lhs_transpose, rhs_transpose, output_matrix.num_rows, - output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/1.0, + output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha, lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, &output_data, /*leading dim of output=*/output_matrix.num_rows) @@ -89,7 +89,7 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, template bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, + MatrixDescriptor output_matrix, double alpha, se::blas::ComputationType computation_type, se::blas::AlgorithmType algorithm, se::Stream* stream, se::blas::ProfileResult* output_profile_result) { @@ -108,11 +108,13 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, return stream ->ThenBlasGemmWithAlgorithm( lhs_transpose, rhs_transpose, output_matrix.num_rows, - output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/1.0, - lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, - /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, - &output_data, /*leading dim of output=*/output_matrix.num_rows, - computation_type, algorithm, output_profile_result) + output_matrix.num_cols, /*size of reduce dim=*/k, + /*alpha=*/static_cast(alpha), lhs_data, + /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, + /*beta=*/static_cast(0.0f), &output_data, + /*leading dim of output=*/output_matrix.num_rows, computation_type, + algorithm, output_profile_result) .ok(); } @@ -125,8 +127,8 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, template StatusOr DoGemmAutotune( MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, se::blas::ComputationType computation_type, - se::Stream* stream) { + MatrixDescriptor output_matrix, double alpha, + se::blas::ComputationType computation_type, se::Stream* stream) { std::vector algorithms; CHECK(stream->parent()->GetBlasGemmAlgorithms(&algorithms)); @@ -138,8 +140,8 @@ StatusOr DoGemmAutotune( // non-null ProfileResult, DoGemmWithAlgorithm should always return true, // and the actual success-ness is returned in ProfileResult::is_valid. CHECK(DoGemmWithAlgorithm(lhs_matrix, rhs_matrix, output_matrix, - computation_type, algorithm, stream, - &profile_result)); + alpha, computation_type, algorithm, + stream, &profile_result)); if (profile_result.is_valid() && profile_result.elapsed_time_in_ms() < best_result.elapsed_time_in_ms()) { @@ -161,6 +163,8 @@ StatusOr DoGemmAutotune( // DoGemm/DoGemmWithAlgorithm/DoGemmAutotune. auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm) { switch (type) { + case F16: + return &DoGemm; case F32: return &DoGemm; case F64: @@ -172,6 +176,8 @@ auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm) { auto GetGemmWithAlgorithmFn(PrimitiveType type) -> decltype(&DoGemmWithAlgorithm) { switch (type) { + case F16: + return &DoGemmWithAlgorithm; case F32: return &DoGemmWithAlgorithm; case F64: @@ -182,6 +188,8 @@ auto GetGemmWithAlgorithmFn(PrimitiveType type) } auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune) { switch (type) { + case F16: + return &DoGemmAutotune; case F32: return &DoGemmAutotune; case F64: @@ -196,6 +204,10 @@ auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune) { // separately from the precision of the inputs and result. se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { switch (type) { + case F16: + // Use F32 as computation type for F16 as we currently only implement the + // cuDNN pseudo half configuration for half precision. + return se::blas::ComputationType::kF32; case F32: return se::blas::ComputationType::kF32; case F64: @@ -212,7 +224,8 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, const Shape& output_shape, bool transpose_lhs, - bool transpose_rhs, const HloInstruction* hlo_instruction) + bool transpose_rhs, double alpha, + const HloInstruction* hlo_instruction) : Thunk(Kind::kGemm, hlo_instruction), lhs_buffer_(lhs_buffer), rhs_buffer_(rhs_buffer), @@ -221,7 +234,8 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, rhs_shape_(rhs_shape), output_shape_(output_shape), transpose_lhs_(transpose_lhs), - transpose_rhs_(transpose_rhs) {} + transpose_rhs_(transpose_rhs), + alpha_(alpha) {} tensorflow::Status GemmThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { @@ -290,7 +304,7 @@ tensorflow::Status GemmThunk::ExecuteOnStream( if (autotune_it == autotune_results_.end()) { StatusOr best_algorithm = GetGemmAutotuneFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, - computation_type, stream); + alpha_, computation_type, stream); autotune_it = autotune_results_.insert({device_name, best_algorithm}).first; @@ -311,12 +325,15 @@ tensorflow::Status GemmThunk::ExecuteOnStream( VLOG(2) << "Using algorithm " << algorithm << " chosen by autotuning on GemmThunk " << this; return GetGemmWithAlgorithmFn(element_type)( - lhs_matrix, rhs_matrix, output_matrix, computation_type, algorithm, - stream, + lhs_matrix, rhs_matrix, output_matrix, alpha_, computation_type, + algorithm, stream, /*output_profile_result=*/nullptr); } + + // Autotune will fail when CUDA 8 and GPU sm_50 or older are used. + // Use the older Gemm API in this case. return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, - stream); + alpha_, stream); }; bool launch_ok; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 8c6a1f51a8a09ef78950dfe7e89994a3fe247f49..df3edcefef898d465cd5ddc53e5d06a966a31f88 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -34,15 +34,16 @@ namespace gpu { // This is thread-compatible. class GemmThunk : public Thunk { public: - // Constructs a thunk that computes "output = lhs rhs" using BLAS gemm. - // transpose_lhs and transpose_rhs indicate whether gemm should transpose the - // lhs and rhs operand. hlo_instruction is as in Thunk. + // Constructs a thunk that computes "output = (lhs rhs) * alpha" using + // BLAS gemm. transpose_lhs and transpose_rhs indicate whether gemm should + // transpose the lhs and rhs operand. hlo_instruction is as in Thunk. alpha is + // a constant. GemmThunk(const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, const Shape& output_shape, bool transpose_lhs, bool transpose_rhs, - const HloInstruction* hlo_instruction); + double alpha, const HloInstruction* hlo_instruction); GemmThunk(const GemmThunk&) = delete; GemmThunk& operator=(const GemmThunk&) = delete; @@ -72,6 +73,7 @@ class GemmThunk : public Thunk { const bool transpose_lhs_; const bool transpose_rhs_; + const double alpha_; // Maps device names (StreamExecutor::DeviceDescription::name()) to autotune // results. The map's value is the best algorithm we've found for this thunk diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 28ebd034ee0c89137f4e6eb417d8a37f4a00af7a..07be2a0cf90c326af6e41764e79950db546e43e4 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -33,8 +33,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" @@ -164,6 +166,9 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, /*rewrite_grad_op=*/true, /*use_fusion=*/false); + // Rewrite gather ops into smaller ones. + pass.AddPass(); + // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. pipeline.AddPass(); @@ -176,6 +181,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, pass.AddPass(); pass.AddPass(); pass.AddPass(); + pass.AddPass(); } pipeline.AddPass( @@ -241,6 +247,22 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } + { + HloPassPipeline pipeline("layout_assignment"); + pipeline.AddPass( + hlo_module->mutable_entry_computation_layout()); + + // The LayoutAssignment pass may leave behind kCopy instructions which are + // duplicate or NOPs, so remove them with algebraic simplification and CSE. + pipeline.AddPass>( + /*is_layout_sensitive=*/true, + /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { + return true; + }); + pipeline.AddPass(/*is_layout_sensitive=*/true); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + { HloPassFix fusion("fusion"); fusion.AddInvariantChecker(); @@ -277,15 +299,6 @@ tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { HloPassPipeline pipeline("GPU-ir-emit-prepare"); pipeline.AddInvariantChecker(); - pipeline.AddPass( - hlo_module->mutable_entry_computation_layout()); - - // The LayoutAssignment pass may leave behind kCopy instructions which are - // duplicate or NOPs, so remove them with algebraic simplification and CSE. - pipeline.AddPass>( - /*is_layout_sensitive=*/true, - [](const Shape&, const Shape&) { return true; }); - pipeline.AddPass(/*is_layout_sensitive=*/true); // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which // materializes the value) or missing a necessary copy (later pass removes an @@ -658,6 +671,8 @@ StatusOr> GpuCompiler::RunBackend( if (module->config().hlo_profiling_enabled()) { HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); + cost_analysis.set_bytes_per_second( + stream_exec->GetDeviceDescription().memory_bandwidth()); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); profile_index_map = MakeUnique(*module); profile_printer = diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 623d6714de501000e38b7698620925f66425f157..04b37d913e0bc8f8226057f107da05fd1e675010 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -46,12 +46,14 @@ namespace { class HloExecutionProfiler { public: // If profiling is enabled, start an execution timer running. - explicit HloExecutionProfiler(bool do_profile, HloExecutionProfile* profile, - se::Stream* stream, - const HloComputation* computation) + explicit HloExecutionProfiler( + bool do_profile, HloExecutionProfile* profile, se::Stream* stream, + const std::vector::SmartPtr>& sub_streams, + const HloComputation* computation) : do_profile_(do_profile), profile_(profile), stream_(stream), + sub_streams_(sub_streams), computation_(computation) { if (do_profile_) { clock_rate_ghz_ = @@ -70,6 +72,7 @@ class HloExecutionProfiler { CHECK(!finished_execution_) << "Call FinishExecution only once!"; finished_execution_ = true; if (do_profile_) { + stream_->ThenWaitFor(&sub_streams_); stream_->ThenStopTimer(execution_timer_.get()); stream_->BlockHostUntilDone().IgnoreError(); profile_->set_total_cycles_executed( @@ -88,6 +91,7 @@ class HloExecutionProfiler { // that the hlo_instruction took to execute in the profile. void FinishOperation(const HloInstruction* hlo_instruction) { if (do_profile_) { + stream_->ThenWaitFor(&sub_streams_); stream_->ThenStopTimer(per_op_timer_.get()); stream_->BlockHostUntilDone().IgnoreError(); profile_->SetCyclesTakenBy( @@ -100,6 +104,7 @@ class HloExecutionProfiler { double clock_rate_ghz_; HloExecutionProfile* profile_; se::Stream* stream_; + const std::vector::SmartPtr>& sub_streams_; const HloComputation* computation_; std::unique_ptr execution_timer_; std::unique_ptr per_op_timer_; @@ -147,13 +152,9 @@ Status GpuExecutable::ExecuteThunks( LOG(WARNING) << "PROFILING: profiling is enabled"; } - HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream, - hlo_module_->entry_computation()); - - uint64 start_micros = tensorflow::Env::Default()->NowMicros(); - // Stream 0 indicates `main_stream` and substreams start from stream 1. std::vector::SmartPtr> sub_streams; + sub_streams.reserve(thunk_schedule_->StreamCount() - 1); while (sub_streams.size() + 1 < thunk_schedule_->StreamCount()) { sub_streams.emplace_back(); TF_ASSIGN_OR_RETURN( @@ -161,6 +162,10 @@ Status GpuExecutable::ExecuteThunks( run_options->BorrowStream(main_stream->parent()->device_ordinal())); } + HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream, + sub_streams, hlo_module_->entry_computation()); + uint64 start_micros = tensorflow::Env::Default()->NowMicros(); + // The next event enqueued on stream N must not run until the thunk at // last_blocking_thunk_for_stream[N] completes. std::map last_blocking_thunk_for_stream; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index b5962f069bf499c913bd5479f263a7cb77c00555..85ecbe8fdb34700ca738b99ddd9ea615afc35da3 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -25,13 +25,19 @@ namespace gpu { namespace { bool IsFusile(const HloInstruction& hlo) { + // Don't fuse get-tuple-element on GPU: We can, but it's slower than not + // fusing. We never generate kernels for unfused GTEs. Instead, if an + // unfused GTE is an input to a kernel (including a fusion kernel), we + // compute the address of the GTE at the top of the kernel. Often we know the + // address of the GTE result statically, so we can do this without chasing any + // pointers. return (hlo.IsElementwise() && hlo.operand_count() > 0) || + hlo.opcode() == HloOpcode::kBitcast || hlo.opcode() == HloOpcode::kBroadcast || hlo.opcode() == HloOpcode::kConcatenate || hlo.opcode() == HloOpcode::kDynamicSlice || hlo.opcode() == HloOpcode::kDynamicUpdateSlice || hlo.opcode() == HloOpcode::kFusion || - hlo.opcode() == HloOpcode::kGetTupleElement || hlo.opcode() == HloOpcode::kPad || hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kReduceWindow || @@ -46,6 +52,34 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); + // Check if we can use output fusion for (A @ B) * alpha + if (producer->opcode() == HloOpcode::kDot) { + if (consumer->opcode() == HloOpcode::kMultiply) { + CHECK_EQ(consumer->operand_count(), 2); + int64 other_operand_index = 1 - operand_index; + const HloInstruction* alpha = consumer->operand(other_operand_index); + if (alpha->opcode() == HloOpcode::kConstant && + ShapeUtil::IsScalar(alpha->shape())) { + return true; + } + } + } + + // Only allow to fuse transpose into an output fusion. + if (consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() == HloInstruction::FusionKind::kOutput) { + if (producer->opcode() != HloOpcode::kTranspose) { + return false; + } + // Check that the transpose is the operand of a dot. + auto producer_operand_index = consumer->operand_index(producer); + auto fused_parameter = consumer->fused_parameter(producer_operand_index); + const std::vector& fused_parameter_users = + fused_parameter->users(); + return (fused_parameter_users.size() == 1 && + fused_parameter_users[0]->opcode() == HloOpcode::kDot); + } + // Output fusion is not currently supported on GPUs. if (producer->opcode() == HloOpcode::kFusion) { return false; @@ -70,17 +104,6 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } - // We may need to know original operand layout to emit input fusion, and so - // far, we merely use the layout of an operand of the fusion node, which means - // we must fuse only elementwise operations. This restriction should be lifted - // later if we need to fuse other operations, e.g. transpose, for performance. - if ((IsReductionToVector(*consumer) || - (HloOpcode::kFusion == consumer->opcode() && - HloInstruction::FusionKind::kInput == consumer->fusion_kind())) && - !producer->IsElementwise()) { - return false; - } - // Cost condition: not fuse (simple, expensive producers) and (consumers who // reuse operand elements). if (producer->opcode() != HloOpcode::kFusion && @@ -98,6 +121,9 @@ HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( if (IsReductionToVector(*consumer)) { return HloInstruction::FusionKind::kInput; } + if (producer->opcode() == HloOpcode::kDot) { + return HloInstruction::FusionKind::kOutput; + } if (HloOpcode::kFusion == consumer->opcode()) { return consumer->fusion_kind(); } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 2d6dad27a59978da6e4719afc50ebee5e641dde0..4b231c449f8f101127b4d30bfff20c69d8cef5c1 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace op = xla::testing::opcode_matchers; @@ -137,30 +138,119 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { .ValueOrDie()); } -TEST_F(InstructionFusionTest, GetTupleElementFused) { - HloComputation::Builder builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - Shape tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape}); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "param")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, param, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, param, 1)); - builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, gte0, gte1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); +// Tests that broadcasts fused into a fusion with a reduce root. +TEST_F(InstructionFusionTest, BroadcastIntoReduce) { + auto module = tools::Parse(R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY BroadcastIntoReduce { + constant = f32[] constant(1) + broadcast = f32[16,16,16,16]{3,2,1,0} broadcast(constant), dimensions={} + constant.1 = f32[] constant(0) + ROOT reduce = f32[] reduce(broadcast, constant.1), dimensions={0,1,2,3}, + to_apply=add + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Reduce(op::Broadcast(op::Parameter()), op::Parameter())); +} + +TEST_F(InstructionFusionTest, BitcastIntoAdd) { + auto module = tools::Parse(R"( + HloModule test_module + + ENTRY BroadcastIntoAdd { + p0 = f32[4,1,1]{2,1,0} parameter(0) + p1 = f32[4,1]{1,0} parameter(1) + bitcast = f32[4,1]{1,0} bitcast(p0) + ROOT add = f32[4,1] add(bitcast, p1) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Add(op::Bitcast(op::Parameter()), op::Parameter())); +} + +TEST_F(InstructionFusionTest, AddIntoBitcast) { + auto module = tools::Parse(R"( + HloModule test_module + + ENTRY BroadcastIntoAdd { + p0 = f32[4,1,1]{2,1,0} parameter(0) + p1 = f32[4,1]{1,0} parameter(1) + add = f32[4,1] add(p0, p1) + ROOT bitcast = f32[4,1,1] bitcast(add) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Bitcast(op::Add(op::Parameter(), op::Parameter()))); +} + +TEST_F(InstructionFusionTest, DontFuseGTE) { + auto module = tools::Parse(R"( + HloModule test_module + ENTRY DontFuseGTE { + p0 = (f32[10], f32[10]) parameter(0) + gte0 = f32[10] get-tuple-element(p0), index=0 + gte1 = f32[10] get-tuple-element(p0), index=1 + ROOT add = f32[10] add(gte0, gte1) + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, DotOutputFusion) { + auto module = tools::Parse(R"( + HloModule test_module + ENTRY OutputFusion { + constant = f32[] constant(3) + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + transpose = f32[3,4]{1,0} transpose(p1), dimensions={1, 0} + dot = f32[4,4]{1,0} dot(p0, transpose) + ROOT mul = f32[4,4] multiply(constant, dot) + })") + .ValueOrDie(); + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) .ValueOrDie()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(HloOpcode::kFusion, root->opcode()); - HloInstruction* fused_root = root->fused_expression_root(); - EXPECT_EQ(HloOpcode::kAdd, fused_root->opcode()); - // Check that operands of 'fused_root' are GTE. - EXPECT_EQ(HloOpcode::kGetTupleElement, fused_root->operand(0)->opcode()); - EXPECT_EQ(HloOpcode::kGetTupleElement, fused_root->operand(1)->opcode()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT( + root->fused_expression_root(), + op::Multiply(op::Parameter(), + op::Dot(op::Parameter(), op::Transpose(op::Parameter())))); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 2f65edffea81db7dba1f8545f92b27ea622044e7..32413f975a40c1abc334b16e81097bb44f56a44a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -49,8 +49,10 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, // The inputs and the output must // 1) be matrices with no padding and a non-zero number of elements, // 2) have an allowed element type. - bool type_is_allowed = (output_shape.element_type() == F32 || - output_shape.element_type() == F64); + PrimitiveType output_primitive_type = output_shape.element_type(); + bool type_is_allowed = + (output_primitive_type == F16 || output_primitive_type == F32 || + output_primitive_type == F64); return type_is_allowed && IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && IsRank2WithNoPadding(output_shape) && @@ -87,6 +89,19 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { return true; } + if (hlo.opcode() == HloOpcode::kFusion && + hlo.fusion_kind() == HloInstruction::FusionKind::kOutput && + hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply) { + // Try to find the dot inside the output fusion node. + const HloInstruction* dot = hlo.fused_expression_root()->operand(0); + if (dot->opcode() != HloOpcode::kDot) { + dot = hlo.fused_expression_root()->operand(1); + } + if (dot->opcode() == HloOpcode::kDot) { + return ImplementedAsGemm(*dot); + } + } + return false; } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index a3df67a87344d6ece2ea9047321ad9542c13f8cf..1e0db2821a2c212d0f212ae94ab69231bc6053ea 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" @@ -438,6 +439,32 @@ Status IrEmitter::HandleSelect(HloInstruction* select) { return IrEmitter::DefaultAction(select); } +namespace { +llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* ir_builder) { + return ir_builder->CreateExtractValue(x, {0}); +} + +llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* ir_builder) { + return ir_builder->CreateExtractValue(x, {1}); +} + +std::pair MultiplyComplex( + llvm::Value* lhs_value, llvm::Value* rhs_value, + llvm::IRBuilder<>* ir_builder) { + llvm::Value* lhs_real = Real(lhs_value, ir_builder); + llvm::Value* lhs_imag = Imag(lhs_value, ir_builder); + llvm::Value* rhs_real = Real(rhs_value, ir_builder); + llvm::Value* rhs_imag = Imag(rhs_value, ir_builder); + llvm::Value* real_result1 = ir_builder->CreateFMul(lhs_real, rhs_real); + llvm::Value* real_result2 = ir_builder->CreateFMul(lhs_imag, rhs_imag); + llvm::Value* real_result = ir_builder->CreateFSub(real_result1, real_result2); + llvm::Value* imag_result1 = ir_builder->CreateFMul(lhs_real, rhs_imag); + llvm::Value* imag_result2 = ir_builder->CreateFMul(lhs_imag, rhs_real); + llvm::Value* imag_result = ir_builder->CreateFAdd(imag_result1, imag_result2); + return {real_result, imag_result}; +} +} // namespace + Status IrEmitter::HandleDot(HloInstruction* dot) { auto lhs_instruction = dot->operand(0); auto rhs_instruction = dot->operand(1); @@ -456,21 +483,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { rhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_); llvm::Value* result; if (ShapeUtil::ElementIsComplex(lhs_shape)) { - auto real = [&](llvm::Value* x) { - return ir_builder_.CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_.CreateExtractValue(x, {1}); - }; - llvm::Value* real_result = ir_builder_.CreateFSub( - ir_builder_.CreateFMul(real(lhs_value), real(rhs_value)), - ir_builder_.CreateFMul(imag(lhs_value), imag(rhs_value))); - llvm::Value* imag_result = ir_builder_.CreateFAdd( - ir_builder_.CreateFMul(real(lhs_value), imag(rhs_value)), - ir_builder_.CreateFMul(imag(lhs_value), real(rhs_value))); + auto value = MultiplyComplex(lhs_value, rhs_value, &ir_builder_); result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType()); - result = ir_builder_.CreateInsertValue(result, real_result, {0}); - result = ir_builder_.CreateInsertValue(result, imag_result, {1}); + result = ir_builder_.CreateInsertValue(result, value.first, {0}); + result = ir_builder_.CreateInsertValue(result, value.second, {1}); } else { result = ir_builder_.CreateFMul(lhs_value, rhs_value); } @@ -548,20 +564,13 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { llvm::Value* accum = ir_builder_.CreateLoad(accum_address); llvm::Value* updated_accum; if (ShapeUtil::ElementIsComplex(lhs_shape)) { -#define REAL(x) ir_builder_.CreateExtractValue(x, {0}) -#define IMAG(x) ir_builder_.CreateExtractValue(x, {1}) - llvm::Value* product_real = ir_builder_.CreateFSub( - ir_builder_.CreateFMul(REAL(lhs_element), REAL(rhs_element)), - ir_builder_.CreateFMul(IMAG(lhs_element), IMAG(rhs_element))); - llvm::Value* product_imag = ir_builder_.CreateFAdd( - ir_builder_.CreateFMul(REAL(lhs_element), IMAG(rhs_element)), - ir_builder_.CreateFMul(IMAG(lhs_element), REAL(rhs_element))); - updated_accum = ir_builder_.CreateInsertValue( - accum, ir_builder_.CreateFAdd(REAL(accum), product_real), {0}); - updated_accum = ir_builder_.CreateInsertValue( - updated_accum, ir_builder_.CreateFAdd(IMAG(accum), product_imag), {1}); -#undef IMAG -#undef REAL + auto value = MultiplyComplex(lhs_element, rhs_element, &ir_builder_); + llvm::Value* accum_real = Real(accum, &ir_builder_); + llvm::Value* real_sum = ir_builder_.CreateFAdd(accum_real, value.first); + updated_accum = ir_builder_.CreateInsertValue(accum, real_sum, {0}); + llvm::Value* accum_imag = Imag(accum, &ir_builder_); + llvm::Value* imag_sum = ir_builder_.CreateFAdd(accum_imag, value.second); + updated_accum = ir_builder_.CreateInsertValue(updated_accum, imag_sum, {1}); } else { llvm::Value* product = ir_builder_.CreateFMul(lhs_element, rhs_element); updated_accum = ir_builder_.CreateFAdd(accum, product); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 30c88c0a5d38f6ea3f94d3b47b7b69c7122bf6ac..199e6b787413c5e0fb1435c62f1fc3b83fc6eba3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include #include @@ -44,6 +46,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" @@ -498,12 +501,11 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { switch (root->opcode()) { case HloOpcode::kReduce: { VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); + TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, + BuildInitializerThunk(fusion)); std::vector> thunks; - thunks.emplace_back(BuildKernelThunk(fusion)); - TF_RETURN_IF_ERROR(EmitInitializer( - fusion, static_cast(thunks.back().get()))); - bindings_.UnbindAllLocalIrValues(); - thunks.emplace_back(BuildKernelThunk(fusion)); + thunks.push_back(std::move(initializer_thunk)); + thunks.push_back(BuildKernelThunk(fusion)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), fusion)); std::vector parameter_arrays; @@ -517,39 +519,6 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); Shape input_shape = root->operand(0)->shape(); - // EmitReductionToVector requires the input shape to have a layout, but - // fused instructions don't have one. So we determine its layout from - // the fusion's operands. The choice of the layout only affects - // performance but not correctness. - auto choose_input_layout = []( - tensorflow::gtl::ArraySlice operands, - Shape* input_shape) -> Status { - // Prefer the layout of an operand whose shape is compatible with - // input_shape. - for (const HloInstruction* operand : operands) { - if (ShapeUtil::Compatible(*input_shape, operand->shape())) { - return LayoutUtil::CopyLayoutBetweenShapes(operand->shape(), - input_shape); - } - } - // If no operand has a compatible shape, prefer an operand that has - // the same rank at least. - for (const HloInstruction* operand : operands) { - if (ShapeUtil::Rank(*input_shape) == - ShapeUtil::Rank(operand->shape())) { - // Do not use CopyLayoutBetweenShapes because input_shape and - // operand->shape() may be incompatible. - *input_shape->mutable_layout() = operand->shape().layout(); - return Status::OK(); - } - } - // When all the above fails, which is rare, set the default layout. - LayoutUtil::SetToDefaultLayout(input_shape); - return Status::OK(); - }; - TF_RETURN_IF_ERROR( - choose_input_layout(fusion->operands(), &input_shape)); - return EmitReductionToVector( root, input_shape, fused_emitter.GetGenerator(root->operand(0)), fused_emitter.GetGenerator(root->operand(1)), root->dimensions(), @@ -1668,14 +1637,14 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { if (IsReductionToVector(*reduce) && // NVPTX backend can't do atomic cmpxchg any narrower than 32 bits 32 <= primitive_util::BitWidth(reduce->shape().element_type())) { + TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, + BuildInitializerThunk(reduce)); std::vector> thunks; - thunks.emplace_back(BuildKernelThunk(reduce)); - TF_RETURN_IF_ERROR(EmitInitializer( - reduce, static_cast(thunks.back().get()))); - bindings_.UnbindAllLocalIrValues(); - thunks.emplace_back(BuildKernelThunk(reduce)); + thunks.push_back(std::move(initializer_thunk)); + thunks.push_back(BuildKernelThunk(reduce)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), reduce)); + return EmitReductionToVector( reduce, input->shape(), [&](const llvm_ir::IrArray::Index& index) { @@ -1739,16 +1708,13 @@ Status IrEmitterUnnested::HandleSelectAndScatter( CHECK_EQ(rank, ShapeUtil::Rank(source->shape())); CHECK_EQ(rank, window.dimensions_size()); - { - std::vector> thunks; - thunks.emplace_back(BuildKernelThunk(select_and_scatter)); - TF_RETURN_IF_ERROR(EmitInitializer( - select_and_scatter, static_cast(thunks.back().get()))); - bindings_.UnbindAllLocalIrValues(); - thunks.emplace_back(BuildKernelThunk(select_and_scatter)); - thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), select_and_scatter)); - } + TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, + BuildInitializerThunk(select_and_scatter)); + std::vector> thunks; + thunks.push_back(std::move(initializer_thunk)); + thunks.push_back(BuildKernelThunk(select_and_scatter)); + thunk_sequence_->emplace_back( + MakeUnique(std::move(thunks), select_and_scatter)); // TODO(b/31410564): Implement dilation rate for select-and-scatter. if (window_util::HasDilation(window)) { @@ -2013,11 +1979,22 @@ GetHloBufferSlices(const HloInstruction* hlo, } } - // If *that* didn't work, check whether instr is a GTE instruction. If it - // is, see if we can get a buffer for its parent, and continue walking up - // parents until we find a defined buffer or we hit something that's not a - // GTE. + // If *that* didn't work, walk up any bitcasts that we might see. These + // must appear before any GTE instructions, because it's illegal to bitcast + // to a tuple type. const HloInstruction* parent = instr; + while (parent->opcode() == HloOpcode::kBitcast) { + parent = parent->operand(0); + + auto slice = GetKnownAtRuntimeSlice(parent, {}, buffer_assn); + if (slice.has_value()) { + return {{*slice, gte_indices}}; + } + } + + // Finally, check whether instr is a GTE instruction. If it is, see if we + // can get a buffer for its parent, and continue walking up parents until we + // find a defined buffer or we hit something that's not a GTE. while (parent->opcode() == HloOpcode::kGetTupleElement) { gte_indices.push_front(parent->tuple_index()); parent = parent->operand(0); @@ -2069,7 +2046,7 @@ Status IrEmitterUnnested::HandleGather(HloInstruction* gather) { return Unimplemented("Gather is not implemented on GPUs."); } -std::unique_ptr IrEmitterUnnested::BuildKernelThunk( +std::unique_ptr IrEmitterUnnested::BuildKernelThunk( const HloInstruction* inst) { const BufferAssignment& buffer_assn = ir_emitter_context_->buffer_assignment(); @@ -2221,31 +2198,63 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( inst->shape(), // The shape of the output. false, // Do not transpose LHS. false, // Do not transpose RHS. + 1.0, // alpha. inst); } if (inst->opcode() == HloOpcode::kFusion) { - const HloInstruction* dot = inst->fused_expression_root(); - DCHECK(dot->opcode() == HloOpcode::kDot); - const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); - const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); - DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && - rhs_parameter->opcode() == HloOpcode::kParameter); - const HloInstruction* lhs = - inst->operand(lhs_parameter->parameter_number()); - const HloInstruction* rhs = - inst->operand(rhs_parameter->parameter_number()); - - return MakeUnique( - GetAllocationSlice(*lhs), // The buffer assigned to LHS. - GetAllocationSlice(*rhs), // The buffer assigned to RHS. - GetAllocationSlice(*inst), // The output buffer. - lhs->shape(), // The shape of LHS. - rhs->shape(), // The shape of RHS. - inst->shape(), // The shape of the output. - dot->operand(0)->IsRank2Transpose(), // Transpose LHS. - dot->operand(1)->IsRank2Transpose(), // Trasnpose RHS. - inst); + if (inst->fusion_kind() == HloInstruction::FusionKind::kOutput) { + const HloInstruction* mul = inst->fused_expression_root(); + const HloInstruction* dot = mul->operand(0); + const HloInstruction* alpha = mul->operand(1); + if (dot->opcode() != HloOpcode::kDot) { + std::swap(dot, alpha); + } + DCHECK(dot->opcode() == HloOpcode::kDot); + const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); + const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); + DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && + rhs_parameter->opcode() == HloOpcode::kParameter); + const HloInstruction* lhs = + inst->operand(lhs_parameter->parameter_number()); + const HloInstruction* rhs = + inst->operand(rhs_parameter->parameter_number()); + + return MakeUnique( + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*mul), // The output buffer. + lhs->shape(), // The shape of LHS. + rhs->shape(), // The shape of RHS. + inst->shape(), // The shape of the output. + dot->operand(0)->IsRank2Transpose(), // Transpose LHS. + dot->operand(1)->IsRank2Transpose(), // Transpose RHS. + alpha->literal().Get({0}), // alpha. + inst); + } else { + const HloInstruction* dot = inst->fused_expression_root(); + DCHECK(dot->opcode() == HloOpcode::kDot); + const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); + const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); + DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && + rhs_parameter->opcode() == HloOpcode::kParameter); + const HloInstruction* lhs = + inst->operand(lhs_parameter->parameter_number()); + const HloInstruction* rhs = + inst->operand(rhs_parameter->parameter_number()); + + return MakeUnique( + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*inst), // The output buffer. + lhs->shape(), // The shape of LHS. + rhs->shape(), // The shape of RHS. + inst->shape(), // The shape of the output. + dot->operand(0)->IsRank2Transpose(), // Transpose LHS. + dot->operand(1)->IsRank2Transpose(), // Transpose RHS. + 1.0, // Alpha. + inst); + } } LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString(); @@ -2261,37 +2270,87 @@ std::unique_ptr IrEmitterUnnested::BuildFftThunk( /*output_shape=*/inst->shape(), inst); } -Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo, - KernelThunk* thunk) { +StatusOr> IrEmitterUnnested::BuildInitializerThunk( + const HloInstruction* hlo) { bool fused = HloOpcode::kFusion == hlo->opcode(); - const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; - CHECK(inst->opcode() == HloOpcode::kSelectAndScatter || - inst->opcode() == HloOpcode::kReduce); - const HloInstruction* init_value = nullptr; - switch (inst->opcode()) { - case HloOpcode::kSelectAndScatter: - init_value = inst->operand(2); - break; - case HloOpcode::kReduce: - init_value = inst->operand(1); - break; - default: - LOG(FATAL) << "Opcode " << inst->opcode() - << " should not need an initializer."; - } + const HloInstruction* init_value = [&] { + switch (inst->opcode()) { + case HloOpcode::kSelectAndScatter: + return inst->operand(2); + case HloOpcode::kReduce: + return inst->operand(1); + default: + LOG(FATAL) << "Opcode " << inst->opcode() + << " should not need an initializer."; + } + }(); if (fused && init_value->opcode() == HloOpcode::kParameter) { init_value = hlo->operand(init_value->parameter_number()); } - return EmitTargetElementLoopInThunk( + // In the common case, the initializer is a constant. In this case, emit a + // device-memset call if we can. Currently StreamExecutor only supports + // zeroing and 32-bit memsets. + if (init_value->IsConstant()) { + CHECK(ShapeUtil::IsScalar(init_value->shape())); + int64 num_bytes = ShapeUtil::ByteSizeOfElements(init_value->shape()); + const auto& literal = init_value->literal(); + + // Are all the bytes of this scalar equal to 0? If so, we can create a + // MemzeroThunk. + ArraySlice literal_bytes( + reinterpret_cast(literal.untyped_data()), num_bytes); + if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { + return {MakeUnique(GetAllocationSlice(*hlo), hlo)}; + } + + // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by + // repeating the literal 4 or 2 times, so long as the destination buffer is + // an even multiple of 32 bits long. + if ((num_bytes == 1 || num_bytes == 2) && + ShapeUtil::ByteSizeOf(hlo->shape()) % 4 == 0) { + uint16 pattern16; + if (num_bytes == 1) { + uint8 b = literal_bytes.front(); + pattern16 = uint16{b} | (uint16{b} << 8); + } else { + pattern16 = literal_bytes.front(); + } + uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); + return {MakeUnique(pattern32, + GetAllocationSlice(*hlo), hlo)}; + } + + // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit + // memset so long as all 32-bit words of the scalar are equal to each other. + if (num_bytes >= 4 && num_bytes % 4 == 0 && + memcmp(literal_bytes.data(), literal_bytes.data() + 4, + literal_bytes.size() - 4) == 0) { + uint32 word; + memcpy(&word, literal_bytes.data(), sizeof(word)); + return {MakeUnique(word, GetAllocationSlice(*hlo), + hlo)}; + } + } + + // Otherwise fall back to our slow initializer code. + std::unique_ptr kernel_thunk = BuildKernelThunk(hlo); + TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk( *hlo, [=](const llvm_ir::IrArray::Index& index) { return GetIrArray(*init_value, *hlo) .EmitReadArrayElement(index, &ir_builder_); }, - thunk); + kernel_thunk.get())); + + // Clean up state left behind by emitting the loop above. (This is normally + // done in IrEmitterUnnested::Postprocess().) + bindings_.UnbindAllLocalIrValues(); + + // Convert unique_ptr to StatusOr>. + return {std::move(kernel_thunk)}; } namespace { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index b83a2337e2decd9d4fba3d40fcf33f131fca8a3c..66c62e2d2de3ed1668271a21943dc73ed3d77651 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -148,13 +148,10 @@ class IrEmitterUnnested : public IrEmitter { tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reducer); - // Emits code to initialize buffer of `inst` in given `thunk`. - Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk); - // Returns a KernelThunk that invokes the kernel emitted for `inst`. The // caller needs to make sure `inst` outlives the lifetime of the returned // Thunk object. - std::unique_ptr BuildKernelThunk(const HloInstruction* inst); + std::unique_ptr BuildKernelThunk(const HloInstruction* inst); // Returns a FftThunk that calls cuFFT to implement `inst`. std::unique_ptr BuildFftThunk(const HloInstruction* inst); @@ -163,6 +160,11 @@ class IrEmitterUnnested : public IrEmitter { // to make sure `inst` outlives the lifetime of the returned Thunk object. std::unique_ptr BuildGemmThunk(const HloInstruction* inst); + // Returns a thunk that, given a reduce or select-and-scatter op, initializes + // its memory to the appropriate initial value. + StatusOr> BuildInitializerThunk( + const HloInstruction* hlo); + // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..18e673542c5b47cb90d31a8eff62a5e4adb78d1d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" +#include "tensorflow/stream_executor/stream_executor.h" + +namespace xla { +namespace gpu { + +namespace se = ::perftools::gputools; + +Status MemzeroThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); + stream->ThenMemZero(&dest_data, dest_data.size()); + return Status::OK(); +} + +Status Memset32BitValueThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); + stream->ThenMemset32(&dest_data, value_, dest_data.size()); + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.h b/tensorflow/compiler/xla/service/gpu/memset_thunk.h new file mode 100644 index 0000000000000000000000000000000000000000..b4bb74d1dd6dc9d09c5e4d439d57dfe8b57c2ed9 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.h @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MEMSET_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MEMSET_THUNK_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/stream_executor/stream_executor.h" + +// This file contains thunks that set a buffer's elements to a particular value. +// This can be faster than emitting a kernel to set the elements. + +namespace xla { +namespace gpu { + +// Thunk that zeroes out a given chunk of memory. +class MemzeroThunk : public Thunk { + public: + explicit MemzeroThunk(const BufferAllocation::Slice& dest, + const HloInstruction* hlo) + : Thunk(Kind::kMemzero, hlo), dest_(dest) {} + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + const BufferAllocation::Slice dest_; +}; + +// Thunk that sets a given chunk of memory to a particular 32-bit value. The +// destination chunk must have size divisible by 32 bits. +class Memset32BitValueThunk : public Thunk { + public: + explicit Memset32BitValueThunk(uint32 value, + const BufferAllocation::Slice& dest, + const HloInstruction* hlo) + : Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {} + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + uint32 value_; + const BufferAllocation::Slice dest_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MEMSET_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 25846dc6cd4633c7becb6e62d6bc9585348a6eac..7bda4e2fcd469bd430e5ef1846251c8504225383 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -68,13 +69,7 @@ HloInstruction* MaybePaddedAndSlicedInput( HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( MakeUnique(Literal::Zero(element_type)))); - input = computation->AddInstruction(HloInstruction::CreatePad( - ShapeInference::InferPadShape( - /*operand_shape=*/input->shape(), - /*padding_value_shape=*/ShapeUtil::MakeShape(element_type, {}), - padding_config) - .ConsumeValueOrDie(), - input, padding, padding_config)); + input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } if (window_util::HasNegativePadding(conv_window)) { @@ -97,11 +92,8 @@ HloInstruction* MaybePaddedAndSlicedInput( std::max(0LL, -conv_window.dimensions(i).padding_high()); } - input = computation->AddInstruction(HloInstruction::CreateSlice( - ShapeInference::InferSliceShape(input->shape(), start_indices, - limit_indices, strides) - .ConsumeValueOrDie(), - input, start_indices, limit_indices, strides)); + input = + MakeSliceHlo(input, start_indices, limit_indices, strides).ValueOrDie(); } return input; @@ -134,13 +126,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( MakeUnique(Literal::Zero(element_type)))); - return computation->AddInstruction(HloInstruction::CreatePad( - ShapeInference::InferPadShape( - /*operand_shape=*/kernel->shape(), - /*padding_value_shape=*/ShapeUtil::MakeShape(element_type, {}), - padding_config) - .ConsumeValueOrDie(), - kernel, padding, padding_config)); + return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); } } // namespace @@ -252,11 +238,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( computation->AddInstruction(HloInstruction::CreateConstant( MakeUnique(Literal::Zero(input->shape().element_type())))); HloInstruction* padded_input = - computation->AddInstruction(HloInstruction::CreatePad( - ShapeInference::InferPadShape(input->shape(), padding->shape(), - input_padding_config) - .ConsumeValueOrDie(), - input, padding, input_padding_config)); + MakePadHlo(input, padding, input_padding_config).ValueOrDie(); // The shape of the backward_conv CustomCall is a tuple (conv_result, // scratch_buffer). Extract out the shape of conv_result. diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 2c3032d79be221e8cacb178ffb1817459b603cc0..9eea958d1214b131d49cb4e28f1944860408d3a8 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -51,6 +51,8 @@ class Thunk { kGemm, kInfeed, kKernel, + kMemset32BitValue, + kMemzero, kSequential, kTuple, kWhile, diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index a2d13c013c56059148ccd04dba2137a5b2badc42..3dd4c4a0794e5c41b877078c4e69c6c9584ce6c0 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -27,38 +27,6 @@ namespace xla { using tensorflow::gtl::FlatMap; using tensorflow::gtl::FlatSet; -namespace { - -// Returns the set of buffers that may be sources of all operands of the given -// instruction. The returned buffers are guaranteed to have no duplicates, and -// to be sorted in a deterministic order. -std::vector UniqueOperandSourceBuffers( - const HloInstruction* instruction, - const TuplePointsToAnalysis& points_to_analysis) { - std::vector buffers; - for (const HloInstruction* operand : instruction->operands()) { - points_to_analysis.GetPointsToSet(operand).ForEachElement( - [&](const ShapeIndex& /*index*/, - const PointsToSet::BufferList& points_to) { - buffers.insert(buffers.end(), points_to.begin(), points_to.end()); - }); - } - - // Sort and then remove duplicates from buffers. - std::sort(buffers.begin(), buffers.end(), - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() < b->id(); - }); - buffers.erase(std::unique(buffers.begin(), buffers.end(), - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() == b->id(); - }), - buffers.end()); - return buffers; -} - -} // namespace - /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, @@ -93,6 +61,7 @@ Status HeapSimulator::RunComputation( const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis) { + VLOG(3) << "Computation:\n" << computation.ToString(); // The goal here is to minimize memory usage, assuming the given sequential // ordering of instructions. The strategy is to walk through the instruction // sequence, calling Alloc and Free on the underlying heap algorithm. The @@ -101,7 +70,51 @@ Status HeapSimulator::RunComputation( // 'live_buffers' tracks the liveness of each buffer that we assign, by // associating it with a set of HloInstructions that need to be visited. When // the set becomes empty, the buffer is no longer used, and can be freed. + // 'used_buffers' is the reverse map - it tracks which buffers were used by an + // instruction, so that we can remove the instructions from a buffer's live + // set after they are visited. FlatMap> live_buffers; + FlatMap> used_buffers; + auto add_user_to_buffer = [this, &live_buffers, &used_buffers]( + const HloInstruction* user, + const LogicalBuffer* buffer) { + if (!IgnoreBuffer(buffer)) { + VLOG(4) << " Adding user " << user->name() << " to buffer " + << buffer->ToString(); + live_buffers[buffer].insert(user); + used_buffers[user].insert(buffer); + } + }; + + // Initialize live_buffers for each buffer that we're going to assign. The + // set of instructions that need to be visited contains all users of all + // aliases, that is, all users of all instructions that have the buffer + // contained in their points-to set. + for (const HloInstruction* instruction : instruction_sequence) { + const PointsToSet& points_to = + points_to_analysis.GetPointsToSet(instruction); + const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet(); + for (const HloInstruction* user : instruction->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + for (const LogicalBuffer* buffer : buffer_set) { + add_user_to_buffer(user, buffer); + } + } else { + // A GetTupleElement doesn't need to keep all of its operand's buffers + // alive. It only needs the buffers that relate to the element its + // extracting, and the tuple it's extracting from, but not the buffers + // for the other elements. + for (const LogicalBuffer* buffer : points_to.element({})) { + add_user_to_buffer(user, buffer); + } + const PointsToSet& gte_points_to = + points_to_analysis.GetPointsToSet(user); + for (const LogicalBuffer* buffer : gte_points_to.CreateFlattenedSet()) { + add_user_to_buffer(user, buffer); + } + } + } + } const HloInstruction* root = computation.root_instruction(); auto output_source_buffers = @@ -114,34 +127,17 @@ Status HeapSimulator::RunComputation( buffers_defined_by_instruction = points_to_analysis.GetBuffersDefinedByInstruction(instruction); - // Initialize live_buffers for each buffer that we're going to assign. The - // set of instructions that need to be visited contains all users of all - // aliases. The alias itself is not necessary; if it has users, the users - // are necessarily scheduled after the alias. And if it has no users, it is - // either a dead value or an output, both of which are handled below. - // - // We ignore control dependencies here. The reasoning is that the control - // dependencies have already been accounted for in the ordering of the given - // 'instruction_sequence', and should not otherwise artificially extend the - // lifetime of buffers that aren't already connected by a data dependency. + VLOG(3) << "Instruction: " << instruction->ToString(); + for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + VLOG(4) << " Defines: " << buffer->ToString() + << (IgnoreBuffer(buffer) ? " (Ignored)" : ""); + } + dead_buffers_to_free.clear(); for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; } - FlatSet* live_set = nullptr; - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - const std::vector& users = - alias.instruction()->users(); - if (!users.empty()) { - if (live_set == nullptr) { - live_set = &live_buffers[buffer]; - } - live_set->insert(users.begin(), users.end()); - } - } - // Add a nullptr sentry to ensure entry parameters and output source // buffers are not freed until the very end. const bool entry_parameter = @@ -165,11 +161,12 @@ Status HeapSimulator::RunComputation( // have no instructions left to visit are moved from live_buffers to // operand_buffers_to_free. operand_buffers_to_free.clear(); - for (const LogicalBuffer* operand_buffer : - UniqueOperandSourceBuffers(instruction, points_to_analysis)) { + for (const LogicalBuffer* operand_buffer : used_buffers[instruction]) { if (IgnoreBuffer(operand_buffer)) { continue; } + VLOG(4) << " Removing user " << instruction->name() << " from buffer " + << operand_buffer->ToString(); auto it = live_buffers.find(operand_buffer); FlatSet* live_set = &it->second; live_set->erase(instruction); @@ -178,6 +175,11 @@ Status HeapSimulator::RunComputation( operand_buffers_to_free.push_back(operand_buffer); } } + // Sort to get a deterministic iteration order. + std::sort(operand_buffers_to_free.begin(), operand_buffers_to_free.end(), + [](const LogicalBuffer* x, const LogicalBuffer* y) { + return x->id() < y->id(); + }); // Allocate buffers defined by this instruction. This is the latest point // that we can allocate; right before the buffer is first used. This must @@ -203,6 +205,8 @@ Status HeapSimulator::RunComputation( CanShareOperandBufferWithUser( operand_buffer->instruction(), operand_buffer->index(), buffer->instruction(), buffer->index(), points_to_analysis)) { + VLOG(3) << " Sharing: " << buffer->ToString() << " with " + << operand_buffer->ToString(); ShareBuffer(buffer, operand_buffer, instruction); shared = true; break; @@ -211,6 +215,7 @@ Status HeapSimulator::RunComputation( } if (!shared) { + VLOG(3) << " Allocating: " << buffer->ToString(); Alloc(buffer, instruction); } } @@ -244,20 +249,34 @@ Status HeapSimulator::RunComputation( // Free buffers that are no longer live. This is the earliest point that we // can de-allocate; right after the last use of the buffer. for (const LogicalBuffer* buffer : dead_buffers_to_free) { + VLOG(3) << " Freeing dead: " << buffer->ToString(); Free(buffer, instruction); } for (const LogicalBuffer* buffer : operand_buffers_to_free) { + VLOG(3) << " Freeing operand: " << buffer->ToString(); Free(buffer, instruction); } } // Any remaining live buffers must be entry parameters or output source - // buffers, which had a nullptr sentry added. Free them now. + // buffers, which had a nullptr sentry added. Free them now, in a + // deterministic order. + std::vector to_free; + to_free.reserve(live_buffers.size()); for (const auto& buffer_pending : live_buffers) { const LogicalBuffer* buffer = buffer_pending.first; const FlatSet& pending = buffer_pending.second; CHECK_EQ(pending.size(), 1) << *buffer; CHECK(*pending.begin() == nullptr) << *buffer; + to_free.push_back(buffer); + } + + std::sort(to_free.begin(), to_free.end(), + [](const LogicalBuffer* x, const LogicalBuffer* y) { + return x->id() < y->id(); + }); + for (const LogicalBuffer* buffer : to_free) { + VLOG(3) << "Freeing pending: " << buffer->ToString(); Free(buffer, root); } diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 387b649a731ebcbfd8307807469f39f22d192b06..688a271712ac243666ba4ff02932aa4f7f7ed21c 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -410,6 +410,56 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { }); } +TEST_F(HeapSimulatorTest, IndependentTupleElements) { + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32scalar_, "paramA")); + auto paramB = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32scalar_, "paramB")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32scalar_, HloOpcode::kMultiply, paramA, paramB)); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + f32scalar_, HloOpcode::kAdd, paramA, paramB)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({mul, add})); + auto element0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 0)); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec4_, element0, {0})); + auto sub = builder.AddInstruction(HloInstruction::CreateBinary( + f32scalar_, HloOpcode::kSubtract, paramA, paramB)); + auto element1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 1)); + auto output = builder.AddInstruction( + HloInstruction::CreateTuple({broadcast, sub, element1})); + + HeapSimulatorTracker tracker(TestName(), builder.Build(), + {paramA, paramB, mul, add, tuple, element0, + broadcast, sub, element1, output}); + tracker.ExpectCallSequence({ + {kAlloc, tracker.BufferAt(paramA, {})}, + {kAlloc, tracker.BufferAt(paramB, {})}, + {kAlloc, tracker.BufferAt(mul, {})}, + {kAlloc, tracker.BufferAt(add, {})}, + {kAlloc, tracker.BufferAt(tuple, {})}, + {kAlloc, tracker.BufferAt(broadcast, {})}, + // The mul can be freed right after the broadcast happens, even though + // The other GetTupleElement is still alive. + {kFree, tracker.BufferAt(mul, {})}, + {kAlloc, tracker.BufferAt(sub, {})}, + // The temporary tuple is now dead. + {kFree, tracker.BufferAt(tuple, {})}, + {kAlloc, tracker.BufferAt(output, {})}, + // All params and outputs are freed at the end. + {kFree, tracker.BufferAt(paramA, {})}, + {kFree, tracker.BufferAt(paramB, {})}, + {kFree, tracker.BufferAt(add, {})}, + {kFree, tracker.BufferAt(broadcast, {})}, + {kFree, tracker.BufferAt(sub, {})}, + {kFree, tracker.BufferAt(output, {})}, + {kFinish, nullptr}, + }); +} + TEST_F(HeapSimulatorTest, WholeModule) { HeapSimulatorTracker tracker(TestName()); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index a43785b4a9701369ae315f67d4d64d03dc6c081d..0b446c654779db410ebbd91ef9a5bab14d08a278 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -13,13 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// DO NOT USE THESE PROTO MESSAGES FOR ANYTHING OTHER THAN DEBUGGING. -// -// Don't use these protos in the real compilation or execution codepaths. The -// data format is meant for debugging only, and may change without notice. +// This proto file defines messages which represent the HLO module. This is a +// full fidelity serialization of the c++ HLO constructs. // // Many of the protos below are simple 1-to-1 serializations of the -// corresponding C++ classes. +// corresponding C++ classes, e.g., HloModule, HloComputation, and +// HloInstruction. // // FIELD NAMES ARE IMPORTANT // @@ -38,16 +37,19 @@ option cc_enable_arenas = true; message HloInstructionProto { reserved 10; reserved "parameter_name"; + reserved 12; + reserved "fused_instructions_computation"; + reserved 4; + reserved "operand_names"; + reserved 5; + reserved "control_predecessor_names"; + reserved 6; + reserved "called_computation_names"; string name = 1; string opcode = 2; xla.Shape shape = 3; - // TODO(b/67782397): Replace instruction names with HloInstruction ids. - repeated string operand_names = 4; - repeated string control_predecessor_names = 5; - repeated string called_computation_names = 6; - xla.OpMetadata metadata = 7; // Literal, only present for kConstant. @@ -58,7 +60,6 @@ message HloInstructionProto { // Fusion state, only present for kFusion. string fusion_kind = 11; - HloComputationProto fused_instructions_computation = 12; // Index for kGetTupleElement. int64 tuple_index = 13; @@ -133,28 +134,53 @@ message HloInstructionProto { // Gather dimension numbers. xla.GatherDimensionNumbers gather_dimension_numbers = 33; repeated int64 gather_window_bounds = 34; + + // The id of this instruction. + int64 id = 35; + + repeated int64 operand_ids = 36; + repeated int64 control_predecessor_ids = 37; + repeated int64 called_computation_ids = 38; + + xla.OpSharding sharding = 40; } // Serialization of HloComputation. message HloComputationProto { + reserved 3; + reserved "root_name"; + string name = 1; // The array of instructions is always in a valid dependency order, where // operands appear before their users. repeated HloInstructionProto instructions = 2; - // The name of the root of the computation. - string root_name = 3; + // The program shape (with layout) of this computation. + xla.ProgramShape program_shape = 4; + + // The id of this computation. + int64 id = 5; + + // The id of the root of the computation. + int64 root_id = 6; } // Serialization of HloModule. message HloModuleProto { string name = 1; string entry_computation_name = 2; + int64 entry_computation_id = 6; // The array of computations is always in a valid dependency order, where // callees appear before their callers. repeated HloComputationProto computations = 3; + + // The program shape (with layout) of the entry computation. + xla.ProgramShape program_shape = 4; + + // The id of this module. + int64 id = 5; } // Serialization of HloOrdering. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 30e32a46d7dd0923f738939c33407ac7484b5bbe..a88283ed9a6459b4fa9310e160b59c77d51f1027 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -171,24 +171,21 @@ class BufferValueMap { return value_to_buffer_number_.at(&value); } - // Compute and return a vector of buffers that the given value must be - // contained in due to HLO aliasing rules. - std::vector ComputeAliasedBuffers(const HloValue& value) { + void ComputeWhileAliasedBuffers(const HloValue& value, + std::vector* aliased_buffers) { + VLOG(3) << "Compute kWhile aliases"; // Value is init of a while (use is while). - std::vector aliased_buffers; for (const HloUse& use : value.uses()) { - VLOG(2) << "use of value " << value.ToShortString() << ": " << use; if (use.instruction->opcode() == HloOpcode::kWhile) { // Determine the while value that this shares a buffer with. const HloValue& while_value = dataflow_.GetUniqueValueAt(use.instruction, use.operand_index); - aliased_buffers.push_back(GetBufferForValue(while_value)); + aliased_buffers->push_back(GetBufferForValue(while_value)); VLOG(3) << " value is init value to a while; must share buffer with " "while value " << while_value.ToShortString(); } } - // Value is a parameter of a while body/condition. if (value.defining_instruction()->opcode() == HloOpcode::kParameter) { const HloComputation* computation = @@ -205,11 +202,10 @@ class BufferValueMap { VLOG(3) << " value is parameter value of the body or condition of a " "while; must share buffer with while value " << while_value.ToShortString(); - aliased_buffers.push_back(GetBufferForValue(while_value)); + aliased_buffers->push_back(GetBufferForValue(while_value)); } } } - // Value is the root of a while body. for (const HloPosition& position : value.positions()) { const HloComputation* computation = position.instruction->parent(); @@ -224,27 +220,71 @@ class BufferValueMap { const HloValue& while_value = dataflow_.GetUniqueValueAt( callsite.instruction(), position.index); - VLOG(3) << " value is root the body computation of a while; must " - "share buffer with while value " + VLOG(3) << " value @ " << position << " is root of " + << callsite.instruction()->name() + << "; body root and while value root must share buffer " + "among them : " << while_value.ToShortString(); - aliased_buffers.push_back(GetBufferForValue(while_value)); + aliased_buffers->push_back(GetBufferForValue(while_value)); } } } } - // Value is the output of the while instruction itself. if (value.defining_instruction()->opcode() == HloOpcode::kWhile) { VLOG(3) << " value is output of a while instruction"; - aliased_buffers.push_back(GetBufferForValue(value)); + aliased_buffers->push_back(GetBufferForValue(value)); + } + } + + void ComputeConditionalAliasedBuffers( + const HloValue& value, std::vector* aliased_buffers) { + VLOG(3) << "Compute kConditional aliases"; + // Aliases the buffers of the true/false computations roots, with the one of + // the conditional. + for (const HloPosition& position : value.positions()) { + const HloComputation* computation = position.instruction->parent(); + const CallGraphNode& call_graph_node = + dataflow_.call_graph().GetNode(computation); + if (position.instruction == computation->root_instruction()) { + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + if (callsite.instruction()->opcode() == HloOpcode::kConditional) { + // Call graph must have been flattened. + CHECK_EQ(call_graph_node.caller_callsites().size(), 1); + + const HloValue& cond_value = dataflow_.GetUniqueValueAt( + callsite.instruction(), position.index); + VLOG(3) + << " value @ " << position << " is root of " + << callsite.instruction()->name() + << "; true/false branch roots must share buffer among them : " + << cond_value.ToShortString(); + aliased_buffers->push_back(GetBufferForValue(cond_value)); + } + } + } + } + // Value is the output of the conditional instruction itself. + if (value.defining_instruction()->opcode() == HloOpcode::kConditional) { + VLOG(3) << " value is output of a conditional instruction"; + aliased_buffers->push_back(GetBufferForValue(value)); } + } + // Compute and return a vector of buffers that the given value must be + // contained in due to HLO aliasing rules. + std::vector ComputeAliasedBuffers(const HloValue& value) { + for (const HloUse& use : value.uses()) { + VLOG(2) << "Use of value " << value.ToShortString() << ": " << use; + } + std::vector aliased_buffers; + ComputeWhileAliasedBuffers(value, &aliased_buffers); + ComputeConditionalAliasedBuffers(value, &aliased_buffers); // Uniquify aliased buffers. std::sort(aliased_buffers.begin(), aliased_buffers.end()); aliased_buffers.erase( std::unique(aliased_buffers.begin(), aliased_buffers.end()), aliased_buffers.end()); - return aliased_buffers; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 21e6b2ca730f6347af902097e6496826b861e8a3..6f983d0b950435d43fe3a1e0fe84902b51bfe249 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -65,6 +65,7 @@ HloComputation::HloComputation( std::vector>* instructions, HloInstruction* root_instruction, HloInstruction* fusion_instruction) : name_(name), + unique_id_(-1), root_instruction_(root_instruction), fusion_instruction_(fusion_instruction) { param_instructions_.resize(parameter_count, nullptr); @@ -101,7 +102,7 @@ HloInstruction* HloComputation::AddInstructionInternal( instruction->UniquifyName(&parent()->instruction_name_uniquer()); instruction->SetUniqueId(parent()->NewUniqueInstructionId()); } - Reparent(instruction.get()); + instruction->set_parent(this); HloInstruction* pinst = instruction.get(); instruction_iterators_[pinst] = instructions_.insert(instructions_.end(), std::move(instruction)); @@ -158,10 +159,6 @@ Status HloComputation::RemoveParameter(int64 param_no) { return Status::OK(); } -void HloComputation::Reparent(HloInstruction* instruction) { - instruction->set_parent(this); -} - bool HloComputation::IsRemovable(const HloInstruction* instruction) { // If the instruction has control predecessors or successors then we cannot // remove the instruction without violating ordering constraints (added, for @@ -393,43 +390,46 @@ string HloComputation::ToString(const HloPrintOptions& options) const { HloComputationProto HloComputation::ToProto() const { HloComputationProto proto; + CHECK(unique_id_ != -1) + << "This computation does not have a valid id. Please make sure the " + "computation is inside a module before dumping it."; + proto.set_id(unique_id_); proto.set_name(name_); for (const HloInstruction* instruction : MakeInstructionPostOrder()) { HloInstructionProto instruction_proto = instruction->ToProto(); proto.add_instructions()->Swap(&instruction_proto); } - proto.set_root_name(root_instruction()->name()); + proto.set_root_id(root_instruction()->unique_id()); + *proto.mutable_program_shape() = ComputeProgramShape(); return proto; } /* static */ StatusOr> HloComputation::CreateFromProto( HloModule* module, const HloComputationProto& proto, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation, - HloInstruction* fusion_instruction) { + const tensorflow::gtl::FlatMap& computation_map) { std::vector> instructions; - tensorflow::gtl::FlatMap instruction_map; + tensorflow::gtl::FlatMap instruction_map; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr instruction, - HloInstruction::CreateFromProto( - module, instruction_proto, instruction_map, - computation_map, add_fused_computation)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr instruction, + HloInstruction::CreateFromProto(module, instruction_proto, + instruction_map, computation_map)); if (instruction->opcode() == HloOpcode::kParameter) { parameter_count++; } - TF_RET_CHECK(!ContainsKey(instruction_map, instruction->name())); - instruction_map[instruction->name()] = instruction.get(); + TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id())); + instruction_map[instruction_proto.id()] = instruction.get(); instructions.push_back(std::move(instruction)); } - TF_RET_CHECK(!proto.root_name().empty()); - TF_RET_CHECK(ContainsKey(instruction_map, proto.root_name())); - HloInstruction* root = instruction_map.at(proto.root_name()); - return WrapUnique(new HloComputation( - proto.name(), parameter_count, &instructions, root, fusion_instruction)); + TF_RET_CHECK(proto.root_id() != -1); + TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id())); + HloInstruction* root = instruction_map.at(proto.root_id()); + return WrapUnique(new HloComputation(proto.name(), parameter_count, + &instructions, root, + /*fusion_instruction=*/nullptr)); } void HloComputation::FuseInstructionsInto( @@ -532,7 +532,6 @@ ProgramShape HloComputation::ComputeProgramShape() const { } *program_shape.mutable_result() = root_instruction_->shape(); - LayoutUtil::ClearLayout(&program_shape); return program_shape; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 39d864efcb70382b6f8e631d7e6e452ea6410104..9d3f6e9a2c2efd97681a22b6b0f6d929afc553de 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -160,20 +160,12 @@ class HloComputation { // module: the module which will contain the computation. The newly created // computation is *not* added to the module, however. // proto: the proto to convert from. - // computation_map: a map from computation name to HloComputation*. This map + // computation_map: a map from computation id to HloComputation*. This map // must contain all computations which the newly constructed computation // calls. - // add_fused_computation: A function to call to add a fused - // computation. Used only when the instruction is a fusion instruction. - // fusion_instruction: if non-null then the newly created computation will - // be constructed as a fused computation with this instruction as its - // fusion parent. static StatusOr> CreateFromProto( HloModule* module, const HloComputationProto& proto, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation, - HloInstruction* fusion_instruction = nullptr); + const tensorflow::gtl::FlatMap& computation_map); // Gets the instructions in this computation. // @@ -248,7 +240,7 @@ class HloComputation { ShapeTree* copies_added = nullptr); // Computes and returns the ProgramShape of this computation (shape of - // parameters and result without layout). + // parameters and result with layout). ProgramShape ComputeProgramShape() const; // Return whether `*this` and `other` are functionally equivalent. @@ -342,6 +334,15 @@ class HloComputation { fusion_instruction_ = fusion_instruction; } + // The id of this computation should be unique within the module. + void SetUniqueId(int64 id) { + CHECK_EQ(unique_id_, -1); + CHECK_GE(id, 0); + unique_id_ = id; + } + + int64 unique_id() const { return unique_id_; } + private: explicit HloComputation( const string& name, int parameter_count, @@ -352,10 +353,6 @@ class HloComputation { HloInstruction* AddInstructionInternal( std::unique_ptr instruction); - // Helper for setting the parent of instructions that are added to this - // computation. - void Reparent(HloInstruction* instruction); - // Fuses HLOs in instructions_to_fuse into fusion_instruction. // // Pre-condition: fusion_instruction's opcode is kFusion. @@ -373,6 +370,7 @@ class HloComputation { std::vector CollectUnreachableRoots() const; string name_; + int64 unique_id_; HloInstruction* root_instruction_; // If this computation is a fusion computation, this field points to the diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 53450991b6fad5b9651d9d23b55c908e6b68e5dd..35ecd4428d0dfde2de445ea34472d2c78148c6c9 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -35,7 +35,10 @@ limitations under the License. namespace xla { StatusOr HloConstantFolding::Run(HloModule* module) { - auto evaluator = MakeUnique(); + // Limit the constant folding to 0 iterations to skip folding loops. This + // retains the behavior from before while loop support in HloEvaluator and may + // be revised. + auto evaluator = MakeUnique(/*max_loop_iterations=*/0); XLA_VLOG_LINES(2, "HloConstantFolding::Run(), before:\n" + module->ToString()); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..b186767ce792cd89ae77fe9a03b3a2ecf296b804 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -0,0 +1,277 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +using tensorflow::gtl::ArraySlice; +using tensorflow::strings::StrCat; + +StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, + HloInstruction* rhs) { + HloComputation* computation = lhs->parent(); + CHECK_EQ(computation, rhs->parent()); + TF_ASSIGN_OR_RETURN(Shape binary_op_shape, + ShapeInference::InferBinaryOpShape(opcode, lhs, rhs)); + return computation->AddInstruction( + HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs)); +} + +StatusOr MakePadHlo(HloInstruction* operand, + HloInstruction* padding_value, + const PaddingConfig& padding_config) { + HloComputation* computation = operand->parent(); + CHECK_EQ(computation, padding_value->parent()); + TF_ASSIGN_OR_RETURN( + Shape pad_shape, + ShapeInference::InferPadShape(operand->shape(), padding_value->shape(), + padding_config)); + return computation->AddInstruction(HloInstruction::CreatePad( + pad_shape, operand, padding_value, padding_config)); +} + +StatusOr MakeSliceHlo(HloInstruction* operand, + ArraySlice start_indices, + ArraySlice limit_indices, + ArraySlice strides) { + HloComputation* computation = operand->parent(); + TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape( + operand->shape(), start_indices, + limit_indices, strides)); + return computation->AddInstruction(HloInstruction::CreateSlice( + slice_shape, operand, start_indices, limit_indices, strides)); +} + +StatusOr MakeConvolveHlo( + HloInstruction* lhs, HloInstruction* rhs, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers) { + HloComputation* computation = lhs->parent(); + CHECK_EQ(computation, rhs->parent()); + TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape( + lhs->shape(), rhs->shape(), + window, dimension_numbers)); + return computation->AddInstruction(HloInstruction::CreateConvolve( + convolve_shape, lhs, rhs, window, dimension_numbers)); +} + +StatusOr MakeTransposeHlo(HloInstruction* operand, + ArraySlice dimensions) { + HloComputation* computation = operand->parent(); + TF_ASSIGN_OR_RETURN( + Shape transpose_shape, + ShapeInference::InferTransposeShape(operand->shape(), dimensions)); + return computation->AddInstruction( + HloInstruction::CreateTranspose(transpose_shape, operand, dimensions)); +} + +StatusOr MakeReshapeHlo(const Shape& result_shape, + HloInstruction* operand) { + HloComputation* computation = operand->parent(); + return computation->AddInstruction( + HloInstruction::CreateReshape(result_shape, operand)); +} + +StatusOr MakeReshapeHlo( + ArraySlice result_shape_dim_bounds, HloInstruction* operand) { + Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(), + result_shape_dim_bounds); + return MakeReshapeHlo(new_shape, operand); +} + +StatusOr MakeDynamicSliceHlo(HloInstruction* operand, + HloInstruction* start_indices, + ArraySlice slice_sizes) { + HloComputation* computation = operand->parent(); + CHECK_EQ(computation, start_indices->parent()); + TF_ASSIGN_OR_RETURN( + Shape dynamic_slice_shape, + ShapeInference::InferDynamicSliceShape( + operand->shape(), start_indices->shape(), slice_sizes)); + return computation->AddInstruction(HloInstruction::CreateDynamicSlice( + dynamic_slice_shape, operand, start_indices, slice_sizes)); +} + +StatusOr MakeDynamicUpdateSliceHlo( + HloInstruction* operand, HloInstruction* update, + HloInstruction* start_indices) { + HloComputation* computation = operand->parent(); + CHECK_EQ(computation, update->parent()); + CHECK_EQ(computation, start_indices->parent()); + TF_ASSIGN_OR_RETURN( + Shape dynamic_update_slice_shape, + ShapeInference::InferDynamicUpdateSliceShape( + operand->shape(), update->shape(), start_indices->shape())); + return computation->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + dynamic_update_slice_shape, operand, update, start_indices)); +} + +StatusOr MakeBroadcastHlo( + HloInstruction* operand, ArraySlice broadcast_dimensions, + ArraySlice result_shape_bounds) { + HloComputation* computation = operand->parent(); + Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(), + result_shape_bounds); + + return computation->AddInstruction(HloInstruction::CreateBroadcast( + broadcast_shape, operand, broadcast_dimensions)); +} + +StatusOr MakeGetTupleElementHlo(HloInstruction* operand, + int64 index) { + HloComputation* computation = operand->parent(); + + TF_ASSIGN_OR_RETURN( + Shape gte_shape, + ShapeInference::InferGetTupleElementShape(operand->shape(), index)); + return computation->AddInstruction( + HloInstruction::CreateGetTupleElement(gte_shape, operand, index)); +} + +StatusOr MakeConcatHlo(ArraySlice operands, + int64 dimension) { + CHECK_GT(operands.size(), 0); + + HloComputation* computation = operands[0]->parent(); + CHECK(c_all_of(operands, [&](HloInstruction* instr) { + return instr->parent() == computation; + })); + + std::vector operand_shapes; + c_transform(operands, std::back_inserter(operand_shapes), + [](HloInstruction* instr) { return &instr->shape(); }); + + TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape( + operand_shapes, dimension)); + return computation->AddInstruction( + HloInstruction::CreateConcatenate(concat_shape, operands, dimension)); +} + +StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { + const Shape& operand_shape = operand->shape(); + CHECK_GE(operand_shape.dimensions_size(), n); + int64 new_shape_leading_bound = 1; + for (int64 i = 0; i < n; i++) { + new_shape_leading_bound *= operand_shape.dimensions(i); + } + + std::vector new_shape_dims; + new_shape_dims.reserve(operand_shape.dimensions_size() - n + 1); + new_shape_dims.push_back(new_shape_leading_bound); + + std::copy(operand_shape.dimensions().begin() + n, + operand_shape.dimensions().end(), + std::back_inserter(new_shape_dims)); + + Shape output_shape = + ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims); + + return MakeReshapeHlo(output_shape, operand); +} + +StatusOr ExpandFirstDimIntoNDims( + HloInstruction* operand, ArraySlice expanded_dims) { + CHECK_GT(operand->shape().dimensions_size(), 0); + CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims)); + + std::vector expanded_shape_dim_bounds; + expanded_shape_dim_bounds.reserve(expanded_dims.size() + + operand->shape().dimensions_size() - 1); + c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds)); + std::copy(operand->shape().dimensions().begin() + 1, + operand->shape().dimensions().end(), + std::back_inserter(expanded_shape_dim_bounds)); + Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(), + expanded_shape_dim_bounds); + return MakeReshapeHlo(new_shape, operand); +} + +StatusOr ElideDegenerateDims(HloInstruction* operand, + ArraySlice dims_to_elide) { + CHECK(c_is_sorted(dims_to_elide)); + + const Shape& input_shape = operand->shape(); + // First accumulate in reverse + std::vector new_shape_dim_bounds; + new_shape_dim_bounds.reserve(input_shape.dimensions_size() - + dims_to_elide.size()); + int64 dims_to_elide_idx = dims_to_elide.size() - 1; + for (int64 i = input_shape.dimensions_size() - 1; i >= 0; i--) { + if (dims_to_elide_idx >= 0 && i == dims_to_elide[dims_to_elide_idx]) { + CHECK_EQ(input_shape.dimensions(i), 1); + dims_to_elide_idx--; + } else { + new_shape_dim_bounds.push_back(input_shape.dimensions(i)); + } + } + + c_reverse(new_shape_dim_bounds); + Shape output_shape = + ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds); + return MakeReshapeHlo(output_shape, operand); +} + +StatusOr PadVectorWithZeros(HloInstruction* operand, + int64 zeros_to_prepend, + int64 zeros_to_append) { + HloComputation* computation = operand->parent(); + CHECK_EQ(operand->shape().dimensions_size(), 1); + PaddingConfig padding_config; + PaddingConfig::PaddingConfigDimension padding_config_dim; + padding_config_dim.set_edge_padding_low(zeros_to_prepend); + padding_config_dim.set_edge_padding_high(zeros_to_append); + *padding_config.add_dimensions() = padding_config_dim; + + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + MakeUnique(Literal::Zero(operand->shape().element_type())))); + return MakePadHlo(operand, zero, padding_config); +} + +StatusOr BroadcastZeros( + HloComputation* computation, PrimitiveType element_type, + ArraySlice broadcast_dimensions) { + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + MakeUnique(Literal::Zero(element_type)))); + return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, + /*result_shape_bounds=*/broadcast_dimensions); +} + +StatusOr> CreateComputationWithSignature( + ArraySlice domain, const Shape& range, + tensorflow::StringPiece name) { + HloComputation::Builder b(name.ToString()); + int64 param_idx = 0; + for (const Shape* param_shape : domain) { + b.AddInstruction(HloInstruction::CreateParameter( + param_idx, *param_shape, StrCat("param.", param_idx))); + param_idx++; + } + + // We can't change the root type of a computation once it is created so create + // a dummy root instruction to give the computation the right root shape. In + // the future we may want to use a (recursive) broadcast here to avoid + // creating large constants. + b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateFromShape(range))); + + return b.Build(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..d99e32a737e6aaa2ff746cf6c00d4300cf62f4e1 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -0,0 +1,153 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_ + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Some lightweight utilities intended to make HLO instruction creation more +// ergonomic. We don't have a complete set of helpers yet -- I expect we'll +// expand this interface as needed on an ad-hoc basis. + +// Creates a binary HLO instruction and adds it to the computation containing +// `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). +StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, + HloInstruction* rhs); + +// Creates a pad HLO instruction and adds it to the computation containing +// `operand` and `padding_value` (`operand` and `padding_value` must be in the +// same computation). +StatusOr MakePadHlo(HloInstruction* operand, + HloInstruction* padding_value, + const PaddingConfig& padding_config); + +// Creates a slice HLO instruction and adds it to the computation containing +// `operand`. +StatusOr MakeSliceHlo( + HloInstruction* operand, tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + +// Creates a convolution HLO instruction and adds it to the computation +// containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). +StatusOr MakeConvolveHlo( + HloInstruction* lhs, HloInstruction* rhs, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers); + +// Creates a transpose HLO instruction and adds it to the computation containing +// `operand`. +StatusOr MakeTransposeHlo( + HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); + +// Creates a reshape HLO instruction and adds it to the computation containing +// `operand`. +StatusOr MakeReshapeHlo(const Shape& result_shape, + HloInstruction* operand); + +StatusOr MakeReshapeHlo( + tensorflow::gtl::ArraySlice result_shape_dim_bounds, + HloInstruction* operand); + +// Creates a dynamic-slice HLO instruction and adds it to the computation +// containing `operand` and `start_indices` (`operand` and `start_indices` must +// be in the same computation). +StatusOr MakeDynamicSliceHlo( + HloInstruction* operand, HloInstruction* start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + +// Creates a dynamic-update-slice HLO instruction and adds it to the computation +// containing `operand`, `update` and `start_indices` (`operand`, `update` and +// `start_indices` must be in the same computation). +StatusOr MakeDynamicUpdateSliceHlo( + HloInstruction* operand, HloInstruction* update, + HloInstruction* start_indices); + +// Creates a broadcast HLO instruction and adds it to the computation containing +// `operand`. +StatusOr MakeBroadcastHlo( + HloInstruction* operand, + tensorflow::gtl::ArraySlice broadcast_dimensions, + tensorflow::gtl::ArraySlice result_shape_bounds); + +// Creates a GetTupleElement HLO instruction and adds it to the computation +// containing `operand`. +StatusOr MakeGetTupleElementHlo(HloInstruction* operand, + int64 index); + +// Creates a Concatenate HLO instruction and adds it to the computation +// containing `operands` (`operands` must be non-empty and every element must be +// contained in the same computation). +StatusOr MakeConcatHlo( + tensorflow::gtl::ArraySlice operands, int64 dimension); + +// ----------------------------------------------------------------------------- +// Some other miscellaneous helpers to generate common HLO patterns. All of +// these add all the instructions they generate into the computation containing +// their operand(s). + +// Collapses (via reshape) the first N (logical) dimensions of `operand` into a +// single leading dimension. `operand` must have rank > n. +// +// For instance if `operand` has shape f32[7,8,9] and n is 2 then the output is +// the `operand` reshaped to [56,9]. +StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n); + +// Expands (via reshape) the first (logical) dimension of `operand` into a +// sequence of `expanded_dims` dimensions. `operand` must at least be of rank 1 +// and the number of elements in its first dimension must be equal to the +// product of `expanded_dims`. +// +// For instance if `operand` has shape f32[200,9,7] and expanded_dims is +// {2,5,20} the result is `operand` reshaped to [2,5,20,9,7]. +StatusOr ExpandFirstDimIntoNDims( + HloInstruction* operand, tensorflow::gtl::ArraySlice expanded_dims); + +// Elides (via reshape) a set of degenerate dimensions (dimensions containing +// exactly one element), `dims_to_elide` from `operand`. Every dimension in +// `dims_to_elide` must be a degenerate dimension. `dims_to_elide` must be +// sorted and not contain duplicates. +// +// For example if `operand` is of shape f32[19,1,20,1,7,1,9] and dims_to_elide +// is {1,5} then the result is `operand` reshaped to [19,20,1,7,9]. +StatusOr ElideDegenerateDims( + HloInstruction* operand, tensorflow::gtl::ArraySlice dims_to_elide); + +// Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the +// front and `zeros_to_append` zeros in the back. +StatusOr PadVectorWithZeros(HloInstruction* operand, + int64 zeros_to_prepend, + int64 zeros_to_append); + +// Broadcasts a zero value of type `element_type` into a tensor with element +// type `element_type` and dimension bounds `broadcast_dimensions`. The +// broadcast instruction is emitted into `computation`. +StatusOr BroadcastZeros( + HloComputation* computation, PrimitiveType element_type, + tensorflow::gtl::ArraySlice broadcast_dimensions); + +// Creates a HLO computation that takes arguments of type `domain` and produces +// a value of type `range`. +StatusOr> CreateComputationWithSignature( + tensorflow::gtl::ArraySlice domain, const Shape& range, + tensorflow::StringPiece name); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 279edd4ba8772a9c576f76f554de8ec68631b953..cd7cbbdd71706fddb64855f631eb09de35da52e8 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -109,6 +109,11 @@ StatusOr HloCSE::Run(HloModule* module) { continue; } + // Skip instructions which have side effects. + if (instruction->HasSideEffect()) { + continue; + } + // An instruction is considered to be equivalent to another only if they // share the exact same set of operands. So to find equivalent // instructions, we just search among instructions which share operand(0) @@ -118,7 +123,7 @@ StatusOr HloCSE::Run(HloModule* module) { tensorflow::gtl::InlinedVector equivalent_instructions; for (HloInstruction* user : operand->users()) { - if (user != instruction && + if (user != instruction && !user->HasSideEffect() && user->Identical(*instruction, eq_instructions, eq_computations, is_layout_sensitive_)) { equivalent_instructions.push_back(user); diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 3601a790c4428ee39c264b217a4b9a991ad8456c..df8853f34f6a72c52d1cde7332ada3809d2f3d96 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -414,8 +414,7 @@ TEST_F(HloCseTest, DoNotCombineRng) { EXPECT_THAT(root, op::Add(rng1, rng2)); } -// TODO(b/28245743): Handle impure functions correctly in CSE. -TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { +TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { // Test that two calls to an impure function are not commoned. RNG // is the source of the impurity. @@ -458,14 +457,16 @@ TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Add(op::Map(), op::Map())); + VLOG(3) << "before: " << module->ToString(); + HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + + VLOG(3) << "after: " << module->ToString(); EXPECT_EQ(4, computation->instruction_count()); root = computation->root_instruction(); - auto operand = root->operand(0)->operand(0); - EXPECT_THAT(operand, op::Map()); - EXPECT_THAT(root, op::Add(operand, operand)); + EXPECT_THAT(root, op::Add(op::Map(op::Constant()), op::Map(op::Constant()))); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 934e43ba4879628362009267c671ec4cb0d79c52..0c37a8d75f38dabaad886cc9d4adce8ab29ddf18 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -368,11 +368,11 @@ bool HloDataflowAnalysis::UpdateConditionalValueSet( conditional->true_computation()->root_instruction()), &GetInstructionValueSet( conditional->false_computation()->root_instruction())}; - // A phi-node is not defined for a kConditional instruction even though it - // represents a join point. This is because the current approach is to define - // a phi-node only for kWhile to account for the dataflow through back-edges - // and deal with the ambiguity in other cases. - return GetInstructionValueSet(conditional).AssignUnionOf(inputs); + if (ssa_form_) { + return Phi(conditional, inputs); + } else { + return GetInstructionValueSet(conditional).AssignUnionOf(inputs); + } } bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) { diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 7bf3a1a06045c79621d75b653bf42220705a69d4..07f69b8e1339fed636e4eb54791941b85e09fd17 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1602,11 +1602,17 @@ TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) { EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), ElementsAre(HloUse{conditional, 2, {}})); - EXPECT_EQ(analysis.values().size(), 3); - EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); - EXPECT_THAT(HloValuesAt(conditional), - UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), - analysis.GetValueDefinedAt(constant2))); + bool ssa_form = GetParam(); + if (ssa_form) { + EXPECT_EQ(analysis.values().size(), 4); + EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional)); + } else { + EXPECT_EQ(analysis.values().size(), 3); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), + analysis.GetValueDefinedAt(constant2))); + } } TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { @@ -1713,11 +1719,17 @@ TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}}, HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}})); - EXPECT_EQ(analysis.values().size(), 6); - EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); - EXPECT_THAT(HloValuesAt(conditional), - UnorderedElementsAre(analysis.GetValueDefinedAt(add), - analysis.GetValueDefinedAt(sub))); + bool ssa_form = GetParam(); + if (ssa_form) { + EXPECT_EQ(analysis.values().size(), 7); + EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional)); + } else { + EXPECT_EQ(analysis.values().size(), 6); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(add), + analysis.GetValueDefinedAt(sub))); + } } TEST_P(HloDataflowAnalysisTest, NestedConditionals) { @@ -1834,20 +1846,27 @@ TEST_P(HloDataflowAnalysisTest, NestedConditionals) { EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond), analysis.GetValueDefinedAt(constant2)); - EXPECT_EQ(analysis.values().size(), 9); - EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional)); - EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); - EXPECT_THAT( - HloValuesAt(inner_conditional), - UnorderedElementsAre( - analysis.GetValueDefinedAt(computation1->root_instruction()), - analysis.GetValueDefinedAt(computation2->root_instruction()))); - EXPECT_THAT( - HloValuesAt(conditional), - UnorderedElementsAre( - analysis.GetValueDefinedAt(computation1->root_instruction()), - analysis.GetValueDefinedAt(computation2->root_instruction()), - analysis.GetValueDefinedAt(computation3->root_instruction()))); + bool ssa_form = GetParam(); + if (ssa_form) { + EXPECT_EQ(analysis.values().size(), 11); + EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_conditional)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional)); + } else { + EXPECT_EQ(analysis.values().size(), 9); + EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT( + HloValuesAt(inner_conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()))); + EXPECT_THAT( + HloValuesAt(conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()), + analysis.GetValueDefinedAt(computation3->root_instruction()))); + } } INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 1e5f0f797a13fd7e7ce1cc934387a274a74153bc..fcd723af146e2227b8661b1a4993f1338f7de389 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -40,7 +40,7 @@ StatusOr HloDCE::Run(HloModule* module) { VLOG(2) << "Before dce:"; XLA_VLOG_LINES(2, module->ToString()); - for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* computation : module->MakeComputationPostOrder()) { std::unordered_set live_instructions; TF_RETURN_IF_ERROR(computation->root_instruction()->Accept( [&live_instructions](HloInstruction* instruction) { diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 15ae53128aa5dfe706daa6d47dc6d842fd78e26c..693004d364114b1a25ce6b6791092665c861d13f 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -51,12 +51,22 @@ namespace xla { namespace { +using tensorflow::gtl::ArraySlice; +using tensorflow::gtl::FlatSet; +using tensorflow::gtl::optional; + template struct is_complex_t : public std::false_type {}; template <> struct is_complex_t : public std::true_type {}; +template +struct is_complex64_t : public std::false_type {}; + +template <> +struct is_complex64_t : public std::true_type {}; + template StatusOr> Compare(const Shape& shape, HloOpcode opcode, const Literal& lhs_literal, @@ -99,11 +109,10 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, } auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - return compare_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); - })); + TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); return std::move(result); } @@ -130,11 +139,10 @@ StatusOr> Compare( } auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - return compare_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); - })); + TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); return std::move(result); } @@ -159,8 +167,8 @@ StatusOr> ElementWiseUnaryOpImpl( auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](ArraySlice multi_index) { return unary_op(operand_literal.Get(multi_index)); })); return std::move(result); @@ -172,7 +180,7 @@ StatusOr> ElementWiseUnaryOpImpl( // with the base index. void IterateThroughWindow( const Shape& window_shape, const Window& window, const Shape& base_shape, - const tensorflow::gtl::ArraySlice& window_count_index, + const ArraySlice& window_count_index, const std::function&)>& f) { const int64 rank = ShapeUtil::Rank(base_shape); DimensionVector window_index(rank); @@ -248,17 +256,37 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, - typename std::enable_if::value || - is_complex_t::value>::type* = nullptr> + typename std::enable_if::value>::type* = nullptr> Status HandleAbs(HloInstruction* abs) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], - ElementWiseUnaryOp(abs, [](ElementwiseT elem_operand) { + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { return std::abs(elem_operand); })); return Status::OK(); } + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAbs(HloInstruction* abs) { + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(abs->operand(0)); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[abs], + (ElementWiseUnaryOpImpl( + abs, [](NativeT elem_operand) { return std::abs(elem_operand); }, + operand_literal))); + + return Status::OK(); + } + Status HandleAbs(HloInstruction* abs) override { + // If the operand is of C64 type, the return type of abs will be F32. + // However, ElementwiseT would still be the return type, F32, and thus + // specifying the ElementwiseT explicitly as C64 is needed below. + if (abs->operand(0)->shape().element_type() == C64) { + return HandleAbs(abs); + } return HandleAbs(abs); } @@ -306,13 +334,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { operand_to_broadcast.shape().dimensions(i)); } - return output->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; - } - return operand_to_broadcast.Get(broadcast_indices); - }); + return output->Populate([&](ArraySlice multi_index) { + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; + } + return operand_to_broadcast.Get(broadcast_indices); + }); } template < @@ -586,14 +613,25 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value>::type* = + nullptr> + Status HandleMaximum(HloInstruction* maximum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[maximum], + ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { + return std::max(lhs, rhs); + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> Status HandleMaximum(HloInstruction* maximum) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[maximum], ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { - return std::fmax(lhs, rhs); + return ((lhs >= rhs) || std::isnan(lhs)) ? lhs : rhs; })); return Status::OK(); } @@ -609,18 +647,30 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return HandleMaximum(maximum); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value>::type* = + nullptr> Status HandleMinimum(HloInstruction* minimum) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return std::fmin(lhs_el, rhs_el); + return std::min(lhs_el, rhs_el); })); return Status::OK(); } + template ::value>::type* = nullptr> + Status HandleMinimum(HloInstruction* minimum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[minimum], + ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return ((lhs_el <= rhs_el) || std::isnan(lhs_el)) ? lhs_el : rhs_el; + })); + return Status::OK(); + } + template < typename NativeT, typename std::enable_if::value>::type* = nullptr> @@ -825,7 +875,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleClamp(HloInstruction* clamp) { std::function clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { - return std::fmax(low, std::fmin(value, high)); + return std::fmin(high, std::fmax(value, low)); }; TF_ASSIGN_OR_RETURN( parent_->evaluated_[clamp], @@ -846,6 +896,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleSelect(HloInstruction* select) override { + CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); CHECK(!ShapeUtil::IsTuple(select->shape())); std::function select_op = [](bool pred, ReturnT on_true, ReturnT on_false) { @@ -876,8 +927,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); auto result = Literal::CreateFromShape(result_shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice out_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](ArraySlice out_index) { std::vector from_index(out_index.begin(), out_index.end()); for (const int64 dim : reverse_dimensions) { from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; @@ -952,7 +1003,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { DimensionVector rhs_index(rhs_rank); DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size()); - auto func = [&](tensorflow::gtl::ArraySlice out_index) { + auto func = [&](ArraySlice out_index) { ElementwiseT result_val = static_cast(0); std::fill(lhs_index.begin(), lhs_index.end(), 0); @@ -1074,9 +1125,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } std::vector rhs_non_batch_non_contracting_dims; - tensorflow::gtl::FlatSet batch_dims_set( - dnums.rhs_batch_dimensions().begin(), - dnums.rhs_batch_dimensions().end()); + FlatSet batch_dims_set(dnums.rhs_batch_dimensions().begin(), + dnums.rhs_batch_dimensions().end()); for (int64 i = 0; i < rhs_rank; i++) { if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { rhs_non_batch_non_contracting_dims.push_back(i); @@ -1088,8 +1138,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { DimensionVector lhs_index(lhs_rank); DimensionVector rhs_index(rhs_rank); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice result_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](ArraySlice result_index) { ElementwiseT result_val = static_cast(0); // Find the corresponding non-contracting indices for lhs and rhs. @@ -1183,9 +1233,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); auto result = Literal::CreateFromShape(pad->shape()); TF_RETURN_IF_ERROR(result->Populate( - [&scalar](tensorflow::gtl::ArraySlice multi_index) { - return scalar; - })); + [&scalar](ArraySlice multi_index) { return scalar; })); const Literal& evaluated_operand = parent_->GetEvaluatedLiteralFor(pad->operand(0)); @@ -1198,7 +1246,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { // corresponding index of the resulting padded literal. const PaddingConfig& pad_config = pad->padding_config(); - auto func = [&](const std::vector& input_index) { + auto func = [&](ArraySlice input_index) { for (auto i = 0; i < input_index.size(); ++i) { // Interior padding occurs logically before edge padding, so in the case // of negative edge padding elements are removed from the @@ -1348,9 +1396,9 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { auto result = Literal::CreateFromShape(map->shape()); - HloEvaluator embedded_evaluator; - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + TF_RETURN_IF_ERROR( + result->Populate([&](ArraySlice multi_index) { std::vector> arg_literals; arg_literals.reserve(operands.size()); @@ -1440,7 +1488,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleReduce(HloInstruction* reduce) override { auto arg = reduce->operand(0); auto init_value = reduce->operand(1); - tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + ArraySlice dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == ShapeUtil::Rank(arg->shape()) - dimensions.size()); @@ -1483,10 +1531,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } } - HloEvaluator embedded_evaluator; + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](ArraySlice multi_index) { ReturnT result_val = init_scalar; std::vector base(arg_dimensions.size()); @@ -1494,7 +1542,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { base[result_to_arg_index[i]] = multi_index[i]; } - auto func = [&](const std::vector& input_index) { + auto func = [&](ArraySlice input_index) { auto curr_val = arg_literal.Get(input_index); // Evaluate computation with specified literal operands. @@ -1540,9 +1588,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { // Initialize result array with the init value. TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice output_index) { - return init_scalar; - })); + [&](ArraySlice output_index) { return init_scalar; })); std::vector window_dimension_sizes; for (const auto& window_dimension : window.dimensions()) { @@ -1559,7 +1605,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { int64 rank = ShapeUtil::Rank(operand_literal.shape()); - HloEvaluator embedded_evaluator; + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); DimensionVector source_index(rank); std::fill(source_index.begin(), source_index.end(), 0); @@ -1575,8 +1621,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { // 2. Using the selected index, scatter value from `source` to result. We // do this by iterating through the window, and compare each index with // the selected index. - tensorflow::gtl::optional selected_val; - tensorflow::gtl::optional> selected_index; + optional selected_val; + optional> selected_index; IterateThroughWindow( window_shape, window, operand_literal.shape(), source_index, @@ -1591,11 +1637,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Literal::CreateR0(*selected_val); const std::vector args = { - curr_val_literal.get(), selected_val_literal.get()}; + selected_val_literal.get(), curr_val_literal.get()}; std::unique_ptr computed_result = embedded_evaluator.Evaluate(*select, args) .ConsumeValueOrDie(); - bool selected = computed_result->Get({}); + bool selected = !computed_result->Get({}); if (selected) { selected_val = curr_val; selected_index = operand_index; @@ -1670,10 +1716,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { DimensionVector window_index(window.dimensions_size()); DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); - HloEvaluator embedded_evaluator; + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice output_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](ArraySlice output_index) { ReturnT result_val = init_scalar; std::fill(window_index.begin(), window_index.end(), 0); @@ -1723,7 +1769,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const int64 rank = ShapeUtil::Rank(operand->shape()); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto func = [&](tensorflow::gtl::ArraySlice out_index) { + auto func = [&](ArraySlice out_index) { DimensionVector operand_index(rank); for (int64 i = 0; i < rank; ++i) { operand_index[i] = @@ -1904,8 +1950,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { std::vector operand_indices(start.size()); auto result = Literal::CreateFromShape(result_shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](ArraySlice multi_index) { for (int64 i = 0; i < operand_indices.size(); ++i) { CHECK_GE(multi_index[i] + start[i], 0); // Mod is only used here to be consistent with the existing @@ -1925,17 +1971,26 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { StatusOr> DynamicUpdateSlice( const Literal& operand_literal, const Literal& update_literal, const Literal& start_indices_literal) { - auto start_indices_typed = start_indices_literal.data(); - const std::vector start(start_indices_typed.begin(), - start_indices_typed.end()); - auto result = operand_literal.CloneToUnique(); - std::vector result_index(ShapeUtil::Rank(result->shape()), 0); + auto start_indices_typed = start_indices_literal.data(); + const auto rank = ShapeUtil::Rank(result->shape()); + std::vector start(rank, 0); + for (int64 i = 0; i < rank; ++i) { + // All other implementations currently wrap-around the index, so this + // should do so as well. + start[i] = (start_indices_typed[i] % result->shape().dimensions(i)); + start[i] += (start[i] < 0) * result->shape().dimensions(i); + } + std::vector result_index(rank, 0); - auto func = [&](const std::vector& update_index) { + auto func = [&](ArraySlice update_index) { std::transform(update_index.begin(), update_index.end(), start.begin(), result_index.begin(), std::plus()); - + // Same as above, wrap-around only to match other implementations' + // semantics. + std::transform(result_index.begin(), result_index.end(), + result->shape().dimensions().begin(), result_index.begin(), + std::modulus()); result->Set(result_index, update_literal.Get(update_index)); return true; @@ -1988,8 +2043,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](ArraySlice multi_index) { return ConvertBinaryFunction(binary_op)( lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -2026,8 +2081,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](ArraySlice multi_index) { return ternary_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index), ehs_literal.Get(multi_index)); @@ -2047,17 +2102,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { HloEvaluator* parent_; }; // class HloEvaluator::TypedVisitor -HloEvaluator::HloEvaluator() { +HloEvaluator::HloEvaluator(int64 max_loop_iterations) + : max_loop_iterations_(max_loop_iterations) { typed_visitors_[PRED] = MakeUnique>(this); typed_visitors_[U8] = MakeUnique>(this); typed_visitors_[U16] = MakeUnique([](HloInstruction*) { - return Unimplemented("HloEvaluator: unhandled primitive type: U16."); + return Unimplemented( + "HloEvaluator::TypedVisitor: unhandled primitive type: U16."); }); typed_visitors_[U32] = MakeUnique>(this); typed_visitors_[U64] = MakeUnique>(this); typed_visitors_[S8] = MakeUnique>(this); typed_visitors_[S16] = MakeUnique([](HloInstruction*) { - return Unimplemented("HloEvaluator: unhandled primitive type: S16."); + return Unimplemented( + "HloEvaluator::TypedVisitor: unhandled primitive type: S16."); }); typed_visitors_[S32] = MakeUnique>(this); typed_visitors_[S64] = MakeUnique>(this); @@ -2071,18 +2129,20 @@ HloEvaluator::HloEvaluator() { // elementwise computations to be done in F32 and do BF16<->F32 conversion // around the input and the output of the computations. typed_visitors_[BF16] = MakeUnique>(this); + typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { - return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE."); + return Unimplemented( + "HloEvaluator::TypedVistor: unhandled primitive type: TUPLE."); }); typed_visitors_[OPAQUE] = MakeUnique([](HloInstruction*) { - return Unimplemented("HloEvaluator: unhandled primitive type: OPAQUE."); + return Unimplemented( + "HloEvaluator::TypedVisitor: unhandled primitive type: OPAQUE."); }); } template StatusOr> HloEvaluator::Evaluate( - const HloModule& module, - tensorflow::gtl::ArraySlice arg_literals) { + const HloModule& module, ArraySlice arg_literals) { XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); evaluated_.clear(); @@ -2099,8 +2159,8 @@ StatusOr> HloEvaluator::Evaluate( template StatusOr> HloEvaluator::Evaluate( - const HloComputation& computation, - tensorflow::gtl::ArraySlice arg_literals) { + const HloComputation& computation, ArraySlice arg_literals) { + CHECK(computation.parent() != nullptr); XLA_VLOG_LINES( 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString()); @@ -2116,8 +2176,7 @@ StatusOr> HloEvaluator::Evaluate( template StatusOr> HloEvaluator::Evaluate( - HloInstruction* instruction, - tensorflow::gtl::ArraySlice arg_literals) { + HloInstruction* instruction, ArraySlice arg_literals) { TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); @@ -2242,8 +2301,7 @@ Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { } Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { - tensorflow::gtl::ArraySlice operands( - concatenate->operands()); + ArraySlice operands(concatenate->operands()); // The result concatenate dimension is going to be the sum of all // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); @@ -2415,6 +2473,349 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) { return Status::OK(); } +// Returns an ShapeUtil::IndexIterationSpace that iterates over the output +// gather dimensions while keeping the rest of the output dimensions clamped to +// 0. +ShapeUtil::IndexIterationSpace IterationSpaceForOutputGatherIndices( + const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) { + int64 output_rank = output_shape.dimensions_size(); + std::vector index_base(output_rank, 0); + std::vector index_count; + index_count.reserve(output_rank); + for (int64 i = 0; i < output_rank; i++) { + bool is_output_gather_dim = + !c_binary_search(dim_numbers.output_window_dims(), i); + index_count.push_back(is_output_gather_dim ? output_shape.dimensions(i) + : 1); + } + + return {std::move(index_base), std::move(index_count), + std::vector(output_rank, 1)}; +} + +// Return an ShapeUtil::IndexIterationSpace that iterates over the output window +// dimensions while keeping the rest of the output dimensions clamped to 0. +ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices( + int64 output_rank, ArraySlice window_bounds, + const GatherDimensionNumbers& dim_numbers) { + std::vector index_base(output_rank, 0); + std::vector index_count(output_rank, 1); + int64 window_bounds_idx = 0; + for (int64 i = 0; i < output_rank; i++) { + bool is_output_window_dim = + c_binary_search(dim_numbers.output_window_dims(), i); + if (is_output_window_dim) { + while (c_binary_search(dim_numbers.elided_window_dims(), + window_bounds_idx)) { + window_bounds_idx++; + } + index_count[i] = window_bounds[window_bounds_idx++]; + } + } + + return {std::move(index_base), std::move(index_count), + std::vector(output_rank, 1)}; +} + +// This functor computes the contribution of gather_indices to an input index +// corresponding to an output index. That is, given an output index I, it picks +// out the gather output indices in I and uses them to look up a gather index, +// G, from the gather indices tensor, and expands G into the input space +// according to gather_dims_to_operand_dims. +class OutputGatherIndexToInputIndex { + public: + // The constructor does some setup work that is amortized across all + // iterations. + explicit OutputGatherIndexToInputIndex( + const GatherDimensionNumbers* dim_numbers, const Shape& input_shape, + const Shape& output_shape, const Literal* gather_indices) + : dim_numbers_(*dim_numbers), gather_indices_(*gather_indices) { + for (int64 i = 0; i < output_shape.dimensions_size(); i++) { + output_dim_is_gather_dims_.push_back( + !c_binary_search(dim_numbers_.output_window_dims(), i)); + } + + for (int64 i = 0; i < input_shape.dimensions_size(); i++) { + int64 index_of_input_dim_in_index_vector = + std::distance(dim_numbers_.gather_dims_to_operand_dims().begin(), + c_find(dim_numbers_.gather_dims_to_operand_dims(), i)); + if (index_of_input_dim_in_index_vector == + dim_numbers_.gather_dims_to_operand_dims_size()) { + input_dim_value_to_index_vector_.push_back(-1); + } else { + input_dim_value_to_index_vector_.push_back( + index_of_input_dim_in_index_vector); + } + } + + index_vector_index_.resize(gather_indices_.shape().dimensions_size()); + input_index_.resize(input_shape.dimensions_size()); + int64 index_vector_size = + gather_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); + index_vector_.resize(index_vector_size); + } + + // Returns the contribution of gather_indices to the input index corresponding + // to output_index. See gather_inner_loop_body. + // + // This is conceptually a stateless transformation from output_index to the + // gather input index, but: + // + // - Instead of allocating memory to represent the gather input index on + // every invocation we reuse the same storage for the result + // (input_index_), mutating it in place. + // - Instead of allocating buffers for temporary values like + // index_vector_index_ and index_vector on every invocation, we reuse the + // same storage for all invocations. + // + // This returns an arrayslice into memory owned by the class. + StatusOr> operator()(ArraySlice output_index) { + PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index); + TF_RETURN_IF_ERROR(FetchIndexVector()); + PropagateIndexVectorToInputIndex(); + return ArraySlice(input_index_); + } + + private: + // Propagates the gather index dimensions from the output index into + // index_vector_index_ by mutating index_vector_index_ in place. Does not + // update the dim_numbers.index_vector_dim() dimension -- that's the dimension + // we iterate over in FetchIndexVector. + void PropagateOutputIndexGatherDimsToIndexVectorIndex( + ArraySlice output_index) { + int64 index_vector_index_i = 0; + for (int64 i = 0, e = output_index.size(); i < e; i++) { + if (!output_dim_is_gather_dims_[i]) { + continue; + } + + if (index_vector_index_i == dim_numbers_.index_vector_dim()) { + index_vector_index_i++; + } + + index_vector_index_[index_vector_index_i++] = output_index[i]; + } + } + + // Populates index_vector_ by iterating over gather_indices_ according to + // index_vector_index_. + Status FetchIndexVector() { + int64 index_vector_dim = dim_numbers_.index_vector_dim(); + for (int64 i = 0, e = index_vector_.size(); i < e; i++) { + index_vector_index_[index_vector_dim] = i; + TF_ASSIGN_OR_RETURN(index_vector_[i], gather_indices_.GetIntegralAsS64( + index_vector_index_)); + } + return Status::OK(); + } + + // Populates input_index_. + void PropagateIndexVectorToInputIndex() { + for (int64 i = 0, e = input_index_.size(); i < e; i++) { + if (input_dim_value_to_index_vector_[i] != -1) { + input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]]; + } + + // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] + // remains 0, as set by the constructor. + } + } + + // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of + // the input index from the index vector. See + // PropagateIndexVectorToInputIndex. + std::vector input_dim_value_to_index_vector_; + + // output_dim_is_gather_dims_[i] is true iff the output index i is a gather + // dimension. + std::vector output_dim_is_gather_dims_; + + // The buffer into which we construct an index into gather_indices_ to fetch + // the index vector. + std::vector index_vector_index_; + + // The index vector fetched from gather_indices_. + std::vector index_vector_; + + // The result computed by this functor. operator() returns an ArraySlice into + // this vector. + std::vector input_index_; + + const GatherDimensionNumbers& dim_numbers_; + const Literal& gather_indices_; +}; + +// This functor computes the contribution of the window indices in an output +// index to an input index. That is, given an output index I it picks out the +// output window indices in I and expands it into a window index into the input +// shape. +class OutputWindowIndexToInputIndex { + public: + // The constructor does some setup work that is amortized across all + // iterations. + explicit OutputWindowIndexToInputIndex( + const GatherDimensionNumbers& dim_numbers, const Shape& input_shape, + const Shape& output_shape) { + std::vector window_index_to_output_index; + int64 output_index_count = 0; + for (int64 i = 0; i < output_shape.dimensions_size(); i++) { + if (c_binary_search(dim_numbers.output_window_dims(), i)) { + window_index_to_output_index.push_back(output_index_count++); + } else { + output_index_count++; + } + } + + int64 window_dim_count = 0; + for (int64 i = 0; i < input_shape.dimensions_size(); i++) { + if (c_binary_search(dim_numbers.elided_window_dims(), i)) { + input_dim_value_to_output_index_.push_back(-1); + } else { + input_dim_value_to_output_index_.push_back( + window_index_to_output_index[window_dim_count++]); + } + } + + input_index_.resize(input_shape.dimensions_size()); + } + + // Returns the contribution of the window indices to the input index + // corresponding to output_index. See gather_inner_loop_body. + // + // This is conceptually a stateless transformation from output_index to the + // window input index, but instead of allocating memory to represent the + // gather input index on every invocation we reuse the same storage for the + // result (input_index_), mutating it in place. + // + // This returns an arrayslice into memory owned by the class. + StatusOr> operator()(ArraySlice output_index) { + PropagateOutputIndexWindowDimsToInputIndex(output_index); + return ArraySlice(input_index_); + } + + private: + // Propagates window dimensions from the output index to input_index_ by + // mutating input_index_ in place. + void PropagateOutputIndexWindowDimsToInputIndex( + ArraySlice output_index) { + for (int64 i = 0, e = input_index_.size(); i < e; i++) { + if (input_dim_value_to_output_index_[i] != -1) { + input_index_[i] = output_index[input_dim_value_to_output_index_[i]]; + } + + // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] + // remains 0, as set by the constructor. + } + } + + // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of + // the input index from the output index. See + // PropagateOutputIndexToInputIndex. + std::vector input_dim_value_to_output_index_; + + // The result computed by this functor. operator() returns an ArraySlice into + // this vector. + std::vector input_index_; +}; + +// Rehapes the gather indices input to have a trailing degenerate `1` dimension +// if necessary. Hands over the ownership of the newly created literal (if +// there is one) to `reshaped_gather_indices`. +static StatusOr> ReshapedGatherIndices( + int64 index_vector_dim, const Literal& gather_indices, + std::unique_ptr* reshaped_gather_indices) { + if (gather_indices.shape().dimensions_size() != index_vector_dim) { + return std::cref(gather_indices); + } + + std::vector new_shape(gather_indices.shape().dimensions().begin(), + gather_indices.shape().dimensions().end()); + new_shape.push_back(1); + TF_ASSIGN_OR_RETURN(*reshaped_gather_indices, + gather_indices.Reshape(new_shape)); + return std::cref(**reshaped_gather_indices); +} + +Status HloEvaluator::HandleGather(HloInstruction* gather) { + std::unique_ptr result = Literal::CreateFromShape(gather->shape()); + const Shape& shape = gather->shape(); + const GatherDimensionNumbers& dim_numbers = + gather->gather_dimension_numbers(); + const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0)); + std::unique_ptr reshaped_gather_indices; + TF_ASSIGN_OR_RETURN( + const Literal& gather_indices, + ReshapedGatherIndices(dim_numbers.index_vector_dim(), + GetEvaluatedLiteralFor(gather->operand(1)), + &reshaped_gather_indices)); + + // We iterate over the gather dimensions in the output shape in an outer loop + // nest, and iterate over the window dimensions in the output shape in an + // inner loop nest. + + ShapeUtil::IndexIterationSpace gather_indices_iteration_space = + IterationSpaceForOutputGatherIndices(shape, dim_numbers); + ShapeUtil::IndexIterationSpace window_indices_iteration_space = + IterationSpaceForOutputWindowIndices( + shape.dimensions_size(), gather->gather_window_bounds(), dim_numbers); + + // Scratch buffers that hold an index in the output shape and the + // corresponding index in the input shape. + std::vector input_index(operand.shape().dimensions_size()); + std::vector output_index(gather->shape().dimensions_size()); + + OutputGatherIndexToInputIndex output_gather_index_to_input_index( + &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), + /*output_shape=*/shape, &gather_indices); + OutputWindowIndexToInputIndex output_window_index_to_input_index( + gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), + /*output_shape=*/shape); + + const Shape& operand_shape = operand.shape(); + + auto gather_inner_loop_body = + [&](ArraySlice output_window_index, + ArraySlice input_gather_index, + ArraySlice output_gather_index) -> StatusOr { + TF_ASSIGN_OR_RETURN( + ArraySlice input_window_index, + output_window_index_to_input_index(output_window_index)); + for (int i = 0, e = output_index.size(); i < e; i++) { + output_index[i] = output_gather_index[i] + output_window_index[i]; + DCHECK_LT(output_index[i], shape.dimensions(i)); + } + for (int i = 0, e = input_index.size(); i < e; i++) { + // TODO(b/74360564): We should implement whatever out of bounds behavior + // we decide for dynamic-slice here as well. + input_index[i] = (input_gather_index[i] + input_window_index[i]) % + operand_shape.dimensions(i); + if (input_index[i] < 0) { + input_index[i] += operand_shape.dimensions(i); + } + } + TF_RETURN_IF_ERROR( + result->CopyElementFrom(operand, input_index, output_index)); + return true; + }; + + auto gather_outer_loop_body = + [&](ArraySlice output_gather_index) -> StatusOr { + TF_ASSIGN_OR_RETURN( + ArraySlice input_gather_index, + output_gather_index_to_input_index(output_gather_index)); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + shape, window_indices_iteration_space, + std::bind(gather_inner_loop_body, std::placeholders::_1, + input_gather_index, output_gather_index))); + return true; + }; + + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + shape, gather_indices_iteration_space, gather_outer_loop_body)); + evaluated_[gather] = std::move(result); + return Status::OK(); +} + Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const auto result_shape = get_tuple_element->shape(); const int64 index = get_tuple_element->tuple_index(); @@ -2445,6 +2846,135 @@ Status HloEvaluator::HandleCopy(HloInstruction* copy) { return Status::OK(); } +Status HloEvaluator::HandleCall(HloInstruction* call) { + auto* computation = call->to_apply(); + auto operands = call->operands(); + + std::vector arg_literals; + arg_literals.reserve(operands.size()); + for (auto operand : operands) { + const Literal& arg_literal = GetEvaluatedLiteralFor(operand); + arg_literals.push_back(&arg_literal); + } + + HloEvaluator embedded_evaluator; + std::unique_ptr result = + embedded_evaluator.Evaluate(*computation, arg_literals) + .ConsumeValueOrDie(); + + evaluated_[call] = std::move(result); + return Status::OK(); +} + +Status HloEvaluator::HandleFusion(HloInstruction* fusion) { + // Attach cloned computation to an empty HLO module so the existing ones are + // not modified. + HloModule empty_hlo_module("EmptyModuleForFusion"); + auto cloned_fused_computation = + fusion->fused_instructions_computation()->Clone( + /*suffix=*/"clone_with_layout", &empty_hlo_module); + for (auto* instruction : cloned_fused_computation->instructions()) { + LayoutUtil::SetToDefaultLayout(instruction->mutable_shape()); + } + auto readded_computation = + empty_hlo_module.AddEntryComputation(std::move(cloned_fused_computation)); + + auto operands = fusion->operands(); + std::vector arg_literals; + arg_literals.reserve(operands.size()); + for (auto operand : operands) { + const Literal& arg_literal = GetEvaluatedLiteralFor(operand); + arg_literals.push_back(&arg_literal); + } + + HloEvaluator embedded_evaluator; + std::unique_ptr result = + embedded_evaluator + .Evaluate(*readded_computation, arg_literals) + .ConsumeValueOrDie(); + + evaluated_[fusion] = std::move(result); + return Status::OK(); +} + +Status HloEvaluator::HandleConditional(HloInstruction* conditional) { + const auto& pred = GetEvaluatedLiteralFor(conditional->operand(0)); + const auto& true_computation_arg = + GetEvaluatedLiteralFor(conditional->operand(1)); + const auto& false_computation_arg = + GetEvaluatedLiteralFor(conditional->operand(2)); + + auto* true_computation = conditional->true_computation(); + auto* false_computation = conditional->false_computation(); + + auto result = Literal::CreateFromShape(conditional->shape()); + HloEvaluator embedded_evaluator; + if (pred.Get({})) { + result = embedded_evaluator + .Evaluate(*true_computation, + {&true_computation_arg}) + .ConsumeValueOrDie(); + } else { + result = embedded_evaluator + .Evaluate(*false_computation, + {&false_computation_arg}) + .ConsumeValueOrDie(); + } + + evaluated_[conditional] = std::move(result); + return Status::OK(); +} + +Status HloEvaluator::HandleSelect(HloInstruction* select) { + const auto& pred = GetEvaluatedLiteralFor(select->operand(0)); + const auto& on_true = GetEvaluatedLiteralFor(select->operand(1)); + const auto& on_false = GetEvaluatedLiteralFor(select->operand(2)); + + // If predicate is of scalar type, no element-wise selection would be needed. + // This would also handle output array of tuple types as the DefaultAction + // would go through the TypedVisitor which doesn't handle tuples. + if (ShapeUtil::IsScalar(pred.shape())) { + if (pred.Get({})) { + evaluated_[select] = on_true.CloneToUnique(); + } else { + evaluated_[select] = on_false.CloneToUnique(); + } + return Status::OK(); + } + + return DefaultAction(select); +} + +Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { + HloComputation* cond_comp = while_hlo->while_condition(); + HloComputation* body_comp = while_hlo->while_body(); + // Initialize the loop carried valued with the input to the While instruction. + auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).CloneToUnique(); + bool keep_going = true; + int64 iteration_count = 0; + HloEvaluator cond_evaluator(max_loop_iterations_); + HloEvaluator loop_body_evaluator(max_loop_iterations_); + while (keep_going) { + if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) { + return InvalidArgument("Loop %s exceeded loop iteration limit (%lld).", + while_hlo->name().c_str(), max_loop_iterations_); + } + TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate( + *cond_comp, {lcv.get()})); + keep_going = cond_val->GetFirstElement(); + if (keep_going) { + TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate( + *body_comp, {lcv.get()})); + VLOG(3) << "Loop iteration result: " << body_val->ToString(); + lcv = std::move(body_val); + cond_evaluator.ResetVisitStates(); + loop_body_evaluator.ResetVisitStates(); + } + } + evaluated_[while_hlo] = std::move(lcv); + return Status::OK(); +} + Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); return Status::OK(); @@ -2458,28 +2988,27 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) { // Explicit instantiation of templatized Evaluate* methods. // -template StatusOr> HloEvaluator::Evaluate< - const Literal*>(const HloModule& module, - tensorflow::gtl::ArraySlice arg_literals); +template StatusOr> +HloEvaluator::Evaluate(const HloModule& module, + ArraySlice arg_literals); template StatusOr> HloEvaluator::Evaluate>( - const HloModule& module, - tensorflow::gtl::ArraySlice> arg_literals); + const HloModule& module, ArraySlice> arg_literals); -template StatusOr> HloEvaluator::Evaluate< - const Literal*>(const HloComputation& computation, - tensorflow::gtl::ArraySlice arg_literals); +template StatusOr> +HloEvaluator::Evaluate(const HloComputation& computation, + ArraySlice arg_literals); template StatusOr> HloEvaluator::Evaluate>( const HloComputation& computation, - tensorflow::gtl::ArraySlice> arg_literals); + ArraySlice> arg_literals); -template StatusOr> HloEvaluator::Evaluate< - const Literal*>(HloInstruction* instruction, - tensorflow::gtl::ArraySlice arg_literals); +template StatusOr> +HloEvaluator::Evaluate(HloInstruction* instruction, + ArraySlice arg_literals); template StatusOr> HloEvaluator::Evaluate>( HloInstruction* instruction, - tensorflow::gtl::ArraySlice> arg_literals); + ArraySlice> arg_literals); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 3b2b697e492a78a06a4e5ae6bf056ff8676f2ff5..c0dcee0c3e382f74de72a2b89f39e06f042e2b80 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -36,7 +36,10 @@ namespace xla { // This class is not thread-safe. class HloEvaluator : public DfsHloVisitorWithDefault { public: - HloEvaluator(); + // Only evaluate up to max_loop_iterations per while-loop execution if + // specified. + explicit HloEvaluator(int64 max_loop_iterations = -1); + // Evaluates an HLO module and an array of pointers to literals. // Returns the evaluated result as a literal if successful. // Precondition: The indices of arg_literals correspond to the parameter @@ -149,10 +152,22 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; + Status HandleGather(HloInstruction* gather) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleCopy(HloInstruction* copy) override; + Status HandleConditional(HloInstruction* conditional) override; + + Status HandleCall(HloInstruction* call) override; + + Status HandleFusion(HloInstruction* fusion) override; + + Status HandleWhile(HloInstruction* while_hlo) override; + + Status HandleSelect(HloInstruction* select) override; + private: // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be @@ -190,6 +205,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Must be cleared for each evaluation. std::vector arg_literals_; + // Max loop iterations to execute with no maximum if negative. + int64 max_loop_iterations_; + TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); }; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 97765d65909cee192f65069777f8f195081603b2..685cacd7f74c00789296dee16f0a6a94c35a4393 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -1729,6 +1729,207 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { *result.ValueOrDie()); } +TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { + const char* hlo_text = R"( +HloModule TensorFlowGatherV1 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[2,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1, 3} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + LiteralTestUtil::ExpectEqual( + *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}), + *Evaluate({operand.get(), gather_indices.get()})); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { + const char* hlo_text = R"( +HloModule TensorFlowGatherV2 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[3,2] gather(operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3, 1} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + LiteralTestUtil::ExpectEqual( + *Literal::CreateR2({{1, 3}, {4, 6}, {7, 9}}), + *Evaluate({operand.get(), gather_indices.get()})); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { + const char* hlo_text = R"( +HloModule TensorFlowGatherMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,3,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 2}, {2, 1}}); + LiteralTestUtil::ExpectEqual( + *Literal::CreateR3( + {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), + *Evaluate({operand.get(), gather_indices.get()})); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { + const char* hlo_text = R"( +HloModule TensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1,2} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralTestUtil::ExpectEqual( + *Literal::CreateR2({{-1, 1}, {-4, 4}}), + *Evaluate({operand.get(), gather_indices.get()})); +} + +TEST_P(HloEvaluatorTest, + EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim) { + const char* hlo_text = R"( +HloModule TensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1,2} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralTestUtil::ExpectEqual( + *Literal::CreateR2({{-2, 2}, {-1, 1}}), + *Evaluate({operand.get(), gather_indices.get()})); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { + const char* hlo_text = R"( +HloModule DynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[1,1] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + LiteralTestUtil::ExpectEqual( + *Literal::CreateR2({{5}}), + *Evaluate({operand.get(), gather_indices.get()})); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { + const char* hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,1,1] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{2, 1}, {1, 1}}); + LiteralTestUtil::ExpectEqual( + *Literal::CreateR3({{{8}}, {{5}}}), + *Evaluate({operand.get(), gather_indices.get()})); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { + const char* hlo_text = R"( +HloModule TensorFlowGatherV1 + +ENTRY main { + operand = s32[3,0] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[2,0] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1, 0} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = Literal::CreateR2({{}, {}, {}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + LiteralTestUtil::ExpectEqual( + *Literal::CreateR2({{}, {}}), + *Evaluate({operand.get(), gather_indices.get()})); +} + INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, ::testing::ValuesIn(use_bf16_params)); diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index f0df93b61d29c1535d8a89fbd65e669de5b43729..c3ccbf0f0c75b569b49652807dea52faebdccc31 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -111,8 +111,8 @@ HloExecutionProfile::HloExecutionProfile( : hlo_profile_printer_data_(*hlo_profile_printer_data), hlo_profile_index_map_(*hlo_profile_index_map), profile_counters_( - /*count*/ hlo_profile_index_map_.total_count(), - /*value*/ 0) {} + /*count=*/hlo_profile_index_map_.total_count(), + /*value=*/0) {} void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken) { diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 2861fec39ef0c92fdfbcee04584f9bd36d3cb4d8..1dc72355cf179e996caab4d6b52068dc99d02244 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -157,52 +157,60 @@ enum ColorScheme { kDashedBorder, }; +// Graphviz attributes/colors that make up a color scheme. +struct NodeColors { + const char* style; + const char* fill_color; + const char* stroke_color; + const char* font_color; +}; + +NodeColors NodeColorsForScheme(ColorScheme color) { + switch (color) { + case kBlue: + return NodeColors{"filled", "#bbdefb", "#8aacc8", "black"}; + case kBrown: + return NodeColors{"filled", "#bcaaa4", "#8c7b75", "black"}; + case kDarkBlue: + return NodeColors{"filled", "#1565c0", "#003c8f", "white"}; + case kDarkGreen: + return NodeColors{"filled", "#2e7d32", "#005005", "white"}; + case kDarkRed: + return NodeColors{"filled", "#b71c1c", "#7f0000", "white"}; + case kGray: + return NodeColors{"filled", "#cfd8dc", "#9ea7aa", "black"}; + case kGreen: + return NodeColors{"filled", "#c8e6c9", "#97b498", "black"}; + case kOrange: + return NodeColors{"filled", "#ffe0b2", "#cbae82", "black"}; + case kPurple: + return NodeColors{"filled", "#e1bee7", "#af8eb5", "black"}; + case kRed: + return NodeColors{"filled", "#ffcdd2", "#cb9ca1", "black"}; + case kWhite: + return NodeColors{"filled", "white", "black", "black"}; + case kYellow: + return NodeColors{"filled", "#fff9c4", "#cbc693", "black"}; + case kDashedBorder: + // "filled,dashed" looks the same as "dashed", since we have a white + // background. But we use "filled,dashed" so that when you hover over + // any part of the node (not just the text inside the node), our css + // :hover rule is triggered. + return NodeColors{"filled,dashed", "white", "#757575", "#757575"}; + } +} + // Given a ColorScheme, returns an attribute string for a node of that color. // Sets the node's style and fill/stroke/text colors. // // Colors are from https://material.io/color. string NodeColorAttributes(ColorScheme color) { - using std::make_tuple; - - const char *style, *fill_color, *stroke_color, *font_color; - std::tie(style, fill_color, stroke_color, font_color) = [color] { - switch (color) { - case kBlue: - return make_tuple("filled", "#bbdefb", "#8aacc8", "black"); - case kBrown: - return make_tuple("filled", "#bcaaa4", "#8c7b75", "black"); - case kDarkBlue: - return make_tuple("filled", "#1565c0", "#003c8f", "white"); - case kDarkGreen: - return make_tuple("filled", "#2e7d32", "#005005", "white"); - case kDarkRed: - return make_tuple("filled", "#b71c1c", "#7f0000", "white"); - case kGray: - return make_tuple("filled", "#cfd8dc", "#9ea7aa", "black"); - case kGreen: - return make_tuple("filled", "#c8e6c9", "#97b498", "black"); - case kOrange: - return make_tuple("filled", "#ffe0b2", "#cbae82", "black"); - case kPurple: - return make_tuple("filled", "#e1bee7", "#af8eb5", "black"); - case kRed: - return make_tuple("filled", "#ffcdd2", "#cb9ca1", "black"); - case kWhite: - return make_tuple("filled", "white", "black", "black"); - case kYellow: - return make_tuple("filled", "#fff9c4", "#cbc693", "black"); - case kDashedBorder: - // "filled,dashed" looks the same as "dashed", since we have a white - // background. But we use "filled,dashed" so that when you hover over - // any part of the node (not just the text inside the node), our css - // :hover rule is triggered. - return make_tuple("filled,dashed", "white", "#757575", "#757575"); - } - }(); + NodeColors node_colors = NodeColorsForScheme(color); return Printf( - R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", style, - font_color, stroke_color, fill_color); + R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", + node_colors.style, node_colors.font_color, node_colors.stroke_color, + node_colors.fill_color); } // Replaces <> with <>, so that this string is safe(er) for use in a @@ -604,11 +612,21 @@ tooltip = " "; StrAppend(&subcomp_label, "
", extra_info); } - // Subcomputation's fill/stroke color is light/dark red/gray, depending on - // whether or not the subcomputation's fusion node is highlighted. bool highlight = filter_.Highlight(parent_instr); - const char* fillcolor = highlight ? "#ffcdd2" : "#f5f5f5"; - const char* strokecolor = highlight ? "#b71c1c" : "#c2c2c2"; + const char* fillcolor; + const char* strokecolor; + if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) { + // Use the sharding color, if the node isn't highlighted. + NodeColors node_colors = + NodeColorsForScheme(GetInstructionColor(parent_instr)); + fillcolor = node_colors.fill_color; + strokecolor = node_colors.stroke_color; + } else { + // Subcomputation's fill/stroke color is light/dark red/gray, depending on + // whether or not the subcomputation's fusion node is highlighted. + fillcolor = highlight ? "#ffcdd2" : "#f5f5f5"; + strokecolor = highlight ? "#b71c1c" : "#c2c2c2"; + } style = Printf(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")", fillcolor, strokecolor); @@ -782,6 +800,14 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( auto stringify_constant = [](const HloInstruction* constant) { const auto& shape = constant->shape(); + // If the shape has a dimension of size zero, print it as e.g. + // "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(), + // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which + // is just noise. + if (ShapeUtil::HasZeroElements(shape)) { + return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape())); + } + // Print the literal value of constants with <= K elements. optional elem_count; if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index b7dd055d7cd78eb759a2b24bcbbbc948159f9425..a2a2c1e615a7f2b226c712a75b1240b980fc8d3c 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -51,24 +52,22 @@ using ::tensorflow::strings::StrCat; /* static */ StatusOr> HloInstruction::CreateFromProto( HloModule* module, const HloInstructionProto& proto, - const tensorflow::gtl::FlatMap& instruction_map, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation) { + const tensorflow::gtl::FlatMap& instruction_map, + const tensorflow::gtl::FlatMap& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); - for (const string& operand_name : proto.operand_names()) { - TF_RET_CHECK(ContainsKey(instruction_map, operand_name)) - << "No instruction named " << operand_name; - instruction->AppendOperand(instruction_map.at(operand_name)); - } - for (const string& predecessor_name : proto.control_predecessor_names()) { - TF_RET_CHECK(ContainsKey(instruction_map, predecessor_name)) - << "No instruction named " << predecessor_name; - TF_RETURN_IF_ERROR(instruction_map.at(predecessor_name) + for (const int64 operand_id : proto.operand_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) + << "No instruction with id " << operand_id; + instruction->AppendOperand(instruction_map.at(operand_id)); + } + for (const int64 predecessor_id : proto.control_predecessor_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) + << "No instruction with id " << predecessor_id; + TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) ->AddControlDependencyTo(instruction.get())); } @@ -76,23 +75,26 @@ StatusOr> HloInstruction::CreateFromProto( // HloInstructionProto and do not appear as an HloComputationProto within the // HloModuleProto. if (instruction->opcode() == HloOpcode::kFusion) { - TF_RET_CHECK(proto.has_fused_instructions_computation()); TF_RET_CHECK(!proto.fusion_kind().empty()); TF_ASSIGN_OR_RETURN(instruction->fusion_kind_, StringToFusionKind(proto.fusion_kind())); - TF_ASSIGN_OR_RETURN(std::unique_ptr fused_computation, - HloComputation::CreateFromProto( - module, proto.fused_instructions_computation(), - computation_map, add_fused_computation, - /*fusion_instruction=*/instruction.get())); - instruction->called_computations_.push_back(fused_computation.get()); - add_fused_computation(std::move(fused_computation)); + + // Find the fused computation and set its fusion instruction. + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Expect 1 called computation for fusion instruction, but sees " + << proto.called_computation_ids_size(); + const int64 fusion_id = proto.called_computation_ids(0); + auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); + TF_RET_CHECK(fused_computation != nullptr) + << "No fusion computation with id " << fusion_id; + fused_computation->SetFusionInstruction(instruction.get()); + instruction->called_computations_.push_back(fused_computation); } else { - for (const string& computation_name : proto.called_computation_names()) { - TF_RET_CHECK(ContainsKey(computation_map, computation_name)) - << "No computation named " << computation_name; + for (const int64 computation_id : proto.called_computation_ids()) { + TF_RET_CHECK(ContainsKey(computation_map, computation_id)) + << "No computation with id " << computation_id; instruction->called_computations_.push_back( - computation_map.at(computation_name)); + computation_map.at(computation_id)); } } @@ -182,6 +184,7 @@ StatusOr> HloInstruction::CreateFromProto( /* static */ std::unique_ptr HloInstruction::CreateGetTupleElement(const Shape& shape, HloInstruction* operand, int64 index) { + CHECK(ShapeUtil::IsTuple(operand->shape())); auto instruction = WrapUnique(new HloInstruction(HloOpcode::kGetTupleElement, shape)); instruction->tuple_index_ = index; @@ -1172,7 +1175,8 @@ bool HloInstruction::HasSideEffect() const { /* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers( tensorflow::gtl::ArraySlice output_window_dims, tensorflow::gtl::ArraySlice elided_window_dims, - tensorflow::gtl::ArraySlice gather_dims_to_operand_dims) { + tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, + int64 index_vector_dim) { GatherDimensionNumbers gather_dim_numbers; for (int64 output_window_dim : output_window_dims) { gather_dim_numbers.add_output_window_dims(output_window_dim); @@ -1184,6 +1188,7 @@ bool HloInstruction::HasSideEffect() const { gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); } + gather_dim_numbers.set_index_vector_dim(index_vector_dim); return gather_dim_numbers; } @@ -2310,14 +2315,18 @@ string HloInstruction::ToShortString() const { HloInstructionProto HloInstruction::ToProto() const { HloInstructionProto proto; + CHECK(unique_id_ != -1) + << "This instruction does not have a valid id. Please make sure the " + "instruction is inside a module before dumping it."; + proto.set_id(unique_id_); proto.set_name(name_); proto.set_opcode(HloOpcodeString(opcode_)); *proto.mutable_shape() = shape_; for (const HloInstruction* operand : operands_) { - *proto.add_operand_names() = operand->name(); + proto.add_operand_ids(operand->unique_id()); } for (const HloInstruction* control : control_predecessors_) { - *proto.add_control_predecessor_names() = control->name(); + proto.add_control_predecessor_ids(control->unique_id()); } *proto.mutable_metadata() = metadata_; @@ -2327,11 +2336,11 @@ HloInstructionProto HloInstruction::ToProto() const { proto.set_parameter_number(parameter_number_); if (opcode() == HloOpcode::kFusion) { proto.set_fusion_kind(xla::ToString(fusion_kind())); - *proto.mutable_fused_instructions_computation() = - fused_instructions_computation()->ToProto(); + proto.add_called_computation_ids( + fused_instructions_computation()->unique_id()); } else { for (const HloComputation* computation : called_computations_) { - *proto.add_called_computation_names() = computation->name(); + proto.add_called_computation_ids(computation->unique_id()); } } @@ -2680,8 +2689,10 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { case HloOpcode::kTrace: break; } - return Unimplemented("unhandled HloOpcode for DfsHloVisitor: %s", - HloOpcodeString(opcode_).c_str()); + return InternalError( + "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - " + "please file a bug for XLA.", + HloOpcodeString(opcode_).c_str()); } // Explicit instantiations. @@ -3369,9 +3380,12 @@ string HloInstruction::GatherDimensionNumbersToString() const { string gather_dims_to_operand_dims = StrCat( "gather_dims_to_operand_dims={", Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); + string index_vector_dim = StrCat( + "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); return Join>( - {output_window_dims, elided_window_dims, gather_dims_to_operand_dims}, + {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, + index_vector_dim}, ", "); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index e4d22e5703811dc9b5f3ea3ee1ca85fd848f88b2..a94ba145df792ade9bb7ce3e9a31b56b2f460cd2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -179,20 +179,15 @@ class HloInstruction { // module: the module which will contain the instruction. The newly created // instruction is *not* added to the module or any computation, however. // proto: the proto to convert from. - // instruction_map: a map from instruction name to HloInstruction*. This map + // instruction_map: a map from instruction id to HloInstruction*. This map // must contain all operands of the newly constructed instruction. - // computation_map: a map from computation name to HloComputation*. This map + // computation_map: a map from computation id to HloComputation*. This map // must contain all computations which the newly constructed instruction // calls. - // add_fused_computation: A function to call to add a fused - // computation. Used (clearly) when the instruction is a fusion - // instruction. static StatusOr> CreateFromProto( HloModule* module, const HloInstructionProto& proto, - const tensorflow::gtl::FlatMap& instruction_map, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation); + const tensorflow::gtl::FlatMap& instruction_map, + const tensorflow::gtl::FlatMap& computation_map); // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, @@ -502,7 +497,8 @@ class HloInstruction { static GatherDimensionNumbers MakeGatherDimNumbers( tensorflow::gtl::ArraySlice output_window_dims, tensorflow::gtl::ArraySlice elided_window_dims, - tensorflow::gtl::ArraySlice gather_dims_to_operand_dims); + tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, + int64 index_vector_dim); // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 32d3ed272bd6b239918076999ecae6c1b3ded2fd..f2980d309d01fdf3b3e601bc260a0ad0895b3064 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1271,7 +1271,7 @@ TEST_F(HloInstructionTest, Stringification) { "true_computation=%TransposeDot, false_computation=%TransposeDot"); } -TEST_F(HloInstructionTest, StringifyGather) { +TEST_F(HloInstructionTest, StringifyGather_0) { Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); Shape gather_indices_tensor_shape = ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); @@ -1291,7 +1291,8 @@ TEST_F(HloInstructionTest, StringifyGather) { HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26})); HloModule module(TestName()); @@ -1303,7 +1304,43 @@ TEST_F(HloInstructionTest, StringifyGather) { "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), " "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " "gather_dims_to_operand_dims={0,1,2,3,4}, " - "window_bounds={30,29,28,27,26}"); + "index_vector_dim=4, window_bounds={30,29,28,27,26}"); +} + +TEST_F(HloInstructionTest, StringifyGather_1) { + Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); + Shape gather_indices_tensor_shape = + ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); + Shape gather_result_shape = + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}); + + HloComputation::Builder builder("Gather"); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); + HloInstruction* gather_indices = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, gather_indices_tensor_shape, "gather_indices")); + + HloInstruction* gather_instruction = + builder.AddInstruction(HloInstruction::CreateGather( + gather_result_shape, input, gather_indices, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(gather_instruction->ToString(), + "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " + "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " + "s64[10,9,5,7,6]{4,3,2,1,0} %gather_indices), " + "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " + "gather_dims_to_operand_dims={0,1,2,3,4}, " + "index_vector_dim=2, window_bounds={30,29,28,27,26}"); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index cb2fe9f874012a51e1e6cbd1dd086dbb26994bde..595c531ccff728f836cfaca2fafaa8a08e715b74 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -83,6 +83,11 @@ HloComputation* HloModule::AddComputationInternal( for (auto* instruction : computation->instructions()) { instruction->SetUniqueId(NewUniqueInstructionId()); } + // Set unique id to this computation. + CHECK_NE(computation->root_instruction()->unique_id(), -1) + << "Root has no valid id: " << computation->ToString(); + computation->SetUniqueId(computation->root_instruction()->unique_id()); + computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); @@ -204,90 +209,36 @@ string HloModule::ToString(const HloPrintOptions& options) const { HloModuleProto HloModule::ToProto() const { HloModuleProto proto; + proto.set_id(unique_id_); proto.set_name(name_); proto.set_entry_computation_name(entry_computation_->name()); + proto.set_entry_computation_id(entry_computation_->unique_id()); for (const HloComputation* computation : MakeComputationPostOrder()) { - // Fusion computations are added when the fusion instructions are created by - // HloInstruction::CreateFromProto. - if (computation->IsFusionComputation()) { - continue; - } HloComputationProto computation_proto = computation->ToProto(); + if (computation->name() == entry_computation_->name()) { + *proto.mutable_program_shape() = computation_proto.program_shape(); + } proto.add_computations()->Swap(&computation_proto); } return proto; } -namespace { - -// Construct a ProgramShape matching the shape of the parameters and root of the -// given module's entry computation. -StatusOr ProgramShapeFromProto(const HloModuleProto& module) { - const HloComputationProto* entry_computation = nullptr; - for (const HloComputationProto& computation : module.computations()) { - if (computation.name() == module.entry_computation_name()) { - entry_computation = &computation; - break; - } - } - TF_RET_CHECK(entry_computation != nullptr) - << "No computation with entry computation name" - << module.entry_computation_name(); - - tensorflow::gtl::FlatMap> parameters; - const HloInstructionProto* root = nullptr; - for (const HloInstructionProto& instruction : - entry_computation->instructions()) { - if (instruction.name() == entry_computation->root_name()) { - TF_RET_CHECK(root == nullptr) << "Entry computation has more than " - "one instruction with (root) name " - << instruction.name(); - root = &instruction; - } - if (instruction.opcode() == HloOpcodeString(HloOpcode::kParameter)) { - TF_RET_CHECK(!ContainsKey(parameters, instruction.parameter_number())) - << "Entry computation has more than one parameter instruction " - "with parameter number " - << instruction.parameter_number(); - parameters[instruction.parameter_number()] = {instruction.name(), - &instruction.shape()}; - } - } - TF_RET_CHECK(root != nullptr) - << "Entry computation is missing root instruction named " - << entry_computation->root_name(); - - ProgramShape program_shape; - *program_shape.mutable_result() = root->shape(); - for (int64 i = 0; i < parameters.size(); ++i) { - TF_RET_CHECK(ContainsKey(parameters, i)) - << "Entry computation missing parameter number " << i; - const string& name = parameters.at(i).first; - const Shape& shape = *parameters.at(i).second; - *program_shape.add_parameters() = shape; - program_shape.add_parameter_names(name); - } - - return std::move(program_shape); -} - -} // namespace - /* static */ StatusOr> HloModule::CreateFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config, const VersionedComputationHandle& entry_computation_handle) { // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. - TF_ASSIGN_OR_RETURN(ProgramShape expected_program_shape, - ProgramShapeFromProto(proto)); + TF_RET_CHECK(proto.has_program_shape()) + << "No program shape found in the proto"; + const auto& expected_program_shape = proto.program_shape(); TF_RET_CHECK(expected_program_shape.parameters_size() == module_config.entry_computation_layout().parameter_count()); for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { const Shape& parameter_shape = module_config.entry_computation_layout().parameter_layout(i).shape(); - TF_RET_CHECK( - ShapeUtil::Equal(expected_program_shape.parameters(i), parameter_shape)) + TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i), + parameter_shape)) << "HloModuleConfig has different shape for parameter " << i << " than the HLO module. Expected: " << ShapeUtil::HumanStringWithLayout( @@ -296,7 +247,8 @@ StatusOr> HloModule::CreateFromProto( } const Shape& result_shape = module_config.entry_computation_layout().result_layout().shape(); - TF_RET_CHECK(ShapeUtil::Equal(expected_program_shape.result(), result_shape)) + TF_RET_CHECK( + ShapeUtil::Compatible(expected_program_shape.result(), result_shape)) << "HloModuleConfig has different result shape than the HLO module. " "Expected: " << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) @@ -305,26 +257,20 @@ StatusOr> HloModule::CreateFromProto( auto module = MakeUnique(proto.name(), entry_computation_handle, module_config); - tensorflow::gtl::FlatMap computation_map; + tensorflow::gtl::FlatMap computation_map; for (const HloComputationProto& computation_proto : proto.computations()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr computation, - HloComputation::CreateFromProto( - module.get(), computation_proto, computation_map, - /*add_fused_computation=*/ - [&module](std::unique_ptr fused_computation) { - module->AddComputationInternal(std::move(fused_computation), - /*is_entry=*/false, - /*uniquify_names=*/false); - })); + TF_ASSIGN_OR_RETURN(std::unique_ptr computation, + HloComputation::CreateFromProto( + module.get(), computation_proto, computation_map)); CHECK_NE(computation.get(), nullptr); - TF_RET_CHECK(!ContainsKey(computation_map, computation->name())); - string computation_name = computation->name(); + int64 computation_id = computation_proto.id(); + TF_RET_CHECK(computation_id != -1); + TF_RET_CHECK(!ContainsKey(computation_map, computation_id)); // Don't uniquify names because we want names to be stable across // serialization and deserialization. - computation_map[computation_name] = module->AddComputationInternal( + computation_map[computation_id] = module->AddComputationInternal( std::move(computation), - /*is_entry=*/proto.entry_computation_name() == computation_name, + /*is_entry=*/proto.entry_computation_id() == computation_id, /*uniquify_names=*/false); } TF_RET_CHECK(module->entry_computation_ != nullptr); @@ -334,10 +280,6 @@ StatusOr> HloModule::CreateFromProto( tensorflow::gtl::FlatSet computation_names; tensorflow::gtl::FlatSet instruction_names; for (HloComputation* computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; - } - TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) << "Computation name is not unique: " << computation->name(); computation_names.insert(computation->name()); @@ -354,8 +296,9 @@ StatusOr> HloModule::CreateFromProto( /* static */ StatusOr HloModule::CreateModuleConfigFromProto( const HloModuleProto& module) { - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - ProgramShapeFromProto(module)); + TF_RET_CHECK(module.has_program_shape()) + << "No program shape found in the proto"; + const auto& program_shape = module.program_shape(); HloModuleConfig module_config(program_shape); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 06d92f94fd6f62162b22575e9cc341f2906cd0db..755bbd359f7b95e7f3f3cbee1b46df85908202c6 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -103,7 +103,7 @@ class HloModule { return config_.mutable_entry_computation_layout(); } - ComputationLayout entry_computation_layout() const { + const ComputationLayout& entry_computation_layout() const { return config_.entry_computation_layout(); } @@ -187,11 +187,6 @@ class HloModule { // Returns a randomly generated uint64. uint64 RandomNew64() const; - // Returns the unique name for a computation in this module. - string GetUniqueCompuationName(const string& prefix) { - return computation_name_uniquer_.GetUniqueName(prefix); - } - // Returns the NameUniquer for uniquing instruction names in this module. NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; } diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index 822e2f1f53e5ee460b88c2241ecf7f6b91ef608b..4205b0402cb8b2c31141d65be652cd84c22e7262 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -40,7 +40,7 @@ void HloModuleConfig::SetDefaultComputationLayout( string HloModuleConfig::compilation_cache_key() const { string key = - tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_); + tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled()); StrAppend(&key, "::("); std::vector params; for (const ShapeLayout& param_layout : diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index d3c1fae592bb465609ffbde2d0262e2600912e63..586a03d412681cacdd780f48e77baf4cd4c51415 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -63,9 +63,10 @@ class HloModuleConfig { return &(*entry_computation_layout_); } - // Sets/returns whether to enable HLO-level profiling. - bool hlo_profiling_enabled() const { return hlo_profiling_enabled_; } - void enable_hlo_profiling(bool enabled) { hlo_profiling_enabled_ = enabled; } + // Returns whether to enable HLO-level profiling. + bool hlo_profiling_enabled() const { + return debug_options_.xla_hlo_profile(); + } // Sets/returns whether this is a "host module". Host modules are used to // record the data- and control-flow dependencies of host side computation @@ -110,9 +111,6 @@ class HloModuleConfig { tensorflow::gtl::optional entry_computation_layout_; - // Whether to enable HLO-level profiling. - bool hlo_profiling_enabled_ = false; - // Whether this is a 'host module'. bool is_host_module_ = false; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc new file mode 100644 index 0000000000000000000000000000000000000000..fa5dcb0b369d17c70c64c67b9f11640c93fb4278 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -0,0 +1,350 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" + +#include +#include + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +string HloModuleGroupMetadata::TrackedInstruction::ToString() const { + string repr = + (instruction_ != nullptr) ? instruction_->ToShortString() : "NULL"; + switch (kind_) { + case ComputationKind::kInvalid: + repr += ":INVALID"; + break; + case ComputationKind::kWhileCondition: + repr += ":WHILE_CONDITION"; + break; + case ComputationKind::kWhileBody: + repr += ":WHILE_BODY"; + break; + case ComputationKind::kConditionalTrue: + repr += ":CONDITIONAL_TRUE"; + break; + case ComputationKind::kConditionalFalse: + repr += ":CONDITIONAL_FALSE"; + break; + } + return repr; +} + +/* static */ StatusOr> +HloModuleGroupMetadata::Build(const std::vector& modules) { + auto metadata = absl::make_unique(modules); + TF_RETURN_IF_ERROR(metadata->Build()); + return std::move(metadata); +} + +Status HloModuleGroupMetadata::Build() { + TF_RETURN_IF_ERROR(RecordInstructions()); + TF_RETURN_IF_ERROR(VerifyChannelInstructions()); + + // Record all companion while instructions. + const auto visitor = [this](HloInstruction* hlo) -> Status { + // We only need to process if the instruction is within the computation + // of a companion instruction, like in the condition or body computation + // of a While. + const TrackedInstruction* tracked = GetTrackedInstruction(hlo->parent()); + if (tracked == nullptr) { + return Status::OK(); + } + // Add the parent computation of this channel instruction and its peer + // computation (both must be while computations) as companions. + if (IsChannelInstruction(hlo)) { + HloComputation* peer_computation = PeerComputation(hlo); + const TrackedInstruction* peer_tracked = + GetTrackedInstruction(peer_computation); + TF_RET_CHECK(peer_tracked != nullptr) + << "Peer instruction is not a possible companion"; + TF_RET_CHECK(*tracked == *peer_tracked) + << "Peer instruction does not match the computation kind"; + TF_RETURN_IF_ERROR( + AddCompanion(tracked->instruction(), peer_tracked->instruction())); + } + + // Add the parents of companion instructions (they must be all of the same + // kind of instructions, opcode wise) as companions. + if (IsCompanionInstruction(hlo)) { + for (HloInstruction* companion : Companions(hlo)) { + const TrackedInstruction* companion_tracked = + GetTrackedInstruction(companion->parent()); + TF_RET_CHECK(companion_tracked != nullptr); + TF_RET_CHECK(*tracked == *companion_tracked); + TF_RETURN_IF_ERROR(AddCompanion(tracked->instruction(), + companion_tracked->instruction())); + } + } + return Status::OK(); + }; + + // Visit the computations in postorder so that the companion information grows + // from inner computations to outer ones. + for (HloModule* module : modules_) { + for (HloComputation* computation : module->MakeComputationPostOrder()) { + TF_RETURN_IF_ERROR(computation->Accept(visitor)); + } + } + return Status::OK(); +} + +bool HloModuleGroupMetadata::IsChannelInstruction( + const HloInstruction* instruction) const { + switch (instruction->opcode()) { + case HloOpcode::kSend: + case HloOpcode::kRecv: + case HloOpcode::kSendDone: + case HloOpcode::kRecvDone: + return true; + default: + return false; + } +} + +bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const { + return companion_set_index_.count(hlo) > 0; +} + +bool HloModuleGroupMetadata::InstructionCommunicates( + HloInstruction* hlo) const { + return IsChannelInstruction(hlo) || IsCompanionInstruction(hlo); +} + +const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel( + int64 channel_id) const { + CHECK(channel_id_map_.find(channel_id) != channel_id_map_.end()); + return channels_[channel_id_map_.at(channel_id)]; +} + +HloComputation* HloModuleGroupMetadata::PeerComputation( + const HloInstruction* instruction) const { + CHECK(IsChannelInstruction(instruction)); + const Channel& channel = GetChannel(instruction->channel_id()); + switch (instruction->opcode()) { + case HloOpcode::kSend: + case HloOpcode::kSendDone: + return channel.recv->parent(); + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + return channel.send->parent(); + default: + LOG(FATAL) << "opcode not supported"; + } +} + +std::vector +HloModuleGroupMetadata::GetCompanionsPath(const HloInstruction* hlo) const { + std::vector path; + const HloComputation* parent = hlo->parent(); + const TrackedInstruction* companion; + while ((companion = GetTrackedInstruction(parent)) != nullptr) { + parent = companion->instruction()->parent(); + path.push_back(*companion); + } + return path; +} + +bool HloModuleGroupMetadata::CheckCompanionPathsCompatibility( + const std::vector& path0, + const std::vector& path1) const { + if (path0.size() != path1.size()) { + VLOG(5) << "Companion path size do not match: " << path0.size() + << " != " << path1.size(); + return false; + } + for (int64 i = 0; i < path0.size(); ++i) { + if (path0[i] != path1[i]) { + VLOG(5) << "Companion instructions at path index " << i + << " do not have the same opcode: " << path0[i].ToString() + << " vs " << path1[i].ToString(); + return false; + } + } + return true; +} + +int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const { + for (int64 i = 0; i < modules_.size(); ++i) { + if (modules_[i] == module) { + return i; + } + } + LOG(FATAL) << "unknown module"; +} + +Status HloModuleGroupMetadata::RecordInstructions() { + const auto visitor = [this](HloInstruction* hlo) -> Status { + if (hlo->opcode() == HloOpcode::kWhile) { + tracked_instructions_[hlo->while_condition()] = + TrackedInstruction(hlo, ComputationKind::kWhileCondition); + tracked_instructions_[hlo->while_body()] = + TrackedInstruction(hlo, ComputationKind::kWhileBody); + } else if (hlo->opcode() == HloOpcode::kConditional) { + tracked_instructions_[hlo->true_computation()] = + TrackedInstruction(hlo, ComputationKind::kConditionalTrue); + tracked_instructions_[hlo->false_computation()] = + TrackedInstruction(hlo, ComputationKind::kConditionalFalse); + } + if (!IsChannelInstruction(hlo)) { + return Status::OK(); + } + + // Add a new channel if needed. + if (channel_id_map_.find(hlo->channel_id()) == channel_id_map_.end()) { + channels_.emplace_back(); + channels_.back().id = hlo->channel_id(); + channel_id_map_[hlo->channel_id()] = channels_.size() - 1; + max_channel_id_ = std::max(max_channel_id_, hlo->channel_id()); + } + Channel& channel = channels_[channel_id_map_[hlo->channel_id()]]; + + if (hlo->opcode() == HloOpcode::kSend) { + TF_RET_CHECK(channel.send == nullptr) + << "channel id " << hlo->channel_id() + << " is used by multiple send instructions"; + channel.send = hlo; + } + if (hlo->opcode() == HloOpcode::kRecv) { + TF_RET_CHECK(channel.recv == nullptr) + << "channel id " << hlo->channel_id() + << " is used by multiple recv instructions"; + channel.recv = hlo; + } + if (hlo->opcode() == HloOpcode::kSendDone) { + TF_RET_CHECK(channel.send_done == nullptr) + << "channel id " << hlo->channel_id() + << " is used by multiple send-done instructions"; + channel.send_done = hlo; + } + if (hlo->opcode() == HloOpcode::kRecvDone) { + TF_RET_CHECK(channel.recv_done == nullptr) + << "channel id " << hlo->channel_id() + << " is used by multiple recv-done instructions"; + channel.recv_done = hlo; + } + return Status::OK(); + }; + + for (HloModule* module : modules_) { + for (auto* computation : module->computations()) { + TF_RETURN_IF_ERROR(computation->Accept(visitor)); + } + } + return Status::OK(); +} + +Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, + HloInstruction* instruction2) { + TF_RET_CHECK(instruction1->opcode() == HloOpcode::kWhile || + instruction1->opcode() == HloOpcode::kConditional); + VLOG(2) << "adding as companions:" << instruction1->ToString() << " and " + << instruction2->ToString(); + + if (!ContainsKey(companion_set_index_, instruction1) && + !ContainsKey(companion_set_index_, instruction2)) { + companion_sets_.push_back( + absl::make_unique>()); + auto companion_set = companion_sets_.back().get(); + companion_set->insert(instruction1); + companion_set->insert(instruction2); + companion_set_index_[instruction1] = companion_sets_.size() - 1; + companion_set_index_[instruction2] = companion_sets_.size() - 1; + } else if (!ContainsKey(companion_set_index_, instruction1)) { + companion_sets_[companion_set_index_[instruction2]]->insert(instruction1); + companion_set_index_[instruction1] = companion_set_index_[instruction2]; + } else if (!ContainsKey(companion_set_index_, instruction2)) { + companion_sets_[companion_set_index_[instruction1]]->insert(instruction2); + companion_set_index_[instruction2] = companion_set_index_[instruction1]; + } else if (companion_set_index_[instruction1] != + companion_set_index_[instruction2]) { + companion_sets_[companion_set_index_[instruction1]]->insert( + Companions(instruction2).begin(), Companions(instruction2).end()); + int64 index_to_remove = companion_set_index_[instruction2]; + for (HloInstruction* hlo : Companions(instruction2)) { + companion_set_index_[hlo] = companion_set_index_[instruction1]; + } + companion_sets_.erase(companion_sets_.begin() + index_to_remove); + } + return Status::OK(); +} + +Status HloModuleGroupMetadata::VerifyChannelInstructions() { + for (const Channel& channel : channels_) { + if (channel.send == nullptr) { + return FailedPrecondition("missing send for id : %lld", channel.id); + } + if (channel.recv == nullptr) { + return FailedPrecondition("missing recv for id : %lld", channel.id); + } + if (channel.send_done == nullptr) { + return FailedPrecondition("missing send-done for id : %lld", channel.id); + } + if (channel.recv_done == nullptr) { + return FailedPrecondition("missing recv-done for id : %lld", channel.id); + } + } + + // Check if the shapes match for each channel. + for (const Channel& channel : channels_) { + const Shape& send_shape = channel.send->operand(0)->shape(); + const Shape& recv_shape = channel.recv_done->shape(); + if (!ShapeUtil::Compatible(send_shape, recv_shape)) { + return FailedPrecondition("send/recv shapes do not match"); + } + } + + // Check if channel instructions are used only in allowed computations. + const auto allowed = [this](HloInstruction* hlo) { + HloComputation* computation = hlo->parent(); + const HloModule* module = computation->parent(); + if (module->entry_computation() == computation || + tracked_instructions_.count(computation) > 0) { + return true; + } + return false; + }; + for (const Channel& channel : channels_) { + if (!allowed(channel.send) || !allowed(channel.send_done) || + !allowed(channel.recv) || !allowed(channel.recv_done)) { + return FailedPrecondition("channel is used in disallowed computation"); + } + } + // Check if the nest levels match for each channel. + for (const Channel& channel : channels_) { + std::vector path = GetCompanionsPath(channel.send); + if (!CheckCompanionPathsCompatibility( + path, GetCompanionsPath(channel.send_done)) || + !CheckCompanionPathsCompatibility(path, + GetCompanionsPath(channel.recv)) || + !CheckCompanionPathsCompatibility( + path, GetCompanionsPath(channel.recv_done))) { + return FailedPrecondition( + "Nest companion paths do not match for channel %lld", channel.id); + } + } + return Status::OK(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h new file mode 100644 index 0000000000000000000000000000000000000000..c48a7ab0b59269474f7406ef24a249355528e085 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -0,0 +1,239 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Class for bookkeeping the information on the given modules, in particular on +// the interaction between computations. +// +// Companion instructions are one of the information collected as we build the +// metadata. For example, for each While instruction, companion instructions +// refer to a set of While instructions in other computations that communicate +// with each other. +// In the example below with 3 modules, {While_0, While_2, While_5}, {While_1, +// While_4}, {While_3, While_6} are companion sets. +// +// +// While_0() { While_2() { While_5() { +// While_1() { Send(0) } While_3() { Send(1) } While_6() { Recv(1) } +// } While_4() { Recv(0) } +// } +// +// Companion instructions are used to detect cycles in the graph and also for +// global scheduling. +class HloModuleGroupMetadata { + public: + // The kind of companion computation a given instruction can be within. + enum class ComputationKind { + kInvalid, + kWhileCondition, + kWhileBody, + kConditionalTrue, + kConditionalFalse, + }; + + // Tracks the instruction mapped to a given computation, and the computation + // kind. + // For example, a body computation of a while instruction, will generate a + // TrackedInstruction with instruction being the while instruction, and + // kind being ComputationKind::kWhileBody. + class TrackedInstruction { + public: + TrackedInstruction() = default; + TrackedInstruction(HloInstruction* instruction, ComputationKind kind) + : instruction_(instruction), kind_(kind) {} + + bool operator==(const TrackedInstruction& rhs) const { + return instruction_->opcode() == rhs.instruction_->opcode() && + kind_ == rhs.kind_; + } + bool operator!=(const TrackedInstruction& rhs) const { + return !operator==(rhs); + } + + HloInstruction* instruction() const { return instruction_; } + + string ToString() const; + + private: + HloInstruction* instruction_ = nullptr; + ComputationKind kind_ = ComputationKind::kInvalid; + }; + + // Represents a channel and the 4 instructions that form the channel. + struct Channel { + int64 id = -1; + HloInstruction* send = nullptr; + HloInstruction* recv = nullptr; + HloInstruction* send_done = nullptr; + HloInstruction* recv_done = nullptr; + }; + + explicit HloModuleGroupMetadata(const std::vector& modules) + : modules_(modules) {} + + ~HloModuleGroupMetadata() = default; + + // Build and return the metadata for the given modules. + static StatusOr> Build( + const std::vector& modules); + + // Returns true if the instruction is one of the 4 channel instructions (Send, + // Recv, SendDone, RecvDone). + bool IsChannelInstruction(const HloInstruction* instruction) const; + + // Returns true if the instruction is a companion instruction. See the class + // comment above on companion instructions. + bool IsCompanionInstruction(HloInstruction* hlo) const; + + // Returns true if the instruction is either a channel instruction or a + // companion instruction. + bool InstructionCommunicates(HloInstruction* hlo) const; + + // Returns the Channel instance for the given channel id. + const Channel& GetChannel(int64 channel_id) const; + + // Returns the computation that contains the peer channel instructions for + // the given instruction. + // + // Precondition: IsChannelInstruction(instruction) is true. + HloComputation* PeerComputation(const HloInstruction* instruction) const; + + // Returns the path of the nested companion instructions, in terms of HLO + // instructions. The path goes from inner to outer companions. + // The returned path does not include the input hlo instruction, in case it + // is a companion instruction. + std::vector GetCompanionsPath( + const HloInstruction* hlo) const; + + // Checks whether two companion paths (as returned by the GetCompanionsPath() + // API) are compatible. The two paths are compatible if the sequence of + // opcodes, and the companion kinds, of the two paths matches. + bool CheckCompanionPathsCompatibility( + const std::vector& path0, + const std::vector& path1) const; + + // Returns the unique integer for each module. The returned id is the index of + // the module in the module vector. + int64 GetModuleId(const HloModule* module) const; + + // Returns the companion instructions for the given instruction. + // + // Precondition: IsCompanionWhile(instruction) is true. + const std::unordered_set& Companions( + HloInstruction* instruction) const { + CHECK_EQ(companion_set_index_.count(instruction), 1); + return companion_set(companion_set_index_.at(instruction)); + } + + // Returns the companion set at the given index. + const std::unordered_set& companion_set(int64 index) const { + CHECK_LT(index, companion_sets_.size()); + return *companion_sets_[index]; + } + + // Returns the companion set index of the given instruction. + int64 companion_set_index(HloInstruction* instruction) const { + return companion_set_index_.at(instruction); + } + + // Returns the list of all companion sets in the HLO module group. + const std::vector>>& + companion_sets() const { + return companion_sets_; + } + + // Returns all channels in the module group. + const std::vector& channels() const { return channels_; } + + // Returns the maximum channel id used in the module group. + int64 max_channel_id() const { return max_channel_id_; } + + private: + Status Build(); + + // Record all channel instructions and While instructions. + Status RecordInstructions(); + + // Verifies the given HloModules are well-formed and follow the specification, + // in particular with respect to using channel instructions. + // + // * Each channel has all 4 instructions (Send, Recv, SendDone, RecvDone). + // * The shape of channel instructions match. + // * The nest level of channel instructions match. + // * Channel instructions are used in allowed computations; i.e., in the + // entry computation of the module or condition/body of While computations. + // + // TODO(b/62064342): Currently, HloModuleGroupScheduler checks if there is a + // cycle in the graph, but it would be good to verify here. + Status VerifyChannelInstructions(); + + // Adds metadata that the given two instructions are companions. + Status AddCompanion(HloInstruction* instruction1, + HloInstruction* instruction2); + + // Retrieves a pointer to the stored TrackedInstruction associated with a + // tracked computation, or nullptr in case such computation is not tracked. + const TrackedInstruction* GetTrackedInstruction( + const HloComputation* computation) const { + auto it = tracked_instructions_.find(computation); + return it != tracked_instructions_.end() ? &it->second : nullptr; + } + + // List of all companion instructions sets in the module. + std::vector>> + companion_sets_; + + // Map from each companion while instruction to the index into companion_set_. + tensorflow::gtl::FlatMap companion_set_index_; + + // Map from computation to the instruction using it (a kWhile, kConditional). + tensorflow::gtl::FlatMap + tracked_instructions_; + + // All channels in the module. + std::vector channels_; + + // Map from channel ids to the index in channels_. + tensorflow::gtl::FlatMap channel_id_map_; + + // The maximum channel id used in the module group. + int64 max_channel_id_ = -1; + + // The modules that this metadata was built from. + const std::vector& modules_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..289c96b0a7b90c5f8a122cd3fc327a5762099106 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -0,0 +1,316 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_group_util.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +std::vector HloModuleGroupUtil::GlobalPredecessors( + HloInstruction* instruction) { + std::vector predecessors; + + // Adds to the unique predecessors list and also add companion instructions + // if the given predecessor has those. + auto add_unique_predecessor = [&](HloInstruction* predecessor) { + if (std::find(predecessors.begin(), predecessors.end(), predecessor) != + predecessors.end()) { + return; + } + if (!metadata_.IsCompanionInstruction(predecessor)) { + predecessors.push_back(predecessor); + return; + } + for (HloInstruction* companion : metadata_.Companions(predecessor)) { + predecessors.push_back(companion); + } + }; + + // If the given instruction is a companion instruction, we need to find the + // predecessors of all of its companion instructions. + std::vector instruction_group; + if (metadata_.IsCompanionInstruction(instruction)) { + for (HloInstruction* companion : metadata_.Companions(instruction)) { + instruction_group.push_back(companion); + } + } else { + instruction_group.push_back(instruction); + } + + for (HloInstruction* hlo : instruction_group) { + for (HloInstruction* operand : hlo->operands()) { + add_unique_predecessor(operand); + } + for (HloInstruction* control_predecessor : hlo->control_predecessors()) { + add_unique_predecessor(control_predecessor); + } + } + if (instruction->opcode() == HloOpcode::kRecvDone) { + // Send is a remote predecessor of RecvDone. + HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send; + add_unique_predecessor(send); + } + if (instruction->opcode() == HloOpcode::kSend) { + // Recv is a remote predecessor of Send. + HloInstruction* recv_done = + metadata_.GetChannel(instruction->channel_id()).recv_done; + CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + CHECK_EQ(recv_done->operand_count(), 1); + HloInstruction* recv = recv_done->mutable_operand(0); + add_unique_predecessor(recv); + } + return predecessors; +} + +std::vector HloModuleGroupUtil::GlobalSuccessors( + HloInstruction* instruction) { + std::vector successors; + + // Adds to the unique successors list and also add companion instructions + // if the given successor has those. + auto add_unique_successor = [&](HloInstruction* successor) { + if (std::find(successors.begin(), successors.end(), successor) != + successors.end()) { + return; + } + if (!metadata_.IsCompanionInstruction(successor)) { + successors.push_back(successor); + return; + } + for (HloInstruction* companion : metadata_.Companions(successor)) { + successors.push_back(companion); + } + }; + + // If the given instruction is a companion instruction, we need to find the + // successors of all of its companion instructions. + std::vector instruction_group; + if (metadata_.IsCompanionInstruction(instruction)) { + for (HloInstruction* companion : metadata_.Companions(instruction)) { + instruction_group.push_back(companion); + } + } else { + instruction_group.push_back(instruction); + } + + for (HloInstruction* hlo : instruction_group) { + for (HloInstruction* user : hlo->users()) { + add_unique_successor(user); + } + for (HloInstruction* control_successor : hlo->control_successors()) { + add_unique_successor(control_successor); + } + } + if (instruction->opcode() == HloOpcode::kRecv) { + // Send is a remote successor of Recv. + const HloInstruction* recv_done = instruction->users().front(); + CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send; + add_unique_successor(send); + } + if (instruction->opcode() == HloOpcode::kSend) { + // RecvDone is a remote successor of Send. + HloInstruction* recv_done = + metadata_.GetChannel(instruction->channel_id()).recv_done; + add_unique_successor(recv_done); + } + return successors; +} + +std::vector HloModuleGroupUtil::RootInstructions( + tensorflow::gtl::ArraySlice computations) { + std::vector roots; + for (HloComputation* computation : computations) { + for (HloInstruction* instruction : computation->instructions()) { + if (GlobalSuccessors(instruction).empty()) { + roots.push_back(instruction); + } + } + } + return roots; +} + +Status HloModuleGroupUtil::VisitTopologicalOrder( + VisitStates* visit_state, const VisitFunction& visit_function, + HloInstruction* root) { + // Stack of HLO instructions visited in DFS order. + std::stack stack; + stack.push(root); + + while (!stack.empty()) { + HloInstruction* hlo = stack.top(); + + // Find the instruction group of the currently visited instruction. The + // instruction group represents all companion instructions of the + // current instruction, and are considered to be a single entity for the + // purpose of the traversal (i.e., they must always be in the same visit + // state). + std::vector instruction_group; + if (metadata_.IsCompanionInstruction(hlo)) { + for (HloInstruction* companion : metadata_.Companions(hlo)) { + instruction_group.push_back(companion); + } + } else { + instruction_group.push_back(hlo); + } + + if ((*visit_state)[hlo] == VisitState::kVisited) { + // All instructions in the group must be in the same state. + for (HloInstruction* instruction : instruction_group) { + TF_RET_CHECK((*visit_state)[instruction] == VisitState::kVisited); + } + stack.pop(); + continue; + } + + if ((*visit_state)[hlo] == VisitState::kVisiting) { + TF_RETURN_IF_ERROR(visit_function(hlo, instruction_group)); + + // Set the visit state of all instructions in the group to kVisited. + for (HloInstruction* instruction : instruction_group) { + TF_RET_CHECK((*visit_state)[instruction] == VisitState::kVisiting); + (*visit_state)[instruction] = VisitState::kVisited; + } + stack.pop(); + continue; + } + + // Set the visit state of all instructions in the group to kVisiting. + for (HloInstruction* instruction : instruction_group) { + TF_RET_CHECK((*visit_state)[instruction] == VisitState::kNotVisited) + << instruction->ToString(); + (*visit_state)[instruction] = VisitState::kVisiting; + } + + // For each instruction in the group, visit its predecessors (operands, + // control predecessors and remote predecessors). + for (HloInstruction* instruction : instruction_group) { + for (HloInstruction* predecessor : GlobalPredecessors(instruction)) { + // Visiting a node that is already being visited implies that there is + // a cycle. Generate an error with the list of instructions in the + // cycle. + if ((*visit_state)[predecessor] == VisitState::kVisiting) { + string cyclic_instructions; + for (const auto& state : *visit_state) { + if (state.second == VisitState::kVisiting) { + tensorflow::strings::StrAppend(&cyclic_instructions, + state.first->ToString(), "\n"); + } + } + // TODO(b/64305524): Improve the error message to print out the + // instructions in a deterministic order that forms the cycle. + return FailedPrecondition( + "Cross-computation cycle detected via communicating nodes. The " + "cycle contains the node %s. The cycle is found among the " + "following nodes. Note that the order of the nodes is arbitrary " + "and that the list may include nodes that are not part of the " + "cycle.\n%s", + predecessor->ToString().c_str(), cyclic_instructions.c_str()); + } + stack.push(predecessor); + } + } + } + + return Status::OK(); +} + +Status HloModuleGroupUtil::VerifyComputations( + tensorflow::gtl::ArraySlice computations) { + auto visit_function = + [&](HloInstruction* instruction, + const std::vector& instruction_group) { + return Status::OK(); + }; + int64 instructions_count = 0; + VisitStates visit_states; + for (HloComputation* computation : computations) { + // Visit all instructions, and not just from the root instruction of the + // computation. This allows us to detect dead cycles (i.e., cycles that + // are not reachable from the root) or to enforce an order for the + // communication instructions that are not reachable from any roots. + for (HloInstruction* instruction : computation->instructions()) { + TF_RETURN_IF_ERROR( + VisitTopologicalOrder(&visit_states, visit_function, instruction)); + } + instructions_count += computation->instruction_count(); + } + + // Check if all instructions are visited and are in the visited state. + TF_RET_CHECK(visit_states.size() == instructions_count); + for (auto& state : visit_states) { + TF_RET_CHECK(state.second == VisitState::kVisited); + } + + return Status::OK(); +} + +StatusOr> +HloModuleGroupUtil::ComputeReachability( + tensorflow::gtl::ArraySlice computations) { + std::list post_order; + auto visit_function = + [&](HloInstruction* instruction, + const std::vector& instruction_group) { + post_order.insert(post_order.end(), instruction_group.begin(), + instruction_group.end()); + return Status::OK(); + }; + HloModuleGroupUtil::VisitStates visit_states; + for (HloInstruction* root : RootInstructions(computations)) { + TF_RETURN_IF_ERROR( + VisitTopologicalOrder(&visit_states, visit_function, root)); + } + auto reachability = absl::make_unique(post_order); + for (HloInstruction* hlo : post_order) { + reachability->SetReachabilityToUnion(GlobalPredecessors(hlo), hlo); + } + return std::move(reachability); +} + +void HloModuleGroupUtil::UpdateReachabilityThroughInstruction( + HloInstruction* instruction, HloReachabilityMap* reachability_map) { + std::queue worklist; + worklist.push(instruction); + + while (!worklist.empty()) { + HloInstruction* item = worklist.front(); + worklist.pop(); + if (reachability_map->SetReachabilityToUnion(GlobalPredecessors(item), + item)) { + for (HloInstruction* successor : GlobalSuccessors(item)) { + worklist.push(successor); + } + } + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h new file mode 100644 index 0000000000000000000000000000000000000000..c25ca1aff50b288f3ac3885cbed53e7ba9768430 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h @@ -0,0 +1,117 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { + +// Collection of utilities for handling HloModuleGroups. +class HloModuleGroupUtil { + public: + explicit HloModuleGroupUtil(const HloModuleGroupMetadata& metadata) + : metadata_(metadata) {} + + // Returns all unique predecessors of the instruction. This includes: + // * predecessors in the same computation: operands and control predecessors + // * Recv is a predecessor of Send + // * Send is a predecessor of RecvDone + // * predecessors of companions (if the instruction is a companion while) + // * predecessors' companions (for any predecessor that is a companion while) + std::vector GlobalPredecessors(HloInstruction* instruction); + + // Returns all unique successors of the instruction. This includes: + // * successors in the same computation: users and control successors + // * Send is a successor of Recv + // * RecvDone is a predecessor of Send + // * successors of companions (if the instruction is a companion while) + // * successors' companions (for any successor that is a companion while) + std::vector GlobalSuccessors(HloInstruction* instruction); + + // Returns the root instructions of the computations. + std::vector RootInstructions( + tensorflow::gtl::ArraySlice computations); + + // Visit state of each instruction during DFS traversal. + enum VisitState { + kNotVisited = 0, + kVisiting, + kVisited, + }; + + // Function called on each instruction group during the DFS traversal. See the + // comment for VisitTopologicalOrder()). + using VisitFunction = std::function& instruction_group)>; + + // Given the hlo instruction as the root, recursively visits all its + // predecessor instructions in DFS order to visit nodes in topological order. + // + // Note that the DFS traversal does not only visit nodes in the same + // computation (parent of the root instruction), but also visits nodes in + // different computations connected via communication instructions. During the + // traversal, companion While instructions (see the class comment in + // HloModuleGroupMetadata) are treated as a single instruction (called + // instruction group, which contains only a single instruction if the visiting + // node is not a companion while) -- visiting one of the instructions in the + // group effectively visits all other instructions in the group, and then all + // predecessor instructions of the group are visited. + // + // * visit_state: map from each instruction to its visit state. + // * visit_function: function called when each instruction group. + // * root: the root instruction of the traversal. + using VisitStates = tensorflow::gtl::FlatMap; + Status VisitTopologicalOrder(VisitStates* visit_state, + const VisitFunction& visit_function, + HloInstruction* root); + + // Verifies that the computations are well-formed (e.g., no cycles). + Status VerifyComputations( + tensorflow::gtl::ArraySlice computations); + + // Below Reachability utils resemble those in HloComputation, except that + // they can handle instructions across multiple computations. + // + // Creates the reachability map for the instructions in the computations. + StatusOr> ComputeReachability( + tensorflow::gtl::ArraySlice computations); + + // Updates the reachability of the given instruction, taking the global + // predeccessorss and successors into account. + void UpdateReachabilityThroughInstruction( + HloInstruction* instruction, HloReachabilityMap* reachability_map); + + private: + const HloModuleGroupMetadata& metadata_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 1b24d8da9e832e6847cb6f405e15af3c455f695a..e89d94bede6c437ca1131a1b1b0098390d58c0d9 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -66,6 +66,28 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, } } + // If the common ancestor is a conditional instruction, even though the true + // and false computations are not really ordered per-se, we define the true + // computation to be ordered before the false one. + // This ensures that buffers can still be shared among the two computations + // as they will forcibly have disjoint liveness. + if (a_ancestor == b_ancestor && + a_ancestor->opcode() == HloOpcode::kConditional) { + const HloComputation* true_computation = a_ancestor->true_computation(); + const HloComputation* false_computation = a_ancestor->false_computation(); + if (call_graph_->InstructionIsNestedIn(a, true_computation) && + call_graph_->InstructionIsNestedIn(b, false_computation)) { + return true; + } + // If 'b' is the conditional ancestor, and 'a' is within the true or false + // computations, 'a' executes before 'b'. + if (b == a_ancestor && + (call_graph_->InstructionIsNestedIn(a, true_computation) || + call_graph_->InstructionIsNestedIn(a, false_computation))) { + return true; + } + } + return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor); } @@ -118,7 +140,18 @@ bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { b.defining_instruction()->while_condition()))) { return true; } - + // If 'b' is a conditional phi and 'a' is in the true or false computation, + // then 'a' executes before 'b'. + if (b.is_phi() && + b.defining_instruction()->opcode() == HloOpcode::kConditional && + (call_graph_->InstructionIsNestedIn( + a.defining_instruction(), + b.defining_instruction()->true_computation()) || + call_graph_->InstructionIsNestedIn( + a.defining_instruction(), + b.defining_instruction()->false_computation()))) { + return true; + } return ExecutesBefore(a.defining_instruction(), b.defining_instruction()); } @@ -212,18 +245,17 @@ bool HloOrdering::LiveRangeStrictlyBefore( VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString() << ", b = " << b.ToShortString() << ")"; if (!IsDefinedBefore(a, b)) { - VLOG(4) << "a not defined before b"; + VLOG(4) << a << " not defined before " << b; return false; } - // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { if (!UseIsBeforeValueDefinition(use, b, dataflow)) { - VLOG(4) << "use of a (" << use << ") not before b is defined"; + VLOG(4) << "use of " << a << " (" << use << ") not before " << b + << " is defined"; return false; } } - return true; } diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index a989fce63234cb860d08c48b02462e96bec879bc..37a7fbad97cea2f34798efecc2489e57d1374f35 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -34,53 +34,6 @@ namespace { class HloOrderingTest : public HloTestBase {}; -TEST_F(HloOrderingTest, LastUseScheduledFirst) { - // Tests scheduling of the following HLO code: - // - // %ab = abs(%param) - // %exp = exp(%param) - // %add = add(%ab, %exp) - // %negate = negate(%exp) - // %sub = subtract(%add, %negate) - // - // %add should be scheduled before %negate because %add is the last (and only) - // use of %ab. Scheduling %add first then frees up %ab's buffer. - const Shape vec = ShapeUtil::MakeShape(xla::F32, {42}); - auto builder = HloComputation::Builder(TestName()); - auto param = - builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param")); - auto ab = builder.AddInstruction( - HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param)); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(vec, HloOpcode::kExp, param)); - - auto add = builder.AddInstruction( - HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp)); - auto negate = builder.AddInstruction( - HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp)); - auto sub = builder.AddInstruction( - HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); - - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(*module, [](const LogicalBuffer& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); - - // The first instruction should be the parameter and the last the root "sub". - EXPECT_EQ(param, sequence.at(module->entry_computation()).front()); - EXPECT_EQ(sub, sequence.at(module->entry_computation()).back()); - - SequentialHloOrdering ordering(module.get(), sequence); - EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); -} - TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { // Tests the ordering of instructions in different computations using the // following HLO code: @@ -362,5 +315,66 @@ ENTRY while.v11 { ordering.ToString(); // Shouldn't crash. } +TEST_F(HloOrderingTest, ConditionalInstructionOrdering) { + const char* module_str = R"( +HloModule test_conditional_module + +true_branch { + param.1 = (s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(param.1), index=0 + get-tuple-element.2 = s32[] get-tuple-element(param.1), index=1 + add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2) + ROOT tuple.1 = (s32[], s32[]) tuple(add.1, get-tuple-element.1) +} + +false_branch { + param.2 = (s32[], s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(param.2), index=0 + get-tuple-element.4 = s32[] get-tuple-element(param.2), index=1 + add.2 = s32[] add(get-tuple-element.3, get-tuple-element.4) + ROOT tuple.2 = (s32[], s32[]) tuple(add.2, get-tuple-element.4) +} + +ENTRY root { + param.3 = (pred[], (s32[], s32[])) parameter(0) + pred.1 = pred[] get-tuple-element(param.3), index=0 + cond_arg.1 = (s32[], s32[]) get-tuple-element(param.3), index=1 + conditional = (s32[], s32[]) conditional(pred.1, cond_arg.1, cond_arg.1), true_computation=true_branch, false_computation=false_branch + cond_res.1 = s32[] get-tuple-element(conditional), index=0 + cond_res.2 = s32[] get-tuple-element(conditional), index=1 + add.3 = s32[] add(cond_res.1, cond_res.2) + ROOT result = (s32[], s32[], s32[]) tuple(add.3, cond_res.1, cond_res.2) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + DependencyHloOrdering ordering(module.get()); + + // Even though the true and false branches has no ordering, since they do not + // interfere (as they are mutually exclusive), we define the true computation + // to be before the false one. + // Similarly, any instruction in the true or false branches are considered + // before the conditional instruction. The roots are effectively "at the same + // time" WRT the conditional, but they are Phi-ed anyway. + HloInstruction* add_1 = FindInstruction(module.get(), "add.1"); + HloInstruction* add_2 = FindInstruction(module.get(), "add.2"); + HloInstruction* add_3 = FindInstruction(module.get(), "add.3"); + HloInstruction* conditional = FindInstruction(module.get(), "conditional"); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1), + dataflow->GetValueDefinedAt(add_2))); + EXPECT_TRUE( + ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2), + dataflow->GetValueDefinedAt(conditional))); + EXPECT_TRUE( + ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1), + dataflow->GetValueDefinedAt(conditional))); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1), + dataflow->GetValueDefinedAt(add_3))); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2), + dataflow->GetValueDefinedAt(add_3))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 78e6a101c10a1e812e3e2631d520139fd0bc425c..3460679558d185d1e022660d9a1d23176d0d96bf 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_proto_util.h" +#include + +#include "tensorflow/compiler/xla/util.h" + namespace xla { HloProto MakeHloProto(const HloModule& module, @@ -35,4 +39,35 @@ HloProto MakeHloProto(const HloModule& module) { return proto; } +StatusOr> EntryComputationParameterShapes( + const HloProto& hlo_proto) { + if (!hlo_proto.has_hlo_module()) { + return NotFound("HloProto missing HloModuleProto."); + } + if (!hlo_proto.hlo_module().has_program_shape()) { + return NotFound("HloProto missing program shape."); + } + + std::vector parameter_shapes; + const auto& program_shape = hlo_proto.hlo_module().program_shape(); + for (const Shape& shape : program_shape.parameters()) { + parameter_shapes.push_back(&shape); + } + return parameter_shapes; +} + +StatusOr EntryComputationOutputShape(const HloProto& hlo_proto) { + if (!hlo_proto.has_hlo_module()) { + return NotFound("HloProto missing HloModuleProto."); + } + if (!hlo_proto.hlo_module().has_program_shape()) { + return NotFound("HloProto missing program shape."); + } + if (!hlo_proto.hlo_module().program_shape().has_result()) { + return NotFound("HloProto missing result in its program shape"); + } + + return &hlo_proto.hlo_module().program_shape().result(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.h b/tensorflow/compiler/xla/service/hlo_proto_util.h index 320288fdb9aa0810b306b1d78bd1ff4cfc366ed2..3d9c375cd5d26f92cf8316f78789daf4fc08c927 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.h +++ b/tensorflow/compiler/xla/service/hlo_proto_util.h @@ -35,6 +35,15 @@ HloProto MakeHloProto(const HloModule& module, // will not be included in the output. HloProto MakeHloProto(const HloModule& module); +// Returns the shapes of the parameters of the entry computation. Shape pointers +// refer to shapes inside of the given HloProto. +StatusOr> EntryComputationParameterShapes( + const HloProto& hlo_proto); + +// Returns the shape of the output of the entry computation. The shape pointer +// refers to the output shape inside of the given HloProto. +StatusOr EntryComputationOutputShape(const HloProto& hlo_proto); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROTO_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b9cca138703c8fa61aadf69dd7304a215a9f4be2 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_proto_util.h" + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace xla { +namespace { + +class HloProtoUtilTest : public ::testing::Test {}; + +TEST_F(HloProtoUtilTest, ParamsAndOutputShapeMissingModule) { + HloProto hlo_proto; + + auto status = EntryComputationParameterShapes(hlo_proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + ::testing::HasSubstr("missing HloModuleProto")); +} + +TEST_F(HloProtoUtilTest, MissingProgramShape) { + HloProto hlo_proto; + HloModuleProto* module = hlo_proto.mutable_hlo_module(); + module->set_name("entry"); + + auto status = EntryComputationParameterShapes(hlo_proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + ::testing::HasSubstr("missing program shape")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 98b8d34be1f331aaeac94e952deeae1e76379861..b0632448933df4b7681a0704c58d697b5ec68a1f 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -1320,7 +1320,7 @@ StatusOr HloRematerialization::Run( /* static */ StatusOr HloRematerialization::RematerializeAndSchedule( const HloRematerialization::ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, - SchedulerAlgorithm scheduler_algorithm, + MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, RematerializationSizes* sizes) { HloRematerialization remat(scheduler_algorithm, size_function); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 52553439033a3bcfa4b472f13f9cd4b1ecf5ed96..2ee2dd0571ae8c6604e4ca722351fd48a913bda5 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -66,12 +66,12 @@ class HloRematerialization { // code generation. static StatusOr RematerializeAndSchedule( const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, SchedulerAlgorithm scheduler_algorithm, + HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, RematerializationSizes* sizes = nullptr); protected: - HloRematerialization(SchedulerAlgorithm scheduler_algorithm, + HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, const ShapeSizeFunction& size_function) : scheduler_algorithm_(scheduler_algorithm), size_function_(size_function) {} @@ -108,7 +108,7 @@ class HloRematerialization { const HloInstruction* instruction) const; // Selects an algorithm to use for HLO scheduling. - SchedulerAlgorithm scheduler_algorithm_; + MemorySchedulerAlgorithm scheduler_algorithm_; // Function which computes the size of the top-level buffer of a shape. const ShapeSizeFunction size_function_; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 1b7d26dde501a6a0955d62ea0938e0683a32d49d..83de54f3fa56ee660b79d8c366dbc0b52f9fde87 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -162,7 +162,7 @@ TEST_F(HloRematerializationTest, SingleComputation) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/14 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // Root should not have changed. @@ -195,7 +195,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/20 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -236,7 +236,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/17 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -272,7 +272,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/15 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // Both computations should have a rematerialized instruction added. @@ -314,7 +314,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/13 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // All computations should have a rematerialized instruction added. @@ -385,7 +385,7 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { bool changed, HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), SchedulerAlgorithm::kAuto, &sequence)); + module.get(), DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -480,7 +480,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/22 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -577,7 +577,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/22 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 41b079eb799d06321a31f7d7ae0630dc8d58c46b..e5b1c2efa3fc25d23531df298e125521c002dba1 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -110,7 +110,7 @@ HloRunner::HloRunner(se::Platform* platform) { HloRunner::~HloRunner() {} -StatusOr> HloRunner::ExecuteInternal( +StatusOr> HloRunner::Execute( std::unique_ptr module, const tensorflow::gtl::ArraySlice arguments, bool run_hlo_passes) { @@ -158,8 +158,8 @@ StatusOr> HloRunner::ExecuteInternal( TF_ASSIGN_OR_RETURN( std::unique_ptr result, - executable->ExecuteOnStream(&service_run_options, argument_buffer_ptrs, - /*hlo_execution_profile=*/nullptr)); + executable->ExecuteOnStreamWrapper( + &service_run_options, /*profile=*/nullptr, argument_buffer_ptrs)); // Create a ScopedShapedBuffer of the result to manage deallocation. This will // deallocate all the device memory when it goes out of scope. diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index cbaebc68bee708090b8ccb2eae19b556c4d6d453..06ce22a5b9fc7b3d6c10857c84196094c0eed303 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -64,17 +65,27 @@ class HloRunner { const std::string& filename, const DebugOptions& debug_options); // Executes the given module with given literals as input and returns the - // result as a Literal. The LiteralPtr type accepts Literal* or - // std::unique_ptr. + // result as a Literal. // // If run_hlo_passes is false, the module will be executed without Hlo // optimization. - template StatusOr> Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::ArraySlice arguments, bool run_hlo_passes = true); + StatusOr> Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice> arguments, + bool run_hlo_passes = true) { + // Construct a vector of plain pointers for the arguments. + std::vector argument_pointers; + c_transform( + arguments, std::back_inserter(argument_pointers), + [](const std::unique_ptr& literal) { return literal.get(); }); + return Execute(std::move(module), argument_pointers, run_hlo_passes); + } + // If backend is not created in the constructor, creates and returns the // default backend. If creation fails, crashes the program. // @@ -83,11 +94,6 @@ class HloRunner { Backend& backend(); private: - StatusOr> ExecuteInternal( - std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes = true); - struct EigenThreadPoolWrapper; std::unique_ptr thread_pool_wrapper_; @@ -95,19 +101,6 @@ class HloRunner { std::unique_ptr backend_; }; -template -StatusOr> HloRunner::Execute( - std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes) { - // Construct a vector of plain pointers for the arguments. - std::vector argument_pointers; - for (const auto& argument : arguments) { - argument_pointers.push_back(&*argument); - } - return ExecuteInternal(std::move(module), argument_pointers, run_hlo_passes); -} - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index f6e33403f538bd8492b04c34d46a458f7f06cc06..1a767628f6e2d33df353366974fb866e89f0df5a 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -103,10 +103,11 @@ class ListScheduler { for (auto* instruction : computation.instructions()) { tensorflow::gtl::FlatSet instr_uses; for (auto* operand : instruction->operands()) { - for (const LogicalBuffer* buffer : - points_to_analysis.GetBuffersDefinedByInstruction(operand)) { - instr_uses.insert(buffer); - } + points_to_analysis.GetPointsToSet(operand).ForEachElement( + [&](const ShapeIndex& /*index*/, + const PointsToSet::BufferList& buffers) { + instr_uses.insert(buffers.begin(), buffers.end()); + }); } buffer_uses_[instruction] = std::vector( instr_uses.begin(), instr_uses.end()); @@ -339,7 +340,33 @@ int64 SumLogicalBufferSizes( return size; } -StatusOr> RunDFSMemoryScheduler( +StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), computation, + sequence, points_to_analysis, size_function)); + return result.heap_size; +} + +StatusOr> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm) { + VLOG(2) << "Computation: " << computation.name(); + if (algorithm) { + return algorithm(computation, points_to_analysis, size_function); + } + return DefaultMemoryScheduler(computation, points_to_analysis, size_function); +} + +} // namespace + +StatusOr> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function) { @@ -348,6 +375,7 @@ StatusOr> RunDFSMemoryScheduler( // simply users-1 for each instruction. By subtracting 1, we're saying that // instructions with no users or a single user don't count; instructions with // lots of fan-out will be visited earlier. + int64 cumulative_total_size = 0; tensorflow::gtl::FlatMap extra_users; tensorflow::gtl::FlatMap total_sizes; for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { @@ -357,14 +385,17 @@ StatusOr> RunDFSMemoryScheduler( continue; } extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1; - total_sizes[hlo] = SumLogicalBufferSizes( + int64 logical_buffer_size = SumLogicalBufferSizes( points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); + total_sizes[hlo] = logical_buffer_size; + cumulative_total_size += logical_buffer_size; tensorflow::gtl::FlatSet unique_operands( hlo->operands().begin(), hlo->operands().end()); for (const HloInstruction* operand : unique_operands) { extra_users[hlo] += extra_users[operand]; total_sizes[hlo] += total_sizes[operand]; } + total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size); } CHECK_EQ(extra_users.size(), computation.instruction_count()); CHECK_EQ(total_sizes.size(), computation.instruction_count()); @@ -392,32 +423,17 @@ StatusOr> RunDFSMemoryScheduler( return sequence; } -StatusOr MinimumMemoryForComputation( +StatusOr> ListMemoryScheduler( const HloComputation& computation, - const std::vector& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function) { - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), computation, - sequence, points_to_analysis, size_function)); - return result.heap_size; + return ListScheduler::Run(computation, points_to_analysis, size_function); } -StatusOr> CreateMemoryMinimizingSequence( +StatusOr> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm) { - VLOG(2) << "Computation: " << computation.name(); - if (algorithm == SchedulerAlgorithm::kListSchedule) { - return ListScheduler::Run(computation, points_to_analysis, size_function); - } - if (algorithm == SchedulerAlgorithm::kDfsSchedule) { - return RunDFSMemoryScheduler(computation, points_to_analysis, - size_function); - } - + const LogicalBuffer::SizeFunction& size_function) { // We try both a list-scheduler based ordering and a DFS based ordering, and // choose whichever returns a lower min-memory, not accounting for // fragmentation. @@ -427,7 +443,7 @@ StatusOr> CreateMemoryMinimizingSequence( // within the caller's context. But it's good enough for now. TF_ASSIGN_OR_RETURN( std::vector list_sequence, - ListScheduler::Run(computation, points_to_analysis, size_function)); + ListMemoryScheduler(computation, points_to_analysis, size_function)); TF_ASSIGN_OR_RETURN( const int64 list_memory, MinimumMemoryForComputation(computation, list_sequence, @@ -436,7 +452,7 @@ StatusOr> CreateMemoryMinimizingSequence( TF_ASSIGN_OR_RETURN( std::vector dfs_sequence, - RunDFSMemoryScheduler(computation, points_to_analysis, size_function)); + DFSMemoryScheduler(computation, points_to_analysis, size_function)); TF_ASSIGN_OR_RETURN( const int64 dfs_memory, MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, @@ -454,12 +470,10 @@ StatusOr> CreateMemoryMinimizingSequence( } } -} // namespace - StatusOr CreateMemoryMinimizingSequence(const HloModule& module, const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm) { + const MemorySchedulerAlgorithm& algorithm) { SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); @@ -475,7 +489,7 @@ CreateMemoryMinimizingSequence(const HloModule& module, StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm) { + const MemorySchedulerAlgorithm& algorithm) { CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index 1d1eb1e064f75c2220b39e84b010e720a0c37880..068e68383deb170ded1c9b09a8b7ceb8c4c0ab4b 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -33,28 +34,48 @@ StatusOr MinimumMemoryForSequence( const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function); -enum class SchedulerAlgorithm { - kListSchedule, - kDfsSchedule, +// A memory scheduler computes an execution sequence for the HLO instructions in +// 'computation' that minimizes peak memory, given a points-to analysis result +// that describes buffer aliasing, together with a target-specific size function +// that maps a tensor's logical size to its padded size. +typedef std::function>( + const HloComputation&, const TuplePointsToAnalysis&, + const LogicalBuffer::SizeFunction&)> + MemorySchedulerAlgorithm; - // Selects the available scheduler algorithm that had the minimum memory in - // the resulting sequence (a la MinimumMemoryForSequence). - kAuto, -}; +// List scheduler +StatusOr> ListMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); + +// DFS-order scheduler +StatusOr> DFSMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); + +// The default scheduling algorithm. Runs both the list scheduler +// and the DFS scheduler, and chooses whichever returns a lower min-memory, +// not accounting for fragmentation. +StatusOr> DefaultMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); // Returns an HloModuleSequence which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. StatusOr -CreateMemoryMinimizingSequence( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto); +CreateMemoryMinimizingSequence(const HloModule& module, + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm = {}); // Overload of above that computes the sequence for a single computation. StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto); + const MemorySchedulerAlgorithm& algorithm = {}); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 7fb338e7042ce19ac9647e23719e738f3ef42c7c..74544c4a67a819d341056aba4cf6b321a5a86c0a 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -89,5 +90,105 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); } +class HloSchedulingTest : public HloTestBase {}; + +TEST_F(HloSchedulingTest, LastUseScheduledFirst) { + // Tests scheduling of the following HLO code: + // + // %ab = abs(%param) + // %exp = exp(%param) + // %add = add(%ab, %exp) + // %negate = negate(%exp) + // %sub = subtract(%add, %negate) + // + // %add should be scheduled before %negate because %add is the last (and only) + // use of %ab. Scheduling %add first then frees up %ab's buffer. + const Shape vec = ShapeUtil::MakeShape(xla::F32, {42}); + auto builder = HloComputation::Builder(TestName()); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param")); + auto ab = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kExp, param)); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp)); + auto sub = builder.AddInstruction( + HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + CreateMemoryMinimizingSequence(*module, [](const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + + // The first instruction should be the parameter and the last the root "sub". + EXPECT_EQ(param, sequence.at(module->entry_computation()).front()); + EXPECT_EQ(sub, sequence.at(module->entry_computation()).back()); + + SequentialHloOrdering ordering(module.get(), sequence); + EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); +} + +TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) { + const char* module_str = R"( +HloModule test_aliasing_module + +ENTRY root { + param = s32[1000] parameter(0) + p0 = s32[1000] copy(param) + p1 = s32[1000] copy(param) + t = (s32[1000], s32[1000]) tuple(p0, p1) + a = s32[1000] get-tuple-element(t), index=0 + b = s32[1000] get-tuple-element(t), index=1 + c = s32[1000] add(a, b) + d = s32[1000] add(c, b) + e = s32[1000] add(c, c) + f = s32[1000] add(e, e) + ROOT result = (s32[1000], s32[1000], s32[1000]) tuple(d, e, f) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(module_str)); + + auto size_fn = [](const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + CreateMemoryMinimizingSequence(*module, size_fn, ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + + std::unordered_map instructions_by_name; + for (const HloInstruction* instruction : + sequence.at(module->entry_computation())) { + instructions_by_name[instruction->name()] = instruction; + } + + // The first instruction should be the parameter and the last the root. + EXPECT_EQ(instructions_by_name.at("param"), + sequence.at(module->entry_computation()).front()); + EXPECT_EQ(instructions_by_name.at("result"), + sequence.at(module->entry_computation()).back()); + + // Instructions "d" and "e" will both be schedulable at the same time, but + // instruction "d" allows us to free the buffer of "p1", so the list scheduler + // should prefer it. + SequentialHloOrdering ordering(module.get(), sequence); + EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"), + instructions_by_name.at("e"))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index afe79c9f17befdcb2812c0a08b205f21b0715b19..e8e45f1ee968992901988e8b85d4e9ae28f2abe9 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -20,6 +20,7 @@ limitations under the License. namespace xla { +using ::tensorflow::str_util::Join; using ::tensorflow::strings::StrCat; HloSharding HloSharding::AssignDevice(int64 device_id) { @@ -57,8 +58,9 @@ string HloSharding::ToString() const { return StrCat( "{maximal device=", static_cast(*tile_assignment_.begin()), "}"); } else { - return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", - "devices=", VectorString(tile_assignment_), "}"); + return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", "devices=[", + Join(tile_assignment_.dimensions(), ","), "]", + Join(tile_assignment_, ","), "}"); } } @@ -348,4 +350,35 @@ OpSharding HloSharding::ToProto() const { return result; } +HloSharding HloSharding::TransformShardedTileShape( + const Shape& new_shape, + const std::function& transform) const { + CHECK(!IsTuple()); + if (IsTileMaximal()) { + return *this; + } + CHECK_EQ(ShapeUtil::Rank(new_shape), ShapeUtil::Rank(tile_shape())); + Shape new_tile_shape; + new_tile_shape.set_element_type(tile_shape().element_type()); + for (int64 i = 0; i < ShapeUtil::Rank(new_shape); ++i) { + int64 dim; + if (tile_assignment().dim(i) == 1) { + dim = new_shape.dimensions(i); + } else if (transform) { + dim = transform(i, tile_shape().dimensions(i)); + } else { + dim = tile_shape().dimensions(i); + } + new_tile_shape.add_dimensions(dim); + } + TF_CHECK_OK( + LayoutUtil::CopyLayoutBetweenShapes(tile_shape_, &new_tile_shape)); + return HloSharding::Tile(new_tile_shape, tile_assignment()); +} + +std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { + out << sharding.ToString(); + return out; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 7263198385cf0c84b1dac1e15177dcac99adaafb..18d406f3700da6dfdfcd16fb76bf9c1d2bc63141 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -173,7 +173,7 @@ class HloSharding { bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && - protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) && + ShapeUtil::Compatible(tile_shape_, other.tile_shape_) && tile_assignment_ == other.tile_assignment_ && tuple_elements_ == other.tuple_elements_; } @@ -207,6 +207,26 @@ class HloSharding { // REQUIRES: !IsReplicated() && !IsTuple() const Array& tile_assignment() const { return tile_assignment_; } + // Returns the flattened list of all the leaf shardings in a tuple shape, by + // pre-order walk (ShapeTree iterator order). + // REQUIRES: IsTuple(). + const std::vector& tuple_elements() const { + return tuple_elements_; + } + + // Return a new sharding that can apply to the given new shape. + // If this sharding is tile-maximal, the returned sharding will be the same as + // this sharding. If this sharding is not tile-maximal, the returned + // sharding's tile size will differ: + // - Non-sharded dimensions will be adapted to be the same as `new_shape`; + // tile_dimension(i) = new_shape.dimensions(i); + // - Sharded dimensions will be kept the same unless `transform` is supplied + // in which case tile_dimension(i) = transform(i, tile_dimension(i)); + // REQUIRES: !IsTuple(). + HloSharding TransformShardedTileShape( + const Shape& new_shape, + const std::function& transform = nullptr) const; + private: HloSharding() : replicated_(true), @@ -249,6 +269,8 @@ class HloSharding { std::vector tuple_elements_; }; +std::ostream& operator<<(std::ostream& out, const HloSharding& sharding); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 0c7487b3ac77ff181d44dd55ebcf2608feaf02ea..69ea4233e45c2e59c8d1541a0517a007f4bbf42f 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -269,5 +269,57 @@ TEST_F(HloShardingTest, Hash) { } } +TEST_F(HloShardingTest, TransformShardedTileShapeTest) { + HloSharding sharding = + HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), + Array4D({{{{0, 1}, {2, 3}}}})); + HloSharding result = sharding.TransformShardedTileShape( + ShapeUtil::MakeShape(F32, {13, 15, 17, 19}), + [](int dim, int value) { return dim * 111; }); + HloSharding expected = + HloSharding::Tile(ShapeUtil::MakeShape(F32, {13, 15, 222, 333}), + Array4D({{{{0, 1}, {2, 3}}}})); + EXPECT_EQ(result, expected); +} + +TEST_F(HloShardingTest, ToStringReplicatedTest) { + HloSharding sharding = HloSharding::Replicate(); + EXPECT_EQ(sharding.ToString(), "{replicated}"); +} + +TEST_F(HloShardingTest, ToStringAssignDeviceTest) { + HloSharding sharding = HloSharding::AssignDevice(7); + EXPECT_EQ(sharding.ToString(), "{maximal device=7}"); +} + +TEST_F(HloShardingTest, ToStringTiledTest) { + HloSharding sharding = + HloSharding::Tile(ShapeUtil::MakeShape(S32, {7, 11, 13}), + Array3D({{{2, 3}}, {{5, 7}}})); + EXPECT_EQ(sharding.ToString(), "{s32[7,11,13] devices=[2,1,2]2,3,5,7}"); +} + +TEST_F(HloShardingTest, ToStringTupleTest) { + HloSharding sharding = HloSharding::Tuple( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}), + ShapeUtil::MakeShape(U32, {7, 25}), + ShapeUtil::MakeShape(S32, {9, 11})}), + {HloSharding::Replicate(), + HloSharding::Tile(ShapeUtil::MakeShape(U32, {7, 13}), + Array2D({{3, 5}})), + HloSharding::AssignDevice(3)}); + EXPECT_EQ(sharding.ToString(), + "{{replicated}, {u32[7,13] devices=[1,2]3,5}, {maximal device=3}}"); +} + +TEST_F(HloShardingTest, OstreamTest) { + HloSharding sharding = + HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), + Array4D({{{{0, 1}, {2, 3}}}})); + std::ostringstream oss; + oss << sharding; + EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index b1fd068115e1d104a11d880675ef84e07d6d5602..8c875698eb1992719d504d272ca338b05b60e36b 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -762,11 +762,14 @@ StatusOr HloVerifier::Run(HloModule* module) { } else if (instruction->opcode() == HloOpcode::kBroadcast) { // If you see this failure then someone has confused the difference // between the HLO broadcast op, and the UserComputation broadcast - // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I + // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I // or ComputationLowerer::Visit() TF_RET_CHECK(instruction->dimensions().size() == ShapeUtil::Rank(instruction->operand(0)->shape())) - << "Broadcast HLO has invalid number of dimensions."; + << "Broadcast HLO (" << instruction->ToShortString() + << ") has invalid number of dimensions: " + << instruction->dimensions().size() + << " != " << ShapeUtil::Rank(instruction->operand(0)->shape()); } else if (instruction->opcode() == HloOpcode::kWhile) { auto* while_cond = instruction->while_condition(); auto* while_body = instruction->while_body(); diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index f494748e17fc2d0de74dec67f7414d4791f76a07..d69ad80bdb4d2eab2d34228be026d7bc0b76efc0 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -302,7 +302,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { // Consider each operand of this instruction for fusion into this // instruction. We want to consider the operands in a particular order to - // avoid created duplicate instruction clones in the fusion instruction. + // avoid creating duplicate instruction clones in the fusion instruction. // For example, consider the following expression: // // A = ... @@ -377,7 +377,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { changed = true; if (operand->user_count() == 0) { - // Operand is now dead. Remove from post order by setting it's + // Operand is now dead. Remove from post order by setting its // location to nullptr. post_order[FindOrDie(post_order_index, operand)] = nullptr; post_order_index.erase(operand); diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 0819ab3b90b2360c6b0b2afaa89f322afe566eb3..0db3863f2428cf0c9a66a928d54f774e39a18539 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -63,10 +63,7 @@ cc_library( name = "platform_id", srcs = ["platform_id.cc"], hdrs = ["platform_id.h"], - deps = [ - "@nsync//:nsync_headers", - "//tensorflow/core:stream_executor_headers_lib", - ] + if_static( + deps = ["//tensorflow/core:stream_executor_headers_lib"] + if_static( ["@protobuf_archive//:protobuf"], ["@protobuf_archive//:protobuf_headers"], ), diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 0cb9b5d8107cd8bf468b07d5fe2a22930d9e8b8c..883063d0f075f5b0d79edc01bcd27a7c579272f4 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -93,7 +93,7 @@ StatusOr> InterpreterExecutable::ExecuteOnStream( TF_ASSIGN_OR_RETURN(std::unique_ptr result, transfer_manager->AllocateShapedBuffer( result_literal->shape(), run_options->allocator(), - run_options->device_ordinal())); + executor->device_ordinal())); TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( executor, *result_literal, *result)); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 0668f66051ce96292c3c85bac7e649d89914106c..39f9120e552f014dd2759bff2892157402d9c47a 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -192,17 +192,34 @@ LayoutConstraints::LayoutConstraints( } } +PointsToSet::BufferSet* LayoutConstraints::GetBufferSet( + const HloInstruction* instruction) const { + auto it = buffer_sets_cache_.find(instruction); + if (it != buffer_sets_cache_.end()) { + return it->second.get(); + } + auto& buffer_set = + buffer_sets_cache_ + .emplace(instruction, MakeUnique()) + .first->second; + const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction); + points_to_set.ForEachElement( + [&buffer_set](const ShapeIndex& /*index*/, + const PointsToSet::BufferList& buffers) { + buffer_set->insert(buffers.begin(), buffers.end()); + }); + return buffer_set.get(); +} + bool LayoutConstraints::OperandBufferForwarded( const HloInstruction* instruction, int64 operand_no) const { // The operand is potentially forwarded if the intersection of points-to sets // of the operand and the instruction is non-empty. - auto output_buffers = - points_to_analysis_.GetPointsToSet(instruction).CreateFlattenedSet(); - auto operand_buffers = - points_to_analysis_.GetPointsToSet(instruction->operand(operand_no)) - .CreateFlattenedSet(); - for (const LogicalBuffer* output_buffer : output_buffers) { - if (operand_buffers.count(output_buffer) > 0) { + PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction); + PointsToSet::BufferSet* operand_buffers = + GetBufferSet(instruction->operand(operand_no)); + for (const LogicalBuffer* output_buffer : *output_buffers) { + if (operand_buffers->count(output_buffer) > 0) { return true; } } @@ -1544,6 +1561,13 @@ StatusOr LayoutAssignment::Run(HloModule* module) { // infeeds. Clearing the layouts here avoids hiding potential bugs in the // layout assignment pass that may accidently use the existing layout. for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kBitcast) { + // bitcasts are inherently layout sensitive and so a bitcast instruction + // present in the IR before layout assignment is a bug. + return InternalError( + "Unexpected bitcast operation seen during layout assignment: %s.", + instruction->ToString().c_str()); + } if (instruction->opcode() != HloOpcode::kInfeed) { LayoutUtil::ClearLayout(instruction->mutable_shape()); } diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 29018584487cabfd740d7914625c2a50f552d6ff..680f88048a1f0cd5ede7991640003ef407d4facf 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -199,6 +200,11 @@ class LayoutConstraints { string ToString() const; private: + // Find a bufferset in the bufferset cache. This is useful since we can + // currently create the flattened buffer set for the same instruction many + // times, which is often slow. + PointsToSet::BufferSet* GetBufferSet(const HloInstruction* instruction) const; + // The set of BufferLayoutConstraints applied to the computation. std::unordered_map buffer_constraints_; @@ -221,6 +227,10 @@ class LayoutConstraints { // Array-shaped buffers which have not yet been constrained. std::set unconstrained_buffer_ids_; + mutable tensorflow::gtl::FlatMap> + buffer_sets_cache_; + HloComputation* computation_; }; @@ -393,7 +403,6 @@ class LayoutAssignment : public HloPassInterface { Status CheckLayouts(HloModule* module); ComputationLayout* entry_computation_layout_; - ChannelLayoutConstraints* channel_layout_constraints_; protected: // Map containing the layouts of all computations assigned so @@ -401,6 +410,7 @@ class LayoutAssignment : public HloPassInterface { // handled before their caller instructions so the layouts of caller // instructions can be set to match the computation. std::map computation_layouts_; + ChannelLayoutConstraints* channel_layout_constraints_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 88e5caaf478bc99ecf93ab00ddba4637397b9d78..4b1c9bad41de8030cf14bc6d1c0db21b9c56c3bf 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -590,6 +590,85 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { transpose->shape(), {2, 3, 0, 1})); } +// TransposeIsBitcast shouldn't be called without layout information. +TEST_F(LayoutAssignmentTest, TransposeIsBitcastFail) { + auto builder = HloComputation::Builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + Shape input_shape_with_layout(input_shape); + *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape_with_layout, "param")); + auto hlo = builder.AddInstruction( + HloInstruction::CreateTranspose(input_shape, param, {0, 2, 1})); + // Clear the default layout assigned to the instruction. + LayoutUtil::ClearLayout(hlo->mutable_shape()); + EXPECT_DEATH(ShapeUtil::TransposeIsBitcast(hlo->operand(0)->shape(), + hlo->shape(), hlo->dimensions()), + "LayoutUtil::HasLayout"); +} + +// ReshapeIsBitcast shouldn't be called without layout information. +TEST_F(LayoutAssignmentTest, ReshapeIsBitcastFail) { + auto builder = HloComputation::Builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + Shape input_shape_with_layout(input_shape); + *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape_with_layout, "param")); + auto hlo = + builder.AddInstruction(HloInstruction::CreateReshape(input_shape, param)); + // Clear the default layout assigned to the instruction. + LayoutUtil::ClearLayout(hlo->mutable_shape()); + EXPECT_DEATH( + ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape()), + "LayoutUtil::HasLayout"); +} + +// Check that the computation below doesn't crash the compiler. +// +// Within a fusion computation, only the parameters and result get assigned a +// layout. When we run the algebraic simplifier on this computation post layout +// assignment, it should not call TransposeIsBitcast on the `transpose` node +// inside the fusion computation as TransposeIsBitcast checks both input_shape +// and output_shape have layouts. +TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { + const char* module_str = R"( + HloModule test_module + + fused_computation { + param_1 = f32[2,2,2]{2,1,0} parameter(1) + transpose = f32[2,2,2]{2,1,0} transpose(param_1), dimensions={0,2,1} + reduce_1 = f32[] parameter(0) + broadcast_1 = f32[2,2,2]{2,1,0} broadcast(reduce_1), dimensions={} + ROOT divide_1 = f32[2,2,2]{2,1,0} divide(transpose, broadcast_1) + } + + ENTRY entry_computation { + fusion.1 = f32[2,2,2]{2,1,0} parameter(1) + reduce.1 = f32[] parameter(0) + fusion.2 = f32[2,2,2]{2,1,0} fusion(reduce.1, fusion.1), kind=kLoop, calls=fused_computation + ROOT tuple.1 = (f32[2,2,2]{2,1,0}) tuple(fusion.2) + } + )"; + + auto module = tools::Parse(module_str).ValueOrDie(); + + module = + backend() + .compiler() + ->RunHloPasses(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + + EXPECT_EQ( + ::tensorflow::Status::OK(), + backend() + .compiler() + ->RunBackend(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .status()); +} + // A GTE inside of a fusion node inherits the layout of its operand (which // should, if we keep following operands, eventually be a parameter). TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { @@ -717,5 +796,26 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy); } +TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { + auto builder = HloComputation::Builder(TestName()); + auto constant0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); + builder.AddInstruction(HloInstruction::CreateUnary( + constant0->shape(), HloOpcode::kBitcast, constant0)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape()); + LayoutAssignment layout_assignment(&computation_layout); + Status error_status = layout_assignment.Run(module.get()).status(); + EXPECT_FALSE(error_status.ok()); + EXPECT_THAT( + error_status.error_message(), + ::testing::HasSubstr( + "Unexpected bitcast operation seen during layout assignment")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 6384c7f46f5ebbedaeda232b40095611a5d738a4..3312a888443233139841ce7a5e3173f907605e1d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -29,18 +29,13 @@ limitations under the License. namespace xla { namespace llvm_ir { -IrArray::Index::Index(llvm::Value* linear, const Shape& shape, - llvm::IRBuilder<>* ir_builder) - : multidim_(ShapeUtil::Rank(shape)), - linear_(linear), - layout_(shape.layout()), - dims_(shape.dimensions().begin(), shape.dimensions().end()) { - CHECK(LayoutUtil::HasLayout(shape)) - << "Shape " << ShapeUtil::HumanStringWithLayout(shape) - << " should have a layout."; +static void Delinearize(std::vector* multidim, + llvm::Value* linear, const Shape& shape, + llvm::IRBuilder<>* ir_builder) { int64 divisor = 1; - for (int64 i = 0; i < layout_.minor_to_major_size(); ++i) { - int64 dimension = layout_.minor_to_major(i); + const Layout& layout = shape.layout(); + for (int64 i = 0; i < layout.minor_to_major_size(); ++i) { + int64 dimension = layout.minor_to_major(i); int64 size_of_current_dimension = shape.dimensions(dimension); // If i is not the last dimension, compute @@ -54,16 +49,28 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape, // memory lives in one big allocation, so cuda-memcheck can't detect // out-of-bounds accesses. auto* quot = ir_builder->CreateUDiv(linear, ir_builder->getInt64(divisor)); - if (i < layout_.minor_to_major_size() - 1) { - multidim_[dimension] = ir_builder->CreateURem( + if (i < layout.minor_to_major_size() - 1) { + (*multidim)[dimension] = ir_builder->CreateURem( quot, ir_builder->getInt64(size_of_current_dimension)); } else { - multidim_[dimension] = quot; + (*multidim)[dimension] = quot; } divisor *= size_of_current_dimension; } } +IrArray::Index::Index(llvm::Value* linear, const Shape& shape, + llvm::IRBuilder<>* ir_builder) + : multidim_(ShapeUtil::Rank(shape)), + linear_(linear), + layout_(shape.layout()), + dims_(shape.dimensions().begin(), shape.dimensions().end()) { + CHECK(LayoutUtil::HasLayout(shape)) + << "Shape " << ShapeUtil::HumanStringWithLayout(shape) + << " should have a layout."; + Delinearize(&multidim_, linear, shape, ir_builder); +} + IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, llvm::Value* linear, const Shape& shape) : multidim_(multidim.begin(), multidim.end()), @@ -83,7 +90,6 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, dims_(shape.dimensions().begin(), shape.dimensions().end()) { CHECK_EQ(shape.dimensions_size(), multidim.size()); CHECK(LayoutUtil::HasLayout(shape)); - linear_ = Linearize(AsInt64Slice(shape.dimensions()), ir_builder); } IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape) @@ -106,16 +112,13 @@ IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape) } } -// Returns whether given linear index valid on given shape. +// Returns whether the given linear index is valid on the given shape. bool IrArray::Index::LinearValidOnShape(const Shape& a) const { - auto b = ShapeUtil::MakeShape(PRED /* irrelevant */, dims_); + auto b = ShapeUtil::MakeShape(a.element_type(), dims_); *b.mutable_layout() = layout_; return linear_ != nullptr && - ContainersEqual( - ShapeUtil::StripDegenerateDimensions(a).dimensions(), - ShapeUtil::StripDegenerateDimensions(b).dimensions()) && - LayoutUtil::Equal(ShapeUtil::StripDegenerateDimensions(a).layout(), - ShapeUtil::StripDegenerateDimensions(b).layout()); + ShapeUtil::ElementsIn(a) == ShapeUtil::ElementsIn(b) && + ShapeUtil::ReshapeIsBitcast(a, b); } IrArray::Index IrArray::Index::SourceIndexOfReshape( @@ -160,7 +163,8 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( } } - if (linear() != nullptr && + if (linear() != nullptr && LayoutUtil::HasLayout(input_shape) && + LayoutUtil::HasLayout(output_shape) && ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) { return Index(source_multidim_index, linear(), input_shape); } @@ -195,13 +199,111 @@ IrArray::Index IrArray::Index::SourceIndexOfTranspose( llvm::IRBuilder<>* builder) const { std::vector operand_multidim_index = Permute(dimension_mapping, multidim()); - if (linear() != nullptr && + + if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) && + LayoutUtil::HasLayout(shape) && ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) { return Index(operand_multidim_index, linear(), operand_shape); } + return Index(operand_multidim_index); } +IrArray::Index IrArray::Index::SourceIndexOfBitcast( + const Shape& shape, const Shape& operand_shape, + llvm::IRBuilder<>* builder) const { + CHECK(LayoutUtil::HasLayout(shape) && LayoutUtil::HasLayout(operand_shape)); + // In case the bitcast is just a reshape, we can use SourceIndexOfReshape() + // instead. This will reuse linear() if possible, so we don't have to build a + // new 'linear_index'. + if (ShapeUtil::ReshapeIsBitcast(operand_shape, shape)) { + return SourceIndexOfReshape(shape, operand_shape, builder); + } + + // First linearize the index coming from the output of the bitcast. We want + // the physical index of the element in the buffer. This is like Linearize, + // but takes the layout into account. + int64 scale = 1; + llvm::Value* linear_index = builder->getInt64(0); + for (auto dimension : LayoutUtil::MinorToMajor(shape)) { + linear_index = builder->CreateAdd( + linear_index, + builder->CreateMul(multidim_[dimension], builder->getInt64(scale), "", + /*HasNUW=*/true, /*HasNSW=*/true), + "", /*HasNUW=*/true, /*HasNSW=*/true); + scale *= shape.dimensions(dimension); + } + + // Now delinearize it for the input of the bitcast. + std::vector multi_index(operand_shape.dimensions_size()); + Delinearize(&multi_index, linear_index, operand_shape, builder); + + return Index(multi_index, linear_index, operand_shape); +} + +IrArray::Index IrArray::Index::SourceIndexOfBroadcast( + const Shape& shape, const Shape& operand_shape, + tensorflow::gtl::ArraySlice dimension_mapping, + llvm::IRBuilder<>* builder) const { + int64 rank = ShapeUtil::Rank(operand_shape); + std::vector source_index(rank); + for (int64 i = 0; i < rank; ++i) { + source_index[i] = multidim_[dimension_mapping[i]]; + } + if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) || + !LayoutUtil::HasLayout(shape)) { + return Index(source_index); + } + // High-level idea: we can reuse the linear index if the broadcasted + // dimensions are contiguous, and this part of the operation is a bitcast. + // The other dimensions can be masked out with a div and a mod operation. + std::vector logical_to_physical = + LayoutUtil::MakeLogicalToPhysical(shape.layout()); + int64 output_rank = ShapeUtil::Rank(shape); + // The minimum physical dimension that is broadcasted. + int64 min_broadcasted_dimension = output_rank; + // The maximum physical dimension that is broadcasted. + int64 max_broadcasted_dimension = -1; + for (int64 i = 0; i < rank; ++i) { + int64 physical_dim = logical_to_physical[dimension_mapping[i]]; + min_broadcasted_dimension = + std::min(min_broadcasted_dimension, physical_dim); + max_broadcasted_dimension = + std::max(max_broadcasted_dimension, physical_dim); + } + bool contiguous_broadcast_dimensions = + max_broadcasted_dimension - min_broadcasted_dimension == rank - 1; + if (!contiguous_broadcast_dimensions) { + return Index(source_index); + } + // Check if the mapped dimensions are a bitcast. + std::vector operand_logical_to_physical = + LayoutUtil::MakeLogicalToPhysical(operand_shape.layout()); + for (int64 i = 0; i < rank; ++i) { + if (operand_logical_to_physical[i] != + logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) { + return Index(source_index); + } + } + llvm::Value* linear = linear_; + int64 divisor = 1; + for (int64 i = max_broadcasted_dimension + 1; i < output_rank; ++i) { + divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); + } + if (divisor > 1) { + linear = builder->CreateUDiv(linear, builder->getInt64(divisor)); + } + if (min_broadcasted_dimension > 0) { + int64 mod = 1; + for (int64 i = min_broadcasted_dimension; i <= max_broadcasted_dimension; + ++i) { + mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); + } + linear = builder->CreateURem(linear, builder->getInt64(mod)); + } + return Index(source_index, linear, operand_shape); +} + llvm::Value* IrArray::Index::Linearize( tensorflow::gtl::ArraySlice dimensions, llvm::IRBuilder<>* builder) const { diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 387d4629125cbb791840e943013188d14159908a..06cfb2a36c56c5fdece7140e469379f8394111fa 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -76,8 +76,7 @@ class IrArray { llvm::IRBuilder<>* ir_builder); // Constructs an index from the given multi-dimensional index and the shape - // that it indexes into. Also, computes the linear index according to - // "shape". + // that it indexes into. // // Precondition: "shape" has a layout. Index(tensorflow::gtl::ArraySlice multidim, @@ -134,6 +133,18 @@ class IrArray { tensorflow::gtl::ArraySlice dimension_mapping, llvm::IRBuilder<>* builder) const; + // Given that "this" is the target index of a bitcast from `operand_shape` + // to `shape`, returns the source index. + Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape, + llvm::IRBuilder<>* builder) const; + + // Given that "this" is the target index of a broadcast from `operand_shape` + // to `shape` with the given dimension mapping, returns the source index. + Index SourceIndexOfBroadcast( + const Shape& shape, const Shape& operand_shape, + tensorflow::gtl::ArraySlice dimension_mapping, + llvm::IRBuilder<>* builder) const; + // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and // returns the index into the sole dimension 0 of the new shape. llvm::Value* Linearize(tensorflow::gtl::ArraySlice dimensions, diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 5c1866311d1ae1e0c33ab061ee326d86d647a908..2a282f3be79f847a6569416794d1a2a3fcd69148 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -106,8 +106,10 @@ llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, auto cmp = ir_builder->CreateFCmpUGE(lhs_value, rhs_value); return ir_builder->CreateSelect(cmp, lhs_value, rhs_value); } else { - return EmitCallToIntrinsic(llvm::Intrinsic::maxnum, {lhs_value, rhs_value}, - {lhs_value->getType()}, ir_builder); + auto cmp_ge = ir_builder->CreateFCmpOGE(lhs_value, rhs_value); + auto lhs_is_nan = ir_builder->CreateFCmpUNE(lhs_value, lhs_value); + auto sel_lhs = ir_builder->CreateOr(cmp_ge, lhs_is_nan); + return ir_builder->CreateSelect(sel_lhs, lhs_value, rhs_value); } } @@ -117,8 +119,10 @@ llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, auto cmp = ir_builder->CreateFCmpULE(lhs_value, rhs_value); return ir_builder->CreateSelect(cmp, lhs_value, rhs_value); } else { - return EmitCallToIntrinsic(llvm::Intrinsic::minnum, {lhs_value, rhs_value}, - {lhs_value->getType()}, ir_builder); + auto cmp_le = ir_builder->CreateFCmpOLE(lhs_value, rhs_value); + auto lhs_is_nan = ir_builder->CreateFCmpUNE(lhs_value, lhs_value); + auto sel_lhs = ir_builder->CreateOr(cmp_le, lhs_is_nan); + return ir_builder->CreateSelect(sel_lhs, lhs_value, rhs_value); } } diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 07f989d4faea199e812e54d2ae74d3ff9e7fa19a..499f280211aacd00e79b3ca0ddb3413f933b02da 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -69,6 +69,68 @@ LocalService::LocalService(const ServiceOptions& options, std::unique_ptr execute_backend) : Service(options, std::move(execute_backend)) {} +namespace { + +// Retrieves the parameter metadata for the given computation and parameter +// number. +// +// If the parameter number is invalid for this computation, nullopt is +// returned. When the return value has_value(), nullptr will never be +// the held value. +tensorflow::gtl::optional ParameterMetadata( + const XlaComputation& computation, int parameter_number) { + for (const HloComputationProto& comp : computation.proto().computations()) { + if (comp.id() == computation.proto().entry_computation_id()) { + for (const HloInstructionProto& instr : comp.instructions()) { + if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) && + instr.parameter_number() == parameter_number) { + if (!instr.has_metadata()) { + return tensorflow::gtl::nullopt; + } + return &instr.metadata(); + } + } + } + } + return tensorflow::gtl::nullopt; +} + +ExecutionOptions CreateExecutionOptions( + const ExecutableBuildOptions& build_options, + const ProgramShape* program_shape) { + ExecutionOptions execution_options = CreateDefaultExecutionOptions(); + if (build_options.hlo_profile().has_value()) { + execution_options.mutable_debug_options()->set_xla_hlo_profile( + *build_options.hlo_profile()); + } + if (build_options.generate_hlo_graph().has_value()) { + execution_options.mutable_debug_options()->set_xla_generate_hlo_graph( + build_options.generate_hlo_graph().value()); + } + if (build_options.dump_optimized_hlo_proto_to().has_value()) { + execution_options.mutable_debug_options() + ->set_xla_dump_optimized_hlo_proto_to( + build_options.dump_optimized_hlo_proto_to().value()); + } + if (build_options.dump_per_pass_hlo_proto_to().has_value()) { + execution_options.mutable_debug_options() + ->set_xla_dump_per_pass_hlo_proto_to( + build_options.dump_per_pass_hlo_proto_to().value()); + } + if (build_options.result_layout() != nullptr) { + *execution_options.mutable_shape_with_output_layout() = + *build_options.result_layout(); + } else { + *execution_options.mutable_shape_with_output_layout() = + program_shape->result(); + LayoutUtil::SetToDefaultLayout( + execution_options.mutable_shape_with_output_layout()); + } + return execution_options; +} + +} // namespace + StatusOr> LocalService::CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -118,30 +180,78 @@ StatusOr> LocalService::CompileExecutable( *build_options.result_layout(), program_shape->result())); } - ExecutionOptions execution_options = CreateDefaultExecutionOptions(); - if (build_options.generate_hlo_graph().has_value()) { - execution_options.mutable_debug_options()->set_xla_generate_hlo_graph( - build_options.generate_hlo_graph().value()); + ExecutionOptions execution_options = + CreateExecutionOptions(build_options, program_shape.get()); + TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, + CreateModuleConfig(*program_shape, argument_layouts, + &execution_options, user_computation)); + + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + execute_backend_->stream_executor(build_options.device_ordinal())); + + return BuildExecutable(versioned_handle, std::move(module_config), + execute_backend_.get(), executor, + build_options.device_allocator()); +} + +StatusOr> LocalService::CompileExecutable( + const XlaComputation& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const ExecutableBuildOptions& build_options) { + const HloModuleProto& proto = computation.proto(); + TF_RET_CHECK(proto.has_program_shape()); + const ProgramShape& program_shape = proto.program_shape(); + + // Validate incoming layouts. + if (argument_layouts.size() != program_shape.parameters_size()) { + return InvalidArgument( + "Invalid number of arguments for computation: expected %d, got %zu.", + program_shape.parameters_size(), argument_layouts.size()); + } + + for (int i = 0; i < argument_layouts.size(); ++i) { + const Shape& argument_shape = *argument_layouts[i]; + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); + if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) { + tensorflow::gtl::optional metadata = + ParameterMetadata(computation, /*parameter_number=*/i); + auto metadata_string = [&metadata]() -> string { + if (!metadata.has_value()) { + return ""; + } + CHECK(metadata.value() != nullptr); + const OpMetadata& m = *metadata.value(); + if (!m.source_file().empty()) { + return tensorflow::strings::Printf( + " (%s:%d)", m.source_file().c_str(), m.source_line()); + } + return ""; + }; + return InvalidArgument( + "Invalid argument shape for argument %d%s, expected %s, got %s.", i, + metadata_string().c_str(), + ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), + ShapeUtil::HumanString(argument_shape).c_str()); + } } if (build_options.result_layout() != nullptr) { - *execution_options.mutable_shape_with_output_layout() = - *build_options.result_layout(); - } else { - *execution_options.mutable_shape_with_output_layout() = - program_shape->result(); - LayoutUtil::SetToDefaultLayout( - execution_options.mutable_shape_with_output_layout()); + TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( + *build_options.result_layout(), program_shape.result())); } + + ExecutionOptions execution_options = + CreateExecutionOptions(build_options, &program_shape); + TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, argument_layouts, &execution_options, - *user_computation)); + CreateModuleConfig(program_shape, argument_layouts, &execution_options)); TF_ASSIGN_OR_RETURN( se::StreamExecutor * executor, execute_backend_->stream_executor(build_options.device_ordinal())); - return BuildExecutable(versioned_handle, std::move(module_config), + return BuildExecutable(proto, std::move(module_config), execute_backend_.get(), executor, build_options.device_allocator()); } diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 15e120685e1be9190d49fdaf5ed6706bdf991a6c..06567cabd6eb28aae53881613cd6beb78e25e222 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -50,6 +51,18 @@ class LocalService : public Service { const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& options); + // Builds an Executable with the given XlaComputation, argument layouts and + // options. If result_layout is non-null, then the executable is compiled to + // produce a result of the given layout. If device_allocator is non-null, + // then the compiler may use it to allocate temp space on the device. The + // compiler is responsible for freeing any memory it allocates this way. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> CompileExecutable( + const XlaComputation& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const ExecutableBuildOptions& build_options); + // Returns the device ordinal that corresponds to the given replica number. // // This returns an error if there is not a one-to-one correspondence of diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index e62bafc50b0e1270702621c9ea7b2ee43e001fe0..f15117f45c689f2d717fbfe6191b510586449bc4 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -53,6 +53,14 @@ bool IsReshapeOrTranspose(const HloInstruction* instruction) { instruction->opcode() == HloOpcode::kTranspose; } +// Returns true if `a` is a broadcast instruction to target shape `shape` and +// its operand is a scalar. +bool IsBroadcastScalarToShape(const HloInstruction* a, const Shape& shape) { + return a->opcode() == HloOpcode::kBroadcast && + ShapeUtil::SameDimensions(a->shape(), shape) && + ShapeUtil::IsScalar(a->operand(0)->shape()); +} + // Returns true iff `instruction` can change its shape simply by adjusting // metadata. bool CanTriviallyChangeShape(const HloInstruction* instruction) { @@ -88,6 +96,7 @@ bool CanTriviallyChangeShape(const HloInstruction* instruction) { instruction->user_count() == 1) { return true; } + return false; } @@ -148,6 +157,8 @@ bool AllOperandsHaveEasyShapeChanges( // or // 2. Are one of kConstant, kRng, and scalars that can change shape // trivially, + // or + // 3. Are broadcast with a scalar operand. for (const HloInstruction* operand : instruction->operands()) { if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { VLOG(5) << "Operand shape differs from output shape; may be " @@ -158,6 +169,12 @@ bool AllOperandsHaveEasyShapeChanges( return false; } + // Skip the rest checks if the current operand is first_reshape_operand + // itself. + if (first_reshape_operand == operand) { + continue; + } + if (AreEquivalentReshapes(first_reshape_operand, operand)) { VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " << first_reshape_operand->ToString(print_no_metadata) @@ -171,6 +188,12 @@ bool AllOperandsHaveEasyShapeChanges( continue; } + if (IsBroadcastScalarToShape(operand, first_reshape_operand->shape())) { + VLOG(5) << "Broadcast scalar to shape: " + << operand->ToString(print_no_metadata); + continue; + } + // TODO(someone): Look into supporting general ops for the operands as // well. VLOG(5) << "Operand is neither equalivant to the first Reshape operand" @@ -222,6 +245,12 @@ HloInstruction* UpdateOperand(HloComputation* computation, VLOG(5) << "Using existing operand of kReshape or kTranspose"; return operand->mutable_operand(0); } + case HloOpcode::kBroadcast: + CHECK(IsBroadcastScalarToShape(operand, first_reshape_operand->shape())); + VLOG(5) << "Changing broadcast"; + return computation->AddInstruction( + operand->CloneWithNewOperands(new_shape, operand->operands())); + default: LOG(FATAL) << "Unexpected operand opcode during update: " << operand; } diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index aac8638a54f744f0c230ec6c5ca071c1daf45ab2..4e0a0a8832379402edfc231ea84221448d70bac2 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -560,5 +560,25 @@ TEST_F(ReshapeMoverTest, MultiplePasses) { op::Reshape(op::Add(param2, op::Reshape(op::Add(param0, param1))))); } +TEST_F(ReshapeMoverTest, SinkTransposeAcrossBroadcastScalar) { + const string hlo_string = R"( + HloModule TransposeMulInversedTransposeModule + ENTRY TransposeMulInversedTranspose { + src0 = f32[1,20,8,32]{3,2,1,0} parameter(0) + transpose0 = f32[1,8,20,32]{3,2,1,0} transpose(src0), dimensions={0,2,1,3} + src1 = f32[] parameter(1) + broadcast0 = f32[1,8,20,32]{3,2,1,0} broadcast(src1), dimensions={} + ROOT multiply0 = f32[1,8,20,32]{3,2,1,0} multiply(transpose0, broadcast0) + } + )"; + + ParseAndVerifyModule(hlo_string.c_str()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + EXPECT_TRUE(changed); + + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::Transpose(op::Multiply())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 43d0f605985819afdaf2db2309a0bfb86f230fe3..1d379f0d03fa509173ffaf7a69f21da62e9b44e0 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -232,10 +232,14 @@ tensorflow::Status Service::ValidateResultShapeWithLayout( return ShapeUtil::ValidateShape(shape_with_layout); } -StatusOr> Service::ResolveAndValidateArguments( +StatusOr>> +Service::ResolveAndValidateArguments( tensorflow::gtl::ArraySlice arguments, - int device_ordinal) { - std::vector shaped_buffers; + tensorflow::gtl::ArraySlice + stream_executors) { + CHECK_EQ(options_.number_of_replicas(), stream_executors.size()); + std::vector> replicated_arguments; + replicated_arguments.resize(options_.number_of_replicas()); for (size_t i = 0; i < arguments.size(); ++i) { auto buffer_status = allocation_tracker_.Resolve(*arguments[i]); if (!buffer_status.ok()) { @@ -243,29 +247,32 @@ StatusOr> Service::ResolveAndValidateArguments( StrCat(buffer_status.status().error_message(), ", ", "failed to resolve allocation for parameter ", i)); } - const ShapedBuffer* shaped_buffer = buffer_status.ValueOrDie(); - - // Verify allocation is same platform and device as the execution. - if (shaped_buffer->platform() != execute_backend_->platform() || - shaped_buffer->device_ordinal() != device_ordinal) { - return InvalidArgument( - "argument %lu is on device %s:%d but computation will be executed " - "on device %s", - i, shaped_buffer->platform()->Name().c_str(), - shaped_buffer->device_ordinal(), - execute_backend_->device_name(device_ordinal).c_str()); + auto replicated_buffers = buffer_status.ValueOrDie(); + CHECK_EQ(options_.number_of_replicas(), replicated_buffers.size()); + for (int replica = 0; replica < options_.number_of_replicas(); ++replica) { + const ShapedBuffer* shaped_buffer = replicated_buffers[replica]; + int replica_device_ordinal = stream_executors[replica]->device_ordinal(); + // Verify allocation is same platform and device as the execution. + if (shaped_buffer->platform() != execute_backend_->platform() || + shaped_buffer->device_ordinal() != replica_device_ordinal) { + return InvalidArgument( + "argument %lu is on device %s:%d but computation will be executed " + "on device %s", + i, shaped_buffer->platform()->Name().c_str(), + shaped_buffer->device_ordinal(), + execute_backend_->device_name(replica_device_ordinal).c_str()); + } + replicated_arguments[replica].push_back(shaped_buffer); } - - shaped_buffers.push_back(shaped_buffer); } - return shaped_buffers; + return replicated_arguments; } StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, const ExecutionOptions* execution_options, - const UserComputation& user_computation) { + const UserComputation* user_computation) { auto config = MakeUnique(program_shape); auto* computation_layout = config->mutable_entry_computation_layout(); @@ -279,8 +286,15 @@ StatusOr> Service::CreateModuleConfig( // ProgramShape. if (!ShapeUtil::Compatible(*argument_shapes[i], program_shape.parameters(i))) { + if (user_computation == nullptr) { + return InvalidArgument( + "Argument does not match shape of computation parameter %d: want " + "%s, got %s", + i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), + ShapeUtil::HumanString(*argument_shapes[i]).c_str()); + } return InvalidParameterArgument( - *user_computation.ParameterMetadata(i).value(), + *user_computation->ParameterMetadata(i).value(), "Argument does not match shape of computation parameter %d: want %s, " "got %s", i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), @@ -307,8 +321,6 @@ StatusOr> Service::CreateModuleConfig( if (execution_options != nullptr) { config->set_seed(execution_options->seed()); config->set_debug_options(execution_options->debug_options()); - config->enable_hlo_profiling( - execution_options->debug_options().xla_hlo_profile()); } else { config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); } @@ -325,7 +337,7 @@ StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, const ExecutionOptions& execution_options, - const UserComputation& user_computation) { + const UserComputation* user_computation) { std::vector argument_shapes; for (const auto* arg : arguments) { argument_shapes.push_back(&arg->on_host_shape()); @@ -490,7 +502,8 @@ StatusOr> Service::BuildAndCacheExecutable( StatusOr> Service::ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice> arguments, + tensorflow::gtl::ArraySlice>> + arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags, ExecutionProfile* profile) { @@ -513,6 +526,8 @@ Service::ExecuteParallelAndRegisterResult( for (int64 i = 0; i < executables.size(); i++) { // Stream executors for the replicas of the current computation. TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i])); + CHECK_EQ(replicas.size(), arguments[i].size()); + std::vector> result_buffers; for (int64 replica = 0; replica < replicas.size(); ++replica) { TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, backend->BorrowStream(replicas[replica])); @@ -545,23 +560,20 @@ Service::ExecuteParallelAndRegisterResult( backend->StreamBorrower()); // Asynchronously launch the computation. - TF_ASSIGN_OR_RETURN( - std::unique_ptr result, - executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i])); + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + executables[i]->ExecuteAsyncOnStream( + &run_options, arguments[i][replica])); if (replica == 0 && profile != nullptr) { streams.back()->ThenStopTimer(timers.back().get()); } - // All replicas share the same device address for the result allocation, - // so only one of the replicas need to register the result handle. - if (replica == 0) { - TF_ASSIGN_OR_RETURN( - GlobalDataHandle handle, - allocation_tracker_.Register(std::move(result), result_tags[i])); - result_handles.push_back(handle); - } + result_buffers.emplace_back(std::move(result)); } + TF_ASSIGN_OR_RETURN(GlobalDataHandle handle, + allocation_tracker_.RegisterReplicatedBuffers( + std::move(result_buffers), result_tags[i])); + result_handles.push_back(handle); } // Wait for all executions to complete. @@ -627,9 +639,9 @@ Service::ExecuteParallelAndRegisterResult( StatusOr Service::ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice arguments, - Backend* backend, perftools::gputools::StreamExecutor* executor, - const string& result_tag, ExecutionProfile* profile) { + const tensorflow::gtl::ArraySlice> + arguments, + Backend* backend, const string& result_tag, ExecutionProfile* profile) { // Set up streams. std::vector::SmartPtr> streams; @@ -662,21 +674,26 @@ StatusOr Service::ExecuteAndRegisterResult( backend->inter_op_thread_pool()); } - std::unique_ptr result; if (options_.number_of_replicas() == 1) { - TF_ASSIGN_OR_RETURN(result, executable->ExecuteOnStreamWrapper( - &run_options[0], profile, arguments)); - } else { - // TODO(b/69985541): Support profiling also on this path. - std::vector> - repeated_arguments(options_.number_of_replicas(), arguments); + TF_ASSIGN_OR_RETURN( + auto result, executable->ExecuteOnStreamWrapper(&run_options[0], + profile, arguments[0])); + return allocation_tracker_.Register(std::move(result), result_tag); + } - TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( - run_options, repeated_arguments)); - TF_RET_CHECK(!results.empty()); - result = std::move(results[0]); + // TODO(b/69985541): Support profiling also on this path. + + std::vector> + replicated_arguments; + for (const auto& arg : arguments) { + replicated_arguments.emplace_back(arg); } - return allocation_tracker_.Register(std::move(result), result_tag); + + TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( + run_options, replicated_arguments)); + TF_RET_CHECK(!results.empty()); + return allocation_tracker_.RegisterReplicatedBuffers(std::move(results), + result_tag); } tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg, @@ -690,7 +707,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) { VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); - std::vector> all_arguments; + std::vector>> all_arguments; std::vector> all_executors; std::vector versioned_handles; std::vector> module_configs; @@ -718,6 +735,14 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, return FailedPrecondition( "device handles must be given to execute parallel computations"); } + if (arg->requests_size() > 1 && + execution_options.device_handles_size() > 1) { + return InvalidArgument( + "Parallel requests with multiple device handles is not supported. " + "Found %d parallel requests, with request %lld containing %d device " + "handles.", + arg->requests_size(), i, execution_options.device_handles_size()); + } std::vector executors; for (const auto& device_handle : execution_options.device_handles()) { TF_ASSIGN_OR_RETURN(auto replicas, @@ -747,22 +772,26 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // In the case of partitioned computations, assume all arguments go on the // zeroth core. TF_ASSIGN_OR_RETURN( - std::vector arguments, - ResolveAndValidateArguments(request.arguments(), - executors[0]->device_ordinal())); + auto replicas, + Replicas(*execute_backend_, execution_options.device_handles(0))); + TF_ASSIGN_OR_RETURN( + std::vector> replicated_arguments, + ResolveAndValidateArguments(request.arguments(), replicas)); // Create an HloModuleConfig object for the computation, given the shape of - // the program and the argument allocations. + // the program and the argument allocations. Here, we care only about the + // shapes of the arguments, so, it is sufficient to use the arguments of + // replica 0. TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arguments, - request.execution_options(), *user_computation)); + CreateModuleConfig(*program_shape, replicated_arguments.front(), + request.execution_options(), user_computation)); VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); // Adds to the vectors to build and execute the computations after the loop. - all_arguments.push_back(arguments); - all_arguments.insert(all_arguments.end(), executors.size() - 1, {}); + all_arguments.push_back(replicated_arguments); + all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}}); versioned_handles.push_back(versioned_handle); module_configs.push_back(std::move(module_config)); computation_names.insert(computation_names.end(), executors.size(), @@ -832,6 +861,33 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, return tensorflow::Status::OK(); } +tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg, + ExecuteResponse* result) { + ExecuteParallelRequest parallel_arg; + *parallel_arg.add_requests() = *arg; + ExecuteParallelResponse parallel_result; + TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); + // The "result device" selection is a bit hacky, but better than assuming it + // is device 0. We have b/76035356 for restructuring the client API to clean + // up the current asymmetries and support more functionalities. + for (int64 i = 0; i < parallel_result.responses_size(); ++i) { + TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, + allocation_tracker_.ResolveForReplica( + parallel_result.responses(i).output(), 0)); + const Shape& shape = buffer->on_host_shape(); + if (!ShapeUtil::IsEmptyTuple(shape)) { + *result = parallel_result.responses(i); + VLOG(3) << "Fetching result from device " << i << ": " + << ShapeUtil::HumanString(shape); + return Status::OK(); + } + } + TF_RET_CHECK(parallel_result.responses_size() > 0); + *result = parallel_result.responses(0); + VLOG(1) << "Defaulting to device 0 result"; + return Status::OK(); +} + tensorflow::Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { VLOG(1) << "running execute request: " << arg->ShortDebugString(); @@ -848,28 +904,25 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, // If we received multiple device handles, we must partition the module. if (arg->execution_options().device_handles_size() > 1) { - ExecuteParallelRequest parallel_arg; - *parallel_arg.add_requests() = *arg; - ExecuteParallelResponse parallel_result; - TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); - TF_RET_CHECK(parallel_result.responses_size() > 0); - *result = parallel_result.responses(0); - return Status::OK(); + return ExecuteOneToN(arg, result); } TF_ASSIGN_OR_RETURN( std::shared_ptr program_shape, user_computation->ComputeProgramShape(versioned_handle.version)); + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, + SingleComputationDeviceHandle())); TF_ASSIGN_OR_RETURN( - std::vector arguments, - ResolveAndValidateArguments(arg->arguments(), - execute_backend_->default_device_ordinal())); + std::vector> replicated_arguments, + ResolveAndValidateArguments(arg->arguments(), replicas)); + // Since we care only about the shapes of the arguments, it is sufficient to + // use the arguments of replica 0. TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arguments, arg->execution_options(), - *user_computation)); + CreateModuleConfig(*program_shape, replicated_arguments.front(), + arg->execution_options(), user_computation)); VLOG(3) << "Execute created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -885,20 +938,21 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, executable->session_module()->set_execution_platform( execute_backend_->platform()->Name()); TF_RETURN_IF_ERROR(RecordArguments( - arguments, execute_backend_->default_stream_executor(), + replicated_arguments.front(), + execute_backend_->default_stream_executor(), execute_backend_->transfer_manager(), executable->session_module())); } TF_ASSIGN_OR_RETURN( *result->mutable_output(), ExecuteAndRegisterResult( - executable.get(), arguments, execute_backend_.get(), - execute_backend_->default_stream_executor(), + executable.get(), replicated_arguments, execute_backend_.get(), "result of " + user_computation->name(), result->mutable_profile())); if (executable->dumping()) { - TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, - allocation_tracker_.Resolve(result->output())); + TF_ASSIGN_OR_RETURN( + const ShapedBuffer* result_buffer, + allocation_tracker_.ResolveForReplica(result->output(), 0)); TF_RETURN_IF_ERROR(RecordResult( *result_buffer, execute_backend_->default_stream_executor(), execute_backend_->transfer_manager(), executable->session_module())); @@ -909,6 +963,68 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, return tensorflow::Status::OK(); } +StatusOr> Service::BuildExecutable( + const HloModuleProto& module_proto, + std::unique_ptr module_config, Backend* backend, + se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { + VLOG(1) << Printf( + "BuildExecutable on service %p with serialized module proto: %s", this, + module_proto.name().c_str()); + + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(module_proto, *module_config)); + + TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); + + TF_ASSIGN_OR_RETURN( + module, backend->compiler()->RunHloPasses(std::move(module), executor, + device_allocator)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + backend->compiler()->RunBackend( + std::move(module), executor, device_allocator)); + + return std::move(executable); +} + +tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { + VLOG(1) << "running execute-graph request"; + + if (!arg->has_computation()) { + return InvalidArgument("computations may not be empty"); + } + + // TODO(b/74197823): Handle partitioning. + + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, + SingleComputationDeviceHandle())); + TF_ASSIGN_OR_RETURN( + std::vector> replicated_arguments, + ResolveAndValidateArguments(arg->arguments(), replicas)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, + CreateModuleConfig(arg->computation().program_shape(), + replicated_arguments.front(), + arg->execution_options())); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + BuildExecutable(arg->computation(), std::move(module_config), + execute_backend_.get(), + execute_backend_->default_stream_executor(), + /*device_allocator=*/nullptr)); + + TF_ASSIGN_OR_RETURN( + *result->mutable_output(), + ExecuteAndRegisterResult( + executable.get(), replicated_arguments, execute_backend_.get(), + "result of " + arg->computation().name(), result->mutable_profile())); + + VLOG(1) << "successfully completed 'execute-graph' request"; + return tensorflow::Status::OK(); +} + tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, ExecuteAsyncResponse* result) { VLOG(1) << "running execute-async request: " << arg->ShortDebugString(); @@ -926,15 +1042,17 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, std::shared_ptr program_shape, user_computation->ComputeProgramShape(versioned_handle.version)); + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, + SingleComputationDeviceHandle())); + TF_RET_CHECK(!replicas.empty()); TF_ASSIGN_OR_RETURN( - std::vector arguments, - ResolveAndValidateArguments(arg->arguments(), - execute_backend_->default_device_ordinal())); + std::vector> replicated_arguments, + ResolveAndValidateArguments(arg->arguments(), replicas)); TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arguments, arg->execution_options(), - *user_computation)); + CreateModuleConfig(*program_shape, replicated_arguments.front(), + arg->execution_options(), user_computation)); VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -947,21 +1065,17 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, versioned_handle, std::move(module_config), execute_backend_.get(), execute_backend_->default_stream_executor(), &profile)); - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); - TF_RET_CHECK(!replicas.empty()); - // Set up streams. std::vector::SmartPtr> streams; - for (se::StreamExecutor* executor : replicas) { TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, execute_backend_->BorrowStream(executor)); streams.push_back(std::move(stream)); } - std::unique_ptr result_buffer; - for (const Pool::SmartPtr& stream : streams) { + std::vector> result_buffers; + for (size_t i = 0; i < streams.size(); ++i) { + const auto& stream = streams[i]; ExecutableRunOptions options; options.set_stream(stream.get()); options.set_allocator(execute_backend_->memory_allocator()); @@ -972,20 +1086,17 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, ServiceExecutableRunOptions service_options( options, execute_backend_->StreamBorrower()); - TF_ASSIGN_OR_RETURN( - std::unique_ptr this_result_buffer, - executable->ExecuteAsyncOnStream(&service_options, arguments)); + TF_ASSIGN_OR_RETURN(std::unique_ptr this_result_buffer, + executable->ExecuteAsyncOnStream( + &service_options, replicated_arguments[i])); - // Take the first result. - if (result_buffer == nullptr) { - result_buffer = std::move(this_result_buffer); - } + result_buffers.emplace_back(std::move(this_result_buffer)); } TF_ASSIGN_OR_RETURN( GlobalDataHandle output, - allocation_tracker_.Register(std::move(result_buffer), - "result of " + user_computation->name())); + allocation_tracker_.RegisterReplicatedBuffers( + std::move(result_buffers), "result of " + user_computation->name())); *result->mutable_execution() = execution_tracker_.Register( execute_backend_.get(), std::move(streams), profile, output); @@ -1013,7 +1124,7 @@ tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, TransferToClientResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, - allocation_tracker_.Resolve(arg->data())); + allocation_tracker_.ResolveForReplica(arg->data(), 0)); const Shape* return_shape; if (arg->has_shape_with_layout()) { @@ -1074,37 +1185,24 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); } - // All memory allocation is done on the first replica. The allocations in all - // other replicas mirror the firsts'. - int master_device_ordinal = replicas[0]->device_ordinal(); - TF_ASSIGN_OR_RETURN( - std::unique_ptr shaped_buffer, - execute_backend_->transfer_manager()->AllocateShapedBuffer( - shape, execute_backend_->memory_allocator(), master_device_ordinal)); - - // Transfer the data to the replicas. + // Allocate memory in each replica and transfer the data to all replicas. + std::vector> replicated_buffers; for (se::StreamExecutor* executor : replicas) { - if (executor->device_ordinal() == master_device_ordinal) { - TF_RETURN_IF_ERROR( - execute_backend_->transfer_manager()->TransferLiteralToDevice( - executor, *literal, *shaped_buffer)); - } else { - // The replica is not the master. Create an cloned shaped buffer with - // the replica's device ordinal. This is required because - // TransferLiteralToDevice verifies that the device ordinal of the shaped - // buffer matches that of the executor. - std::unique_ptr clone = - CloneShapedBufferOnDevice(*shaped_buffer, executor->device_ordinal()); - TF_RETURN_IF_ERROR( - execute_backend_->transfer_manager()->TransferLiteralToDevice( - executor, *literal, *clone)); - } + TF_ASSIGN_OR_RETURN( + std::unique_ptr shaped_buffer, + execute_backend_->transfer_manager()->AllocateShapedBuffer( + shape, execute_backend_->memory_allocator(), + executor->device_ordinal())); + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->TransferLiteralToDevice( + executor, *literal, *shaped_buffer)); + replicated_buffers.emplace_back(std::move(shaped_buffer)); } - TF_ASSIGN_OR_RETURN( - *result->mutable_data(), - allocation_tracker_.Register(std::move(shaped_buffer), - StrCat("TransferToServer literal of shape ", - ShapeUtil::HumanString(shape)))); + TF_ASSIGN_OR_RETURN(*result->mutable_data(), + allocation_tracker_.RegisterReplicatedBuffers( + std::move(replicated_buffers), + StrCat("TransferToServer literal of shape ", + ShapeUtil::HumanString(shape)))); return tensorflow::Status::OK(); } @@ -1255,7 +1353,7 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, CreateModuleConfig(program_shape, {}, execution_options, - *user_computation)); + user_computation)); // Exclude dead parameter instructions for the purpose of computing constants. TF_ASSIGN_OR_RETURN( @@ -1287,7 +1385,7 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, tensorflow::Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, - allocation_tracker_.Resolve(arg->data())); + allocation_tracker_.ResolveForReplica(arg->data(), 0)); *result->mutable_shape() = buffer->on_host_shape(); return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 6ce241971156599aaa25aea1b0caac0e1bd5379c..773f0a642dc93899828ef7b2dd4e271fc3d50d05 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -112,6 +112,14 @@ class Service : public ServiceInterface { tensorflow::Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; + // Executes a computation with the provided global data passed as + // immutable arguments. The request contains the whole computation graph. + // Returns global data output and execution timing. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) override; + // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. @@ -252,7 +260,7 @@ class Service : public ServiceInterface { const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, const ExecutionOptions& execution_options, - const UserComputation& user_computation); + const UserComputation* user_computation = nullptr); protected: friend class LocalExecutable; @@ -265,11 +273,14 @@ class Service : public ServiceInterface { static StatusOr> CreateComputeConstantBackend(); // Resolves the given argument handles in the allocation tracker and returns - // the corresponding allocations. The function also verifies that each - // allocation matches the execution platform and device ordinal. - StatusOr> ResolveAndValidateArguments( + // the corresponding allocations for every replica. The function also verifies + // that each allocation matches the execution platform and device ordinal of + // the corresponding replica. + StatusOr>> + ResolveAndValidateArguments( tensorflow::gtl::ArraySlice arguments, - int device_ordinal); + tensorflow::gtl::ArraySlice + stream_executors); // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. @@ -277,7 +288,7 @@ class Service : public ServiceInterface { const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, const ExecutionOptions* execution_options, - const UserComputation& user_computation); + const UserComputation* user_computation = nullptr); // Builds an Executable for the given parameters. // @@ -290,6 +301,15 @@ class Service : public ServiceInterface { perftools::gputools::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator = nullptr); + // Builds an Executable for the given HLO module proto. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> BuildExecutable( + const HloModuleProto& module_proto, + std::unique_ptr module_config, Backend* backend, + perftools::gputools::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator = nullptr); + // Same as BuildExecutable() above, but builds a list of Executables for the // given computations that may interact with each other. StatusOr>> BuildExecutables( @@ -314,16 +334,17 @@ class Service : public ServiceInterface { // ExecutionProfile object which will be filled in with profile data. StatusOr ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice arguments, - Backend* backend, perftools::gputools::StreamExecutor* executor, - const string& result_tag, ExecutionProfile* profile); + const tensorflow::gtl::ArraySlice> + arguments, + Backend* backend, const string& result_tag, ExecutionProfile* profile); // Runs the given executables with the given arguments and register the result // from each executable in the allocation tracker. The handles of the result // from the tracker are returned. StatusOr> ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice> arguments, + tensorflow::gtl::ArraySlice>> + arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags, @@ -336,6 +357,12 @@ class Service : public ServiceInterface { const std::function(UserComputation*)>& adder); + // Executes a single computation which has more than one target device. + // The N devices are expected to all return an empty tuple, but one, which + // will be the result of this computation. + tensorflow::Status ExecuteOneToN(const ExecuteRequest* arg, + ExecuteResponse* result); + // Convenience function which checks whether the given shape_with_layout // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index c9692757b27980b10a5ca562223c3d0f6462d820..36456d552d1ed41e192308fec7489a44f8dd5051 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -169,11 +169,11 @@ bool AllUnique(tensorflow::gtl::ArraySlice slice) { tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape, tensorflow::StringPiece op_type) { if (ShapeUtil::IsTuple(shape)) { - return InvalidArgument("Expected non-tuple argument for %s. Got: %s", + return InvalidArgument("Expected non-tuple argument for %s, but got %s.", op_type.ToString().c_str(), ShapeUtil::HumanString(shape).c_str()); } else if (ShapeUtil::IsOpaque(shape)) { - return InvalidArgument("Expected non-opaque argument for %s. Got: %s", + return InvalidArgument("Expected non-opaque argument for %s, but got %s.", op_type.ToString().c_str(), ShapeUtil::HumanString(shape).c_str()); } else { @@ -193,8 +193,10 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, const Shape& accumulator_shape = reducer_shape.result(); if (ShapeUtil::Rank(accumulator_shape) != 0) { - return Unimplemented( - "Reduction function currently must have rank-0 result."); + return InvalidArgument( + "Reduction function must have rank 0 (rank %lld reduction function " + "given).", + ShapeUtil::Rank(accumulator_shape)); } // Check that the accumulator can be passed in as the first argument. @@ -235,8 +237,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, reducer_shape.parameters(1))) { return InvalidArgument( - "Reduction function's second parameter shape currently must " - "match the result shape. Got %s vs %s", + "Reduction function's second parameter shape must " + "match the result shape, but got %s vs %s.", ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(), ShapeUtil::HumanString(accumulator_shape).c_str()); } @@ -258,29 +260,29 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, for (int64 i = 0; i < window.dimensions_size(); ++i) { const auto& dim = window.dimensions(i); if (dim.size() <= 0) { - return InvalidArgument("Window has a non-positive dimension. Window: %s", + return InvalidArgument("Window %s has a non-positive dimension.", window.DebugString().c_str()); } if (dim.stride() <= 0) { - return InvalidArgument("Window has a non-positive stride. Window: %s", + return InvalidArgument("Window %s has a non-positive stride.", window.DebugString().c_str()); } if (!allow_negative_padding && dim.padding_low() < 0) { - return InvalidArgument("Window has a negative low padding. Window: %s", + return InvalidArgument("Window %s has a negative low padding.", window.DebugString().c_str()); } if (!allow_negative_padding && dim.padding_high() < 0) { - return InvalidArgument("Window has a negative high padding. Window: %s", + return InvalidArgument("Window %s has a negative high padding.", window.DebugString().c_str()); } if (dim.base_dilation() < 1) { return InvalidArgument( - "Window has a non-positive base area dilation factor. Window: %s", + "Window %s has a non-positive base area dilation factor.", window.DebugString().c_str()); } if (dim.window_dilation() < 1) { return InvalidArgument( - "Window has a non-positive window dilation factor. Window: %s", + "Window %s has a non-positive window dilation factor.", window.DebugString().c_str()); } @@ -302,12 +304,17 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferUnaryOpShape( HloOpcode opcode, const HloInstruction* operand) { + return InferUnaryOpShape(opcode, operand->shape()); +} + +/* static */ StatusOr ShapeInference::InferUnaryOpShape( + HloOpcode opcode, const Shape& shape) { // There is no copy operation at the proto level, so handle copy explicitly. if (opcode == HloOpcode::kCopy) { - return operand->shape(); + return shape; } - return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), operand->shape()); + return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), shape); } /* static */ StatusOr ShapeInference::InferUnaryOpShape( @@ -320,8 +327,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case UNOP_CEIL: if (!ShapeUtil::ElementIsFloating(arg)) { return InvalidArgument( - "expected element type in shape to be floating for floor/ceil " - "operation; got %s", + "Expected element type in shape to be floating for floor/ceil " + "operation; got %s.", PrimitiveType_Name(arg.element_type()).c_str()); } return arg; @@ -333,8 +340,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (!ShapeUtil::ElementIsFloating(arg) && !ShapeUtil::ElementIsComplex(arg)) { return InvalidArgument( - "expected element type in shape to be floating or complex for " - "sin/cos/exp/log/tanh operation; got %s", + "Expected element type in shape to be floating or complex for " + "sin/cos/exp/log/tanh operation; got %s.", PrimitiveType_Name(arg.element_type()).c_str()); } return arg; @@ -342,8 +349,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case UNOP_IMAG: if (!ShapeUtil::ElementIsComplex(arg)) { return InvalidArgument( - "expected element type in shape to be complex for real/imag " - "operation; got %s", + "Expected element type in shape to be complex for real/imag " + "operation; got %s.", PrimitiveType_Name(arg.element_type()).c_str()); } return ShapeUtil::ChangeElementType(arg, F32); @@ -363,8 +370,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (arg.element_type() != PRED && !primitive_util::IsIntegralType(arg.element_type())) { return InvalidArgument( - "expected pred or an integral element type in argument to not " - "operation; got %s", + "Expected pred or an integral element type in argument to Not " + "operation; got %s.", PrimitiveType_Name(arg.element_type()).c_str()); } return arg; @@ -372,8 +379,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case UNOP_IS_FINITE: if (!ShapeUtil::ElementIsFloating(arg)) { return InvalidArgument( - "expected element type in shape to be floating point for IsFinite " - "operation; got %s", + "Expected element type in shape to be floating point for IsFinite " + "operation; got %s.", PrimitiveType_Name(arg.element_type()).c_str()); } return ShapeUtil::ChangeElementType(arg, PRED); @@ -389,10 +396,10 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, tensorflow::gtl::ArraySlice arg_shapes, const int64 dimension) { if (arg_shapes.empty()) { - return InvalidArgument("Concatenate expects at least one argument"); + return InvalidArgument("Concatenate expects at least one argument."); } if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) { - return InvalidArgument("dimension to concatenate along out of bounds: %lld", + return InvalidArgument("Concatenate dimension out of bounds: %lld.", dimension); } const Shape* arg_shape = nullptr; @@ -408,14 +415,14 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { return InvalidArgument( "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld " - "(%s)", + "(%s).", ShapeUtil::Rank(*arg_shape), ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape).c_str()); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) { return InvalidArgument( - "cannot concatenate arrays with different element types: %s vs %s", + "Cannot concatenate arrays with different element types: %s vs %s.", PrimitiveType_Name(arg_shape->element_type()).c_str(), PrimitiveType_Name(shape->element_type()).c_str()); } @@ -428,9 +435,9 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, // concatenating. } return InvalidArgument( - "cannot concatenate arrays that differ in dimensions other than " + "Cannot concatenate arrays that differ in dimensions other than " "the one being concatenated (the other array dimensions must be " - "the same): %s vs %s in dimension %lld", + "the same): %s vs %s in dimension %lld.", ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::HumanString(*shape).c_str(), dimension); } @@ -452,7 +459,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (primitive_util::IsComplexType(old_element_type) && !primitive_util::IsComplexType(new_element_type)) { return Unimplemented( - "Unsupported conversion from complex to real type: %s => %s", + "Conversion from complex to real type %s => %s is not implemented.", ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } @@ -461,7 +468,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, // future, by recursing into the tuple elements to check all sub-conversions // are valid. For now we just reject them, though. return InvalidArgument( - "cannot convert from or to tuple type; requested conversion: %s => %s", + "Convert does not allow tuples, so cannot convert from %s to %s.", ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } @@ -474,24 +481,23 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, auto old_element_type = operand_shape.element_type(); if (primitive_util::IsComplexType(old_element_type) != primitive_util::IsComplexType(new_element_type)) { - return Unimplemented( - "Unsupported conversion between real and complex types: %s => %s", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + return InvalidArgument("Conversion from complex to real type %s => %s.", + ShapeUtil::HumanString(operand_shape).c_str(), + PrimitiveType_Name(new_element_type).c_str()); } if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions // are valid. For now we just reject them, though. return InvalidArgument( - "cannot convert from or to tuple type; requested conversion: %s => %s", + "Cannot convert from or to tuple type; requested conversion: %s => %s.", ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } if (primitive_util::BitWidth(old_element_type) != primitive_util::BitWidth(new_element_type)) { return InvalidArgument( - "cannot bitcast types with different bit-widths: %s => %s", + "Cannot bitcast types with different bit-widths: %s => %s.", PrimitiveType_Name(old_element_type).c_str(), PrimitiveType_Name(new_element_type).c_str()); } @@ -504,20 +510,20 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, const int mantissa_bits) { if (!ShapeUtil::ElementIsFloating(operand_shape)) { return InvalidArgument( - "expected element type in shape to be floating point for " - "ReducePrecision operation; got %s", + "Expected element type in shape to be floating point for " + "ReducePrecision operation; got %s.", PrimitiveType_Name(operand_shape.element_type()).c_str()); } if (exponent_bits < 1) { // One exponent bit is necessary to distinguish 0 from infinity. Having // no exponent bits doesn't produce a sensible number, so we require at // least one. - return InvalidArgument("expected exponent_bits >= 1; got %d", + return InvalidArgument("Expected exponent_bits >= 1; got %d.", exponent_bits); } if (mantissa_bits < 0) { // A number with no mantissa bits is still meaningful, however. - return InvalidArgument("expected non-negative mantissa_bits; got %d", + return InvalidArgument("Expected non-negative mantissa_bits; got %d.", mantissa_bits); } return operand_shape; @@ -528,23 +534,23 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, const PaddingConfig& padding_config) { if (ShapeUtil::IsTuple(operand_shape)) { return InvalidArgument( - "pad operation does not support tuple-shape operands"); + "Pad operation does not support tuple-shape operands."); } if (!ShapeUtil::IsScalar(padding_value_shape)) { return InvalidArgument( - "pad operation does not support non-scalar padding values"); + "Pad operation does not support non-scalar padding values."); } if (ShapeUtil::Rank(operand_shape) != padding_config.dimensions_size()) { return InvalidArgument( "The rank of the operand and the padding configuration do not match: " - "%s vs %s", + "%s vs %s.", ShapeUtil::HumanString(operand_shape).c_str(), padding_config.ShortDebugString().c_str()); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, padding_value_shape)) { return InvalidArgument( - "the element types of the operands to pad do not match"); + "The element types of the operands to Pad do not match."); } std::vector dimensions(ShapeUtil::Rank(operand_shape)); for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { @@ -605,7 +611,7 @@ Status ValidateDotDimensionNumbers( lhs_batch_dimensions) || !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions, rhs_batch_dimensions)) { - return InvalidArgument("A dimension number is out of range in dot: %s", + return InvalidArgument("A dimension number is out of range in Dot: %s.", dimension_numbers.DebugString().c_str()); } @@ -623,7 +629,7 @@ Status ValidateDotDimensionNumbers( if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) { - return InvalidArgument("A dimension number is not unique in dot: %s", + return InvalidArgument("A dimension number is not unique in Dot: %s.", dimension_numbers.DebugString().c_str()); } @@ -641,8 +647,7 @@ Status ValidateDotDimensionNumbers( rhs_non_contracting_non_batch_dims < 0 || rhs_non_contracting_non_batch_dims > 1) { return InvalidArgument( - "batch and contracting dimension number mismatch " - "with rank "); + "Batch and contracting dimension number mismatch with rank."); } // Check that batch dimension numbers are ordered before all others, and @@ -654,7 +659,7 @@ Status ValidateDotDimensionNumbers( !std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), rhs_batch_dimensions.begin())) { return InvalidArgument( - "batch dimension numbers must precede non-batch dimensions and be" + "Batch dimension numbers must precede non-batch dimensions and be" "monotonically increasing."); } @@ -671,22 +676,22 @@ Status ValidateDotDimensionNumbers( auto fail = [lhs, rhs](const string& addendum) -> Status { string message = tensorflow::strings::Printf( - "cannot infer shape for dot operation: %s %s", + "Cannot infer shape for dot operation: %s %s.", ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); if (!addendum.empty()) { - message += ": " + addendum; + message += " " + addendum; } return InvalidArgument("%s", message.c_str()); }; // Check if both element types are the same. if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { - return fail("element types do not match"); + return fail("Element types do not match."); } if ((ShapeUtil::Rank(lhs) < 1) || (ShapeUtil::Rank(rhs) < 1)) { - return fail("dot only supports rank 1 or above."); + return fail("Dot only supports rank 1 or above."); } // Validate basic properties of dot dimension numbers. @@ -696,7 +701,7 @@ Status ValidateDotDimensionNumbers( if (dimension_numbers.lhs_contracting_dimensions_size() != dimension_numbers.rhs_contracting_dimensions_size() || dimension_numbers.lhs_contracting_dimensions_size() != 1) { - return fail("must specify one contracting dimension for both lhs and rhs."); + return fail("Must specify one contracting dimension for both lhs and rhs."); } // Check that contracting dimension sizes match. @@ -706,13 +711,13 @@ Status ValidateDotDimensionNumbers( dimension_numbers.rhs_contracting_dimensions(0); if (lhs.dimensions(lhs_contracting_dimension) != rhs.dimensions(rhs_contracting_dimension)) { - return fail("contracting dimension sizes do not match."); + return fail("Contracting dimension sizes do not match."); } // Check that number of batch dimensions match. if (dimension_numbers.lhs_batch_dimensions_size() != dimension_numbers.rhs_batch_dimensions_size()) { - return fail("must the same number of batch dimensions for lhs and rhs."); + return fail("Must the same number of batch dimensions for lhs and rhs."); } // Check that batch dimension numbers and sizes match. @@ -721,7 +726,7 @@ Status ValidateDotDimensionNumbers( dimension_numbers.rhs_batch_dimensions(i) || lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) { - return fail("batch dimension numbers and sizes must match for lhs/rhs."); + return fail("Batch dimension numbers and sizes must match for lhs/rhs."); } } @@ -770,10 +775,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } else if (rhs.dimensions(i) == 1) { output_dimensions[i] = lhs.dimensions(i); } else { - return InvalidArgument("binary op %s with incompatible shapes: %s and %s", - BinaryOperation_Name(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + return InvalidArgument( + "Binary op %s with incompatible shapes: %s and %s.", + BinaryOperation_Name(operation).c_str(), + ShapeUtil::HumanString(lhs).c_str(), + ShapeUtil::HumanString(rhs).c_str()); } } return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), @@ -788,15 +794,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // Reject "magic" inference for binops on different shapes, requiring // the user to provide an explicit broadcast dimension in this case. // See b/25177275 for more details. - return InvalidArgument("automatic shape inference not supported: %s and %s", + return InvalidArgument("Automatic shape inference not supported: %s and %s", ShapeUtil::HumanString(smaller_shape).c_str(), ShapeUtil::HumanString(larger_shape).c_str()); } else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) { return InvalidArgument( - "size of broadcast_dimensions has to match lower-rank operand's " + "Size of broadcast_dimensions has to match lower-rank operand's " "rank; " " lower-rank operand's rank is %lld, size of broadcast_dimensions is " - "%zu", + "%zu.", ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size()); } @@ -846,13 +852,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( int64 dimension_to_match = broadcast_dimensions.at(i); if (dimension_to_match < 0) { return InvalidArgument( - "broadcast dimension number (%lld) cannot be negative", + "Broadcast dimension number (%lld) cannot be negative.", dimension_to_match); } if (dimension_to_match >= larger_shape.dimensions_size()) { return InvalidArgument( - "broadcast dimension number (%lld) too large; higher-rank " - "operand has rank %d", + "Broadcast dimension number (%lld) too large; higher-rank " + "operand has rank %d.", dimension_to_match, larger_shape.dimensions_size()); } int64 small_dimension_size = smaller_shape.dimensions(i); @@ -863,7 +869,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (small_dimension_size != large_dimension_size && small_dimension_size != 1 && large_dimension_size != 1) { return InvalidArgument( - "broadcast dimension %d mismatch: %lld != %lld; %s and %s", i, + "Broadcast dimension %d mismatch: %lld != %lld; %s and %s.", i, small_dimension_size, large_dimension_size, ShapeUtil::HumanString(smaller_shape).c_str(), ShapeUtil::HumanString(larger_shape).c_str()); @@ -872,7 +878,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // order. if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) { return InvalidArgument( - "broadcast dimensions order is wrong: %lld comes after %lld", + "Broadcast dimensions order is wrong: %lld comes after %lld.", dimension_to_match, broadcast_dimensions.at(i - 1)); } @@ -892,7 +898,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( - "binary op %s with different element types: %s and %s", + "Binary op %s with different element types: %s and %s.", BinaryOperation_Name(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); @@ -904,8 +910,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!broadcast_dimensions.empty() && broadcast_dimensions != identity_dims) { return InvalidArgument( - "broadcast dimensions field must either be not set or be the " - "identity on binary operations with operands of the same rank"); + "Broadcast dimensions field must either be not set or be the " + "identity on binary operations with operands of the same rank."); } } @@ -943,6 +949,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( rhs->shape(), /*broadcast_dimensions=*/{}); } +/* static */ StatusOr ShapeInference::InferBinaryOpShape( + HloOpcode opcode, const Shape& lhs, const Shape& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs, rhs, + broadcast_dimensions); +} + /* static */ StatusOr ShapeInference::InferBinaryOpShape( BinaryOperation operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { @@ -979,8 +992,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( case BINOP_COMPLEX: { if (!ShapeUtil::ElementIsFloating(lhs)) { return InvalidArgument( - "expected element type in shape to be floating for complex compose " - "operation; got %s", + "Expected element type in shape to be floating for complex compose " + "operation; got %s.", PrimitiveType_Name(lhs.element_type()).c_str()); } TF_ASSIGN_OR_RETURN(const Shape& shape, @@ -989,7 +1002,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (lhs.element_type() == F32 && rhs.element_type() == F32) { return ShapeUtil::ChangeElementType(shape, C64); } else { - return Unimplemented("complex component type not supported"); + return Unimplemented("Complex component type is not implemented."); } } case BINOP_AND: @@ -997,8 +1010,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (lhs.element_type() != PRED && !primitive_util::IsIntegralType(lhs.element_type())) { return InvalidArgument( - "expected pred or integral type in argument to and/or operation; " - "got %s", + "Expected pred or integral type in argument to and/or operation; " + "got %s.", PrimitiveType_Name(lhs.element_type()).c_str()); } return InferElementwiseBinaryOpShape(operation, lhs, rhs, @@ -1016,7 +1029,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } default: return Unimplemented( - "not yet implemented; infer binary op shape: %s; lhs: %s; rhs: %s", + "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.", BinaryOperation_Name(operation).c_str(), lhs.ShortDebugString().c_str(), rhs.ShortDebugString().c_str()); } @@ -1025,8 +1038,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferTernaryOpShape( HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs, const HloInstruction* ehs) { - return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs->shape(), - rhs->shape(), ehs->shape()); + return InferTernaryOpShape(opcode, lhs->shape(), rhs->shape(), ehs->shape()); +} + +/* static */ StatusOr ShapeInference::InferTernaryOpShape( + HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) { + return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs, rhs, ehs); } /* static */ StatusOr ShapeInference::InferTernaryOpShape( @@ -1041,7 +1058,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( case TRIOP_SELECT: return InferSelectShape(lhs, rhs, ehs); default: - return InvalidArgument("unknown operation %s", + return InvalidArgument("Unknown operation %s.", TernaryOperation_Name(operation).c_str()); } } @@ -1072,7 +1089,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return result; } default: - return InvalidArgument("unknown operation %s", + return InvalidArgument("Unknown operation %s.", VariadicOperation_Name(operation).c_str()); } } @@ -1082,7 +1099,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const ProgramShape& to_apply, tensorflow::gtl::ArraySlice dimensions) { if (arg_shapes.empty()) { - return InvalidArgument("Map expects at least one argument"); + return InvalidArgument("Map expects at least one argument."); } // All arguments must have the same shape. @@ -1113,7 +1130,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } return InvalidArgument( "Map operation requires all operands to have the same shape; got: " - "%s", + "%s.", Join(pieces, ", ").c_str()); } @@ -1122,7 +1139,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (dimensions.size() != arg_shape->dimensions_size()) { return InvalidArgument( "Map applied to a subset of dimensions currently not supported: " - "arg_dimension_size: %d, requested_map_dimensions_size: %zu", + "arg_dimension_size: %d, requested_map_dimensions_size: %zu.", arg_shape->dimensions_size(), dimensions.size()); } @@ -1130,7 +1147,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (int i = 0; i < dimensions.size(); ++i) { if (dimensions[i] != i) { return InvalidArgument( - "Map requires monotonically increasing dimension numbers, found: %s ", + "Map requires monotonically increasing dimension numbers; got: %s.", Join(dimensions, ", ").c_str()); } } @@ -1139,7 +1156,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (arg_shapes.size() != to_apply.parameters_size()) { return InvalidArgument( "Map applied function arity must match number of arguments; got: " - "arity: %d, arguments: %zu", + "arity: %d, arguments: %zu.", to_apply.parameters_size(), arg_shapes.size()); } @@ -1147,8 +1164,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& output_shape = to_apply.result(); if (!ShapeUtil::IsScalar(output_shape)) { return InvalidArgument( - "mapped computation's result has to be a scalar; " - "got: %s", + "Mapped computation's result has to be a scalar; got: %s.", ShapeUtil::HumanString(output_shape).c_str()); } @@ -1157,16 +1173,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::IsScalar(parameter_shape)) { return InvalidArgument( - "mapped computation's parameter has to be a scalar; " - "got parameter %d shape: %s", + "Mapped computation's parameter has to be a scalar; " + "got parameter %d shape: %s.", i, ShapeUtil::HumanString(parameter_shape).c_str()); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape, *arg_shape)) { return InvalidArgument( - "mapped computation's parameter type has to match argument element " - "type; got parameter %d shape: %s, argument shape: %s", + "Mapped computation's parameter type has to match argument element " + "type; got parameter %d shape: %s, argument shape: %s.", i, ShapeUtil::HumanString(parameter_shape).c_str(), ShapeUtil::HumanString(*arg_shape).c_str()); } @@ -1197,21 +1213,21 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Expected feature_index of batch-norm-training to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld", + "got feature_index %lld, and rank %lld.", feature_index, ShapeUtil::Rank(operand_shape)); } if (feature_index < 0) { return InvalidArgument( "Expected feature_index of batch-norm-training to " - "be a non-negative number, got %lld", + "be a non-negative number, got %lld.", feature_index); } if (ShapeUtil::Rank(operand_shape) < 1) { return InvalidArgument( "Expected the rank of operand to " - "batch-norm-training to be at least 1; got %lld", + "batch-norm-training to be at least 1; got %lld.", ShapeUtil::Rank(operand_shape)); } @@ -1232,7 +1248,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::ElementIsFloating(operand_shape)) { return InvalidArgument( "The operand to batch-norm-training must have a floating point " - "element type, but the shape is %s", + "element type, but the shape is %s.", PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1241,7 +1257,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The inputs should have the same element type for batch-norm-training, " "but the shape of offset factor is %s " - "and the shape of operand is %s", + "and the shape of operand is %s.", PrimitiveType_Name(offset_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1251,7 +1267,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The inputs should have the same element type for batch-norm-training, " "but the shape of scale factor is %s " - "and the shape of operand is %s", + "and the shape of operand is %s.", PrimitiveType_Name(scale_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1264,7 +1280,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of offset factor should be the same as feature count," "but the size of offset factor is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(offset_shape, 0), feature_count); } @@ -1272,7 +1288,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of scale factor should be the same as feature count," "but the size of scale factor is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } @@ -1307,21 +1323,21 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Expected feature_index of batch-norm-inference to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld", + "got feature_index %lld, and rank %lld.", feature_index, ShapeUtil::Rank(operand_shape)); } if (feature_index < 0) { return InvalidArgument( "Expected feature_index of batch-norm-inference to " - "be a non-negative number, got %lld", + "be a non-negative number, got %lld.", feature_index); } if (ShapeUtil::Rank(operand_shape) < 1) { return InvalidArgument( "Expected the rank of operand to " - "batch-norm-inference to be at least 1; got %lld", + "batch-norm-inference to be at least 1; got %lld.", ShapeUtil::Rank(operand_shape)); } @@ -1342,7 +1358,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::ElementIsFloating(operand_shape)) { return InvalidArgument( "The operand to batch-norm-inference must have a floating point " - "element type, but the shape is %s", + "element type, but the shape is %s.", PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1352,7 +1368,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "The inputs should have the same element type for " "batch-norm-inference, " "but the shape of offset factor is %s " - "and the shape of operand is %s", + "and the shape of operand is %s.", PrimitiveType_Name(offset_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1363,7 +1379,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "The inputs should have the same element type for " "batch-norm-inference, " "but the shape of scale factor is %s " - "and the shape of operand is %s", + "and the shape of operand is %s.", PrimitiveType_Name(scale_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1374,7 +1390,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "The inputs should have the same element type for " "batch-norm-inference, " "but the shape of mean is %s " - "and the shape of operand is %s", + "and the shape of operand is %s.", PrimitiveType_Name(mean_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1385,7 +1401,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "The inputs should have the same element type for " "batch-norm-inference, " "but the shape of variance is %s " - "and the shape of operand is %s", + "and the shape of operand is %s.", PrimitiveType_Name(mean_shape.element_type()).c_str(), PrimitiveType_Name(variance_shape.element_type()).c_str()); } @@ -1398,7 +1414,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of offset factor should be the same as feature count," "but the size of offset factor is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(offset_shape, 0), feature_count); } @@ -1406,7 +1422,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of scale factor should be the same as feature count," "but the size of scale factor is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } @@ -1414,7 +1430,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of mean should be the same as feature count," "but the size of mean is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(mean_shape, 0), feature_count); } @@ -1422,7 +1438,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of variance should be the same as feature count," "but the size of variance is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(variance_shape, 0), feature_count); } @@ -1455,7 +1471,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Expected feature_index of batch-norm-grad to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld", + "got feature_index %lld, and rank %lld.", feature_index, ShapeUtil::Rank(operand_shape)); } @@ -1463,7 +1479,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Expected operand_shape of batch-norm-grad to have the same rank as" " output_grad_shape; got rank(oprand_shape) %lld, and" - " rank(output_grad_shape) %lld", + " rank(output_grad_shape) %lld.", ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape)); } @@ -1491,14 +1507,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::ElementIsFloating(operand_shape)) { return InvalidArgument( "The operand to batch-norm-grad must have a floating point " - "element type, but the shape is %s", + "element type, but the shape is %s.", PrimitiveType_Name(operand_shape.element_type()).c_str()); } if (!ShapeUtil::ElementIsFloating(output_grad_shape)) { return InvalidArgument( "The output_grad to batch-norm-grad must have a floating point " - "element type, but the shape is %s", + "element type, but the shape is %s.", PrimitiveType_Name(output_grad_shape.element_type()).c_str()); } @@ -1507,7 +1523,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of output_grad is %s " - "and the element type of operand is %s", + "and the element type of operand is %s.", PrimitiveType_Name(output_grad_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1517,7 +1533,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of scale factor is %s " - "and the element type of operand is %s", + "and the element type of operand is %s.", PrimitiveType_Name(scale_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1527,7 +1543,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " - "and the element type of operand is %s", + "and the element type of operand is %s.", PrimitiveType_Name(mean_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1537,7 +1553,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " - "and the element type of operand is %s", + "and the element type of operand is %s.", PrimitiveType_Name(mean_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1551,7 +1567,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of mean should be the same as feature count," "but the size of offset factor is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(mean_shape, 0), feature_count); } @@ -1559,7 +1575,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of scale factor should be the same as feature count," "but the size of scale factor is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } @@ -1567,7 +1583,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of variance should be the same as feature count," "but the size of variance is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(var_shape, 0), feature_count); } @@ -1578,7 +1594,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The bounds of operand shape should be the same as output_grad's," "but the bound of operand_shape at dimension %lld is %lld " - "and the bound of output_grad_shape is %lld", + "and the bound of output_grad_shape is %lld.", i, ShapeUtil::GetDimension(operand_shape, i), ShapeUtil::GetDimension(output_grad_shape, i)); } @@ -1596,7 +1612,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( - "Convolution with different element types: %s and %s", + "Convolution with different element types: %s and %s.", ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); } @@ -1612,21 +1628,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (window.dimensions_size() != num_spatial_dims) { return InvalidArgument( "Window must have same number of dimensions as dimension numbers.\n" - "Window: %s\nDimension numbers: %s", + "Window: %s\nDimension numbers: %s.", window.DebugString().c_str(), dnums.DebugString().c_str()); } const int num_dims = num_spatial_dims + 2; if (ShapeUtil::Rank(lhs) != num_dims) { return InvalidArgument( - "The LHS argument to a convolution should have rank %d.\n" - "lhs: %s", + "The LHS argument to a convolution should have rank %d; lhs: %s.", num_dims, ShapeUtil::HumanString(lhs).c_str()); } if (ShapeUtil::Rank(rhs) != num_dims) { return InvalidArgument( - "The RHS argument to a convolution should have rank %d.\n" - "lhs: %s", + "The RHS argument to a convolution should have rank %d; lhs: %s.", num_dims, ShapeUtil::HumanString(lhs).c_str()); } TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); @@ -1663,26 +1677,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( !std::all_of(window_dnums.begin(), window_dnums.end(), in_range) || !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) { return InvalidArgument( - "A dimension number is out of range in convolution: %s", + "A dimension number is out of range in convolution: %s.", dnums.DebugString().c_str()); } if (input_dnums != expected_dnums) { return InvalidArgument( "Input dimensions of convolution must contain each dimension exactly " - "once: %s", + "once: %s.", dnums.DebugString().c_str()); } if (window_dnums != expected_dnums) { return InvalidArgument( "Window dimensions of convolution must contain each dimension exactly " - "once: %s", + "once: %s.", dnums.DebugString().c_str()); } if (output_dnums != expected_dnums) { return InvalidArgument( "Output dimensions of convolution must contain each dimension exactly " - "once: %s", + "once: %s.", dnums.DebugString().c_str()); } @@ -1706,7 +1720,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Expected LHS feature dimension (value %lld) to match RHS " "input feature dimension (value %lld); got (%s, %s)\n" - "Dimension numbers: {%s}", + "Dimension numbers: {%s}.", input_features, kernel_input_features, ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str()); @@ -1720,7 +1734,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "Window dimensions do not match RHS shape:\n\t" "RHS shape: %s\n\t" "Window: {%s}\n\t" - "Dimension numbers: {%s}", + "Dimension numbers: {%s}.", ShapeUtil::HumanString(rhs).c_str(), window.ShortDebugString().c_str(), dnums.ShortDebugString().c_str()); } @@ -1748,8 +1762,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const tensorflow::gtl::ArraySlice fft_length) { const int64 fft_rank = fft_length.size(); if (fft_rank < 1 || fft_rank > 3) { - return InvalidArgument("FFT only supports ranks 1-3, but got %lld", - fft_rank); + return InvalidArgument("FFT only supports ranks 1-3; got %lld.", fft_rank); } #define RET_CHECK_RANK(x) \ if (x.dimensions_size() < fft_rank) { \ @@ -1762,7 +1775,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( case FFT: case IFFT: if (in.element_type() != C64) { - return InvalidArgument("%s requires C64 input type, found %s", + return InvalidArgument("%s requires C64 input type, found %s.", FftType_Name(fft_type).c_str(), PrimitiveType_Name(in.element_type()).c_str()); } @@ -1770,7 +1783,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return in; case RFFT: { if (in.element_type() != F32) { - return InvalidArgument("RFFT requires F32 input type, found %s", + return InvalidArgument("RFFT requires F32 input type, found %s.", PrimitiveType_Name(in.element_type()).c_str()); } RET_CHECK_RANK(in); @@ -1779,7 +1792,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( fft_length[i]) { return InvalidArgument( "RFFT requires innermost dimensions match fft_length but " - "dimension %lld is %lld and should be %lld", + "dimension %lld is %lld and should be %lld.", in.dimensions_size() - fft_rank + i, in.dimensions(in.dimensions_size() - fft_rank + i), fft_length[i]); @@ -1792,7 +1805,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } case IRFFT: { if (in.element_type() != C64) { - return InvalidArgument("IRFFT requires C64 input type, found %s", + return InvalidArgument("IRFFT requires C64 input type, found %s.", PrimitiveType_Name(in.element_type()).c_str()); } RET_CHECK_RANK(in); @@ -1802,7 +1815,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( fft_length[i]) { return InvalidArgument( "IRFFT requires all but one innermost dimensions match " - "fft_length, but dimension %lld is %lld and should be %lld", + "fft_length, but dimension %lld is %lld and should be %lld.", in.dimensions_size() - fft_rank + i, in.dimensions(in.dimensions_size() - fft_rank + i), fft_length[i]); @@ -1812,7 +1825,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( fft_length[fft_rank - 1] / 2 + 1) { return InvalidArgument( "IRFFT requires innermost dimension matches fft_length/2+1, but " - "dimension %d is %lld and should be %lld", + "dimension %d is %lld and should be %lld.", in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1), fft_length[fft_rank - 1] / 2 + 1); } @@ -1850,8 +1863,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (int64 dimension : dimensions_to_reduce) { if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) { return InvalidArgument( - "attempting to reduce out-of-bounds dimension %lld in shape %s", - dimension, ShapeUtil::HumanString(arg).c_str()); + "Reducing out-of-bounds dimension %lld in shape %s.", dimension, + ShapeUtil::HumanString(arg).c_str()); } } TF_RETURN_IF_ERROR( @@ -1891,30 +1904,30 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // Check if the select function has a proper shape of (T,T) -> PRED. if (select_shape.parameters_size() != 2) { return InvalidArgument( - "select function must take 2 parameters, but " + "Select function must take 2 parameters, but " "takes %d parameter(s).", select_shape.parameters_size()); } const Shape& select_result_shape = select_shape.result(); if (!ShapeUtil::Compatible(select_result_shape, ShapeUtil::MakeShape(PRED, {}))) { - return Unimplemented("select function must have rank-0 PRED result."); + return InvalidArgument("Select function must have rank-0 PRED result."); } const Shape& operand_element_shape = ShapeUtil::MakeShape(operand_shape.element_type(), {}); if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, select_shape.parameters(0))) { return InvalidArgument( - "select function's first parameter shape currently must " - "match the operand element shape. Got %s vs %s", + "Select function's first parameter shape currently must " + "match the operand element shape, but got %s vs %s.", ShapeUtil::HumanString(select_shape.parameters(0)).c_str(), ShapeUtil::HumanString(operand_element_shape).c_str()); } if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, select_shape.parameters(1))) { return InvalidArgument( - "select function's second parameter shape currently must " - "match the operand element shape. Got %s vs %s", + "Select function's second parameter shape currently must " + "match the operand element shape, but got %s vs %s.", ShapeUtil::HumanString(select_shape.parameters(1)).c_str(), ShapeUtil::HumanString(operand_element_shape).c_str()); } @@ -1931,8 +1944,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape, window_result_shape)) { return InvalidArgument( - "source shape does not match the shape of window-reduced operand: " - "source(%s), window-reduced operand(%s)", + "Source shape does not match the shape of window-reduced operand: " + "source(%s), window-reduced operand(%s).", ShapeUtil::HumanString(source_shape).c_str(), ShapeUtil::HumanString(window_result_shape).c_str()); } @@ -1946,7 +1959,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( auto error = [&](const string& message) { return InvalidArgument( "%s in slice operation; argument shape: %s; starts: {%s}; limits: " - "{%s}; strides: {%s}", + "{%s}; strides: {%s}.", message.c_str(), ShapeUtil::HumanString(arg).c_str(), Join(starts, ",").c_str(), Join(limits, ",").c_str(), Join(strides, ",").c_str()); @@ -1969,7 +1982,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (starts.size() != ShapeUtil::Rank(arg)) { return InvalidArgument( - "slice index count does not match argument rank: %zu vs %lld", + "Slice index count does not match argument rank: %zu vs %lld.", starts.size(), ShapeUtil::Rank(arg)); } @@ -1979,7 +1992,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( int64 limit_index = limits[dimension]; int64 stride = strides[dimension]; if (start_index < 0) { - return InvalidArgument("negative start index to slice: %lld", + return InvalidArgument("Negative start index to slice: %lld.", start_index); } if (limit_index > arg.dimensions(dimension)) { @@ -1999,7 +2012,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( limit_index, start_index)); } if (stride <= 0) { - return InvalidArgument("stride (%lld) must be positive", stride); + return InvalidArgument("Stride (%lld) must be positive.", stride); } sizes.push_back((limit_index - start_index + stride - 1) / stride); } @@ -2023,20 +2036,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( - "dynamic slice start indices of rank %lld must be rank1.", + "Dynamic slice start indices of rank %lld must be rank1.", ShapeUtil::Rank(start_indices_shape)); } if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { return InvalidArgument( - "dynamic slice start indices must be of integral type."); + "Dynamic slice start indices must be of integral type."); } const int64 start_num_dims = start_indices_shape.dimensions(0); if (ShapeUtil::Rank(operand_shape) != start_num_dims) { return InvalidArgument( - "dynamic slice start number of dimensions %lld (%s) must match rank " - "%lld of slice input (%s)", + "Dynamic slice start number of dimensions %lld (%s) must match rank " + "%lld of slice input (%s).", start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(), ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape).c_str()); @@ -2044,7 +2057,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) { return InvalidArgument( - "dynamic slice index count does not match argument rank: %zu vs %lld", + "Dynamic slice index count does not match argument rank: %zu vs %lld.", slice_sizes.size(), ShapeUtil::Rank(operand_shape)); } @@ -2052,12 +2065,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const int64 input_dim_size = operand_shape.dimensions(dim); const int64 slice_dim_size = slice_sizes[dim]; if (slice_dim_size < 0) { - return InvalidArgument("negative size index to dynamic slice: %lld", + return InvalidArgument("Negative size index to dynamic slice: %lld.", slice_dim_size); } if (slice_dim_size > input_dim_size) { return InvalidArgument( - "slice dim size %lld greater than dynamic slice dimension: %lld", + "Slice dim size %lld greater than dynamic slice dimension: %lld.", slice_dim_size, input_dim_size); } VLOG(2) << tensorflow::strings::Printf("slice_sizes[%lld] = %lld", dim, @@ -2086,20 +2099,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( - "dynamic update slice start indices of rank %lld must be rank1.", + "Dynamic update slice start indices of rank %lld must be rank1.", ShapeUtil::Rank(start_indices_shape)); } if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { return InvalidArgument( - "dynamic update slice start indices must be of integral type."); + "Dynamic update slice start indices must be of integral type."); } const int64 start_num_dims = start_indices_shape.dimensions(0); if (ShapeUtil::Rank(operand_shape) != start_num_dims) { return InvalidArgument( - "dynamic slice start number of dimensions %lld (%s) must match rank " - "%lld of slice input (%s)", + "Dynamic update slice start number of dimensions %lld (%s) must match " + "rank %lld of slice input (%s).", start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(), ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape).c_str()); @@ -2107,16 +2120,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) { return InvalidArgument( - "dynamic update slice update rank does not match argument rank: " - "%lld vs %lld", + "Dynamic update slice update rank does not match argument rank: " + "%lld vs %lld.", ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape)); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, update_shape)) { return InvalidArgument( - "dynamic update slice update element type does not match argument. " - "operand.element_type: %s vs update.element_type: %s", + "Dynamic update slice update element type does not match argument. " + "operand.element_type: %s vs update.element_type: %s.", PrimitiveType_Name(operand_shape.element_type()).c_str(), PrimitiveType_Name(update_shape.element_type()).c_str()); } @@ -2126,12 +2139,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const int64 update_dim_size = update_shape.dimensions(dim); if (update_dim_size < 0) { return InvalidArgument( - "size index %lld to dynamic update slice must be >= 0", + "Size index %lld to dynamic update slice must be >= 0.", update_dim_size); } if (update_dim_size > input_dim_size) { return InvalidArgument( - "update dim size %lld greater than dynamic slice dimension: %lld", + "Update dim size %lld greater than dynamic slice dimension: %lld.", update_dim_size, input_dim_size); } VLOG(2) << tensorflow::strings::Printf("update_sizes[%lld] = %lld", dim, @@ -2151,7 +2164,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (int64 dimension : dimensions) { if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) { return InvalidArgument( - "one of the reverse dimensions (%lld) is out-of-bounds in shape %s", + "One of the reverse dimensions (%lld) is out-of-bounds in shape %s.", dimension, ShapeUtil::HumanString(operand_shape).c_str()); } } @@ -2162,14 +2175,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& arg, int64 index) { if (!ShapeUtil::IsTuple(arg)) { return InvalidArgument( - "cannot infer shape: attempting to index into non-tuple: %s", + "Cannot infer shape: attempting to index into non-tuple: %s.", ShapeUtil::HumanString(arg).c_str()); } if (index >= arg.tuple_shapes_size()) { return InvalidArgument( - "cannot infer shape: attempt to index out of tuple bounds: %lld " - ">= %d in shape %s", + "Cannot infer shape: attempt to index out of tuple bounds: %lld " + ">= %d in shape %s.", index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg).c_str()); } @@ -2181,17 +2194,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& init) { // Check the number of parameters for given computations. if (condition.parameters_size() != 1) { - return InvalidArgument("condition must take 1 arguments; got %d", + return InvalidArgument("Condition must take 1 arguments; got %d.", condition.parameters_size()); } if (body.parameters_size() != 1) { - return InvalidArgument("body must take 1 arguments; got %d", + return InvalidArgument("Body must take 1 arguments; got %d.", body.parameters_size()); } auto shape_string = [&]() { return tensorflow::strings::Printf( - "condition: %s; body: %s; init: %s", + "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition).c_str(), ShapeUtil::HumanString(body).c_str(), ShapeUtil::HumanString(init).c_str()); @@ -2199,15 +2212,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // Check the shapes of computation parameters and return types. if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) { - return InvalidArgument("condition must return a boolean; got %s", + return InvalidArgument("Condition must return a boolean; got %s.", shape_string().c_str()); } if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) || !ShapeUtil::Compatible(body.result(), body.parameters(0)) || !ShapeUtil::Compatible(body.result(), init)) { return InvalidArgument( - "the parameter of condition and body, the result of the body, and init " - "must all have the same shape; got %s", + "The parameter of condition and body, the result of the body, and init " + "must all have the same shape; got %s.", shape_string().c_str()); } @@ -2219,7 +2232,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& false_operand, const ProgramShape& true_computation, const ProgramShape& false_computation) { if (!ShapeUtil::ShapeIs(predicate, PRED, {})) { - return InvalidArgument("predicate must be a boolean; got %s.", + return InvalidArgument("Predicate must be a boolean; got %s.", ShapeUtil::HumanString(predicate).c_str()); } @@ -2302,8 +2315,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) { return InvalidArgument( - "reshape operation has mismatched element counts: from=%lld (%s) " - "to=%lld (%s)", + "Reshape operation has mismatched element counts: from=%lld (%s) " + "to=%lld (%s).", ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand).c_str(), ShapeUtil::ElementsIn(inferred_shape), ShapeUtil::HumanString(inferred_shape).c_str()); @@ -2351,7 +2364,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) { - return InvalidArgument("clamp op with different operand types: %s, %s, %s", + return InvalidArgument("Clamp with different operand types: %s, %s, %s.", ShapeUtil::HumanString(min).c_str(), ShapeUtil::HumanString(operand).c_str(), ShapeUtil::HumanString(max).c_str()); @@ -2372,7 +2385,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } } return Unimplemented( - "not yet implemented: %s, %s %s", min.ShortDebugString().c_str(), + "%s, %s %s is not implemented.", min.ShortDebugString().c_str(), max.ShortDebugString().c_str(), operand.ShortDebugString().c_str()); } @@ -2391,25 +2404,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } if (!compatible) { return InvalidArgument( - "operands to select must be the same shape; got %s and %s", + "Operands to select must be the same shape; got %s and %s.", ShapeUtil::HumanString(on_true).c_str(), ShapeUtil::HumanString(on_false).c_str()); } if (pred.element_type() != PRED) { return InvalidArgument( - "select's pred operand must have PRED element type; got %s", + "Select's pred operand must have PRED element type; got %s.", ShapeUtil::HumanString(pred).c_str()); } - if (ShapeUtil::SameDimensions(pred, on_true) || ShapeUtil::Rank(pred) == 0) { + if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) || + ShapeUtil::Rank(pred) == 0) { // By this stage we know that pred's element type is PRED. Therefore, this // check restricts pred to be a PRED scalar, or a PRED array with the same // dimensions as on_true and on_false. return ShapeUtil::ChangeElementType( on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false)); } else { - return Unimplemented( - "select operation with non-scalar predicate with dimensionality " - " different from the other operands: %s", + return InvalidArgument( + "Select operation with non-scalar predicate with dimensionality " + " different from the other operands: %s.", ShapeUtil::HumanString(pred).c_str()); } } @@ -2427,7 +2441,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Call applied function arity must match number of arguments; got: " "arity: %d, arguments: %zu; computation signature: %s; argument " - "shapes: [%s]", + "shapes: [%s].", to_apply.parameters_size(), arg_shapes.size(), computation_signature.c_str(), argument_shapes.c_str()); } @@ -2439,7 +2453,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::Compatible(arg_shape, param_shape)) { return InvalidArgument( "Call parameter must match argument; got parameter %d shape: %s, " - "argument shape: %s", + "argument shape: %s.", i, ShapeUtil::HumanString(param_shape).c_str(), ShapeUtil::HumanString(arg_shape).c_str()); } @@ -2454,40 +2468,40 @@ static Status ValidateGatherDimensionNumbers( const GatherDimensionNumbers& dim_numbers) { if (!c_is_sorted(dim_numbers.output_window_dims())) { return InvalidArgument( - "Output window dimensions in gather op must be ascending; got: %s", + "Output window dimensions in gather op must be ascending; got: %s.", Join(dim_numbers.output_window_dims(), ", ").c_str()); } if (c_adjacent_find(dim_numbers.output_window_dims()) != dim_numbers.output_window_dims().end()) { return InvalidArgument( - "Output window dimensions in gather op must not repeat; got: %s", + "Output window dimensions in gather op must not repeat; got: %s.", Join(dim_numbers.output_window_dims(), ", ").c_str()); } const int64 output_window_dim_count = dim_numbers.output_window_dims_size(); const int64 output_shape_rank = - output_window_dim_count + gather_indices_shape.size(); + output_window_dim_count + gather_indices_shape.size() - 1; for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) { int64 window_index = dim_numbers.output_window_dims(i); if (window_index < 0 || window_index >= output_shape_rank) { return InvalidArgument( "Window index %d in gather op is out of bounds; got %lld, but should " - "have been in" - "[0,%lld)", + "have been in [0,%lld).", i, window_index, output_shape_rank); } } if (dim_numbers.gather_dims_to_operand_dims_size() != - gather_indices_shape.back()) { + gather_indices_shape[dim_numbers.index_vector_dim()]) { return InvalidArgument( - "There must be exactly as many elements in gather_dims_to_operand_dims " - "as there are elements in the last dimension of %%gather_indices; got: " - "%d, expected %lld", + "Gather op has %d elements in gather_dims_to_operand_dims and the " + "bound of dimension index_vector_dim=%lld of gather_indices is " + "%lld. These two numbers must be equal.", dim_numbers.gather_dims_to_operand_dims_size(), - gather_indices_shape.back()); + dim_numbers.index_vector_dim(), + gather_indices_shape[dim_numbers.index_vector_dim()]); } for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) { @@ -2496,7 +2510,7 @@ static Status ValidateGatherDimensionNumbers( gather_dim_to_input_dim >= input_shape.dimensions_size()) { return InvalidArgument( "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), " - "got: %d->%lld", + "got: %d->%lld.", input_shape.dimensions_size(), i, gather_dim_to_input_dim); } } @@ -2511,7 +2525,7 @@ static Status ValidateGatherDimensionNumbers( sorted_gather_dims_to_operand_dims.end()) { return InvalidArgument( "Repeated dimensions are not allowed in gather_dims_to_operand_dims; " - "got: %s", + "got: %s.", Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str()); } @@ -2519,7 +2533,7 @@ static Status ValidateGatherDimensionNumbers( if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) { return InvalidArgument( "Invalid elided_window_dims set in gather op; valid range is [0, " - "%d), got: %lld", + "%d), got: %lld.", input_shape.dimensions_size(), elided_dim); } } @@ -2534,7 +2548,7 @@ static Status ValidateGatherDimensionNumbers( dim_numbers.elided_window_dims().end()) { return InvalidArgument( "Repeated dimensions not allowed in elided_window_dims in gather op; " - "got: %s", + "got: %s.", Join(dim_numbers.elided_window_dims(), ", ").c_str()); } @@ -2550,24 +2564,33 @@ static Status ValidateGatherDimensionNumbers( TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( gather_indices_shape, "gather indices operand of gather op")); - if (gather_indices_shape.dimensions_size() < 1) { + if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { return InvalidArgument( - "Gather indices parameter must at least of rank 1; got %s", + "Gather indices parameter must be an integral tensor; got %s.", ShapeUtil::HumanString(gather_indices_shape).c_str()); } - if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { + // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if + // index_vector_dim is rank(P). The bounds of this expanded shape is + // stored in expanded_gather_indices_shape. + + if (gather_indices_shape.dimensions_size() < + gather_dim_numbers.index_vector_dim() || + gather_dim_numbers.index_vector_dim() < 0) { return InvalidArgument( - "Gather indices parameter must be an integral tensor; got %s", - ShapeUtil::HumanString(gather_indices_shape).c_str()); + "Gather index leaf dimension must be within [0, rank(gather_indices) + " + "1). rank(gather_indices) is %d and gather index leaf dimension is " + "%lld.", + gather_indices_shape.dimensions_size(), + gather_dim_numbers.index_vector_dim()); } std::vector expanded_gather_indices_shape; - // We implicitly reshape gather indices of shape P[N] to P[N,1]. expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size()); c_copy(gather_indices_shape.dimensions(), std::back_inserter(expanded_gather_indices_shape)); - if (expanded_gather_indices_shape.size() == 1) { + if (expanded_gather_indices_shape.size() == + gather_dim_numbers.index_vector_dim()) { expanded_gather_indices_shape.push_back(1); } @@ -2577,7 +2600,7 @@ static Status ValidateGatherDimensionNumbers( if (window_bounds.size() != input_shape.dimensions_size()) { return InvalidArgument( "Gather op must have one window bound for every input dimension; got: " - "len(window_bounds)=%lu, input_shape.rank=%d", + "len(window_bounds)=%lu, input_shape.rank=%d.", window_bounds.size(), input_shape.dimensions_size()); } @@ -2587,7 +2610,7 @@ static Status ValidateGatherDimensionNumbers( return InvalidArgument( "All components of the window index in a gather op must either be a " "output window index or explicitly elided; got len(window_bounds)=%lu, " - "output_window_bounds=%s, elided_window_bounds=%s", + "output_window_bounds=%s, elided_window_bounds=%s.", window_bounds.size(), Join(gather_dim_numbers.output_window_dims(), ",").c_str(), Join(gather_dim_numbers.elided_window_dims(), ",").c_str()); @@ -2600,7 +2623,7 @@ static Status ValidateGatherDimensionNumbers( return InvalidArgument( "Window bound at index %d in gather op is out of range, must be " "within " - "[0, %lld), got %lld", + "[0, %lld), got %lld.", i, corresponding_input_bound + 1, window_bound); } } @@ -2609,7 +2632,7 @@ static Status ValidateGatherDimensionNumbers( if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) { return InvalidArgument( "Gather op can only elide window indices with bound 1, but bound is " - "%lld for index %lld at position %d", + "%lld for index %lld at position %d.", window_bounds[gather_dim_numbers.elided_window_dims(i)], gather_dim_numbers.elided_window_dims(i), i); } @@ -2632,6 +2655,9 @@ static Status ValidateGatherDimensionNumbers( } current_bound = window_bounds[window_dims_seen++]; } else { + if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) { + gather_dims_seen++; + } current_bound = expanded_gather_indices_shape[gather_dims_seen++]; } diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 0d3045213db2230da3e18ffcb1a9923250560b64..88830e6d2516cd664dd4e632adf0bdc72e451880 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -48,6 +48,8 @@ class ShapeInference { // given input shape. static StatusOr InferUnaryOpShape(UnaryOperation operation, const Shape& arg); + static StatusOr InferUnaryOpShape(HloOpcode opcode, + const Shape& shape); static StatusOr InferUnaryOpShape(HloOpcode opcode, const HloInstruction* operand); @@ -56,6 +58,9 @@ class ShapeInference { static StatusOr InferBinaryOpShape( BinaryOperation operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); + static StatusOr InferBinaryOpShape( + HloOpcode opcode, const Shape& lhs, const Shape& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); static StatusOr InferBinaryOpShape(HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs); @@ -65,6 +70,9 @@ class ShapeInference { static StatusOr InferTernaryOpShape(TernaryOperation operation, const Shape& lhs, const Shape& rhs, const Shape& ehs); + static StatusOr InferTernaryOpShape(HloOpcode opcode, const Shape& lhs, + const Shape& rhs, + const Shape& ehs); static StatusOr InferTernaryOpShape(HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs, diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 7eb120843fd841d841048eeaefd895fde96d133c..0e61994a786b53a295ef9c9c2287b28fbf754d9b 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -135,7 +135,7 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), - HasSubstr("operands to select must be the same shape")); + HasSubstr("Operands to select must be the same shape")); auto inferred_status_error2 = ShapeInference::InferTernaryOpShape( TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_); @@ -340,7 +340,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("source shape does not match")); + HasSubstr("Source shape does not match")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { @@ -351,7 +351,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("select function must take 2 parameters")); + HasSubstr("Select function must take 2 parameters")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { @@ -362,7 +362,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("select function must have rank-0 PRED")); + HasSubstr("Select function must have rank-0 PRED")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { @@ -373,7 +373,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("select function's first parameter")); + HasSubstr("Select function's first parameter")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { @@ -384,7 +384,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("select function's second parameter")); + HasSubstr("Select function's second parameter")); } TEST_F(ShapeInferenceTest, Convolve) { @@ -906,7 +906,7 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) { ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("dot only supports rank")); + HasSubstr("Dot only supports rank")); } // 3D 2D: error @@ -918,7 +918,7 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("batch and contracting dimension number mismatch")); + HasSubstr("Batch and contracting dimension number mismatch")); } // vector vector -> scalar @@ -1024,7 +1024,7 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("must specify one contracting dimension for both " + HasSubstr("Must specify one contracting dimension for both " "lhs and rhs")); } @@ -1044,7 +1044,7 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("batch dimension numbers and sizes must match")); + HasSubstr("Batch dimension numbers and sizes must match")); } // BatchMatMul with different batch dimension numbers fails. @@ -1063,7 +1063,7 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("batch dimension numbers must precede non-batch")); + HasSubstr("Batch dimension numbers must precede non-batch")); } // BatchMatMul with out-of-range dimension numbers fails. @@ -1166,42 +1166,42 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { BinaryOperation::BINOP_ADD, tensor, vec8, {}); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), - HasSubstr("automatic")); + HasSubstr("Automatic")); // broadcast_dimension out of bounds for tensor's rank auto inferred_status_error2 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, vec8, {3}); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), - ContainsRegex("broadcast dimension number .* too large")); + ContainsRegex("Broadcast dimension number .* too large")); // broadcast_dimension doesn't match corresponding dimension auto inferred_status_error3 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, vec8, {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - HasSubstr("broadcast dimension 0 mismatch")); + HasSubstr("Broadcast dimension 0 mismatch")); // broadcast_dimensions list too long auto inferred_status_error4 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2}); ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT(inferred_status_error4.status().error_message(), - HasSubstr("size of broadcast_dimensions has to match")); + HasSubstr("broadcast_dimensions has to match")); // there's a dimension above the rank of the tensor auto inferred_status_error5 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0}); ASSERT_FALSE(inferred_status_error5.ok()); ASSERT_THAT(inferred_status_error5.status().error_message(), - ContainsRegex("broadcast dimension number .* too large")); + ContainsRegex("dimension number .* too large")); // broadcasting dimensions don't match in this order auto inferred_status_error6 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1}); ASSERT_FALSE(inferred_status_error6.ok()); ASSERT_THAT(inferred_status_error6.status().error_message(), - HasSubstr("broadcast dimension 0 mismatch")); + HasSubstr("dimension 0 mismatch")); // The following two tests make sure that broadcasting dimensions are listed // in a proper (strictly increasing) order, even if the lower-rank array @@ -1210,13 +1210,13 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0}); ASSERT_FALSE(inferred_status_error7.ok()); ASSERT_THAT(inferred_status_error7.status().error_message(), - HasSubstr("broadcast dimensions order is wrong")); + HasSubstr("dimensions order is wrong")); auto inferred_status_error8 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0}); ASSERT_FALSE(inferred_status_error8.ok()); ASSERT_THAT(inferred_status_error8.status().error_message(), - HasSubstr("broadcast dimensions order is wrong")); + HasSubstr("dimensions order is wrong")); } // Tests for the while instruction with proper shapes. @@ -1242,7 +1242,7 @@ TEST_F(ShapeInferenceTest, WhileWithBadShapes) { ShapeInference::InferWhileShape(bad_shape_1, body, result_shape); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), - HasSubstr("condition must take 1 arguments")); + HasSubstr("Condition must take 1 arguments")); auto bad_shape_2 = ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape); @@ -1250,14 +1250,14 @@ TEST_F(ShapeInferenceTest, WhileWithBadShapes) { ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), - HasSubstr("body must take 1 arguments")); + HasSubstr("Body must take 1 arguments")); auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_); auto inferred_status_error3 = ShapeInference::InferWhileShape(bad_shape_3, body, result_shape); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - HasSubstr("condition must return a boolean")); + HasSubstr("Condition must return a boolean")); auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_); auto inferred_status_error4 = @@ -1301,13 +1301,13 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), - HasSubstr("dimension to concatenate along out of bounds: -1")); + HasSubstr("dimension out of bounds: -1")); auto inferred_status_error3 = ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - HasSubstr("dimension to concatenate along out of bounds: 1")); + HasSubstr("dimension out of bounds: 1")); Shape tuple = ShapeUtil::MakeTupleShape({vector_32_}); auto inferred_status_error4 = ShapeInference::InferConcatOpShape( @@ -1315,21 +1315,20 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT( inferred_status_error4.status().error_message(), - HasSubstr("Expected non-tuple argument for operand of concatenation.")); + HasSubstr("Expected non-tuple argument for operand of concatenation")); const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32}); auto inferred_status_error5 = ShapeInference::InferConcatOpShape( {&vector_32_, &vector_s32}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_THAT( - inferred_status_error5.status().error_message(), - HasSubstr("cannot concatenate arrays with different element types")); + ASSERT_THAT(inferred_status_error5.status().error_message(), + HasSubstr("concatenate arrays with different element types")); auto inferred_status_error6 = ShapeInference::InferConcatOpShape( {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error6.ok()); ASSERT_THAT(inferred_status_error6.status().error_message(), - HasSubstr("cannot concatenate arrays that differ in " + HasSubstr("concatenate arrays that differ in " "dimensions other than the one being " "concatenated")); } @@ -1467,7 +1466,7 @@ TEST_F(ShapeInferenceTest, Conditional) { ShapeUtil::MakeProgramShape({vector_64_}, f32_)); EXPECT_FALSE(inferred_status_error0.ok()); EXPECT_THAT(inferred_status_error0.status().error_message(), - HasSubstr("predicate must be a boolean")); + HasSubstr("Predicate must be a boolean")); auto inferred_status_error1 = ShapeInference::InferConditionalShape( pred_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_, @@ -1530,11 +1529,17 @@ TEST_F(ShapeInferenceTest, BadSlice) { class GatherShapeInferenceTest : public ShapeInferenceTest { protected: + const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); + const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5}); const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32}); const Shape s64_4d_tensor_10_9_8_7_1_ = ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}); const Shape s64_4d_tensor_10_9_8_7_5_ = ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); + const Shape s64_4d_tensor_5_10_9_7_6_ = + ShapeUtil::MakeShape(S64, {5, 10, 9, 7, 6}); + const Shape s64_4d_tensor_10_9_5_7_6_ = + ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); const Shape f32_5d_tensor_50_49_48_47_46_ = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( @@ -1548,7 +1553,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGather) { HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{0}, /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}), + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), /*window_bounds=*/{64, 1})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) @@ -1562,7 +1568,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{1}, /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}), + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1), /*window_bounds=*/{1, 48})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) @@ -1576,7 +1583,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4}, /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}), + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4), /*window_bounds=*/{1, 48})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) @@ -1591,7 +1599,8 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, @@ -1599,12 +1608,85 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { << ShapeUtil::HumanString(gather_shape); } +TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + EXPECT_TRUE(ShapeUtil::Equal( + gather_shape, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + EXPECT_TRUE(ShapeUtil::Equal( + gather_shape, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { + // This is equivalent to a dynamic slice. + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0, 1, 2, 3, 4}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + EXPECT_TRUE(ShapeUtil::Equal(gather_shape, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { + // The gather indices "tensor" is a scalar S here that's used to slice out + // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result. + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_scalar_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0, 1, 2, 3}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/0), + /*window_bounds=*/{1, 30, 29, 28, 27})); + + EXPECT_TRUE(ShapeUtil::Equal(gather_shape, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27}))) + << ShapeUtil::HumanString(gather_shape); +} + TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { StatusOr statusor = ShapeInference::InferGatherShape( tuple_shape_, s64_vector_32_, HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}), + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1617,7 +1699,8 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { s64_vector_32_, tuple_shape_, HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}), + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1625,25 +1708,13 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { << statusor.status(); } -TEST_F(GatherShapeInferenceTest, ScalarGatherIndicesInput) { - StatusOr statusor = ShapeInference::InferGatherShape( - s64_vector_32_, s32_, - HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}), - /*window_bounds=*/{64, 1}); - ASSERT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Gather indices parameter must at least of rank 1")) - << statusor.status(); -} - TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, vector_32_, HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}), + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1658,7 +1729,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 8, 7}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1674,7 +1746,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 7}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1690,7 +1763,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 99, 100, 101}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1698,6 +1772,22 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 9}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Window index 4 in gather op is out of bounds")) + << statusor.status(); +} + TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingElidedWindowDims) { StatusOr statusor = ShapeInference::InferGatherShape( @@ -1705,7 +1795,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{4}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1722,7 +1813,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{0, 1, 2, 3, 19}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1738,7 +1830,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{0, 1, 2, 3, 3}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1755,15 +1848,15 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), - HasSubstr( - "There must be exactly as many elements in " - "gather_dims_to_operand_dims " - "as there are elements in the last dimension of %gather_indices")) + HasSubstr("Gather op has 4 elements in gather_dims_to_operand_dims and " + "the bound of dimension index_vector_dim=4 of " + "gather_indices is 5. These two numbers must be equal.")) << statusor.status(); } @@ -1774,7 +1867,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1791,7 +1885,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1808,7 +1903,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{2, 1}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{1, 1, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1822,7 +1918,8 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7}, /*elided_window_dims=*/{2}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 1, 300, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1838,7 +1935,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( @@ -1855,7 +1953,8 @@ TEST_F(GatherShapeInferenceTest, HloInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7}, /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 26, 20}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1864,5 +1963,22 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } +TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/32), + /*window_bounds=*/{30, 29, 28, 27, 26}); + + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather index leaf dimension must be within [0, " + "rank(gather_indices) + 1)")) + << statusor.status(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 06735e9442942f3c69d1cd679857fe22f2fa6756..0dca30a804005c6f536aca5b54af24eb08d4560b 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -3315,20 +3315,23 @@ void ComputationLowerer::Visit( HloInstruction* rhs = lookup_instruction(ternary_op_request.rhs()); HloInstruction* ehs = lookup_instruction(ternary_op_request.ehs()); auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop()); - - if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { - if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { + if (debug_options_.xla_eliminate_hlo_implicit_broadcast() && + !ShapeUtil::IsTuple(request.output_shape())) { + if (!ShapeUtil::IsTuple(lhs->shape()) && + !ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { // lhs side is being implicitly broadcast. Change to explicit. lhs = ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); } - if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { + if (!ShapeUtil::IsTuple(rhs->shape()) && + !ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { rhs = ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); } - if (!ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) { + if (!ShapeUtil::IsTuple(ehs->shape()) && + !ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) { ehs = ImplicitBroadcastToExplicitBroadcast(ehs, request.output_shape()); } diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index a5f9b01f011ce04f1114c74391a967c62f015221..3ef0cdff6751258e4489ce350deb0931fdf69ef9 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -106,20 +106,12 @@ static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { case HloOpcode::kBitcast: case HloOpcode::kBroadcast: case HloOpcode::kConstant: + case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kSlice: + case HloOpcode::kTranspose: case HloOpcode::kTuple: return true; - - case HloOpcode::kTranspose: - return ShapeUtil::TransposeIsBitcast( - /*input_shape=*/instruction.operand(0)->shape(), - /*output_shape=*/instruction.shape(), instruction.dimensions()); - - case HloOpcode::kReshape: - return ShapeUtil::ReshapeIsBitcast( - /*input_shape=*/instruction.operand(0)->shape(), - /*output_shape=*/instruction.shape()); } } diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 981de9b2200a9ae8938db21299580f510834d2f0..ec05a74e286c89dd8db5ae07580e461938d7c087 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -212,7 +213,7 @@ static optional GetLoopTripCount(HloInstruction* while_op) { // Now that we know the index of the induction variable, we can we can try to // compute how many times the loop executes. Start by computing the induction // variable's initial value. - HloEvaluator evaluator; + HloEvaluator evaluator(/*max_loop_iterations=*/0); auto* while_init = while_op->mutable_operand(0); auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); StatusOr> indvar_init_result = @@ -605,6 +606,78 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { return false; } +static StatusOr TryPropagateConstant(HloInstruction* while_op) { + auto while_init = while_op->operand(0); + if (while_init->opcode() != HloOpcode::kTuple) { + return false; + } + + auto while_body = while_op->while_body(); + auto while_body_root = while_body->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + return false; + } + + auto while_body_param = while_body->parameter_instruction(0); + const HloInstruction::InstructionVector& root_operands = + while_body_root->operands(); + + // Find the loop invariant tuple elements with scalar constant init value and + // build a map from the tuple element index to the constant value. Limit this + // to scalar constant values because propagating array constants can regress + // performance by forcing us to copy constants. + tensorflow::gtl::FlatMap index_to_constant; + for (int i = 0; i < root_operands.size(); i++) { + HloInstruction* instr = root_operands[i]; + if (instr->opcode() == HloOpcode::kGetTupleElement && + instr->tuple_index() == i && instr->operand(0) == while_body_param && + ShapeUtil::IsScalar(instr->shape())) { + auto tuple_element = while_init->operand(i); + if (tuple_element->IsConstant()) { + VLOG(3) << "Found loop invariant tuple element " << i << " " + << tuple_element->ToString(); + index_to_constant[i] = tuple_element; + } + } + } + + if (index_to_constant.empty()) { + return false; + } + + // Replace the use of each constant tuple element in the loop_condition and + // loop_body with the corresponding constant value. + auto propagate_constant = [&](HloComputation* computation) -> StatusOr { + HloInstruction* param = computation->parameter_instruction(0); + bool changed = false; + for (auto instr : param->users()) { + // Since only a while-loop with a tuple result reaches here, we can safely + // assume that `param` is a tuple and the first operand of the + // GetTupleElement instruction is a use of `param`. + if (instr->opcode() == HloOpcode::kGetTupleElement) { + VLOG(3) << "tuple index " << instr->tuple_index() << " " + << instr->ToString(); + auto iter = index_to_constant.find(instr->tuple_index()); + if (iter != index_to_constant.end()) { + const HloInstruction* hlo_constant = (*iter).second; + VLOG(3) << "Replace use of " << instr->ToString() << " with " + << hlo_constant->ToString(); + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith( + computation->AddInstruction(hlo_constant->Clone()))); + changed = true; + } + } + } + return changed; + }; + + TF_ASSIGN_OR_RETURN(bool changed_cond, + propagate_constant(while_op->while_condition())); + TF_ASSIGN_OR_RETURN(bool changed_body, propagate_constant(while_body)); + + return changed_cond || changed_body; +} + StatusOr WhileLoopSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(3, "WhileLoopSimplifier::Run(), before:\n" + module->ToString()); @@ -635,7 +708,11 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { continue; } - StatusOr result = TryRemoveWhileLoop(while_op); + StatusOr result = TryPropagateConstant(while_op); + TF_RETURN_IF_ERROR(result.status()); + changed |= result.ValueOrDie(); + + result = TryRemoveWhileLoop(while_op); TF_RETURN_IF_ERROR(result.status()); if (result.ValueOrDie()) { changed = true; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index d3d55634c97bbdf3f81321d8089bb808c411340b..3d3e1d60f294c3a2574513c1c2f071805a341ad1 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -25,7 +25,7 @@ namespace xla { // HLO pass that makes the following transformations on while loops: // // - A while loop with static trip count of 0 is deleted. -// - A while loops with static trip count of 1 is replaced by its body (sans +// - A while loop with static trip count of 1 is replaced by its body (sans // loop). // - Elements of a while loop's tuple that the loop doesn't use are removed // from the tuple. diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index c5183f8d3aee99696ed4114c3f7e451888222137..619e87caa5b6d0f6ec3c3b1489b0d4f50ef29963 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { @@ -26,112 +27,140 @@ namespace { namespace op = xla::testing::opcode_matchers; class WhileLoopSimplifierTest : public HloVerifiedTestBase { - public: - // Makes a computation that contains a loop that runs num_iters times. - HloComputation* MakeSimpleLoop(int num_iters, HloModule* module); - - // Makes a computation which has one parameter, of the given shape, and always - // returns PRED[]{true}. This is useful as a dummy loop condition. - HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape, - HloModule* module); + protected: + // Makes an HloModule that contains a loop with `num_iters` iteration. + void MakeModuleWithSimpleLoop(int num_iters); + + // Similar to MakeModuleWithSimpleLoop except that the loop bound is passed to + // the loop-condition through an element of a tuple which is the + // loop-condition parameter. + void MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters); }; -HloComputation* WhileLoopSimplifierTest::MakeSimpleLoop(int num_iters, - HloModule* module) { - HloComputation::Builder builder(TestName()); - - auto loop_iter_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - auto loop_data_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 1, 2}))); - auto loop_init = builder.AddInstruction( - HloInstruction::CreateTuple({loop_iter_init, loop_data_init})); - - HloComputation* condition; - { - HloComputation::Builder cond_builder(TestName() + ".condition"); - auto loop_var = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - auto loop_induction_var = - cond_builder.AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::MakeShape(S32, {}), loop_var, 0)); - auto limit = cond_builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(42 + num_iters))); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, loop_induction_var, - limit)); - condition = module->AddEmbeddedComputation(cond_builder.Build()); +void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { + string hlo_string_template = R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) } - - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - auto loop_var = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - auto loop_induction_var = - body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::MakeShape(S32, {}), loop_var, 0)); - auto new_loop_induction_var = - body_builder.AddInstruction(HloInstruction::CreateBinary( - loop_induction_var->shape(), HloOpcode::kAdd, loop_induction_var, - body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))))); - auto loop_data = - body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( - loop_data_init->shape(), loop_var, 1)); - auto new_loop_data = - body_builder.AddInstruction(HloInstruction::CreateBinary( - loop_data_init->shape(), HloOpcode::kMultiply, loop_data, - loop_data)); - body_builder.AddInstruction( - HloInstruction::CreateTuple({new_loop_induction_var, new_loop_data})); - body = module->AddEmbeddedComputation(body_builder.Build()); + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant({{LOOP_BOUND}}) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) } + ENTRY SimpleLoop { + constant.3 = s32[] constant(42) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + } + )"; - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - - return module->AddEntryComputation(builder.Build()); + string hlo_string = tensorflow::str_util::StringReplace( + hlo_string_template, "{{LOOP_BOUND}}", + tensorflow::strings::StrCat(42 + num_iters), + /*replace_all=*/true); + ParseAndVerifyModule(hlo_string); } -HloComputation* WhileLoopSimplifierTest::MakeAlwaysTrueComputation( - const Shape& param_shape, HloModule* module) { - HloComputation::Builder builder(TestName() + ".always_true"); - builder.AddInstruction( - HloInstruction::CreateParameter(0, param_shape, "param")); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); - return module->AddEmbeddedComputation(builder.Build()); +void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( + int num_iters) { + string hlo_string_template = R"( + HloModule SimpleLoopWithIndirectLoopBound + SimpleLoopWithIndirectLoopBound.body { + loop_var.1 = (s32[], s32[3]{0}, s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + limit = s32[] get-tuple-element(loop_var.1), index=2 + ROOT tuple = (s32[], s32[3]{0}, s32[]) tuple(add, multiply, limit) + } + SimpleLoopWithIndirectLoopBound.condition { + loop_var.2 = (s32[], s32[3]{0}, s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=2 + ROOT less-than = pred[] less-than(get-tuple-element.3, get-tuple-element.4) + } + ENTRY SimpleLoopWithIndirectLoopBound { + constant.3 = s32[] constant(42) + constant.4 = s32[3]{0} constant({0, 1, 2}) + constant.2 = s32[] constant({{LOOP_BOUND}}) + tuple.1 = (s32[], s32[3]{0}, s32[]) tuple(constant.3, constant.4, + constant.2) + ROOT while = (s32[], s32[3]{0}, s32[]) while(tuple.1), + condition=SimpleLoopWithIndirectLoopBound.condition, + body=SimpleLoopWithIndirectLoopBound.body + } + )"; + + string hlo_string = tensorflow::str_util::StringReplace( + hlo_string_template, "{{LOOP_BOUND}}", + tensorflow::strings::StrCat(42 + num_iters), + /*replace_all=*/true); + ParseAndVerifyModule(hlo_string); } -TEST_F(WhileLoopSimplifierTest, WhileLoopWithZeroIterations) { - HloComputation* computation = MakeSimpleLoop(/*num_iters=*/0, &module()); - ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), +TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) { + MakeModuleWithSimpleLoop(/*num_iters=*/0); + HloModule* the_module = &module(); + ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_THAT(the_module->entry_computation()->root_instruction(), op::Tuple(op::Constant(), op::Constant())); } -TEST_F(WhileLoopSimplifierTest, WhileLoopWithOneIteration) { - HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); - ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), +TEST_F(WhileLoopSimplifierTest, + LoopWithZeroIterationTupleElementLoopBoundSimplified) { + MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0); + HloModule* the_module = &module(); + ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_THAT(the_module->entry_computation()->root_instruction(), + op::Tuple(op::Constant(), op::Constant(), op::Constant())); +} + +TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) { + MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloModule* the_module = &module(); + ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_THAT(the_module->entry_computation()->root_instruction(), op::Tuple(op::Add(), op::Multiply())); } -TEST_F(WhileLoopSimplifierTest, WhileLoopWithTwoIterations) { - MakeSimpleLoop(/*num_iters=*/2, &module()); +TEST_F(WhileLoopSimplifierTest, + LoopWithOneIterationTupleELementLoopBoundSimplified) { + MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1); + HloModule* the_module = &module(); + ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_THAT(the_module->entry_computation()->root_instruction(), + op::Tuple(op::Add(), op::Multiply(), op::Constant())); +} + +TEST_F(WhileLoopSimplifierTest, LoopWithTwoIterationsNotSimplified) { + MakeModuleWithSimpleLoop(/*num_iters=*/2); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } -TEST_F(WhileLoopSimplifierTest, WhileLoopWithControlDependency) { - HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); +TEST_F(WhileLoopSimplifierTest, + LoopWithControlDependencySimplifiedDependencyPreserved) { + MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloModule* the_module = &module(); + HloComputation* computation = the_module->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* true_op = while_op->while_body()->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(true))); TF_ASSERT_OK(true_op->AddControlDependencyTo( while_op->while_body()->root_instruction())); - ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); EXPECT_THAT(computation->root_instruction()->control_predecessors(), ElementsAre(op::Constant())) << computation->ToString(); @@ -139,8 +168,10 @@ TEST_F(WhileLoopSimplifierTest, WhileLoopWithControlDependency) { // Loops that contain send/recv nodes can't be simplified; the loop structure // around send/recv nodes must be preserved. -TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) { - HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); +TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) { + MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloModule* the_module = &module(); + HloComputation* computation = the_module->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); @@ -149,11 +180,13 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) { HloInstruction::CreateConstant(Literal::CreateR0(true))), /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateSendDone(send)); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); } -TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) { - HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); +TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { + MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloModule* the_module = &module(); + HloComputation* computation = the_module->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); @@ -161,247 +194,217 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) { HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateRecvDone(recv)); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); } // The limitation on not being able to simplify loops that contain infeeds (and // other non-removable instructions) isn't fundamental -- it just stems from the // fact that our infrastructure sees simplifying such a loop as tantamount to // removing the non-removable instruction. -TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { - HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); +TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) { + MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloModule* the_module = &module(); + HloComputation* computation = the_module->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); while_body->AddInstruction( HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); } -// Check that we don't crash when given a loop whose shape is not a tuple. -TEST_F(WhileLoopSimplifierTest, IgnoreNonTupleShapedLoop) { - HloComputation::Builder builder(TestName()); - auto loop_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - - HloComputation* condition; - { - HloComputation::Builder cond_builder(TestName() + ".condition"); - auto param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, - cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(100))))); - condition = module().AddEmbeddedComputation(cond_builder.Build()); +// A non-tuple shaped loop shouldn't be simplified or crash the compiler. +TEST_F(WhileLoopSimplifierTest, NonTupleShapedLoopNotSimplified) { + const string hlo_string = R"( + HloModule NonTupleShapedLoop + NonTupleShapedLoop.body { + loop_var.1 = s32[] parameter(0) + constant.1 = s32[] constant(-1) + ROOT add = s32[] add(s32[] loop_var.1, s32[] constant.1) + } + NonTupleShapedLoop.condition { + loop_var = s32[] parameter(0) + constant = s32[] constant(100) + ROOT less-than = pred[] less-than(s32[] loop_var, s32[] constant) + } + ENTRY INonTupleShapedLoop { + constant.2 = s32[] constant(42) + ROOT while = s32[] while(s32[] constant.2), + condition=NonTupleShapedLoop.condition, + body=NonTupleShapedLoop.body } + )"; - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - auto param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - body_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, - body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-1))))); - body = module().AddEmbeddedComputation(body_builder.Build()); - } - - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - - module().AddEntryComputation(builder.Build()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } -// Construct a loop where we swap the tuple elements in each iteration. -// Although the tuple elements aren't used in the loop, we don't eliminate them, -// because the swapping side-effect is visible to users of the loop. -TEST_F(WhileLoopSimplifierTest, SwapTupleIndices) { - HloComputation::Builder builder(TestName()); - auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({ - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))), - })); - - HloComputation* condition = - MakeAlwaysTrueComputation(loop_init->shape(), &module()); - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - auto param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - body_builder.AddInstruction(HloInstruction::CreateTuple({ - body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)), - body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)), - })); - body = module().AddEmbeddedComputation(body_builder.Build()); +// A while loop that does nothing else besides swapping tuple elements +// can't be simplified as the result of the swapping is visible to users of the +// loop. +TEST_F(WhileLoopSimplifierTest, LoopSwappingTupleElementsNotSimplified) { + const string hlo_string = R"( + HloModule SwappingTupleElements + SwappingTupleElements.body { + loop_var = (s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element((s32[], s32[]) loop_var),index=1 + get-tuple-element.1 = s32[] get-tuple-element((s32[], s32[]) loop_var), + index=0 + ROOT tuple = (s32[], s32[]) tuple(s32[] get-tuple-element, + s32[] get-tuple-element.1) } + SwappingTupleElements.always_true { + param = (s32[], s32[]) parameter(0) + ROOT constant = pred[] constant(true) + } + ENTRY SwappingTupleElements { + x = s32[] parameter(0) + y = s32[] parameter(1) + tuple.1 = (s32[], s32[]) tuple(s32[] x, s32[] y) + ROOT while = (s32[], s32[]) while((s32[], s32[]) tuple.1), + condition=SwappingTupleElements.always_true, + body=SwappingTupleElements.body + } + )"; - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - - module().AddEntryComputation(builder.Build()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } // Construct a loop where we assign a constant to tuple element 0 in each // iteration. We can't eliminate tuple element 0, even though we never use its // value. -TEST_F(WhileLoopSimplifierTest, UnusedButModifiedTupleElement) { - HloComputation::Builder builder(TestName()); - auto loop_init = builder.AddInstruction( - HloInstruction::CreateTuple({builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0)))})); - - HloComputation* condition = - MakeAlwaysTrueComputation(loop_init->shape(), &module()); - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - body_builder.AddInstruction(HloInstruction::CreateTuple({ - body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))), - })); - body = module().AddEmbeddedComputation(body_builder.Build()); +TEST_F(WhileLoopSimplifierTest, + LoopWithUnusedButModifiedTupleElementNotSimplified) { + const string hlo_string = R"( + HloModule UnusedButModifiedTupleElement + UnusedButModifiedTupleElement.body { + loop_var = (s32[]) parameter(0) + constant.1 = s32[] constant(1) + ROOT tuple = (s32[]) tuple(s32[] constant.1) } + UnusedButModifiedTupleElement.always_true { + param = (s32[]) parameter(0) + ROOT constant = pred[] constant(true) + } + ENTRY UnusedButModifiedTupleElement { + constant.2 = s32[] constant(0) + tuple.1 = (s32[]) tuple(s32[] constant.2) + ROOT while = (s32[]) while((s32[]) tuple.1), + condition=UnusedButModifiedTupleElement.always_true, + body=UnusedButModifiedTupleElement.body + } + )"; - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - - module().AddEntryComputation(builder.Build()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } // Nothing to simplify in a while loop whose tuple has 0 elements. -TEST_F(WhileLoopSimplifierTest, EmptyTuple) { - HloComputation::Builder builder(TestName()); - auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({})); - - HloComputation* condition = - MakeAlwaysTrueComputation(loop_init->shape(), &module()); - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - body_builder.AddInstruction(HloInstruction::CreateTuple({})); - body = module().AddEmbeddedComputation(body_builder.Build()); +TEST_F(WhileLoopSimplifierTest, LoopWithEmptyTupleNotSimplified) { + const string hlo_string = R"( + HloModule EmptyTuple + EmptyTuple.body { + loop_var = () parameter(0) + ROOT tuple = () tuple() + } + EmptyTuple.always_true { + param = () parameter(0) + ROOT constant = pred[] constant(true) + } + ENTRY EmptyTuple { + tuple.1 = () tuple() + ROOT while = () while(() tuple.1), condition=EmptyTuple.always_true, + body=EmptyTuple.body } + )"; - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - module().AddEntryComputation(builder.Build()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } // While loop where one tuple element is used twice in the body, and thus can't // be simplified away. -TEST_F(WhileLoopSimplifierTest, ElemUsedTwice) { - HloComputation::Builder builder(TestName()); - auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({ - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))), - })); - - HloComputation* condition = - MakeAlwaysTrueComputation(loop_init->shape(), &module()); - - auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - auto* param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "param0")); - auto* gte0 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/0)); - // get0 is used twice in the loop body's tuple. - body_builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte0})); - body = module().AddEmbeddedComputation(body_builder.Build()); +TEST_F(WhileLoopSimplifierTest, LoopWithElemUsedTwiceNotSimplified) { + const string hlo_string = R"( + HloModule ElemUsedTwice + ElemUsedTwice.body { + param0 = (s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element((s32[], s32[]) param0), index=0 + ROOT tuple = (s32[], s32[]) tuple(s32[] get-tuple-element, + s32[] get-tuple-element) + } + ElemUsedTwice.always_true { + param = (s32[], s32[]) parameter(0) + ROOT constant = pred[] constant(true) } + ENTRY ElemUsedTwice { + x = s32[] parameter(0) + y = s32[] parameter(1) + tuple.1 = (s32[], s32[]) tuple(s32[] x, s32[] y) + ROOT while = (s32[], s32[]) while((s32[], s32[]) tuple.1), + condition=ElemUsedTwice.always_true, body=ElemUsedTwice.body + } + )"; - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - module().AddEntryComputation(builder.Build()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } // This while loop has three tuple elements. Element 0 is unused and should be // removed. Element 1 is used by the loop body, and element 2 is used by the // loop condition; these two should stay. -TEST_F(WhileLoopSimplifierTest, RemoveUnusedOperand) { - HloComputation::Builder builder(TestName()); - auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({ - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), - })); - auto loop_shape = loop_init->shape(); - auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - - HloComputation* condition; - { - HloComputation::Builder cond_builder(TestName() + ".loop_condition"); - auto param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_shape, "param0")); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, - cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), - cond_builder.AddInstruction(HloInstruction::CreateGetTupleElement( - scalar_s32, param, /*index=*/2)))); - condition = module().AddEmbeddedComputation(cond_builder.Build()); +TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) { + const string hlo_string = R"( + HloModule RemoveUnusedOperands + RemoveUnusedOperands.body { + loop_var = (s32[], s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element((s32[], s32[], + s32[]) loop_var), index=0 + get-tuple-element.2 = s32[] get-tuple-element((s32[], s32[], + s32[]) loop_var), index=1 + constant.1 = s32[] constant(1) + add = s32[] add(s32[] get-tuple-element.2, s32[] constant.1) + get-tuple-element.3 = s32[] get-tuple-element((s32[], s32[], s32[]) + loop_var), index=2 + ROOT tuple = (s32[], s32[], s32[]) tuple(s32[] get-tuple-element.1, + s32[] add, s32[] get-tuple-element.3) } - - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - auto* param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_shape, "loop_var")); - - auto* tuple0 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/0)); - auto* tuple1 = body_builder.AddInstruction(HloInstruction::CreateBinary( - scalar_s32, HloOpcode::kAdd, - body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( - scalar_s32, param, /*index=*/1)), - body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))))); - auto* tuple2 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/2)); - body_builder.AddInstruction( - HloInstruction::CreateTuple({tuple0, tuple1, tuple2})); - - body = module().AddEmbeddedComputation(body_builder.Build()); + RemoveUnusedOperands.loop_condition { + constant.2 = s32[] constant(0) + param0 = (s32[], s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element((s32[], s32[], s32[]) param0), + index=2 + ROOT equal-to = pred[] equal-to(s32[] constant.2, s32[] get-tuple-element) } + ENTRY RemoveUnusedOperands { + x = s32[] parameter(0) + constant.3 = s32[] constant(0) + y = s32[] parameter(1) + tuple.1 = (s32[], s32[], s32[]) tuple(s32[] x, s32[] constant.3, + s32[] y) + ROOT while = (s32[], s32[], s32[]) while((s32[], s32[], s32[]) tuple.1), + condition=RemoveUnusedOperands.loop_condition, + body=RemoveUnusedOperands.body + } + )"; + + ParseAndVerifyModule(hlo_string); + HloModule* the_module = &module(); + EXPECT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + + // The original while instruction is still left in the module as a dead + // instruction, find a while instruction with a different name as the new + // while instruction. + HloInstruction* new_while_op = + *std::find_if(the_module->entry_computation()->instructions().begin(), + the_module->entry_computation()->instructions().end(), + [&](const HloInstruction* instr) { + return (instr->opcode() == HloOpcode::kWhile && + instr->name() != "while"); + }); - auto* while_op = builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - - module().AddEntryComputation(builder.Build()); - EXPECT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); - - // We leave most of the checking to HloVerifiedTestBase, which runs the - // verifier on module() at the end of this test. - HloInstruction* new_while_op = *std::find_if( - module().entry_computation()->instructions().begin(), - module().entry_computation()->instructions().end(), - [&](const HloInstruction* instr) { - return instr != while_op && instr->opcode() == HloOpcode::kWhile; - }); + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); EXPECT_TRUE( ShapeUtil::Equal(new_while_op->shape(), ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}))) @@ -418,31 +421,91 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedOperand) { op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1))); } -TEST_F(WhileLoopSimplifierTest, BodyHasNonTupleRoot) { - auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); - - HloComputation* while_body = [&]() { - HloComputation::Builder builder(TestName() + ".passthrough"); - HloInstruction* param = builder.AddInstruction( - HloInstruction::CreateParameter(0, while_shape, "param")); - HloComputation* result = module().AddEmbeddedComputation(builder.Build()); - - result->AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); - return result; - }(); - - HloComputation::Builder builder(TestName()); - auto* init_value = builder.AddInstruction( - HloInstruction::CreateParameter(0, while_shape, "init_value")); - builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - module().AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopSimplifier{}.Run(&module())); - EXPECT_FALSE(simplified_loop); +TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) { + const string hlo_string = R"( + HloModule BodyHasNonTupleRoot + BodyHasNonTupleRoot.passthrough { + ROOT param = (s32[], s32[]) parameter(0) + } + BodyHasNonTupleRoot.always_true { + param.1 = (s32[], s32[]) parameter(0) + ROOT constant = pred[] constant(true) + } + ENTRY BodyHasNonTupleRoot { + init_value = (s32[], s32[]) parameter(0) + ROOT while = (s32[], s32[]) while((s32[], s32[]) init_value), + condition=BodyHasNonTupleRoot.always_true, + body=BodyHasNonTupleRoot.passthrough + } + )"; + + ParseAndVerifyModule(hlo_string); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(WhileLoopSimplifierTest, + LoopWithNonTupleBodyRootInstructionNotSimplified) { + const string hlo_string = R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT custom-call = (s32[], s32[3]{0}) custom-call(add, multiply), + custom_call_target="x" + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(44) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(42) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + } + )"; + + ParseAndVerifyModule(hlo_string); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) { + const string hlo_string = R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}, s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + get-tuple-element.3 = s32[3]{0} get-tuple-element(loop_var.1), index=2 + add.2 = s32[3]{0} add(get-tuple-element.2, get-tuple-element.3) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, add.2, get-tuple-element.3) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}, s32[3]{0}) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(47) + ROOT less-than = pred[] less-than(get-tuple-element.4, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(42) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4, constant.4) + ROOT while = (s32[], s32[3]{0}, s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + } + )"; + + ParseAndVerifyModule(hlo_string); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } } // namespace diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index e20b25e4a08a946f6b58575a4d4e557744f8035c..bd0794184328b7926543c4275b3b915f51e7b812 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -15,18 +15,21 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/tuple_util.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace xla { +using tensorflow::strings::StrCat; + static StatusOr WidenWhileCondition( HloComputation* narrow_condition, const Shape& wide_shape) { const Shape& narrow_shape = narrow_condition->parameter_instruction(0)->shape(); HloComputation* wide_while_cond = [&]() { - HloComputation::Builder builder( - tensorflow::strings::StrCat("wide.", narrow_condition->name())); + HloComputation::Builder builder(StrCat("wide.", narrow_condition->name())); builder.AddInstruction( HloInstruction::CreateParameter(0, wide_shape, "wide_param")); @@ -57,8 +60,7 @@ WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape(); HloComputation* wide_while_body = [&]() { - HloComputation::Builder builder( - tensorflow::strings::StrCat("wide.", narrow_body->name())); + HloComputation::Builder builder(StrCat("wide.", narrow_body->name())); builder.AddInstruction( HloInstruction::CreateParameter(0, wide_shape, "wide_param")); return narrow_body->parent()->AddEmbeddedComputation(builder.Build()); @@ -137,4 +139,109 @@ WhileUtil::MakeInstructionsLiveIn( return std::move(result); } + +static StatusOr> +MakeCountedLoopConditionComputation(const Shape& loop_state_shape, + int32 trip_count) { + Shape scalar_pred = ShapeUtil::MakeShape(PRED, {}); + + TF_ASSIGN_OR_RETURN(std::unique_ptr cond_computation, + CreateComputationWithSignature( + {&loop_state_shape}, scalar_pred, "while_cond")); + + HloInstruction* trip_count_constant = cond_computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(trip_count))); + + HloInstruction* param = cond_computation->parameter_instruction(0); + TF_ASSIGN_OR_RETURN(HloInstruction * indvar, + MakeGetTupleElementHlo(param, 0)); + + TF_ASSIGN_OR_RETURN( + HloInstruction * compare, + MakeBinaryHlo(HloOpcode::kLt, indvar, trip_count_constant)); + cond_computation->set_root_instruction(compare); + return std::move(cond_computation); +} + +static StatusOr> MakeCountedLoopBodyComputation( + const Shape& loop_state_shape, + const std::function( + HloInstruction*, const WhileUtil::LoopStateTy&)>& loop_body_generator) { + TF_ASSIGN_OR_RETURN(std::unique_ptr body_computation, + CreateComputationWithSignature( + {&loop_state_shape}, loop_state_shape, "while_body")); + HloInstruction* one = body_computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction* param = body_computation->parameter_instruction(0); + TF_ASSIGN_OR_RETURN(HloInstruction * indvar, + MakeGetTupleElementHlo(param, 0)); + TF_ASSIGN_OR_RETURN(HloInstruction * next_indvar, + MakeBinaryHlo(HloOpcode::kAdd, indvar, one)); + + std::vector loop_body_generator_args; + for (int64 i = 1, e = loop_state_shape.tuple_shapes_size(); i < e; i++) { + TF_ASSIGN_OR_RETURN(HloInstruction * tuple_element, + MakeGetTupleElementHlo(param, i)); + loop_body_generator_args.push_back(tuple_element); + } + TF_ASSIGN_OR_RETURN(std::vector next_state, + loop_body_generator(indvar, loop_body_generator_args)); + next_state.insert(next_state.begin(), next_indvar); + HloInstruction* next_state_tuple = + body_computation->AddInstruction(HloInstruction::CreateTuple(next_state)); + body_computation->set_root_instruction(next_state_tuple); + + return std::move(body_computation); +} + +static StatusOr MakeInitTupleFromInitValues( + HloComputation* computation, const WhileUtil::LoopStateTy& init_values) { + std::vector init_values_with_indvar; + init_values_with_indvar.reserve(init_values.size() + 1); + HloInstruction* zero = computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))); + init_values_with_indvar.push_back(zero); + c_copy(init_values, std::back_inserter(init_values_with_indvar)); + return computation->AddInstruction( + HloInstruction::CreateTuple(init_values_with_indvar)); +} + +static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { + std::vector loop_state_shape_components; + loop_state_shape_components.reserve(init_values.size() + 1); + loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {})); + c_transform(init_values, std::back_inserter(loop_state_shape_components), + [](HloInstruction* instr) { return instr->shape(); }); + return ShapeUtil::MakeTupleShape(loop_state_shape_components); +} + +/*static*/ StatusOr WhileUtil::MakeCountedLoop( + HloComputation* computation, int32 trip_count, + const WhileUtil::LoopStateTy& init_values, + const WhileUtil::LoopBodyGeneratorTy& loop_body_generator) { + CHECK_GE(trip_count, 0); + + Shape loop_state_shape = MakeLoopStateShape(init_values); + TF_ASSIGN_OR_RETURN( + std::unique_ptr cond, + MakeCountedLoopConditionComputation(loop_state_shape, trip_count)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr body, + MakeCountedLoopBodyComputation(loop_state_shape, loop_body_generator)); + TF_ASSIGN_OR_RETURN(HloInstruction * init_tuple, + MakeInitTupleFromInitValues(computation, init_values)); + HloModule* module = computation->parent(); + HloInstruction* while_instr = + computation->AddInstruction(HloInstruction::CreateWhile( + loop_state_shape, module->AddEmbeddedComputation(std::move(cond)), + module->AddEmbeddedComputation(std::move(body)), init_tuple)); + + std::vector result; + for (int64 i = 0, e = init_values.size(); i < e; i++) { + TF_ASSIGN_OR_RETURN(HloInstruction * user_state, + MakeGetTupleElementHlo(while_instr, i + 1)); + result.push_back(user_state); + } + return result; +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index 3600b5a80d26e37fdb7d5173c3b8743734306390..1688d4674269c36c5b356f262dbd5d958572e101 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -52,6 +52,28 @@ class WhileUtil { static StatusOr MakeInstructionsLiveIn( HloInstruction* while_instr, tensorflow::gtl::ArraySlice instructions); + + using LoopStateTy = std::vector; + using LoopBodyGeneratorTy = std::function( + HloInstruction* /*induction_var*/, + const LoopStateTy& /*current_values*/)>; + + // Creates a while loop in `computation` that runs for `trip_count` + // iterations. The structure of the while loop is as follows, in pseudocode: + // + // loop_state while_loop() { + // indvar = 0; + // loop_state = init_values + // while (indvar < trip_count) { + // loop_state = loop_body_generator(loop_state) + // indvar++; + // } + // return loop_state; + // } + static StatusOr MakeCountedLoop( + HloComputation* computation, int32 trip_count, + const LoopStateTy& init_values, + const LoopBodyGeneratorTy& loop_body_generator); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h index 063e312df66ce9cba0fa9f49c2fc6026ba6b74aa..8763e588c484011ba2ccbc7cad8f29817347a605 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -// HLO pass that replaces zero sized Hlos with an zero sized constant literal. +// HLO pass that replaces zero sized Hlos with a zero sized constant literal. namespace xla { class ZeroSizedHloElimination : public HloPassInterface { public: diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 809941d8fe1f63d66bf104e66eea66167a0f509d..d8235113dd800f7bab5ceb70272a598b9dcb1fbe 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -54,6 +54,9 @@ class ServiceInterface { virtual tensorflow::Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) = 0; + virtual tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) = 0; + virtual tensorflow::Status ExecuteParallel( const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) = 0; diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 280f02e88675381bd75108bfae0dd22c462ba718..ffaa40c2d673a2365342371ed8dab59565d1d08f 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -53,7 +53,7 @@ struct ShapeTreeNode { ShapeTreeNode(const ShapeTreeNode& other) : data(other.data), children(other.children.size()) { for (size_t i = 0; i < children.size(); ++i) { - children[i] = MakeUnique(*other.children[i]); + children[i] = ::xla::MakeUnique(*other.children[i]); } } @@ -62,7 +62,7 @@ struct ShapeTreeNode { data = other.data; children.resize(other.children.size()); for (size_t i = 0; i < children.size(); ++i) { - children[i] = MakeUnique(*other.children[i]); + children[i] = ::xla::MakeUnique(*other.children[i]); } } return *this; @@ -445,7 +445,7 @@ class ShapeTreeIterator : public std::iterator(index, node_->data); + current_ = ::xla::MakeUnique(index, node_->data); return *current_; } @@ -492,7 +492,7 @@ void ShapeTree::InitChildren(const Shape& shape, Node* node) { template ShapeTree::ShapeTree(Shape shape) : root_(), - shape_storage_(MakeUnique(std::move(shape))), + shape_storage_(::xla::MakeUnique(std::move(shape))), shape_(shape_storage_.get()) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. @@ -508,7 +508,7 @@ ShapeTree::ShapeTree(const Shape* shape) : root_(), shape_(shape) { template ShapeTree::ShapeTree(Shape shape, const T& init_value) : root_(init_value), - shape_storage_(MakeUnique(std::move(shape))), + shape_storage_(::xla::MakeUnique(std::move(shape))), shape_(shape_storage_.get()) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 604e0173e789348923316174873f58058eaf2815..4f604e6f7cb18c1aaf844967d54e3b0e07e54b34 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -609,6 +609,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs, const Shape& rhs) { + CHECK(ShapeUtil::IsArray(lhs)); + CHECK(ShapeUtil::IsArray(rhs)); return ContainersEqual(lhs.dimensions(), rhs.dimensions()); } @@ -617,7 +619,10 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible); } - return SameDimensions(lhs, rhs) && SameElementType(lhs, rhs); + if (lhs.element_type() == OPAQUE) { + return rhs.element_type() == OPAQUE; + } + return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, @@ -627,7 +632,10 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), CompatibleIgnoringElementType); } - return SameDimensions(lhs, rhs); + if (lhs.element_type() == OPAQUE) { + return rhs.element_type() == OPAQUE; + } + return ShapeUtil::IsArray(rhs) && SameDimensions(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, @@ -637,6 +645,9 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), CompatibleIgnoringFpPrecision); } + if (lhs.element_type() == OPAQUE) { + return rhs.element_type() == OPAQUE; + } if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return CompatibleIgnoringElementType(lhs, rhs); } @@ -1073,9 +1084,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::TransposeIsBitcast( const Shape& input_shape, const Shape& output_shape, tensorflow::gtl::ArraySlice dimension_mapping) { - // Can't insert bitcasts without layout information. - if (!LayoutUtil::HasLayout(input_shape) && - !LayoutUtil::HasLayout(output_shape)) { + CHECK(LayoutUtil::HasLayout(input_shape) && + LayoutUtil::HasLayout(output_shape)); + + if (!SameElementType(input_shape, output_shape)) { return false; } @@ -1106,9 +1118,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape) { - // Can't convert reshapes into bitcasts without layout information. - if (!LayoutUtil::HasLayout(input_shape) || - !LayoutUtil::HasLayout(output_shape)) { + CHECK(LayoutUtil::HasLayout(input_shape) && + LayoutUtil::HasLayout(output_shape)); + + if (!SameElementType(input_shape, output_shape)) { return false; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 19b1aa93bd373ebd5f502d0dca56c9b31ab4fd7f..3e130a02e2ce853ee157e46afb9760f5ff5a5026 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -208,6 +209,7 @@ class ShapeUtil { // Returns whether the LHS and RHS shapes have the same dimensions; note: does // not check element type. + // Precondition: IsArray(lhs) && IsArray(rhs) static bool SameDimensions(const Shape& lhs, const Shape& rhs); // Returns whether the lhs and rhs shapes have the same element type. @@ -320,6 +322,15 @@ class ShapeUtil { static Shape MakeShape(PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions); + // Creates a Shape with element type corresponding to T and the given + // dimensions + template + static Shape MakeShapeWithType( + tensorflow::gtl::ArraySlice dimensions) { + return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), + dimensions); + } + // Constructs a new shape with the given minor_to_major order in its Layout. // Returns a value shape such that shape.has_layout(). static Shape MakeShapeWithLayout( @@ -522,12 +533,16 @@ class ShapeUtil { // Returns whether a transpose from input_shape to output_shape with dimension // mapping "dimension_mapping" produces a result which is bit-wise identical // to its input and thus may be replaced with a bitcast. + // + // Precondition: Both input_shape and output_shape have explicit layouts. static bool TransposeIsBitcast( const Shape& input_shape, const Shape& output_shape, tensorflow::gtl::ArraySlice dimension_mapping); // Returns whether a reshape from "input_shape" to "output_shape" is a // bitcast. + // + // Precondition: Both input_shape and output_shape have explicit layouts. static bool ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape); @@ -560,16 +575,16 @@ class ShapeUtil { // The visitor_function visitor function should return true if it wants to // continue, or false otherwise. // - // visitor_function must be a callable of type bool(const std::vector&) - // or compatible. + // visitor_function must be a callable of type + // StatusOr(ArraySlice) or compatible. template - static void ForEachIndex(const Shape& shape, - tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, - const FnType& visitor_function) { + static Status ForEachIndexWithStatus(const Shape& shape, + tensorflow::gtl::ArraySlice base, + tensorflow::gtl::ArraySlice count, + tensorflow::gtl::ArraySlice incr, + const FnType& visitor_function) { if (ShapeUtil::HasZeroElements(shape)) { - return; + return Status::OK(); } CHECK_EQ(Rank(shape), base.size()); CHECK_EQ(incr.size(), base.size()); @@ -579,7 +594,11 @@ class ShapeUtil { // once with the proper empty indexes. int64 n = -1; std::vector indexes(base.begin(), base.end()); - while (n < rank && visitor_function(indexes)) { + while (n < rank) { + TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes)); + if (!should_continue) { + break; + } // Increments dimensions in minor to major order. for (n = 0; n < rank; ++n) { int64 dim = LayoutUtil::Minor(shape.layout(), n); @@ -590,6 +609,37 @@ class ShapeUtil { indexes[dim] = base[dim]; } } + + return Status::OK(); + } + + // Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus. + struct IndexIterationSpace { + std::vector index_base; + std::vector index_count; + std::vector index_incr; + }; + + template + static Status ForEachIndexWithStatus( + const Shape& shape, const IndexIterationSpace& iteration_space, + FnTy&& function) { + return ShapeUtil::ForEachIndexWithStatus( + shape, iteration_space.index_base, iteration_space.index_count, + iteration_space.index_incr, std::forward(function)); + } + + template + static void ForEachIndex(const Shape& shape, + tensorflow::gtl::ArraySlice base, + tensorflow::gtl::ArraySlice count, + tensorflow::gtl::ArraySlice incr, + const FnType& visitor_function) { + ForEachIndexWithStatus(shape, base, count, incr, + [&](tensorflow::gtl::ArraySlice indices) { + return StatusOr(visitor_function(indices)); + }) + .IgnoreError(); } private: diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 4db97d45b20b86dc60531845c6e28a223203ff7f..424cfe37ea44d64884e08695fd1f49ca1970ca62 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -238,6 +238,18 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentDimensions) { EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); } +TEST(ShapeUtilTest, IncompatibleScalarVsTuple) { + Shape shape1 = ShapeUtil::MakeShape(F32, {}); + Shape shape2 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(U32, {})}); + EXPECT_FALSE(ShapeUtil::Compatible(shape1, shape2)); + EXPECT_FALSE(ShapeUtil::Compatible(shape2, shape1)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(shape1, shape2)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(shape2, shape1)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape2, shape1)); +} + TEST(ShapeUtilTest, CompareShapesWithPaddedDimensionsMismatch) { Shape shape1 = ShapeUtil::MakeShape(F32, {20, 30}); shape1.mutable_layout()->add_padded_dimensions(10); @@ -573,10 +585,11 @@ TEST(ShapeUtilTest, ForEachIndex) { Shape shape = ShapeUtil::MakeShape(F32, data.dimensions); // Increments at every invocation. int invocations = 0; - auto increment_func = [&invocations](const std::vector& indexes) { - invocations++; - return true; - }; + auto increment_func = + [&invocations](tensorflow::gtl::ArraySlice indexes) { + invocations++; + return true; + }; std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); @@ -588,6 +601,29 @@ TEST(ShapeUtilTest, ForEachIndex) { } } +TEST(ShapeUtilTest, ForEachIndexWithStatus) { + Shape shape = ShapeUtil::MakeShape(F32, {10, 10}); + // Increments at every invocation. + int invocations = 0; + auto increment_func = + [&invocations]( + tensorflow::gtl::ArraySlice indexes) -> StatusOr { + if (++invocations == 5) { + return Unimplemented("Cannot increment beyond 5."); + } + return true; + }; + + Status error_status = ShapeUtil::ForEachIndexWithStatus( + shape, /*base=*/{0, 0}, /*count=*/{10, 10}, /*incr=*/{0, 1}, + increment_func); + + EXPECT_FALSE(error_status.ok()); + EXPECT_THAT(error_status.error_message(), + ::testing::HasSubstr("Cannot increment beyond 5.")); + EXPECT_EQ(invocations, 5); +} + TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) { // All output dimensions should be unmodified. One of the input dimensions is // modified because the input rank is larger by one. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 1958e5abf6b9ea971e6e4c498ff291f822ca930c..5ab25f226415efb3736e2626173b0ebcc182f312 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -44,6 +44,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:test", ], + alwayslink = True, ) cc_library( @@ -138,6 +139,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -188,6 +190,7 @@ cc_library( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -334,6 +337,9 @@ xla_test( xla_test( name = "while_test", srcs = ["while_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -381,6 +387,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -478,6 +485,7 @@ xla_test( xla_test( name = "conditional_test", srcs = ["conditional_test.cc"], + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", @@ -494,6 +502,7 @@ xla_test( xla_test( name = "unary_op_test", srcs = ["unary_op_test.cc"], + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", @@ -550,6 +559,9 @@ xla_test( xla_test( name = "deconstruct_tuple_test", srcs = ["deconstruct_tuple_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -586,6 +598,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -662,6 +675,20 @@ xla_test( ], ) +xla_test( + name = "gather_operation_test", + srcs = ["gather_operation_test.cc"], + deps = [ + ":client_library_test_base", + ":hlo_test_base", + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + # Repeat dot_operation_runtime_test with single-threaded eigen. xla_test( name = "dot_operation_single_threaded_runtime_test", @@ -942,6 +969,9 @@ xla_test( name = "dynamic_ops_test", timeout = "moderate", srcs = ["dynamic_ops_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", @@ -968,6 +998,9 @@ xla_test( xla_test( name = "tuple_test", srcs = ["tuple_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -979,6 +1012,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -1011,6 +1045,7 @@ xla_test( shard_count = 40, tags = [ "enable_for_xla_interpreter", + "optonly", ], deps = [ "//tensorflow/compiler/xla:array2d", @@ -1142,6 +1177,9 @@ xla_test( xla_test( name = "call_test", srcs = ["call_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -1290,6 +1328,7 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1336,6 +1375,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1645,6 +1685,9 @@ xla_test( xla_test( name = "fusion_test", srcs = ["fusion_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -1767,9 +1810,8 @@ tf_cc_test( deps = [ ":local_client_test_base", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/service:computation_tracker", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 8b35259013200e96807446803c696451a8db80a9..03c91745b978f80801e0da5ac44d31959659b20c 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -50,28 +51,28 @@ class ArrayElementwiseOpTestParamCount public ::testing::WithParamInterface {}; XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto result = builder.Neg(a); + builder.Neg(a); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); - auto result = builder.Neg(a); + builder.Neg(a); ComputeAndCompareR1(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-1, 0, 1, 324, std::numeric_limits::min(), std::numeric_limits::max()}); - auto result = builder.Neg(a); + builder.Neg(a); // -min == min for int32 due to an overflow. In C++ it is undefined behavior // to do this calculation. For XLA we have not specified that, so it @@ -83,18 +84,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto result = builder.Neg(a); + builder.Neg(a); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}}); - auto result = builder.Neg(a); + builder.Neg(a); ComputeAndCompareR1( &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}}, @@ -102,7 +103,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({ -1, 1, @@ -112,7 +113,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) { static_cast(0x8000000000000000LL), static_cast(0x8000000000000001LL), }); - auto result = builder.Neg(a); + builder.Neg(a); LOG(INFO) << -static_cast(0x7FFFFFFFFFFFFFFFLL); ComputeAndCompareR1(&builder, @@ -129,9 +130,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) { } XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto result = builder.IsFinite(a); + builder.IsFinite(a); ComputeAndCompareR1(&builder, {}, {}); } @@ -140,64 +141,63 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { static const float kNonCanonicalNaN = tensorflow::bit_cast(0x7FD01234); XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) { - ComputationBuilder builder(client_, TestName()); - auto result = builder.IsFinite(builder.ConstantR0(NAN)); + XlaBuilder builder(TestName()); + builder.IsFinite(builder.ConstantR0(NAN)); ComputeAndCompareR0(&builder, false, {}); EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); - auto result_non_canonical = - builder.IsFinite(builder.ConstantR0(kNonCanonicalNaN)); + builder.IsFinite(builder.ConstantR0(kNonCanonicalNaN)); ComputeAndCompareR0(&builder, false, {}); const float inf = std::numeric_limits::infinity(); - auto result_inf = builder.IsFinite(builder.ConstantR0(inf)); + builder.IsFinite(builder.ConstantR0(inf)); ComputeAndCompareR0(&builder, false, {}); - auto result_neg_inf = builder.IsFinite(builder.ConstantR0(-inf)); + builder.IsFinite(builder.ConstantR0(-inf)); ComputeAndCompareR0(&builder, false, {}); - auto result_zero = builder.IsFinite(builder.ConstantR0(0.0f)); + builder.IsFinite(builder.ConstantR0(0.0f)); ComputeAndCompareR0(&builder, true, {}); } XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const float inf = std::numeric_limits::infinity(); EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); auto a = builder.ConstantR1( {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}}); - auto result = builder.IsFinite(a); + builder.IsFinite(a); ComputeAndCompareR1(&builder, {false, true, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); auto b = builder.ConstantR1({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR1(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}}); auto b = builder.ConstantR1( {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}}); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR1( &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {}, @@ -205,10 +205,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -244,7 +244,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { std::unique_ptr rhs_data = client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); - auto add = b.Add(lhs_param, rhs_param); + b.Add(lhs_param, rhs_param); std::vector expected(lhs.size()); for (int64 i = 0; i < lhs.size(); ++i) { @@ -295,7 +295,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { const int count = GetParam(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector a_values; std::vector b_values; for (int i = 0; i < count; ++i) { @@ -334,49 +334,49 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); auto b = builder.ConstantR1({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-1, 0, 2, 1000000000}); auto b = builder.ConstantR1({-1, 2, 1, -1}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {0, -2, 1, 1000000001}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}}); auto b = builder.ConstantR1( {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1( &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {}, @@ -384,29 +384,29 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto b = builder.ConstantR1({10.0f, 5.1f, 1.0f, 10.0f, -6.0f}); - auto add = builder.Div(a, b); + builder.Div(a, b); ComputeAndCompareR1(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Div(a, b); + builder.Div(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -436,9 +436,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = @@ -451,8 +451,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { // Test with a compile-time constant divisor. { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; + XlaBuilder builder(TestName()); + XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); builder.Div(dividend, builder.ConstantR1(divisors)); @@ -461,9 +461,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = @@ -476,8 +476,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { // Test with a compile-time constant divisor. { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; + XlaBuilder builder(TestName()); + XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); builder.Rem(dividend, builder.ConstantR1(divisors)); @@ -507,9 +507,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = @@ -521,8 +521,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; + XlaBuilder builder(TestName()); + XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); builder.Div(dividend, builder.ConstantR1(divisors)); @@ -531,9 +531,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = @@ -545,8 +545,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; + XlaBuilder builder(TestName()); + XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); builder.Rem(dividend, builder.ConstantR1(divisors)); @@ -556,33 +556,33 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}}); auto b = builder.ConstantR1( {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}}); - auto div = builder.Div(a, b); + builder.Div(a, b); ComputeAndCompareR1( &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto div = builder.Div(a, b); + builder.Div(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f}); auto b = builder.ConstantR1( {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f}); - auto add = builder.Rem(a, b); + builder.Rem(a, b); ComputeAndCompareR1( &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {}, @@ -590,21 +590,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Rem(a, b); + builder.Rem(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0}); auto b = builder.ConstantR1( {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0}); - auto add = builder.Rem(a, b); + builder.Rem(a, b); ComputeAndCompareR1( &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {}, @@ -612,20 +612,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto b = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -648,19 +648,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1(a_data); auto b = builder.ConstantR1(b_data); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, {}, {}); } @@ -679,21 +679,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1(a_data); auto b = builder.ConstantR1(b_data); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}}); auto b = builder.ConstantR1( {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1( &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {}, @@ -701,264 +701,264 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({false, false, true, true}); auto b = builder.ConstantR1({false, true, false, true}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {false, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{false, false}, {true, true}}); auto b = builder.ConstantR2({{false, true}, {false, true}}); - auto out = builder.And(a, b); + builder.And(a, b); Array2D expected_array({{false, false}, {false, true}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, -1, -8}); auto b = builder.ConstantR1({5, -7, 12}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {0, -7, 8}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, -5}, {-1, 5}}); auto b = builder.ConstantR2({{1, -6}, {4, 5}}); - auto out = builder.And(a, b); + builder.And(a, b); Array2D expected_array({{0, -6}, {4, 5}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, 1, 8}); auto b = builder.ConstantR1({5, 7, 12}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {0, 1, 8}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, 1}, {3, 8}}); auto b = builder.ConstantR2({{1, 0}, {7, 6}}); - auto out = builder.And(a, b); + builder.And(a, b); Array2D expected_array({{0, 0}, {3, 0}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({false, false, true, true}); auto b = builder.ConstantR1({false, true, false, true}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {false, true, true, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{false, false}, {true, true}}); auto b = builder.ConstantR2({{false, true}, {false, true}}); - auto out = builder.Or(a, b); + builder.Or(a, b); Array2D expected_array({{false, true}, {true, true}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, -1, 8}); auto b = builder.ConstantR1({5, -7, 4}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {5, -1, 12}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, -1}, {8, 8}}); auto b = builder.ConstantR2({{5, -7}, {4, 1}}); - auto out = builder.Or(a, b); + builder.Or(a, b); Array2D expected_array({{5, -1}, {12, 9}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, 1, 8}); auto b = builder.ConstantR1({5, 7, 4}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {5, 7, 12}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, 1}, {8, 8}}); auto b = builder.ConstantR2({{5, 7}, {4, 1}}); - auto out = builder.Or(a, b); + builder.Or(a, b); Array2D expected_array({{5, 7}, {12, 9}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({false, true, true, false}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {true, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{false, true}, {true, false}}); - auto out = builder.Not(a); + builder.Not(a); Array2D expected_array({{true, false}, {false, true}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-1, 0, 1}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {0, -1, -2}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-1, 0}, {1, 8}}); - auto out = builder.Not(a); + builder.Not(a); Array2D expected_array({{0, -1}, {-2, -9}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, 4294967295}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {4294967295, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, 4294967295}, {1, 4294967294}}); - auto out = builder.Not(a); + builder.Not(a); Array2D expected_array({{4294967295, 0}, {4294967294, 1}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({static_cast(0x12345678), static_cast(0xF0001000), 1, 3, 77, 1, -3, 77}); auto b = builder.ConstantR1({4, 8, 2, 7, 15, 32, 100, -1}); - auto out = builder.ShiftLeft(a, b); + builder.ShiftLeft(a, b); ComputeAndCompareR1(&builder, {static_cast(0x23456780), 0x00100000, 0x4, @@ -967,12 +967,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({static_cast(0x92345678), static_cast(0x10001000), 1, 3, 77, 1, -3, 77}); auto b = builder.ConstantR1({4, 8, 2, 7, 2, 32, 100, -1}); - auto out = builder.ShiftRightArithmetic(a, b); + builder.ShiftRightArithmetic(a, b); ComputeAndCompareR1( &builder, @@ -982,45 +982,45 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({static_cast(0x92345678), static_cast(0x10001000), 1, 3, 77, 1, -3, 77}); auto b = builder.ConstantR1({4, 8, 2, 7, 5, 32, 100, -1}); - auto out = builder.ShiftRightLogical(a, b); + builder.ShiftRightLogical(a, b); ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77}); auto b = builder.ConstantR1({4, 8, 2, 7, 15, 32, 100, ~0u}); - auto out = builder.ShiftLeft(a, b); + builder.ShiftLeft(a, b); ComputeAndCompareR1( &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136, 0, 0, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); auto b = builder.ConstantR1({4, 8, 2, 7, 2, 32, 100, ~0u}); - auto out = builder.ShiftRightArithmetic(a, b); + builder.ShiftRightArithmetic(a, b); ComputeAndCompareR1( &builder, {0xF9234567, 0x00100010, 0, 0, 19, 0, ~0u, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); auto b = builder.ConstantR1({4, 8, 2, 7, 5, 32, 100, ~0u}); - auto out = builder.ShiftRightLogical(a, b); + builder.ShiftRightLogical(a, b); ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); @@ -1028,59 +1028,59 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 2.25f, 10.0f, NAN}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - auto compare = builder.Ge(lhs, rhs); + builder.Ge(lhs, rhs); ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - auto compare = builder.Gt(lhs, rhs); + builder.Gt(lhs, rhs); ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 5.0f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - auto compare = builder.Le(lhs, rhs); + builder.Le(lhs, rhs); ComputeAndCompareR1(&builder, {true, true, false, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - auto compare = builder.Lt(lhs, rhs); + builder.Lt(lhs, rhs); ComputeAndCompareR1(&builder, {true, false, false, false, false}, {}); } @@ -1088,10 +1088,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, false, true, false, false, false, true}, @@ -1099,17 +1099,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, {1.0f, 25.5f}, {2.25f, -3.0f}, @@ -1120,16 +1120,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) { {2.25f, -3.0f}, {10.0f, 0.0f}, {1.0f, NAN}}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } @@ -1138,7 +1138,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) { // Disable fast-math because we're operating on NaNs. SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, {1.0f, 25.5f}, {2.25f, -3.0f}, @@ -1149,7 +1149,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) { {2.25f, -3.0f}, {10.0f, 0.0f}, {1.0f, NAN}}); - auto compare = builder.Ne(lhs, rhs); + builder.Ne(lhs, rhs); ComputeAndCompareR1(&builder, {true, true, false, true, true}, {}); } @@ -1158,10 +1158,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { // Disable fast-math because we're operating on NaNs. SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 25.5f, 1.0f, 10.0f, NAN}); - auto compare = builder.Ne(lhs, rhs); + builder.Ne(lhs, rhs); ComputeAndCompareR1(&builder, {true, false, true, true, true}, {}); } @@ -1169,10 +1169,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Ne(lhs, rhs); + builder.Ne(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, true, false, true, true, true, false}, {}); @@ -1181,10 +1181,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Ge(lhs, rhs); + builder.Ge(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, true, true, false, true, true, true}, {}); @@ -1193,10 +1193,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Gt(lhs, rhs); + builder.Gt(lhs, rhs); ComputeAndCompareR1( &builder, {false, false, false, true, false, false, true, true, false}, @@ -1206,10 +1206,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Le(lhs, rhs); + builder.Le(lhs, rhs); ComputeAndCompareR1( &builder, {true, true, true, false, true, true, false, false, true}, {}); @@ -1218,10 +1218,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Lt(lhs, rhs); + builder.Lt(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, false, false, true, false, false, false}, @@ -1230,10 +1230,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, false, true, false, false, false, true}, @@ -1242,10 +1242,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Ne(lhs, rhs); + builder.Ne(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, true, false, true, true, true, false}, {}); @@ -1253,10 +1253,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Ge(lhs, rhs); + builder.Ge(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, true, true, false, true, true, true}, {}); @@ -1264,10 +1264,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Gt(lhs, rhs); + builder.Gt(lhs, rhs); ComputeAndCompareR1( &builder, {false, false, false, true, false, false, true, true, false}, @@ -1276,10 +1276,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Le(lhs, rhs); + builder.Le(lhs, rhs); ComputeAndCompareR1( &builder, {true, true, true, false, true, true, false, false, true}, {}); @@ -1287,10 +1287,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Lt(lhs, rhs); + builder.Lt(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, false, false, true, false, false, false}, @@ -1299,12 +1299,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f}); auto rhs = builder.ConstantR1({2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f}); - auto minimum = builder.Pow(lhs, rhs); + builder.Pow(lhs, rhs); ComputeAndCompareR1( &builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_); @@ -1312,20 +1312,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.0f, -0.6f, -0.6f, 0.0f}); auto rhs = builder.ConstantR1({0.5f, 0.6f, -0.6f, -0.6f}); - auto minimum = builder.Pow(lhs, rhs); + builder.Pow(lhs, rhs); ComputeAndCompareR1(&builder, {NAN, NAN, NAN, INFINITY}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto minimum = builder.Pow(lhs, rhs); + builder.Pow(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -1599,14 +1599,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { const int count = GetParam(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector values; values.reserve(count); for (int i = 0; i < count; ++i) { values.push_back(i / static_cast(count)); } auto x = builder.ConstantR1(values); - auto exp = builder.Pow(x, builder.ConstantR0(2.0f)); + builder.Pow(x, builder.ConstantR0(2.0f)); std::vector expected; expected.reserve(values.size()); @@ -1618,7 +1618,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { } XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D values(2, 2, 2, 2); std::vector values_vector; @@ -1632,140 +1632,86 @@ XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) { Array4D expected(2, 2, 2, 2, expected_vector); auto x = builder.ConstantR4FromArray4D(values); - auto exp = builder.Pow(x, builder.ConstantR0(2.0f)); + builder.Pow(x, builder.ConstantR0(2.0f)); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D values(2, 2, 0, 2); Array4D expected(2, 2, 0, 2); auto x = builder.ConstantR4FromArray4D(values); - auto exp = builder.Pow(x, builder.ConstantR0(2.0f)); + builder.Pow(x, builder.ConstantR0(2.0f)); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } -// GPU backend emits nvvm intrinsic for fmin and fmax, whose semantics is NOT -// such -// * fmin(NaN, x) = x -// * fmax(NaN, x) = x -// so we only test NAN on CPU. -// -// TODO(b/28180546): Make this compile in a way that is consistent -// among backends. XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) { - ComputationBuilder builder(client_, TestName()); -#if !defined(XLA_TEST_BACKEND_CPU) - auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f}); - auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f}); -#else + XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f, 10.0f, NAN}); -#endif - auto minimum = builder.Min(lhs, rhs); - - ComputeAndCompareR1(&builder, -#if !defined(XLA_TEST_BACKEND_CPU) - {1.0f, -5.0f, 1.0f}, -#else - {1.0f, -5.0f, 1.0f, 10.0f, 6.0f}, -#endif - {}, error_spec_); + builder.Min(lhs, rhs); + + ComputeAndCompareR1(&builder, {1.0f, -5.0f, 1.0f, NAN, NAN}, {}, + error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto minimum = builder.Min(lhs, rhs); + builder.Min(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } -// TODO(b/28180546): Make this compile in a way that is consistent -// among backends. See comment on MinF32s test above. XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) { - ComputationBuilder builder(client_, TestName()); -#if !defined(XLA_TEST_BACKEND_CPU) - auto lhs = builder.ConstantR1({1.0, 1.0, 2.25}); - auto rhs = builder.ConstantR1({2.0, -5.0, 1.0}); -#else + XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = builder.ConstantR1({1.0, 1.0, 2.25, NAN, 6.0}); auto rhs = builder.ConstantR1({2.0, -5.0, 1.0, 10.0, NAN}); -#endif - auto minimum = builder.Min(lhs, rhs); + builder.Min(lhs, rhs); - ComputeAndCompareR1(&builder, -#if !defined(XLA_TEST_BACKEND_CPU) - {1.0, -5.0, 1.0}, -#else - {1.0, -5.0, 1.0, 10.0, 6.0}, -#endif - {}, error_spec_); + ComputeAndCompareR1(&builder, {1.0, -5.0, 1.0, NAN, NAN}, {}, + error_spec_); } -// TODO(b/28180546): Make this compile in a way that is consistent -// among backends. See comment on MinF32s test above. XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) { - ComputationBuilder builder(client_, TestName()); -#if !defined(XLA_TEST_BACKEND_CPU) - auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f}); - auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f}); -#else + XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f, 10.0f, NAN}); -#endif - auto maximum = builder.Max(lhs, rhs); - - ComputeAndCompareR1(&builder, -#if !defined(XLA_TEST_BACKEND_CPU) - {2.0f, 1.0f, 2.25f}, -#else - {2.0f, 1.0f, 2.25f, 10.0f, 6.0f}, -#endif - {}, error_spec_); + builder.Max(lhs, rhs); + + ComputeAndCompareR1(&builder, {2.0f, 1.0f, 2.25f, NAN, NAN}, {}, + error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto minimum = builder.Max(lhs, rhs); + builder.Max(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } -// TODO(b/28180546): Make this compile in a way that is consistent -// among backends. See comment on MinF32s test above. XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) { - ComputationBuilder builder(client_, TestName()); -#if !defined(XLA_TEST_BACKEND_CPU) - auto lhs = builder.ConstantR1({1.0, 1.0, 2.25}); - auto rhs = builder.ConstantR1({2.0, -5.0, 1.0}); -#else + XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = builder.ConstantR1({1.0, 1.0, 2.25, NAN, 6.0}); auto rhs = builder.ConstantR1({2.0, -5.0, 1.0, 10.0, NAN}); -#endif - auto maximum = builder.Max(lhs, rhs); + builder.Max(lhs, rhs); - ComputeAndCompareR1(&builder, -#if !defined(XLA_TEST_BACKEND_CPU) - {2.0, 1.0, 2.25}, -#else - {2.0, 1.0, 2.25, 10.0, 6.0}, -#endif - {}, error_spec_); + ComputeAndCompareR1(&builder, {2.0, 1.0, 2.25, NAN, NAN}, {}, + error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); auto y = builder.ConstantR1( @@ -1780,7 +1726,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); auto y = builder.ConstantR1( @@ -1794,7 +1740,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0, 0, 1, 1, 1, max, max, max}); auto y = builder.ConstantR1({0, 1, 0, 1, 10, 0, 234234, max}); builder.Max(x, y); @@ -1805,7 +1751,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0, 0, 1, 1, 1, max, max, max}); auto y = builder.ConstantR1({0, 1, 0, 1, 10, 0, 234234, max}); builder.Min(x, y); @@ -1815,7 +1761,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) { } XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); auto y = builder.ConstantR1( @@ -1828,7 +1774,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto u = builder.ConstantR1({3.5}); auto v = builder.ConstantR1({}); builder.Max(u, v); @@ -1838,7 +1784,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { for (int broadcast_dim : {0, 1}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto u = builder.ConstantR1({3.5}); auto v = builder.ConstantR2FromArray2D(Array2D(0, 2)); builder.Max(u, v, /*broadcast_dimensions=*/{broadcast_dim}); @@ -1848,7 +1794,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({2.0f, 3.0f, 4.0f}); auto m = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); @@ -1859,7 +1805,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({}); auto m = builder.ConstantR2({{}, {}}); builder.Max(v, m, /*broadcast_dimensions=*/{1}); @@ -1869,7 +1815,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto scalar = builder.ConstantR0(2); Array3D a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}}); auto array = builder.ConstantR3FromArray3D(a_3d); @@ -1880,7 +1826,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto scalar = builder.ConstantR0(2); Array3D a_3d(2, 0, 3); auto array = builder.ConstantR3FromArray3D(a_3d); @@ -1891,7 +1837,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantR2({{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}}); auto v = builder.ConstantR1({-10.2f, 16.4f}); @@ -1902,7 +1848,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantR2({{}, {}}); auto v = builder.ConstantR1({-10.2f, 16.4f}); builder.Min(m, v, /*broadcast_dimensions=*/{0}); @@ -1912,7 +1858,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto array2d = builder.ConstantR2({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); auto array4d = builder.ConstantR4FromArray4D( @@ -1927,7 +1873,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto array2d = builder.ConstantR2({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); Array4D arg(2, 2, 0, 3); @@ -1939,7 +1885,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); auto y = builder.ConstantR1({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); builder.Min(x, y); @@ -1949,7 +1895,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); auto y = builder.ConstantR1({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); builder.Max(x, y); @@ -1959,110 +1905,107 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-3, 26, 2, -1, 1}); auto b = builder.ConstantR1({10, 5, 1, 10, -10}); - auto add = builder.Rem(a, b); + builder.Rem(a, b); ComputeAndCompareR1(&builder, {-3, 1, 0, -1, 1}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto minimum = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); auto argument = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); auto maximum = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); - auto clamp = builder.Clamp(minimum, argument, maximum); + builder.Clamp(minimum, argument, maximum); ComputeAndCompareR1(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto minimum = builder.ConstantR0(0.0f); auto argument = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); auto maximum = builder.ConstantR0(5.0f); - auto clamp = builder.Clamp(minimum, argument, maximum); + builder.Clamp(minimum, argument, maximum); ComputeAndCompareR1(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_scalar = builder.ConstantR0(0.0f); auto min_vector = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); auto arg_vector = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); auto max_scalar = builder.ConstantR0(3.0f); auto max_vector = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); // Perform clamp with broadcasted scalar and vector. - auto clamp = builder.Add( - builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0, -5}); auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4, 10}); auto max_vector = builder.ConstantR1({3, 0, 25, 5, 123, -1}); - auto clamp = builder.Clamp(min_vector, arg_vector, max_vector); + builder.Clamp(min_vector, arg_vector, max_vector); ComputeAndCompareR1(&builder, {2, 0, 1, 2, 4, -1}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_scalar = builder.ConstantR0(0); auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0}); auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4}); auto max_scalar = builder.ConstantR0(3); auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); // Perform clamp with broadcasted scalar and vector. - auto clamp = builder.Add( - builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_vector = builder.ConstantR1({1, 2, 1, 2, 0, ~0u - 4}); auto arg_vector = builder.ConstantR1({2, 10, 5, 1, 4, 10}); auto max_vector = builder.ConstantR1({3, 5, 25, 5, 123, ~0u}); - auto clamp = builder.Clamp(min_vector, arg_vector, max_vector); + builder.Clamp(min_vector, arg_vector, max_vector); ComputeAndCompareR1(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_scalar = builder.ConstantR0(0); auto min_vector = builder.ConstantR1({1, 0, 1, 2, 0}); auto arg_vector = builder.ConstantR1({2, 10, 0, 1, 4}); auto max_scalar = builder.ConstantR0(3); auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); // Perform clamp with broadcasted scalar and vector. - auto clamp = builder.Add( - builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); @@ -2076,7 +2019,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto add = builder.Add(p0, p1); + builder.Add(p0, p1); ComputeAndCompareR1(&builder, {8.3f, 4.5f, 6.7f, 11.1f}, {param0_data.get(), param1_data.get()}, @@ -2084,7 +2027,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR3FromArray3D(Array3D(0, 7, 0)); @@ -2098,7 +2041,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto add = builder.Add(p0, p1); + builder.Add(p0, p1); Array3D expected(0, 7, 0); ComputeAndCompareR3( @@ -2106,7 +2049,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); @@ -2115,35 +2058,35 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { auto a = builder.ConstantR1({1.1f, 2.2f, 3.3f, 4.4f}); auto p = builder.Parameter(0, param0_literal->shape(), "param0"); - auto add = builder.Add(a, p); + builder.Add(a, p); ComputeAndCompareR1(&builder, {2.2f, 4.4f, 6.6f, 9.9f}, {param0_data.get()}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({3.14159f, 0.0f, 1.570796f, -0.78539f}); - auto result = builder.Cos(a); + builder.Cos(a); ComputeAndCompareR1(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({3.14159f, 0.0f, 1.570796f, -0.78539f}); - auto result = builder.Sin(a); + builder.Sin(a); ComputeAndCompareR1(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f}); auto b = builder.ConstantR1({6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f}); - auto atan = builder.Atan2(a, b); + builder.Atan2(a, b); ComputeAndCompareR1( &builder, @@ -2152,9 +2095,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { } XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f}); - auto result = builder.Tanh(a); + builder.Tanh(a); ComputeAndCompareR1(&builder, {-0.986614f, 0.996260f, 0.978026}, {}, error_spec_); @@ -2164,7 +2107,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that // the input tensor is large enough to exercise the vectorized tanh // implementation on XLA CPU. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1( {1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80, -0.67, 0.16, -0.07, 0.39, -0.41, 0.04, 1.36, 1.25, 0.41, 0.65, -1.08, 0.32, @@ -2203,7 +2146,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { // The input tensor is large enough to exercise the vectorized exp // implementation on XLA CPU. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Just to help make sense of the scales here -- exp(89) saturates float32 and // exp(-10) is smaller than our error spec. @@ -2239,7 +2182,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { // The input tensor is large enough to exercise the vectorized exp // implementation on XLA CPU. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr input_literal = Literal::CreateR1( {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198, @@ -2279,14 +2222,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { // / / // b -----/ / // c---------------------/ - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({1.1f, 2.2f, 3.3f, 4.4f}); auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); auto add = builder.Add(a, b); - auto add2 = builder.Add(add, c); + builder.Add(add, c); ComputeAndCompareR1(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {}, error_spec_); @@ -2297,14 +2240,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) { // / / // c -----/ / // a---------------------/ - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); auto add = builder.Add(b, c); - auto add2 = builder.Add(a, add); + builder.Add(a, add); ComputeAndCompareR1(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {}, error_spec_); @@ -2314,14 +2257,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) { // a ----- (neg) ----- (add) // / // b ----- (neg) ----/ - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); auto neg_a = builder.Neg(a); auto neg_b = builder.Neg(b); - auto result = builder.Add(neg_a, neg_b); + builder.Add(neg_a, neg_b); ComputeAndCompareR1(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {}, error_spec_); @@ -2335,7 +2278,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { // c ------ (add) ------------/ // / // d -----/ - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); @@ -2344,19 +2287,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { auto add_ab = builder.Add(a, b); auto add_cd = builder.Add(c, d); - auto add_all = builder.Add(add_ab, add_cd); + builder.Add(add_ab, add_cd); ComputeAndCompareR1(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto b = builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - auto add = builder.Add(a, b); + builder.Add(a, b); Array2D expected_array( {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); @@ -2365,11 +2308,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { // Add a scalar + matrix. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto scalar = builder.ConstantR0(3.0f); - auto add = builder.Add(scalar, a); + builder.Add(scalar, a); Array2D expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2377,11 +2320,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) { // Add a matrix + scalar. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto scalar = builder.ConstantR0(3.0f); - auto add = builder.Add(a, scalar); + builder.Add(a, scalar); Array2D expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2390,14 +2333,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) { // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches // only dim 0 of the matrix. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({20.0f, 40.0f, 60.0f}); // clang-format off auto m = builder.ConstantR2({ {-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); // clang-format on - auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1}); + builder.Add(v, m, /*broadcast_dimensions=*/{1}); Array2D expected_array( {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2423,10 +2366,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { // Test broadcasting in Ne comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({42, 73}); auto m = builder.ConstantR2({{42, 73}, {42, 52}}); - auto cmp = builder.Ne(v, m, /*broadcast_dimensions=*/{1}); + builder.Ne(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,2] { { 00 }, @@ -2437,10 +2380,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { // Test broadcasting in Ge comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, 2, 3, 4}); auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - auto cmp = builder.Ge(v, m, /*broadcast_dimensions=*/{1}); + builder.Ge(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 1100 }, @@ -2451,10 +2394,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { // Test broadcasting in Gt comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, 2, 3, 4}); auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - auto cmp = builder.Gt(v, m, /*broadcast_dimensions=*/{1}); + builder.Gt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 0100 }, @@ -2465,10 +2408,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { // Test broadcasting in Le comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, 2, 3, 4}); auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - auto cmp = builder.Le(v, m, /*broadcast_dimensions=*/{1}); + builder.Le(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 1011 }, @@ -2479,10 +2422,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { // Test broadcasting in Lt comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, 2, 3, 4}); auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - auto cmp = builder.Lt(v, m, /*broadcast_dimensions=*/{1}); + builder.Lt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 0011 }, @@ -2494,24 +2437,24 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) { // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op // arguments is reversed. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantR2({{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}}); auto v = builder.ConstantR1({2.0f, 4.0f, 6.0f}); - auto add = builder.Mul(m, v, /*broadcast_dimensions=*/{1}); + builder.Mul(m, v, /*broadcast_dimensions=*/{1}); Array2D expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) { // Tests broadcasting for arrays with degenerate (size == 1) dimensions. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // m's shape in XLA notation is {3, 2} // md's shape in XLA notation is {3, 1} // The result has shape {3, 2}, where md is broadcast over m auto m = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto md = builder.ConstantR2({{10.0f, 20.0f, 30.0f}}); - auto add = builder.Add(m, md); + builder.Add(m, md); Array2D expected_array( {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2519,14 +2462,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) { XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) { // Tests broadcasting for arrays with degenerate (size == 1) dimensions. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // m's shape in XLA notation is {3, 2} // md's shape in XLA notation is {1, 2} // The result has shape {3, 2}, where md is broadcast over m auto m = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto md = builder.ConstantR2({{10.0f}, {20.0f}}); - auto add = builder.Add(m, md); + builder.Add(m, md); Array2D expected_array( {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2537,13 +2480,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) { // effectively creates an "outer product" operation. // This is taken from the Numpy docs example at: // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // a's shape in XLA notation is {1, 4} // b's shape in XLA notation is {3, 1} // The result has shape {3, 4}. auto a = builder.ConstantR2({{0.0f}, {10.0f}, {20.0f}, {30.0f}}); auto b = builder.ConstantR2({{1.0f, 2.0f, 3.0f}}); - auto add = builder.Add(a, b); + builder.Add(a, b); Array2D expected_array({{1.0f, 2.0f, 3.0f}, {11.0f, 12.0f, 13.0f}, {21.0f, 22.0f, 23.0f}, @@ -2554,10 +2497,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) { // Add together a (2,2) array and a (2) array, using dimension 0 for // broadcasting (though there are two ways to broadcast these shapes). - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({20.0f, 40.0f}); auto m = builder.ConstantR2({{10.0f, 50.0f}, {77.0f, 88.0f}}); - auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1}); + builder.Add(v, m, /*broadcast_dimensions=*/{1}); Array2D expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } @@ -2565,17 +2508,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) { // Add together a (2,2) array and a (2) array, using dimension 1 for // broadcasting (though there are two ways to broadcast these shapes). - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({20.0f, 40.0f}); auto m = builder.ConstantR2({{10.0f, 50.0f}, {77.0f, 88.0f}}); - auto add = builder.Add(v, m, /*broadcast_dimensions=*/{0}); + builder.Add(v, m, /*broadcast_dimensions=*/{0}); Array2D expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { // Binary add of two R3s together - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); auto a = builder.ConstantR3FromArray3D(a_3d); @@ -2583,7 +2526,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { Array3D b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}}, {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}}); auto b = builder.ConstantR3FromArray3D(b_3d); - auto add = builder.Add(a, b); + builder.Add(a, b); Array3D expected_3d( {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}}, @@ -2594,7 +2537,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for // broadcasting (though there are two ways to broadcast these shapes). - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // clang-format off Array3D a_3d({ {{1.0f, 2.0f}, @@ -2607,7 +2550,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { // clang-format on auto a = builder.ConstantR3FromArray3D(a_3d); auto v = builder.ConstantR1({10.0f, 20.0f}); - auto add = builder.Add(a, v, /*broadcast_dimensions=*/{2}); + builder.Add(a, v, /*broadcast_dimensions=*/{2}); Array3D expected_3d( {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}}, @@ -2618,7 +2561,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for // broadcasting (though there are two ways to broadcast these shapes). - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // clang-format off Array3D a_3d({ {{1.0f, 2.0f}, @@ -2631,7 +2574,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { // clang-format on auto a = builder.ConstantR3FromArray3D(a_3d); auto v = builder.ConstantR1({10.0f, 20.0f}); - auto add = builder.Add(a, v, /*broadcast_dimensions=*/{0}); + builder.Add(a, v, /*broadcast_dimensions=*/{0}); // clang-format off Array3D expected_3d({ @@ -2649,7 +2592,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2} // for broadcasting. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // clang-format off Array3D a_3d({ {{1.0f, 2.0f}, @@ -2664,7 +2607,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { {10.0f, 20.0f, 30.0f}, {40.0f, 50.0f, 60.0f}, }); - auto add = builder.Add(a, m, /*broadcast_dimensions=*/{0, 1}); + builder.Add(a, m, /*broadcast_dimensions=*/{0, 1}); Array3D expected_3d({ {{11.0f, 12.0f}, @@ -2681,7 +2624,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { // Comparison between two 3D arrays of compatible shapes: // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); auto a = builder.ConstantR3FromArray3D(a_3d); @@ -2689,7 +2632,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { Array3D b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}}); auto b = builder.ConstantR3FromArray3D(b_3d); - auto compare = builder.Gt(a, b); + builder.Gt(a, b); Array3D expected_3d( {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}}); @@ -2705,7 +2648,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { } XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr> operand_a_4d(new Array4D(2, 3, 4, 5)); std::unique_ptr> operand_b_4d(new Array4D(2, 3, 4, 5)); @@ -2726,13 +2669,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) { auto a = builder.ConstantR4FromArray4D(*operand_a_4d); auto b = builder.ConstantR4FromArray4D(*operand_b_4d); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr> operand_a_4d(new Array4D(2, 3, 4, 5)); std::unique_ptr> expected_4d(new Array4D(2, 3, 4, 5)); @@ -2754,7 +2697,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) { auto a = builder.ConstantR4FromArray4D(*operand_a_4d); auto b = builder.ConstantR1(operand_b_1d); - auto add = builder.Add(a, b, {1}); + builder.Add(a, b, {1}); ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); } @@ -2769,7 +2712,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { std::vector r1(d1); std::iota(r1.begin(), r1.end(), 1.0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR4FromArray4DWithLayout( r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); auto a = builder.ConstantLiteral(*a_literal); @@ -2790,11 +2733,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { // Show that we can't add two opaques. XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto shape = ShapeUtil::MakeOpaqueShape(); auto x = builder.Parameter(0, shape, "x"); - auto concatenated = builder.Add(x, x); - StatusOr computation_status = builder.Build(); + builder.Add(x, x); + auto computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), ::testing::ContainsRegex( @@ -2802,12 +2745,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { } XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto b = builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - auto add = builder.Add(a, b, /*broadcast_dimensions=*/{0, 1}); + builder.Add(a, b, /*broadcast_dimensions=*/{0, 1}); Array2D expected_array( {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); @@ -2815,14 +2758,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { } XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto b = builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - auto add = builder.Add(a, b, /*broadcast_dimensions=*/{1, 0}); + builder.Add(a, b, /*broadcast_dimensions=*/{1, 0}); - StatusOr computation_status = builder.Build(); + auto computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().error_message(), ::testing::ContainsRegex("must.*be the identity")); diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index 3f6fd7c65d3360a622dbf754833009fb20410535..ec3b46acfec0ee0ff514a862ce5b1ca74279efa8 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -28,11 +29,11 @@ namespace { class AxpySimpleTest : public ClientLibraryTestBase {}; TEST_F(AxpySimpleTest, AxTenValues) { - ComputationBuilder builder(client_, "ax_10"); + XlaBuilder builder("ax_10"); auto alpha = builder.ConstantR0(3.1415926535); auto x = builder.ConstantR1( {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Mul(alpha, x); + builder.Mul(alpha, x); std::vector expected = { -3.14159265, 3.14159265, 6.28318531, -6.28318531, -9.42477796, @@ -46,7 +47,7 @@ XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) { auto x = builder.ConstantR1({}); auto y = builder.ConstantR1({}); auto ax = builder.Mul(alpha, x); - auto axpy = builder.Add(ax, y); + builder.Add(ax, y); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -60,7 +61,7 @@ TEST_F(AxpySimpleTest, AxpyTenValues) { auto y = builder.ConstantR1( {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); auto ax = builder.Mul(alpha, x); - auto axpy = builder.Add(ax, y); + builder.Add(ax, y); TF_ASSERT_OK_AND_ASSIGN(ProgramShape shape, builder.GetProgramShape()); diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 03f5e08315bfed2bcb43ebb7098aaa0b97228605..97095f1cc427789845051a8fea24c95475286fe2 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -662,7 +662,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("broadcast dimension 0 mismatch")); + HasSubstr("dimension 0 mismatch")); } XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { @@ -675,7 +675,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("binary op BINOP_ADD with incompatible shapes")); + HasSubstr("op BINOP_ADD with incompatible shapes")); } XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { @@ -688,7 +688,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("binary op BINOP_ADD with incompatible shapes")); + HasSubstr("op BINOP_ADD with incompatible shapes")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index a677986cd926cc0054d8f36abc98ccac33dc043d..ec95a68ead055ae3ef301889806ef48982ed76f7 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -95,6 +95,20 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( &execution_options); } +StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout) { + ExecutionOptions execution_options = execution_options_; + if (shape_with_output_layout != nullptr) { + *execution_options.mutable_shape_with_output_layout() = + *shape_with_output_layout; + } + return client_->ExecuteAndTransfer(computation, arguments, + &execution_options); +} + +template <> StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments, @@ -104,6 +118,15 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); } +template <> +StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout) { + // Build the computation, as a convenience. + TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); + return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); +} + std::unique_ptr ClientLibraryTestBase::ExecuteOrDie( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments) { @@ -116,14 +139,31 @@ std::unique_ptr ClientLibraryTestBase::ExecuteAndTransferOrDie( return ExecuteAndTransfer(builder, arguments).ConsumeValueOrDie(); } +string ClientLibraryTestBase::ExecuteToString( + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + auto computation_status = builder->Build(); + if (!computation_status.ok()) { + return computation_status.status().ToString(); + } + auto computation = computation_status.ConsumeValueOrDie(); + + auto result = + client_->ExecuteAndTransfer(computation, arguments, &execution_options_); + if (!result.ok()) { + return result.status().ToString(); + } else { + return result.ValueOrDie()->ToString(); + } +} + string ClientLibraryTestBase::ExecuteToString( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments) { - StatusOr computation_status = builder->Build(); + auto computation_status = builder->Build(); if (!computation_status.ok()) { return computation_status.status().ToString(); } - Computation computation = computation_status.ConsumeValueOrDie(); + auto computation = computation_status.ConsumeValueOrDie(); auto result = client_->ExecuteAndTransfer(computation, arguments, &execution_options_); @@ -142,16 +182,18 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments); } +template void ClientLibraryTestBase::ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, shape_with_layout)); } +template void ClientLibraryTestBase::ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, @@ -249,8 +291,28 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( return choose(0); } +tensorflow::Status +ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( + const xla::XlaComputation& /*computation*/, const Literal& /*expected*/, + tensorflow::gtl::ArraySlice /*arguments*/, + const std::function& /*verify_output*/) { + return Unimplemented("not yet implemented for XlaComputation"); +} + +tensorflow::Status +ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( + const xla::XlaComputation& /*computation*/, const Literal& /*expected*/, + tensorflow::gtl::ArraySlice /*arguments*/, + const std::function& /*verify_output*/, + const Shape* /*output_with_layout*/) { + return Unimplemented("not yet implemented for XlaComputation"); +} + +template tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), @@ -307,8 +369,9 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( return tensorflow::Status::OK(); } +template tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, ErrorSpec error, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), @@ -522,33 +585,6 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, return array; } -std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral( - int64 parameter_number, const Literal& literal, const string& name, - ComputationBuilder* builder, ComputationDataHandle* data_handle) { - return CreateParameterAndTransferLiteral(parameter_number, literal, name, - nullptr, builder, data_handle); -} - -std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral( - int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, ComputationBuilder* builder, - ComputationDataHandle* data_handle) { - const Literal* param_literal = &literal; - std::unique_ptr converted_literal; - if (use_bfloat16_) { - converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); - param_literal = converted_literal.get(); - } - std::unique_ptr data = - client_->TransferToServer(*param_literal, device_handle) - .ConsumeValueOrDie(); - *data_handle = - builder->Parameter(parameter_number, param_literal->shape(), name); - return data; -} - ComputationDataHandle ClientLibraryTestBase::AddParam( const Literal& argument, ComputationBuilder* builder) { ComputationDataHandle data_handle; @@ -563,4 +599,24 @@ ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral( use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); } +template void ClientLibraryTestBase::ComputeAndCompareLiteral( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_layout); + +template void ClientLibraryTestBase::ComputeAndCompareLiteral( + XlaBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_layout); + +template void ClientLibraryTestBase::ComputeAndCompareLiteral( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error, + const Shape* shape_with_layout); + +template void ClientLibraryTestBase::ComputeAndCompareLiteral( + XlaBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error, + const Shape* shape_with_layout); + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index ba0319990bc04196386e6812b0a03671676698ec..5ff200be03ebd2aa76144644acc86f85037fff5a 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -94,15 +95,25 @@ class ClientLibraryTestBase : public ::testing::Test { StatusOr> Execute( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments); + + // TODO(b/74197823): Remove the template type 'BuilderT' in all methods once + // the migration to XlaBuilder is complete. + + template StatusOr> ExecuteAndTransfer( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments, + BuilderT* builder, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout = nullptr); + StatusOr> ExecuteAndTransfer( const Computation& computation, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout = nullptr); + StatusOr> ExecuteAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout = nullptr); + // Convenience OrDie variants of above methods. std::unique_ptr ExecuteOrDie( ComputationBuilder* builder, @@ -113,29 +124,31 @@ class ClientLibraryTestBase : public ::testing::Test { // Run a computation and return its value as a string. If an error // occurs, then instead return the error as a string. + string ExecuteToString(XlaBuilder* builder, + tensorflow::gtl::ArraySlice arguments); string ExecuteToString(ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments); // Convenience methods for building and running a computation, transferring // the result, and comparing it to the expected value(s). Methods are // templated on the native host type which maps to specific XLA types (See - // ComputationBuilder for details). For each rank, two forms are provided: one - // for floating point types with an ErrorSpec parameter, and one for integral - // types without the ErrorSpec parameter. - template - void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected, + // ComputationBuilder/XlaBuilder for details). For each rank, two forms are + // provided: one for floating point types with an ErrorSpec parameter, and one + // for integral types without the ErrorSpec parameter. + template + void ComputeAndCompareR0(BuilderT* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected, + template + void ComputeAndCompareR0(BuilderT* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - template - void ComputeAndCompareR1(ComputationBuilder* builder, + template + void ComputeAndCompareR1(BuilderT* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR1(ComputationBuilder* builder, + template + void ComputeAndCompareR1(BuilderT* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); @@ -146,55 +159,53 @@ class ClientLibraryTestBase : public ::testing::Test { const tensorflow::core::Bitmap& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR2(ComputationBuilder* builder, - const Array2D& expected, + template + void ComputeAndCompareR2(BuilderT* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR2(ComputationBuilder* builder, - const Array2D& expected, + template + void ComputeAndCompareR2(BuilderT* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - template - void ComputeAndCompareR3(ComputationBuilder* builder, - const Array3D& expected, + template + void ComputeAndCompareR3(BuilderT* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR3(ComputationBuilder* builder, - const Array3D& expected, + template + void ComputeAndCompareR3(BuilderT* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - template - void ComputeAndCompareR4(ComputationBuilder* builder, - const Array4D& expected, + template + void ComputeAndCompareR4(BuilderT* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR4(ComputationBuilder* builder, - const Array4D& expected, + template + void ComputeAndCompareR4(BuilderT* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); // Build and run the computation and compare the result with the given // literal. shape_with_layout indicates the result layout to request when // calling Execute. + template void ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout = nullptr); + template void ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); // ComputeAndCompare variant which returns an error status. + template tensorflow::Status ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout = nullptr); + template tensorflow::Status ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); @@ -266,17 +277,19 @@ class ClientLibraryTestBase : public ::testing::Test { // server, then stores into "data_handle" the global handle for that // parameter. When the use_bfloat16 flag is set but the literal has F32 // elements, the literal will be converted to BF16 before being transferred. + template std::unique_ptr CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, - ComputationBuilder* builder, ComputationDataHandle* data_handle); + BuilderT* builder, HandleT* data_handle); // As above, but the caller can specify the device that the literal is // transferred to. If device_handle is nullptr, the literal will be // transferred to the default device. + template std::unique_ptr CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, ComputationBuilder* builder, - ComputationDataHandle* data_handle); + const DeviceHandle* device_handle, BuilderT* builder, + HandleT* data_handle); // Creates a parameter instruction and sets the value that will be passed to // the computation as specified. This function must be used for all parameters @@ -323,10 +336,12 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template - std::unique_ptr CreateR0Parameter( - NativeT value, int64 parameter_number, const string& name, - ComputationBuilder* builder, ComputationDataHandle* data_handle); + template + std::unique_ptr CreateR0Parameter(NativeT value, + int64 parameter_number, + const string& name, + BuilderT* builder, + HandleT* data_handle); // Creates a parameter instruction that wraps the given values and then stores // into "data_handle" the global handle for that parameter. @@ -336,11 +351,10 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle); + const string& name, BuilderT* builder, HandleT* data_handle); // Creates a parameter instruction that wraps the given constant array // "array_2d" and then stores to "data_handle" the global handle for that @@ -351,11 +365,10 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle); + const string& name, BuilderT* builder, HandleT* data_handle); // Creates a parameter instruction that wraps the given constant array // "array_3d" and then stores to "data_handle" the global handle for that @@ -366,11 +379,10 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle); + const string& name, BuilderT* builder, HandleT* data_handle); // Getter and setter for the use_bfloat16 flag, which indicates whether to run // tests with all float-type input/output converted to bfloat16. @@ -399,6 +411,18 @@ class ClientLibraryTestBase : public ::testing::Test { const string& error_message)>& verify_output, const Shape* output_with_layout = nullptr); + tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts( + const xla::XlaComputation& computation, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const std::function& verify_output); + tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts( + const xla::XlaComputation& computation, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const std::function& verify_output, + const Shape* output_with_layout = nullptr); + // Executes the computation and calculates the expected reference value using // the HloEvaluator. Returns two literal in the order of (expected, actual). StatusOr, std::unique_ptr>> @@ -414,9 +438,9 @@ class ClientLibraryTestBase : public ::testing::Test { std::vector> arguments_; }; -template +template void ClientLibraryTestBase::ComputeAndCompareR0( - ComputationBuilder* builder, NativeT expected, + BuilderT* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR0(expected); @@ -424,9 +448,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR0( - ComputationBuilder* builder, NativeT expected, + BuilderT* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -440,9 +464,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR1( - ComputationBuilder* builder, tensorflow::gtl::ArraySlice expected, + BuilderT* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR1(expected); @@ -450,9 +474,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR1( - ComputationBuilder* builder, tensorflow::gtl::ArraySlice expected, + BuilderT* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -466,9 +490,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR2( - ComputationBuilder* builder, const Array2D& expected, + BuilderT* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR2FromArray2D(expected); @@ -476,9 +500,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR2( - ComputationBuilder* builder, const Array2D& expected, + BuilderT* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -492,9 +516,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR3( - ComputationBuilder* builder, const Array3D& expected, + BuilderT* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR3FromArray3D(expected); @@ -502,9 +526,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR3( - ComputationBuilder* builder, const Array3D& expected, + BuilderT* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -518,9 +542,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR4( - ComputationBuilder* builder, const Array4D& expected, + BuilderT* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR4FromArray4D(expected); @@ -528,9 +552,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR4( - ComputationBuilder* builder, const Array4D& expected, + BuilderT* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -544,10 +568,10 @@ void ClientLibraryTestBase::ComputeAndCompareR4( arguments, error); } -template +template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, - ComputationBuilder* builder, ComputationDataHandle* data_handle) { + BuilderT* builder, HandleT* data_handle) { std::unique_ptr literal = Literal::CreateR0(value); if (use_bfloat16_ && literal->shape().element_type() == F32) { literal = LiteralTestUtil::ConvertF32ToBF16(*literal); @@ -558,11 +582,10 @@ std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( return data; } -template +template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle) { + const string& name, BuilderT* builder, HandleT* data_handle) { std::unique_ptr literal = Literal::CreateR1(values); if (use_bfloat16_ && literal->shape().element_type() == F32) { literal = LiteralTestUtil::ConvertF32ToBF16(*literal); @@ -573,11 +596,10 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( return data; } -template +template std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle) { + const string& name, BuilderT* builder, HandleT* data_handle) { std::unique_ptr literal = Literal::CreateR2FromArray2D(array_2d); if (use_bfloat16_ && literal->shape().element_type() == F32) { literal = LiteralTestUtil::ConvertF32ToBF16(*literal); @@ -588,11 +610,10 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( return data; } -template +template std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle) { + const string& name, BuilderT* builder, HandleT* data_handle) { std::unique_ptr literal = Literal::CreateR3FromArray3D(array_3d); if (use_bfloat16_ && literal->shape().element_type() == F32) { literal = LiteralTestUtil::ConvertF32ToBF16(*literal); @@ -628,6 +649,37 @@ std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( return result; } +template +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, + const Literal& literal, + const string& name, + BuilderT* builder, + HandleT* data_handle) { + return CreateParameterAndTransferLiteral(parameter_number, literal, name, + nullptr, builder, data_handle); +} + +template +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + const DeviceHandle* device_handle, BuilderT* builder, + HandleT* data_handle) { + const Literal* param_literal = &literal; + std::unique_ptr converted_literal; + if (use_bfloat16_) { + converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); + param_literal = converted_literal.get(); + } + std::unique_ptr data = + client_->TransferToServer(*param_literal, device_handle) + .ConsumeValueOrDie(); + *data_handle = + builder->Parameter(parameter_number, param_literal->shape(), name); + return data; +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index 1bcad5a3f37a37c9d482f3a5a899ac527666cca3..fb0e9c724a69b61801e6e0c2d07ef75b63a00465 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -75,7 +75,7 @@ XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), - HasSubstr("dimension to concatenate along out of bounds: 0")); + HasSubstr("out of bounds: 0")); } XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) { diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index bc821674820fb128823786d7149037fc59b22ab6..b917dee77b5400db8f2c0a6a86258fee64723d71 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -571,5 +571,56 @@ XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { "only parameter of true_computation")); } +XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { + Shape tuple_shape = ShapeUtil::MakeTupleShape({r0f32_, r0f32_}); + Computation swapper; + { + ComputationBuilder builder(client_, TestName() + ".swapper"); + auto param0 = builder.Parameter(0, tuple_shape, "sp0"); + auto x = builder.GetTupleElement(param0, 0); + auto y = builder.GetTupleElement(param0, 1); + builder.Tuple({y, x}); + swapper = builder.Build().ConsumeValueOrDie(); + } + Computation forwarder; + { + ComputationBuilder builder(client_, TestName() + ".forwarder"); + auto param0 = builder.Parameter(0, tuple_shape, "fp0"); + auto x = builder.GetTupleElement(param0, 0); + auto y = builder.GetTupleElement(param0, 1); + builder.Tuple({x, y}); + forwarder = builder.Build().ConsumeValueOrDie(); + } + Computation main; + { + ComputationBuilder builder(client_, TestName() + ".main"); + auto param0 = builder.Parameter(0, tuple_shape, "mp0"); + auto x = builder.GetTupleElement(param0, 0); + auto y = builder.GetTupleElement(param0, 1); + auto lt_pred = builder.Lt(x, y); + auto res = builder.Conditional(lt_pred, param0, forwarder, param0, swapper); + auto ge_pred = builder.Ge(x, y); + builder.Conditional(ge_pred, res, swapper, res, forwarder); + main = builder.Build().ConsumeValueOrDie(); + } + + auto test_swap = [&](float a, float b) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR0(a); + auto y = builder.ConstantR0(b); + auto tuple_operand = builder.Tuple({x, y}); + builder.Call(main, {tuple_operand}); + + ComputeAndCompareTuple( + &builder, + *Literal::MakeTuple({Literal::CreateR0(a).get(), + Literal::CreateR0(b).get()}), + {}, error_spec_); + }; + + test_swap(3.11f, 9.4f); + test_swap(11.24f, 5.55f); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 59d6d7a4153be1b76ed8195a12a90cb103baa422..9a899b79141fbc35fabd8d2e5d4195fb589dd84c 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -177,6 +178,24 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { ComputeAndCompareR1(&builder, expected, {arg_data.get()}); } +XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { + ComputationBuilder builder(client_, TestName()); + std::vector arg{0.0f, 1.0f, 16777216.0f, + 16777218.0f, 2147483647.0f, 4294967040.0f}; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, U32); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); +} + XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { ComputationBuilder builder(client_, TestName()); std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; @@ -366,5 +385,44 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { ComputeAndCompareR1(&builder, expected_output, {dot_lhs_handle.get()}); } + +XLA_TEST_F(ConvertTest, ConvertC64ToC64) { + ComputationBuilder builder(client_, TestName()); + std::vector x = {{42.0f, 64.0f}}; + builder.ConvertElementType(builder.ConstantR1(x), C64); + ComputeAndCompareR1(&builder, x, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConvertTest, ConvertS64S64) { + ComputationBuilder builder(client_, TestName()); + std::vector x = {{-42, 64}}; + builder.ConvertElementType(builder.ConstantR1(x), S64); + ComputeAndCompareR1(&builder, x, {}); +} + +XLA_TEST_F(ConvertTest, ConvertU64U64) { + ComputationBuilder builder(client_, TestName()); + std::vector x = {{42, 64}}; + builder.ConvertElementType(builder.ConstantR1(x), U64); + ComputeAndCompareR1(&builder, x, {}); +} + +XLA_TEST_F(ConvertTest, ConvertU64S64) { + ComputationBuilder builder(client_, TestName()); + std::vector unsigned_x = {{42, UINT64_MAX}}; + builder.ConvertElementType(builder.ConstantR1(unsigned_x), S64); + std::vector signed_x = {{42, -1}}; + ComputeAndCompareR1(&builder, signed_x, {}); +} + +XLA_TEST_F(ConvertTest, ConvertS64U64) { + ComputationBuilder builder(client_, TestName()); + std::vector signed_x = {{42, -1, INT64_MIN}}; + builder.ConvertElementType(builder.ConstantR1(signed_x), U64); + std::vector unsigned_x = { + {42, UINT64_MAX, tensorflow::MathUtil::IPow(2, 63)}}; + ComputeAndCompareR1(&builder, unsigned_x, {}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index e2b5c91653fa6db5df86404c6c5f9158b0d484e1..72715398dea468d0000144759454c5f8d8673516 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -53,26 +53,12 @@ class ConvolutionTest : public ClientLibraryTestBase { #endif }; -#if (XLA_TEST_BACKEND_GPU || XLA_TEST_BACKEND_CPU) -using TestTypes = ::testing::Types; -#else +#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 using TestTypes = ::testing::Types; +#else +using TestTypes = ::testing::Types; #endif -template -Shape MakeShapeWrapper(tensorflow::gtl::ArraySlice dimensions); - -template <> -Shape MakeShapeWrapper(tensorflow::gtl::ArraySlice dimensions) { - return ShapeUtil::MakeShape(F32, dimensions); -} - -template <> -Shape MakeShapeWrapper( - tensorflow::gtl::ArraySlice dimensions) { - return ShapeUtil::MakeShape(F16, dimensions); -} - template class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { public: @@ -121,8 +107,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { public: void RunTest() { ComputationBuilder builder(client_, TestName()); - Shape input_shape = MakeShapeWrapper({1, 1, 1, 2}); - Shape filter_shape = MakeShapeWrapper({1, 1, 1, 2}); + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 1, 2}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 1, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); @@ -152,8 +138,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { public: void RunTest() { ComputationBuilder builder(client_, TestName()); - Shape input_shape = MakeShapeWrapper({1, 1, 4, 4}); - Shape filter_shape = MakeShapeWrapper({1, 1, 2, 2}); + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 2, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); @@ -186,8 +172,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { public: void RunTest() { ComputationBuilder builder(client_, TestName()); - Shape input_shape = MakeShapeWrapper({1, 1, 4, 4}); - Shape filter_shape = MakeShapeWrapper({1, 1, 2, 2}); + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 2, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); @@ -222,8 +208,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { public: void RunTest() { ComputationBuilder builder(client_, TestName()); - Shape input_shape = MakeShapeWrapper({1, 1, 4, 4}); - Shape filter_shape = MakeShapeWrapper({1, 1, 3, 3}); + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 3, 3}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); @@ -280,8 +266,8 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { void RunTest() { ComputationBuilder builder(client_, TestName()); { - Shape input_shape = MakeShapeWrapper({1, 2, 5}); - Shape filter_shape = MakeShapeWrapper({1, 2, 2}); + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 2, 5}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 2, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); // Convolution dimensions are bf0_oi0->bo0. @@ -381,8 +367,8 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { void RunTest() { ComputationBuilder builder(client_, TestName()); { - Shape input_shape = MakeShapeWrapper({1, 2, 5}); - Shape filter_shape = MakeShapeWrapper({1, 2, 2}); + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 2, 5}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 2, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); // Convolution dimensions are bf0_oi0->bo0. @@ -486,8 +472,8 @@ class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { ComputationBuilder builder(client_, TestName()); std::vector input_dims = {1, 3, 3, 5}; std::vector filter_dims = {3, 3, 5, 3}; - Shape input_shape = MakeShapeWrapper(input_dims); - Shape filter_shape = MakeShapeWrapper(filter_dims); + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); { auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); @@ -611,8 +597,8 @@ class Convolve1D1WindowTestBase input_feature}; std::vector filter_dims = {window_size, input_feature, output_feature}; - Shape input_shape = MakeShapeWrapper(input_dims); - Shape filter_shape = MakeShapeWrapper(filter_dims); + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); { auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); @@ -737,7 +723,7 @@ INSTANTIATE_TEST_CASE_P( ); #endif -TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { +XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { ComputationBuilder builder(client_, TestName()); Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); Shape filter_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index 032c06cd3c9f872f57674d3d7b5adc201c91ea77..3ab0ea4ad48c00724d48e7d285ec024e10d5db31 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -195,7 +195,7 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) { auto result_status = client_->DeconstructTuple(*global_data); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("deconstructing nested tuples not yet supported")); + HasSubstr("Deconstructing nested tuples is not implemented")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 815962094ae476c4b15713ad2c1e4f1e0d140fd9..09b1dd283e4d026a2f0007240d88cd9ac38acb19 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -34,169 +34,194 @@ limitations under the License. namespace xla { namespace { -// TODO(b/34468543): use GUnit typed tests when we can do all tests on all -// backends. class DotOperationTest : public ClientLibraryTestBase { public: ErrorSpec error_spec_{0.0001, 1e-5}; - - protected: - template - void TestOneElementVectorDot(); - template - void TestVectorDot(); - template - void TestSquareMatrixDot(bool lhs_row_major = false, - bool rhs_row_major = false); - template - void TestNonsquareMatrixDot(bool lhs_row_major = false, - bool rhs_row_major = false); }; -XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); +#if defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \ + defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +using TypesF16F32 = ::testing::Types; +using TypesF16F32F64 = ::testing::Types; +using TypesF16F32F64CF64 = ::testing::Types; +#elif !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \ + !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +using TypesF16F32 = ::testing::Types; +using TypesF16F32F64 = ::testing::Types; +using TypesF16F32F64CF64 = + ::testing::Types; +#else +#error "Situation not handled yet" +#endif + +template +class DotOperationTest_F16F32F64CF64 : public DotOperationTest {}; +TYPED_TEST_CASE(DotOperationTest_F16F32F64CF64, TypesF16F32F64CF64); + +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ZeroElementVectorDot) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + + auto lhs = builder.ConstantR1({}); + auto rhs = builder.ConstantR1({}); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR0(&builder, 0.0, {}, error_spec_); + this->template ComputeAndCompareR0(&builder, static_cast(0.0), {}, + this->error_spec_); } -XLA_TEST_F(DotOperationTest, TrivialMatrixVectorDotF32) { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR2({{3.0, 4.0}}); - auto rhs = builder.ConstantR1({3.0, 4.0}); - auto result = builder.Dot(lhs, rhs); +template +class DotOperationTest_F16F32F64 : public DotOperationTest {}; +TYPED_TEST_CASE(DotOperationTest_F16F32F64, TypesF16F32F64); - ComputeAndCompareR1(&builder, {25.0}, {}, error_spec_); -} - -template -void DotOperationTest::TestOneElementVectorDot() { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR1({2.0}); - auto rhs = builder.ConstantR1({3.0}); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + auto lhs = builder.ConstantR2FromArray2D({{3.0f, 4.0f}}); + auto rhs = builder.ConstantFromArray({3.0f, 4.0f}); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR0(&builder, 6.0, {}, error_spec_); + this->template ComputeAndCompareR1(&builder, {static_cast(25.0f)}, {}, + this->error_spec_); } -XLA_TEST_F(DotOperationTest, OneElementVectorDotF32) { - TestOneElementVectorDot(); -} +XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + auto lhs = builder.ConstantR1({static_cast(2.0f)}); + auto rhs = builder.ConstantR1({static_cast(3.0f)}); + auto result = builder.Dot(lhs, rhs); -XLA_TEST_F(DotOperationTest, OneElementVectorDotF64) { - TestOneElementVectorDot(); + this->template ComputeAndCompareR0(&builder, static_cast(6.0f), {}, + this->error_spec_); } -template -void DotOperationTest::TestVectorDot() { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR1({1.0, 2.5, 42.0}); - auto rhs = builder.ConstantR1({11.0, -1.0, 0.5}); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, VectorDot) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + auto lhs = builder.ConstantFromArray({1.0f, 2.5f, 42.0f}); + auto rhs = builder.ConstantFromArray({11.0f, -1.0f, 0.5f}); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR0(&builder, 29.5, {}, error_spec_); + this->template ComputeAndCompareR0(&builder, static_cast(29.5f), {}, + this->error_spec_); } -XLA_TEST_F(DotOperationTest, VectorDotF32) { TestVectorDot(); } - -XLA_TEST_F(DotOperationTest, VectorDotF64) { TestVectorDot(); } - -namespace { - std::vector MinorToMajorForIsRowMajor(bool row_major) { return {row_major ? 1 : 0, row_major ? 0 : 1}; } -} // namespace - -XLA_TEST_F(DotOperationTest, Dot_0x2_2x0) { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); + auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR2(&builder, Array2D(0, 0), {}, error_spec_); + this->template ComputeAndCompareR2(&builder, Array2D(0, 0), {}, + this->error_spec_); } -XLA_TEST_F(DotOperationTest, Dot_0x2_2x3) { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto rhs = builder.ConstantR2({{7.0, 8.0, 9.0}, {42.0, 77.0, 101.0}}); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); + auto rhs = builder.ConstantR2FromArray2D( + {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}}); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR2(&builder, Array2D(0, 3), {}, error_spec_); + this->template ComputeAndCompareR2(&builder, Array2D(0, 3), {}, + this->error_spec_); } -XLA_TEST_F(DotOperationTest, Dot_3x2_2x0) { - ComputationBuilder builder(client_, TestName()); - auto lhs = - builder.ConstantR2({{7.0, 8.0}, {9.0, 42.0}, {77.0, 101.0}}); - auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + auto lhs = builder.ConstantR2FromArray2D( + {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}}); + auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR2(&builder, Array2D(3, 0), {}, error_spec_); + this->template ComputeAndCompareR2(&builder, Array2D(3, 0), {}, + this->error_spec_); } -XLA_TEST_F(DotOperationTest, Dot_2x0_0x2) { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); - auto rhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + auto lhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); + auto rhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR2(&builder, Array2D(2, 2, 0.0f), {}, - error_spec_); + this->template ComputeAndCompareR2( + &builder, Array2D(2, 2, static_cast(0.0f)), {}, this->error_spec_); } -XLA_TEST_F(DotOperationTest, FusedDot) { - ComputationBuilder builder(client_, TestName()); - auto param0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 4}), "arg0"); - auto param1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4, 1}), "arg1"); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + auto param0 = + builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 4}), "arg0"); + auto param1 = + builder.Parameter(1, ShapeUtil::MakeShapeWithType({4, 1}), "arg1"); auto exp0 = builder.Exp(param0); auto result = builder.Dot(exp0, param1); - auto lhs_handle = client_ - ->TransferToServer(*Literal::CreateR2( - {{1.0, 2.0, 3.0, 4.0}, {-1.0, -2.0, -3.0, -4.0}})) - .ConsumeValueOrDie(); - auto rhs_handle = client_ - ->TransferToServer(*Literal::CreateR2( - {{1.0}, {2.0}, {3.0}, {4.0}})) - .ConsumeValueOrDie(); - - ComputeAndCompareR2( - &builder, Array2D({{296.14560492846033}, {0.8611737683031964}}), - {lhs_handle.get(), rhs_handle.get()}, error_spec_); -} - -template -void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major, - bool rhs_row_major) { auto lhs_handle = - client_ - ->TransferToServer(*Literal::CreateR2WithLayout( - {{1.0, 2.0}, {3.0, -4.0}}, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)))) - .ConsumeValueOrDie(); - auto rhs_handle = - client_ - ->TransferToServer(*Literal::CreateR2WithLayout( - {{1.0, 6.0}, {7.0, -4.0}}, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)))) + this->client_ + ->TransferToServer(*Literal::CreateR2FromArray2D( + {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}})) .ConsumeValueOrDie(); + auto rhs_handle = this->client_ + ->TransferToServer(*Literal::CreateR2FromArray2D( + {{1.0f}, {2.0f}, {3.0f}, {4.0f}})) + .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto prim_type = primitive_util::NativeToPrimitiveType(); - auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs")); + if (std::is_same::value) { + this->error_spec_ = ErrorSpec{0.0001, 1e-3}; + } - Array2D expected({{15.0, -2.0}, {-25.0, 34.0}}); - ComputeAndCompareR2( - &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); + this->template ComputeAndCompareR2( + &builder, Array2D({{296.14560492846033f}, {0.8611737683031964f}}), + {lhs_handle.get(), rhs_handle.get()}, this->error_spec_); } +template +class SquareMatrixDot : public DotOperationTest { + public: + void TestImpl(bool lhs_row_major, bool rhs_row_major) { + auto lhs_handle = + client_ + ->TransferToServer(*Literal::CreateFromArrayWithLayout( + {{1.0f, 2.0f}, {3.0f, -4.0f}}, + LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(lhs_row_major)))) + .ConsumeValueOrDie(); + auto rhs_handle = + client_ + ->TransferToServer(*Literal::CreateFromArrayWithLayout( + {{1.0f, 6.0f}, {7.0f, -4.0f}}, + LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(rhs_row_major)))) + .ConsumeValueOrDie(); + ComputationBuilder builder(client_, TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); + auto result = builder.Dot( + builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs")); + + Array2D expected({{15.0f, -2.0f}, {-25.0f, 34.0f}}); + ComputeAndCompareR2(&builder, expected, + {lhs_handle.get(), rhs_handle.get()}, error_spec_); + } +}; + +TYPED_TEST_CASE(SquareMatrixDot, TypesF16F32F64CF64); +XLA_TYPED_TEST(SquareMatrixDot, TypesFF) { this->TestImpl(false, false); } +XLA_TYPED_TEST(SquareMatrixDot, TypesFT) { this->TestImpl(false, true); } +XLA_TYPED_TEST(SquareMatrixDot, TypesTF) { this->TestImpl(true, false); } +XLA_TYPED_TEST(SquareMatrixDot, TypesTT) { this->TestImpl(true, true); } + struct DotTestParam { int m; int k; @@ -302,14 +327,13 @@ void ParametricDotTest::TestImpl() { if (param.has_addend) { args.push_back(addend_handle.get()); } - - ComputeAndCompareR2(&builder, *expected, args, ErrorSpec(0.3, 3e-3)); + ErrorSpec error_spec(0.3, 3e-3); + if (std::is_same::value) { + error_spec = ErrorSpec(0.3, 5e-3); + } + ComputeAndCompareR2(&builder, *expected, args, error_spec); } -XLA_TEST_P(ParametricDotTest, TestF32) { TestImpl(); } - -XLA_TEST_P(ParametricDotTest, TestF64) { TestImpl(); } - std::vector CreateDotTestParameters() { std::vector params; @@ -331,6 +355,12 @@ std::vector CreateDotTestParameters() { return params; } +#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +XLA_TEST_P(ParametricDotTest, TestF16) { TestImpl(); } +#endif +XLA_TEST_P(ParametricDotTest, TestF32) { TestImpl(); } +XLA_TEST_P(ParametricDotTest, TestF64) { TestImpl(); } + INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest, ::testing::ValuesIn(CreateDotTestParameters()), PrintDotTestParam); @@ -343,14 +373,6 @@ class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest { } }; -XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF32) { - TestImpl(); -} - -XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF64) { - TestImpl(); -} - std::vector CreateNoLayoutAssignmentDotTestParameters() { std::vector params; @@ -407,110 +429,60 @@ std::vector CreateNoLayoutAssignmentDotTestParameters() { return params; } -INSTANTIATE_TEST_CASE_P( - DotTests, ParametricDotTestWithoutLayoutAssignment, - ::testing::ValuesIn(CreateNoLayoutAssignmentDotTestParameters()), - PrintDotTestParam); - -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { - TestSquareMatrixDot(false, false); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFT) { - TestSquareMatrixDot(false, true); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTF) { - TestSquareMatrixDot(true, false); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) { - TestSquareMatrixDot(true, true); +#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF16) { + TestImpl(); } - -XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFF) { - TestSquareMatrixDot(false, false); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFT) { - TestSquareMatrixDot(false, true); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTF) { - TestSquareMatrixDot(true, false); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTT) { - TestSquareMatrixDot(true, true); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotF64) { - TestSquareMatrixDot(); -} - -template -void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major, - bool rhs_row_major) { - auto lhs_handle = - client_ - ->TransferToServer(*Literal::CreateR2WithLayout( - {{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}}, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)))) - .ConsumeValueOrDie(); - auto rhs_handle = - client_ - ->TransferToServer(*Literal::CreateR2WithLayout( - {{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}}, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)))) - .ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto prim_type = primitive_util::NativeToPrimitiveType(); - auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs")); - - Array2D expected({{26.0, 0.0}, {-12.0, 10.0}}); - - ComputeAndCompareR2( - &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); -} - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFF) { - TestNonsquareMatrixDot(false, false); -} - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFT) { - TestNonsquareMatrixDot(false, true); -} - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) { - TestNonsquareMatrixDot(true, false); -} - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) { - TestNonsquareMatrixDot(true, true); -} - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) { - TestNonsquareMatrixDot(); +#endif +XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF32) { + TestImpl(); } - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFF) { - TestNonsquareMatrixDot(false, false); +XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF64) { + TestImpl(); } -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFT) { - TestNonsquareMatrixDot(false, true); -} +INSTANTIATE_TEST_CASE_P( + DotTests, ParametricDotTestWithoutLayoutAssignment, + ::testing::ValuesIn(CreateNoLayoutAssignmentDotTestParameters()), + PrintDotTestParam); -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTF) { - TestNonsquareMatrixDot(true, false); -} +template +class NonsquareMatrixDot : public DotOperationTest { + public: + void TestImpl(bool lhs_row_major, bool rhs_row_major) { + auto lhs_handle = + client_ + ->TransferToServer(*Literal::CreateFromArrayWithLayout( + {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}}, + LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(lhs_row_major)))) + .ConsumeValueOrDie(); + auto rhs_handle = + client_ + ->TransferToServer(*Literal::CreateFromArrayWithLayout( + {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}}, + LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(rhs_row_major)))) + .ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); + auto result = builder.Dot( + builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs")); + + Array2D expected({{26.0f, 0.0f}, {-12.0f, 10.0f}}); + + ComputeAndCompareR2(&builder, expected, + {lhs_handle.get(), rhs_handle.get()}, error_spec_); + } +}; -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTT) { - TestNonsquareMatrixDot(true, true); -} +TYPED_TEST_CASE(NonsquareMatrixDot, TypesF16F32F64CF64); +XLA_TYPED_TEST(NonsquareMatrixDot, TestFF) { this->TestImpl(false, false); } +XLA_TYPED_TEST(NonsquareMatrixDot, TestFT) { this->TestImpl(false, true); } +XLA_TYPED_TEST(NonsquareMatrixDot, TestTF) { this->TestImpl(true, false); } +XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); } XLA_TEST_F(DotOperationTest, MatrixVectorC64) { auto lhs_handle = @@ -537,25 +509,35 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) { &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); } -XLA_TEST_F(DotOperationTest, ConcurrentMatMul) { - ComputationBuilder builder(client_, TestName()); - auto matrix1 = builder.ConstantR2({{1.0, 2.0}, {3.0, 4.0}}); - auto matrix2 = builder.ConstantR2({{5.0, 6.0}, {7.0, 8.0}}); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, ConcurrentMatMult) { + using T = TypeParam; + + ComputationBuilder builder(this->client_, this->TestName()); + auto matrix1 = builder.ConstantR2FromArray2D({{1.0f, 2.0f}, {3.0f, 4.0f}}); + auto matrix2 = builder.ConstantR2FromArray2D({{5.0f, 6.0f}, {7.0f, 8.0f}}); auto matrix12 = builder.Dot(matrix1, matrix2); auto matrix21 = builder.Dot(matrix2, matrix1); builder.Add(matrix12, matrix21); - Array2D expected({{42.0, 56.0}, {74.0, 96.0}}); - ComputeAndCompareR2(&builder, expected, {}, error_spec_); + Array2D expected({{42.0f, 56.0f}, {74.0f, 96.0f}}); + this->template ComputeAndCompareR2(&builder, expected, {}, + this->error_spec_); } +template +class DotOperationTestForBatchMatMul : public DotOperationTest {}; +TYPED_TEST_CASE(DotOperationTestForBatchMatMul, TypesF16F32F64); + // Regression test for b/32055648. The root of the graph is a kFusion of 4 // bitcasts. Although bitcasts don't map to thunks, the root should still be // sync-dependent on bitcasts' operands. -XLA_TEST_F(DotOperationTest, BatchMatMul) { - ComputationBuilder builder(client_, TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "y"); +XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + auto x = + builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), "x"); + auto y = + builder.Parameter(1, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), "y"); auto x_flat = builder.Reshape(x, {0, 1, 2, 3}, {4, 2, 2}); auto y_flat = builder.Reshape(y, {0, 1, 2, 3}, {4, 2, 2}); @@ -576,29 +558,42 @@ XLA_TEST_F(DotOperationTest, BatchMatMul) { auto out_flat = builder.ConcatInDim(out_slices, 0); builder.Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); - auto x_data = client_ - ->TransferToServer(*Literal::CreateR4( - {{{{1000, 100}, {10, 1}}, {{2000, 200}, {20, 2}}}, - {{{3000, 300}, {30, 3}}, {{4000, 400}, {40, 4}}}})) - .ConsumeValueOrDie(); - auto y_data = client_ - ->TransferToServer(*Literal::CreateR4( - {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}, - {{{11, 22}, {33, 44}}, {{55, 66}, {77, 88}}}})) + auto x_data = this->client_ + ->TransferToServer(*Literal::CreateR4FromArray4D( + {{{{1000.0f, 100.0f}, {10.0f, 1.0f}}, + {{2000.0f, 200.0f}, {20.0f, 2.0f}}}, + {{{3000.0f, 300.0f}, {30.0f, 3.0f}}, + {{4000.0f, 400.0f}, {40.0f, 4.0f}}}})) .ConsumeValueOrDie(); + auto y_data = + this->client_ + ->TransferToServer(*Literal::CreateR4FromArray4D( + {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + {{{11.0f, 22.0f}, {33.0f, 44.0f}}, + {{55.0f, 66.0f}, {77.0f, 88.0f}}}})) + .ConsumeValueOrDie(); - ComputeAndCompareR4( + if (std::is_same::value) { + this->error_spec_ = ErrorSpec{0.0001, 1e-3}; + } + this->template ComputeAndCompareR4( &builder, /*expected=*/ - {{{{1300, 2400}, {13, 24}}, {{11400, 13600}, {114, 136}}}, - {{{42900, 79200}, {429, 792}}, {{250800, 299200}, {2508, 2992}}}}, - {x_data.get(), y_data.get()}, error_spec_); + {{{{1300.0f, 2400.0f}, {13.0f, 24.0f}}, + {{11400.0f, 13600.0f}, {114.0f, 136.0f}}}, + {{{42900.0f, 79200.0f}, {429.0f, 792.0f}}, + {{250800.0f, 299200.0f}, {2508.0f, 2992.0f}}}}, + {x_data.get(), y_data.get()}, this->error_spec_); } -XLA_TEST_F(DotOperationTest, GeneralMatMul) { - ComputationBuilder builder(client_, TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2}), "y"); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { + using T = TypeParam; + + ComputationBuilder builder(this->client_, this->TestName()); + auto x = + builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2, 2}), "x"); + auto y = + builder.Parameter(1, ShapeUtil::MakeShapeWithType({2, 2, 2}), "y"); DotDimensionNumbers dnums; dnums.add_lhs_contracting_dimensions(2); @@ -608,31 +603,34 @@ XLA_TEST_F(DotOperationTest, GeneralMatMul) { auto out = builder.DotGeneral(x, y, dnums); - auto x_data = client_ - ->TransferToServer(*Literal::CreateR3( - {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}})) - .ConsumeValueOrDie(); + auto x_data = + this->client_ + ->TransferToServer(*Literal::CreateR3FromArray3D( + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) + .ConsumeValueOrDie(); - auto y_data = client_ - ->TransferToServer(*Literal::CreateR3( - {{{1.0, 0.0}, {0.0, 1.0}}, {{1.0, 0.0}, {0.0, 1.0}}})) - .ConsumeValueOrDie(); + auto y_data = + this->client_ + ->TransferToServer(*Literal::CreateR3FromArray3D( + {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}})) + .ConsumeValueOrDie(); - ComputeAndCompareR3( + this->template ComputeAndCompareR3( &builder, /*expected=*/ - {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}}, - {x_data.get(), y_data.get()}, error_spec_); + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + {x_data.get(), y_data.get()}, this->error_spec_); } -TEST_F(DotOperationTest, TransposeFolding) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { + using T = TypeParam; for (bool transpose_lhs : {false, true}) { for (bool transpose_rhs : {false, true}) { for (bool row_major : {false, true}) { - std::unique_ptr> lhs( - new Array2D({{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}})); - std::unique_ptr> rhs( - new Array2D({{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}})); + std::unique_ptr> lhs( + new Array2D({{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}})); + std::unique_ptr> rhs( + new Array2D({{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}})); if (transpose_lhs) { lhs = ReferenceUtil::TransposeArray2D(*lhs); @@ -641,22 +639,20 @@ TEST_F(DotOperationTest, TransposeFolding) { rhs = ReferenceUtil::TransposeArray2D(*rhs); } auto lhs_handle = - client_ - ->TransferToServer( - *Literal::CreateR2FromArray2DWithLayout( - *lhs, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(row_major)))) + this->client_ + ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( + *lhs, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); auto rhs_handle = - client_ - ->TransferToServer( - *Literal::CreateR2FromArray2DWithLayout( - *rhs, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(row_major)))) + this->client_ + ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( + *rhs, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto prim_type = primitive_util::NativeToPrimitiveType(); + ComputationBuilder builder(this->client_, this->TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); auto lhs_arg = builder.Parameter( 0, ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}), "lhs"); @@ -671,24 +667,27 @@ TEST_F(DotOperationTest, TransposeFolding) { } auto result = builder.Dot(lhs_arg, rhs_arg); - Array2D expected({{26.0, 0.0}, {-12.0, 10.0}}); + Array2D expected({{26.0f, 0.0f}, {-12.0f, 10.0f}}); VLOG(1) << "TestTransposeFolding " << transpose_lhs << " " << transpose_rhs << " " << row_major; - ComputeAndCompareR2(&builder, expected, - {lhs_handle.get(), rhs_handle.get()}, - error_spec_); + this->template ComputeAndCompareR2( + &builder, expected, {lhs_handle.get(), rhs_handle.get()}, + this->error_spec_); } } } } -TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstLHS) { - auto prim_type = primitive_util::NativeToPrimitiveType(); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, + DotOfConcatOptimizationWithConstLHS) { + using T = TypeParam; + auto prim_type = primitive_util::NativeToPrimitiveType(); - std::unique_ptr> constant_lhs_array(new Array2D( - {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, + {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}})); - ComputationBuilder builder(client_, TestName()); + ComputationBuilder builder(this->client_, this->TestName()); auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); auto rhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs_arg_0"); @@ -699,78 +698,80 @@ TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstLHS) { auto result = builder.Dot( lhs_constant, builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0)); - std::unique_ptr> arg_0_value_array( - new Array2D({{1.0, 2.0}, {3.0, 4.0}})); - std::unique_ptr> arg_1_value_array( - new Array2D({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})); - std::unique_ptr> arg_2_value_array( - new Array2D({{1.0, 2.0}})); + std::unique_ptr> arg_0_value_array( + new Array2D({{1.0f, 2.0f}, {3.0f, 4.0f}})); + std::unique_ptr> arg_1_value_array( + new Array2D({{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}})); + std::unique_ptr> arg_2_value_array(new Array2D({{1.0f, 2.0f}})); TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, - client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_0_value_array))); + this->client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, - client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_1_value_array))); + this->client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, - client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_2_value_array))); + this->client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_2_value_array))); - Array2D expected({{53.0, 74.0}, {45.0, 66.0}}); - ComputeAndCompareR2( + Array2D expected({{53.0f, 74.0f}, {45.0f, 66.0f}}); + this->template ComputeAndCompareR2( &builder, expected, - {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_); -} - -TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstRHS) { - auto prim_type = primitive_util::NativeToPrimitiveType(); - - std::unique_ptr> constant_rhs_array( - new Array2D({{1.0, 2.0}, - {3.0, 4.0}, - {5.0, 6.0}, - {6.0, 5.0}, - {4.0, 3.0}, - {2.0, 1.0}})); - - ComputationBuilder builder(client_, TestName()); + {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, + this->error_spec_); +} + +XLA_TYPED_TEST(DotOperationTest_F16F32F64, + DotOfConcatOptimizationWithConstRHS) { + using T = TypeParam; + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0f, 2.0f}, + {3.0f, 4.0f}, + {5.0f, 6.0f}, + {6.0f, 5.0f}, + {4.0f, 3.0f}, + {2.0f, 1.0f}})); + + ComputationBuilder builder(this->client_, this->TestName()); auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), + auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2}), "lhs_arg_0"); - auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 3}), + auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShapeWithType({2, 3}), "lhs_arg_1"); - auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {2, 1}), + auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShapeWithType({2, 1}), "lhs_arg_2"); auto result = builder.Dot( builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), rhs_constant); - std::unique_ptr> arg_0_value_array( - new Array2D({{1.0, 2.0}, {3.0, 4.0}})); - std::unique_ptr> arg_1_value_array( - new Array2D({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); - std::unique_ptr> arg_2_value_array( - new Array2D({{1.0}, {2.0}})); + std::unique_ptr> arg_0_value_array( + new Array2D({{1.0f, 2.0f}, {3.0f, 4.0f}})); + std::unique_ptr> arg_1_value_array( + new Array2D({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}})); + std::unique_ptr> arg_2_value_array( + new Array2D({{1.0f}, {2.0f}})); TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, - client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_0_value_array))); + this->client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, - client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_1_value_array))); + this->client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, - client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_2_value_array))); + this->client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_2_value_array))); - Array2D expected({{38.0, 36.0}, {93.0, 91.0}}); - ComputeAndCompareR2( + Array2D expected({{38.0f, 36.0f}, {93.0f, 91.0f}}); + this->template ComputeAndCompareR2( &builder, expected, - {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_); + {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, + this->error_spec_); } + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 877dc7db0eec229a7119b3627f177a33ed0d971b..4f354e6aefe70a51c09be1c0ca151af2bb9f0a2c 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -206,19 +206,19 @@ XLA_TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R1Wrap) { TestR1Wrap(); } XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1(); } -XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } +XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R2Wrap) { TestR2Wrap(); } -XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } +XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R3Wrap) { TestR3Wrap(); } XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3(); } -XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } +XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R1Pred) { // Slice at dimension start. @@ -506,7 +506,7 @@ XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R1BF16)) { } XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1(); } // TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R2BF16)) { diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index 6fe7737de7af349dca2931b52d62dbc03b14e0b3..b28fe0c15a89a1331698a29f70b966380bd3fcb9 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -71,8 +71,8 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, LogF32) { #ifdef XLA_TEST_BACKEND_CPU // TODO(b/73141998): The vectorized Log implementation gives results outside // our error spec in this range (these numbers are bitwise representations of - // floats expressed as a zero extended int64): - std::pair known_incorrect_range = {1, 8315654}; + // floats expressed as a zero extended int64). + std::pair known_incorrect_range = {1, 8388608}; #else std::pair known_incorrect_range = {0, 0}; #endif diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9db68ff7a6dcbd9204fb2b3a37734a9aaed35dfd --- /dev/null +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -0,0 +1,461 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +// NB! TODO(b/74360564): These tests do not test out of bounds behavior since +// that hasn't been specced yet. + +namespace xla { +namespace { + +using tensorflow::gtl::nullopt; + +class GatherOperationTest : public HloTestBase { + protected: + void RunTest(const string& hlo_text, Literal* operand, + Literal* gather_indices) { + RunTest(hlo_text, {operand, gather_indices}); + } + + void RunTest(const string& hlo_text, + tensorflow::gtl::ArraySlice args) { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_text, config)); + EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt)); + } +}; + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherV1) { + const string hlo_text = R"( +HloModule TensorFlowGatherV1 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[2,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1, 3} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) { + const string hlo_text = R"( +HloModule TensorFlowGatherV2 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[3,2] gather(operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3, 1} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) { + const string hlo_text = R"( +HloModule TensorFlowGatherMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,3,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 2}, {2, 1}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) { + const string hlo_text = R"( +HloModule TensorFlowGatherNdMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2,2] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=2, + window_bounds={1, 1} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) { + const string hlo_text = R"( +HloModule TensorFlowGatherNdMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2,2] parameter(1) + ROOT gather = s32[2,1,1,2] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=2, + window_bounds={1, 1} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) { + const string hlo_text = R"( +HloModule TensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1,2} +} +)"; + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) { + const string hlo_text = R"( +HloModule TensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1,2} +} +)"; + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, DynamicSlice) { + const char* hlo_text = R"( +HloModule DynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[1,1] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) { + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,1,1] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{2, 1}, {1, 1}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, ZeroDimBounds) { + const char* hlo_text = R"( +HloModule TensorFlowGatherV1 + +ENTRY main { + operand = s32[3,0] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[2,0] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1, 0} +} +)"; + std::unique_ptr operand = Literal::CreateR2({{}, {}, {}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) { + // Out of bounds indices must not crash, and the indices in range should + // produce the same values across all backends. + // + // TODO(b/74360564): Once we have a well defined semantics for OOB accesses, + // we should get rid of the mask and check that backends produce the same + // value for OOB indices too. + + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + gather = s32[6,1,1]{2,1,0} gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} + gather_reshaped = s32[6]{0} reshape(gather) + in_bounds_mask = s32[6]{0} parameter(2) + ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); + std::unique_ptr in_bounds_mask = + Literal::CreateR1({0, 1, 1, 0, 0, 1}); + + RunTest(hlo_text, + {operand.get(), gather_indices.get(), in_bounds_mask.get()}); +} + +XLA_TEST_F(GatherOperationTest, NegativeIndex) { + // Negative indices must not crash, and the indices in range should produce + // the same values across all backends. + // + // TODO(b/74360564): Once we have a well defined semantics for negative + // accesses, we should get rid of the mask and check that backends produce the + // same value for negative indices too. + + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + gather = s32[6,1,1]{2,1,0} gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} + gather_reshaped = s32[6]{0} reshape(gather) + in_bounds_mask = s32[6]{0} parameter(2) + ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR2( + {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); + std::unique_ptr in_bounds_mask = + Literal::CreateR1({0, 1, 1, 0, 0, 1}); + + RunTest(hlo_text, + {operand.get(), gather_indices.get(), in_bounds_mask.get()}); +} + +XLA_TEST_F(GatherOperationTest, OneScalarIndex) { + const char* hlo_text = R"( +HloModule OneScalarIndex + +ENTRY main { + operand = s32[2,3,2]{2,1,0} parameter(0) + index = s32[] parameter(1) + ROOT gather = s32[1,3,2]{2,1,0} gather(operand, index), + output_window_dims={0,1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0}, + index_vector_dim=0, + window_bounds={1,3,2} +} +)"; + std::unique_ptr operand = Literal::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); + std::unique_ptr gather_indices = Literal::CreateR0(1); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, ScalarResult) { + const char* hlo_text = R"( +HloModule ScalarResult + +ENTRY main { + operand = s32[4]{0} parameter(0) + index = s32[] parameter(1) + ROOT gather = s32[] gather(operand, index), + output_window_dims={}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=0, + window_bounds={1} +} +)"; + std::unique_ptr operand = Literal::CreateR1({1, 2, 3, 4}); + std::unique_ptr gather_indices = Literal::CreateR0(1); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, ZeroSizedResult) { + const string hlo_text = R"( +HloModule ZeroSizedResult + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[0] parameter(1) + ROOT gather = s32[0,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1, 3} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +class GatherClientLibraryTest : public ClientLibraryTestBase {}; + +// TODO(b/30671675): Asynchronous execution on stream is not yet supported on +// GPU and CPU_PARALLEL. +XLA_TEST_F(GatherClientLibraryTest, + DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(Basic))) { + // We create this HLO, but using the ComputationBuilder API. + // + // ENTRY main { + // operand = s32[3,3] parameter(0) + // indices = s32[2] parameter(1) + // ROOT gather = s32[2,3] gather(operand, indices), + // output_window_dims={1}, + // elided_window_dims={0}, + // gather_dims_to_operand_dims={0}, + // index_vector_dim=1, + // window_bounds={1, 3} + // } + + ComputationBuilder builder(client_, "gather_basic"); + + Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3}); + Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); + + auto operand = builder.Parameter(0, operand_shape, "operand"); + auto indices = builder.Parameter(1, indices_shape, "indices"); + GatherDimensionNumbers dim_numbers; + dim_numbers.add_output_window_dims(1); + dim_numbers.add_elided_window_dims(0); + dim_numbers.add_gather_dims_to_operand_dims(0); + dim_numbers.set_index_vector_dim(1); + builder.Gather(operand, indices, dim_numbers, {1, 3}); + + std::vector expected = {}; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr operand_arg, + client_->TransferToServer(*Literal::CreateR2( + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr indices_arg, + client_->TransferToServer(*Literal::CreateR1({0, 2}))); + TF_ASSERT_OK_AND_ASSIGN(std::vector devices, + client_->GetDeviceHandles(1)); + xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions(); + *execution_options.add_device_handles() = devices[0]; + TF_ASSERT_OK_AND_ASSIGN(Computation computation, builder.Build()); + std::vector computation_instances = { + {computation, + {operand_arg.get(), indices_arg.get()}, + execution_options, + /*execution_profile=*/nullptr}}; + TF_ASSERT_OK_AND_ASSIGN( + std::vector> result_data, + client_->ExecuteParallel(computation_instances)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + client_->Transfer(*(result_data[0]))); + LiteralTestUtil::ExpectEqual( + *result_literal, *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}})); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc index eded2077fce965ab1c729c610764afa2228ca128..cf971dd61b71ad329b20b0bb7c16166126562681 100644 --- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc +++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" @@ -30,7 +29,7 @@ class HloMetadataTest : public LocalClientTestBase { metadata_.set_op_name("my_sum_op"); } - void BuildAddComputation(ComputationBuilder* builder) { + void BuildAddComputation(XlaBuilder* builder) { auto x = builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder->Add(x, y); @@ -40,7 +39,7 @@ class HloMetadataTest : public LocalClientTestBase { }; TEST_F(HloMetadataTest, MetadataPropagation) { - ComputationBuilder builder(local_client_, "add"); + XlaBuilder builder("add"); builder.SetOpMetadata(metadata_); BuildAddComputation(&builder); builder.ClearOpMetadata(); @@ -61,7 +60,7 @@ TEST_F(HloMetadataTest, MetadataPropagation) { } TEST_F(HloMetadataTest, MetadataClearing) { - ComputationBuilder builder(local_client_, "add"); + XlaBuilder builder("add"); builder.SetOpMetadata(metadata_); // Some other pretend computation here. builder.ClearOpMetadata(); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 6723c99edb945492abfbac159bed1959d551ec57..e574644dea7c1ba144ba87fbeb7f28cc52312e26 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -115,6 +115,13 @@ StatusOr> HloTestBase::Execute( return test_runner_.Execute(std::move(module), arguments); } +StatusOr> HloTestBase::ExecuteNoHloPasses( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments) { + return test_runner_.Execute(std::move(module), arguments, + /*run_hlo_passes=*/false); +} + std::unique_ptr HloTestBase::ExecuteAndTransfer( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments) { @@ -140,15 +147,10 @@ StatusOr> HloTestBase::MakeReferenceModule( return std::move(reference_module); } -template StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, const ArraySlice arguments, const optional& error, bool run_hlo_passes, const std::function& reference_preprocessor) { - static_assert( - std::is_same::value || - std::is_same, LiteralPtr>::value, - "The LiteralPtr type only accepts Literal* or std::unique_ptr."); TF_RETURN_IF_ERROR( VerifyHloModule(*test_runner_.backend().platform(), module.get())); TF_ASSIGN_OR_RETURN(auto reference_module, @@ -165,9 +167,8 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( error); } -template ::testing::AssertionResult HloTestBase::RunAndCompare( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, const ArraySlice arguments, const optional& error, const std::function& reference_preprocessor) { auto result = @@ -179,9 +180,8 @@ template return result.ValueOrDie(); } -template ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, const ArraySlice arguments, const optional& error, const std::function& reference_preprocessor) { auto result = @@ -198,8 +198,14 @@ template const std::function& reference_preprocessor) { const auto& fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie(); - return RunAndCompare>( - std::move(module), fake_arguments, error, reference_preprocessor); + + std::vector fake_argument_ptrs; + c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const std::unique_ptr& literal) { return literal.get(); }); + + return RunAndCompare(std::move(module), fake_argument_ptrs, error, + reference_preprocessor); } ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( @@ -207,8 +213,13 @@ template const std::function& reference_preprocessor) { const auto& fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie(); - return RunAndCompareNoHloPasses>( - std::move(module), fake_arguments, error, reference_preprocessor); + std::vector fake_argument_ptrs; + c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const std::unique_ptr& literal) { return literal.get(); }); + + return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error, + reference_preprocessor); } ::testing::AssertionResult HloTestBase::RunAndCompare( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 413bb213fdcb1303f396308d13d9d0b96b47b71f..3e8e2360bb3a87e127920cd222803c0f7b9161f4 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -44,7 +44,7 @@ namespace xla { // enables, for one, explicitly building a graph of HLO instructions to run. // // This can also be used to write text/file-based test cases. Note that the test -// target is responsible for linking the needed backends. A covenient way to do +// target is responsible for linking the needed backends. A convenient way to do // this is to make it an xla_test: it will generate test targets linking with // the respective backends, which will be used as the test backend; the // interpreter backend is already linked with hlo_test_base so it will be the @@ -98,14 +98,19 @@ class HloTestBase : public ::testing::Test { std::unique_ptr module, tensorflow::gtl::ArraySlice arguments); + // Same as above, except the module will be executed without running any HLO + // passes on it. + StatusOr> ExecuteNoHloPasses( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments); + std::unique_ptr ExecuteAndTransfer( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments); // Executes the given hlo module on two backends and compares results. // - // 'arguments': the input of the hlo module. The LiteralPtr type accepts - // Literal* or std::unique_ptr. + // 'arguments': the input of the hlo module. // // 'error': if has value, expects the results to be near (within the error // bound). Otherwise, expects the results to be equal. @@ -114,20 +119,18 @@ class HloTestBase : public ::testing::Test { // backend, but it might need to be tailored so that it is able to run on the // reference backend. Note that the program shape of the module must not be // modified. - template ::testing::AssertionResult RunAndCompare( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::ArraySlice arguments, const tensorflow::gtl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; // Same as above, except that the module will be executed without Hlo // optimization. - template ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::ArraySlice arguments, const tensorflow::gtl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; @@ -232,10 +235,9 @@ class HloTestBase : public ::testing::Test { // Runs the module on two platforms with or without running hlo passes and // compares the results. Returns whether the results are near or equal. If any // error happens before the results are computed, returns the error status. - template StatusOr<::testing::AssertionResult> RunAndCompareInternal( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::ArraySlice arguments, const tensorflow::gtl::optional& error, bool run_hlo_passes, const std::function& reference_preprocessor); }; diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index 506091ddd8d1d8e6519525bb7031f4e8b296b5fb..da4cf4ae0c31bc194cd2ec9b845df36afbde69b0 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -40,18 +41,22 @@ void HloVerifiedTestBase::TearDown() { << "TearDown called more than once; it should be called exactly once."; tear_down_called_ = true; if (module_) { - HloVerifier verifier; - xla::StatusOr mutated = verifier.Run(module_.get()); - if (!mutated.ok()) { - ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); - } else { - EXPECT_FALSE(mutated.ValueOrDie()) - << "HloVerifier should never mutate the HloModule"; - } + VerifyModule(); } HloTestBase::TearDown(); } +void HloVerifiedTestBase::VerifyModule() { + HloVerifier verifier; + xla::StatusOr mutated = verifier.Run(module_.get()); + if (!mutated.ok()) { + ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); + } else { + EXPECT_FALSE(mutated.ValueOrDie()) + << "HloVerifier should never mutate the HloModule"; + } +} + HloModule& HloVerifiedTestBase::module() { if (!module_) { module_ = CreateNewModule(); @@ -59,4 +64,10 @@ HloModule& HloVerifiedTestBase::module() { return *module_; } +void HloVerifiedTestBase::ParseAndVerifyModule( + tensorflow::StringPiece hlo_text) { + CHECK(!module_) << "Called ParseModule when test already has a module."; + TF_ASSERT_OK_AND_ASSIGN(module_, tools::Parse(hlo_text)); + VerifyModule(); +} } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index 492688bf7d682cf991cb8c09399492a0437f651b..e5bb14a8839acbdef8fd2b79bb0f574c46ea3d40 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -44,6 +44,7 @@ class HloVerifiedTestBase : public HloTestBase { // Returns the default HloModule, lazily creating it if necessary via // HloTestBase::CreateNewModule(). HloModule& module(); + void ParseAndVerifyModule(tensorflow::StringPiece hlo_text); // Sets the shape-size function used during hlo verification. If this isn't // called, a default ShapeVerifier is used instead. @@ -55,6 +56,7 @@ class HloVerifiedTestBase : public HloTestBase { std::unique_ptr module_; // Lazily populated. Access via module(). std::unique_ptr shape_verifier_; bool tear_down_called_ = false; + void VerifyModule(); }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc index 99514baf23cafe61adc28a30dfdfe2691ab82d32..3023df47cda33f5d11abc921fd0355d48f761107 100644 --- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc +++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -49,11 +50,11 @@ void LLVMIRGenTestBase::CompileAndVerifyIr( std::unique_ptr hlo_module, const string& pattern, bool match_optimized_ir) { SetIrHook(match_optimized_ir); - ASSERT_TRUE(CompileToExecutable(std::move(hlo_module)).ok()); + TF_ASSERT_OK(CompileToExecutable(std::move(hlo_module)).status()); ResetIrHook(); StatusOr filecheck_result = RunFileCheck(ir_, pattern); - ASSERT_TRUE(filecheck_result.ok()); + TF_ASSERT_OK(filecheck_result.status()); EXPECT_TRUE(filecheck_result.ValueOrDie()); } diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 2b0f7e6e80c48435ca55432a2afa3b6d69162625..0cd812fd1b4bc69c34b70d3ca0fd0aa6cf57fa4c 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -531,7 +531,7 @@ TEST_F(MapTest, MapOperantionWithBuildError) { ASSERT_TRUE(!computation_status.ok()); EXPECT_THAT( computation_status.status().ToString(), - ::testing::HasSubstr("error from: ErrorAdd: binary op BINOP_ADD with " + ::testing::HasSubstr("error from: ErrorAdd: Binary op BINOP_ADD with " "different element types: f32[] and u16[]")); } diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 6c86dd5b9ef673c9facffafa37e00a859ce82010..c42f71388baba73e08a361d817e41b03e03bf133 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -29,6 +29,8 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" @@ -38,258 +40,223 @@ limitations under the License. namespace xla { namespace { -class MatOpsSimpleTest : public ClientLibraryTestBase { - protected: - Computation BuildSum() { - // sum(x, y) = x + y - ComputationBuilder builder(client_, "sum"); - auto x_value = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value"); - auto y_value = - builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y_value"); - builder.Add(x_value, y_value); - auto computation_status = builder.Build(); - TF_CHECK_OK(computation_status.status()); - return computation_status.ConsumeValueOrDie(); - } - - void TestLinspaceMax(int64 rows, int64 cols) { - float from = -128.0, to = 256.0; - std::unique_ptr> alhs = - MakeLinspaceArray2D(from, to, rows, cols); - auto arhs = MakeUnique>(rows, cols, 1.0); - - ComputationBuilder builder( - client_, - tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols)); - auto lhs = builder.ConstantR2FromArray2D(*alhs); - auto rhs = builder.ConstantR2FromArray2D(*arhs); - auto max = builder.Max(lhs, rhs); - - Array2D aexpected(rows, cols); - for (int row = 0; row < rows; ++row) { - for (int col = 0; col < cols; ++col) { - aexpected(row, col) = std::max((*alhs)(row, col), (*arhs)(row, col)); - } - } - - ComputeAndCompareR2(&builder, aexpected, {}, ErrorSpec(1e-6)); - } -}; - -TEST_F(MatOpsSimpleTest, ExpTwoByTwoValues) { - ComputationBuilder builder(client_, "exp_2x2"); - auto data = builder.ConstantR2({ - {1.0, 0.0}, // row 0 - {-1.0, 0.5}, // row 1 +#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +using TypesF16F32 = ::testing::Types; +#else +using TypesF16F32 = ::testing::Types; +#endif + +class MatOpsSimpleTest : public ClientLibraryTestBase {}; + +template +class MatOpsSimpleTest_F16F32 : public MatOpsSimpleTest {}; + +// TODO(bixia): This test for F16 failed on GPU 02-25-2018. +#ifdef XLA_TEST_BACKEND_GPU +TYPED_TEST_CASE(MatOpsSimpleTest_F16F32, ::testing::Types); +#else +TYPED_TEST_CASE(MatOpsSimpleTest_F16F32, TypesF16F32); +#endif + +XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { + using T = TypeParam; + ComputationBuilder builder(this->client_, "exp_2x2"); + auto data = builder.ConstantR2FromArray2D({ + {1.0f, 0.0f}, // row 0 + {-1.0f, 0.5f}, // row 1 }); builder.Exp(data); std::unique_ptr expected = - Literal::CreateR2({{2.71828, 1.00000}, // row 0 - {0.36788, 1.64872}}); // row 1 + Literal::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 + {0.36788f, 1.64872f}}); // row 1 - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); + this->template ComputeAndCompareLiteral(&builder, *expected, {}, + ErrorSpec(1e-5)); } -TEST_F(MatOpsSimpleTest, MapTwoByTwo) { +XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { + using T = TypeParam; Computation add_half; { // add_half(x) = x + 0.5 - ComputationBuilder builder(client_, "add_half"); + ComputationBuilder builder(this->client_, "add_half"); auto x_value = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value"); - auto half = builder.ConstantR0(0.5); + builder.Parameter(0, ShapeUtil::MakeShapeWithType({}), "x_value"); + auto half = builder.ConstantR0(static_cast(0.5)); builder.Add(x_value, half); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); add_half = computation_status.ConsumeValueOrDie(); } - ComputationBuilder builder(client_, "map_2x2"); - auto data = builder.ConstantR2({ - {1.0, 0.0}, // row 0 - {-1.0, 0.5}, // row 1 + ComputationBuilder builder(this->client_, "map_2x2"); + auto data = builder.ConstantR2FromArray2D({ + {1.0f, 0.0f}, // row 0 + {-1.0f, 0.5f}, // row 1 }); auto map = builder.Map({data}, add_half, {0, 1}); std::unique_ptr expected = - Literal::CreateR2({{1.5, 0.5}, // row 0 - {-0.5, 1.0}}); // row 1 - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); + Literal::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 + {-0.5f, 1.0f}}); // row 1 + this->template ComputeAndCompareLiteral(&builder, *expected, {}, + ErrorSpec(1e-5)); } -TEST_F(MatOpsSimpleTest, MaxTwoByTwoValues) { - ComputationBuilder builder(client_, "max_2x2"); - auto lhs = builder.ConstantR2({ - {7.0, 2.0}, // row 0 - {3.0, -4.0}, // row 1 +XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { + using T = TypeParam; + ComputationBuilder builder(this->client_, "max_2x2"); + auto lhs = builder.ConstantR2FromArray2D({ + {7.0f, 2.0f}, // row 0 + {3.0f, -4.0f}, // row 1 }); - auto rhs = builder.ConstantR2({ - {5.0, 6.0}, // row 0 - {1.0, -8.0}, // row 1 + auto rhs = builder.ConstantR2FromArray2D({ + {5.0f, 6.0f}, // row 0 + {1.0f, -8.0f}, // row 1 }); auto max = builder.Max(lhs, rhs); std::unique_ptr expected = - Literal::CreateR2({{7.0, 6.0}, // row 0 - {3.0, -4.0}}); // row 1 - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); + Literal::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 + {3.0f, -4.0f}}); // row 1 + this->template ComputeAndCompareLiteral(&builder, *expected, {}, + ErrorSpec(1e-6)); } -TEST_F(MatOpsSimpleTest, Max1x1Linspace) { TestLinspaceMax(1, 1); } - -TEST_F(MatOpsSimpleTest, Max2x2Linspace) { TestLinspaceMax(2, 2); } - -TEST_F(MatOpsSimpleTest, Max3x3Linspace) { TestLinspaceMax(3, 3); } - -TEST_F(MatOpsSimpleTest, Max4x4Linspace) { TestLinspaceMax(4, 4); } - -TEST_F(MatOpsSimpleTest, Max6x6Linspace) { TestLinspaceMax(6, 6); } - -TEST_F(MatOpsSimpleTest, Max8x8Linspace) { TestLinspaceMax(8, 8); } - -TEST_F(MatOpsSimpleTest, Max12x12Linspace) { TestLinspaceMax(12, 12); } - -TEST_F(MatOpsSimpleTest, Max16x16Linspace) { TestLinspaceMax(16, 16); } +struct TestLinspaceMaxParam { + int64 rows; + int64 cols; +}; -TEST_F(MatOpsSimpleTest, Max32x8Linspace) { TestLinspaceMax(32, 8); } +class TestLinspaceMaxParametric + : public MatOpsSimpleTest, + public ::testing::WithParamInterface { + public: + template + void TestImpl() { + TestLinspaceMaxParam param = GetParam(); + int64 rows = param.rows; + int64 cols = param.cols; + float from = -128.0, to = 256.0; + std::unique_ptr> alhs = + MakeLinspaceArray2D(from, to, rows, cols); + auto arhs = MakeUnique>(rows, cols, static_cast(1.0f)); -TEST_F(MatOpsSimpleTest, Max64x8Linspace) { TestLinspaceMax(64, 8); } + ComputationBuilder builder( + client_, + tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols)); + auto lhs = builder.ConstantR2FromArray2D(*alhs); + auto rhs = builder.ConstantR2FromArray2D(*arhs); + auto max = builder.Max(lhs, rhs); -class MatOpsDotAddTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface> {}; - -TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2) { - bool row_major = std::get<0>(GetParam()); - bool add_lhs = std::get<1>(GetParam()); - bool transpose = std::get<2>(GetParam()); - Array2D lhs({{1.0, 2.0}, {3.0, 4.0}}); - Array2D rhs({{10.0, 11.0}, {12.0, 13.0}}); - - auto minor_to_major = [](bool row_major) -> std::vector { - return {row_major ? 1 : 0, row_major ? 0 : 1}; - }; - - auto prim_type = primitive_util::NativeToPrimitiveType(); - Shape lhs_shape = - ShapeUtil::MakeShape(prim_type, {lhs.height(), lhs.width()}); - Shape rhs_shape = - ShapeUtil::MakeShape(prim_type, {rhs.height(), rhs.width()}); - - TF_ASSERT_OK_AND_ASSIGN( - auto lhs_handle, - client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( - lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); - TF_ASSERT_OK_AND_ASSIGN( - auto rhs_handle, - client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( - rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); - - ComputationBuilder builder(client_, TestName()); - auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); - auto lhs_mat_arg = lhs_arg; - if (transpose) { - lhs_mat_arg = builder.Transpose(lhs_mat_arg, {1, 0}); - } - auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs"); - auto result = builder.Dot(lhs_mat_arg, rhs_arg); - Array2D expected; - if (add_lhs) { - result = builder.Add(result, lhs_arg); - if (transpose) { - expected = Array2D({{47, 52}, {71, 78}}); - } else { - expected = Array2D({{35, 39}, {81, 89}}); + Array2D expected(rows, cols); + for (int row = 0; row < rows; ++row) { + for (int col = 0; col < cols; ++col) { + expected(row, col) = std::max((*alhs)(row, col), (*arhs)(row, col)); + } } - } else { - result = builder.Add(result, rhs_arg); - if (transpose) { - expected = Array2D({{56, 61}, {80, 87}}); - } else { - expected = Array2D({{44, 48}, {90, 98}}); + ErrorSpec error_spec(1e-6); + if (std::is_same::value) { + error_spec = ErrorSpec(1e-6, 2e-4); } + ComputeAndCompareR2(&builder, expected, {}, error_spec); } +}; - ComputeAndCompareR2(&builder, expected, - {lhs_handle.get(), rhs_handle.get()}, - ErrorSpec(1e-6)); +string PrintTestLinspaceMaxParam( + const ::testing::TestParamInfo& test_param) { + const TestLinspaceMaxParam& param = test_param.param; + return tensorflow::strings::StrCat(param.rows, "r", param.cols, "c"); } -INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest, - ::testing::Combine(::testing::Bool(), ::testing::Bool(), - ::testing::Bool())); +#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +// TODO(bixia): This test failed on GPU 02-25-2018 +#ifdef XLA_TEST_BACKEND_CPU +XLA_TEST_P(TestLinspaceMaxParametric, TestF16) { TestImpl(); } +#endif +#endif +XLA_TEST_P(TestLinspaceMaxParametric, TestF32) { TestImpl(); } + +INSTANTIATE_TEST_CASE_P( + TestLinspaceMax, TestLinspaceMaxParametric, + ::testing::Values(TestLinspaceMaxParam{1, 1}, TestLinspaceMaxParam{2, 2}, + TestLinspaceMaxParam{3, 3}, TestLinspaceMaxParam{4, 4}, + TestLinspaceMaxParam{6, 6}, TestLinspaceMaxParam{8, 8}, + TestLinspaceMaxParam{12, 12}, + TestLinspaceMaxParam{16, 16}, TestLinspaceMaxParam{32, 8}, + TestLinspaceMaxParam{64, 8}), + PrintTestLinspaceMaxParam); -class MatOpsDotAddTest_bf16 +class MatOpsDotAddTest : public ClientLibraryTestBase, - public ::testing::WithParamInterface> {}; - -TEST_P(MatOpsDotAddTest_bf16, Dot_Add_2x2_2x2) { - bool row_major = std::get<0>(GetParam()); - bool add_lhs = std::get<1>(GetParam()); - bool transpose = std::get<2>(GetParam()); - Array2D lhs( - {{bfloat16(1.0f), bfloat16(2.0f)}, {bfloat16(3.0), bfloat16(4.0)}}); - Array2D rhs( - {{bfloat16(10.0f), bfloat16(11.0f)}, {bfloat16(12.0f), bfloat16(13.0f)}}); - - auto minor_to_major = [](bool row_major) -> std::vector { - return {row_major ? 1 : 0, row_major ? 0 : 1}; - }; - - auto prim_type = primitive_util::NativeToPrimitiveType(); - Shape lhs_shape = - ShapeUtil::MakeShape(prim_type, {lhs.height(), lhs.width()}); - Shape rhs_shape = - ShapeUtil::MakeShape(prim_type, {rhs.height(), rhs.width()}); - - TF_ASSERT_OK_AND_ASSIGN( - auto lhs_handle, - client_->TransferToServer( - *Literal::CreateR2FromArray2DWithLayout( - lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); - TF_ASSERT_OK_AND_ASSIGN( - auto rhs_handle, - client_->TransferToServer( - *Literal::CreateR2FromArray2DWithLayout( - rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); - - ComputationBuilder builder(client_, TestName()); - auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); - auto lhs_mat_arg = lhs_arg; - if (transpose) { - lhs_mat_arg = builder.Transpose(lhs_mat_arg, {1, 0}); - } - auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs"); - auto result = builder.Dot(lhs_mat_arg, rhs_arg); - Array2D expected; - if (add_lhs) { - result = builder.Add(result, lhs_arg); + public ::testing::WithParamInterface> { + public: + template + void TestImpl() { + bool row_major = std::get<0>(GetParam()); + bool add_lhs = std::get<1>(GetParam()); + bool transpose = std::get<2>(GetParam()); + Array2D lhs({{1.0f, 2.0f}, {3.0f, 4.0f}}); + Array2D rhs({{10.0f, 11.0f}, {12.0f, 13.0f}}); + + auto minor_to_major = [](bool row_major) -> std::vector { + return {row_major ? 1 : 0, row_major ? 0 : 1}; + }; + + auto prim_type = primitive_util::NativeToPrimitiveType(); + Shape lhs_shape = + ShapeUtil::MakeShape(prim_type, {lhs.height(), lhs.width()}); + Shape rhs_shape = + ShapeUtil::MakeShape(prim_type, {rhs.height(), rhs.width()}); + + TF_ASSERT_OK_AND_ASSIGN( + auto lhs_handle, + client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( + lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + TF_ASSERT_OK_AND_ASSIGN( + auto rhs_handle, + client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( + rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + + ComputationBuilder builder(client_, TestName()); + auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); + auto lhs_mat_arg = lhs_arg; if (transpose) { - expected = Array2D( - {{bfloat16(47), bfloat16(52)}, {bfloat16(71), bfloat16(78)}}); - } else { - expected = Array2D( - {{bfloat16(35), bfloat16(39)}, {bfloat16(81), bfloat16(89)}}); + lhs_mat_arg = builder.Transpose(lhs_mat_arg, {1, 0}); } - } else { - result = builder.Add(result, rhs_arg); - if (transpose) { - expected = Array2D( - {{bfloat16(56), bfloat16(61)}, {bfloat16(80), bfloat16(87)}}); + auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs"); + auto result = builder.Dot(lhs_mat_arg, rhs_arg); + Array2D expected; + if (add_lhs) { + result = builder.Add(result, lhs_arg); + if (transpose) { + expected = Array2D({{47.0f, 52.0f}, {71.0f, 78.0f}}); + } else { + expected = Array2D({{35.0f, 39.0f}, {81.0f, 89.0f}}); + } } else { - expected = Array2D( - {{bfloat16(44), bfloat16(48)}, {bfloat16(90), bfloat16(98)}}); + result = builder.Add(result, rhs_arg); + if (transpose) { + expected = Array2D({{56.0f, 61.0f}, {80.0f, 87.0f}}); + } else { + expected = Array2D({{44.0f, 48.0f}, {90.0f, 98.0f}}); + } } + + ComputeAndCompareR2(&builder, expected, + {lhs_handle.get(), rhs_handle.get()}, + ErrorSpec(1e-6)); } +}; - ComputeAndCompareR2(&builder, expected, - {lhs_handle.get(), rhs_handle.get()}, - ErrorSpec(1e-6)); -} +XLA_TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2BF16) { TestImpl(); } +#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +XLA_TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2F16) { TestImpl(); } +#endif +XLA_TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2F32) { TestImpl(); } -INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest_bf16, +INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest, ::testing::Combine(::testing::Bool(), ::testing::Bool(), ::testing::Bool())); diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 50d7b5074d201d2292cf90224ef4cd37efdbb8d3..3a097a01ab095b8a21a39f0d738a43c3d6a4d1d7 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -884,5 +884,47 @@ XLA_TEST_F(ReduceTest, ReduceOrPredR2_64x32_To_R1) { RunR2ToR1PredTest(/*and_reduce=false*/ false, /*rows=64*/ 64); } +// Tests reductions with different initial values. There's no test macro that +// combines TYPED_TEST and TYPED_P, so we have to do it manually. +class ReduceInitializerTest : public ReduceTest { + protected: + template + void DoTest(T initializer, int num_elems) { + ComputationBuilder builder(client_, TestName()); + Computation max_fn = CreateScalarMaxComputation( + primitive_util::NativeToPrimitiveType(), &builder); + + auto init = builder.ConstantR0(initializer); + std::vector input_arr(num_elems, std::numeric_limits::lowest()); + auto input_literal = Literal::CreateR1(input_arr); + auto input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + builder.Reduce(builder.Parameter(0, input_literal->shape(), "input"), init, + max_fn, {0}); + + ComputeAndCompareR0(&builder, initializer, {input_data.get()}); + } +}; + +XLA_TEST_F(ReduceInitializerTest, U8Small) { DoTest(42, 2); } + +XLA_TEST_F(ReduceInitializerTest, U8BigPowerOf2) { DoTest(42, 4096); } + +XLA_TEST_F(ReduceInitializerTest, U8InitializerBigNonPowerOf2) { + DoTest(42, 4095); +} + +XLA_TEST_F(ReduceInitializerTest, U64InitializerZero) { + DoTest(0, 1024); +} + +XLA_TEST_F(ReduceInitializerTest, U64InitializerOne) { + DoTest(1, 1024); +} + +XLA_TEST_F(ReduceInitializerTest, U64InitializerBigValue) { + DoTest(1234556789123, 1024); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index b11b64e40a582150d6adf29e915cd70b4bcb982b..9c317fe579394c5b7a1d599169f471d484950199 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -960,45 +960,76 @@ struct R2ReduceWindowTestData { int64 base_bounds[2]; int64 window_bounds[2]; int64 strides[2]; + int64 pad_low[2]; + int64 pad_high[2]; int64 layout[2]; - Padding padding; Reducer reducer; } kR2TestCases[] = { {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4}, - /*strides=*/{1, 2}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 2}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{2, 5}, /*window_bounds=*/{2, 4}, - /*strides=*/{1, 1}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 2}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{1, 3}, /*window_bounds=*/{2, 3}, - /*strides=*/{1, 1}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3, 129}, /*window_bounds=*/{1, 100}, - /*strides=*/{2, 99}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{2, 99}, /*pad_low=*/{0, 0}, /*pad_high=*/{35, 35}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, +// TODO(b/74260408): This test last failed on GPU on 2018-03-08, likely due to a +// ptxas bug. +#ifndef XLA_TEST_BACKEND_GPU {/*base_bounds=*/{6, 152}, /*window_bounds=*/{2, 25}, - /*strides=*/{5, 4}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{5, 4}, /*pad_low=*/{0, 1}, /*pad_high=*/{10, 11}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, +#endif {/*base_bounds=*/{6, 4}, /*window_bounds=*/{4, 2}, - /*strides=*/{3, 3}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{3, 3}, /*pad_low=*/{0, 1}, /*pad_high=*/{0, 1}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{5, 147}, /*window_bounds=*/{1, 36}, - /*strides=*/{4, 5}, /*layout=*/{1, 0}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{4, 5}, /*pad_low=*/{0, 0}, /*pad_high=*/{17, 17}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{4, 153}, /*window_bounds=*/{2, 93}, - /*strides=*/{1, 1}, /*layout=*/{1, 0}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{46, 46}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, // Regression test for a bug that appeared in Inception (b/34784899). {/*base_bounds=*/{28, 28}, /*window_bounds=*/{3, 3}, - /*strides=*/{1, 1}, /*layout=*/{1, 0}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{1, 1}, /*pad_high=*/{1, 1}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, // Regression test for a bug that appeared in Inception (b/34784899). {/*base_bounds=*/{4, 32}, /*window_bounds=*/{2, 2}, - /*strides=*/{2, 2}, /*layout=*/{1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, - {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2}, - /*strides=*/{1, 1}, /*layout=*/{1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*strides=*/{2, 2}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, + // Regression test for b/73903312: bf16 lacks precision to store result of + // very large windows. Testing with a reasonable window larger than 128. + {/*base_bounds=*/{8, 130}, /*window_bounds=*/{1, 130}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 130}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, +// TODO(b/76025683): These tests fail on TPU. +#if defined(XLA_TEST_BACKEND_CPU) || defined(XLA_TEST_BACKEND_GPU) + {/*base_bounds=*/{4096, 4096}, /*window_bounds=*/{1, 4}, + /*strides=*/{1, 1024}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0}, + /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{8, 256}, /*window_bounds=*/{1, 4}, + /*strides=*/{1, 64}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, +#endif }; string R2ReduceWindowTestDataToString( @@ -1008,10 +1039,11 @@ string R2ReduceWindowTestDataToString( string str = tensorflow::strings::StrCat( "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // "__window_bounds_", - tensorflow::str_util::Join(param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(param.strides, "x"), // - "__padding_", param.padding == Padding::kSame ? "same" : "valid", // - "__layout_", param.layout[0], "_", param.layout[1], // + tensorflow::str_util::Join(param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(param.strides, "x"), // + "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), + "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), + "__layout_", param.layout[0], "_", param.layout[1], // "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { str = tensorflow::strings::StrCat(str, "_bfloat16"); @@ -1039,17 +1071,29 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, ComputationDataHandle parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); + std::vector> padding(2); + for (int i = 0; i < 2; ++i) { + padding[i] = {param.pad_low[i], param.pad_high[i]}; + } + auto computation = param.reducer == kAdd + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); auto init_value = CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); - b.ReduceWindow(/*operand=*/parameter, - /*init_value=*/init_value, - /*computation=*/CreateScalarAddComputation(FloatType(), &b), - /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/param.padding); + b.ReduceWindowWithGeneralPadding( + /*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/computation, + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/padding); - auto expected = ReferenceUtil::ReduceWindow2DAdd( - /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/param.padding); + auto reduce_func = param.reducer == kAdd + ? +[](float a, float b) { return a + b; } + : +[](float a, float b) { return std::max(a, b); }; + auto expected = ReferenceUtil::ReduceWindow2DGeneric( + /*operand=*/input, /*init=*/kInitValue, /*reduce_func=*/reduce_func, + /*window=*/param.window_bounds, + /*stride=*/param.strides, /*padding=*/padding); ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), {input_arg.get()}, DefaultErrorSpec()); @@ -1074,8 +1118,9 @@ XLA_TEST_P(R2ReduceWindowFailingCpuGpuBf16Test, const R2ReduceWindowTestData kR2FailingValuesCpuGpuBf16Test[] = { {/*base_bounds=*/{8, 128}, /*window_bounds=*/{8, 128}, - /*strides=*/{1, 1}, /*layout=*/{1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, }; INSTANTIATE_TEST_CASE_P( @@ -1315,5 +1360,41 @@ ENTRY R2Window { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } +TEST_F(ReduceWindowTextTest, R2EffectiveScalar) { + const string& hlo_string = R"( +HloModule R2Window +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} +ENTRY R2Window { + operand = f32[1,1]{1,0} parameter(0) + negate = f32[1,1]{1,0} negate(operand) + constant = f32[] constant(1) + ROOT reduce-window = f32[1,1]{1,0} reduce-window(negate, constant), window={size=1x1 pad=0_0x0_0}, to_apply=mul +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); +} + +TEST_F(ReduceWindowTextTest, R3EffectiveScalar) { + const string& hlo_string = R"( +HloModule R3Window +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} +ENTRY R3Window { + operand = f32[1,1,1]{2,1,0} parameter(0) + negate = f32[1,1,1]{2,1,0} negate(operand) + constant = f32[] constant(1) + ROOT reduce-window = f32[1,1,1]{2,1,0} reduce-window(negate, constant), window={size=1x1x1 pad=0_0x0_0x0_0}, to_apply=mul +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index f7b04debd4f5c40a904e32c832b6fc384a03c33b..02272d60171c70896f44b0d6b96f176ea52e686f 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -207,9 +208,9 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) { // // Splits an empty vector into an empty matrix. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1({}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, @@ -221,10 +222,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { // Splits a vector into a matrix. XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, @@ -241,9 +242,9 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { // // Transposes a 2x0 array to a 0x2 array. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(0, 2)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, @@ -255,10 +256,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { // Transposes a 2-dimensional row vector to a column vector. XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); auto input_literal = Literal::CreateFromArray(*simple); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, @@ -272,10 +273,10 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { // Transposes a 2-dimensional array. XLA_TEST_P(ReshapeTest, TransposeAsReshape) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = Literal::CreateFromArray(*a4x3); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, @@ -291,11 +292,11 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { // does not handle zero-sized shapes correctly. Failed last on 2017-11-30 // with an incorrect result rank. // -// Transposes a 0x4 array with ComputationBuilder::Trans. +// Transposes a 0x4 array with XlaBuilder::Transpose. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(0, 4)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Transpose(parameter, {1, 0}); @@ -306,10 +307,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { // Transposes a 2-dimensional array with ComputationBuilder::Trans. XLA_TEST_P(ReshapeTest, Transpose4x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = Literal::CreateFromArray(*a4x3); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Transpose(parameter, {1, 0}); @@ -327,9 +328,9 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { // Reshapes an empty 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(6, 0)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index d7bda77e87f33938162f94dbee42b160906b4087..0c88bef69dfc522fef52422b0bd3a825fa173d44 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -860,6 +860,12 @@ XLA_TEST_F(ScalarComputationsTest, MinF32Below) { TestMinMax(-100.1f, 3.1f, -100.1f, &ComputationBuilder::Min); } +XLA_TEST_F(ScalarComputationsTest, MinPropagatesNan) { + SetFastMathDisabled(true); + TestMinMax(NAN, 3.1f, NAN, &ComputationBuilder::Min); + TestMinMax(-3.1f, NAN, NAN, &ComputationBuilder::Min); +} + XLA_TEST_F(ScalarComputationsTest, MaxF32Above) { TestMinMax(10.1f, 3.1f, 10.1f, &ComputationBuilder::Max); } @@ -868,6 +874,12 @@ XLA_TEST_F(ScalarComputationsTest, MaxF32Below) { TestMinMax(-100.1f, 3.1f, 3.1f, &ComputationBuilder::Max); } +XLA_TEST_F(ScalarComputationsTest, MaxPropagatesNan) { + SetFastMathDisabled(true); + TestMinMax(NAN, 3.1f, NAN, &ComputationBuilder::Max); + TestMinMax(-3.1f, NAN, NAN, &ComputationBuilder::Max); +} + XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { // Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20. ComputationBuilder b(client_, TestName()); diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index 9ee94b8571e5fc8789b60501462986967ce909a0..d268fdcacebcb162bf61bc7dd4b208f4db6c4a5f 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -252,6 +252,21 @@ XLA_TEST_F(SelectAndScatterTest, R2S32) { ComputeAndCompareR2(&builder_, expected, {}); } +// Test for tie breaking rule in ge_f32_. When a tie is present, the operand +// that has the lower lexicographical order (smaller index) should be chosen. +XLA_TEST_F(SelectAndScatterTest, R2F32Tie) { + const auto operand = builder_.ConstantR2( + {{0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}}); + const auto source = builder_.ConstantR2( + {{1.0f, 2.0f, 3.0f}, {4.f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}); + Array2D expected( + {{12.f, 9.f, 0.f}, {15.f, 9.f, 0.f}, {0.f, 0.f, 0.f}}); + builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3, 3}, + /*window_strides=*/{1, 1}, Padding::kSame, source, + builder_.ConstantR0(0.0f), add_f32_); + ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(1e-7)); +} + // Similar to SelectAndScatterTest.R2S32 but the input is transposed. XLA_TEST_F(SelectAndScatterTest, ReshapeR2S32) { const auto operand = builder_.ConstantR2( diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index fe36df160daacc4fdfbdb0b75f8304f91e1a4245..a14a365bd0529ba82a25cdfacfe3902a655c4876 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -193,7 +193,9 @@ class SliceR1Test : public ClientLibraryTestBase, protected: template void Run(const R1Spec& spec) { - std::vector input(spec.input_dim0); + // This can't be an std::vector, since you can't grab an ArraySlice of a + // vector. + tensorflow::gtl::InlinedVector input(spec.input_dim0); std::iota(input.begin(), input.end(), NativeT()); ComputationBuilder builder(client_, TestName()); @@ -201,7 +203,8 @@ class SliceR1Test : public ClientLibraryTestBase, builder.Slice(original, {spec.slice_start}, {spec.slice_limit}, {spec.slice_stride}); - std::vector expected; + // Ditto. + tensorflow::gtl::InlinedVector expected; for (int i = spec.slice_start; i < spec.slice_limit; i += spec.slice_stride) { expected.push_back(i); @@ -230,6 +233,8 @@ XLA_TEST_P(SliceR1Test, DoIt_U64) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_S64) { Run(GetParam()); } +XLA_TEST_P(SliceR1Test, DoIt_PRED) { Run(GetParam()); } + // Tests for R1 slice ops. // The format for each testcase is {input size, start, limit, stride}. // clang-format off diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc index 978a669bcab720bddec5c4bcd0144810ba3c8477..be35ec6c6ee4c015755622b2dc9bb92e23af7c85 100644 --- a/tensorflow/compiler/xla/tests/test_macros.cc +++ b/tensorflow/compiler/xla/tests/test_macros.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/regexp.h" namespace xla { diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 2029312f94a14bc81706368b9ecfc2727fd9fe4c..fa60af4b6a7d4f249b28be14357b8cad9a42c783 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -514,5 +515,33 @@ XLA_TEST_F(TupleTest, ComplexTuples) { error_spec_); } +class TupleHloTest : public HloTestBase {}; + +// Disabled on CPU parallel because that's broken and will be removed soon. +// Disabled on the interpreter because bitcast doesn't exist on the interpreter. +TEST_F(TupleHloTest, + DISABLED_ON_INTERPRETER(DISABLED_ON_CPU_PARALLEL(BitcastAfterGTE))) { + const char* testcase = R"( + HloModule m + + ENTRY test { + name.1 = (f32[3]{0}) parameter(0) + get-tuple-element.1 = f32[3]{0} get-tuple-element(name.1), index=0 + bitcast = f32[1,3]{1,0} bitcast(get-tuple-element.1) + copy = f32[1,3]{1,0} copy(bitcast) + ROOT tuple.4 = (f32[1,3]{1,0}) tuple(copy) + } + )"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::MakeTupleOwned(Literal::CreateR1({1, 2, 3})); + TF_ASSERT_OK_AND_ASSIGN(auto result, + ExecuteNoHloPasses(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, + *Literal::MakeTupleOwned(Literal::CreateR2({{1, 2, 3}})))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 52157b837c383205f77a030ef98b2fd03a41aff5..33d457c70bac84c2da10e3cf9302c2c952cf1bc2 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -910,7 +910,7 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Per backend the values generated can be different as the different backends // use different random number generators. // TODO(b/32240857): Extend test to verify outputs. -TEST_F(WhileTest, WhileWithPrngScalarResult) { +TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { auto v6s32 = ShapeUtil::MakeShape(S32, {6}); // Create a computation for the condition: repeat for count iterations. @@ -1166,7 +1166,7 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // while (f(result).get<0>()) { // result = result + 1; // } -TEST_F(WhileTest, WhileWithCallInsideCondition) { +TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 9ad2a1985331b80625dd0687ea052300bc99e440..24b9f37a8008b6f774634f2dbff9d3296ec0585b 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -144,7 +144,7 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr local_executable, client->Compile(computation, {&lhs_arg_shape, &rhs_arg_shape}, - ExecutableBuildOptions())); + ExecutableBuildOptions().set_hlo_profile(true))); Executable* executable = local_executable->executable(); HloExecutionProfile hlo_execution_profile( diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 091fa0c3ec807a66449eca0bfbb141285b8eb532..2e55f609d17bf42e410f97c51c7b9c6c0e85576d 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -75,6 +75,7 @@ cc_library( name = "replay_computation_library", srcs = ["replay_computation.cc"], deps = [ + "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index cd2b843ad36013ae83818ecbc184fb823093f037..e60a5a4919f2207939821e787c3c59a08ff3ba4e 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -1049,9 +1049,40 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateDot(shape, operands[0], operands[1], dnum)); break; } - case HloOpcode::kGather: - // TODO(b/72710576): HLO parsing is not implemented for Gather. - return TokenError("HLO parsing is not implemented for Gather"); + case HloOpcode::kGather: { + optional> output_window_dims; + attrs["output_window_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims}; + optional> elided_window_dims; + attrs["elided_window_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims}; + optional> gather_dims_to_operand_dims; + attrs["gather_dims_to_operand_dims"] = {/*required=*/true, + AttrTy::kBracedInt64List, + &gather_dims_to_operand_dims}; + optional index_vector_dim; + attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, + &index_vector_dim}; + optional> window_bounds; + attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List, + &window_bounds}; + + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + + GatherDimensionNumbers dim_numbers = HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/*output_window_dims, + /*elided_window_dims=*/*elided_window_dims, + /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims, + /*index_vector_dim=*/*index_vector_dim); + + instruction = builder->AddInstruction(HloInstruction::CreateGather( + shape, /*operand=*/operands[0], /*gather_indices=*/operands[1], + dim_numbers, *window_bounds)); + break; + } case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index b8c6b59204f897c7dc07b846370b5b776a19a808..863081d654390440aa6506bab4576b3cc5c1cbd1 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -716,6 +716,18 @@ ENTRY %sparse_f32_r1 () -> f32[9] { ROOT %foo = f32[9]sparse{10} constant(f32[9]{1: 2, 3: 4, 5: 6}) } +)" +}, +{ +"gather", +R"(HloModule StringifyGather + +ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] { + %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) + %gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} +} + )" }, }); @@ -860,6 +872,18 @@ ENTRY dot { ROOT dot = f32[2,3]{1,0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={0} } +)" +}, +{ +"gather", +R"(HloModule gather + +ENTRY Gather { + input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) + gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} +} + )" }, }); diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index eda5effbb92db92c9317a956497a00c0ec15c27c..62a353ad09af009e4abf47664a5c5f7bd70a049e 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -66,6 +67,7 @@ struct Options { bool use_fake_data = false; bool print_result = true; int num_runs = 1; + bool xla_hlo_profile_last_run = false; }; // Invokes the given computation passing arbitrary data for every (unbound) @@ -122,16 +124,21 @@ StatusOr> ReplayComputation( std::unique_ptr result; for (int i = 0; i < opts.num_runs; ++i) { ExecutionProfile profile; + ExecutionOptions execution_options = CreateDefaultExecutionOptions(); + if (opts.xla_hlo_profile_last_run && i == opts.num_runs - 1) { + execution_options.mutable_debug_options()->set_xla_hlo_profile(true); + } + if (opts.print_result) { - TF_ASSIGN_OR_RETURN(result, client->ExecuteAndTransfer( - computation, execute_arguments, - /*execution_options=*/nullptr, &profile)); + TF_ASSIGN_OR_RETURN( + result, client->ExecuteAndTransfer(computation, execute_arguments, + &execution_options, &profile)); } else { // If we're not printing the result, execute the computation but don't // bother retrieving the result. This can be a significant speedup. TF_RETURN_IF_ERROR(client ->Execute(computation, execute_arguments, - /*execution_options=*/nullptr, &profile) + &execution_options, &profile) .status()); } LOG(INFO) << "Execution took " @@ -191,6 +198,9 @@ int main(int argc, char** argv) { "Number of times to run each computation"), tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape, "Shape of fake data to construct for (infinite) infeed"), + tensorflow::Flag( + "xla_hlo_profile_last_run", &opts.xla_hlo_profile_last_run, + "Pass --xla_hlo_profile the last time we run the computation."), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 1f0c626bbb2d64ef4e67c9ec51485ae96ae73d04..dc4f7a1cb436183f5acfa360fb092795258b6a75 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" -#include #include #include @@ -292,7 +291,8 @@ void LogLines(int sev, tensorflow::StringPiece text, const char* fname, } int64 Product(tensorflow::gtl::ArraySlice xs) { - return std::accumulate(xs.begin(), xs.end(), 1, std::multiplies()); + return std::accumulate(xs.begin(), xs.end(), static_cast(1), + std::multiplies()); } std::vector> CommonFactors( diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index e14c8cefa1d16e0a749e7a2c022a24a1c5083b15..2da9f9ed6f40fcf5b2512f974519df0b355da10f 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include #include #include "tensorflow/compiler/xla/status.h" @@ -427,30 +428,37 @@ std::vector> CommonFactors( string SanitizeFileName(string file_name); template -bool c_all_of(Container container, Predicate&& predicate) { +bool c_all_of(const Container& container, Predicate&& predicate) { return std::all_of(std::begin(container), std::end(container), std::forward(predicate)); } +template +bool c_any_of(const Container& container, Predicate&& predicate) { + return std::any_of(std::begin(container), std::end(container), + std::forward(predicate)); +} + template -OutputIterator c_transform(InputContainer input_container, +OutputIterator c_transform(const InputContainer& input_container, OutputIterator output_iterator, - UnaryOperation unary_op) { + UnaryOperation&& unary_op) { return std::transform(std::begin(input_container), std::end(input_container), - output_iterator, unary_op); + output_iterator, + std::forward(unary_op)); } template -OutputIterator c_copy_if(InputContainer input_container, +OutputIterator c_copy_if(const InputContainer& input_container, OutputIterator output_iterator, - UnaryPredicate predicate) { + UnaryPredicate&& predicate) { return std::copy_if(std::begin(input_container), std::end(input_container), - output_iterator, predicate); + output_iterator, std::forward(predicate)); } template -OutputIterator c_copy(InputContainer input_container, +OutputIterator c_copy(const InputContainer& input_container, OutputIterator output_iterator) { return std::copy(std::begin(input_container), std::end(input_container), output_iterator); @@ -468,7 +476,7 @@ void c_sort(InputContainer& input_container, Comparator&& comparator) { } template -bool c_binary_search(Sequence& sequence, T&& value) { +bool c_binary_search(const Sequence& sequence, T&& value) { return std::binary_search(std::begin(sequence), std::end(sequence), std::forward(value)); } @@ -487,6 +495,39 @@ template auto c_find_if(const C& c, Pred&& pred) -> decltype(std::begin(c)) { return std::find_if(std::begin(c), std::end(c), std::forward(pred)); } + +template +auto c_find(const C& c, Value&& value) -> decltype(std::begin(c)) { + return std::find(std::begin(c), std::end(c), std::forward(value)); +} + +template +void c_reverse(Sequence& sequence) { + std::reverse(std::begin(sequence), std::end(sequence)); +} + +template +typename std::decay::type c_accumulate(const Sequence& sequence, T&& init, + BinaryOp&& binary_op) { + return std::accumulate(std::begin(sequence), std::end(sequence), + std::forward(init), + std::forward(binary_op)); +} + +template +int64 FindIndex(const C& c, Value&& value) { + auto it = c_find(c, std::forward(value)); + return std::distance(c.begin(), it); +} + +// Returns true if `x` fits in 32-bits. +template +bool IsInt32(T x) { + // Following conversion rules: "the value is unchanged if it can be + // represented in the destination type (and bit-field width); otherwise, the + // value is implementation-defined." + return static_cast(x) == x; +} } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \ diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 56162ab44e2e0e3e4478fe631888f243332dc1d8..edf1b07af82b5d43fe67c6efdabdb0a9b4b1edea 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -16,6 +16,7 @@ limitations under the License. syntax = "proto3"; import "tensorflow/compiler/xla/xla_data.proto"; +import "tensorflow/compiler/xla/service/hlo.proto"; import "tensorflow/compiler/xla/service/session.proto"; package xla; @@ -342,6 +343,14 @@ message ExecuteRequest { ExecutionOptions execution_options = 5; } +message ExecuteGraphRequest { + HloModuleProto computation = 1; + repeated GlobalDataHandle arguments = 2; + + // Options that affect how XLA compiles and runs code to service this request. + ExecutionOptions execution_options = 3; +} + message ExecuteParallelRequest { repeated ExecuteRequest requests = 1; } diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 28620c3b86349281573eaf57d2838bee1488d838..1f16e6d25178fd9c10a30b0c500e090ee2e08117 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -418,6 +418,10 @@ message GatherDimensionNumbers { // transforms the gather index looked up from the gather_indices tensor into // the starting index in the input space. repeated int64 gather_dims_to_operand_dims = 3; + + // The dimension in the gather_indices input that contains the starting + // indices. + int64 index_vector_dim = 4; } // Operation requests that are all collected as a tagged union with a oneof diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index bab37e8906e5c648acdc1556da7e5f4601776ff5..fb81b50fe8e29a2e4cb7d127fd4b2b6778da763c 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -51,7 +51,6 @@ py_library( "//tensorflow/contrib/image:single_image_random_dot_stereograms_py", "//tensorflow/contrib/input_pipeline:input_pipeline_py", "//tensorflow/contrib/integrate:integrate_py", - "//tensorflow/contrib/kafka", "//tensorflow/contrib/keras", "//tensorflow/contrib/kernel_methods", "//tensorflow/contrib/kfac", @@ -79,7 +78,7 @@ py_library( "//tensorflow/contrib/predictor", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", - "//tensorflow/contrib/py2tf", + "//tensorflow/contrib/autograph", "//tensorflow/contrib/receptive_field:receptive_field_py", "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py", "//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py", @@ -110,7 +109,13 @@ py_library( "//tensorflow/python:util", ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_tensorrt([ "//tensorflow/contrib/tensorrt:init_py", - ]), + ]) + select({ + "//tensorflow:with_kafka_support_windows_override": [], + "//tensorflow:with_kafka_support": [ + "//tensorflow/contrib/kafka", + ], + "//conditions:default": [], + }), ) cc_library( @@ -119,7 +124,6 @@ cc_library( deps = [ "//tensorflow/contrib/boosted_trees:boosted_trees_kernels", "//tensorflow/contrib/coder:all_kernels", - "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_kernels", "//tensorflow/contrib/data/kernels:dataset_kernels", "//tensorflow/contrib/factorization/kernels:all_kernels", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels", @@ -133,7 +137,13 @@ cc_library( "//tensorflow/contrib/text:all_kernels", ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_cuda([ "//tensorflow/contrib/nccl:nccl_kernels", - ]), + ]) + select({ + "//tensorflow:with_kafka_support_windows_override": [], + "//tensorflow:with_kafka_support": [ + "//tensorflow/contrib/kafka:dataset_kernels", + ], + "//conditions:default": [], + }), ) cc_library( @@ -142,12 +152,10 @@ cc_library( deps = [ "//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib", "//tensorflow/contrib/coder:all_ops", - "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_ops_op_lib", "//tensorflow/contrib/data:dataset_ops_op_lib", "//tensorflow/contrib/factorization:all_ops", "//tensorflow/contrib/framework:all_ops", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib", - "//tensorflow/contrib/kafka:kafka_ops_op_lib", "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", "//tensorflow/contrib/nccl:nccl_ops_op_lib", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_op_lib", @@ -158,7 +166,13 @@ cc_library( "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib", "//tensorflow/contrib/text:all_ops", "//tensorflow/contrib/tpu:all_ops", - ], + ] + select({ + "//tensorflow:with_kafka_support_windows_override": [], + "//tensorflow:with_kafka_support": [ + "//tensorflow/contrib/kafka:dataset_ops_op_lib", + ], + "//conditions:default": [], + }), ) filegroup( diff --git a/tensorflow/contrib/android/cmake/CMakeLists.txt b/tensorflow/contrib/android/cmake/CMakeLists.txt index a115d1610e2334a6626f29674f3dd195e3a3c648..ecf1a103d2981f409a4598d762fb26100217f779 100644 --- a/tensorflow/contrib/android/cmake/CMakeLists.txt +++ b/tensorflow/contrib/android/cmake/CMakeLists.txt @@ -75,7 +75,6 @@ target_link_libraries(tensorflow_inference include_directories( ${PREBUILT_DIR}/proto ${PREBUILT_DIR}/protobuf/include - ${PREBUILT_DIR}/nsync/public ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/downloads/eigen ${TENSORFLOW_ROOT_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/..) diff --git a/tensorflow/contrib/py2tf/BUILD b/tensorflow/contrib/autograph/BUILD similarity index 75% rename from tensorflow/contrib/py2tf/BUILD rename to tensorflow/contrib/autograph/BUILD index d91220f6ddb859ff52d4e5853948cb667981009b..30dd846893c30b9205972bd5216cc1871ab03d76 100644 --- a/tensorflow/contrib/py2tf/BUILD +++ b/tensorflow/contrib/autograph/BUILD @@ -15,16 +15,16 @@ filegroup( ) py_library( - name = "py2tf", + name = "autograph", srcs = [ "__init__.py", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/py2tf/impl", - "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/utils", + "//tensorflow/contrib/autograph/impl", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/utils", "@gast_archive//:gast", "@six_archive//:six", ], diff --git a/tensorflow/contrib/py2tf/README.md b/tensorflow/contrib/autograph/README.md similarity index 87% rename from tensorflow/contrib/py2tf/README.md rename to tensorflow/contrib/autograph/README.md index cd50675ad57316b9c749c137e6acd30b91c10073..7e84f237dc9a83098f142a54c48cf5b6ba35aaaa 100644 --- a/tensorflow/contrib/py2tf/README.md +++ b/tensorflow/contrib/autograph/README.md @@ -1,4 +1,4 @@ -# Py2TF +# Autograph A compiler for generating TensorFlow numeric and control flow ops from Python code. diff --git a/tensorflow/contrib/py2tf/__init__.py b/tensorflow/contrib/autograph/__init__.py similarity index 59% rename from tensorflow/contrib/py2tf/__init__.py rename to tensorflow/contrib/autograph/__init__.py index 379fa7fd5c2a22b5b16a21cca8c2ea8afdcaeefa..a39f44b21aa0ddf683b30c18bbe15a43262f7db2 100644 --- a/tensorflow/contrib/py2tf/__init__.py +++ b/tensorflow/contrib/autograph/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Py2TF compiles Python code into equivalent TensorFlow code. +"""Autograph compiles Python code into equivalent TensorFlow code. Equivalent here means that they have the same effect when executed. """ @@ -21,16 +21,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf import utils -from tensorflow.contrib.py2tf.impl.api import convert -from tensorflow.contrib.py2tf.impl.api import graph_ready -from tensorflow.contrib.py2tf.impl.api import to_code -from tensorflow.contrib.py2tf.impl.api import to_graph -from tensorflow.contrib.py2tf.pyct.transformer import PyFlowParseError +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.impl.api import convert +from tensorflow.contrib.autograph.impl.api import converted_call +from tensorflow.contrib.autograph.impl.api import do_not_convert +from tensorflow.contrib.autograph.impl.api import RunMode +from tensorflow.contrib.autograph.impl.api import to_code +from tensorflow.contrib.autograph.impl.api import to_graph +from tensorflow.contrib.autograph.pyct.transformer import AutographParseError from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'to_graph', 'to_code', 'convert', 'graph_ready', 'utils', 'PyFlowParseError' + 'utils', 'convert', 'converted_call', 'do_not_convert', 'RunMode', + 'to_code', 'to_graph', 'AutographParseError' ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/py2tf/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD similarity index 78% rename from tensorflow/contrib/py2tf/converters/BUILD rename to tensorflow/contrib/autograph/converters/BUILD index 42baaaaba72c6c3dbc896e87b6e0f3c62b7f06fc..608bd82722fa45a7009bd597cfd74060b1239a3b 100644 --- a/tensorflow/contrib/py2tf/converters/BUILD +++ b/tensorflow/contrib/autograph/converters/BUILD @@ -25,10 +25,13 @@ py_library( "control_flow.py", "decorators.py", "for_loops.py", + "ifexp.py", "list_comprehension.py", + "lists.py", "logical_expressions.py", "name_scopes.py", "side_effect_guards.py", + "single_return.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], @@ -46,8 +49,9 @@ py_library( visibility = ["//tensorflow:__subpackages__"], deps = [ ":converters", - "//tensorflow/contrib/py2tf/pyct/static_analysis", - "//tensorflow/contrib/py2tf/utils", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/pyct/static_analysis", + "//tensorflow/contrib/autograph/utils", "@gast_archive//:gast", "@six_archive//:six", ], @@ -59,7 +63,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -70,7 +73,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -81,18 +83,18 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) py_test( name = "call_trees_test", + size = "large", srcs = ["call_trees_test.py"], srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/impl", "//tensorflow/python:client_testlib", ], ) @@ -103,7 +105,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -114,7 +115,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -125,7 +125,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -136,7 +135,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -146,7 +144,7 @@ py_test( srcs = ["name_scopes_test.py"], deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], ) @@ -157,7 +155,16 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "lists_test", + srcs = ["lists_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":test_lib", "//tensorflow/python:client_testlib", ], ) @@ -168,7 +175,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -184,7 +190,28 @@ py_test( ], deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "single_return_test", + srcs = ["single_return_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":test_lib", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "ifexp_test", + srcs = ["ifexp_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":test_lib", + "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], ) diff --git a/tensorflow/contrib/py2tf/converters/__init__.py b/tensorflow/contrib/autograph/converters/__init__.py similarity index 95% rename from tensorflow/contrib/py2tf/converters/__init__.py rename to tensorflow/contrib/autograph/converters/__init__.py index ca10896ee5c6c23d9b20ff23add9945de68e5bf9..e4e8eda42f655e204310eaa9defdd5c90bf06e15 100644 --- a/tensorflow/contrib/py2tf/converters/__init__.py +++ b/tensorflow/contrib/autograph/converters/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Code converters used by Py2TF.""" +"""Code converters used by Autograph.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/py2tf/converters/asserts.py b/tensorflow/contrib/autograph/converters/asserts.py similarity index 93% rename from tensorflow/contrib/py2tf/converters/asserts.py rename to tensorflow/contrib/autograph/converters/asserts.py index 5b9b8e772bed82df2429fd6cb94dbf7b565e22b3..f011a97ade94f2979486ef6329673a0160dd9bac 100644 --- a/tensorflow/contrib/py2tf/converters/asserts.py +++ b/tensorflow/contrib/autograph/converters/asserts.py @@ -20,8 +20,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer class AssertsTransformer(transformer.Base): diff --git a/tensorflow/contrib/py2tf/converters/asserts_test.py b/tensorflow/contrib/autograph/converters/asserts_test.py similarity index 90% rename from tensorflow/contrib/py2tf/converters/asserts_test.py rename to tensorflow/contrib/autograph/converters/asserts_test.py index 6611f2777a93a7e819c8becfa06a09b27f4e6aaf..cc913febe8d0f411588af69b87ec52ce58f4469c 100644 --- a/tensorflow/contrib/py2tf/converters/asserts_test.py +++ b/tensorflow/contrib/autograph/converters/asserts_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.converters import asserts -from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.autograph.converters import asserts +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py similarity index 93% rename from tensorflow/contrib/py2tf/converters/break_statements.py rename to tensorflow/contrib/autograph/converters/break_statements.py index bfb709c5e32c6f19dc0fd109df61ece925d701a3..721bc0ccd0a00d09d7b308df867ef3839bb08d43 100644 --- a/tensorflow/contrib/py2tf/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -20,10 +20,10 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno class BreakCanonicalizationTransformer(transformer.Base): diff --git a/tensorflow/contrib/py2tf/converters/break_statements_test.py b/tensorflow/contrib/autograph/converters/break_statements_test.py similarity index 95% rename from tensorflow/contrib/py2tf/converters/break_statements_test.py rename to tensorflow/contrib/autograph/converters/break_statements_test.py index 095fcdff07d44ecc6b9bb7f8d3e2c7c43df72a02..dd4914a022f57b3bb4a19ec132f311f12269fa9e 100644 --- a/tensorflow/contrib/py2tf/converters/break_statements_test.py +++ b/tensorflow/contrib/autograph/converters/break_statements_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import break_statements -from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.autograph.converters import break_statements +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py similarity index 80% rename from tensorflow/contrib/py2tf/converters/builtin_functions.py rename to tensorflow/contrib/autograph/converters/builtin_functions.py index e69038acedd6b7a251c3328a14f36ed107bde746..0349ce29ceb097fbebc36a0378b9072750772416 100644 --- a/tensorflow/contrib/py2tf/converters/builtin_functions.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions.py @@ -20,8 +20,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer class BuiltinFunctionTransformer(transformer.Base): @@ -36,23 +36,24 @@ class BuiltinFunctionTransformer(transformer.Base): # pylint:disable=invalid-name - def _convert_len(self, node): + def _convert_builtin(self, node): template = """ - py2tf_utils.dynamic_len(args) + autograph_utils.dynamic_builtin(func, args) """ - return templates.replace(template, args=node.args)[0].value + return templates.replace(template, func=node.func, args=node.args)[0].value def _convert_print(self, node): template = """ - py2tf_utils.call_print(args) + autograph_utils.dynamic_print(args) """ return templates.replace(template, args=node.args)[0].value def visit_Call(self, node): self.generic_visit(node) # TODO(mdan): This won't work if the function was hidden. - if isinstance(node.func, gast.Name) and node.func.id == 'len': - return self._convert_len(node) + if isinstance(node.func, gast.Name) and node.func.id in ('len', 'range'): + return self._convert_builtin(node) + # Print needs to be handled separately because it can be read as statement. if isinstance(node.func, gast.Name) and node.func.id == 'print': return self._convert_print(node) return node diff --git a/tensorflow/contrib/py2tf/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py similarity index 96% rename from tensorflow/contrib/py2tf/converters/builtin_functions_test.py rename to tensorflow/contrib/autograph/converters/builtin_functions_test.py index eb60a1d8ae2b56907df8f3ffafe7604883cfc2a9..ac7e756c47c31816ad34a7ea6926917712afa6c3 100644 --- a/tensorflow/contrib/py2tf/converters/builtin_functions_test.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py @@ -22,8 +22,8 @@ import sys import six -from tensorflow.contrib.py2tf.converters import builtin_functions -from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.autograph.converters import builtin_functions +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import logging_ops diff --git a/tensorflow/contrib/py2tf/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py similarity index 70% rename from tensorflow/contrib/py2tf/converters/call_trees.py rename to tensorflow/contrib/autograph/converters/call_trees.py index 1050ba654c63bb52c1c5e71c981a6a0baa3fc987..61f6bfd7e733fc3e2e0bea35a955509c39d57bc9 100644 --- a/tensorflow/contrib/py2tf/converters/call_trees.py +++ b/tensorflow/contrib/autograph/converters/call_trees.py @@ -22,17 +22,30 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from collections import namedtuple import types import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import inspect_utils +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect +class FunctionInfo(namedtuple('FunctionInfo', ('dtype',))): + pass + + +# TODO(mdan): Move this to config.py. +KNOWN_NUMPY_FUNCTIONS = { + ('numpy', 'random', 'binomial'): FunctionInfo(dtype='tf.int64'), +} + + class FunctionNamer(object): """Describes the interface for CallTreeTransformer's namer.""" @@ -72,9 +85,8 @@ class CallTreeTransformer(transformer.Base): self.uncompiled_modules = uncompiled_modules self.nocompile_decorators = nocompile_decorators - # pylint:disable=invalid-name - def _resolve_name(self, node): + """Used to resolve decorator info.""" if isinstance(node, gast.Call): return self._resolve_name(node.func) if isinstance(node, gast.Name): @@ -99,7 +111,19 @@ class CallTreeTransformer(transformer.Base): (owner_type, node.attr)) return None + def _function_is_compilable(self, target_entity): + """Determines whether an entity can be compiled at all.""" + # TODO(mdan): This is just a placeholder. Implement. + return not isinstance(target_entity, types.BuiltinFunctionType) + def _should_compile(self, node, fqn): + """Determines whether an entity should be compiled in the context.""" + # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether. + module_name = fqn[0] + for mod in self.uncompiled_modules: + if module_name.startswith(mod[0] + '.'): + return False + for i in range(1, len(fqn)): if fqn[:i] in self.uncompiled_modules: return False @@ -141,33 +165,6 @@ class CallTreeTransformer(transformer.Base): return True - def _determine_function_owner(self, m): - # TODO(mdan): The parent type should be known at analysis. Use that instead. - if hasattr(m, 'im_class'): # Python 2 - return m.im_class - if hasattr(m, '__qualname__'): # Python 3 - # Object attributes: should be bound to "self". - if hasattr(m, '__self__'): - return type(m.__self__) - - # Class attributes: should have the owner name in their namespace. - qn = m.__qualname__.split('.') - if len(qn) < 2: - return None - owner_name, func_name = qn[-2:] - if func_name != m.__name__: - raise ValueError('Inconsistent names detected ' - '(__qualname__[1] = "%s", __name__ = "%s") for %s.' % - (func_name, m.__name__, m)) - if owner_name == '': - return None - if owner_name not in self.context.namespace: - raise ValueError( - 'Could not resolve name "%s" while analyzing %s. Namespace:\n%s' % - (owner_name, m, self.context.namespace)) - return self.context.namespace[owner_name] - return None - def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') @@ -182,7 +179,11 @@ class CallTreeTransformer(transformer.Base): target_fqn, live_entity=target_entity) do_rename = True else: - owner_type = self._determine_function_owner(target_entity) + if anno.hasanno(node.func, 'parent_type'): + owner_type = anno.getanno(node.func, 'parent_type') + else: + # Fallback - not reliable. + owner_type = inspect_utils.getmethodclass(target_entity) new_name, do_rename = self.context.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) @@ -196,15 +197,56 @@ class CallTreeTransformer(transformer.Base): return node def _wrap_to_py_func_no_return(self, node): - # TODO(mdan): Properly handle varargs, kwargs, etc. + # TODO(mdan): Properly handle varargs, etc. + template = """ + autograph_utils.wrap_py_func(func, None, (args,), kwargs, True) + """ + return templates.replace( + template, + func=node.func, + args=node.args, + kwargs=ast_util.keywords_to_dict(node.keywords)) + + def _wrap_to_py_func_single_return(self, node, dtype): + # TODO(mdan): Properly handle varargs, etc. + template = """ + autograph_utils.wrap_py_func(func, dtype, (args,), kwargs, False) + """ + return templates.replace_as_expression( + template, + func=node.func, + dtype=parser.parse_expression(dtype), + args=node.args, + kwargs=ast_util.keywords_to_dict(node.keywords)) + + def _insert_dynamic_conversion(self, node): + """Inlines a dynamic conversion for a dynamic function.""" + # TODO(mdan): Pass information on the statically compiled functions. + # Having access to the statically compiled functions can help avoid + # unnecessary compilation. + # For example, this would lead to function `a` being compiled twice: + # + # def a(): + # v = b + # b() + # def b(): + # a() + # + # This is really a problem with recursive calls, which currently can + # only be gated by a static condition, and should be rare. + # TODO(mdan): It probably makes sense to use dynamic conversion every time. + # Before we could convert all the time though, we'd need a reasonable + # caching mechanism. template = """ - py2tf_utils.wrap_py_func(func, None, (original_args,), True) + autograph_api.converted_call(func, True, False, {}, args) """ - return templates.replace(template, func=node.func, original_args=node.args) + call_expr = templates.replace(template, func=node.func, args=node.args) + new_call = call_expr[0].value + # TODO(mdan): Improve the template mechanism to better support this. + new_call.keywords = node.keywords + return new_call - def _function_is_compilable(self, target_entity): - # TODO(mdan): This is just a placeholder. Implement. - return not isinstance(target_entity, types.BuiltinFunctionType) + # pylint:disable=invalid-name def visit_Expr(self, node): if isinstance(node.value, gast.Call): @@ -239,15 +281,24 @@ class CallTreeTransformer(transformer.Base): self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') + if anno.hasanno(node.func, 'fqn'): + target_fqn = anno.getanno(node.func, 'fqn') + else: + target_fqn = None if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) + elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: + # TODO(mdan): Should we replace these with equivalent TF ops instead? + node = self._wrap_to_py_func_single_return( + node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) else: - raise NotImplementedError('py_func with return values') + raise NotImplementedError( + 'py_func with return values (unknown function)') else: if self.context.recursive: - raise NotImplementedError('Could not resolve target function.') + node = self._insert_dynamic_conversion(node) else: - # TODO(mdan): Double check. Is this reachable code? + # Unresolved functions are allowed in non-recursive mode. pass return node diff --git a/tensorflow/contrib/py2tf/converters/call_trees_test.py b/tensorflow/contrib/autograph/converters/call_trees_test.py similarity index 75% rename from tensorflow/contrib/py2tf/converters/call_trees_test.py rename to tensorflow/contrib/autograph/converters/call_trees_test.py index 777648dc0b31863227262fbf931aba680bb4ed98..c666dcb73b232ce443898cfe3359f74605af98f2 100644 --- a/tensorflow/contrib/py2tf/converters/call_trees_test.py +++ b/tensorflow/contrib/autograph/converters/call_trees_test.py @@ -18,9 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import call_trees -from tensorflow.contrib.py2tf.converters import converter_test_base +import numpy as np + +from tensorflow.contrib.autograph.converters import call_trees +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -47,6 +51,21 @@ class CallTreesTest(converter_test_base.TestCase): result.renamed_test_fn_1 = renamed_test_fn_1 self.assertEquals(3, result.test_fn_2(1)) + def test_dynamic_function(self): + + def test_fn_1(): + raise ValueError('This should be masked by the mock.') + + def test_fn_2(f): + return f() + 3 + + node = self.parse_and_analyze(test_fn_2, {}) + node = call_trees.transform(node, self.ctx, (), ()) + + with self.compiled(node) as result: + # 10 = 7 (from the mock) + 3 (from test_fn_2) + self.assertEquals(10, result.test_fn_2(test_fn_1)) + def test_simple_methods(self): class TestClass(object): @@ -59,6 +78,7 @@ class CallTreesTest(converter_test_base.TestCase): node = self.parse_and_analyze( TestClass.test_fn_2, {'TestClass': TestClass}, + namer=converter_test_base.FakeNoRenameNamer(), arg_types={'self': (TestClass.__name__, TestClass)}) node = call_trees.transform(node, self.ctx, (), ()) @@ -89,6 +109,20 @@ class CallTreesTest(converter_test_base.TestCase): sess.run(sess.graph.get_operations()[0]) self.assertEquals('bar', a.foo) + def test_py_func_wrap_known_function(self): + + def test_fn(): + return np.random.binomial(2, 0.5) + + node = self.parse_and_analyze(test_fn, {'np': np}) + node = call_trees.transform(node, self.ctx, (), ()) + + with self.compiled(node, dtypes.int64) as result: + result.np = np + with self.test_session() as sess: + self.assertTrue(isinstance(result.test_fn(), ops.Tensor)) + self.assertIn(sess.run(result.test_fn()), (0, 1, 2)) + def test_uncompiled_modules(self): def test_fn(a): diff --git a/tensorflow/contrib/py2tf/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py similarity index 94% rename from tensorflow/contrib/py2tf/converters/continue_statements.py rename to tensorflow/contrib/autograph/converters/continue_statements.py index 4069a678b118b56b59d2e5491bb80cf52efd8143..4299a8a9d59715d032222c47794bbb4393f34ce6 100644 --- a/tensorflow/contrib/py2tf/converters/continue_statements.py +++ b/tensorflow/contrib/autograph/converters/continue_statements.py @@ -18,10 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno class ContinueCanonicalizationTransformer(transformer.Base): diff --git a/tensorflow/contrib/py2tf/converters/continue_statements_test.py b/tensorflow/contrib/autograph/converters/continue_statements_test.py similarity index 95% rename from tensorflow/contrib/py2tf/converters/continue_statements_test.py rename to tensorflow/contrib/autograph/converters/continue_statements_test.py index a598dcd1aed29478b7e3fe27e3c1b20010247dd9..bcbb316d7459aa5a25bb0bd128cd6e359a393288 100644 --- a/tensorflow/contrib/py2tf/converters/continue_statements_test.py +++ b/tensorflow/contrib/autograph/converters/continue_statements_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import continue_statements -from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.autograph.converters import continue_statements +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py similarity index 88% rename from tensorflow/contrib/py2tf/converters/control_flow.py rename to tensorflow/contrib/autograph/converters/control_flow.py index d53e3e4fd6d87004cbe55bd430346ad263e898ea..49d932026ffa9e79e7ddc640f7d3deaec0f4b8a6 100644 --- a/tensorflow/contrib/py2tf/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -20,11 +20,11 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import ast_util -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno class SymbolNamer(object): @@ -82,7 +82,7 @@ class ControlFlowTransformer(transformer.Base): def _create_cond_expr(self, results, test, body_name, orelse_name): if results is not None: template = """ - results = py2tf_utils.run_cond(test, body_name, orelse_name) + results = autograph_utils.run_cond(test, body_name, orelse_name) """ return templates.replace( template, @@ -92,7 +92,7 @@ class ControlFlowTransformer(transformer.Base): orelse_name=orelse_name) else: template = """ - py2tf_utils.run_cond(test, body_name, orelse_name) + autograph_utils.run_cond(test, body_name, orelse_name) """ return templates.replace( template, test=test, body_name=body_name, orelse_name=orelse_name) @@ -171,6 +171,14 @@ class ControlFlowTransformer(transformer.Base): all_referenced = body_scope.referenced state = list(body_closure) + if not state: + # TODO(mdan): Implement this properly. + # To complete this statement, we need to check whether any variable + # created inside the body scope is used before being modified outside the + # scope. This should be done during activity analysis, and in general + # should cover the case where variables may not be initialized. + raise ValueError('cannot convert while loop: no outputs') + state_ssf = [ self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state ] @@ -196,7 +204,7 @@ class ControlFlowTransformer(transformer.Base): def body_name(state_ssf): body return state_ssf, - state_ast_tuple = py2tf_utils.run_while(test_name, body_name, [state]) + state_ast_tuple = autograph_utils.run_while(test_name, body_name, [state]) """ node = templates.replace( template, diff --git a/tensorflow/contrib/py2tf/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py similarity index 95% rename from tensorflow/contrib/py2tf/converters/control_flow_test.py rename to tensorflow/contrib/autograph/converters/control_flow_test.py index b785b284a7fb7a0257551326c88b44a341b295ba..86fed51f27bee07f772633f3928ac5263bf57652 100644 --- a/tensorflow/contrib/py2tf/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import control_flow -from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.autograph.converters import control_flow +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.framework import constant_op from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/converters/converter_test_base.py b/tensorflow/contrib/autograph/converters/converter_test_base.py similarity index 62% rename from tensorflow/contrib/py2tf/converters/converter_test_base.py rename to tensorflow/contrib/autograph/converters/converter_test_base.py index afa5c2f96fb55302e67e5ecac3532cb87827871a..3ea2cfd668270a69427c24cdf1bbf11d32d66ebe 100644 --- a/tensorflow/contrib/py2tf/converters/converter_test_base.py +++ b/tensorflow/contrib/autograph/converters/converter_test_base.py @@ -21,14 +21,15 @@ from __future__ import print_function import contextlib import imp -from tensorflow.contrib.py2tf import utils -from tensorflow.contrib.py2tf.pyct import compiler -from tensorflow.contrib.py2tf.pyct import context -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.static_analysis import activity -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import pretty_printer +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.platform import test @@ -52,26 +53,49 @@ class FakeNamer(object): return ('renamed_%s' % '_'.join(original_fqn)), True +class FakeNoRenameNamer(FakeNamer): + + def compiled_function_name(self, original_fqn, **_): + return str(original_fqn), False + + class TestCase(test.TestCase): """Base class for unit tests in this module. Contains relevant utilities.""" @contextlib.contextmanager def compiled(self, node, *symbols): - source = '' + source = None + + self.dynamic_calls = [] + def converted_call(*args): + """Mock version of api.converted_call.""" + self.dynamic_calls.append(args) + return 7 + try: result, source = compiler.ast_to_object(node) - result.tf = self.make_fake_tf(*symbols) - result.py2tf_utils = utils + result.tf = self.make_fake_mod('fake_tf', *symbols) + result.autograph_utils = utils + result.autograph_api = self.make_fake_mod('fake_api', converted_call) yield result except Exception: # pylint:disable=broad-except - print('Offending compiled code:\n%s' % source) + if source is None: + print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) + else: + print('Offending compiled code:\n%s' % source) raise - def make_fake_tf(self, *symbols): - fake_tf = imp.new_module('fake_tf') + def make_fake_mod(self, name, *symbols): + fake_mod = imp.new_module(name) for s in symbols: - setattr(fake_tf, s.__name__, s) - return fake_tf + if hasattr(s, '__name__'): + setattr(fake_mod, s.__name__, s) + elif hasattr(s, 'name'): + # This is a bit of a hack, but works for things like tf.int32 + setattr(fake_mod, s.name, s) + else: + raise ValueError('can not attach %s - what should be its name?' % s) + return fake_mod def attach_namespace(self, module, **ns): for k, v in ns.items(): @@ -94,7 +118,8 @@ class TestCase(test.TestCase): arg_values=None, arg_types=arg_types, owner_type=owner_type, - recursive=recursive) + recursive=recursive, + type_annotation_func=utils.set_element_type) node = qual_names.resolve(node) node = activity.resolve(node, ctx) node = live_values.resolve(node, ctx, {}) diff --git a/tensorflow/contrib/py2tf/converters/decorators.py b/tensorflow/contrib/autograph/converters/decorators.py similarity index 96% rename from tensorflow/contrib/py2tf/converters/decorators.py rename to tensorflow/contrib/autograph/converters/decorators.py index 68bf241ef33292f0581ccb3c44f313f853c92ba7..92445f31746cf94856ea43893f99a2ba60355fb5 100644 --- a/tensorflow/contrib/py2tf/converters/decorators.py +++ b/tensorflow/contrib/autograph/converters/decorators.py @@ -24,8 +24,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import pretty_printer +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import pretty_printer class DecoratorsTransformer(gast.NodeTransformer): diff --git a/tensorflow/contrib/py2tf/converters/decorators_test.py b/tensorflow/contrib/autograph/converters/decorators_test.py similarity index 95% rename from tensorflow/contrib/py2tf/converters/decorators_test.py rename to tensorflow/contrib/autograph/converters/decorators_test.py index c75e5461746f27d14a54b7ac06e7f77d868372c8..e67ab1cd6a15ceb66fe75140419c7abca9653ae4 100644 --- a/tensorflow/contrib/py2tf/converters/decorators_test.py +++ b/tensorflow/contrib/autograph/converters/decorators_test.py @@ -20,9 +20,9 @@ from __future__ import print_function from functools import wraps -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import decorators -from tensorflow.contrib.py2tf.pyct import compiler +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import decorators +from tensorflow.contrib.autograph.pyct import compiler from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/converters/for_loops.py b/tensorflow/contrib/autograph/converters/for_loops.py similarity index 62% rename from tensorflow/contrib/py2tf/converters/for_loops.py rename to tensorflow/contrib/autograph/converters/for_loops.py index 935dade0ed30975dd29c8ffe5be875993936d241..4999c47bdc79ec0ea352472cfd3e97b94ebc7cce 100644 --- a/tensorflow/contrib/py2tf/converters/for_loops.py +++ b/tensorflow/contrib/autograph/converters/for_loops.py @@ -22,10 +22,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno class ForLoopCanonicalizationTransformer(transformer.Base): @@ -37,42 +37,48 @@ class ForLoopCanonicalizationTransformer(transformer.Base): def visit_For(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - + i_var = self.context.namer.new_symbol('i', body_scope.referenced) + smart_loop_iter_var = self.context.namer.new_symbol('smart_loop_iter', + body_scope.referenced) + cont_var = self.context.namer.new_symbol('cont', body_scope.referenced) + # TODO(mdan): Use TensorListFromTensor(loop_iter) here. if anno.hasanno(node, 'extra_cond'): template = """ i = 0 - n = len(loop_iter) - while i < n and extra_cond: - # TODO(mdan): Use TensorListFromTensor(loop_iter) here. - target = loop_iter[i] + smart_loop_iter = autograph_utils.dynamic_dataset(loop_iter) + cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter) + while cont and extra_cond: body i += 1 + cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter) """ return templates.replace( template, loop_iter=node.iter, target=node.target, body=node.body, - i=self.context.namer.new_symbol('i', body_scope.referenced), - n=self.context.namer.new_symbol('n', body_scope.referenced), + i=i_var, + smart_loop_iter=smart_loop_iter_var, + cont=cont_var, extra_cond=anno.getanno(node, 'extra_cond')) else: template = """ i = 0 - n = len(loop_iter) - while i < n: - # TODO(mdan): Use TensorListFromTensor(loop_iter) here. - target = loop_iter[i] - body # pylint:disable=pointless-statement + smart_loop_iter = autograph_utils.dynamic_dataset(loop_iter) + cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter) + while cont: + body i += 1 + cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter) """ repl = templates.replace( template, loop_iter=node.iter, target=node.target, body=node.body, - i=self.context.namer.new_symbol('i', body_scope.referenced), - n=self.context.namer.new_symbol('n', body_scope.referenced)) + i=i_var, + smart_loop_iter=smart_loop_iter_var, + cont=cont_var) return repl def visit_Continue(self, node): diff --git a/tensorflow/contrib/py2tf/converters/for_loops_test.py b/tensorflow/contrib/autograph/converters/for_loops_test.py similarity index 64% rename from tensorflow/contrib/py2tf/converters/for_loops_test.py rename to tensorflow/contrib/autograph/converters/for_loops_test.py index 70a367d3b517e528b67f260d607431d324d2ab7d..943f52de55a3629fdb18e6188e42269a4cb06275 100644 --- a/tensorflow/contrib/py2tf/converters/for_loops_test.py +++ b/tensorflow/contrib/autograph/converters/for_loops_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import for_loops +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import for_loops from tensorflow.python.platform import test @@ -42,6 +42,29 @@ class ControlFlowTest(converter_test_base.TestCase): l = [] self.assertEqual(test_fn(l), result.test_fn(l)) + def test_for_with_iterated_expression(self): + + eval_count = [0] + + def count_evals(x): + eval_count[0] += 1 + return x + + def test_fn(n): + s = 0 + for e in count_evals(range(n)): + s += e + return s + + node = self.parse_and_analyze(test_fn, {'count_evals': count_evals}) + node = for_loops.transform(node, self.ctx) + + with self.compiled(node) as result: + result.count_evals = count_evals + self.assertEqual(test_fn(5), result.test_fn(5)) + # count_evals ran twice, once for test_fn and another for result.test_fn + self.assertEqual(eval_count[0], 2) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/utils/printing.py b/tensorflow/contrib/autograph/converters/ifexp.py similarity index 50% rename from tensorflow/contrib/py2tf/utils/printing.py rename to tensorflow/contrib/autograph/converters/ifexp.py index 95a62bd80b5f4854e6a062df18d882f7bd495555..bb0c0a36a7827e5c73e0fa67f09aa4f54d497a2c 100644 --- a/tensorflow/contrib/py2tf/utils/printing.py +++ b/tensorflow/contrib/autograph/converters/ifexp.py @@ -12,36 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TensorFlow printing support utilities.""" +"""Canonicalizes the ternary conditional operator.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils import py_func -from tensorflow.python.ops import logging_ops +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer -def is_tf_print_compatible(value): - # TODO(mdan): Enable once we can reliably test this. - # This is currently disabled because we can't capture the output of - # op kernels from Python. - del value - return False +class IfExp(transformer.Base): + """Canonicalizes all IfExp nodes into plain conditionals.""" + def visit_IfExp(self, node): + template = """ + autograph_utils.run_cond(test, lambda: (body,), lambda: (orelse,)) + """ + desugared_ifexp = templates.replace_as_expression( + template, test=node.test, body=node.body, orelse=node.orelse) + return desugared_ifexp -def call_print(*values): - """Compiled counterpart of the print builtin. - The function attempts to use tf.Print if all the values are compatible. - Otherwise, it will fall back to py_func. +def transform(node, context): + """Desugar IfExp nodes into plain conditionals. Args: - *values: values to print + node: an AST node to transform + context: a context object + Returns: - A dummy value indicating the print completed. If tf. + new_node: an AST with no IfExp nodes, only conditionals. """ - if all(map(is_tf_print_compatible, values)): - return logging_ops.Print(1, values) - return py_func.wrap_py_func(print, None, values, use_dummy_return=True) + node = IfExp(context).visit(node) + return node diff --git a/tensorflow/contrib/autograph/converters/ifexp_test.py b/tensorflow/contrib/autograph/converters/ifexp_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ac6849dcb4bd7dacd84bb205f5c65395d8c2f51e --- /dev/null +++ b/tensorflow/contrib/autograph/converters/ifexp_test.py @@ -0,0 +1,106 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ifexp module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import ifexp +from tensorflow.python.platform import test + + +class IfExpTest(converter_test_base.TestCase): + + def compiled_fn(self, test_fn, *args): + node = self.parse_and_analyze(test_fn, {}) + node = ifexp.transform(node, self.ctx) + module = self.compiled(node, *args) + return module + + def test_simple(self): + + def test_fn(x): + return 1 if x else 0 + + with self.compiled_fn(test_fn) as result: + result.autograph_util = utils + for x in [0, 1]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_fn(self): + + def f(x): + return 3 * x + + def test_fn(x): + y = f(x * x if x > 0 else x) + return y + + with self.compiled_fn(test_fn) as result: + result.autograph_util = utils + result.f = f + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_exp(self): + + def test_fn(x): + return x * x if x > 0 else x + + with self.compiled_fn(test_fn) as result: + result.autograph_util = utils + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_nested(self): + + def test_fn(x): + return x * x if x > 0 else x if x else 1 + + with self.compiled_fn(test_fn) as result: + result.autograph_util = utils + for x in [-2, 0, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_in_cond(self): + + def test_fn(x): + if x > 0: + return x * x if x < 5 else x * x * x + return -x + + with self.compiled_fn(test_fn) as result: + result.autograph_util = utils + for x in [-2, 2, 5]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_assign_in_cond(self): + + def test_fn(x): + if x > 0: + x = -x if x < 5 else x + return x + + with self.compiled_fn(test_fn) as result: + result.autograph_util = utils + for x in [-2, 2, 5]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/converters/list_comprehension.py b/tensorflow/contrib/autograph/converters/list_comprehension.py similarity index 93% rename from tensorflow/contrib/py2tf/converters/list_comprehension.py rename to tensorflow/contrib/autograph/converters/list_comprehension.py index e8744831100e4852919b5cd1253b74acea4d790d..d7f292015164e047d054c5d1fb0b391e960bb73d 100644 --- a/tensorflow/contrib/py2tf/converters/list_comprehension.py +++ b/tensorflow/contrib/autograph/converters/list_comprehension.py @@ -31,9 +31,9 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer class ListCompCanonicalizationTransformer(transformer.Base): diff --git a/tensorflow/contrib/py2tf/converters/list_comprehension_test.py b/tensorflow/contrib/autograph/converters/list_comprehension_test.py similarity index 93% rename from tensorflow/contrib/py2tf/converters/list_comprehension_test.py rename to tensorflow/contrib/autograph/converters/list_comprehension_test.py index 025fac11e41e6771fbb9b80ff3da70dc3ceec73e..4758671f5ec83c26cfa54be0ef68f5f564094f6c 100644 --- a/tensorflow/contrib/py2tf/converters/list_comprehension_test.py +++ b/tensorflow/contrib/autograph/converters/list_comprehension_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import list_comprehension +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import list_comprehension from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/contrib/autograph/converters/lists.py new file mode 100644 index 0000000000000000000000000000000000000000..234a0a7487d5fc9e068acf4a19af3bac84f4737e --- /dev/null +++ b/tensorflow/contrib/autograph/converters/lists.py @@ -0,0 +1,106 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converter for list operations. + +This includes converting Python lists to TensorArray/TensorList. +""" + +# TODO(mdan): Elaborate the logic here. +# TODO(mdan): Does it even make sense to attempt to try to use TAs? +# The current rule (always convert to TensorArray) is naive and insufficient. +# In general, a better mechanism could look like: +# * convert to TensorList by default +# * leave as Python list if the user explicitly forbids it +# * convert to TensorArray only when complete write once behavior can be +# guaranteed (e.g. list comprehensions) + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.python.framework import dtypes + + +class ListTransformer(transformer.Base): + """Converts lists and related operations to their TF counterpart.""" + + def _empty_list(self, node): + if not anno.hasanno(node, 'element_type'): + raise NotImplementedError( + 'type inference for empty lists is not yet supported; ' + 'use utils.set_element_type(, ) to continue') + dtype = anno.getanno(node, 'element_type') + if not isinstance(dtype, dtypes.DType): + # TODO(mdan): Allow non-TF dtypes? + # That would be consistent with the dynamic dispatch pattern, but + # we must make sure that doesn't become confusing. + raise NotImplementedError('element type "%s" not yet supported' % dtype) + + dtype_name = dtype.name + # TODO(mdan): Does it ever make sense not to use tensor lists? + template = """ + tf.TensorArray(tf.dtype_name, size=0, dynamic_size=True) + """ + return templates.replace_as_expression(template, dtype_name=dtype_name) + + def _pre_populated_list(self, node): + raise NotImplementedError('pre-populated lists') + + def visit_Expr(self, node): + node = self.generic_visit(node) + if isinstance(node.value, gast.Call): + call_node = node.value + + if not anno.hasanno(call_node.func, anno.Basic.QN): + return node + qn = anno.getanno(call_node.func, anno.Basic.QN) + + if qn.qn[-1] == 'append' and (len(call_node.args) == 1): + template = """ + target = autograph_utils.dynamic_list_append(target, element) + """ + node = templates.replace( + template, + target=qn.parent.ast(), + element=call_node.args[0]) + return node + + def visit_Assign(self, node): + node = self.generic_visit(node) + + # Only convert lists when they are assigned to a variable, e.g.: + # l = [] + # TODO(mdan): This rule should be improved. + if len(node.targets) != 1: + return node + if not isinstance(node.value, gast.List): + return node + if not isinstance(node.value.ctx, gast.Load): + return node + + if node.value.elts: + node.value = self._pre_populated_list(node.value) + else: + node.value = self._empty_list(node.value) + return node + + +def transform(node, context): + return ListTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/utils/printing_test.py b/tensorflow/contrib/autograph/converters/lists_test.py similarity index 50% rename from tensorflow/contrib/py2tf/utils/printing_test.py rename to tensorflow/contrib/autograph/converters/lists_test.py index 2070deb304d8df2433fb9a95ae36d48415578482..749ba14347314f975c5a6e1111133336e2f5c5e6 100644 --- a/tensorflow/contrib/py2tf/utils/printing_test.py +++ b/tensorflow/contrib/autograph/converters/lists_test.py @@ -12,41 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for printing module.""" +"""Tests for lists module.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import lists +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.platform import test -import six -from tensorflow.contrib.py2tf.utils import printing -from tensorflow.python.platform import test +class ListTest(converter_test_base.TestCase): + def test_empty_annotated_list(self): -class ContextManagersTest(test.TestCase): + def test_fn(): + l = [] + utils.set_element_type(l, dtypes.int32) + l.append(1) + return l - def test_call_print_tf(self): - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - with self.test_session() as sess: - sess.run(printing.call_print('test message', 1)) - self.assertEqual(out_capturer.getvalue(), 'test message 1\n') - finally: - sys.stdout = sys.__stdout__ - - def test_call_print_py_func(self): - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer + node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + node = lists.transform(node, self.ctx) + + with self.compiled(node, tensor_array_ops.TensorArray, + dtypes.int32) as result: + # TODO(mdan): Attach these additional modules automatically. + result.utils = utils + result.dtypes = dtypes with self.test_session() as sess: - sess.run(printing.call_print('test message', [1, 2])) - self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n') - finally: - sys.stdout = sys.__stdout__ + self.assertEqual(test_fn(), sess.run(result.test_fn().stack())) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/logical_expressions.py b/tensorflow/contrib/autograph/converters/logical_expressions.py new file mode 100644 index 0000000000000000000000000000000000000000..3a795a315a3c2aa08ac1577a204102755b6e849c --- /dev/null +++ b/tensorflow/contrib/autograph/converters/logical_expressions.py @@ -0,0 +1,132 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converter for logical expressions. + +e.g. `a and b -> tf.logical_and(a, b)`. This is not done automatically in TF. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer + + +# TODO(mdan): Properly extrack boolean ops according to lazy eval rules. +# Note that this isn't completely safe either, because tensors may have control +# dependencies. +# Note that for loops that should be done after the loop was converted to +# tf.while_loop so that the expanded conditionals are properly scoped. + +# Used to signal that an operand is safe for non-lazy evaluation. +SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND' + + +class LogicalExpressionTransformer(transformer.Base): + """Converts logical expressions to corresponding TF calls.""" + + def __init__(self, context): + super(LogicalExpressionTransformer, self).__init__(context) + # TODO(mdan): Look into replacing with bitwise operators instead. + # TODO(mdan): Skip replacing if the function is trivial. + self.op_mapping = { + gast.And: 'tf.logical_and', + gast.Eq: 'tf.equal', + gast.Gt: 'tf.greater', + gast.GtE: 'tf.greater_equal', + gast.Lt: 'tf.less', + gast.LtE: 'tf.less_equal', + gast.Not: 'tf.logical_not', + gast.NotEq: 'tf.not_equal', + gast.Or: 'tf.logical_or', + gast.USub: 'tf.negative', + gast.Is: 'autograph_utils.dynamic_is', + gast.IsNot: 'autograph_utils.dynamic_is_not' + } + + def _expect_simple_symbol(self, operand): + if isinstance(operand, gast.Name): + return + if anno.hasanno(operand, SAFE_BOOLEAN_OPERAND): + return + raise NotImplementedError( + 'only simple local variables are supported in logical and compound ' + 'comparison expressions; for example, we support "a or b" but not ' + '"a.x or b"; for a workaround, assign the expression to a local ' + 'variable and use that instead, for example "tmp = a.x", "tmp or b"') + + def _matching_func(self, operator): + op_type = type(operator) + mapped_op = self.op_mapping.get(op_type) + if not mapped_op: + raise NotImplementedError('operator %s is not yet supported' % op_type) + return mapped_op + + def _as_function(self, func_name, args): + template = """ + func_name(args) + """ + replacement = templates.replace_as_expression( + template, func_name=parser.parse_expression(func_name), args=args) + anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True) + return replacement + + def visit_Compare(self, node): + node = self.generic_visit(node) + ops_and_comps = list(zip(node.ops, node.comparators)) + left = node.left + op_tree = None + + # Repeated comparisons are converted to conjunctions: + # a < b < c -> a < b and b < c + while ops_and_comps: + op, right = ops_and_comps.pop(0) + binary_comparison = self._as_function( + self._matching_func(op), (left, right)) + if isinstance(left, gast.Name) and isinstance(right, gast.Name): + anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True) + if op_tree: + self._expect_simple_symbol(right) + op_tree = self._as_function('tf.logical_and', + (binary_comparison, op_tree)) + else: + op_tree = binary_comparison + left = right + assert op_tree is not None + return op_tree + + def visit_UnaryOp(self, node): + node = self.generic_visit(node) + return self._as_function(self._matching_func(node.op), node.operand) + + def visit_BoolOp(self, node): + node = self.generic_visit(node) + node_values = node.values + right = node.values.pop() + self._expect_simple_symbol(right) + while node_values: + left = node_values.pop() + self._expect_simple_symbol(left) + right = self._as_function(self._matching_func(node.op), (left, right)) + return right + + +def transform(node, context): + return LogicalExpressionTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/logical_expressions_test.py b/tensorflow/contrib/autograph/converters/logical_expressions_test.py similarity index 86% rename from tensorflow/contrib/py2tf/converters/logical_expressions_test.py rename to tensorflow/contrib/autograph/converters/logical_expressions_test.py index a28326c517d468230f35e45f0fbfe5257d769895..2814060c4d831e4dddacb3dcbcbe1db42160db20 100644 --- a/tensorflow/contrib/py2tf/converters/logical_expressions_test.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import logical_expressions +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import logical_expressions from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -32,7 +32,7 @@ class GradientsFunctionTest(converter_test_base.TestCase): return a == b node = self.parse_and_analyze(test_fn, {}) - node = logical_expressions.transform(node) + node = logical_expressions.transform(node, self.ctx) with self.compiled(node, math_ops.equal) as result: with self.test_session() as sess: @@ -45,7 +45,7 @@ class GradientsFunctionTest(converter_test_base.TestCase): return (a or b) and (a or b or c) node = self.parse_and_analyze(test_fn, {}) - node = logical_expressions.transform(node) + node = logical_expressions.transform(node, self.ctx) with self.compiled(node, math_ops.logical_or, math_ops.logical_and) as result: diff --git a/tensorflow/contrib/py2tf/converters/name_scopes.py b/tensorflow/contrib/autograph/converters/name_scopes.py similarity index 93% rename from tensorflow/contrib/py2tf/converters/name_scopes.py rename to tensorflow/contrib/autograph/converters/name_scopes.py index c702823fcf047fcad3254318bd323d2b8fddd700..2a3f474360e94635470bf9581222e4c79f46b7a1 100644 --- a/tensorflow/contrib/py2tf/converters/name_scopes.py +++ b/tensorflow/contrib/autograph/converters/name_scopes.py @@ -21,8 +21,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer class FunctionNameScopeTransformer(transformer.Base): diff --git a/tensorflow/contrib/py2tf/converters/name_scopes_test.py b/tensorflow/contrib/autograph/converters/name_scopes_test.py similarity index 95% rename from tensorflow/contrib/py2tf/converters/name_scopes_test.py rename to tensorflow/contrib/autograph/converters/name_scopes_test.py index a8ca341602ee5f06dbb812643a58794339d98afe..61e5db2af826d0c2238f1af0f3240411596f7429 100644 --- a/tensorflow/contrib/py2tf/converters/name_scopes_test.py +++ b/tensorflow/contrib/autograph/converters/name_scopes_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import name_scopes +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import name_scopes from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/converters/side_effect_guards.py b/tensorflow/contrib/autograph/converters/side_effect_guards.py similarity index 91% rename from tensorflow/contrib/py2tf/converters/side_effect_guards.py rename to tensorflow/contrib/autograph/converters/side_effect_guards.py index 30976b3ec6db5a6607023ac804d9d54cfb296190..1c1293d2c411b51b563ac3965284a48725ed3278 100644 --- a/tensorflow/contrib/py2tf/converters/side_effect_guards.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards.py @@ -36,12 +36,12 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import ast_util -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno class SymbolNamer(object): @@ -160,8 +160,8 @@ class SideEffectGuardTransformer(transformer.Base): [alias_map.get(s, s).ast() for s in guarded_args], None) template = """ - with py2tf_utils.control_dependency_on_returns(call): - aliased_guarded_args = py2tf_utils.alias_tensors(guarded_args) + with autograph_utils.control_dependency_on_returns(call): + aliased_guarded_args = autograph_utils.alias_tensors(guarded_args) """ control_deps_guard = templates.replace( template, @@ -172,7 +172,7 @@ class SideEffectGuardTransformer(transformer.Base): alias_map = {} template = """ - with py2tf_utils.control_dependency_on_returns(call): + with autograph_utils.control_dependency_on_returns(call): pass """ control_deps_guard = templates.replace(template, call=node.value)[-1] diff --git a/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py similarity index 97% rename from tensorflow/contrib/py2tf/converters/side_effect_guards_test.py rename to tensorflow/contrib/autograph/converters/side_effect_guards_test.py index 463db2e770213ba9636d2537b095a77dece5d8f6..ce0ce33243a1352107eb8121050ee76474869809 100644 --- a/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import side_effect_guards +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import side_effect_guards from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/autograph/converters/single_return.py b/tensorflow/contrib/autograph/converters/single_return.py new file mode 100644 index 0000000000000000000000000000000000000000..bcc9ca9dfeb00ef2d2e60edf6a1abfba19a1bad7 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/single_return.py @@ -0,0 +1,317 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Canonicalizes functions with multiple returns to use just one.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno + + +# TODO(mdan): Move this logic into transformer_base. +class BodyVisitor(transformer.Base): + """Walks breadth- or depth-first the list-of-nodes bodies of AST nodes.""" + + def __init__(self, context, depth_first=False): + self.depth_first = depth_first + self.changes_made = False + super(BodyVisitor, self).__init__(context) + + def visit_nodelist(self, nodelist): + for node in nodelist: + if isinstance(node, list): + node = self.visit_nodelist(node) + else: + node = self.generic_visit(node) + return nodelist + + def visit_If(self, node): + if self.depth_first: + node = self.generic_visit(node) + node.body = self.visit_nodelist(node.body) + node.orelse = self.visit_nodelist(node.orelse) + if not self.depth_first: + node = self.generic_visit(node) + return node + + def visit_For(self, node): + if self.depth_first: + node = self.generic_visit(node) + node.body = self.visit_nodelist(node.body) + node.orelse = self.visit_nodelist(node.orelse) + if not self.depth_first: + node = self.generic_visit(node) + return node + + def visit_While(self, node): + if self.depth_first: + node = self.generic_visit(node) + node.body = self.visit_nodelist(node.body) + node.orelse = self.visit_nodelist(node.orelse) + if not self.depth_first: + node = self.generic_visit(node) + return node + + def visit_Try(self, node): + if self.depth_first: + node = self.generic_visit(node) + node.body = self.visit_nodelist(node.body) + node.orelse = self.visit_nodelist(node.orelse) + node.finalbody = self.visit_nodelist(node.finalbody) + for i in range(len(node.handlers)): + node.handlers[i].body = self.visit_nodelist(node.handlers[i].body) + if not self.depth_first: + node = self.generic_visit(node) + return node + + def visit_With(self, node): + if self.depth_first: + node = self.generic_visit(node) + node.body = self.visit_nodelist(node.body) + if not self.depth_first: + node = self.generic_visit(node) + return node + + def visit_FunctionDef(self, node): + if self.depth_first: + node = self.generic_visit(node) + node.body = self.visit_nodelist(node.body) + self.generic_visit(node) + if not self.depth_first: + node = self.generic_visit(node) + return node + + +class FoldElse(BodyVisitor): + + def visit_nodelist(self, nodelist): + for i in range(len(nodelist)): + node = nodelist[i] + if isinstance(node, gast.If): + true_branch_returns = isinstance(node.body[-1], gast.Return) + false_branch_returns = len(node.orelse) and isinstance( + node.orelse[-1], gast.Return) + # If the last node in the if body is a return, + # then every line after this if statement effectively + # belongs in the else. + if true_branch_returns and not false_branch_returns: + for j in range(i + 1, len(nodelist)): + nodelist[i].orelse.append(ast_util.copy_clean(nodelist[j])) + if nodelist[i + 1:]: + self.changes_made = True + return nodelist[:i + 1] + elif not true_branch_returns and false_branch_returns: + for j in range(i + 1, len(nodelist)): + nodelist[i].body.append(ast_util.copy_clean(nodelist[j])) + if nodelist[i + 1:]: + self.changes_made = True + return nodelist[:i + 1] + elif true_branch_returns and false_branch_returns: + if nodelist[i + 1:]: + raise ValueError( + 'Unreachable code after conditional where both branches return.' + ) + return nodelist + elif isinstance(node, gast.Return) and nodelist[i + 1:]: + raise ValueError( + 'Cannot have statements after a return in the same basic block') + return nodelist + + +def contains_return(node): + for n in gast.walk(node): + if isinstance(n, gast.Return): + return True + return False + + +class LiftReturn(transformer.Base): + """Move return statements out of If and With blocks.""" + + def __init__(self, context): + self.changes_made = False + self.common_return_name = None + super(LiftReturn, self).__init__(context) + + def visit_If(self, node): + # Depth-first traversal of if statements + node = self.generic_visit(node) + + # We check if both branches return, and if so, lift the return out of the + # conditional. We don't enforce that the true and false branches either + # both return or both do not, because FoldElse might move a return + # into a branch after this transform completes. FoldElse and LiftReturn + # are alternately run until the code reaches a fixed point. + true_branch_returns = isinstance(node.body[-1], gast.Return) + false_branch_returns = len(node.orelse) and isinstance( + node.orelse[-1], gast.Return) + if true_branch_returns and false_branch_returns: + node.body[-1] = templates.replace( + 'a = b', a=self.common_return_name, b=node.body[-1].value)[0] + node.orelse[-1] = templates.replace( + 'a = b', a=self.common_return_name, b=node.orelse[-1].value)[0] + return_node = templates.replace('return a', a=self.common_return_name)[0] + self.changes_made = True + return [node, return_node] + else: + return node + + def visit_With(self, node): + # Depth-first traversal of syntax + node = self.generic_visit(node) + + # If the with statement returns, lift the return + if isinstance(node.body[-1], gast.Return): + node.body[-1] = templates.replace( + 'a = b', a=self.common_return_name, b=node.body[-1].value)[0] + return_node = templates.replace('return a', a=self.common_return_name)[0] + node = self.generic_visit(node) + self.changes_made = True + return [node, return_node] + else: + return node + + def visit_FunctionDef(self, node): + # Ensure we're doing depth-first traversal + last_return_name = self.common_return_name + body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + referenced_names = body_scope.referenced + self.common_return_name = self.context.namer.new_symbol( + 'return_', referenced_names) + node = self.generic_visit(node) + self.common_return_name = last_return_name + return node + + +class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor): + """Throws an error if code returns inside loops or try/except.""" + + # First, throw an error if we detect a return statement in a loop. + # TODO(alexbw): we need to learn to handle returns inside a loop, + # but don't currently have the TF constructs to do so (need something + # that looks vaguely like a goto). + + def __init__(self): + self.cant_return = False + super(DetectReturnInUnsupportedControlFlow, self).__init__() + + def visit_While(self, node): + self.cant_return = True + self.generic_visit(node) + self.cant_return = False + + def visit_For(self, node): + self.cant_return = True + self.generic_visit(node) + self.cant_return = False + + def visit_Try(self, node): + self.cant_return = True + self.generic_visit(node) + self.cant_return = False + + def visit_Return(self, node): + if self.cant_return: + raise ValueError( + '`return` statements are not supported in loops. ' + 'Try assigning to a variable in the while loop, and returning ' + 'outside of the loop') + + +class DetectReturnInConditional(gast.NodeVisitor): + """Assert that no return statements are present in conditionals.""" + + def __init__(self): + self.cant_return = False + super(DetectReturnInConditional, self).__init__() + + def visit_If(self, node): + self.cant_return = True + self.generic_visit(node) + self.cant_return = False + + def visit_Return(self, node): + if self.cant_return: + raise ValueError( + 'After transforms, a conditional contained a `return `statement, ' + 'which is not allowed. This is a bug, and should not happen.') + + +class DetectReturnInFunctionDef(gast.NodeVisitor): + + def visit_FunctionDef(self, node): + self.generic_visit(node) + if not contains_return(node): + raise ValueError( + 'Each function definition should contain at least one return.') + + +def transform(node, context): + """Ensure a function has only a single return. + + This transforms an AST node with multiple returns successively into containing + only a single return node. + There are a few restrictions on what we can handle: + - An AST being transformed must contain at least one return. + - No returns allowed in loops. We have to know the type of the return value, + and we currently don't have either a type inference system to discover it, + nor do we have a mechanism for late type binding in TensorFlow. + - After all transformations are finished, a Return node is not allowed inside + control flow. If we were unable to move a return outside of control flow, + this is an error. + + Args: + node: an AST node to transform + context: a context object + + Returns: + new_node: an AST with a single return value + + Raises: + ValueError: if the AST is structured so that we can't perform the + transform. + """ + # Make sure that the function has at least one return statement + # TODO(alexbw): turning off this assertion for now -- + # we need to not require this in e.g. class constructors. + # DetectReturnInFunctionDef().visit(node) + + # Make sure there's no returns in unsupported locations (loops, try/except) + DetectReturnInUnsupportedControlFlow().visit(node) + + while True: + + # Try to lift all returns out of if statements and with blocks + lr = LiftReturn(context) + node = lr.visit(node) + changes_made = lr.changes_made + fe = FoldElse(context) + node = fe.visit(node) + changes_made = changes_made or fe.changes_made + + if not changes_made: + break + + # Make sure we've scrubbed all returns from conditionals + DetectReturnInConditional().visit(node) + + return node diff --git a/tensorflow/contrib/autograph/converters/single_return_test.py b/tensorflow/contrib/autograph/converters/single_return_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d483005a09537ea8227814f65aa7e6402c853f60 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/single_return_test.py @@ -0,0 +1,189 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for single_return module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import single_return +from tensorflow.python.framework.ops import name_scope +from tensorflow.python.platform import test + + +class SingleReturnTest(converter_test_base.TestCase): + + def compiled_fn(self, test_fn, *args): + node = self.parse_and_analyze(test_fn, {}) + node = single_return.transform(node, self.ctx) + module = self.compiled(node, *args) + return module + + def test_noop(self): + # Noop + def test_fn(x): + return x + + with self.compiled_fn(test_fn) as result: + self.assertEqual(test_fn(2.0), result.test_fn(2.0)) + + def test_return_expression(self): + # ANF + def test_fn(x): + return x * x + + with self.compiled_fn(test_fn) as result: + x = 2 + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_merge(self): + # Simple merge + def test_fn(x): + if x > 0: + return x + else: + return x * x + + with self.compiled_fn(test_fn) as result: + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_orphan_branch(self): + + def test_fn(x): + if x > 0: + return x + + with self.assertRaises(ValueError): + self.compiled_fn(test_fn) + + def test_lift_body_into_false_branch(self): + + def test_fn(x): + if x > 0: + return x + return x * x + + with self.compiled_fn(test_fn) as result: + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_lift_body_into_true_branch(self): + + def test_fn(x): + if x < 0: + x *= x + else: + # TODO(alexbw): linter bug here that requires us suppress this warning. + return x # pylint: disable=undefined-loop-variable + return x + + with self.compiled_fn(test_fn) as result: + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_nested_if(self): + + def test_fn(x): + if x > 0: + if x < 5: + return x + else: + return x * x + else: + return x * x * x + + with self.compiled_fn(test_fn) as result: + for x in [-2, 2, 5]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_context_manager(self): + + def test_fn(x): + + with name_scope(''): + return x * x + + with self.compiled_fn(test_fn) as result: + result.name_scope = name_scope + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_context_manager_in_conditional(self): + + def test_fn(x): + if x > 0: + with name_scope(''): + return x * x + else: + return x + + with self.compiled_fn(test_fn, name_scope) as result: + result.name_scope = name_scope + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def text_conditional_in_context_manager(self): + + def test_fn(x): + with name_scope(''): + if x > 0: + return x * x + else: + return x + + with self.compiled_fn(test_fn) as result: + result.name_scope = name_scope + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_no_return(self): + + def test_fn(x): + x *= x + + with self.compiled_fn(test_fn) as result: + self.assertEqual(test_fn(2), result.test_fn(2)) + + def test_nested_functiondefs(self): + + def test_fn(x): + + def inner_fn(y): + if y > 0: + return y * y + else: + return y + + return inner_fn(x) + + with self.compiled_fn(test_fn) as result: + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_loop(self): + + def test_fn(x): + for _ in range(10): + return x + return x + + with self.assertRaises(ValueError): + self.compiled_fn(test_fn) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/impl/BUILD b/tensorflow/contrib/autograph/impl/BUILD similarity index 80% rename from tensorflow/contrib/py2tf/impl/BUILD rename to tensorflow/contrib/autograph/impl/BUILD index 90ffabbc9bf4524ec2ebf54b6dd847bd8768a486..e468176da1724d8a7ce62647dc3c4b656c71affb 100644 --- a/tensorflow/contrib/py2tf/impl/BUILD +++ b/tensorflow/contrib/autograph/impl/BUILD @@ -25,10 +25,10 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ - "//tensorflow/contrib/py2tf/converters", - "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis", - "//tensorflow/contrib/py2tf/utils", + "//tensorflow/contrib/autograph/converters", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/pyct/static_analysis", + "//tensorflow/contrib/autograph/utils", "@gast_archive//:gast", "@six_archive//:six", ], @@ -40,8 +40,9 @@ py_test( srcs_version = "PY2AND3", deps = [ ":impl", - "//tensorflow/contrib/py2tf/utils", + "//tensorflow/contrib/autograph/utils", "//tensorflow/python:client_testlib", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/py2tf/impl/api.py b/tensorflow/contrib/autograph/impl/api.py similarity index 53% rename from tensorflow/contrib/py2tf/impl/api.py rename to tensorflow/contrib/autograph/impl/api.py index 29d2e038a73c8cac89c121ec65c32f0d4f68aff6..1c4fcaa62228232e8dddf9b6c0e845e13fa3ae8b 100644 --- a/tensorflow/contrib/py2tf/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -20,13 +20,20 @@ from __future__ import print_function from functools import wraps +from enum import Enum + +# pylint:disable=g-bad-import-order import gast import six - -from tensorflow.contrib.py2tf.impl import config -from tensorflow.contrib.py2tf.impl import conversion -from tensorflow.contrib.py2tf.pyct import compiler -from tensorflow.contrib.py2tf.pyct import parser +# pylint:enable=g-bad-import-order + +from tensorflow.contrib.autograph.impl import config +from tensorflow.contrib.autograph.impl import conversion +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import inspect_utils +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.utils import builtins +from tensorflow.contrib.autograph.utils import py_func from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_inspect @@ -35,55 +42,6 @@ from tensorflow.python.util import tf_inspect # (currently we require (module + class name, type)) -def graph_ready(f): - """No-op decorator that explicitly marks a function as graph-ready. - - Graph-ready functions are assumed to not need any conversion. - - Args: - f: Any callable. - Returns: - f itself. - """ - setattr(f, '__pyct_is_compile_decorator', True) - return f - - -def convert_inline(f, *args, **kwargs): - """Shorthand to convert and call a function. - - For example, the following two statements are equivalent: - - @convert() - def foo(): - ... - foo(bar) - - def foo(): - ... - convert_inline(foo, bar) - - Args: - f: Function to convert. Only this call will be converted. - *args: Passed through to f. - **kwargs: Passed through to f, with the following exceptions: - * arg_value_hints: A dict mapping parameter names to objects that can - hint at the type of those parameters. - - Returns: - The result of the converted f applied to args and kwargs. - """ - if 'arg_value_hints' in kwargs: - arg_value_hints = kwargs['arg_value_hints'] - del kwargs['arg_value_hints'] - else: - arg_value_hints = None - if tf_inspect.ismethod(f): - # When converting methods, the result is still an unbound function. - args = (f.__self__,) + args - return convert(arg_value_hints)(f)(*args, **kwargs) - - def convert(recursive=False, verbose=False, arg_types=None): """Decorator that compiles a function to graph mode. @@ -110,28 +68,56 @@ def convert(recursive=False, verbose=False, arg_types=None): @wraps(f) def wrapper(*args, **kwargs): - """Wrapper that calls the compiled version of the wrapped function.""" - partial_types = () - arg_values = {} - arg_names = tf_inspect.getargspec(f)[0] - for name, arg in zip(arg_names, args): - arg_values[name] = arg - arg_class = arg.__class__ - # If arg_value_hints specifies any name, use that instead. - if name not in arg_types: - arg_types[name] = (arg_class.__name__, arg_class) - if name == 'self' and tf_inspect.isclass(arg_class): - # Annotated methods need to specify that their owner type is partial, - # otherwise other members they call will not be converted. - partial_types = (arg_class,) - wrapped = to_graph( - f, - recursive=recursive, - verbose=verbose, - arg_values=arg_values, - arg_types=arg_types, - partial_types=partial_types) - return wrapped(*args, **kwargs) + return converted_call(f, recursive, verbose, arg_types, *args, **kwargs) + + # Sometimes the decorator is just desugared, making it impossible to detect. + # This attribute makes detection easier. + setattr(wrapper, '__pyct_is_compile_decorator', True) + return wrapper + + return decorator + + +class RunMode(Enum): + GRAPH = 1 + PY_FUNC = 2 + + +def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): + """Decorator that suppresses compilation of a function. + + Args: + run_as: RunMode value. Whether to run the function as-is, or wrap it into + a py_func. + return_dtypes: See autograph.utils.py_func.wrap_py_func. Setting to None or + empty list or tuple will create a dummy return value that can be used + to set control dependencies. + + Returns: + A decorator that wraps the original function. + """ + def decorator(f): + """Decorator implementation.""" + + @wraps(f) + def graph_wrapper(*args, **kwargs): + return f(*args, **kwargs) + + @wraps(f) + def py_func_wrapper(*args, **kwargs): + if kwargs: + raise NotImplementedError( + 'RunMode.PY_FUNC does not yet support kwargs') + # TODO(mdan): Add support for kwargs. + return py_func.wrap_py_func( + f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes) + + if run_as == RunMode.GRAPH: + wrapper = graph_wrapper + elif run_as == RunMode.PY_FUNC: + wrapper = py_func_wrapper + else: + raise ValueError('unknown value for run_as: %s' % run_as) # Sometimes the decorator is just desugared, making it impossible to detect. # This attribute makes detection easier. @@ -141,6 +127,78 @@ def convert(recursive=False, verbose=False, arg_types=None): return decorator +def converted_call(f, recursive, verbose, arg_types, *args, **kwargs): + """Compiles a function call inline.""" + # TODO(mdan): This needs cleanup. + # In particular, we may want to avoid renaming functions altogether. + + if conversion.is_whitelisted_for_graph(f): + return f(*args, **kwargs) + + unknown_arg_value = object() # Sentinel for arguments of unknown value + + if tf_inspect.isbuiltin(f): + return builtins.dynamic_builtin(f, *args, **kwargs) + + if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): + # Regular functions + target_entity = f + arg_map_target = f + effective_args = args + f_class = inspect_utils.getmethodclass(f) + + if f_class is not None: + partial_types = (f_class,) + else: + partial_types = () + + elif tf_inspect.isclass(f): + # Constructors + target_entity = f + arg_map_target = f.__init__ + effective_args = (unknown_arg_value,) + args + partial_types = () + + elif hasattr(f, '__call__') and hasattr(f, '__class__'): + # Callable objects + target_entity = f.__call__ + arg_map_target = f.__call__ + effective_args = (f,) + args + partial_types = (f.__class__,) + + else: + NotImplementedError('unknown callable type "%s"' % type(f)) + + arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs) + for name, arg in arg_values.items(): + if arg is unknown_arg_value: + continue + arg_class = arg.__class__ + # If arg_value_hints specifies any name, use that instead. + if name not in arg_types: + arg_types[name] = (arg_class.__name__, arg_class) + + # When called from within a decorator, this is the only indication that + # the function is a method - it appears that the decorator is applied + # before the method is bound. + if not partial_types: + if 'self' in arg_values: + if tf_inspect.isclass(arg_values['self'].__class__): + partial_types = (arg_values['self'].__class__,) + elif 'cls' in arg_values: + if tf_inspect.isclass(arg_values['cls']): + partial_types = (arg_values['cls'],) + + converted_f = to_graph( + target_entity, + recursive=recursive, + verbose=verbose, + arg_values=arg_values, + arg_types=arg_types, + partial_types=partial_types) + return converted_f(*effective_args, **kwargs) + + def to_graph(e, recursive=True, verbose=False, @@ -174,14 +232,14 @@ def to_graph(e, """ conversion_map = conversion.ConversionMap( recursive=recursive, - nocompile_decorators=(convert, graph_ready, convert_inline), + nocompile_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, api_module=tf_inspect.getmodule(to_graph)) _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) module = gast.Module([]) for import_line in config.COMPILED_IMPORT_STATEMENTS: - module.body.append(parser.parse_str(import_line)) + module.body.extend(parser.parse_str(import_line).body) for dep in conversion_map.dependency_cache.values(): module.body.append(dep) compiled_node, compiled_src = compiler.ast_to_object(module) @@ -189,7 +247,7 @@ def to_graph(e, # The compiled code should see everything the entry function saw. # TODO(mdan): This might not work well if the call tree spans modules? if tf_inspect.isfunction(e): - compiled_node.__dict__.update(six.get_function_globals(e)) + compiled_node.__dict__.update(inspect_utils.getnamespace(e)) compiled_fn = getattr(compiled_node, name) if verbose: @@ -221,7 +279,7 @@ def to_code(e, """ conversion_map = conversion.ConversionMap( recursive=recursive, - nocompile_decorators=(convert, graph_ready, convert_inline), + nocompile_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, api_module=tf_inspect.getmodule(to_graph)) conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) diff --git a/tensorflow/contrib/py2tf/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py similarity index 74% rename from tensorflow/contrib/py2tf/impl/api_test.py rename to tensorflow/contrib/autograph/impl/api_test.py index 51e99864adeba9c928b6e74eb759054ef1d1d78c..ee2d301d7562ef5ba6bc7ca6d013b99dec78d4c3 100644 --- a/tensorflow/contrib/py2tf/impl/api_test.py +++ b/tensorflow/contrib/autograph/impl/api_test.py @@ -18,23 +18,27 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.impl import api -from tensorflow.contrib.py2tf.impl import config -from tensorflow.contrib.py2tf.pyct import parser +import numpy as np + +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.impl import api +from tensorflow.contrib.autograph.impl import config +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.utils import py_func from tensorflow.python.framework import constant_op -from tensorflow.python.ops import math_ops from tensorflow.python.platform import test +tf = utils.fake_tf() + + class ApiTest(test.TestCase): def setUp(self): - config.DEFAULT_UNCOMPILED_MODULES.add((math_ops.__name__,)) config.COMPILED_IMPORT_STATEMENTS = ( - 'from tensorflow.python.framework ' - 'import ops as tf', - 'from tensorflow.contrib.py2tf import utils as ' - 'py2tf_utils') + 'from __future__ import print_function', + 'from tensorflow.contrib.autograph import utils as ' + 'autograph_utils', 'tf = autograph_utils.fake_tf()') def test_decorator_recurses(self): @@ -47,7 +51,7 @@ class ApiTest(test.TestCase): @api.convert(recursive=True) def test_method(self, x, s, a): - while math_ops.reduce_sum(x) > s: + while tf.reduce_sum(x) > s: x //= self.called_member(a) return x @@ -63,11 +67,11 @@ class ApiTest(test.TestCase): class TestClass(object): def called_member(self, a): - return math_ops.negative(a) + return tf.negative(a) @api.convert(recursive=False) def test_method(self, x, s, a): - while math_ops.reduce_sum(x) > s: + while tf.reduce_sum(x) > s: x //= self.called_member(a) return x @@ -78,17 +82,17 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) - def test_decorator_calls_converted(self): + def test_decorator_calls_unconverted_graph(self): class TestClass(object): - @api.graph_ready + @api.do_not_convert(api.RunMode.GRAPH) def called_member(self, a): - return math_ops.negative(a) + return tf.negative(a) @api.convert(recursive=True) def test_method(self, x, s, a): - while math_ops.reduce_sum(x) > s: + while tf.reduce_sum(x) > s: x //= self.called_member(a) return x @@ -99,20 +103,23 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) - def test_decorator_calls_decorated(self): + def test_decorator_calls_unconverted_py_func(self): class TestClass(object): - @api.convert() + @api.do_not_convert( + api.RunMode.PY_FUNC, return_dtypes=py_func.MatchDType(1)) def called_member(self, a): - if a < 0: - a = -a - return a + return np.negative(a) @api.convert(recursive=True) def test_method(self, x, s, a): - while math_ops.reduce_sum(x) > s: - x //= self.called_member(a) + while tf.reduce_sum(x) > s: + y = self.called_member(a) + # set_shape works around while_loop's limitations. + # TODO(mdan): Allow specifying shapes (or ShapeLike) instead. + y.set_shape(a.shape) + x //= y return x tc = TestClass() @@ -122,10 +129,11 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) - def test_convert_call_site_decorator(self): + def test_decorator_calls_decorated(self): class TestClass(object): + @api.convert() def called_member(self, a): if a < 0: a = -a @@ -133,8 +141,8 @@ class ApiTest(test.TestCase): @api.convert(recursive=True) def test_method(self, x, s, a): - while math_ops.reduce_sum(x) > s: - x //= api.convert_inline(self.called_member, a) + while tf.reduce_sum(x) > s: + x //= self.called_member(a) return x tc = TestClass() @@ -144,17 +152,20 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) - def test_graph_ready_call_site_decorator(self): + def test_convert_call_site_decorator(self): class TestClass(object): def called_member(self, a): - return math_ops.negative(a) + if a < 0: + a = -a + return a @api.convert(recursive=True) def test_method(self, x, s, a): - while math_ops.reduce_sum(x) > s: - x //= api.graph_ready(self.called_member(a)) + while tf.reduce_sum(x) > s: + x //= api.converted_call(self.called_member, False, False, {}, self, + a) return x tc = TestClass() @@ -165,8 +176,9 @@ class ApiTest(test.TestCase): self.assertListEqual([0, 1], sess.run(x).tolist()) def test_to_graph_basic(self): + def test_fn(x, s): - while math_ops.reduce_sum(x) > s: + while tf.reduce_sum(x) > s: x //= 2 return x @@ -177,15 +189,16 @@ class ApiTest(test.TestCase): self.assertListEqual([1, 2], sess.run(x).tolist()) def test_to_code_basic(self): + def test_fn(x, s): - while math_ops.reduce_sum(x) > s: + while tf.reduce_sum(x) > s: x /= 2 return x compiled_code = api.to_code(test_fn) # Just check for some key words and that it is parseable Python code. - self.assertRegexpMatches(compiled_code, 'py2tf_utils\\.run_while') + self.assertRegexpMatches(compiled_code, 'autograph_utils\\.run_while') self.assertIsNotNone(parser.parse_str(compiled_code)) diff --git a/tensorflow/contrib/py2tf/impl/config.py b/tensorflow/contrib/autograph/impl/config.py similarity index 69% rename from tensorflow/contrib/py2tf/impl/config.py rename to tensorflow/contrib/autograph/impl/config.py index c90e85c96b690b7781358b173e5d83fe60e29c00..543c1486e657f4e7b16e5723cc294c09ebbcec00 100644 --- a/tensorflow/contrib/py2tf/impl/config.py +++ b/tensorflow/contrib/autograph/impl/config.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf import utils +from tensorflow.contrib.autograph import utils PYTHON_LITERALS = { @@ -31,16 +31,20 @@ PYTHON_LITERALS = { DEFAULT_UNCOMPILED_MODULES = set(( ('tensorflow',), (utils.__name__,), + + # All of tensorflow's subpackages. Unlike the root tf module, they don't + # have well-known names. Not refering to the module directly to avoid + # circular imports. + ( + utils.__name__[:-len('.contrib.autograph.utils')],), )) NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',)) # TODO(mdan): Also allow controlling the generated names (for testability). -# TODO(mdan): Make sure copybara renames the reference below. COMPILED_IMPORT_STATEMENTS = ( - 'from __future__ import print_function', - 'import tensorflow as tf', - 'from tensorflow.contrib.py2tf.impl import api as ' - 'py2tf_api', - 'from tensorflow.contrib.py2tf import utils as ' - 'py2tf_utils') + 'from __future__ import print_function', 'import tensorflow as tf', + 'from tensorflow.contrib.autograph.impl import api as ' + 'autograph_api', + 'from tensorflow.contrib.autograph import utils as ' + 'autograph_utils') diff --git a/tensorflow/contrib/py2tf/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py similarity index 69% rename from tensorflow/contrib/py2tf/impl/conversion.py rename to tensorflow/contrib/autograph/impl/conversion.py index 4bf698f2072256f9b87d4b36159fd372536ec7a1..62a49cd92d835fb942f48354041cb0ab03d02c97 100644 --- a/tensorflow/contrib/py2tf/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -19,28 +19,32 @@ from __future__ import division from __future__ import print_function import gast -import six - -from tensorflow.contrib.py2tf import utils -from tensorflow.contrib.py2tf.converters import asserts -from tensorflow.contrib.py2tf.converters import break_statements -from tensorflow.contrib.py2tf.converters import builtin_functions -from tensorflow.contrib.py2tf.converters import call_trees -from tensorflow.contrib.py2tf.converters import continue_statements -from tensorflow.contrib.py2tf.converters import control_flow -from tensorflow.contrib.py2tf.converters import decorators -from tensorflow.contrib.py2tf.converters import for_loops -from tensorflow.contrib.py2tf.converters import logical_expressions -from tensorflow.contrib.py2tf.converters import name_scopes -from tensorflow.contrib.py2tf.converters import side_effect_guards -from tensorflow.contrib.py2tf.impl import config -from tensorflow.contrib.py2tf.impl import naming -from tensorflow.contrib.py2tf.pyct import context -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.static_analysis import activity -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info + +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.converters import asserts +from tensorflow.contrib.autograph.converters import break_statements +from tensorflow.contrib.autograph.converters import builtin_functions +from tensorflow.contrib.autograph.converters import call_trees +from tensorflow.contrib.autograph.converters import continue_statements +from tensorflow.contrib.autograph.converters import control_flow +from tensorflow.contrib.autograph.converters import decorators +from tensorflow.contrib.autograph.converters import for_loops +from tensorflow.contrib.autograph.converters import ifexp +from tensorflow.contrib.autograph.converters import lists +from tensorflow.contrib.autograph.converters import logical_expressions +from tensorflow.contrib.autograph.converters import name_scopes +from tensorflow.contrib.autograph.converters import side_effect_guards +from tensorflow.contrib.autograph.converters import single_return +from tensorflow.contrib.autograph.impl import config +from tensorflow.contrib.autograph.impl import naming +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import inspect_utils +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import type_info +from tensorflow.contrib.autograph.utils import type_hints from tensorflow.python.util import tf_inspect @@ -48,7 +52,9 @@ from tensorflow.python.util import tf_inspect class ConversionMap(object): - """ConversionMaps keep track of converting function hierarchies. + """ConversionMap keeps track of converting function hierarchies. + + This object is mutable, and is updated as functions are converted. Attributes: recursive: Whether to recusrively convert any functions that the decorator @@ -97,6 +103,24 @@ class ConversionMap(object): self.dependency_cache[original_entity] = converted_ast +def is_whitelisted_for_graph(o): + """Check whether an entity is whitelisted for use in graph mode. + + Examples of whitelisted entities include all members of the tensorflow + package. + + Args: + o: A Python entity. + Returns: + Boolean + """ + m = tf_inspect.getmodule(o) + for prefix, in config.DEFAULT_UNCOMPILED_MODULES: + if m.__name__.startswith(prefix): + return True + return False + + def entity_to_graph(o, conversion_map, arg_values, arg_types): """Compile a Python entity into equivalent TensorFlow. @@ -136,14 +160,20 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types): conversion_map.add_to_cache(o, node) if conversion_map.recursive: - for obj in conversion_map.name_map.keys(): - if obj not in conversion_map.dependency_cache: - if (hasattr(obj, 'im_class') and - getattr(obj, 'im_class') not in conversion_map.partial_types): - # Class members are converted with their objects, unless they're - # only converted partially. - continue - entity_to_graph(obj, conversion_map, {}, {}) + while True: + candidate = None + for obj in conversion_map.name_map.keys(): + if obj not in conversion_map.dependency_cache: + candidate = obj + break + if candidate is None: + break + if (hasattr(candidate, 'im_class') and + getattr(candidate, 'im_class') not in conversion_map.partial_types): + # Class members are converted with their objects, unless they're + # only converted partially. + continue + entity_to_graph(candidate, conversion_map, {}, {}) return node, new_name @@ -151,11 +181,12 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types): def class_to_graph(c, conversion_map): """Specialization of `entity_to_graph` for classes.""" converted_members = {} - members = tf_inspect.getmembers(c, predicate=tf_inspect.ismethod) + method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m) + members = tf_inspect.getmembers(c, predicate=method_filter) if not members: - raise ValueError('Cannot convert %s: it has no member methods.') + raise ValueError('Cannot convert %s: it has no member methods.' % c) - class_globals = None + class_namespace = None for _, m in members: node, _ = function_to_graph( m, @@ -164,16 +195,16 @@ def class_to_graph(c, conversion_map): arg_types={'self': (c.__name__, c)}, owner_type=c) # TODO(mdan): Do not assume all members have the same view of globals. - if class_globals is None: - class_globals = six.get_function_globals(m) + if class_namespace is None: + class_namespace = inspect_utils.getnamespace(m) converted_members[m] = node - namer = conversion_map.new_namer(class_globals) + namer = conversion_map.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) node = gast.ClassDef( class_name, bases=[], keywords=[], - body=converted_members.values(), + body=list(converted_members.values()), decorator_list=[]) return node, class_name @@ -182,19 +213,19 @@ def class_to_graph(c, conversion_map): def _add_self_references(namespace, api_module): """Self refs are only required for analysis and are not used directly.""" # Manually add the utils namespace which may be used from generated code. - if 'py2tf_util' not in namespace: - namespace['py2tf_utils'] = utils - elif namespace['py2tf_utils'] != utils: + if 'autograph_util' not in namespace: + namespace['autograph_utils'] = utils + elif namespace['autograph_utils'] != utils: raise ValueError( - 'The module name "py2tf_utils" is reserved and may not be used.') + 'The module name "autograph_utils" is reserved and may not be used.') # We also make reference to the api module for dynamic conversion, but # to avoid circular references we don't import it here. - if 'py2tf_api' not in namespace: - namespace['py2tf_api'] = api_module - elif namespace['py2tf_api'] != api_module: + if 'autograph_api' not in namespace: + namespace['autograph_api'] = api_module + elif namespace['autograph_api'] != api_module: raise ValueError( - 'The module name "py2tf_api" is reserved and may not be used.') + 'The module name "autograph_api" is reserved and may not be used.') def function_to_graph(f, conversion_map, arg_values, arg_types, @@ -202,19 +233,11 @@ def function_to_graph(f, conversion_map, arg_values, arg_types, """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] - namespace = six.get_function_globals(f) - - # This is needed for non-global functions. - closure = six.get_function_closure(f) - if closure: - for e in closure: - if callable(e.cell_contents): - fn = e.cell_contents - namespace[fn.__name__] = fn + namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, conversion_map.api_module) - namer = conversion_map.new_namer(namespace) + ctx = context.EntityContext( namer=namer, source_code=source, @@ -223,7 +246,8 @@ def function_to_graph(f, conversion_map, arg_values, arg_types, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type, - recursive=conversion_map.recursive) + recursive=conversion_map.recursive, + type_annotation_func=type_hints.set_element_type) node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators) # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py @@ -276,10 +300,15 @@ def node_to_graph(node, ctx, nocompile_decorators): # to re-run the analysis. node = _static_analysis_pass(node, ctx) + + # TODO(mdan): Clean this up. + # Some intermediate analyses are not required, and some comments got orphaned. + # Past this point, line numbers are no longer accurate so we ignore the # source. # TODO(mdan): Is it feasible to reconstruct intermediate source code? ctx.source_code = None + node = ifexp.transform(node, ctx) node, deps = decorators.transform(node, nocompile_decorators) node = break_statements.transform(node, ctx) node = asserts.transform(node, ctx) @@ -291,6 +320,10 @@ def node_to_graph(node, ctx, nocompile_decorators): ctx.namespace['len'] = len node = _static_analysis_pass(node, ctx) + node = single_return.transform(node, ctx) + + node = _static_analysis_pass(node, ctx) + node = lists.transform(node, ctx) node = for_loops.transform(node, ctx) # for_loops may insert new global references. node = builtin_functions.transform(node, ctx) @@ -302,7 +335,7 @@ def node_to_graph(node, ctx, nocompile_decorators): # control_flow may create new symbols and change scopes. node = _static_analysis_pass(node, ctx) - node = logical_expressions.transform(node) + node = logical_expressions.transform(node, ctx) node = side_effect_guards.transform(node, ctx) node = name_scopes.transform(node, ctx) diff --git a/tensorflow/contrib/py2tf/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py similarity index 82% rename from tensorflow/contrib/py2tf/impl/conversion_test.py rename to tensorflow/contrib/autograph/impl/conversion_test.py index 7816f958575d58236007fc7f0f1f3d1f3a99c4cf..7066739eb87f89ab98e906b10dab62baeaa2de8e 100644 --- a/tensorflow/contrib/py2tf/impl/conversion_test.py +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -20,12 +20,23 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.impl import conversion +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.impl import conversion +from tensorflow.python.framework import constant_op from tensorflow.python.platform import test class ConversionTest(test.TestCase): + def test_is_whitelisted_for_graph(self): + + def test_fn(): + return constant_op.constant(1) + + self.assertFalse(conversion.is_whitelisted_for_graph(test_fn)) + self.assertTrue(conversion.is_whitelisted_for_graph(utils)) + self.assertTrue(conversion.is_whitelisted_for_graph(constant_op.constant)) + def test_entity_to_graph_unsupported_types(self): with self.assertRaises(ValueError): conversion_map = conversion.ConversionMap(True, (), (), None) diff --git a/tensorflow/contrib/py2tf/impl/naming.py b/tensorflow/contrib/autograph/impl/naming.py similarity index 98% rename from tensorflow/contrib/py2tf/impl/naming.py rename to tensorflow/contrib/autograph/impl/naming.py index 51326091de13715c32d0a79279f1d3274e48ad10..1facaa0ca0ebcc6d4281e7c92a462ceeb00b453a 100644 --- a/tensorflow/contrib/py2tf/impl/naming.py +++ b/tensorflow/contrib/autograph/impl/naming.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.autograph.pyct import qual_names class Namer(object): diff --git a/tensorflow/contrib/py2tf/impl/naming_test.py b/tensorflow/contrib/autograph/impl/naming_test.py similarity index 98% rename from tensorflow/contrib/py2tf/impl/naming_test.py rename to tensorflow/contrib/autograph/impl/naming_test.py index beb4e54937bbb91b19157c9b9e3c528353206c62..73fc0894655cb49e4f61bf8ca51995b06feb3072 100644 --- a/tensorflow/contrib/py2tf/impl/naming_test.py +++ b/tensorflow/contrib/autograph/impl/naming_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.impl import naming +from tensorflow.contrib.autograph.impl import naming from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD similarity index 100% rename from tensorflow/contrib/py2tf/pyct/BUILD rename to tensorflow/contrib/autograph/pyct/BUILD diff --git a/tensorflow/contrib/py2tf/pyct/__init__.py b/tensorflow/contrib/autograph/pyct/__init__.py similarity index 100% rename from tensorflow/contrib/py2tf/pyct/__init__.py rename to tensorflow/contrib/autograph/pyct/__init__.py diff --git a/tensorflow/contrib/py2tf/pyct/anno.py b/tensorflow/contrib/autograph/pyct/anno.py similarity index 92% rename from tensorflow/contrib/py2tf/pyct/anno.py rename to tensorflow/contrib/autograph/pyct/anno.py index 7a0528b6d0b65b6604930b7a13d8493af9d61f02..cc4a7edf02ed7556c9a552d8730e4c7875038c83 100644 --- a/tensorflow/contrib/py2tf/pyct/anno.py +++ b/tensorflow/contrib/autograph/pyct/anno.py @@ -70,3 +70,8 @@ def delanno(node, key, field_name='___pyct_anno'): if not annotations: delattr(node, field_name) node._fields = tuple(f for f in node._fields if f != field_name) + + +def copyanno(from_node, to_node, key, field_name='___pyct_anno'): + if hasanno(from_node, key, field_name): + setanno(to_node, key, getanno(from_node, key, field_name), field_name) diff --git a/tensorflow/contrib/py2tf/pyct/anno_test.py b/tensorflow/contrib/autograph/pyct/anno_test.py similarity index 77% rename from tensorflow/contrib/py2tf/pyct/anno_test.py rename to tensorflow/contrib/autograph/pyct/anno_test.py index ff40bfe1f50ae731648afdf509c26c3a70d3f6cb..1d4d9d119e0c45c4bf9dd4e5b8156766489a2e4d 100644 --- a/tensorflow/contrib/py2tf/pyct/anno_test.py +++ b/tensorflow/contrib/autograph/pyct/anno_test.py @@ -20,10 +20,13 @@ from __future__ import print_function import ast -from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.autograph.pyct import anno from tensorflow.python.platform import test +# TODO(mdan): Consider strong types instead of primitives. + + class AnnoTest(test.TestCase): def test_basic(self): @@ -42,6 +45,17 @@ class AnnoTest(test.TestCase): with self.assertRaises(AttributeError): anno.getanno(node, 'foo') + def test_copyanno(self): + node_1 = ast.Name() + anno.setanno(node_1, 'foo', 3) + + node_2 = ast.Name() + anno.copyanno(node_1, node_2, 'foo') + anno.copyanno(node_1, node_2, 'bar') + + self.assertTrue(anno.hasanno(node_2, 'foo')) + self.assertFalse(anno.hasanno(node_2, 'bar')) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/ast_util.py b/tensorflow/contrib/autograph/pyct/ast_util.py similarity index 87% rename from tensorflow/contrib/py2tf/pyct/ast_util.py rename to tensorflow/contrib/autograph/pyct/ast_util.py index f916775b9cf3cec960ec2896c334f1d737862205..4f76a695228f7d84b80b2e4b03801e15e94b8f11 100644 --- a/tensorflow/contrib/py2tf/pyct/ast_util.py +++ b/tensorflow/contrib/autograph/pyct/ast_util.py @@ -22,7 +22,7 @@ import ast import gast -from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.autograph.pyct import anno class CleanCopier(gast.NodeVisitor): @@ -84,7 +84,10 @@ class SymbolRenamer(gast.NodeTransformer): return self._process(node) def visit_Attribute(self, node): - return self._process(node) + if anno.hasanno(node, anno.Basic.QN): + return self._process(node) + # Attributes of dynamic objects will not have a QN. + return self.generic_visit(node) def rename_symbols(node, name_map): @@ -94,3 +97,12 @@ def rename_symbols(node, name_map): elif isinstance(node, tuple): return tuple(renamer.visit(n) for n in node) return renamer.visit(node) + + +def keywords_to_dict(keywords): + keys = [] + values = [] + for kw in keywords: + keys.append(gast.Str(kw.arg)) + values.append(kw.value) + return gast.Dict(keys=keys, values=values) diff --git a/tensorflow/contrib/py2tf/pyct/ast_util_test.py b/tensorflow/contrib/autograph/pyct/ast_util_test.py similarity index 72% rename from tensorflow/contrib/py2tf/pyct/ast_util_test.py rename to tensorflow/contrib/autograph/pyct/ast_util_test.py index e0b00c178168f96e656c57cc75a76e6da8af1d8a..8faf92c705d997db298dbb1115981fd9da26372d 100644 --- a/tensorflow/contrib/py2tf/pyct/ast_util_test.py +++ b/tensorflow/contrib/autograph/pyct/ast_util_test.py @@ -20,8 +20,10 @@ from __future__ import print_function import ast -from tensorflow.contrib.py2tf.pyct import ast_util -from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.python.platform import test @@ -33,15 +35,15 @@ class AstUtilTest(test.TestCase): ast.Name('b', ast.Load()), ast.Attribute(ast.Name('b', None), 'c', ast.Store()), ast.Attribute( - ast.Attribute(ast.Name('b', None), 'c', ast.Load()), 'd', - None) + ast.Attribute(ast.Name('b', None), 'c', ast.Load()), 'd', None) ], None) node = qual_names.resolve(node) node = ast_util.rename_symbols( - node, - { - qual_names.QN('a'): qual_names.QN('renamed_a'), - qual_names.QN('b.c'): qual_names.QN('renamed_b_c'), + node, { + qual_names.QN('a'): + qual_names.QN('renamed_a'), + qual_names.QN(qual_names.QN('b'), attr='c'): + qual_names.QN('renamed_b_c'), }) self.assertEqual(node.elts[0].id, 'renamed_a') @@ -74,6 +76,17 @@ class AstUtilTest(test.TestCase): self.assertFalse(ret is new_node.body[0]) self.assertFalse(hasattr(new_node.body[0], '__foo')) + def test_keywords_to_dict(self): + keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords + d = ast_util.keywords_to_dict(keywords) + # Make sure we generate a usable dict node by attaching it to a variable and + # compiling everything. + output = parser.parse_str('b = 3') + output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d),) + result, _ = compiler.ast_to_object(output) + self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'}) + print(d) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/compiler.py b/tensorflow/contrib/autograph/pyct/compiler.py similarity index 90% rename from tensorflow/contrib/py2tf/pyct/compiler.py rename to tensorflow/contrib/autograph/pyct/compiler.py index 51cf6930e8bcb3728ee55bf5d4781f01a5ef73bd..24c4517afa89147101f80af3ef60237132c1144c 100644 --- a/tensorflow/contrib/py2tf/pyct/compiler.py +++ b/tensorflow/contrib/autograph/pyct/compiler.py @@ -31,7 +31,7 @@ import astor import gast -def ast_to_source(node, indentation): +def ast_to_source(node, indentation=' '): """Return the source code of given AST.""" if isinstance(node, gast.AST): node = gast.gast_to_ast(node) @@ -39,7 +39,10 @@ def ast_to_source(node, indentation): astor.string_repr.pretty_string) generator.visit(node) generator.result.append('\n') - return astor.source_repr.pretty_source(generator.result).lstrip() + # In some versions of Python, literals may appear as actual values. This + # ensures everything is string. + code = map(str, generator.result) + return astor.source_repr.pretty_source(code).lstrip() def ast_to_object( diff --git a/tensorflow/contrib/py2tf/pyct/compiler_test.py b/tensorflow/contrib/autograph/pyct/compiler_test.py similarity index 83% rename from tensorflow/contrib/py2tf/pyct/compiler_test.py rename to tensorflow/contrib/autograph/pyct/compiler_test.py index c1f84238efa7dd6fc0748748a2cb4f074572b4c6..98cdc1506b6aced603df99662f1468687a55f92c 100644 --- a/tensorflow/contrib/py2tf/pyct/compiler_test.py +++ b/tensorflow/contrib/autograph/pyct/compiler_test.py @@ -22,12 +22,29 @@ import textwrap import gast -from tensorflow.contrib.py2tf.pyct import compiler +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import parser from tensorflow.python.platform import test +from tensorflow.python.util import tf_inspect class CompilerTest(test.TestCase): + def test_parser_compile_idempotent(self): + + def test_fn(x): + a = True + b = '' + if a: + b = x + 1 + return b + + self.assertEqual( + textwrap.dedent(tf_inspect.getsource(test_fn)), + tf_inspect.getsource( + compiler.ast_to_object( + parser.parse_entity(test_fn)[0].body[0])[0].test_fn)) + def test_ast_to_source(self): node = gast.If( test=gast.Num(1), diff --git a/tensorflow/contrib/py2tf/pyct/context.py b/tensorflow/contrib/autograph/pyct/context.py similarity index 87% rename from tensorflow/contrib/py2tf/pyct/context.py rename to tensorflow/contrib/autograph/pyct/context.py index 4fcf2a687d58af951adfc0dcf52ff7303d2b17f5..b34015cfd2888f0dbeb6492b9e7335d561bf4763 100644 --- a/tensorflow/contrib/py2tf/pyct/context.py +++ b/tensorflow/contrib/autograph/pyct/context.py @@ -22,6 +22,8 @@ from __future__ import print_function class EntityContext(object): """Contains information about an entity, like source code. + In general, objects of this class should be considered immutable. + Attributes: namer: Namer that matches the contract of all converters. source_code: The entity's source code. @@ -33,8 +35,9 @@ class EntityContext(object): owner_type: The surrounding class type of the function, if present. """ + # TODO(mdan): Remove the default and update tests. def __init__(self, namer, source_code, source_file, namespace, arg_values, - arg_types, owner_type, recursive): + arg_types, owner_type, recursive, type_annotation_func=None): self.namer = namer self.source_code = source_code self.source_file = source_file @@ -43,3 +46,4 @@ class EntityContext(object): self.arg_types = {} if arg_types is None else arg_types self.owner_type = owner_type self.recursive = recursive + self.type_annotation_func = type_annotation_func diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils.py b/tensorflow/contrib/autograph/pyct/inspect_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d19c6ed75e0f0651781d6e1ed80f7be11fb8a5a4 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/inspect_utils.py @@ -0,0 +1,119 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Live entity inspection utilities. + +This module contains whatever inspect doesn't offer out of the box. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import six + +from tensorflow.python.util import tf_inspect + + +def getnamespace(f): + """Returns the complete namespace of a function. + + Namespace is defined here as the mapping of all non-local variables to values. + This includes the globals and the closure variables. Note that this captures + the entire globals collection of the function, and may contain extra symbols + that it does not actually use. + + Args: + f: User defined function. + Returns: + A dict mapping symbol names to values. + """ + namespace = dict(six.get_function_globals(f)) + closure = six.get_function_closure(f) + freevars = six.get_function_code(f).co_freevars + if freevars and closure: + for name, cell in zip(freevars, closure): + namespace[name] = cell.cell_contents + return namespace + + +def getmethodclass(m): + """Resolves a function's owner, e.g. a method's class. + + Note that this returns the object that the function was retrieved from, not + necessarily the class where it was defined. + + This function relies on Python stack frame support in the interpreter, and + has the same limitations that inspect.currentframe. + + Limitations. This function will only work correctly if the owned class is + visible in the caller's global or local variables. + + Args: + m: A user defined function + + Returns: + The class that this function was retrieved from, or None if the function + is not an object or class method, or the class that owns the object or + method is not visible to m. + + Raises: + ValueError: if the class could not be resolved for any unexpected reason. + """ + + # Instance method and class methods: should be bound to a non-null "self". + # If self is a class, then it's a class method. + if hasattr(m, '__self__'): + if m.__self__: + if tf_inspect.isclass(m.__self__): + return m.__self__ + return type(m.__self__) + + # Class, static and unbound methods: search all defined classes in any + # namespace. This is inefficient but more robust method. + owners = [] + caller_frame = tf_inspect.currentframe().f_back + try: + # TODO(mdan): This doesn't consider cell variables. + # TODO(mdan): This won't work if the owner is hidden inside a container. + # Cell variables may be pulled using co_freevars and the closure. + for v in itertools.chain(caller_frame.f_locals.values(), + caller_frame.f_globals.values()): + if hasattr(v, m.__name__): + candidate = getattr(v, m.__name__) + # Py2 methods may be bound or unbound, extract im_func to get the + # underlying function. + if hasattr(candidate, 'im_func'): + candidate = candidate.im_func + if hasattr(m, 'im_func'): + m = m.im_func + if candidate is m: + owners.append(v) + finally: + del caller_frame + + if owners: + if len(owners) == 1: + return owners[0] + + # If multiple owners are found, and are not subclasses, raise an error. + owner_types = tuple(o if tf_inspect.isclass(o) else type(o) for o in owners) + for o in owner_types: + if tf_inspect.isclass(o) and issubclass(o, tuple(owner_types)): + return o + raise ValueError('Found too many owners of %s: %s' % (m, owners)) + + return None diff --git a/tensorflow/contrib/py2tf/pyct/inspect_utils_test.py b/tensorflow/contrib/autograph/pyct/inspect_utils_test.py similarity index 63% rename from tensorflow/contrib/py2tf/pyct/inspect_utils_test.py rename to tensorflow/contrib/autograph/pyct/inspect_utils_test.py index 5d92e75b1899e99c983325c9a474df43feb17d55..ddca6f963b8abadd621c544a79935c69326bf65e 100644 --- a/tensorflow/contrib/py2tf/pyct/inspect_utils_test.py +++ b/tensorflow/contrib/autograph/pyct/inspect_utils_test.py @@ -20,7 +20,9 @@ from __future__ import print_function from functools import wraps -from tensorflow.contrib.py2tf.pyct import inspect_utils +import six + +from tensorflow.contrib.autograph.pyct import inspect_utils from tensorflow.python.platform import test @@ -76,6 +78,10 @@ def free_function(): pass +def factory(): + return free_function + + def free_factory(): def local_function(): pass @@ -84,87 +90,87 @@ def free_factory(): class InspectUtilsTest(test.TestCase): - def test_getcallargs_constructor(self): - - class TestSuperclass(object): + def test_getnamespace_globals(self): + ns = inspect_utils.getnamespace(factory) + self.assertEqual(ns['free_function'], free_function) - def __init__(self, x): - pass - - class TestCallable(TestSuperclass): - pass + def test_getnamespace_hermetic(self): - self.assertDictEqual({ - 'x': 1 - }, inspect_utils.getcallargs(TestCallable, 1)) + # Intentionally hiding the global function to make sure we don't overwrite + # it in the global namespace. + free_function = object() # pylint:disable=redefined-outer-name - def test_getcallargs_object(self): + def test_fn(): + return free_function - class TestCallable(object): + ns = inspect_utils.getnamespace(test_fn) + globs = six.get_function_globals(test_fn) + self.assertTrue(ns['free_function'] is free_function) + self.assertFalse(globs['free_function'] is free_function) - def __call__(self, x): - pass + def test_getnamespace_locals(self): - obj = TestCallable() - self.assertDictEqual({ - 'self': obj, - 'x': 1 - }, inspect_utils.getcallargs(obj, 1)) + def called_fn(): + return 0 - def test_getcallargs_function(self): + closed_over_list = [] + closed_over_primitive = 1 - def test_fn(x): - return x + 1 + def local_fn(): + closed_over_list.append(1) + local_var = 1 + return called_fn() + local_var + closed_over_primitive - self.assertDictEqual({ - 'x': 1 - }, inspect_utils.getcallargs(test_fn, 1)) + ns = inspect_utils.getnamespace(local_fn) + self.assertEqual(ns['called_fn'], called_fn) + self.assertEqual(ns['closed_over_list'], closed_over_list) + self.assertEqual(ns['closed_over_primitive'], closed_over_primitive) + self.assertTrue('local_var' not in ns) def test_getmethodclass(self): self.assertEqual( - inspect_utils.getmethodclass(free_function, {}), None) + inspect_utils.getmethodclass(free_function), None) self.assertEqual( - inspect_utils.getmethodclass(free_factory(), {}), None) + inspect_utils.getmethodclass(free_factory()), None) - ns = {'TestClass': TestClass} self.assertEqual( - inspect_utils.getmethodclass(TestClass.member_function, ns), + inspect_utils.getmethodclass(TestClass.member_function), TestClass) self.assertEqual( - inspect_utils.getmethodclass(TestClass.decorated_member, ns), + inspect_utils.getmethodclass(TestClass.decorated_member), TestClass) self.assertEqual( - inspect_utils.getmethodclass(TestClass.fn_decorated_member, ns), + inspect_utils.getmethodclass(TestClass.fn_decorated_member), TestClass) self.assertEqual( - inspect_utils.getmethodclass(TestClass.wrap_decorated_member, ns), + inspect_utils.getmethodclass(TestClass.wrap_decorated_member), TestClass) self.assertEqual( - inspect_utils.getmethodclass(TestClass.static_method, ns), + inspect_utils.getmethodclass(TestClass.static_method), TestClass) self.assertEqual( - inspect_utils.getmethodclass(TestClass.class_method, ns), + inspect_utils.getmethodclass(TestClass.class_method), TestClass) test_obj = TestClass() self.assertEqual( - inspect_utils.getmethodclass(test_obj.member_function, ns), + inspect_utils.getmethodclass(test_obj.member_function), TestClass) self.assertEqual( - inspect_utils.getmethodclass(test_obj.decorated_member, ns), + inspect_utils.getmethodclass(test_obj.decorated_member), TestClass) self.assertEqual( - inspect_utils.getmethodclass(test_obj.fn_decorated_member, ns), + inspect_utils.getmethodclass(test_obj.fn_decorated_member), TestClass) self.assertEqual( - inspect_utils.getmethodclass(test_obj.wrap_decorated_member, ns), + inspect_utils.getmethodclass(test_obj.wrap_decorated_member), TestClass) self.assertEqual( - inspect_utils.getmethodclass(test_obj.static_method, ns), + inspect_utils.getmethodclass(test_obj.static_method), TestClass) self.assertEqual( - inspect_utils.getmethodclass(test_obj.class_method, ns), + inspect_utils.getmethodclass(test_obj.class_method), TestClass) def test_getmethodclass_locals(self): @@ -190,34 +196,33 @@ class InspectUtilsTest(test.TestCase): pass self.assertEqual( - inspect_utils.getmethodclass(local_function, {}), None) + inspect_utils.getmethodclass(local_function), None) - ns = {'LocalClass': LocalClass} self.assertEqual( - inspect_utils.getmethodclass(LocalClass.member_function, ns), + inspect_utils.getmethodclass(LocalClass.member_function), LocalClass) self.assertEqual( - inspect_utils.getmethodclass(LocalClass.decorated_member, ns), + inspect_utils.getmethodclass(LocalClass.decorated_member), LocalClass) self.assertEqual( - inspect_utils.getmethodclass(LocalClass.fn_decorated_member, ns), + inspect_utils.getmethodclass(LocalClass.fn_decorated_member), LocalClass) self.assertEqual( - inspect_utils.getmethodclass(LocalClass.wrap_decorated_member, ns), + inspect_utils.getmethodclass(LocalClass.wrap_decorated_member), LocalClass) test_obj = LocalClass() self.assertEqual( - inspect_utils.getmethodclass(test_obj.member_function, ns), + inspect_utils.getmethodclass(test_obj.member_function), LocalClass) self.assertEqual( - inspect_utils.getmethodclass(test_obj.decorated_member, ns), + inspect_utils.getmethodclass(test_obj.decorated_member), LocalClass) self.assertEqual( - inspect_utils.getmethodclass(test_obj.fn_decorated_member, ns), + inspect_utils.getmethodclass(test_obj.fn_decorated_member), LocalClass) self.assertEqual( - inspect_utils.getmethodclass(test_obj.wrap_decorated_member, ns), + inspect_utils.getmethodclass(test_obj.wrap_decorated_member), LocalClass) diff --git a/tensorflow/contrib/autograph/pyct/parser.py b/tensorflow/contrib/autograph/pyct/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..c961efa892df6a21804dae8f52ef64bf99cd409e --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/parser.py @@ -0,0 +1,58 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converting code to AST. + +Adapted from Tangent. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import textwrap + +import gast + +from tensorflow.python.util import tf_inspect + + +def parse_entity(entity): + """Returns the AST of given entity.""" + source = tf_inspect.getsource(entity) + source = textwrap.dedent(source) + return parse_str(source), source + + +def parse_str(src): + """Returns the AST of given piece of code.""" + return gast.parse(src) + + +def parse_expression(src): + """Returns the AST of given identifier. + + Args: + src: A piece of code that represents a single Python expression + Returns: + A gast.AST object. + Raises: + ValueError: if src does not consist of a single Expression. + """ + node = parse_str(src) + assert isinstance(node, gast.Module) + if len(node.body) != 1 and not isinstance(node.body[0], gast.Expr): + raise ValueError( + 'Expected a single expression, found instead %s' % node.body) + return node.body[0].value diff --git a/tensorflow/contrib/py2tf/pyct/parser_test.py b/tensorflow/contrib/autograph/pyct/parser_test.py similarity index 80% rename from tensorflow/contrib/py2tf/pyct/parser_test.py rename to tensorflow/contrib/autograph/pyct/parser_test.py index f35dfa04c70dc191078248c32f9a04d28133129a..007a4c6fb0393b7235808478d55b3ffa469f85d0 100644 --- a/tensorflow/contrib/py2tf/pyct/parser_test.py +++ b/tensorflow/contrib/autograph/pyct/parser_test.py @@ -20,28 +20,33 @@ from __future__ import print_function import textwrap -from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.contrib.autograph.pyct import parser from tensorflow.python.platform import test -def f(x): - return x + 1 - - class ParserTest(test.TestCase): def test_parse_entity(self): + + def f(x): + return x + 1 + mod, _ = parser.parse_entity(f) self.assertEqual('f', mod.body[0].name) def test_parse_str(self): mod = parser.parse_str( textwrap.dedent(""" - def f(x): - return x + 1 + def f(x): + return x + 1 """)) self.assertEqual('f', mod.body[0].name) + def test_parse_expression(self): + node = parser.parse_expression('a.b') + self.assertEqual('a', node.value.id) + self.assertEqual('b', node.attr) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/pretty_printer.py b/tensorflow/contrib/autograph/pyct/pretty_printer.py similarity index 100% rename from tensorflow/contrib/py2tf/pyct/pretty_printer.py rename to tensorflow/contrib/autograph/pyct/pretty_printer.py diff --git a/tensorflow/contrib/py2tf/pyct/pretty_printer_test.py b/tensorflow/contrib/autograph/pyct/pretty_printer_test.py similarity index 96% rename from tensorflow/contrib/py2tf/pyct/pretty_printer_test.py rename to tensorflow/contrib/autograph/pyct/pretty_printer_test.py index 81e3f47b80b6cb3bb7ba9f4a1787d03df4151a99..0cb48f35760b7b2655eb5cf73017b70e28dae219 100644 --- a/tensorflow/contrib/py2tf/pyct/pretty_printer_test.py +++ b/tensorflow/contrib/autograph/pyct/pretty_printer_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import ast -from tensorflow.contrib.py2tf.pyct import pretty_printer +from tensorflow.contrib.autograph.pyct import pretty_printer from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/pyct/qual_names.py b/tensorflow/contrib/autograph/pyct/qual_names.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5764a974aac542ddf4a54a9acd36f1afcb0464 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/qual_names.py @@ -0,0 +1,205 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for manipulating qualified names. + +A qualified name is a uniform way to refer to simple (e.g. 'foo') and composite +(e.g. 'foo.bar') syntactic symbols. + +This is *not* related to the __qualname__ attribute used by inspect, which +refers to scopes. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import gast + +from tensorflow.contrib.autograph.pyct import anno + + +class Symbol(collections.namedtuple('Symbol', ['name'])): + """Represents a Python symbol.""" + + +class StringLiteral(collections.namedtuple('StringLiteral', ['value'])): + """Represents a Python string literal.""" + + def __str__(self): + return '\'%s\'' % self.value + + def __repr__(self): + return str(self) + + +class NumberLiteral(collections.namedtuple('NumberLiteral', ['value'])): + """Represents a Python numeric literal.""" + + def __str__(self): + return '%s' % self.value + + def __repr__(self): + return str(self) + + +# TODO(mdan): Use subclasses to remove the has_attr has_subscript booleans. +class QN(object): + """Represents a qualified name.""" + + def __init__(self, base, attr=None, subscript=None): + if attr is not None and subscript is not None: + raise ValueError('A QN can only be either an attr or a subscript, not ' + 'both: attr={}, subscript={}.'.format(attr, subscript)) + self._has_attr = False + self._has_subscript = False + + if attr is not None: + if not isinstance(base, QN): + raise ValueError( + 'for attribute QNs, base must be a QN; got instead "%s"' % base) + if not isinstance(attr, str): + raise ValueError('attr may only be a string; got instead "%s"' % attr) + self._parent = base + # TODO(mdan): Get rid of the tuple - it can only have 1 or 2 elements now. + self.qn = (base, attr) + self._has_attr = True + + elif subscript is not None: + if not isinstance(base, QN): + raise ValueError('For subscript QNs, base must be a QN.') + self._parent = base + self.qn = (base, subscript) + self._has_subscript = True + + else: + if not isinstance(base, (str, StringLiteral, NumberLiteral)): + # TODO(mdan): Require Symbol instead of string. + raise ValueError( + 'For simple QNs, base must be a string or a Literal object.') + assert '.' not in base and '[' not in base and ']' not in base + self._parent = None + self.qn = (base,) + + def is_symbol(self): + return isinstance(self.qn[0], str) + + def is_composite(self): + return len(self.qn) > 1 + + def has_subscript(self): + return self._has_subscript + + def has_attr(self): + return self._has_attr + + @property + def parent(self): + if self._parent is None: + raise ValueError('Cannot get parent of simple name "%s".' % self.qn[0]) + return self._parent + + def __hash__(self): + return hash(self.qn + (self._has_attr, self._has_subscript)) + + def __eq__(self, other): + return (isinstance(other, QN) and self.qn == other.qn and + self.has_subscript() == other.has_subscript() and + self.has_attr() == other.has_attr()) + + def __str__(self): + if self.has_subscript(): + return str(self.qn[0]) + '[' + str(self.qn[1]) + ']' + if self.has_attr(): + return '.'.join(map(str, self.qn)) + else: + return str(self.qn[0]) + + def __repr__(self): + return str(self) + + def ssf(self): + """Simple symbol form.""" + ssfs = [n.ssf() if isinstance(n, QN) else n for n in self.qn] + ssf_string = '' + for i in range(0, len(self.qn) - 1): + if self.has_subscript(): + delimiter = '_sub_' + else: + delimiter = '_' + ssf_string += ssfs[i] + delimiter + return ssf_string + ssfs[-1] + + def ast(self): + # The caller must adjust the context appropriately. + if self.has_subscript(): + return gast.Subscript(self.parent.ast(), gast.Index(self.qn[-1].ast()), + None) + if self.has_attr(): + return gast.Attribute(self.parent.ast(), self.qn[-1], None) + + base = self.qn[0] + if isinstance(base, str): + return gast.Name(base, None, None) + elif isinstance(base, StringLiteral): + return gast.Str(base.value) + elif isinstance(base, NumberLiteral): + return gast.Num(base.value) + else: + assert False, ('the constructor should prevent types other than ' + 'str, StringLiteral and NumberLiteral') + + +class QnResolver(gast.NodeTransformer): + """Annotates nodes with QN information. + + Note: Not using NodeAnnos to avoid circular dependencies. + """ + + def visit_Name(self, node): + node = self.generic_visit(node) + anno.setanno(node, anno.Basic.QN, QN(node.id)) + return node + + def visit_Attribute(self, node): + node = self.generic_visit(node) + if anno.hasanno(node.value, anno.Basic.QN): + anno.setanno(node, anno.Basic.QN, + QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr)) + return node + + def visit_Subscript(self, node): + node = self.generic_visit(node) + s = node.slice + if not isinstance(s, gast.Index): + # TODO(mdan): Support range and multi-dimensional indices. + # Continuing silently because some demos use these. + return node + if isinstance(s.value, gast.Num): + subscript = QN(NumberLiteral(s.value.n)) + elif isinstance(s.value, gast.Str): + subscript = QN(StringLiteral(s.value.s)) + else: + subscript = anno.getanno(node.slice.value, anno.Basic.QN) + if anno.hasanno(node.value, anno.Basic.QN): + anno.setanno(node, anno.Basic.QN, + QN(anno.getanno(node.value, anno.Basic.QN), + subscript=subscript)) + return node + + +def resolve(node): + return QnResolver().visit(node) diff --git a/tensorflow/contrib/autograph/pyct/qual_names_test.py b/tensorflow/contrib/autograph/pyct/qual_names_test.py new file mode 100644 index 0000000000000000000000000000000000000000..103bd25aa380e9f61ecea9c5298f34df5157d629 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/qual_names_test.py @@ -0,0 +1,231 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for qual_names module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import textwrap + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.qual_names import QN +from tensorflow.contrib.autograph.pyct.qual_names import resolve +from tensorflow.python.platform import test + + +class QNTest(test.TestCase): + + def test_basic(self): + a = QN('a') + self.assertEqual(a.qn, ('a',)) + self.assertEqual(str(a), 'a') + self.assertEqual(a.ssf(), 'a') + self.assertEqual(a.ast().id, 'a') + self.assertFalse(a.is_composite()) + with self.assertRaises(ValueError): + _ = a.parent + + a_b = QN(a, attr='b') + self.assertEqual(a_b.qn, (a, 'b')) + self.assertEqual(str(a_b), 'a.b') + self.assertEqual(a_b.ssf(), 'a_b') + self.assertEqual(a_b.ast().value.id, 'a') + self.assertEqual(a_b.ast().attr, 'b') + self.assertTrue(a_b.is_composite()) + self.assertEqual(a_b.parent.qn, ('a',)) + + def test_subscripts(self): + a = QN('a') + b = QN('b') + a_sub_b = QN(a, subscript=b) + self.assertEqual(a_sub_b.qn, (a, b)) + self.assertEqual(str(a_sub_b), 'a[b]') + self.assertEqual(a_sub_b.ssf(), 'a_sub_b') + self.assertEqual(a_sub_b.ast().value.id, 'a') + self.assertEqual(a_sub_b.ast().slice.value.id, 'b') + self.assertTrue(a_sub_b.is_composite()) + self.assertTrue(a_sub_b.has_subscript()) + self.assertEqual(a_sub_b.parent.qn, ('a',)) + + c = QN('c') + b_sub_c = QN(b, subscript=c) + a_sub_b_sub_c = QN(a, subscript=b_sub_c) + self.assertEqual(a_sub_b_sub_c.qn, (a, b_sub_c)) + self.assertTrue(a_sub_b.is_composite()) + self.assertTrue(a_sub_b_sub_c.is_composite()) + self.assertTrue(a_sub_b.has_subscript()) + self.assertTrue(a_sub_b_sub_c.has_subscript()) + self.assertEqual(b_sub_c.qn, (b, c)) + self.assertEqual(str(a_sub_b_sub_c), 'a[b[c]]') + self.assertEqual(a_sub_b_sub_c.ssf(), 'a_sub_b_sub_c') + self.assertEqual(a_sub_b_sub_c.ast().value.id, 'a') + self.assertEqual(a_sub_b_sub_c.ast().slice.value.value.id, 'b') + self.assertEqual(a_sub_b_sub_c.ast().slice.value.slice.value.id, 'c') + self.assertEqual(b_sub_c.ast().slice.value.id, 'c') + self.assertEqual(a_sub_b_sub_c.parent.qn, ('a',)) + with self.assertRaises(ValueError): + QN('a', 'b') + + def test_equality(self): + a = QN('a') + a2 = QN('a') + a_b = QN(a, attr='b') + self.assertEqual(a2.qn, ('a',)) + with self.assertRaises(ValueError): + _ = a.parent + + a_b2 = QN(a, attr='b') + self.assertEqual(a_b2.qn, (a, 'b')) + self.assertEqual(a_b2.parent.qn, ('a',)) + + self.assertTrue(a2 == a) + self.assertFalse(a2 is a) + + self.assertTrue(a_b.parent == a) + self.assertTrue(a_b2.parent == a) + + self.assertTrue(a_b2 == a_b) + self.assertFalse(a_b2 is a_b) + self.assertFalse(a_b2 == a) + a_sub_b = QN(a, subscript='b') + a_sub_b2 = QN(a, subscript='b') + self.assertTrue(a_sub_b == a_sub_b2) + self.assertFalse(a_sub_b == a_b) + + def test_nested_attrs_subscripts(self): + a = QN('a') + b = QN('b') + c = QN('c') + b_sub_c = QN(b, subscript=c) + a_sub_b_sub_c = QN(a, subscript=b_sub_c) + + b_dot_c = QN(b, attr='c') + a_sub__b_dot_c = QN(a, subscript=b_dot_c) + + a_sub_b = QN(a, subscript=b) + a_sub_b__dot_c = QN(a_sub_b, attr='c') + + a_dot_b = QN(a, attr='b') + a_dot_b_sub_c = QN(a_dot_b, subscript=c) + + self.assertEqual(str(a_sub_b_sub_c), 'a[b[c]]') + self.assertEqual(str(a_sub__b_dot_c), 'a[b.c]') + self.assertEqual(str(a_sub_b__dot_c), 'a[b].c') + self.assertEqual(str(a_dot_b_sub_c), 'a.b[c]') + + self.assertNotEqual(a_sub_b_sub_c, a_sub__b_dot_c) + self.assertNotEqual(a_sub_b_sub_c, a_sub_b__dot_c) + self.assertNotEqual(a_sub_b_sub_c, a_dot_b_sub_c) + + self.assertNotEqual(a_sub__b_dot_c, a_sub_b__dot_c) + self.assertNotEqual(a_sub__b_dot_c, a_dot_b_sub_c) + + self.assertNotEqual(a_sub_b__dot_c, a_dot_b_sub_c) + + def test_hashable(self): + d = {QN('a'): 'a', QN('b'): 'b'} + self.assertEqual(d[QN('a')], 'a') + self.assertEqual(d[QN('b')], 'b') + self.assertTrue(QN('c') not in d) + + def test_literals(self): + a = QN('a') + a_sub_str_b = QN(a, subscript=QN(qual_names.StringLiteral('b'))) + a_sub_b = QN(a, subscript=QN('b')) + + self.assertNotEqual(a_sub_str_b, a_sub_b) + self.assertNotEqual(hash(a_sub_str_b), hash(a_sub_b)) + + a_sub_three = QN(a, subscript=QN(qual_names.NumberLiteral(3))) + self.assertEqual(a_sub_three.ast().slice.value.n, 3) + + +class QNResolverTest(test.TestCase): + + def assertQNStringIs(self, node, qn_str): + self.assertEqual(str(anno.getanno(node, anno.Basic.QN)), qn_str) + + def test_resolve(self): + samples = """ + a + a.b + (c, d.e) + [f, (g.h.i)] + j(k, l) + """ + nodes = resolve(parser.parse_str(textwrap.dedent(samples))) + nodes = tuple(n.value for n in nodes.body) + + self.assertQNStringIs(nodes[0], 'a') + self.assertQNStringIs(nodes[1], 'a.b') + self.assertQNStringIs(nodes[2].elts[0], 'c') + self.assertQNStringIs(nodes[2].elts[1], 'd.e') + self.assertQNStringIs(nodes[3].elts[0], 'f') + self.assertQNStringIs(nodes[3].elts[1], 'g.h.i') + self.assertQNStringIs(nodes[4].func, 'j') + self.assertQNStringIs(nodes[4].args[0], 'k') + self.assertQNStringIs(nodes[4].args[1], 'l') + + def test_subscript_resolve(self): + samples = """ + x[i] + x[i.b] + a.b[c] + a.b[x.y] + a[z[c]] + a[b[c[d]]] + a[b].c + a.b.c[d].e.f + a.b[c[d]].e.f + a.b[c[d.e.f].g].h + """ + nodes = resolve(parser.parse_str(textwrap.dedent(samples))) + nodes = tuple(n.value for n in nodes.body) + + self.assertQNStringIs(nodes[0], 'x[i]') + self.assertQNStringIs(nodes[1], 'x[i.b]') + self.assertQNStringIs(nodes[2], 'a.b[c]') + self.assertQNStringIs(nodes[3], 'a.b[x.y]') + self.assertQNStringIs(nodes[4], 'a[z[c]]') + self.assertQNStringIs(nodes[5], 'a[b[c[d]]]') + self.assertQNStringIs(nodes[6], 'a[b].c') + self.assertQNStringIs(nodes[7], 'a.b.c[d].e.f') + self.assertQNStringIs(nodes[8], 'a.b[c[d]].e.f') + self.assertQNStringIs(nodes[9], 'a.b[c[d.e.f].g].h') + + def test_function_calls(self): + samples = """ + a.b + a.b() + a().b + z[i] + z[i]() + z()[i] + """ + nodes = resolve(parser.parse_str(textwrap.dedent(samples))) + nodes = tuple(n.value for n in nodes.body) + self.assertQNStringIs(nodes[0], 'a.b') + self.assertQNStringIs(nodes[1].func, 'a.b') + self.assertQNStringIs(nodes[2].value.func, 'a') + self.assertQNStringIs(nodes[3], 'z[i]') + self.assertQNStringIs(nodes[4].func, 'z[i]') + self.assertQNStringIs(nodes[5].value.func, 'z') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD similarity index 83% rename from tensorflow/contrib/py2tf/pyct/static_analysis/BUILD rename to tensorflow/contrib/autograph/pyct/static_analysis/BUILD index fbfce18c60cca4b105e7de3c3ea7b9c3438f6b2a..d192bc7aabf6ea36d616ff6f2cef60fddd5973b4 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD +++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD @@ -25,7 +25,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "@gast_archive//:gast", ], ) @@ -36,7 +36,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":static_analysis", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", "@gast_archive//:gast", ], @@ -48,7 +48,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":static_analysis", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], ) @@ -59,7 +59,8 @@ py_test( srcs_version = "PY2AND3", deps = [ ":static_analysis", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/utils", "//tensorflow/python:client_testlib", ], ) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/__init__.py b/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py similarity index 100% rename from tensorflow/contrib/py2tf/pyct/static_analysis/__init__.py rename to tensorflow/contrib/autograph/pyct/static_analysis/__init__.py diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/activity.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py similarity index 85% rename from tensorflow/contrib/py2tf/pyct/static_analysis/activity.py rename to tensorflow/contrib/autograph/pyct/static_analysis/activity.py index 02ea6fdeaf78152b6bc48983f79b36f43d4f665d..da6a2f6f0500ebba41b85d06dcc912aae9d68f97 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/activity.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py @@ -22,10 +22,10 @@ import copy import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.qual_names import QN -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.qual_names import QN +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno # TODO(mdan): Add support for PY3 (e.g. Param vs arg). @@ -71,13 +71,33 @@ class Scope(object): tuple(self.modified)) def copy_from(self, other): + """Recursively copies the contents of this scope from another scope.""" + if (self.parent is None) != (other.parent is None): + raise ValueError('cannot copy scopes of different structures') + if other.parent is not None: + self.parent.copy_from(other.parent) + self.isolated = other.isolated self.modified = copy.copy(other.modified) self.created = copy.copy(other.created) self.used = copy.copy(other.used) self.params = copy.copy(other.params) self.returned = copy.copy(other.returned) + @classmethod + def copy_of(cls, other): + if other.parent is not None: + parent = cls.copy_of(other.parent) + else: + parent = None + new_copy = cls(parent) + new_copy.copy_from(other) + return new_copy + def merge_from(self, other): + if (self.parent is None) != (other.parent is None): + raise ValueError('cannot merge scopes of different structures') + if other.parent is not None: + self.parent.merge_from(other.parent) self.modified |= other.modified self.created |= other.created self.used |= other.used @@ -151,6 +171,10 @@ class ActivityAnalizer(transformer.Base): self._in_return_statement = False def _track_symbol(self, node): + # This can happen when we have an attribute (or subscript) on a function + # call. Example: a().b + if not anno.hasanno(node, anno.Basic.QN): + return qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Store): @@ -225,14 +249,12 @@ class ActivityAnalizer(transformer.Base): # modifies the parent state causing the other child blocks to be # processed incorrectly. So we need to checkpoint the parent scope so that # each child sees the same context. - before_parent = Scope(None) - before_parent.copy_from(self.scope) + before_parent = Scope.copy_of(self.scope) after_children = [] for child, scope_name in children: self.scope.copy_from(before_parent) parent = self._process_block_node(parent, child, scope_name) - after_child = Scope(None) - after_child.copy_from(self.scope) + after_child = Scope.copy_of(self.scope) after_children.append(after_child) for after_child in after_children: self.scope.merge_from(after_child) @@ -250,6 +272,15 @@ class ActivityAnalizer(transformer.Base): self.scope = current_scope return node + def visit_With(self, node): + current_scope = self.scope + with_scope = Scope(current_scope, isolated=False) + self.scope = with_scope + self.generic_visit(node) + anno.setanno(node, NodeAnno.BODY_SCOPE, with_scope) + self.scope = current_scope + return node + def visit_If(self, node): self.visit(node.test) node = self._process_parallel_blocks(node, diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py similarity index 81% rename from tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py rename to tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py index 69f5f4fc582f159e46c8b8929a90ca95b724794d..37c28872bb9fc4f0c6f95eec8145101b7a6c83de 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import context -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.qual_names import QN -from tensorflow.contrib.py2tf.pyct.static_analysis import activity -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.qual_names import QN +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno from tensorflow.python.platform import test @@ -45,7 +45,7 @@ class ScopeTest(test.TestCase): scope.mark_read(QN('bar')) self.assertFalse(scope.has(QN('bar'))) - def test_copy(self): + def test_copy_from(self): scope = activity.Scope(None) scope.mark_write(QN('foo')) @@ -65,6 +65,17 @@ class ScopeTest(test.TestCase): self.assertTrue(QN('bar') in scope.created) self.assertFalse(QN('bar') in other.created) + def test_copy_of(self): + scope = activity.Scope(None) + scope.mark_read(QN('foo')) + + self.assertTrue(QN('foo') in activity.Scope.copy_of(scope).used) + + child_scope = activity.Scope(scope) + child_scope.mark_read(QN('bar')) + + self.assertTrue(QN('bar') in activity.Scope.copy_of(child_scope).used) + def test_nesting(self): scope = activity.Scope(None) scope.mark_write(QN('foo')) @@ -133,7 +144,7 @@ class ActivityAnalizerTest(test.TestCase): anno.getanno(node.body[0].body[2].value, NodeAnno.IS_LOCAL)) # b in return b - def assertScopeIs(self, scope, used, modified, created): + def assertScopeIsRmc(self, scope, used, modified, created): self.assertItemsEqual(used, tuple(str(s) for s in scope.used)) self.assertItemsEqual(modified, tuple(str(s) for s in scope.modified)) self.assertItemsEqual(created, tuple(str(s) for s in scope.created)) @@ -159,7 +170,7 @@ class ActivityAnalizerTest(test.TestCase): print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE) # We basically need to detect which variables are captured by the call # arguments. - self.assertScopeIs(print_args_scope, ('a', 'b'), (), ()) + self.assertScopeIsRmc(print_args_scope, ('a', 'b'), (), ()) def test_call(self): @@ -173,7 +184,7 @@ class ActivityAnalizerTest(test.TestCase): call_node = node.body[0].body[2].value # We basically need to detect which variables are captured by the call # arguments. - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'b'), (), ()) def test_while(self): @@ -187,10 +198,10 @@ class ActivityAnalizerTest(test.TestCase): node = self._parse_and_analyze(test_fn) while_node = node.body[0].body[1] - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), ('c',)) - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(while_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'), ('b', 'c'), ('a', 'b', 'c')) @@ -205,9 +216,9 @@ class ActivityAnalizerTest(test.TestCase): node = self._parse_and_analyze(test_fn) for_node = node.body[0].body[1] - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), ('c',)) - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(for_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'), ('b', 'c', '_'), ('a', 'b', 'c', '_')) @@ -226,21 +237,40 @@ class ActivityAnalizerTest(test.TestCase): node = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'), ('y', 'z')) # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'), ('x', 'y', 'u'), ('y', 'u')) - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) - def test_functiondef(self): + def test_nested_if_else_creation(self): + + def test_fn(b): + if b > 0: + if b < 5: + a = b + else: + a = b * b + return a + + node = self._parse_and_analyze(test_fn) + inner_if_node = node.body[0].body[0].body[0] + self.assertScopeIsRmc( + anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',), + ('a',)) + self.assertScopeIsRmc( + anno.getanno(inner_if_node, NodeAnno.ORELSE_SCOPE), ('b',), ('a',), + ('a',)) + + def test_function_def(self): def test_fn(a): @@ -257,11 +287,11 @@ class ActivityAnalizerTest(test.TestCase): node = self._parse_and_analyze(test_fn) fndef_node = node.body[0].body[0] - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(fndef_node, NodeAnno.BODY_SCOPE).parent, ('b', 'i', 'f', 'c', 'a'), ('f', 'b', 'c', 'i'), ('f', 'a', 'b', 'c', 'i')) - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(fndef_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), ( 'x', 'y', @@ -284,13 +314,13 @@ class ActivityAnalizerTest(test.TestCase): node = self._parse_and_analyze(test_fn) call_node = node.body[0].body[0].value - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), (), ()) if_node = node.body[0].body[1] - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a',), ('a.b',), ()) - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('a', 'a.c', 'd', 'd.e', 'f'), ('a.c', 'd', 'd.e', 'f'), ('d', 'f')) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/annos.py b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py similarity index 81% rename from tensorflow/contrib/py2tf/pyct/static_analysis/annos.py rename to tensorflow/contrib/autograph/pyct/static_analysis/annos.py index 2d8e49442364fdd4a4752c8a83a5f3b76117fe57..5254b83ca7c775867fc2ad5ef0a0ad93ac483ba0 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/annos.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py @@ -34,13 +34,14 @@ class NodeAnno(NoValue): """ # Symbols - + # These flags are boolean. IS_LOCAL = 'Symbol is local to the function scope being analized.' IS_PARAM = 'Symbol is a parameter to the function being analized.' IS_MODIFIED_SINCE_ENTRY = ( 'Symbol has been explicitly replaced in the current function scope.') # Scopes + # Scopes are represented by objects of type activity.Scope. ARGS_SCOPE = 'The scope for the argument list of a function call.' BODY_SCOPE = ( 'The scope for the main body of a statement (True branch for if ' @@ -48,3 +49,10 @@ class NodeAnno(NoValue): ORELSE_SCOPE = ( 'The scope for the orelse body of a statement (False branch for if ' 'statements, orelse body for loops).') + + # Type and Value annotations + # Type annotations are represented by objects of type type_info.Type. + STATIC_INFO = ( + 'The type or value information that should be asserted about the entity ' + 'referenced by the symbol holding this annotation, irrespective of the ' + 'execution context.') diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py similarity index 85% rename from tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py rename to tensorflow/contrib/autograph/pyct/static_analysis/live_values.py index 9c0a9a9e74eccb3d22840032e8f0c2b81e051e7e..53ae15459097baff918432a493edd7360ebf209d 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py @@ -25,9 +25,9 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno class LiveValueResolver(transformer.Base): @@ -55,11 +55,19 @@ class LiveValueResolver(transformer.Base): if not symbol_is_local and not symbol_is_param: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) - # TODO(mdan): Could live values have FQNs? i.e. 'a'.join() elif node.id in self.context.namespace: obj = self.context.namespace[node.id] anno.setanno(node, 'live_val', obj) - anno.setanno(node, 'fqn', (obj.__name__,)) + if hasattr(obj, '__name__'): + anno.setanno(node, 'fqn', (obj.__name__,)) + elif hasattr(obj, '__class__'): + obj_class = obj.__class__ + anno.setanno(node, 'fqn', + (obj_class.__module__, obj_class.__name__)) + else: + # If the symbol value is for example a primitive, then it will not + # have a name. + pass else: pass # TODO(mdan): Should we raise an error here? @@ -86,6 +94,7 @@ class LiveValueResolver(transformer.Base): if not hasattr(parent_object, node.attr): raise AttributeError('%s has no attribute %s' % (parent_object, node.attr)) + anno.setanno(node, 'parent_type', type(parent_object)) anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,)) # TODO(mdan): Investigate the role built-in annotations can play here. @@ -96,6 +105,7 @@ class LiveValueResolver(transformer.Base): # This would not hold for dynamic members like function attributes. # For the dynamic case, we simply leave the node without an annotation, # and let downstream consumers figure out what to do. + anno.setanno(node, 'parent_type', parent_type) anno.setanno(node, 'live_val', getattr(parent_type, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'type_fqn') + (node.attr,)) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py similarity index 76% rename from tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py rename to tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py index 1e81bc70a85b9e68b955d23554c09004284449ea..69e428bde109ed43c3cdda1a94970a832dc47852 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py @@ -18,13 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import context -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.static_analysis import activity -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info +import six + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.framework import constant_op from tensorflow.python.platform import test @@ -57,13 +59,30 @@ class LiveValuesResolverTest(test.TestCase): def test_literals(self): + a = None + def test_fn(): - return Foo # pylint: disable=undefined-variable + return a - node = self._parse_and_analyze(test_fn, {}, {'Foo': 'bar'}) + node = self._parse_and_analyze(test_fn, {}, literals={'a': 'bar'}) retval_node = node.body[0].body[0].value self.assertEquals('bar', anno.getanno(retval_node, 'live_val')) + def test_primitive_values(self): + + a = None + + def test_fn(): + return a + + node = self._parse_and_analyze(test_fn, {'a': True}) + retval_node = node.body[0].body[0].value + if six.PY2: + self.assertEqual( + anno.getanno(retval_node, 'fqn'), ('__builtin__', 'bool')) + else: + self.assertEqual(anno.getanno(retval_node, 'fqn'), ('builtins', 'bool')) + def test_namespace(self): def foo(): @@ -103,6 +122,7 @@ class LiveValuesResolverTest(test.TestCase): arg_types={'self': (TestClass.__name__, TestClass)}) func_node = node.body[0].body[0].value.func self.assertEquals(TestClass.member, anno.getanno(func_node, 'live_val')) + self.assertEquals(TestClass, anno.getanno(func_node, 'parent_type')) self.assertEquals(('TestClass', 'member'), anno.getanno(func_node, 'fqn')) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py similarity index 60% rename from tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py rename to tensorflow/contrib/autograph/pyct/static_analysis/type_info.py index 8203bda0f9a792a5b24b9abb25d8f39b61625748..203aa3c3d18ab15300bbf424adeece6e74d9c994 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py @@ -14,22 +14,42 @@ # ============================================================================== """Type resolution. +This analyzer uses known live values to further infer object types. This +may include for instance constructed objects and object member functions. + +In addition, the analyzer will also process annotations for TF (staged) type +annotations. + Requires annotations generated by LiveValuesResolver. """ +# TODO(mdan): This would be more robust with a CFG. +# Situations with multiple reaching modifications (e.g. modified inside and +# outside a control flow statement) should be more robustly detected and +# analyzed. + +# TODO(mdan): Look into using Python AST's type annotation fields instead. +# It would be desirable to use that mechanism if we can. +# Some caveats to consider: We may need to annotate other nodes like +# Attribute. It may also not be feasible for us to faithfully to replicate +# PY3's type annotations where it isn't available. It would also require us +# to design rigorous type definitions that can accommodate Python types +# as well as TensorFLow dtypes and shapes. + + from __future__ import absolute_import from __future__ import division from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect class Scope(object): - """Encloses symbol value references. + """Tracks symbol value references. Attributes: values: A dict mapping string to gast.Node, containing the value that was @@ -138,13 +158,25 @@ class TypeInfoResolver(transformer.Base): elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn): # E.g. if we had # a = b - # then for future references to `a` we should have traced_source = `b` - traced_source = self.scope.getval(qn) - if anno.hasanno(traced_source, 'type'): - anno.setanno(node, 'type', anno.getanno(traced_source, 'type')) - anno.setanno(node, 'type_fqn', anno.getanno(traced_source, 'type_fqn')) + # then for future references to `a` we should have definition = `b` + definition = self.scope.getval(qn) + if anno.hasanno(definition, 'type'): + anno.setanno(node, 'type', anno.getanno(definition, 'type')) + anno.setanno(node, 'type_fqn', anno.getanno(definition, 'type_fqn')) + if anno.hasanno(definition, 'element_type'): + anno.setanno(node, 'element_type', + anno.getanno(definition, 'element_type')) return node + def _process_tuple_assignment(self, source, t): + for i, e in enumerate(t.elts): + if isinstance(e, gast.Tuple): + self._process_tuple_assignment(source, e) + else: + self.scope.setval( + anno.getanno(e, anno.Basic.QN), + gast.Subscript(source, gast.Index(i), ctx=gast.Store())) + def _process_variable_assignment(self, source, targets): if isinstance(source, gast.Call): func = source.func @@ -160,10 +192,9 @@ class TypeInfoResolver(transformer.Base): for t in targets: if isinstance(t, gast.Tuple): - for i, e in enumerate(t.elts): - self.scope.setval( - anno.getanno(e, anno.Basic.QN), - gast.Subscript(source, gast.Index(i), ctx=gast.Store())) + # need to recurse on the case of assigning nested tuples, + # ex. a, (b, c) = f() + self._process_tuple_assignment(source, t) elif isinstance(t, (gast.Name, gast.Attribute)): self.scope.setval(anno.getanno(t, anno.Basic.QN), source) else: @@ -181,6 +212,34 @@ class TypeInfoResolver(transformer.Base): self._process_variable_assignment(node.value, node.targets) return node + def visit_Call(self, node): + if anno.hasanno(node.func, 'live_val'): + # Symbols targeted by the "set_type" marker function are assigned the data + # type that it specified. + if (anno.getanno(node.func, 'live_val') is + self.context.type_annotation_func): + # Expecting the actual type to be the second argument. + if len(node.args) != 2: + raise ValueError('"%s" must have exactly two parameters' + % self.context.type_annotation_func) + if not anno.hasanno(node.args[0], anno.Basic.QN): + raise ValueError('the first argument of "%s" must by a symbol' + % self.context.type_annotation_func) + if not anno.hasanno(node.args[1], 'live_val'): + raise ValueError( + 'the second argument of "%s" must be statically resolvable' % + self.context.type_annotation_func) + target_symbol = anno.getanno(node.args[0], anno.Basic.QN) + element_type = anno.getanno(node.args[1], 'live_val') + # Find the definition of this symbol and annotate it with the given + # data type. That in turn will cause future uses of the symbol + # to receive the same type annotation. + definition = self.scope.getval(target_symbol) + anno.setanno(node, 'element_type', element_type) + anno.setanno(definition, 'element_type', element_type) + # TODO(mdan): Should we update references between definition and here? + return self.generic_visit(node) + def resolve(node, context): return TypeInfoResolver(context).visit(node) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py similarity index 76% rename from tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py rename to tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py index a3e78202c80e45552c038a6a1da763eb30aff52f..c0de4a604301b6e9f80ee83e4797b9ac7e558a48 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py @@ -18,13 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import context -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.static_analysis import activity -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.client import session from tensorflow.python.platform import test from tensorflow.python.training import training @@ -56,7 +57,10 @@ class ScopeTest(test.TestCase): class TypeInfoResolverTest(test.TestCase): - def _parse_and_analyze(self, test_fn, namespace, arg_types=None): + def _parse_and_analyze(self, + test_fn, + namespace, + arg_types=None): node, source = parser.parse_entity(test_fn) ctx = context.EntityContext( namer=None, @@ -66,7 +70,8 @@ class TypeInfoResolverTest(test.TestCase): arg_values=None, arg_types=arg_types, owner_type=None, - recursive=True) + recursive=True, + type_annotation_func=utils.set_element_type) node = qual_names.resolve(node) node = activity.resolve(node, ctx) node = live_values.resolve(node, ctx, {}) @@ -175,6 +180,39 @@ class TypeInfoResolverTest(test.TestCase): method_call = node.body[0].body[1].value.func self.assertFalse(anno.hasanno(method_call, 'live_val')) + def test_type_annotation(self): + + class Foo(object): + pass + + def test_fn(): + f = [] + f = utils.set_element_type(f, Foo) + return f + + node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) + f_def = node.body[0].body[0].value + self.assertEqual(anno.getanno(f_def, 'element_type'), Foo) + f_ref = node.body[0].body[1].value + self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo) + + def test_nested_assignment(self): + + def test_fn(foo): + a, (b, c) = foo + return a, b, c + + node = self._parse_and_analyze(test_fn, {'foo': (1, 2, 3)}) + lhs = node.body[0].body[1].value.elts + a = lhs[0] + b = lhs[1] + c = lhs[2] + # TODO(mdan): change these once we have the live values propagating + # correctly + self.assertFalse(anno.hasanno(a, 'live_val')) + self.assertFalse(anno.hasanno(b, 'live_val')) + self.assertFalse(anno.hasanno(c, 'live_val')) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py similarity index 56% rename from tensorflow/contrib/py2tf/pyct/templates.py rename to tensorflow/contrib/autograph/pyct/templates.py index 6ee6c0c5ceb70d87779ee313670135cadc5214b5..baf7923fff7c786c1abd05e11fa6ffdb8c8f0912 100644 --- a/tensorflow/contrib/py2tf/pyct/templates.py +++ b/tensorflow/contrib/autograph/pyct/templates.py @@ -26,9 +26,9 @@ import textwrap import gast -from tensorflow.contrib.py2tf.pyct import ast_util -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names class ReplaceTransformer(gast.NodeTransformer): @@ -44,8 +44,6 @@ class ReplaceTransformer(gast.NodeTransformer): self.replacements = replacements self.in_replacements = False - # TODO(mdan): Make a more detailed pass and clean up if needed. - def visit_Expr(self, node): if (isinstance(node.value, gast.Name) and node.value.id in self.replacements): @@ -53,17 +51,66 @@ class ReplaceTransformer(gast.NodeTransformer): self.generic_visit(node) return node + def visit_keyword(self, node): + if node.arg in self.replacements: + repl = self.replacements[node.arg] + if isinstance(repl, gast.keyword): + return repl + elif (isinstance(repl, (list, tuple)) and repl and + all(isinstance(r, gast.keyword) for r in repl)): + return repl + # TODO(mdan): We may allow replacing with a string as well. + # For example, if one wanted to replace foo with bar in foo=baz, then + # we could allow changing just node arg, so that we end up with bar=baz. + raise ValueError( + 'a keyword argument may only be replaced by another keyword or a ' + 'non-empty list of keywords. Found: %s' % repl) + return self.generic_visit(node) + def visit_FunctionDef(self, node): node = self.generic_visit(node) if node.name in self.replacements: repl = self.replacements[node.name] if not isinstance(repl, (gast.Name, ast.Name)): raise ValueError( - 'A function name can only be replaced by a Name node. Found: %s' % + 'a function name can only be replaced by a Name node. Found: %s' % repl) node.name = repl.id return node + def _check_has_context(self, node): + if not node.ctx: + raise ValueError('node %s is missing ctx value' % node) + + def _check_inner_children_have_context(self, node): + if isinstance(node, gast.Attribute): + self._check_inner_children_have_context(node.value) + self._check_has_context(node) + elif isinstance(node, gast.Tuple): + for e in node.elts: + self._check_inner_children_have_context(e) + self._check_has_context(node) + elif isinstance(node, gast.Dict): + for e in node.keys: + self._check_inner_children_have_context(e) + for e in node.values: + self._check_inner_children_have_context(e) + elif isinstance(node, gast.Subscript): + self._check_inner_children_have_context(node.value) + self._check_inner_children_have_context(node.slice) + elif isinstance(node, gast.Slice): + self._check_inner_children_have_context(node.lower) + if node.upper: + self._check_inner_children_have_context(node.upper) + if node.step: + self._check_inner_children_have_context(node.step) + elif isinstance(node, gast.Name): + self._check_has_context(node) + elif isinstance(node, (gast.Str, gast.Num)): + pass + else: + raise ValueError('unexpected node type "%s"' % node) + def _set_inner_child_context(self, node, ctx): if isinstance(node, gast.Attribute): self._set_inner_child_context(node.value, ctx) @@ -74,11 +121,40 @@ class ReplaceTransformer(gast.NodeTransformer): node.ctx = ctx elif isinstance(node, gast.Name): node.ctx = ctx + elif isinstance(node, gast.Call): + self._set_inner_child_context(node.func, ctx) + # We may be able to override these to Load(), but for now it's simpler + # to just assert that they're set. + for a in node.args: + self._check_inner_children_have_context(a) + for k in node.keywords: + self._check_inner_children_have_context(k.value) + elif isinstance(node, gast.Dict): + # We may be able to override these to Load(), but for now it's simpler + # to just assert that they're set. + for e in node.keys: + self._check_inner_children_have_context(e) + for e in node.values: + self._check_inner_children_have_context(e) + elif isinstance(node, gast.Subscript): + self._set_inner_child_context(node.value, ctx) + self._check_inner_children_have_context(node.slice) elif isinstance(node, (gast.Str, gast.Num)): pass else: raise ValueError('unexpected node type "%s"' % node) + def visit_Attribute(self, node): + node = self.generic_visit(node) + if node.attr not in self.replacements: + return node + repl = self.replacements[node.attr] + if not isinstance(repl, gast.Name): + raise ValueError( + 'An attribute can only be replaced by a Name node. Found: %s' % repl) + node.attr = repl.id + return node + def visit_Name(self, node): if node.id not in self.replacements: return node @@ -154,3 +230,17 @@ def replace(template, **replacements): if isinstance(results, list): return [qual_names.resolve(r) for r in results] return qual_names.resolve(results) + + +def replace_as_expression(template, **replacements): + """Variant of replace that generates expressions, instead of code blocks.""" + replacement = replace(template, **replacements) + if len(replacement) != 1: + raise ValueError( + 'single expression expected; for more general templates use replace') + node = replacement[0] + if not isinstance(node, gast.Expr): + raise ValueError( + 'the template is expected to generate an expression node; instead ' + 'found %s' % node) + return node.value diff --git a/tensorflow/contrib/autograph/pyct/templates_test.py b/tensorflow/contrib/autograph/pyct/templates_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a01f8bf04c4faa6ec1779e0fb306155d99f5bd09 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/templates_test.py @@ -0,0 +1,168 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for templates module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import imp + +import gast + +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.python.platform import test + + +class TemplatesTest(test.TestCase): + + def test_replace_tuple(self): + template = """ + def test_fn(a, c): + return b, + """ + + node = templates.replace(template, b=('a', 'c'))[0] + result, _ = compiler.ast_to_object(node) + + self.assertEquals((2, 3), result.test_fn(2, 3)) + + def test_replace_variable(self): + template = """ + def test_fn(a): + a += 1 + a = 2 * a + 1 + return b + """ + + node = templates.replace(template, a='b')[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(7, result.test_fn(2)) + + def test_replace_function_name(self): + template = """ + def fname(a): + a += 1 + a = 2 * a + 1 + return a + """ + + node = templates.replace(template, fname='test_fn')[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(7, result.test_fn(2)) + + def test_replace_code_block(self): + template = """ + def test_fn(a): + block + return a + """ + + node = templates.replace( + template, + block=[ + gast.Assign([ + gast.Name('a', None, None) + ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))), + ] * 2)[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(3, result.test_fn(1)) + + def test_replace_attribute(self): + template = """ + def test_fn(a): + return a.foo + """ + + node = templates.replace(template, foo='b')[0] + result, _ = compiler.ast_to_object(node) + mod = imp.new_module('test') + mod.b = 3 + self.assertEquals(3, result.test_fn(mod)) + + with self.assertRaises(ValueError): + templates.replace(template, foo=1) + + def test_replace_call_keyword(self): + template = """ + def test_fn(): + def f(a, d, f): + return a + d + f + return f(1, kws=None) + """ + + source = parser.parse_expression('f(d=3, f=5)') + node = templates.replace(template, kws=source.keywords)[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(9, result.test_fn()) + + with self.assertRaises(ValueError): + templates.replace(template, kws=[]) + templates.replace(template, kws=1) + + def test_replace_name_with_call(self): + template = """ + def test_fn(): + b = 5 + def g(a): + return 3 * a + def f(): + return g + return foo + """ + + source = parser.parse_expression('f()(b)') + node = templates.replace(template, foo=source)[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(15, result.test_fn()) + + def test_replace_name_with_dict(self): + template = """ + def test_fn(): + return foo['bar'] + """ + + source = parser.parse_expression('{\'bar\': 3}') + node = templates.replace(template, foo=source)[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(3, result.test_fn()) + + def replace_as_expression(self): + template = """ + foo(a) + """ + + node = templates.replace(template, foo='bar', a='baz') + self.assertTrue(node is gast.Call) + self.assertEqual(node.func.id, 'bar') + self.assertEqual(node.func.args[0].id, 'baz') + + def replace_as_expression_restrictions(self): + template = """ + foo(a) + bar(b) + """ + with self.assertRaises(ValueError): + templates.replace_as_expression(template) + with self.assertRaises(ValueError): + templates.replace('') + with self.assertRaises(ValueError): + templates.replace('a = b') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py similarity index 71% rename from tensorflow/contrib/py2tf/pyct/transformer.py rename to tensorflow/contrib/autograph/pyct/transformer.py index 877d52af016af720424c8a56257fec9ab64611cb..35f114b6e11901a854c1d631061ae42285c0e261 100644 --- a/tensorflow/contrib/py2tf/pyct/transformer.py +++ b/tensorflow/contrib/autograph/pyct/transformer.py @@ -23,14 +23,22 @@ import sys import gast import six -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import pretty_printer +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import pretty_printer -class PyFlowParseError(SyntaxError): +class AutographParseError(SyntaxError): pass +def try_ast_to_source(node): + try: + return compiler.ast_to_source(node) + except AssertionError: + return '' + + class Base(gast.NodeTransformer): """Base class for specialized transformers.""" @@ -44,6 +52,12 @@ class Base(gast.NodeTransformer): self._col_offset = 0 self.context = context + def debug_print(self, node): + """Helper method useful for debugging.""" + if __debug__: + print(pretty_printer.fmt(node)) + return node + def visit(self, node): source_code = self.context.source_code source_file = self.context.source_file @@ -56,14 +70,15 @@ class Base(gast.NodeTransformer): return super(Base, self).visit(node) except (ValueError, AttributeError, KeyError, NotImplementedError, AssertionError) as e: - msg = '%s: %s\nOccurred at node:\n%s' % ( - e.__class__.__name__, str(e), pretty_printer.fmt(node, color=False)) + msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % ( + e.__class__.__name__, str(e), try_ast_to_source(node), + pretty_printer.fmt(node, color=False)) if source_code: line = source_code.splitlines()[self._lineno - 1] else: line = '' - six.reraise(PyFlowParseError, - PyFlowParseError( + six.reraise(AutographParseError, + AutographParseError( msg, (source_file, self._lineno, self._col_offset + 1, line)), sys.exc_info()[2]) diff --git a/tensorflow/contrib/py2tf/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD similarity index 90% rename from tensorflow/contrib/py2tf/utils/BUILD rename to tensorflow/contrib/autograph/utils/BUILD index c2fdd40707775783140390e4b5c0186c9c3e562e..b53fbb5c18f27aa4681347d965dc7322c849ec91 100644 --- a/tensorflow/contrib/py2tf/utils/BUILD +++ b/tensorflow/contrib/autograph/utils/BUILD @@ -20,25 +20,29 @@ py_library( name = "utils", srcs = [ "__init__.py", + "builtins.py", "context_managers.py", "misc.py", "multiple_dispatch.py", - "printing.py", "py_func.py", "tensor_list.py", + "testing.py", "type_check.py", + "type_hints.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ + "//tensorflow/python:list_ops", "//tensorflow/python:script_ops", + "//tensorflow/python/data/ops:dataset_ops", "@six_archive//:six", ], ) py_test( - name = "context_managers_test", - srcs = ["context_managers_test.py"], + name = "builtins_test", + srcs = ["builtins_test.py"], srcs_version = "PY2AND3", deps = [ ":utils", @@ -47,8 +51,8 @@ py_test( ) py_test( - name = "misc_test", - srcs = ["misc_test.py"], + name = "context_managers_test", + srcs = ["context_managers_test.py"], srcs_version = "PY2AND3", deps = [ ":utils", @@ -57,8 +61,8 @@ py_test( ) py_test( - name = "multiple_dispatch_test", - srcs = ["multiple_dispatch_test.py"], + name = "misc_test", + srcs = ["misc_test.py"], srcs_version = "PY2AND3", deps = [ ":utils", @@ -67,8 +71,8 @@ py_test( ) py_test( - name = "py_func_test", - srcs = ["py_func_test.py"], + name = "multiple_dispatch_test", + srcs = ["multiple_dispatch_test.py"], srcs_version = "PY2AND3", deps = [ ":utils", @@ -77,8 +81,8 @@ py_test( ) py_test( - name = "printing_test", - srcs = ["printing_test.py"], + name = "py_func_test", + srcs = ["py_func_test.py"], srcs_version = "PY2AND3", deps = [ ":utils", diff --git a/tensorflow/contrib/autograph/utils/__init__.py b/tensorflow/contrib/autograph/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22898b17e98bb004b4d2aa529b58cc99fc64dbb2 --- /dev/null +++ b/tensorflow/contrib/autograph/utils/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility module that contains APIs usable in the generated code.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.utils.builtins import dynamic_builtin +from tensorflow.contrib.autograph.utils.builtins import dynamic_dataset +from tensorflow.contrib.autograph.utils.builtins import dynamic_for_cond +from tensorflow.contrib.autograph.utils.builtins import dynamic_print +from tensorflow.contrib.autograph.utils.builtins import dynamic_range +from tensorflow.contrib.autograph.utils.context_managers import control_dependency_on_returns +from tensorflow.contrib.autograph.utils.misc import alias_tensors +from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is +from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is_not +from tensorflow.contrib.autograph.utils.multiple_dispatch import run_cond +from tensorflow.contrib.autograph.utils.multiple_dispatch import run_while +from tensorflow.contrib.autograph.utils.py_func import wrap_py_func +from tensorflow.contrib.autograph.utils.tensor_list import dynamic_list_append +from tensorflow.contrib.autograph.utils.testing import fake_tf +from tensorflow.contrib.autograph.utils.type_check import is_tensor +from tensorflow.contrib.autograph.utils.type_hints import set_element_type diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py new file mode 100644 index 0000000000000000000000000000000000000000..4ab32ee47de5c0b3b6ab18c731da7626887b67a5 --- /dev/null +++ b/tensorflow/contrib/autograph/utils/builtins.py @@ -0,0 +1,166 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Builtin conversion utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.autograph.utils import py_func +from tensorflow.contrib.autograph.utils import type_check +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.util import tf_inspect + + +def dynamic_builtin(f, *args, **kwargs): + """Converts a builtin function call inline.""" + # Some built-ins may be objects. + if not tf_inspect.isbuiltin(f) and f not in (range,): + return f(*args, **kwargs) + + if f is len: + return dynamic_len(*args, **kwargs) + if six.PY2 and f is xrange: + return dynamic_range(*args, **kwargs) + if f is range: + return dynamic_range(*args, **kwargs) + + raise NotImplementedError( + 'The "%s" builtin is not yet supported.' % f.__name__) + + +def dynamic_len(list_or_tensor): + """Implementation of len using dynamic dispatch.""" + if tensor_util.is_tensor(list_or_tensor): + shape = list_or_tensor.shape + if not shape: + raise ValueError( + 'len requires non-zero rank for tensor "%s"' % list_or_tensor) + return array_ops.shape(list_or_tensor)[0] + return len(list_or_tensor) + + +def dynamic_range(start_or_stop, stop=None, step=None): + """Implementation of range using dynamic dispatch.""" + if type_check.is_tensor(start_or_stop, stop, step): + if step is not None: + return math_ops.range(start_or_stop, stop, step) + if stop is not None: + return math_ops.range(start_or_stop, stop) + return math_ops.range(start_or_stop) + + if step is not None: + return range(start_or_stop, stop, step) + elif stop is not None: + return range(start_or_stop, stop) + return range(start_or_stop) + + +def is_tf_print_compatible(value): + # TODO(mdan): Enable once we can reliably test this. + # This is currently disabled because we can't capture the output of + # op kernels from Python. + del value + return False + + +def dynamic_print(*values): + """Implementartion of print using dynamic dispatch. + + The function attempts to use tf.Print if all the values are compatible. + Otherwise, it will fall back to py_func. + + Args: + *values: values to print + Returns: + A dummy value indicating the print completed. If tf. + """ + + if all(map(is_tf_print_compatible, values)): + return logging_ops.Print(1, values) + return py_func.wrap_py_func(print, None, values, use_dummy_return=True) + + +def dynamic_dataset(iterated): + """Implementartion of smart tf.data.Dataset epoch wrapping. + + The function checks if the input is a tf.data.Dataset and if so then wraps it + so that for each element it returns it also returns the current epoch the + dataset iteration is in, for two epochs. If the input is not a + tf.data.Dataset then it just returns the input. + + Args: + iterated: The iterable or tf.data.Dataset that is being iterated over. + Returns: + Either just the untouched input, or in the case of input being a + tf.data.Dataset then it returns a wrapped tf.data.Dataset where for each + element it returns it also returns the current epoch the dataset iteration + is in. + """ + if not isinstance(iterated, dataset_ops.Dataset): + return iterated + + def epoch_dataset_number_helper(i): + return dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensors(i).repeat(), iterated)) + + epoch_numbers = dataset_ops.Dataset.range(2) + return epoch_numbers.flat_map(epoch_dataset_number_helper) + + +def dynamic_for_cond(iteration, iterated): + """Implementartion of smart while-loop condition using dynamic dispatch. + + The function checks if it is iterating over a tf.data.Dataset or not, and in + the case it is not then it simply returns if we are still in range of the + iterated and the next element. If it is iterating over a dataset then it only + iterates for a single epoch. + + Args: + iteration: The current iteration of the loop. + iterated: The iterable or tf.data.Dataset that is being iterated over. + Returns: + A tuple of a bool that indicates whether the loop should continue, and the + next element in iterated. + """ + # TODO(znado): Clean up. + # TODO(znado): This won't work for unpacked iterates. Fix. + if isinstance(iterated, dataset_ops.Dataset): + curr_epoch, next_elem = iterated.make_one_shot_iterator().get_next() + return math_ops.less(curr_epoch, 1), next_elem + elif tensor_util.is_tensor(iterated): + if iterated.shape.ndims > 1: + elem_shape = array_ops.shape(iterated)[1:] + else: + elem_shape = () + if iterated.shape.ndims == 0 or iterated.shape[0] == 0: + return False, array_ops.zeros(elem_shape, iterated.dtype) + return control_flow_ops.cond( + math_ops.less(iteration, dynamic_len(iterated)), + lambda: (True, iterated[iteration]), + lambda: (False, array_ops.zeros(elem_shape, iterated.dtype))) + elif hasattr(iterated, '__len__'): + if iteration < len(iterated): + return True, iterated[iteration] + return False, None + else: + raise NotImplementedError('Python iterators not yet supported.') diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d9f7913d89a5471c76eb7ae484674bd7a1853ac9 --- /dev/null +++ b/tensorflow/contrib/autograph/utils/builtins_test.py @@ -0,0 +1,111 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for builtins module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import six + +from tensorflow.contrib.autograph.utils import builtins +from tensorflow.python.framework import constant_op +from tensorflow.python.platform import test + + +class BuiltinsTest(test.TestCase): + + def test_dynamic_len_tf_scalar(self): + a = constant_op.constant(1) + + with self.assertRaises(ValueError): + with self.test_session() as sess: + sess.run(builtins.dynamic_builtin(len, a)) + + def test_dynamic_len_tf_array(self): + a = constant_op.constant([1, 2, 3]) + + with self.test_session() as sess: + self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a))) + + def test_dynamic_len_tf_matrix(self): + a = constant_op.constant([[1, 2], [3, 4]]) + + with self.test_session() as sess: + self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a))) + + def test_dynamic_len_py_list(self): + a = [3] * 5 + + self.assertEqual(5, builtins.dynamic_builtin(len, a)) + + def test_dynamic_range_all_python(self): + self.assertListEqual(list(builtins.dynamic_builtin(range, 3)), [0, 1, 2]) + self.assertListEqual(list(builtins.dynamic_builtin(range, 1, 3)), [1, 2]) + self.assertListEqual( + list(builtins.dynamic_builtin(range, 2, 0, -1)), [2, 1]) + + def test_dynamic_range_tf(self): + with self.test_session() as sess: + self.assertAllEqual( + sess.run(builtins.dynamic_builtin(range, constant_op.constant(3))), + [0, 1, 2]) + self.assertAllEqual( + sess.run(builtins.dynamic_builtin(range, 1, constant_op.constant(3))), + [1, 2]) + self.assertAllEqual( + sess.run( + builtins.dynamic_builtin(range, 2, 0, constant_op.constant(-1))), + [2, 1]) + + def test_dynamic_range_detection(self): + def range(x): # pylint:disable=redefined-builtin + return x + + # Functions that just have the names of builtins are ignored. + self.assertEqual(builtins.dynamic_builtin(range, 1), 1) + if six.PY2: + self.assertListEqual( + list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2]) + self.assertListEqual( + list(builtins.dynamic_builtin(six.moves.range, 3)), [0, 1, 2]) + self.assertListEqual( + list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2]) + + def test_dynamic_print_tf(self): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + with self.test_session() as sess: + sess.run(builtins.dynamic_print('test message', 1)) + self.assertEqual(out_capturer.getvalue(), 'test message 1\n') + finally: + sys.stdout = sys.__stdout__ + + def test_dynamic_print_complex(self): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + with self.test_session() as sess: + sess.run(builtins.dynamic_print('test message', [1, 2])) + self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n') + finally: + sys.stdout = sys.__stdout__ + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/context_managers.py b/tensorflow/contrib/autograph/utils/context_managers.py similarity index 85% rename from tensorflow/contrib/py2tf/utils/context_managers.py rename to tensorflow/contrib/autograph/utils/context_managers.py index 38d9e11fe9069722b9023fee848bf53e1f72de6a..3d150a95817b83c4d7aaa78dc250092dcc4c5a9b 100644 --- a/tensorflow/contrib/py2tf/utils/context_managers.py +++ b/tensorflow/contrib/autograph/utils/context_managers.py @@ -21,6 +21,7 @@ from __future__ import print_function import contextlib from tensorflow.python.framework import ops +from tensorflow.python.ops import tensor_array_ops def control_dependency_on_returns(return_value): @@ -34,9 +35,15 @@ def control_dependency_on_returns(return_value): Returns: A context manager. """ + def control_dependency_handle(t): + if isinstance(t, tensor_array_ops.TensorArray): + return t.flow + return t + if return_value is None: return contextlib.contextmanager(lambda: (yield))() # TODO(mdan): Filter to tensor objects. if not isinstance(return_value, (list, tuple)): return_value = (return_value,) + return_value = tuple(control_dependency_handle(t) for t in return_value) return ops.control_dependencies(return_value) diff --git a/tensorflow/contrib/py2tf/utils/context_managers_test.py b/tensorflow/contrib/autograph/utils/context_managers_test.py similarity index 82% rename from tensorflow/contrib/py2tf/utils/context_managers_test.py rename to tensorflow/contrib/autograph/utils/context_managers_test.py index 633ba93540e696889a6b2b71b40b999da39d48ff..42e27724b9856f715b524cdd7539897851715638 100644 --- a/tensorflow/contrib/py2tf/utils/context_managers_test.py +++ b/tensorflow/contrib/autograph/utils/context_managers_test.py @@ -18,8 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils import context_managers +from tensorflow.contrib.autograph.utils import context_managers from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test @@ -32,6 +34,9 @@ class ContextManagersTest(test.TestCase): with context_managers.control_dependency_on_returns( constant_op.constant(1)): pass + with context_managers.control_dependency_on_returns( + tensor_array_ops.TensorArray(dtypes.int32, size=1)): + pass with context_managers.control_dependency_on_returns( [constant_op.constant(1), constant_op.constant(2)]): diff --git a/tensorflow/contrib/py2tf/utils/misc.py b/tensorflow/contrib/autograph/utils/misc.py similarity index 80% rename from tensorflow/contrib/py2tf/utils/misc.py rename to tensorflow/contrib/autograph/utils/misc.py index 7548048388766d0f12a55eecd77fca2706f9734b..1b06caf0bdeb6f4a079e33f2e887d2dca017adc2 100644 --- a/tensorflow/contrib/py2tf/utils/misc.py +++ b/tensorflow/contrib/autograph/utils/misc.py @@ -19,22 +19,9 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops -def dynamic_len(list_or_tensor): - """Implementation of len using dynamic dispatch.""" - if tensor_util.is_tensor(list_or_tensor): - shape = list_or_tensor.shape - if not shape: - raise ValueError( - 'len requires non-zero rank for tensor "%s"' % list_or_tensor) - return array_ops.shape(list_or_tensor)[0] - - return len(list_or_tensor) - - def alias_tensors(*args): """Wrap any Tensor arguments with an identity op. diff --git a/tensorflow/contrib/py2tf/utils/misc_test.py b/tensorflow/contrib/autograph/utils/misc_test.py similarity index 67% rename from tensorflow/contrib/py2tf/utils/misc_test.py rename to tensorflow/contrib/autograph/utils/misc_test.py index ec88e7cb74bd40b851fb3d2fe246d37d8c668d82..71e358c33e1ea9887d267c67bc80362bac26c3a6 100644 --- a/tensorflow/contrib/py2tf/utils/misc_test.py +++ b/tensorflow/contrib/autograph/utils/misc_test.py @@ -18,38 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils.misc import alias_tensors -from tensorflow.contrib.py2tf.utils.misc import dynamic_len +from tensorflow.contrib.autograph.utils.misc import alias_tensors from tensorflow.python.framework.constant_op import constant from tensorflow.python.ops.variables import Variable from tensorflow.python.platform import test -class ContextManagersTest(test.TestCase): - - def test_dynamic_len_tf_scalar(self): - a = constant(1) - - with self.assertRaises(ValueError): - with self.test_session() as sess: - sess.run(dynamic_len(a)) - - def test_dynamic_len_tf_array(self): - a = constant([1, 2, 3]) - - with self.test_session() as sess: - self.assertEqual(3, sess.run(dynamic_len(a))) - - def test_dynamic_len_tf_matrix(self): - a = constant([[1, 2], [3, 4]]) - - with self.test_session() as sess: - self.assertEqual(2, sess.run(dynamic_len(a))) - - def test_dynamic_len_py_list(self): - a = [3] * 5 - - self.assertEqual(5, dynamic_len(a)) +class MiscTest(test.TestCase): def test_alias_single_tensor(self): a = constant(1) diff --git a/tensorflow/contrib/py2tf/utils/multiple_dispatch.py b/tensorflow/contrib/autograph/utils/multiple_dispatch.py similarity index 80% rename from tensorflow/contrib/py2tf/utils/multiple_dispatch.py rename to tensorflow/contrib/autograph/utils/multiple_dispatch.py index a855fdc075941915035d1e3380846ff912803494..47049255f31113a0c7b2f5a1269593afdbbc9b19 100644 --- a/tensorflow/contrib/py2tf/utils/multiple_dispatch.py +++ b/tensorflow/contrib/autograph/utils/multiple_dispatch.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities for type-dependent behavior used in py2tf-generated code.""" +"""Utilities for type-dependent behavior used in autograph-generated code.""" from __future__ import absolute_import from __future__ import division @@ -20,10 +20,20 @@ from __future__ import print_function import six -from tensorflow.contrib.py2tf.utils.type_check import is_tensor +from tensorflow.contrib.autograph.utils.type_check import is_tensor from tensorflow.python.ops import control_flow_ops +def dynamic_is(left, right): + # TODO(alexbw) if we're sure we should leave 'is' in place, + # then change the semantics in converters/logical_expressions.py + return left is right + + +def dynamic_is_not(left, right): + return left is not right + + def run_cond(condition, true_fn, false_fn): """Type-dependent functional conditional. @@ -45,10 +55,17 @@ def run_cond(condition, true_fn, false_fn): def py_cond(condition, true_fn, false_fn): + """Functional version of Python's conditional.""" if condition: - return true_fn() + results = true_fn() else: - return false_fn() + results = false_fn() + + # The contract for the branch functions is to return tuples, but they should + # be collapsed to a single element when there is only one output. + if len(results) == 1: + return results[0] + return results def run_while(cond_fn, body_fn, init_args): diff --git a/tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py b/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py similarity index 61% rename from tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py rename to tensorflow/contrib/autograph/utils/multiple_dispatch_test.py index 5bb4d4086b002211eebb86783bb7212c707a1418..e6a41bb4166e8cfc8c703685f56eb90a1b5f63b4 100644 --- a/tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py +++ b/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py @@ -17,7 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils import multiple_dispatch + +import numpy as np + +from tensorflow.contrib.autograph.utils import multiple_dispatch from tensorflow.python.client.session import Session from tensorflow.python.framework.constant_op import constant from tensorflow.python.platform import test @@ -25,21 +28,47 @@ from tensorflow.python.platform import test class MultipleDispatchTest(test.TestCase): + def test_dynamic_is_python(self): + a = np.eye(3) + also_a = a + not_actually_a = np.eye(3) + should_be_true1 = multiple_dispatch.dynamic_is(a, also_a) + should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a) + should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a) + should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a) + self.assertTrue(should_be_true1) + self.assertTrue(should_be_true2) + self.assertFalse(should_be_false1) + self.assertFalse(should_be_false2) + + def test_dynamic_is_tf(self): + with Session().as_default(): + a = constant([2.0]) + also_a = a + not_actually_a = constant([2.0]) + should_be_true1 = multiple_dispatch.dynamic_is(a, also_a) + should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a) + should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a) + should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a) + self.assertTrue(should_be_true1) + self.assertTrue(should_be_true2) + self.assertFalse(should_be_false1) + self.assertFalse(should_be_false2) + def test_run_cond_python(self): - true_fn = lambda: 2.0 - false_fn = lambda: 3.0 - self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2.0) - self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3.0) + true_fn = lambda: (2,) + false_fn = lambda: (3,) + self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2) + self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3) def test_run_cond_tf(self): - - true_fn = lambda: constant([2.0]) - false_fn = lambda: constant([3.0]) + true_fn = lambda: (constant(2),) + false_fn = lambda: (constant(3),) with Session() as sess: out = multiple_dispatch.run_cond(constant(True), true_fn, false_fn) - self.assertEqual(sess.run(out), 2.0) + self.assertEqual(sess.run(out), 2) out = multiple_dispatch.run_cond(constant(False), true_fn, false_fn) - self.assertEqual(sess.run(out), 3.0) + self.assertEqual(sess.run(out), 3) def test_run_while_python(self): cond_fn = lambda x, t, s: x > t diff --git a/tensorflow/contrib/autograph/utils/py_func.py b/tensorflow/contrib/autograph/utils/py_func.py new file mode 100644 index 0000000000000000000000000000000000000000..11ebfb2e49f0e762b56ae2cde2b76d2e24032d72 --- /dev/null +++ b/tensorflow/contrib/autograph/utils/py_func.py @@ -0,0 +1,131 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pyfunc creation utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import script_ops + + +class MatchDType(namedtuple('MatchDType', ('arg_number',))): + """Allows matching the dtype of an argument. + + Used in conjunction with function calls. For example, MatchDType(0) will + match the DType of the first argument. + """ + + pass + + +def wrap_py_func(f, return_dtypes, args, kwargs=None, use_dummy_return=False): + """Helper that wraps a callable to py_func. + + The helper passes tensor arguments through the py_func interface. Non-tensor + arguments are allowed, and will be passed to f directly. Note that non-tensor + arguments are captured by f will not update every time the wrapper is + called (this is consistent with its argument list, which only includes + the tensor arguments). In general, it's safest not to reuse this wrapper. + + Args: + f: Callable + return_dtypes: None, individual of tuple/list of DType or MatchDType, the + data type for each of f's return value(s). Set to None if f has no + return values or use_dummy_return is True. Use MatchDType to define a + dtype identical to that of `i`th argument (argument 0 is the first); + an argument must of Tensor type if it is to be used with MatchDType. + args: Positional arguments for f, as list or tuple. + kwargs: Keyword arguments for f, as dict with string keys. May be None. + use_dummy_return: If True, the function will return a dummy value of 1 + and discard its actual return value. + Returns: + The return values of f converted to tensor. + Raises: + ValueError: if any of the arguments are incorrect. + """ + + if return_dtypes and use_dummy_return: + raise ValueError('if use_dummy_return is True, return_dtypes must be empty') + + tensor_args = [] + tensor_args_idx = {} + + # Of the positional arguments, only grab the tensor ones to be passed through + # the py_func. + n_args = len(args) + arg_is_tensor = tuple(map(tensor_util.is_tensor, args)) + for i in range(n_args): + if arg_is_tensor[i]: + tensor_args_idx[i] = len(tensor_args) + tensor_args.append(args[i]) + + # We essentially take the tensor kwargs, if any, and add them to the list of + # positional arguments. The kwargs are then reconstructed inside the py_func. + # + # For example, if + # + # args = [Tensor(1), 'foo'] + # kwargs = {'a': Tensor(2), 'b': 'bar'} + # + # Then + # + # tensor_args = (Tensor(1), Tensor(2)) + # kwarg_keys = ('a', 'b') + if kwargs: + kwarg_keys = tuple(kwargs.keys()) + kwarg_is_tensor = {k: tensor_util.is_tensor(kwargs[k]) for k in kwarg_keys} + for k in kwarg_keys: + if kwarg_is_tensor[k]: + tensor_args_idx[k] = len(tensor_args) + tensor_args.append(kwargs[k]) + else: + kwarg_keys = () + + # Set up return dtypes. + def match_arg_dtype(arg_number): + arg = args[arg_number] + if not arg_is_tensor[arg_number]: + raise ValueError( + 'argument %d was used with MatchDType and must be a tf.Tensor, but ' + 'was %s instead' % (arg_number, type(arg))) + return arg.dtype + + if return_dtypes: + if isinstance(return_dtypes, MatchDType): + return_dtypes = match_arg_dtype(return_dtypes.arg_number) + elif isinstance(return_dtypes, (list, tuple)): + return_dtypes = tuple( + match_arg_dtype(a.arg_number) if isinstance(a, MatchDType) else a + for a in return_dtypes) + else: + assert isinstance(return_dtypes, dtypes.DType) + + def f_wrapper(*tensor_args): + f_args = tuple(tensor_args[tensor_args_idx[i]] if arg_is_tensor[i] else a + for i, a in enumerate(args)) + f_kwargs = { + k: tensor_args[tensor_args_idx[k]] if kwarg_is_tensor[k] else kwargs[k] + for i, k in enumerate(kwarg_keys) + } + retval = f(*f_args, **f_kwargs) + return 1 if use_dummy_return else retval + + return script_ops.py_func(f_wrapper, tensor_args, dtypes.int64 + if use_dummy_return else return_dtypes) diff --git a/tensorflow/contrib/autograph/utils/py_func_test.py b/tensorflow/contrib/autograph/utils/py_func_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2468263142f14332e86db99d198ba0f5c633dc69 --- /dev/null +++ b/tensorflow/contrib/autograph/utils/py_func_test.py @@ -0,0 +1,103 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for wrap_py_func module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.utils import py_func +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.platform import test + + +class PyFuncTest(test.TestCase): + + def test_wrap_py_func_simple(self): + + def test_fn(a, b, c): + return a + b + c + + with self.test_session() as sess: + result = py_func.wrap_py_func(test_fn, dtypes.int64, + (1, constant_op.constant(1), 1)) + self.assertEqual(3, sess.run(result)) + result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1)) + self.assertEqual(3, sess.run(result)) + result = py_func.wrap_py_func( + test_fn, dtypes.int64, + (constant_op.constant(1), 1, constant_op.constant(1))) + self.assertEqual(3, sess.run(result)) + + def test_wrap_py_func_complex_args(self): + + class TestClass(object): + + def __init__(self): + self.foo = 5 + + def test_fn(a, b): + return a * b.foo + + with self.test_session() as sess: + result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass())) + self.assertEqual(35, sess.run(result)) + result = py_func.wrap_py_func(test_fn, dtypes.int64, + (constant_op.constant(7), TestClass())) + self.assertEqual(35, sess.run(result)) + + def test_wrap_py_func_kwargs(self): + + class TestClass(object): + + def __init__(self, foo): + self.foo = foo + + def test_fn(a, b, c, d): + return a * b.foo + c * d.foo + + with self.test_session() as sess: + result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass(5)), { + 'c': 11, + 'd': TestClass(13) + }) + self.assertEqual(178, sess.run(result)) + result = py_func.wrap_py_func(test_fn, dtypes.int64, + (constant_op.constant(7), TestClass(5)), { + 'c': constant_op.constant(11), + 'd': TestClass(13) + }) + self.assertEqual(178, sess.run(result)) + + def test_wrap_py_func_dummy_return(self): + + side_counter = [0] + + def test_fn(_): + side_counter[0] += 1 + + with self.test_session() as sess: + result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True) + self.assertEqual(1, sess.run(result)) + self.assertEqual([1], side_counter) + result = py_func.wrap_py_func( + test_fn, None, (constant_op.constant(5),), use_dummy_return=True) + self.assertEqual(1, sess.run(result)) + self.assertEqual([2], side_counter) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/tensor_list.py b/tensorflow/contrib/autograph/utils/tensor_list.py similarity index 67% rename from tensorflow/contrib/py2tf/utils/tensor_list.py rename to tensorflow/contrib/autograph/utils/tensor_list.py index b6ff49e2a0eff384f10903e12212ab929e267804..2556f412891b4f0b954af5a6f0193341a6a5020a 100644 --- a/tensorflow/contrib/py2tf/utils/tensor_list.py +++ b/tensorflow/contrib/autograph/utils/tensor_list.py @@ -18,7 +18,26 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import ops from tensorflow.python.ops import list_ops +from tensorflow.python.ops import tensor_array_ops + + +def dynamic_list_append(target, element): + """Converts a list append call inline.""" + if isinstance(target, tensor_array_ops.TensorArray): + return target.write(target.size(), element) + # TODO(mdan): What's the right way to check this? + # TODO(mdan): We may not need this branch. + # It may be possible to use TensorList alone if the loop body will not + # require wrapping it, although we'd have to think about an autoboxing + # mechanism for lists received as parameter. + if isinstance(target, ops.Tensor): + return list_ops.tensor_list_push_back(target, element) + + # Python targets (including TensorList): fallback to their original append. + target.append(element) + return target class TensorList(object): diff --git a/tensorflow/contrib/py2tf/utils/tensor_list_test.py b/tensorflow/contrib/autograph/utils/tensor_list_test.py similarity index 71% rename from tensorflow/contrib/py2tf/utils/tensor_list_test.py rename to tensorflow/contrib/autograph/utils/tensor_list_test.py index b5e554a162674e08da21785dcbe193c54647f128..d58489eb68b6b949a4276520605c62b7c2825558 100644 --- a/tensorflow/contrib/py2tf/utils/tensor_list_test.py +++ b/tensorflow/contrib/autograph/utils/tensor_list_test.py @@ -12,22 +12,50 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for PyFlow list.""" +"""Tests for Autograph lists.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils import tensor_list as tl +from tensorflow.contrib.autograph.utils import tensor_list as tl from tensorflow.python.client.session import Session from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework.constant_op import constant +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test class TensorListTest(test.TestCase): + def _shape(self, shape_tuple): + return constant(shape_tuple, dtypes.int32) + + def test_dynamic_list_append(self): + l = [] + l = tl.dynamic_list_append(l, 1) + self.assertListEqual(l, [1]) + + l = list_ops.empty_tensor_list(self._shape(()), dtypes.int32) + l = tl.dynamic_list_append(l, 1) + s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) + with self.test_session() as sess: + self.assertAllEqual(sess.run(s), [1]) + + l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True) + l = tl.dynamic_list_append(l, 1) + s = l.stack() + with self.test_session() as sess: + self.assertAllEqual(sess.run(s), [1]) + + l = tl.TensorList(self._shape(()), dtypes.int32) + l = tl.dynamic_list_append(l, 1) + with self.test_session() as sess: + self.assertAllEqual(sess.run(l[0]), 1) + def test_list_append_python(self): with context.eager_mode(): a = constant(3.0) diff --git a/tensorflow/contrib/py2tf/pyct/parser.py b/tensorflow/contrib/autograph/utils/testing.py similarity index 65% rename from tensorflow/contrib/py2tf/pyct/parser.py rename to tensorflow/contrib/autograph/utils/testing.py index dc7df883b349becd860bb0dbceab22cb39c750b5..cb4785d0dc0f4674b3560418daeb6733364b21e7 100644 --- a/tensorflow/contrib/py2tf/pyct/parser.py +++ b/tensorflow/contrib/autograph/utils/testing.py @@ -12,29 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Converting code to AST. - -Adapted from Tangent. -""" +"""Testing utilities.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import textwrap - -import gast - -from tensorflow.python.util import tf_inspect - +import imp -def parse_entity(entity): - """Return the AST of given entity.""" - source = tf_inspect.getsource(entity) - source = textwrap.dedent(source) - return parse_str(source), source +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops -def parse_str(src): - """Return the AST of given piece of code.""" - return gast.parse(src) +def fake_tf(): + """Creates a fake module that looks like TensorFlow, for testing.""" + mod = imp.new_module('tensorflow') + mod_contents = dict() + mod_contents.update(math_ops.__dict__) + mod_contents.update(ops.__dict__) + mod_contents.update(mod.__dict__) + mod.__dict__.update(mod_contents) + return mod diff --git a/tensorflow/contrib/py2tf/utils/type_check.py b/tensorflow/contrib/autograph/utils/type_check.py similarity index 86% rename from tensorflow/contrib/py2tf/utils/type_check.py rename to tensorflow/contrib/autograph/utils/type_check.py index 9ca2dec872c8a9ca7bedaa8603f70e3214a3e24a..8748abc47bcfb55b4d0b11178a46816249732da9 100644 --- a/tensorflow/contrib/py2tf/utils/type_check.py +++ b/tensorflow/contrib/autograph/utils/type_check.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities used in py2tf-generated code.""" +"""Utilities used in autograph-generated code.""" from __future__ import absolute_import from __future__ import division @@ -22,12 +22,12 @@ from tensorflow.python.framework import tensor_util def is_tensor(*args): - """Check if all arguments are tensors. + """Check if any arguments are tensors. Args: *args: Python objects that may or may not be tensors. Returns: - True if all *args are TensorFlow types, False if one or more are not. + True if any *args are TensorFlow types, False if none are. """ return any([tensor_util.is_tensor(a) for a in args]) diff --git a/tensorflow/contrib/py2tf/utils/type_check_test.py b/tensorflow/contrib/autograph/utils/type_check_test.py similarity index 96% rename from tensorflow/contrib/py2tf/utils/type_check_test.py rename to tensorflow/contrib/autograph/utils/type_check_test.py index 7d0428e9cccecdc67511e236bc00655a055aea29..3b67b7194c5656b193d47860f93986a985cb1aef 100644 --- a/tensorflow/contrib/py2tf/utils/type_check_test.py +++ b/tensorflow/contrib/autograph/utils/type_check_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy -from tensorflow.contrib.py2tf.utils import type_check +from tensorflow.contrib.autograph.utils import type_check from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.platform import test diff --git a/tensorflow/contrib/bayesflow/python/ops/custom_grad.py b/tensorflow/contrib/autograph/utils/type_hints.py similarity index 54% rename from tensorflow/contrib/bayesflow/python/ops/custom_grad.py rename to tensorflow/contrib/autograph/utils/type_hints.py index ca1ecb9c40204c3c723fa3423cfe148e823adc28..aeb9e545610460afbe364dfcfc7a54b9aede29fe 100644 --- a/tensorflow/contrib/bayesflow/python/ops/custom_grad.py +++ b/tensorflow/contrib/autograph/utils/type_hints.py @@ -12,23 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functions for specifying custom gradients. +"""No-op utilities that provide static type hints. -See ${python/contrib.bayesflow.custom_gradient}. +These are used when the data type is not known at creation, for instance in the +case of empty lists. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.custom_grad_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = [ - 'custom_gradient', -] +def set_element_type(entity, dtype, shape=None): + """Indicates that the entity is expected hold items of specified type. -remove_undocumented(__name__, _allowed_symbols) + This function is a no-op. Its presence merely marks the data type of its + argument. The staged TensorFlow ops will reflect and assert this data type. + + Args: + entity: A Tensor or TensorArray. + dtype: TensorFlow dtype value to assert for entity. + shape: Optional shape to assert for entity. + Returns: + The value of entity, unchanged. + """ + del dtype + del shape + return entity diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index d7beb26e1b77c44d45a21a8c2bc752c93bf5b313..a55029b314e67571519d96607ff1fe36070c50ef 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -37,136 +37,6 @@ py_library( ], ) -cuda_py_test( - name = "metropolis_hastings_test", - size = "medium", - srcs = ["python/kernel_tests/metropolis_hastings_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], -) - -cuda_py_test( - name = "csiszar_divergence_test", - size = "medium", - srcs = ["python/kernel_tests/csiszar_divergence_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:gradients", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - ], - tags = [ - "manual", # b/64490288 - "notap", - ], -) - -cuda_py_test( - name = "custom_grad_test", - size = "small", - srcs = ["python/kernel_tests/custom_grad_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:init_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], -) - -cuda_py_test( - name = "docstring_util_test", - size = "small", - srcs = ["python/kernel_tests/docstring_util_test.py"], - additional_deps = [ - ":bayesflow_py", - "//tensorflow/python:client_testlib", - ], -) - -cuda_py_test( - name = "layers_conv_variational_test", - size = "small", - srcs = ["python/kernel_tests/layers_conv_variational_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:gradients", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - ], -) - -cuda_py_test( - name = "layers_dense_variational_test", - size = "small", - srcs = ["python/kernel_tests/layers_dense_variational_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:gradients", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - ], -) - -cuda_py_test( - name = "mcmc_diagnostics_test", - size = "small", - srcs = ["python/kernel_tests/mcmc_diagnostics_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/python:spectral_ops_test_util", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_seed", - ], -) - cuda_py_test( name = "monte_carlo_test", size = "small", @@ -188,108 +58,6 @@ cuda_py_test( ], ) -cuda_py_test( - name = "halton_sequence_test", - size = "small", - srcs = ["python/kernel_tests/halton_sequence_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], - tags = ["no_mac"], # b/73192243 -) - -cuda_py_test( - name = "hmc_test", - size = "medium", - srcs = ["python/kernel_tests/hmc_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_seed", - ], -) - -cuda_py_test( - name = "sgld_optimizer_test", - size = "small", - srcs = ["python/kernel_tests/sgld_optimizer_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_seed", - ], - tags = ["notsan"], -) - -cuda_py_test( - name = "variable_utils_test", - size = "small", - srcs = ["python/kernel_tests/variable_utils_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_test( - name = "variational_sgd_optimizer_test", - size = "small", - srcs = ["python/kernel_tests/variational_sgd_optimizer_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_seed", - ], - tags = ["notsan"], -) - filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/bayesflow/README.md b/tensorflow/contrib/bayesflow/README.md new file mode 100644 index 0000000000000000000000000000000000000000..10323dc6d59918a9f8cf1840d06dcd219dfe3568 --- /dev/null +++ b/tensorflow/contrib/bayesflow/README.md @@ -0,0 +1,17 @@ +# Notice + +`tf.contrib.bayesflow` has moved! + +See new code at [github.com/tensorflow/probability]( +https://github.com/tensorflow/probability). + +Switch imports with: + +```python +# old +import tensorflow as tf +tfp = tf.contrib.bayesflow + +# new +import tensorflow_probability as tfp +``` diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py index 528c4fbacd06c7b0defa0e32bd24a98b2bc07b64..41a8c920fc4e81af90f4c94a149d8c404c58b747 100644 --- a/tensorflow/contrib/bayesflow/__init__.py +++ b/tensorflow/contrib/bayesflow/__init__.py @@ -21,36 +21,14 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long -from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence -from tensorflow.contrib.bayesflow.python.ops import custom_grad -from tensorflow.contrib.bayesflow.python.ops import halton_sequence -from tensorflow.contrib.bayesflow.python.ops import hmc -from tensorflow.contrib.bayesflow.python.ops import layers -from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics -from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings from tensorflow.contrib.bayesflow.python.ops import monte_carlo -from tensorflow.contrib.bayesflow.python.ops import optimizers -from tensorflow.contrib.bayesflow.python.ops import variable_utils # pylint: enable=unused-import,line-too-long from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'csiszar_divergence', - 'custom_grad', - 'entropy', - 'halton_sequence', - 'hmc', - 'layers', - 'metropolis_hastings', - 'mcmc_diagnostics', 'monte_carlo', - 'optimizers', - 'special_math', - 'stochastic_variables', - 'variable_utils', - 'variational_inference', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py deleted file mode 100644 index 2e94b7206de4f7c40c89f083f3bfa2a22bb7b917..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py +++ /dev/null @@ -1,1004 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Csiszar Divergence Ops.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence_impl -from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib -from tensorflow.contrib.distributions.python.ops import mvn_full_covariance as mvn_full_lib -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import kullback_leibler -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.platform import test - - -cd = csiszar_divergence_impl - - -def tridiag(d, diag_value, offdiag_value): - """d x d matrix with given value on diag, and one super/sub diag.""" - diag_mat = linalg_ops.eye(d) * (diag_value - offdiag_value) - three_bands = array_ops.matrix_band_part( - array_ops.fill([d, d], offdiag_value), 1, 1) - return diag_mat + three_bands - - -class AmariAlphaTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - for alpha in [-1., 0., 1., 2.]: - for normalized in [True, False]: - with self.test_session(graph=ops.Graph()): - self.assertAllClose( - cd.amari_alpha(0., alpha=alpha, - self_normalized=normalized).eval(), - 0.) - - def test_correct_when_alpha0(self): - with self.test_session(): - self.assertAllClose( - cd.amari_alpha(self._logu, alpha=0.).eval(), - -self._logu) - - self.assertAllClose( - cd.amari_alpha(self._logu, alpha=0., self_normalized=True).eval(), - -self._logu + (self._u - 1.)) - - def test_correct_when_alpha1(self): - with self.test_session(): - self.assertAllClose( - cd.amari_alpha(self._logu, alpha=1.).eval(), - self._u * self._logu) - - self.assertAllClose( - cd.amari_alpha(self._logu, alpha=1., self_normalized=True).eval(), - self._u * self._logu - (self._u - 1.)) - - def test_correct_when_alpha_not_01(self): - for alpha in [-2, -1., -0.5, 0.5, 2.]: - with self.test_session(graph=ops.Graph()): - self.assertAllClose( - cd.amari_alpha(self._logu, - alpha=alpha, - self_normalized=False).eval(), - ((self._u**alpha - 1)) / (alpha * (alpha - 1.))) - - self.assertAllClose( - cd.amari_alpha(self._logu, - alpha=alpha, - self_normalized=True).eval(), - ((self._u**alpha - 1.) - - alpha * (self._u - 1)) / (alpha * (alpha - 1.))) - - -class KLReverseTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - for normalized in [True, False]: - with self.test_session(graph=ops.Graph()): - self.assertAllClose( - cd.kl_reverse(0., self_normalized=normalized).eval(), - 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.kl_reverse(self._logu).eval(), - -self._logu) - - self.assertAllClose( - cd.kl_reverse(self._logu, self_normalized=True).eval(), - -self._logu + (self._u - 1.)) - - -class KLForwardTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - for normalized in [True, False]: - with self.test_session(graph=ops.Graph()): - self.assertAllClose( - cd.kl_forward(0., self_normalized=normalized).eval(), - 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.kl_forward(self._logu).eval(), - self._u * self._logu) - - self.assertAllClose( - cd.kl_forward(self._logu, self_normalized=True).eval(), - self._u * self._logu - (self._u - 1.)) - - -class JensenShannonTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.jensen_shannon(0.).eval(), np.log(0.25)) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.jensen_shannon(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.jensen_shannon).eval()) - - self.assertAllClose( - cd.jensen_shannon(self._logu, self_normalized=True).eval(), - cd.symmetrized_csiszar_function( - self._logu, - lambda x: cd.jensen_shannon(x, self_normalized=True)).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.jensen_shannon(self._logu).eval(), - (self._u * self._logu - - (1 + self._u) * np.log1p(self._u))) - - self.assertAllClose( - cd.jensen_shannon(self._logu, self_normalized=True).eval(), - (self._u * self._logu - - (1 + self._u) * np.log((1 + self._u) / 2))) - - -class ArithmeticGeometricMeanTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.arithmetic_geometric(0.).eval(), np.log(4)) - self.assertAllClose( - cd.arithmetic_geometric(0., self_normalized=True).eval(), 0.) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.arithmetic_geometric(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.arithmetic_geometric).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.arithmetic_geometric(self._logu).eval(), - (1. + self._u) * np.log((1. + self._u) / np.sqrt(self._u))) - - self.assertAllClose( - cd.arithmetic_geometric(self._logu, self_normalized=True).eval(), - (1. + self._u) * np.log(0.5 * (1. + self._u) / np.sqrt(self._u))) - - -class TotalVariationTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.total_variation(0.).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.total_variation(self._logu).eval(), - 0.5 * np.abs(self._u - 1)) - - -class PearsonTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.pearson(0.).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.pearson(self._logu).eval(), - np.square(self._u - 1)) - - -class SquaredHellingerTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.squared_hellinger(0.).eval(), 0.) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.squared_hellinger(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.squared_hellinger).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.squared_hellinger(self._logu).eval(), - np.square(np.sqrt(self._u) - 1)) - - -class TriangularTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.triangular(0.).eval(), 0.) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.triangular(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.triangular).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.triangular(self._logu).eval(), - np.square(self._u - 1) / (1 + self._u)) - - -class TPowerTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.t_power(0., t=-0.1).eval(), 0.) - self.assertAllClose(cd.t_power(0., t=0.5).eval(), 0.) - self.assertAllClose(cd.t_power(0., t=1.1).eval(), 0.) - self.assertAllClose( - cd.t_power(0., t=-0.1, self_normalized=True).eval(), 0.) - self.assertAllClose( - cd.t_power(0., t=0.5, self_normalized=True).eval(), 0.) - self.assertAllClose( - cd.t_power(0., t=1.1, self_normalized=True).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(-0.1)).eval(), - self._u ** -0.1 - 1.) - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(0.5)).eval(), - -self._u ** 0.5 + 1.) - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(1.1)).eval(), - self._u ** 1.1 - 1.) - - def test_correct_self_normalized(self): - with self.test_session(): - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(-0.1), - self_normalized=True).eval(), - self._u ** -0.1 - 1. + 0.1 * (self._u - 1.)) - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(0.5), - self_normalized=True).eval(), - -self._u ** 0.5 + 1. + 0.5 * (self._u - 1.)) - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(1.1), - self_normalized=True).eval(), - self._u ** 1.1 - 1. - 1.1 * (self._u - 1.)) - - -class Log1pAbsTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.log1p_abs(0.).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.log1p_abs(self._logu).eval(), - self._u**(np.sign(self._u - 1)) - 1) - - -class JeffreysTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.jeffreys(0.).eval(), 0.) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.jeffreys(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.jeffreys).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.jeffreys(self._logu).eval(), - 0.5 * (self._u * self._logu - self._logu)) - - -class ChiSquareTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.chi_square(0.).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.chi_square(self._logu).eval(), - self._u**2 - 1) - - -class ModifiedGanTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose( - cd.modified_gan(0.).eval(), np.log(2)) - self.assertAllClose( - cd.modified_gan(0., self_normalized=True).eval(), np.log(2)) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.modified_gan(self._logu).eval(), - np.log1p(self._u) - self._logu) - - self.assertAllClose( - cd.modified_gan(self._logu, self_normalized=True).eval(), - np.log1p(self._u) - self._logu + 0.5 * (self._u - 1)) - - -class SymmetrizedCsiszarFunctionTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10., 100) - self._u = np.exp(self._logu) - - def test_jensen_shannon(self): - with self.test_session(): - - # The following functions come from the claim made in the - # symmetrized_csiszar_function docstring. - def js1(logu): - return (-logu - - (1. + math_ops.exp(logu)) * ( - nn_ops.softplus(logu))) - - def js2(logu): - return 2. * (math_ops.exp(logu) * ( - logu - nn_ops.softplus(logu))) - - self.assertAllClose( - cd.symmetrized_csiszar_function(self._logu, js1).eval(), - cd.jensen_shannon(self._logu).eval()) - - self.assertAllClose( - cd.symmetrized_csiszar_function(self._logu, js2).eval(), - cd.jensen_shannon(self._logu).eval()) - - def test_jeffreys(self): - with self.test_session(): - self.assertAllClose( - cd.symmetrized_csiszar_function(self._logu, cd.kl_reverse).eval(), - cd.jeffreys(self._logu).eval()) - - self.assertAllClose( - cd.symmetrized_csiszar_function(self._logu, cd.kl_forward).eval(), - cd.jeffreys(self._logu).eval()) - - -class DualCsiszarFunctionTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10., 100) - self._u = np.exp(self._logu) - - def test_kl_forward(self): - with self.test_session(): - self.assertAllClose( - cd.dual_csiszar_function(self._logu, cd.kl_forward).eval(), - cd.kl_reverse(self._logu).eval()) - - def test_kl_reverse(self): - with self.test_session(): - self.assertAllClose( - cd.dual_csiszar_function(self._logu, cd.kl_reverse).eval(), - cd.kl_forward(self._logu).eval()) - - -class MonteCarloCsiszarFDivergenceTest(test.TestCase): - - def test_kl_forward(self): - with self.test_session() as sess: - q = normal_lib.Normal( - loc=np.ones(6), - scale=np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])) - - p = normal_lib.Normal(loc=q.loc + 0.1, scale=q.scale - 0.2) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_forward, - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_forward(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - exact_kl = kullback_leibler.kl_divergence(p, q) - - [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ - approx_kl, approx_kl_self_normalized, exact_kl]) - - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.08, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.02, atol=0.) - - def test_kl_reverse(self): - with self.test_session() as sess: - - q = normal_lib.Normal( - loc=np.ones(6), - scale=np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])) - - p = normal_lib.Normal(loc=q.loc + 0.1, scale=q.scale - 0.2) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_reverse, - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - exact_kl = kullback_leibler.kl_divergence(q, p) - - [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ - approx_kl, approx_kl_self_normalized, exact_kl]) - - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.07, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.02, atol=0.) - - def test_kl_reverse_multidim(self): - - with self.test_session() as sess: - d = 5 # Dimension - - p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5)) - - q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[0.5]*d) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_reverse, - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - exact_kl = kullback_leibler.kl_divergence(q, p) - - [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ - approx_kl, approx_kl_self_normalized, exact_kl]) - - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.02, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.08, atol=0.) - - def test_kl_forward_multidim(self): - - with self.test_session() as sess: - d = 5 # Dimension - - p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5)) - - # Variance is very high when approximating Forward KL, so we make - # scale_diag larger than in test_kl_reverse_multidim. This ensures q - # "covers" p and thus Var_q[p/q] is smaller. - q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[1.]*d) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_forward, - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_forward(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - exact_kl = kullback_leibler.kl_divergence(p, q) - - [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ - approx_kl, approx_kl_self_normalized, exact_kl]) - - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.06, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.05, atol=0.) - - def test_score_trick(self): - - with self.test_session() as sess: - d = 5 # Dimension - num_draws = int(1e5) - seed = 1 - - p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5)) - - # Variance is very high when approximating Forward KL, so we make - # scale_diag larger than in test_kl_reverse_multidim. This ensures q - # "covers" p and thus Var_q[p/q] is smaller. - s = array_ops.constant(1.) - q = mvn_diag_lib.MultivariateNormalDiag( - scale_diag=array_ops.tile([s], [d])) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_reverse, - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - seed=seed) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - seed=seed) - - approx_kl_score_trick = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_reverse, - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - use_reparametrization=False, - seed=seed) - - approx_kl_self_normalized_score_trick = ( - cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - use_reparametrization=False, - seed=seed)) - - exact_kl = kullback_leibler.kl_divergence(q, p) - - grad_sum = lambda fs: gradients_impl.gradients(fs, s)[0] - - [ - approx_kl_grad_, - approx_kl_self_normalized_grad_, - approx_kl_score_trick_grad_, - approx_kl_self_normalized_score_trick_grad_, - exact_kl_grad_, - approx_kl_, - approx_kl_self_normalized_, - approx_kl_score_trick_, - approx_kl_self_normalized_score_trick_, - exact_kl_, - ] = sess.run([ - grad_sum(approx_kl), - grad_sum(approx_kl_self_normalized), - grad_sum(approx_kl_score_trick), - grad_sum(approx_kl_self_normalized_score_trick), - grad_sum(exact_kl), - approx_kl, - approx_kl_self_normalized, - approx_kl_score_trick, - approx_kl_self_normalized_score_trick, - exact_kl, - ]) - - # Test average divergence. - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.02, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.08, atol=0.) - - self.assertAllClose(approx_kl_score_trick_, exact_kl_, - rtol=0.02, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_score_trick_, exact_kl_, - rtol=0.08, atol=0.) - - # Test average gradient-divergence. - self.assertAllClose(approx_kl_grad_, exact_kl_grad_, - rtol=0.007, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_grad_, exact_kl_grad_, - rtol=0.011, atol=0.) - - self.assertAllClose(approx_kl_score_trick_grad_, exact_kl_grad_, - rtol=0.018, atol=0.) - - self.assertAllClose( - approx_kl_self_normalized_score_trick_grad_, exact_kl_grad_, - rtol=0.017, atol=0.) - - -class CsiszarVIMCOTest(test.TestCase): - - def _csiszar_vimco_helper(self, logu): - """Numpy implementation of `csiszar_vimco_helper`.""" - - # Since this is a naive/intuitive implementation, we compensate by using the - # highest precision we can. - logu = np.float128(logu) - n = logu.shape[0] - u = np.exp(logu) - loogeoavg_u = [] # Leave-one-out geometric-average of exp(logu). - for j in range(n): - loogeoavg_u.append(np.exp(np.mean( - [logu[i, ...] for i in range(n) if i != j], - axis=0))) - loogeoavg_u = np.array(loogeoavg_u) - - loosum_u = [] # Leave-one-out sum of exp(logu). - for j in range(n): - loosum_u.append(np.sum( - [u[i, ...] for i in range(n) if i != j], - axis=0)) - loosum_u = np.array(loosum_u) - - # Natural log of the average u except each is swapped-out for its - # leave-`i`-th-out Geometric average. - log_sooavg_u = np.log(loosum_u + loogeoavg_u) - np.log(n) - - log_avg_u = np.log(np.mean(u, axis=0)) - return log_avg_u, log_sooavg_u - - def _csiszar_vimco_helper_grad(self, logu, delta): - """Finite difference approximation of `grad(csiszar_vimco_helper, logu)`.""" - - # This code actually estimates the sum of the Jacobiab because that's what - # TF's `gradients` does. - np_log_avg_u1, np_log_sooavg_u1 = self._csiszar_vimco_helper( - logu[..., None] + np.diag([delta]*len(logu))) - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper( - logu[..., None]) - return [ - (np_log_avg_u1 - np_log_avg_u) / delta, - np.sum(np_log_sooavg_u1 - np_log_sooavg_u, axis=0) / delta, - ] - - def test_vimco_helper_1(self): - """Tests that function calculation correctly handles batches.""" - - logu = np.linspace(-100., 100., 100).reshape([10, 2, 5]) - with self.test_session() as sess: - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu) - [log_avg_u, log_sooavg_u] = sess.run(cd.csiszar_vimco_helper(logu)) - self.assertAllClose(np_log_avg_u, log_avg_u, - rtol=1e-8, atol=0.) - self.assertAllClose(np_log_sooavg_u, log_sooavg_u, - rtol=1e-8, atol=0.) - - def test_vimco_helper_2(self): - """Tests that function calculation correctly handles overflow.""" - - # Using 700 (rather than 1e3) since naive numpy version can't handle higher. - logu = np.float32([0., 700, -1, 1]) - with self.test_session() as sess: - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu) - [log_avg_u, log_sooavg_u] = sess.run(cd.csiszar_vimco_helper(logu)) - self.assertAllClose(np_log_avg_u, log_avg_u, - rtol=1e-6, atol=0.) - self.assertAllClose(np_log_sooavg_u, log_sooavg_u, - rtol=1e-5, atol=0.) - - def test_vimco_helper_3(self): - """Tests that function calculation correctly handles underlow.""" - - logu = np.float32([0., -1000, -1, 1]) - with self.test_session() as sess: - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu) - [log_avg_u, log_sooavg_u] = sess.run(cd.csiszar_vimco_helper(logu)) - self.assertAllClose(np_log_avg_u, log_avg_u, - rtol=1e-5, atol=0.) - self.assertAllClose(np_log_sooavg_u, log_sooavg_u, - rtol=1e-4, atol=1e-15) - - def test_vimco_helper_gradient_using_finite_difference_1(self): - """Tests that gradient calculation correctly handles batches.""" - - logu_ = np.linspace(-100., 100., 100).reshape([10, 2, 5]) - with self.test_session() as sess: - logu = array_ops.constant(logu_) - - grad = lambda flogu: gradients_impl.gradients(flogu, logu)[0] - log_avg_u, log_sooavg_u = cd.csiszar_vimco_helper(logu) - - [ - grad_log_avg_u, - grad_log_sooavg_u, - ] = sess.run([grad(log_avg_u), grad(log_sooavg_u)]) - - # We skip checking against finite-difference approximation since it - # doesn't support batches. - - # Verify claim in docstring. - self.assertAllClose( - np.ones_like(grad_log_avg_u.sum(axis=0)), - grad_log_avg_u.sum(axis=0)) - self.assertAllClose( - np.ones_like(grad_log_sooavg_u.mean(axis=0)), - grad_log_sooavg_u.mean(axis=0)) - - def test_vimco_helper_gradient_using_finite_difference_2(self): - """Tests that gradient calculation correctly handles overflow.""" - - delta = 1e-3 - logu_ = np.float32([0., 1000, -1, 1]) - with self.test_session() as sess: - logu = array_ops.constant(logu_) - - [ - np_grad_log_avg_u, - np_grad_log_sooavg_u, - ] = self._csiszar_vimco_helper_grad(logu_, delta) - - grad = lambda flogu: gradients_impl.gradients(flogu, logu)[0] - log_avg_u, log_sooavg_u = cd.csiszar_vimco_helper(logu) - - [ - grad_log_avg_u, - grad_log_sooavg_u, - ] = sess.run([grad(log_avg_u), grad(log_sooavg_u)]) - - self.assertAllClose(np_grad_log_avg_u, grad_log_avg_u, - rtol=delta, atol=0.) - self.assertAllClose(np_grad_log_sooavg_u, grad_log_sooavg_u, - rtol=delta, atol=0.) - # Verify claim in docstring. - self.assertAllClose( - np.ones_like(grad_log_avg_u.sum(axis=0)), - grad_log_avg_u.sum(axis=0)) - self.assertAllClose( - np.ones_like(grad_log_sooavg_u.mean(axis=0)), - grad_log_sooavg_u.mean(axis=0)) - - def test_vimco_helper_gradient_using_finite_difference_3(self): - """Tests that gradient calculation correctly handles underlow.""" - - delta = 1e-3 - logu_ = np.float32([0., -1000, -1, 1]) - with self.test_session() as sess: - logu = array_ops.constant(logu_) - - [ - np_grad_log_avg_u, - np_grad_log_sooavg_u, - ] = self._csiszar_vimco_helper_grad(logu_, delta) - - grad = lambda flogu: gradients_impl.gradients(flogu, logu)[0] - log_avg_u, log_sooavg_u = cd.csiszar_vimco_helper(logu) - - [ - grad_log_avg_u, - grad_log_sooavg_u, - ] = sess.run([grad(log_avg_u), grad(log_sooavg_u)]) - - self.assertAllClose(np_grad_log_avg_u, grad_log_avg_u, - rtol=delta, atol=0.) - self.assertAllClose(np_grad_log_sooavg_u, grad_log_sooavg_u, - rtol=delta, atol=0.) - # Verify claim in docstring. - self.assertAllClose( - np.ones_like(grad_log_avg_u.sum(axis=0)), - grad_log_avg_u.sum(axis=0)) - self.assertAllClose( - np.ones_like(grad_log_sooavg_u.mean(axis=0)), - grad_log_sooavg_u.mean(axis=0)) - - def test_vimco_and_gradient(self): - - with self.test_session() as sess: - dims = 5 # Dimension - num_draws = int(20) - num_batch_draws = int(3) - seed = 1 - - f = lambda logu: cd.kl_reverse(logu, self_normalized=False) - np_f = lambda logu: -logu - - p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=tridiag(dims, diag_value=1, offdiag_value=0.5)) - - # Variance is very high when approximating Forward KL, so we make - # scale_diag larger than in test_kl_reverse_multidim. This ensures q - # "covers" p and thus Var_q[p/q] is smaller. - s = array_ops.constant(1.) - q = mvn_diag_lib.MultivariateNormalDiag( - scale_diag=array_ops.tile([s], [dims])) - - vimco = cd.csiszar_vimco( - f=f, - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - num_batch_draws=num_batch_draws, - seed=seed) - - x = q.sample(sample_shape=[num_draws, num_batch_draws], - seed=seed) - x = array_ops.stop_gradient(x) - logu = p.log_prob(x) - q.log_prob(x) - f_log_sum_u = f(cd.csiszar_vimco_helper(logu)[0]) - - grad_sum = lambda fs: gradients_impl.gradients(fs, s)[0] - - def jacobian(x): - # Warning: this function is slow and may not even finish if prod(shape) - # is larger than, say, 100. - shape = x.shape.as_list() - assert all(s is not None for s in shape) - x = array_ops.reshape(x, shape=[-1]) - r = [grad_sum(x[i]) for i in range(np.prod(shape))] - return array_ops.reshape(array_ops.stack(r), shape=shape) - - [ - logu_, - jacobian_logqx_, - vimco_, - grad_vimco_, - f_log_sum_u_, - grad_mean_f_log_sum_u_, - ] = sess.run([ - logu, - jacobian(q.log_prob(x)), - vimco, - grad_sum(vimco), - f_log_sum_u, - grad_sum(f_log_sum_u) / num_batch_draws, - ]) - - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu_) - - # Test VIMCO loss is correct. - self.assertAllClose(np_f(np_log_avg_u).mean(axis=0), vimco_, - rtol=1e-5, atol=0.) - - # Test gradient of VIMCO loss is correct. - # - # To make this computation we'll inject two gradients from TF: - # - grad[mean(f(log(sum(p(x)/q(x)))))] - # - jacobian[log(q(x))]. - # - # We now justify why using these (and only these) TF values for - # ground-truth does not undermine the completeness of this test. - # - # Regarding `grad_mean_f_log_sum_u_`, note that we validate the - # correctness of the zero-th order derivative (for each batch member). - # Since `cd.csiszar_vimco_helper` itself does not manipulate any gradient - # information, we can safely rely on TF. - self.assertAllClose(np_f(np_log_avg_u), f_log_sum_u_, rtol=1e-4, atol=0.) - # - # Regarding `jacobian_logqx_`, note that testing the gradient of - # `q.log_prob` is outside the scope of this unit-test thus we may safely - # use TF to find it. - - # The `mean` is across batches and the `sum` is across iid samples. - np_grad_vimco = ( - grad_mean_f_log_sum_u_ - + np.mean( - np.sum( - jacobian_logqx_ * (np_f(np_log_avg_u) - - np_f(np_log_sooavg_u)), - axis=0), - axis=0)) - - self.assertAllClose(np_grad_vimco, grad_vimco_, - rtol=1e-5, atol=0.) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py deleted file mode 100644 index a95df31ac1fd9f5038abe779391ccba5f7fe408d..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Custom Gradient Ops.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import custom_grad_impl -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - - -cg = custom_grad_impl - - -class CustomGradientTest(test.TestCase): - - def test_works_correctly(self): - with self.test_session() as sess: - f = lambda x: x**2 / 2 - g = lambda x: (x - 1)**3 / 3 - x_ = np.linspace(-100, 100, int(1e4)) + [0.] - - x = constant_op.constant(x_) - fx = cg.custom_gradient(f(x), g(x), x) - gx = gradients_impl.gradients(fx, x)[0] - [fx_, gx_] = sess.run([fx, gx]) - - self.assertAllClose(f(x_), fx_) - self.assertAllClose(g(x_), gx_) - - def test_works_correctly_both_f_g_zero(self): - with self.test_session() as sess: - f = lambda x: x**2 / 2 - g = lambda x: x**3 / 3 - x_ = np.linspace(-100, 100, int(1e4)) + [0.] - - x = constant_op.constant(x_) - fx = cg.custom_gradient(f(x), g(x), x) - gx = gradients_impl.gradients(fx, x)[0] - [fx_, gx_] = sess.run([fx, gx]) - - self.assertAllClose(f(x_), fx_) - self.assertAllClose(g(x_), gx_) - - def test_works_correctly_vector_of_vars(self): - with self.test_session() as sess: - x = variable_scope.get_variable( - name="x", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(2)) - y = variable_scope.get_variable( - name="y", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(3)) - sess.run([variables.global_variables_initializer()]) - - f = lambda z: z[0] * z[1] - g = lambda z: z[0]**2 * z[1]**2 / 2 - - z = array_ops.stack([x, y]) - fz = cg.custom_gradient(f(z), g(z), z, axis=0) - gz = gradients_impl.gradients(fz, variables.trainable_variables()) - [z_, fz_, gx_, gy_] = sess.run([z, fz, gz[0], gz[1]]) - - self.assertEqual(f(z_), fz_) - self.assertEqual(g(z_), gx_) - self.assertEqual(g(z_), gy_) - - def test_works_correctly_side_vars(self): - with self.test_session() as sess: - x_ = np.float32(2.1) # Adding extra tenth to force imprecision. - y_ = np.float32(3.1) - x = variable_scope.get_variable( - name="x", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(x_)) - y = variable_scope.get_variable( - name="y", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(y_)) - sess.run([variables.global_variables_initializer()]) - - f = lambda x: x * y - g = lambda z: math_ops.square(x) * y - - fx = cg.custom_gradient(f(x), g(x), x) - gx = gradients_impl.gradients(fx, variables.trainable_variables()) - [x_, fx_, gx_] = sess.run([x, fx, gx[0]]) - gy_ = gx[1] - - self.assertEqual(x_ * y_, fx_) - self.assertEqual(np.square(x_) * y_, gx_) - self.assertEqual(None, gy_) - - def test_works_correctly_fx_gx_manually_stopped(self): - with self.test_session() as sess: - x_ = np.float32(2.1) # Adding extra tenth to force imprecision. - y_ = np.float32(3.1) - x = variable_scope.get_variable( - name="x", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(x_)) - y = variable_scope.get_variable( - name="y", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(y_)) - sess.run([variables.global_variables_initializer()]) - - stop = array_ops.stop_gradient # For readability. - - # Basically we need to stop the `x` portion of `f`. And when we supply the - # arg to `custom_gradient` we need to stop the complement, i.e., the `y` - # part. - f = lambda x: stop(x) * y - g = lambda x: stop(math_ops.square(x)) * y - fx = cg.custom_gradient(f(x), g(x), x + stop(y), - fx_gx_manually_stopped=True) - - gx = gradients_impl.gradients(fx, variables.trainable_variables()) - [x_, fx_, gx_, gy_] = sess.run([x, fx, gx[0], gx[1]]) - - self.assertEqual(x_ * y_, fx_) - self.assertEqual(np.square(x_) * y_, gx_) - self.assertEqual(x_, gy_) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/docstring_util_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/docstring_util_test.py deleted file mode 100644 index 8ed500b19d8dd72795758a2920119e3680576697..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/docstring_util_test.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for docstring utilities.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.bayesflow.python.ops import docstring_util -from tensorflow.python.platform import test - - -class DocstringUtil(test.TestCase): - - def _testFunction(self): - doc_args = """x: Input to return as output. - y: Baz.""" - @docstring_util.expand_docstring(args=doc_args) - def foo(x): - # pylint: disable=g-doc-args - """Hello world. - - Args: - @{args} - - Returns: - x. - """ - # pylint: enable=g-doc-args - return x - - true_docstring = """Hello world. - - Args: - x: Input to return as output. - y: Baz. - - Returns: - x. - """ - self.assertEqual(foo.__doc__, true_docstring) - - def _testClassInit(self): - doc_args = """x: Input to return as output. - y: Baz.""" - - class Foo(object): - - @docstring_util.expand_docstring(args=doc_args) - def __init__(self, x, y): - # pylint: disable=g-doc-args - """Hello world. - - Args: - @{args} - - Bar. - """ - # pylint: enable=g-doc-args - pass - - true_docstring = """Hello world. - - Args: - x: Input to return as output. - y: Baz. - - Bar. - """ - self.assertEqual(Foo.__doc__, true_docstring) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py deleted file mode 100644 index 0a85862abfd744a86b9a38e10dbb5b985d0a0e94..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for halton_sequence.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import halton_sequence as halton -from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.platform import test - - -mc = monte_carlo_lib - - -class HaltonSequenceTest(test.TestCase): - - def test_known_values_small_bases(self): - with self.test_session(): - # The first five elements of the Halton sequence with base 2 and 3 - expected = np.array(((1. / 2, 1. / 3), - (1. / 4, 2. / 3), - (3. / 4, 1. / 9), - (1. / 8, 4. / 9), - (5. / 8, 7. / 9)), dtype=np.float32) - sample = halton.sample(2, num_samples=5) - self.assertAllClose(expected, sample.eval(), rtol=1e-6) - - def test_sample_indices(self): - with self.test_session(): - dim = 5 - indices = math_ops.range(10, dtype=dtypes.int32) - sample_direct = halton.sample(dim, num_samples=10) - sample_from_indices = halton.sample(dim, sample_indices=indices) - self.assertAllClose(sample_direct.eval(), sample_from_indices.eval(), - rtol=1e-6) - - def test_dtypes_works_correctly(self): - with self.test_session(): - dim = 3 - sample_float32 = halton.sample(dim, num_samples=10, dtype=dtypes.float32) - sample_float64 = halton.sample(dim, num_samples=10, dtype=dtypes.float64) - self.assertEqual(sample_float32.eval().dtype, np.float32) - self.assertEqual(sample_float64.eval().dtype, np.float64) - - def test_normal_integral_mean_and_var_correctly_estimated(self): - n = int(1000) - # This test is almost identical to the similarly named test in - # monte_carlo_test.py. The only difference is that we use the Halton - # samples instead of the random samples to evaluate the expectations. - # MC with pseudo random numbers converges at the rate of 1/ Sqrt(N) - # (N=number of samples). For QMC in low dimensions, the expected convergence - # rate is ~ 1/N. Hence we should only need 1e3 samples as compared to the - # 1e6 samples used in the pseudo-random monte carlo. - with self.test_session(): - mu_p = array_ops.constant([-1.0, 1.0], dtype=dtypes.float64) - mu_q = array_ops.constant([0.0, 0.0], dtype=dtypes.float64) - sigma_p = array_ops.constant([0.5, 0.5], dtype=dtypes.float64) - sigma_q = array_ops.constant([1.0, 1.0], dtype=dtypes.float64) - p = normal_lib.Normal(loc=mu_p, scale=sigma_p) - q = normal_lib.Normal(loc=mu_q, scale=sigma_q) - - cdf_sample = halton.sample(2, num_samples=n, dtype=dtypes.float64) - q_sample = q.quantile(cdf_sample) - - # Compute E_p[X]. - e_x = mc.expectation_importance_sampler( - f=lambda x: x, log_p=p.log_prob, sampling_dist_q=q, z=q_sample, - seed=42) - - # Compute E_p[X^2]. - e_x2 = mc.expectation_importance_sampler( - f=math_ops.square, log_p=p.log_prob, sampling_dist_q=q, z=q_sample, - seed=42) - - stddev = math_ops.sqrt(e_x2 - math_ops.square(e_x)) - # Keep the tolerance levels the same as in monte_carlo_test.py. - self.assertEqual(p.batch_shape, e_x.get_shape()) - self.assertAllClose(p.mean().eval(), e_x.eval(), rtol=0.01) - self.assertAllClose(p.stddev().eval(), stddev.eval(), rtol=0.02) - - def test_docstring_example(self): - # Produce the first 1000 members of the Halton sequence in 3 dimensions. - num_samples = 1000 - dim = 3 - with self.test_session(): - sample = halton.sample(dim, num_samples=num_samples) - - # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional - # hypercube. - powers = math_ops.range(1.0, limit=dim + 1) - integral = math_ops.reduce_mean( - math_ops.reduce_prod(sample ** powers, axis=-1)) - true_value = 1.0 / math_ops.reduce_prod(powers + 1.0) - - # Produces a relative absolute error of 1.7%. - self.assertAllClose(integral.eval(), true_value.eval(), rtol=0.02) - - # Now skip the first 1000 samples and recompute the integral with the next - # thousand samples. The sample_indices argument can be used to do this. - - sample_indices = math_ops.range(start=1000, limit=1000 + num_samples, - dtype=dtypes.int32) - sample_leaped = halton.sample(dim, sample_indices=sample_indices) - - integral_leaped = math_ops.reduce_mean( - math_ops.reduce_prod(sample_leaped ** powers, axis=-1)) - self.assertAllClose(integral_leaped.eval(), true_value.eval(), rtol=0.001) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py deleted file mode 100644 index 819095a060b5f4cf18df6e7e4e4556e50ae44dd3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py +++ /dev/null @@ -1,869 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Hamiltonian Monte Carlo.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections - -import numpy as np -from scipy import stats - -from tensorflow.contrib.bayesflow.python.ops import hmc -from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _compute_energy_change -from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _leapfrog_integrator - -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_linalg_ops -from tensorflow.python.ops import gradients_impl as gradients_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import gamma as gamma_lib -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging as logging_ops - - -def _reduce_variance(x, axis=None, keepdims=False): - sample_mean = math_ops.reduce_mean(x, axis, keepdims=True) - return math_ops.reduce_mean( - math_ops.squared_difference(x, sample_mean), axis, keepdims) - - -class HMCTest(test.TestCase): - - def setUp(self): - self._shape_param = 5. - self._rate_param = 10. - - random_seed.set_random_seed(10003) - np.random.seed(10003) - - def assertAllFinite(self, x): - self.assertAllEqual(np.ones_like(x).astype(bool), np.isfinite(x)) - - def _log_gamma_log_prob(self, x, event_dims=()): - """Computes log-pdf of a log-gamma random variable. - - Args: - x: Value of the random variable. - event_dims: Dimensions not to treat as independent. - - Returns: - log_prob: The log-pdf up to a normalizing constant. - """ - return math_ops.reduce_sum(self._shape_param * x - - self._rate_param * math_ops.exp(x), - event_dims) - - def _integrator_conserves_energy(self, x, independent_chain_ndims, sess, - feed_dict=None): - step_size = array_ops.placeholder(np.float32, [], name="step_size") - hmc_lf_steps = array_ops.placeholder(np.int32, [], name="hmc_lf_steps") - - if feed_dict is None: - feed_dict = {} - feed_dict[hmc_lf_steps] = 1000 - - event_dims = math_ops.range(independent_chain_ndims, - array_ops.rank(x)) - - m = random_ops.random_normal(array_ops.shape(x)) - log_prob_0 = self._log_gamma_log_prob(x, event_dims) - grad_0 = gradients_ops.gradients(log_prob_0, x) - old_energy = -log_prob_0 + 0.5 * math_ops.reduce_sum(m**2., event_dims) - - new_m, _, log_prob_1, _ = _leapfrog_integrator( - current_momentums=[m], - target_log_prob_fn=lambda x: self._log_gamma_log_prob(x, event_dims), - current_state_parts=[x], - step_sizes=[step_size], - num_leapfrog_steps=hmc_lf_steps, - current_target_log_prob=log_prob_0, - current_grads_target_log_prob=grad_0) - new_m = new_m[0] - - new_energy = -log_prob_1 + 0.5 * math_ops.reduce_sum(new_m * new_m, - event_dims) - - x_shape = sess.run(x, feed_dict).shape - event_size = np.prod(x_shape[independent_chain_ndims:]) - feed_dict[step_size] = 0.1 / event_size - old_energy_, new_energy_ = sess.run([old_energy, new_energy], - feed_dict) - logging_ops.vlog(1, "average energy relative change: {}".format( - (1. - new_energy_ / old_energy_).mean())) - self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02) - - def _integrator_conserves_energy_wrapper(self, independent_chain_ndims): - """Tests the long-term energy conservation of the leapfrog integrator. - - The leapfrog integrator is symplectic, so for sufficiently small step - sizes it should be possible to run it more or less indefinitely without - the energy of the system blowing up or collapsing. - - Args: - independent_chain_ndims: Python `int` scalar representing the number of - dims associated with independent chains. - """ - with self.test_session(graph=ops.Graph()) as sess: - x_ph = array_ops.placeholder(np.float32, name="x_ph") - feed_dict = {x_ph: np.random.rand(50, 10, 2)} - self._integrator_conserves_energy(x_ph, independent_chain_ndims, - sess, feed_dict) - - def testIntegratorEnergyConservationNullShape(self): - self._integrator_conserves_energy_wrapper(0) - - def testIntegratorEnergyConservation1(self): - self._integrator_conserves_energy_wrapper(1) - - def testIntegratorEnergyConservation2(self): - self._integrator_conserves_energy_wrapper(2) - - def testIntegratorEnergyConservation3(self): - self._integrator_conserves_energy_wrapper(3) - - def testSampleChainSeedReproducibleWorksCorrectly(self): - with self.test_session(graph=ops.Graph()) as sess: - num_results = 10 - independent_chain_ndims = 1 - - def log_gamma_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, - array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - kwargs = dict( - target_log_prob_fn=log_gamma_log_prob, - current_state=np.random.rand(4, 3, 2), - step_size=0.1, - num_leapfrog_steps=2, - num_burnin_steps=150, - seed=52, - ) - - samples0, kernel_results0 = hmc.sample_chain( - **dict(list(kwargs.items()) + list(dict( - num_results=2 * num_results, - num_steps_between_results=0).items()))) - - samples1, kernel_results1 = hmc.sample_chain( - **dict(list(kwargs.items()) + list(dict( - num_results=num_results, - num_steps_between_results=1).items()))) - - [ - samples0_, - samples1_, - target_log_prob0_, - target_log_prob1_, - ] = sess.run([ - samples0, - samples1, - kernel_results0.current_target_log_prob, - kernel_results1.current_target_log_prob, - ]) - self.assertAllClose(samples0_[::2], samples1_, - atol=1e-5, rtol=1e-5) - self.assertAllClose(target_log_prob0_[::2], target_log_prob1_, - atol=1e-5, rtol=1e-5) - - def _chain_gets_correct_expectations(self, x, independent_chain_ndims, - sess, feed_dict=None): - counter = collections.Counter() - def log_gamma_log_prob(x): - counter["target_calls"] += 1 - event_dims = math_ops.range(independent_chain_ndims, - array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - num_results = array_ops.placeholder( - np.int32, [], name="num_results") - step_size = array_ops.placeholder( - np.float32, [], name="step_size") - num_leapfrog_steps = array_ops.placeholder( - np.int32, [], name="num_leapfrog_steps") - - if feed_dict is None: - feed_dict = {} - feed_dict.update({num_results: 150, - step_size: 0.05, - num_leapfrog_steps: 2}) - - samples, kernel_results = hmc.sample_chain( - num_results=num_results, - target_log_prob_fn=log_gamma_log_prob, - current_state=x, - step_size=step_size, - num_leapfrog_steps=num_leapfrog_steps, - num_burnin_steps=150, - seed=42) - - self.assertAllEqual(dict(target_calls=2), counter) - - expected_x = (math_ops.digamma(self._shape_param) - - np.log(self._rate_param)) - - expected_exp_x = self._shape_param / self._rate_param - - log_accept_ratio_, samples_, expected_x_ = sess.run( - [kernel_results.log_accept_ratio, samples, expected_x], - feed_dict) - - actual_x = samples_.mean() - actual_exp_x = np.exp(samples_).mean() - acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) - - logging_ops.vlog(1, "True E[x, exp(x)]: {}\t{}".format( - expected_x_, expected_exp_x)) - logging_ops.vlog(1, "Estimated E[x, exp(x)]: {}\t{}".format( - actual_x, actual_exp_x)) - self.assertNear(actual_x, expected_x_, 2e-2) - self.assertNear(actual_exp_x, expected_exp_x, 2e-2) - self.assertAllEqual(np.ones_like(acceptance_probs, np.bool), - acceptance_probs > 0.5) - self.assertAllEqual(np.ones_like(acceptance_probs, np.bool), - acceptance_probs <= 1.) - - def _chain_gets_correct_expectations_wrapper(self, independent_chain_ndims): - with self.test_session(graph=ops.Graph()) as sess: - x_ph = array_ops.placeholder(np.float32, name="x_ph") - feed_dict = {x_ph: np.random.rand(50, 10, 2)} - self._chain_gets_correct_expectations(x_ph, independent_chain_ndims, - sess, feed_dict) - - def testHMCChainExpectationsNullShape(self): - self._chain_gets_correct_expectations_wrapper(0) - - def testHMCChainExpectations1(self): - self._chain_gets_correct_expectations_wrapper(1) - - def testHMCChainExpectations2(self): - self._chain_gets_correct_expectations_wrapper(2) - - def testKernelResultsUsingTruncatedDistribution(self): - def log_prob(x): - return array_ops.where( - x >= 0., - -x - x**2, # Non-constant gradient. - array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype))) - # This log_prob has the property that it is likely to attract - # the flow toward, and below, zero...but for x <=0, - # log_prob(x) = -inf, which should result in rejection, as well - # as a non-finite log_prob. Thus, this distribution gives us an opportunity - # to test out the kernel results ability to correctly capture rejections due - # to finite AND non-finite reasons. - # Why use a non-constant gradient? This ensures the leapfrog integrator - # will not be exact. - - num_results = 1000 - # Large step size, will give rejections due to integration error in addition - # to rejection due to going into a region of log_prob = -inf. - step_size = 0.1 - num_leapfrog_steps = 5 - num_chains = 2 - - with self.test_session(graph=ops.Graph()) as sess: - - # Start multiple independent chains. - initial_state = ops.convert_to_tensor([0.1] * num_chains) - - states, kernel_results = hmc.sample_chain( - num_results=num_results, - target_log_prob_fn=log_prob, - current_state=initial_state, - step_size=step_size, - num_leapfrog_steps=num_leapfrog_steps, - seed=42) - - states_, kernel_results_ = sess.run([states, kernel_results]) - pstates_ = kernel_results_.proposed_state - - neg_inf_mask = np.isneginf(kernel_results_.proposed_target_log_prob) - - # First: Test that the mathematical properties of the above log prob - # function in conjunction with HMC show up as expected in kernel_results_. - - # We better have log_prob = -inf some of the time. - self.assertLess(0, neg_inf_mask.sum()) - # We better have some rejections due to something other than -inf. - self.assertLess(neg_inf_mask.sum(), (~kernel_results_.is_accepted).sum()) - # We better have accepted a decent amount, even near end of the chain. - self.assertLess( - 0.1, kernel_results_.is_accepted[int(0.9 * num_results):].mean()) - # We better not have any NaNs in states or log_prob. - # We may have some NaN in grads, which involve multiplication/addition due - # to gradient rules. This is the known "NaN grad issue with tf.where." - self.assertAllEqual(np.zeros_like(states_), - np.isnan(kernel_results_.proposed_target_log_prob)) - self.assertAllEqual(np.zeros_like(states_), - np.isnan(states_)) - # We better not have any +inf in states, grads, or log_prob. - self.assertAllEqual(np.zeros_like(states_), - np.isposinf(kernel_results_.proposed_target_log_prob)) - self.assertAllEqual( - np.zeros_like(states_), - np.isposinf(kernel_results_.proposed_grads_target_log_prob[0])) - self.assertAllEqual(np.zeros_like(states_), - np.isposinf(states_)) - - # Second: Test that kernel_results is congruent with itself and - # acceptance/rejection of states. - - # Proposed state is negative iff proposed target log prob is -inf. - np.testing.assert_array_less(pstates_[neg_inf_mask], 0.) - np.testing.assert_array_less(0., pstates_[~neg_inf_mask]) - - # Acceptance probs are zero whenever proposed state is negative. - acceptance_probs = np.exp(np.minimum( - kernel_results_.log_accept_ratio, 0.)) - self.assertAllEqual( - np.zeros_like(pstates_[neg_inf_mask]), - acceptance_probs[neg_inf_mask]) - - # The move is accepted ==> state = proposed state. - self.assertAllEqual( - states_[kernel_results_.is_accepted], - pstates_[kernel_results_.is_accepted], - ) - # The move was rejected <==> state[t] == state[t - 1]. - for t in range(1, num_results): - for i in range(num_chains): - if kernel_results_.is_accepted[t, i]: - self.assertNotEqual(states_[t, i], states_[t - 1, i]) - else: - self.assertEqual(states_[t, i], states_[t - 1, i]) - - def _kernel_leaves_target_invariant(self, initial_draws, - independent_chain_ndims, - sess, feed_dict=None): - def log_gamma_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - def fake_log_prob(x): - """Cooled version of the target distribution.""" - return 1.1 * log_gamma_log_prob(x) - - step_size = array_ops.placeholder(np.float32, [], name="step_size") - - if feed_dict is None: - feed_dict = {} - - feed_dict[step_size] = 0.4 - - sample, kernel_results = hmc.kernel( - target_log_prob_fn=log_gamma_log_prob, - current_state=initial_draws, - step_size=step_size, - num_leapfrog_steps=5, - seed=43) - - bad_sample, bad_kernel_results = hmc.kernel( - target_log_prob_fn=fake_log_prob, - current_state=initial_draws, - step_size=step_size, - num_leapfrog_steps=5, - seed=44) - - [ - log_accept_ratio_, - bad_log_accept_ratio_, - initial_draws_, - updated_draws_, - fake_draws_, - ] = sess.run([ - kernel_results.log_accept_ratio, - bad_kernel_results.log_accept_ratio, - initial_draws, - sample, - bad_sample, - ], feed_dict) - - # Confirm step size is small enough that we usually accept. - acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) - bad_acceptance_probs = np.exp(np.minimum(bad_log_accept_ratio_, 0.)) - self.assertGreater(acceptance_probs.mean(), 0.5) - self.assertGreater(bad_acceptance_probs.mean(), 0.5) - - # Confirm step size is large enough that we sometimes reject. - self.assertLess(acceptance_probs.mean(), 0.99) - self.assertLess(bad_acceptance_probs.mean(), 0.99) - - _, ks_p_value_true = stats.ks_2samp(initial_draws_.flatten(), - updated_draws_.flatten()) - _, ks_p_value_fake = stats.ks_2samp(initial_draws_.flatten(), - fake_draws_.flatten()) - - logging_ops.vlog(1, "acceptance rate for true target: {}".format( - acceptance_probs.mean())) - logging_ops.vlog(1, "acceptance rate for fake target: {}".format( - bad_acceptance_probs.mean())) - logging_ops.vlog(1, "K-S p-value for true target: {}".format( - ks_p_value_true)) - logging_ops.vlog(1, "K-S p-value for fake target: {}".format( - ks_p_value_fake)) - # Make sure that the MCMC update hasn't changed the empirical CDF much. - self.assertGreater(ks_p_value_true, 1e-3) - # Confirm that targeting the wrong distribution does - # significantly change the empirical CDF. - self.assertLess(ks_p_value_fake, 1e-6) - - def _kernel_leaves_target_invariant_wrapper(self, independent_chain_ndims): - """Tests that the kernel leaves the target distribution invariant. - - Draws some independent samples from the target distribution, - applies an iteration of the MCMC kernel, then runs a - Kolmogorov-Smirnov test to determine if the distribution of the - MCMC-updated samples has changed. - - We also confirm that running the kernel with a different log-pdf - does change the target distribution. (And that we can detect that.) - - Args: - independent_chain_ndims: Python `int` scalar representing the number of - dims associated with independent chains. - """ - with self.test_session(graph=ops.Graph()) as sess: - initial_draws = np.log(np.random.gamma(self._shape_param, - size=[50000, 2, 2])) - initial_draws -= np.log(self._rate_param) - x_ph = array_ops.placeholder(np.float32, name="x_ph") - - feed_dict = {x_ph: initial_draws} - - self._kernel_leaves_target_invariant(x_ph, independent_chain_ndims, - sess, feed_dict) - - def testKernelLeavesTargetInvariant1(self): - self._kernel_leaves_target_invariant_wrapper(1) - - def testKernelLeavesTargetInvariant2(self): - self._kernel_leaves_target_invariant_wrapper(2) - - def testKernelLeavesTargetInvariant3(self): - self._kernel_leaves_target_invariant_wrapper(3) - - def _ais_gets_correct_log_normalizer(self, init, independent_chain_ndims, - sess, feed_dict=None): - counter = collections.Counter() - - def proposal_log_prob(x): - counter["proposal_calls"] += 1 - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi), - axis=event_dims) - - def target_log_prob(x): - counter["target_calls"] += 1 - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - if feed_dict is None: - feed_dict = {} - - num_steps = 200 - - _, ais_weights, _ = hmc.sample_annealed_importance_chain( - proposal_log_prob_fn=proposal_log_prob, - num_steps=num_steps, - target_log_prob_fn=target_log_prob, - step_size=0.5, - current_state=init, - num_leapfrog_steps=2, - seed=45) - - # We have three calls because the calculation of `ais_weights` entails - # another call to the `convex_combined_log_prob_fn`. We could refactor - # things to avoid this, if needed (eg, b/72994218). - self.assertAllEqual(dict(target_calls=3, proposal_calls=3), counter) - - event_shape = array_ops.shape(init)[independent_chain_ndims:] - event_size = math_ops.reduce_prod(event_shape) - - log_true_normalizer = ( - -self._shape_param * math_ops.log(self._rate_param) - + math_ops.lgamma(self._shape_param)) - log_true_normalizer *= math_ops.cast(event_size, log_true_normalizer.dtype) - - log_estimated_normalizer = (math_ops.reduce_logsumexp(ais_weights) - - np.log(num_steps)) - - ratio_estimate_true = math_ops.exp(ais_weights - log_true_normalizer) - ais_weights_size = array_ops.size(ais_weights) - standard_error = math_ops.sqrt( - _reduce_variance(ratio_estimate_true) - / math_ops.cast(ais_weights_size, ratio_estimate_true.dtype)) - - [ - ratio_estimate_true_, - log_true_normalizer_, - log_estimated_normalizer_, - standard_error_, - ais_weights_size_, - event_size_, - ] = sess.run([ - ratio_estimate_true, - log_true_normalizer, - log_estimated_normalizer, - standard_error, - ais_weights_size, - event_size, - ], feed_dict) - - logging_ops.vlog(1, " log_true_normalizer: {}\n" - " log_estimated_normalizer: {}\n" - " ais_weights_size: {}\n" - " event_size: {}\n".format( - log_true_normalizer_, - log_estimated_normalizer_, - ais_weights_size_, - event_size_)) - self.assertNear(ratio_estimate_true_.mean(), 1., 4. * standard_error_) - - def _ais_gets_correct_log_normalizer_wrapper(self, independent_chain_ndims): - """Tests that AIS yields reasonable estimates of normalizers.""" - with self.test_session(graph=ops.Graph()) as sess: - x_ph = array_ops.placeholder(np.float32, name="x_ph") - initial_draws = np.random.normal(size=[30, 2, 1]) - self._ais_gets_correct_log_normalizer( - x_ph, - independent_chain_ndims, - sess, - feed_dict={x_ph: initial_draws}) - - def testAIS1(self): - self._ais_gets_correct_log_normalizer_wrapper(1) - - def testAIS2(self): - self._ais_gets_correct_log_normalizer_wrapper(2) - - def testAIS3(self): - self._ais_gets_correct_log_normalizer_wrapper(3) - - def testSampleAIChainSeedReproducibleWorksCorrectly(self): - with self.test_session(graph=ops.Graph()) as sess: - independent_chain_ndims = 1 - x = np.random.rand(4, 3, 2) - - def proposal_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi), - axis=event_dims) - - def target_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - ais_kwargs = dict( - proposal_log_prob_fn=proposal_log_prob, - num_steps=200, - target_log_prob_fn=target_log_prob, - step_size=0.5, - current_state=x, - num_leapfrog_steps=2, - seed=53) - - _, ais_weights0, _ = hmc.sample_annealed_importance_chain( - **ais_kwargs) - - _, ais_weights1, _ = hmc.sample_annealed_importance_chain( - **ais_kwargs) - - [ais_weights0_, ais_weights1_] = sess.run([ - ais_weights0, ais_weights1]) - - self.assertAllClose(ais_weights0_, ais_weights1_, - atol=1e-5, rtol=1e-5) - - def testNanRejection(self): - """Tests that an update that yields NaN potentials gets rejected. - - We run HMC with a target distribution that returns NaN - log-likelihoods if any element of x < 0, and unit-scale - exponential log-likelihoods otherwise. The exponential potential - pushes x towards 0, ensuring that any reasonably large update will - push us over the edge into NaN territory. - """ - def _unbounded_exponential_log_prob(x): - """An exponential distribution with log-likelihood NaN for x < 0.""" - per_element_potentials = array_ops.where( - x < 0., - array_ops.fill(array_ops.shape(x), x.dtype.as_numpy_dtype(np.nan)), - -x) - return math_ops.reduce_sum(per_element_potentials) - - with self.test_session(graph=ops.Graph()) as sess: - initial_x = math_ops.linspace(0.01, 5, 10) - updated_x, kernel_results = hmc.kernel( - target_log_prob_fn=_unbounded_exponential_log_prob, - current_state=initial_x, - step_size=2., - num_leapfrog_steps=5, - seed=46) - initial_x_, updated_x_, log_accept_ratio_ = sess.run( - [initial_x, updated_x, kernel_results.log_accept_ratio]) - acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) - - logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) - logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) - logging_ops.vlog(1, "log_accept_ratio = {}".format(log_accept_ratio_)) - - self.assertAllEqual(initial_x_, updated_x_) - self.assertEqual(acceptance_probs, 0.) - - def testNanFromGradsDontPropagate(self): - """Test that update with NaN gradients does not cause NaN in results.""" - def _nan_log_prob_with_nan_gradient(x): - return np.nan * math_ops.reduce_sum(x) - - with self.test_session(graph=ops.Graph()) as sess: - initial_x = math_ops.linspace(0.01, 5, 10) - updated_x, kernel_results = hmc.kernel( - target_log_prob_fn=_nan_log_prob_with_nan_gradient, - current_state=initial_x, - step_size=2., - num_leapfrog_steps=5, - seed=47) - initial_x_, updated_x_, log_accept_ratio_ = sess.run( - [initial_x, updated_x, kernel_results.log_accept_ratio]) - acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) - - logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) - logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) - logging_ops.vlog(1, "log_accept_ratio = {}".format(log_accept_ratio_)) - - self.assertAllEqual(initial_x_, updated_x_) - self.assertEqual(acceptance_probs, 0.) - - self.assertAllFinite( - gradients_ops.gradients(updated_x, initial_x)[0].eval()) - self.assertAllEqual([True], [g is None for g in gradients_ops.gradients( - kernel_results.proposed_grads_target_log_prob, initial_x)]) - self.assertAllEqual([False], [g is None for g in gradients_ops.gradients( - kernel_results.proposed_grads_target_log_prob, - kernel_results.proposed_state)]) - - # Gradients of the acceptance probs and new log prob are not finite. - # self.assertAllFinite( - # gradients_ops.gradients(acceptance_probs, initial_x)[0].eval()) - # self.assertAllFinite( - # gradients_ops.gradients(new_log_prob, initial_x)[0].eval()) - - def _testChainWorksDtype(self, dtype): - with self.test_session(graph=ops.Graph()) as sess: - states, kernel_results = hmc.sample_chain( - num_results=10, - target_log_prob_fn=lambda x: -math_ops.reduce_sum(x**2., axis=-1), - current_state=np.zeros(5).astype(dtype), - step_size=0.01, - num_leapfrog_steps=10, - seed=48) - states_, log_accept_ratio_ = sess.run( - [states, kernel_results.log_accept_ratio]) - self.assertEqual(dtype, states_.dtype) - self.assertEqual(dtype, log_accept_ratio_.dtype) - - def testChainWorksIn64Bit(self): - self._testChainWorksDtype(np.float64) - - def testChainWorksIn16Bit(self): - self._testChainWorksDtype(np.float16) - - def testChainWorksCorrelatedMultivariate(self): - dtype = np.float32 - true_mean = dtype([0, 0]) - true_cov = dtype([[1, 0.5], - [0.5, 1]]) - num_results = 2000 - counter = collections.Counter() - with self.test_session(graph=ops.Graph()) as sess: - def target_log_prob(x, y): - counter["target_calls"] += 1 - # Corresponds to unnormalized MVN. - # z = matmul(inv(chol(true_cov)), [x, y] - true_mean) - z = array_ops.stack([x, y], axis=-1) - true_mean - z = array_ops.squeeze( - gen_linalg_ops.matrix_triangular_solve( - np.linalg.cholesky(true_cov), - z[..., array_ops.newaxis]), - axis=-1) - return -0.5 * math_ops.reduce_sum(z**2., axis=-1) - states, _ = hmc.sample_chain( - num_results=num_results, - target_log_prob_fn=target_log_prob, - current_state=[dtype(-2), dtype(2)], - step_size=[0.5, 0.5], - num_leapfrog_steps=2, - num_burnin_steps=200, - num_steps_between_results=1, - seed=54) - self.assertAllEqual(dict(target_calls=2), counter) - states = array_ops.stack(states, axis=-1) - self.assertEqual(num_results, states.shape[0].value) - sample_mean = math_ops.reduce_mean(states, axis=0) - x = states - sample_mean - sample_cov = math_ops.matmul(x, x, transpose_a=True) / dtype(num_results) - [sample_mean_, sample_cov_] = sess.run([ - sample_mean, sample_cov]) - self.assertAllClose(true_mean, sample_mean_, - atol=0.05, rtol=0.) - self.assertAllClose(true_cov, sample_cov_, - atol=0., rtol=0.1) - - -class _EnergyComputationTest(object): - - def testHandlesNanFromPotential(self): - with self.test_session(graph=ops.Graph()) as sess: - x = [1, np.inf, -np.inf, np.nan] - target_log_prob, proposed_target_log_prob = [ - self.dtype(x.flatten()) for x in np.meshgrid(x, x)] - num_chains = len(target_log_prob) - dummy_momentums = [-1, 1] - momentums = [self.dtype([dummy_momentums] * num_chains)] - proposed_momentums = [self.dtype([dummy_momentums] * num_chains)] - - target_log_prob = ops.convert_to_tensor(target_log_prob) - momentums = [ops.convert_to_tensor(momentums[0])] - proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob) - proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])] - - energy = _compute_energy_change( - target_log_prob, - momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims=1) - grads = gradients_ops.gradients(energy, momentums) - - [actual_energy, grads_] = sess.run([energy, grads]) - - # Ensure energy is `inf` (note: that's positive inf) in weird cases and - # finite otherwise. - expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1)) - self.assertAllEqual(expected_energy, actual_energy) - - # Ensure gradient is finite. - self.assertAllEqual(np.ones_like(grads_).astype(np.bool), - np.isfinite(grads_)) - - def testHandlesNanFromKinetic(self): - with self.test_session(graph=ops.Graph()) as sess: - x = [1, np.inf, -np.inf, np.nan] - momentums, proposed_momentums = [ - [np.reshape(self.dtype(x), [-1, 1])] - for x in np.meshgrid(x, x)] - num_chains = len(momentums[0]) - target_log_prob = np.ones(num_chains, self.dtype) - proposed_target_log_prob = np.ones(num_chains, self.dtype) - - target_log_prob = ops.convert_to_tensor(target_log_prob) - momentums = [ops.convert_to_tensor(momentums[0])] - proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob) - proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])] - - energy = _compute_energy_change( - target_log_prob, - momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims=1) - grads = gradients_ops.gradients(energy, momentums) - - [actual_energy, grads_] = sess.run([energy, grads]) - - # Ensure energy is `inf` (note: that's positive inf) in weird cases and - # finite otherwise. - expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1)) - self.assertAllEqual(expected_energy, actual_energy) - - # Ensure gradient is finite. - g = grads_[0].reshape([len(x), len(x)])[:, 0] - self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isfinite(g)) - - # The remaining gradients are nan because the momentum was itself nan or - # inf. - g = grads_[0].reshape([len(x), len(x)])[:, 1:] - self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isnan(g)) - - -class EnergyComputationTest16(test.TestCase, _EnergyComputationTest): - dtype = np.float16 - - -class EnergyComputationTest32(test.TestCase, _EnergyComputationTest): - dtype = np.float32 - - -class EnergyComputationTest64(test.TestCase, _EnergyComputationTest): - dtype = np.float64 - - -class _HMCHandlesLists(object): - - def testStateParts(self): - with self.test_session(graph=ops.Graph()) as sess: - dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1)) - dist_y = independent_lib.Independent( - gamma_lib.Gamma(concentration=self.dtype([1, 2]), - rate=self.dtype([0.5, 0.75])), - reinterpreted_batch_ndims=1) - def target_log_prob(x, y): - return dist_x.log_prob(x) + dist_y.log_prob(y) - x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)] - samples, _ = hmc.sample_chain( - num_results=int(2e3), - target_log_prob_fn=target_log_prob, - current_state=x0, - step_size=0.85, - num_leapfrog_steps=3, - num_burnin_steps=int(250), - seed=49) - actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples] - actual_vars = [_reduce_variance(s, axis=0) for s in samples] - expected_means = [dist_x.mean(), dist_y.mean()] - expected_vars = [dist_x.variance(), dist_y.variance()] - [ - actual_means_, - actual_vars_, - expected_means_, - expected_vars_, - ] = sess.run([ - actual_means, - actual_vars, - expected_means, - expected_vars, - ]) - self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16) - self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.25) - - -class HMCHandlesLists32(_HMCHandlesLists, test.TestCase): - dtype = np.float32 - - -class HMCHandlesLists64(_HMCHandlesLists, test.TestCase): - dtype = np.float64 - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py deleted file mode 100644 index 750afb6654311fea30a1dc6b31b20aa3b4160ae2..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py +++ /dev/null @@ -1,521 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for convolutional Bayesian layers.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import layers_conv_variational as prob_layers_lib -from tensorflow.contrib.bayesflow.python.ops import layers_util as prob_layers_util -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.ops.distributions import util as distribution_util -from tensorflow.python.platform import test - - -class Counter(object): - """Helper class to manage incrementing a counting `int`.""" - - def __init__(self): - self._value = -1 - - @property - def value(self): - return self._value - - def __call__(self): - self._value += 1 - return self._value - - -class MockDistribution(independent_lib.Independent): - """Monitors layer calls to the underlying distribution.""" - - def __init__(self, result_sample, result_log_prob, loc=None, scale=None): - self.result_sample = result_sample - self.result_log_prob = result_log_prob - self.result_loc = loc - self.result_scale = scale - self.result_distribution = normal_lib.Normal(loc=0.0, scale=1.0) - if loc is not None and scale is not None: - self.result_distribution = normal_lib.Normal(loc=self.result_loc, - scale=self.result_scale) - self.called_log_prob = Counter() - self.called_sample = Counter() - self.called_loc = Counter() - self.called_scale = Counter() - - def log_prob(self, *args, **kwargs): - self.called_log_prob() - return self.result_log_prob - - def sample(self, *args, **kwargs): - self.called_sample() - return self.result_sample - - @property - def distribution(self): # for dummy check on Independent(Normal) - return self.result_distribution - - @property - def loc(self): - self.called_loc() - return self.result_loc - - @property - def scale(self): - self.called_scale() - return self.result_scale - - -class MockKLDivergence(object): - """Monitors layer calls to the divergence implementation.""" - - def __init__(self, result): - self.result = result - self.args = [] - self.called = Counter() - - def __call__(self, *args, **kwargs): - self.called() - self.args.append(args) - return self.result - - -class ConvVariational(test.TestCase): - - def _testKLPenaltyKernel(self, layer_class): - with self.test_session(): - layer = layer_class(filters=2, kernel_size=3) - if layer_class in (prob_layers_lib.Conv1DReparameterization, - prob_layers_lib.Conv1DFlipout): - inputs = random_ops.random_uniform([2, 3, 1], seed=1) - elif layer_class in (prob_layers_lib.Conv2DReparameterization, - prob_layers_lib.Conv2DFlipout): - inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1) - elif layer_class in (prob_layers_lib.Conv3DReparameterization, - prob_layers_lib.Conv3DFlipout): - inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1) - - # No keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 0) - self.assertListEqual(layer.losses, losses) - - _ = layer(inputs) - - # Yes keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 1) - self.assertListEqual(layer.losses, losses) - - def _testKLPenaltyBoth(self, layer_class): - def _make_normal(dtype, *args): # pylint: disable=unused-argument - return normal_lib.Normal( - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)) - with self.test_session(): - layer = layer_class( - filters=2, - kernel_size=3, - bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(), - bias_prior_fn=_make_normal) - if layer_class in (prob_layers_lib.Conv1DReparameterization, - prob_layers_lib.Conv1DFlipout): - inputs = random_ops.random_uniform([2, 3, 1], seed=1) - elif layer_class in (prob_layers_lib.Conv2DReparameterization, - prob_layers_lib.Conv2DFlipout): - inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1) - elif layer_class in (prob_layers_lib.Conv3DReparameterization, - prob_layers_lib.Conv3DFlipout): - inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1) - - # No keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 0) - self.assertListEqual(layer.losses, losses) - - _ = layer(inputs) - - # Yes keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 2) - self.assertListEqual(layer.losses, losses) - - def _testConvSetUp(self, layer_class, batch_size, depth=None, - height=None, width=None, channels=None, filters=None, - **kwargs): - seed = Counter() - if layer_class in (prob_layers_lib.Conv1DReparameterization, - prob_layers_lib.Conv1DFlipout): - inputs = random_ops.random_uniform( - [batch_size, width, channels], seed=seed()) - kernel_size = (2,) - elif layer_class in (prob_layers_lib.Conv2DReparameterization, - prob_layers_lib.Conv2DFlipout): - inputs = random_ops.random_uniform( - [batch_size, height, width, channels], seed=seed()) - kernel_size = (2, 2) - elif layer_class in (prob_layers_lib.Conv3DReparameterization, - prob_layers_lib.Conv3DFlipout): - inputs = random_ops.random_uniform( - [batch_size, depth, height, width, channels], seed=seed()) - kernel_size = (2, 2, 2) - - kernel_shape = kernel_size + (channels, filters) - kernel_posterior = MockDistribution( - loc=random_ops.random_uniform(kernel_shape, seed=seed()), - scale=random_ops.random_uniform(kernel_shape, seed=seed()), - result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()), - result_sample=random_ops.random_uniform(kernel_shape, seed=seed())) - kernel_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()), - result_sample=random_ops.random_uniform(kernel_shape, seed=seed())) - kernel_divergence = MockKLDivergence( - result=random_ops.random_uniform(kernel_shape, seed=seed())) - - bias_size = (filters,) - bias_posterior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_divergence = MockKLDivergence( - result=random_ops.random_uniform(bias_size, seed=seed())) - - layer = layer_class( - filters=filters, - kernel_size=kernel_size, - padding="SAME", - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - kernel_prior_fn=lambda *args: kernel_prior, - kernel_divergence_fn=kernel_divergence, - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - bias_prior_fn=lambda *args: bias_prior, - bias_divergence_fn=bias_divergence, - **kwargs) - - outputs = layer(inputs) - - kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - return (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, - layer, inputs, outputs, kl_penalty, kernel_shape) - - def _testConvReparameterization(self, layer_class): - batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 - with self.test_session() as sess: - (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, layer, inputs, - outputs, kl_penalty, kernel_shape) = self._testConvSetUp( - layer_class, batch_size, - depth=depth, height=height, width=width, channels=channels, - filters=filters) - - convolution_op = nn_ops.Convolution( - tensor_shape.TensorShape(inputs.shape), - filter_shape=tensor_shape.TensorShape(kernel_shape), - padding="SAME") - expected_outputs = convolution_op(inputs, kernel_posterior.result_sample) - expected_outputs = nn.bias_add(expected_outputs, - bias_posterior.result_sample, - data_format="NHWC") - - [ - expected_outputs_, actual_outputs_, - expected_kernel_, actual_kernel_, - expected_kernel_divergence_, actual_kernel_divergence_, - expected_bias_, actual_bias_, - expected_bias_divergence_, actual_bias_divergence_, - ] = sess.run([ - expected_outputs, outputs, - kernel_posterior.result_sample, layer.kernel_posterior_tensor, - kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, layer.bias_posterior_tensor, - bias_divergence.result, kl_penalty[1], - ]) - - self.assertAllClose( - expected_kernel_, actual_kernel_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_, actual_bias_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_outputs_, actual_outputs_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_kernel_divergence_, actual_kernel_divergence_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_divergence_, actual_bias_divergence_, - rtol=1e-6, atol=0.) - - self.assertAllEqual( - [[kernel_posterior.distribution, - kernel_prior.distribution, - kernel_posterior.result_sample]], - kernel_divergence.args) - - self.assertAllEqual( - [[bias_posterior.distribution, - bias_prior.distribution, - bias_posterior.result_sample]], - bias_divergence.args) - - def _testConvFlipout(self, layer_class): - batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 - with self.test_session() as sess: - (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, layer, inputs, - outputs, kl_penalty, kernel_shape) = self._testConvSetUp( - layer_class, batch_size, - depth=depth, height=height, width=width, channels=channels, - filters=filters, seed=44) - - convolution_op = nn_ops.Convolution( - tensor_shape.TensorShape(inputs.shape), - filter_shape=tensor_shape.TensorShape(kernel_shape), - padding="SAME") - - expected_kernel_posterior_affine = normal_lib.Normal( - loc=array_ops.zeros_like(kernel_posterior.result_loc), - scale=kernel_posterior.result_scale) - expected_kernel_posterior_affine_tensor = ( - expected_kernel_posterior_affine.sample(seed=42)) - - expected_outputs = convolution_op( - inputs, kernel_posterior.distribution.loc) - - input_shape = array_ops.shape(inputs) - output_shape = array_ops.shape(expected_outputs) - batch_shape = array_ops.expand_dims(input_shape[0], 0) - channels = input_shape[-1] - rank = len(inputs.get_shape()) - 2 - - sign_input = random_ops.random_uniform( - array_ops.concat([batch_shape, - array_ops.expand_dims(channels, 0)], 0), - minval=0, - maxval=2, - dtype=dtypes.int32, - seed=layer.seed) - sign_input = math_ops.cast(2 * sign_input - 1, inputs.dtype) - sign_output = random_ops.random_uniform( - array_ops.concat([batch_shape, - array_ops.expand_dims(filters, 0)], 0), - minval=0, - maxval=2, - dtype=dtypes.int32, - seed=distribution_util.gen_new_seed( - layer.seed, salt="conv_flipout")) - sign_output = math_ops.cast(2 * sign_output - 1, inputs.dtype) - for _ in range(rank): - sign_input = array_ops.expand_dims(sign_input, 1) # 2D ex: (B, 1, 1, C) - sign_output = array_ops.expand_dims(sign_output, 1) - - sign_input = array_ops.tile( # tile for element-wise op broadcasting - sign_input, - [1] + [input_shape[i + 1] for i in range(rank)] + [1]) - sign_output = array_ops.tile( - sign_output, - [1] + [output_shape[i + 1] for i in range(rank)] + [1]) - - perturbed_inputs = convolution_op( - inputs * sign_input, expected_kernel_posterior_affine_tensor) - perturbed_inputs *= sign_output - - expected_outputs += perturbed_inputs - expected_outputs = nn.bias_add(expected_outputs, - bias_posterior.result_sample, - data_format="NHWC") - - [ - expected_outputs_, actual_outputs_, - expected_kernel_divergence_, actual_kernel_divergence_, - expected_bias_, actual_bias_, - expected_bias_divergence_, actual_bias_divergence_, - ] = sess.run([ - expected_outputs, outputs, - kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, layer.bias_posterior_tensor, - bias_divergence.result, kl_penalty[1], - ]) - - self.assertAllClose( - expected_bias_, actual_bias_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_outputs_, actual_outputs_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_kernel_divergence_, actual_kernel_divergence_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_divergence_, actual_bias_divergence_, - rtol=1e-6, atol=0.) - - self.assertAllEqual( - [[kernel_posterior.distribution, kernel_prior.distribution, None]], - kernel_divergence.args) - - self.assertAllEqual( - [[bias_posterior.distribution, - bias_prior.distribution, - bias_posterior.result_sample]], - bias_divergence.args) - - def _testRandomConvFlipout(self, layer_class): - batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 - with self.test_session() as sess: - seed = Counter() - if layer_class in (prob_layers_lib.Conv1DReparameterization, - prob_layers_lib.Conv1DFlipout): - inputs = random_ops.random_uniform( - [batch_size, width, channels], seed=seed()) - kernel_size = (2,) - elif layer_class in (prob_layers_lib.Conv2DReparameterization, - prob_layers_lib.Conv2DFlipout): - inputs = random_ops.random_uniform( - [batch_size, height, width, channels], seed=seed()) - kernel_size = (2, 2) - elif layer_class in (prob_layers_lib.Conv3DReparameterization, - prob_layers_lib.Conv3DFlipout): - inputs = random_ops.random_uniform( - [batch_size, depth, height, width, channels], seed=seed()) - kernel_size = (2, 2, 2) - - kernel_shape = kernel_size + (channels, filters) - bias_size = (filters,) - - kernel_posterior = MockDistribution( - loc=random_ops.random_uniform( - kernel_shape, seed=seed()), - scale=random_ops.random_uniform( - kernel_shape, seed=seed()), - result_log_prob=random_ops.random_uniform( - kernel_shape, seed=seed()), - result_sample=random_ops.random_uniform( - kernel_shape, seed=seed())) - bias_posterior = MockDistribution( - loc=random_ops.random_uniform( - bias_size, seed=seed()), - scale=random_ops.random_uniform( - bias_size, seed=seed()), - result_log_prob=random_ops.random_uniform( - bias_size, seed=seed()), - result_sample=random_ops.random_uniform( - bias_size, seed=seed())) - layer_one = layer_class( - filters=filters, - kernel_size=kernel_size, - padding="SAME", - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - seed=44) - layer_two = layer_class( - filters=filters, - kernel_size=kernel_size, - padding="SAME", - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - seed=45) - - outputs_one = layer_one(inputs) - outputs_two = layer_two(inputs) - - outputs_one_, outputs_two_ = sess.run([ - outputs_one, outputs_two]) - - self.assertLess(np.sum(np.isclose(outputs_one_, outputs_two_)), - np.prod(outputs_one_.shape)) - - def testKLPenaltyKernelConv1DReparameterization(self): - self._testKLPenaltyKernel(prob_layers_lib.Conv1DReparameterization) - - def testKLPenaltyKernelConv2DReparameterization(self): - self._testKLPenaltyKernel(prob_layers_lib.Conv2DReparameterization) - - def testKLPenaltyKernelConv3DReparameterization(self): - self._testKLPenaltyKernel(prob_layers_lib.Conv3DReparameterization) - - def testKLPenaltyKernelConv1DFlipout(self): - self._testKLPenaltyKernel(prob_layers_lib.Conv1DFlipout) - - def testKLPenaltyKernelConv2DFlipout(self): - self._testKLPenaltyKernel(prob_layers_lib.Conv2DFlipout) - - def testKLPenaltyKernelConv3DFlipout(self): - self._testKLPenaltyKernel(prob_layers_lib.Conv3DFlipout) - - def testKLPenaltyBothConv1DReparameterization(self): - self._testKLPenaltyBoth(prob_layers_lib.Conv1DReparameterization) - - def testKLPenaltyBothConv2DReparameterization(self): - self._testKLPenaltyBoth(prob_layers_lib.Conv2DReparameterization) - - def testKLPenaltyBothConv3DReparameterization(self): - self._testKLPenaltyBoth(prob_layers_lib.Conv3DReparameterization) - - def testKLPenaltyBothConv1DFlipout(self): - self._testKLPenaltyBoth(prob_layers_lib.Conv1DFlipout) - - def testKLPenaltyBothConv2DFlipout(self): - self._testKLPenaltyBoth(prob_layers_lib.Conv2DFlipout) - - def testKLPenaltyBothConv3DFlipout(self): - self._testKLPenaltyBoth(prob_layers_lib.Conv3DFlipout) - - def testConv1DReparameterization(self): - self._testConvReparameterization(prob_layers_lib.Conv1DReparameterization) - - def testConv2DReparameterization(self): - self._testConvReparameterization(prob_layers_lib.Conv2DReparameterization) - - def testConv3DReparameterization(self): - self._testConvReparameterization(prob_layers_lib.Conv3DReparameterization) - - def testConv1DFlipout(self): - self._testConvFlipout(prob_layers_lib.Conv1DFlipout) - - def testConv2DFlipout(self): - self._testConvFlipout(prob_layers_lib.Conv2DFlipout) - - def testConv3DFlipout(self): - self._testConvFlipout(prob_layers_lib.Conv3DFlipout) - - def testRandomConv1DFlipout(self): - self._testRandomConvFlipout(prob_layers_lib.Conv1DFlipout) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py deleted file mode 100644 index 342f38ccec7ec74db1b393d6cdc22300205cc547..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py +++ /dev/null @@ -1,443 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for dense Bayesian layers.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational as prob_layers_lib -from tensorflow.contrib.bayesflow.python.ops import layers_util as prob_layers_util -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.ops.distributions import util as distribution_util -from tensorflow.python.platform import test - - -class Counter(object): - """Helper class to manage incrementing a counting `int`.""" - - def __init__(self): - self._value = -1 - - @property - def value(self): - return self._value - - def __call__(self): - self._value += 1 - return self._value - - -class MockDistribution(independent_lib.Independent): - """Monitors layer calls to the underlying distribution.""" - - def __init__(self, result_sample, result_log_prob, loc=None, scale=None): - self.result_sample = result_sample - self.result_log_prob = result_log_prob - self.result_loc = loc - self.result_scale = scale - self.result_distribution = normal_lib.Normal(loc=0.0, scale=1.0) - if loc is not None and scale is not None: - self.result_distribution = normal_lib.Normal(loc=self.result_loc, - scale=self.result_scale) - self.called_log_prob = Counter() - self.called_sample = Counter() - self.called_loc = Counter() - self.called_scale = Counter() - - def log_prob(self, *args, **kwargs): - self.called_log_prob() - return self.result_log_prob - - def sample(self, *args, **kwargs): - self.called_sample() - return self.result_sample - - @property - def distribution(self): # for dummy check on Independent(Normal) - return self.result_distribution - - @property - def loc(self): - self.called_loc() - return self.result_loc - - @property - def scale(self): - self.called_scale() - return self.result_scale - - -class MockKLDivergence(object): - """Monitors layer calls to the divergence implementation.""" - - def __init__(self, result): - self.result = result - self.args = [] - self.called = Counter() - - def __call__(self, *args, **kwargs): - self.called() - self.args.append(args) - return self.result - - -class DenseVariational(test.TestCase): - - def _testKLPenaltyKernel(self, layer_class): - with self.test_session(): - layer = layer_class(units=2) - inputs = random_ops.random_uniform([2, 3], seed=1) - - # No keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 0) - self.assertListEqual(layer.losses, losses) - - _ = layer(inputs) - - # Yes keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 1) - self.assertListEqual(layer.losses, losses) - - def _testKLPenaltyBoth(self, layer_class): - def _make_normal(dtype, *args): # pylint: disable=unused-argument - return normal_lib.Normal( - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)) - with self.test_session(): - layer = layer_class( - units=2, - bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(), - bias_prior_fn=_make_normal) - inputs = random_ops.random_uniform([2, 3], seed=1) - - # No keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 0) - self.assertListEqual(layer.losses, losses) - - _ = layer(inputs) - - # Yes keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 2) - self.assertListEqual(layer.losses, losses) - - def _testDenseSetUp(self, layer_class, batch_size, in_size, out_size, - **kwargs): - seed = Counter() - inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) - - kernel_size = [in_size, out_size] - kernel_posterior = MockDistribution( - loc=random_ops.random_uniform(kernel_size, seed=seed()), - scale=random_ops.random_uniform(kernel_size, seed=seed()), - result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), - result_sample=random_ops.random_uniform(kernel_size, seed=seed())) - kernel_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), - result_sample=random_ops.random_uniform(kernel_size, seed=seed())) - kernel_divergence = MockKLDivergence( - result=random_ops.random_uniform(kernel_size, seed=seed())) - - bias_size = [out_size] - bias_posterior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_divergence = MockKLDivergence( - result=random_ops.random_uniform(bias_size, seed=seed())) - - layer = layer_class( - units=out_size, - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - kernel_prior_fn=lambda *args: kernel_prior, - kernel_divergence_fn=kernel_divergence, - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - bias_prior_fn=lambda *args: bias_prior, - bias_divergence_fn=bias_divergence, - **kwargs) - - outputs = layer(inputs) - - kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - return (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, - layer, inputs, outputs, kl_penalty) - - def testKLPenaltyKernelReparameterization(self): - self._testKLPenaltyKernel(prob_layers_lib.DenseReparameterization) - - def testKLPenaltyKernelLocalReparameterization(self): - self._testKLPenaltyKernel(prob_layers_lib.DenseLocalReparameterization) - - def testKLPenaltyKernelFlipout(self): - self._testKLPenaltyKernel(prob_layers_lib.DenseFlipout) - - def testKLPenaltyBothReparameterization(self): - self._testKLPenaltyBoth(prob_layers_lib.DenseReparameterization) - - def testKLPenaltyBothLocalReparameterization(self): - self._testKLPenaltyBoth(prob_layers_lib.DenseLocalReparameterization) - - def testKLPenaltyBothFlipout(self): - self._testKLPenaltyBoth(prob_layers_lib.DenseFlipout) - - def testDenseReparameterization(self): - batch_size, in_size, out_size = 2, 3, 4 - with self.test_session() as sess: - (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, layer, inputs, - outputs, kl_penalty) = self._testDenseSetUp( - prob_layers_lib.DenseReparameterization, - batch_size, in_size, out_size) - - expected_outputs = ( - math_ops.matmul(inputs, kernel_posterior.result_sample) + - bias_posterior.result_sample) - - [ - expected_outputs_, actual_outputs_, - expected_kernel_, actual_kernel_, - expected_kernel_divergence_, actual_kernel_divergence_, - expected_bias_, actual_bias_, - expected_bias_divergence_, actual_bias_divergence_, - ] = sess.run([ - expected_outputs, outputs, - kernel_posterior.result_sample, layer.kernel_posterior_tensor, - kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, layer.bias_posterior_tensor, - bias_divergence.result, kl_penalty[1], - ]) - - self.assertAllClose( - expected_kernel_, actual_kernel_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_, actual_bias_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_outputs_, actual_outputs_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_kernel_divergence_, actual_kernel_divergence_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_divergence_, actual_bias_divergence_, - rtol=1e-6, atol=0.) - - self.assertAllEqual( - [[kernel_posterior.distribution, - kernel_prior.distribution, - kernel_posterior.result_sample]], - kernel_divergence.args) - - self.assertAllEqual( - [[bias_posterior.distribution, - bias_prior.distribution, - bias_posterior.result_sample]], - bias_divergence.args) - - def testDenseLocalReparameterization(self): - batch_size, in_size, out_size = 2, 3, 4 - with self.test_session() as sess: - (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, layer, inputs, - outputs, kl_penalty) = self._testDenseSetUp( - prob_layers_lib.DenseLocalReparameterization, - batch_size, in_size, out_size) - - expected_kernel_posterior_affine = normal_lib.Normal( - loc=math_ops.matmul(inputs, kernel_posterior.result_loc), - scale=math_ops.matmul( - inputs**2., kernel_posterior.result_scale**2)**0.5) - expected_kernel_posterior_affine_tensor = ( - expected_kernel_posterior_affine.sample(seed=42)) - expected_outputs = (expected_kernel_posterior_affine_tensor + - bias_posterior.result_sample) - - [ - expected_outputs_, actual_outputs_, - expected_kernel_divergence_, actual_kernel_divergence_, - expected_bias_, actual_bias_, - expected_bias_divergence_, actual_bias_divergence_, - ] = sess.run([ - expected_outputs, outputs, - kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, layer.bias_posterior_tensor, - bias_divergence.result, kl_penalty[1], - ]) - - self.assertAllClose( - expected_bias_, actual_bias_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_outputs_, actual_outputs_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_kernel_divergence_, actual_kernel_divergence_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_divergence_, actual_bias_divergence_, - rtol=1e-6, atol=0.) - - self.assertAllEqual( - [[kernel_posterior.distribution, - kernel_prior.distribution, - None]], - kernel_divergence.args) - - self.assertAllEqual( - [[bias_posterior.distribution, - bias_prior.distribution, - bias_posterior.result_sample]], - bias_divergence.args) - - def testDenseFlipout(self): - batch_size, in_size, out_size = 2, 3, 4 - with self.test_session() as sess: - (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, layer, inputs, - outputs, kl_penalty) = self._testDenseSetUp( - prob_layers_lib.DenseFlipout, - batch_size, in_size, out_size, seed=44) - - expected_kernel_posterior_affine = normal_lib.Normal( - loc=array_ops.zeros_like(kernel_posterior.result_loc), - scale=kernel_posterior.result_scale) - expected_kernel_posterior_affine_tensor = ( - expected_kernel_posterior_affine.sample(seed=42)) - - sign_input = random_ops.random_uniform( - [batch_size, in_size], - minval=0, - maxval=2, - dtype=dtypes.int32, - seed=layer.seed) - sign_input = math_ops.cast(2 * sign_input - 1, inputs.dtype) - sign_output = random_ops.random_uniform( - [batch_size, out_size], - minval=0, - maxval=2, - dtype=dtypes.int32, - seed=distribution_util.gen_new_seed( - layer.seed, salt="dense_flipout")) - sign_output = math_ops.cast(2 * sign_output - 1, inputs.dtype) - perturbed_inputs = math_ops.matmul( - inputs * sign_input, expected_kernel_posterior_affine_tensor) - perturbed_inputs *= sign_output - - expected_outputs = math_ops.matmul(inputs, kernel_posterior.result_loc) - expected_outputs += perturbed_inputs - expected_outputs += bias_posterior.result_sample - - [ - expected_outputs_, actual_outputs_, - expected_kernel_divergence_, actual_kernel_divergence_, - expected_bias_, actual_bias_, - expected_bias_divergence_, actual_bias_divergence_, - ] = sess.run([ - expected_outputs, outputs, - kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, layer.bias_posterior_tensor, - bias_divergence.result, kl_penalty[1], - ]) - - self.assertAllClose( - expected_bias_, actual_bias_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_outputs_, actual_outputs_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_kernel_divergence_, actual_kernel_divergence_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_divergence_, actual_bias_divergence_, - rtol=1e-6, atol=0.) - - self.assertAllEqual( - [[kernel_posterior.distribution, kernel_prior.distribution, None]], - kernel_divergence.args) - - self.assertAllEqual( - [[bias_posterior.distribution, - bias_prior.distribution, - bias_posterior.result_sample]], - bias_divergence.args) - - def testRandomDenseFlipout(self): - batch_size, in_size, out_size = 2, 3, 4 - with self.test_session() as sess: - seed = Counter() - inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) - - kernel_posterior = MockDistribution( - loc=random_ops.random_uniform( - [in_size, out_size], seed=seed()), - scale=random_ops.random_uniform( - [in_size, out_size], seed=seed()), - result_log_prob=random_ops.random_uniform( - [in_size, out_size], seed=seed()), - result_sample=random_ops.random_uniform( - [in_size, out_size], seed=seed())) - bias_posterior = MockDistribution( - loc=random_ops.random_uniform( - [out_size], seed=seed()), - scale=random_ops.random_uniform( - [out_size], seed=seed()), - result_log_prob=random_ops.random_uniform( - [out_size], seed=seed()), - result_sample=random_ops.random_uniform( - [out_size], seed=seed())) - layer_one = prob_layers_lib.DenseFlipout( - units=out_size, - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - seed=44) - layer_two = prob_layers_lib.DenseFlipout( - units=out_size, - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - seed=45) - - outputs_one = layer_one(inputs) - outputs_two = layer_two(inputs) - - outputs_one_, outputs_two_ = sess.run([ - outputs_one, outputs_two]) - - self.assertLess(np.sum(np.isclose(outputs_one_, outputs_two_)), out_size) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py deleted file mode 100644 index 52e36e135d95c1ec919c710f35d59073c2134d05..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py +++ /dev/null @@ -1,445 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for MCMC diagnostic utilities.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics_impl as mcmc_diagnostics -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import spectral_ops_test_util -from tensorflow.python.platform import test - -rng = np.random.RandomState(42) - - -class _EffectiveSampleSizeTest(object): - - @property - def use_static_shape(self): - raise NotImplementedError( - "Subclass failed to implement `use_static_shape`.") - - def _check_versus_expected_effective_sample_size(self, - x_, - expected_ess, - sess, - atol=1e-2, - rtol=1e-2, - filter_threshold=None, - filter_beyond_lag=None): - x = array_ops.placeholder_with_default( - input=x_, shape=x_.shape if self.use_static_shape else None) - ess = mcmc_diagnostics.effective_sample_size( - x, - filter_threshold=filter_threshold, - filter_beyond_lag=filter_beyond_lag) - if self.use_static_shape: - self.assertAllEqual(x.shape[1:], ess.shape) - - ess_ = sess.run(ess) - - self.assertAllClose( - np.ones_like(ess_) * expected_ess, ess_, atol=atol, rtol=rtol) - - def testIidRank1NormalHasFullEssMaxLags10(self): - # With a length 5000 iid normal sequence, and filter_beyond_lag = 10, we - # should have a good estimate of ESS, and it should be close to the full - # sequence length of 5000. - # The choice of filter_beyond_lag = 10 is a short cutoff, reasonable only - # since we know the correlation length should be zero right away. - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - self._check_versus_expected_effective_sample_size( - x_=rng.randn(5000).astype(np.float32), - expected_ess=5000, - sess=sess, - filter_beyond_lag=10, - filter_threshold=None, - rtol=0.3) - - def testIidRank2NormalHasFullEssMaxLags10(self): - # See similar test for Rank1Normal for reasoning. - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - self._check_versus_expected_effective_sample_size( - x_=rng.randn(5000, 2).astype(np.float32), - expected_ess=5000, - sess=sess, - filter_beyond_lag=10, - filter_threshold=None, - rtol=0.3) - - def testIidRank1NormalHasFullEssMaxLagThresholdZero(self): - # With a length 5000 iid normal sequence, and filter_threshold = 0, - # we should have a super-duper estimate of ESS, and it should be very close - # to the full sequence length of 5000. - # The choice of filter_beyond_lag = 0 means we cutoff as soon as the - # auto-corris below zero. This should happen very quickly, due to the fact - # that the theoretical auto-corr is [1, 0, 0,...] - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - self._check_versus_expected_effective_sample_size( - x_=rng.randn(5000).astype(np.float32), - expected_ess=5000, - sess=sess, - filter_beyond_lag=None, - filter_threshold=0., - rtol=0.1) - - def testIidRank2NormalHasFullEssMaxLagThresholdZero(self): - # See similar test for Rank1Normal for reasoning. - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - self._check_versus_expected_effective_sample_size( - x_=rng.randn(5000, 2).astype(np.float32), - expected_ess=5000, - sess=sess, - filter_beyond_lag=None, - filter_threshold=0., - rtol=0.1) - - def testLength10CorrelationHasEssOneTenthTotalLengthUsingMaxLags50(self): - # Create x_, such that - # x_[i] = iid_x_[0], i = 0,...,9 - # x_[i] = iid_x_[1], i = 10,..., 19, - # and so on. - iid_x_ = rng.randn(5000, 1).astype(np.float32) - x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - self._check_versus_expected_effective_sample_size( - x_=x_, - expected_ess=50000 // 10, - sess=sess, - filter_beyond_lag=50, - filter_threshold=None, - rtol=0.2) - - def testLength10CorrelationHasEssOneTenthTotalLengthUsingMaxLagsThresholdZero( - self): - # Create x_, such that - # x_[i] = iid_x_[0], i = 0,...,9 - # x_[i] = iid_x_[1], i = 10,..., 19, - # and so on. - iid_x_ = rng.randn(5000, 1).astype(np.float32) - x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - self._check_versus_expected_effective_sample_size( - x_=x_, - expected_ess=50000 // 10, - sess=sess, - filter_beyond_lag=None, - filter_threshold=0., - rtol=0.1) - - def testListArgs(self): - # x_ has correlation length 10 ==> ESS = N / 10 - # y_ has correlation length 1 ==> ESS = N - iid_x_ = rng.randn(5000, 1).astype(np.float32) - x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) - y_ = rng.randn(50000).astype(np.float32) - states = [x_, x_, y_, y_] - filter_threshold = [0., None, 0., None] - filter_beyond_lag = [None, 5, None, 5] - - # See other tests for reasoning on tolerance. - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - ess = mcmc_diagnostics.effective_sample_size( - states, - filter_threshold=filter_threshold, - filter_beyond_lag=filter_beyond_lag) - ess_ = sess.run(ess) - self.assertAllEqual(4, len(ess_)) - - self.assertAllClose(50000 // 10, ess_[0], rtol=0.3) - self.assertAllClose(50000 // 10, ess_[1], rtol=0.3) - self.assertAllClose(50000, ess_[2], rtol=0.1) - self.assertAllClose(50000, ess_[3], rtol=0.1) - - def testMaxLagsThresholdLessThanNeg1SameAsNone(self): - # Setting both means we filter out items R_k from the auto-correlation - # sequence if k > filter_beyond_lag OR k >= j where R_j < filter_threshold. - - # x_ has correlation length 10. - iid_x_ = rng.randn(500, 1).astype(np.float32) - x_ = (iid_x_ * np.ones((500, 10)).astype(np.float32)).reshape((5000,)) - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - x = array_ops.placeholder_with_default( - input=x_, shape=x_.shape if self.use_static_shape else None) - - ess_none_none = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=None, filter_beyond_lag=None) - ess_none_200 = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=None, filter_beyond_lag=200) - ess_neg2_200 = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=-2., filter_beyond_lag=200) - ess_neg2_none = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=-2., filter_beyond_lag=None) - ess_none_none_, ess_none_200_, ess_neg2_200_, ess_neg2_none_ = sess.run( - [ess_none_none, ess_none_200, ess_neg2_200, ess_neg2_none]) - - # filter_threshold=-2 <==> filter_threshold=None. - self.assertAllClose(ess_none_none_, ess_neg2_none_) - self.assertAllClose(ess_none_200_, ess_neg2_200_) - - def testMaxLagsArgsAddInAnOrManner(self): - # Setting both means we filter out items R_k from the auto-correlation - # sequence if k > filter_beyond_lag OR k >= j where R_j < filter_threshold. - - # x_ has correlation length 10. - iid_x_ = rng.randn(500, 1).astype(np.float32) - x_ = (iid_x_ * np.ones((500, 10)).astype(np.float32)).reshape((5000,)) - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - x = array_ops.placeholder_with_default( - input=x_, shape=x_.shape if self.use_static_shape else None) - - ess_1_9 = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=1., filter_beyond_lag=9) - ess_1_none = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=1., filter_beyond_lag=None) - ess_none_9 = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=1., filter_beyond_lag=9) - ess_1_9_, ess_1_none_, ess_none_9_ = sess.run( - [ess_1_9, ess_1_none, ess_none_9]) - - # Since R_k = 1 for k < 10, and R_k < 1 for k >= 10, - # filter_threshold = 1 <==> filter_beyond_lag = 9. - self.assertAllClose(ess_1_9_, ess_1_none_) - self.assertAllClose(ess_1_9_, ess_none_9_) - - -class EffectiveSampleSizeStaticTest(test.TestCase, _EffectiveSampleSizeTest): - - @property - def use_static_shape(self): - return True - - -class EffectiveSampleSizeDynamicTest(test.TestCase, _EffectiveSampleSizeTest): - - @property - def use_static_shape(self): - return False - - -class _PotentialScaleReductionTest(object): - - @property - def use_static_shape(self): - raise NotImplementedError( - "Subclass failed to impliment `use_static_shape`.") - - def testListOfStatesWhereFirstPassesSecondFails(self): - """Simple test showing API with two states. Read first!.""" - n_samples = 1000 - - # state_0 is two scalar chains taken from iid Normal(0, 1). Will pass. - state_0 = rng.randn(n_samples, 2) - - # state_1 is three 4-variate chains taken from Normal(0, 1) that have been - # shifted. Since every chain is shifted, they are not the same, and the - # test should fail. - offset = np.array([1., -1., 2.]).reshape(3, 1) - state_1 = rng.randn(n_samples, 3, 4) + offset - - rhat = mcmc_diagnostics.potential_scale_reduction( - chains_states=[state_0, state_1], independent_chain_ndims=1) - - self.assertIsInstance(rhat, list) - with self.test_session() as sess: - rhat_0_, rhat_1_ = sess.run(rhat) - - # r_hat_0 should be close to 1, meaning test is passed. - self.assertAllEqual((), rhat_0_.shape) - self.assertAllClose(1., rhat_0_, rtol=0.02) - - # r_hat_1 should be greater than 1.2, meaning test has failed. - self.assertAllEqual((4,), rhat_1_.shape) - self.assertAllEqual(np.ones_like(rhat_1_).astype(bool), rhat_1_ > 1.2) - - def check_results(self, state_, independent_chain_shape, should_pass): - sample_ndims = 1 - independent_chain_ndims = len(independent_chain_shape) - with self.test_session(): - state = array_ops.placeholder_with_default( - input=state_, shape=state_.shape if self.use_static_shape else None) - - rhat = mcmc_diagnostics.potential_scale_reduction( - state, independent_chain_ndims=independent_chain_ndims) - - if self.use_static_shape: - self.assertAllEqual( - state_.shape[sample_ndims + independent_chain_ndims:], rhat.shape) - - rhat_ = rhat.eval() - if should_pass: - self.assertAllClose(np.ones_like(rhat_), rhat_, atol=0, rtol=0.02) - else: - self.assertAllEqual(np.ones_like(rhat_).astype(bool), rhat_ > 1.2) - - def iid_normal_chains_should_pass_wrapper(self, - sample_shape, - independent_chain_shape, - other_shape, - dtype=np.float32): - """Check results with iid normal chains.""" - - state_shape = sample_shape + independent_chain_shape + other_shape - state_ = rng.randn(*state_shape).astype(dtype) - - # The "other" dimensions do not have to be identical, just independent, so - # force them to not be identical. - if other_shape: - state_ *= rng.rand(*other_shape).astype(dtype) - - self.check_results(state_, independent_chain_shape, should_pass=True) - - def testPassingIIDNdimsAreIndependentOneOtherZero(self): - self.iid_normal_chains_should_pass_wrapper( - sample_shape=[10000], independent_chain_shape=[4], other_shape=[]) - - def testPassingIIDNdimsAreIndependentOneOtherOne(self): - self.iid_normal_chains_should_pass_wrapper( - sample_shape=[10000], independent_chain_shape=[3], other_shape=[7]) - - def testPassingIIDNdimsAreIndependentOneOtherTwo(self): - self.iid_normal_chains_should_pass_wrapper( - sample_shape=[10000], independent_chain_shape=[2], other_shape=[5, 7]) - - def testPassingIIDNdimsAreIndependentTwoOtherTwo64Bit(self): - self.iid_normal_chains_should_pass_wrapper( - sample_shape=[10000], - independent_chain_shape=[2, 3], - other_shape=[5, 7], - dtype=np.float64) - - def offset_normal_chains_should_fail_wrapper( - self, sample_shape, independent_chain_shape, other_shape): - """Check results with normal chains that are offset from each other.""" - - state_shape = sample_shape + independent_chain_shape + other_shape - state_ = rng.randn(*state_shape) - - # Add a significant offset to the different (formerly iid) chains. - offset = np.linspace( - 0, 2, num=np.prod(independent_chain_shape)).reshape([1] * len( - sample_shape) + independent_chain_shape + [1] * len(other_shape)) - state_ += offset - - self.check_results(state_, independent_chain_shape, should_pass=False) - - def testFailingOffsetNdimsAreSampleOneIndependentOneOtherOne(self): - self.offset_normal_chains_should_fail_wrapper( - sample_shape=[10000], independent_chain_shape=[2], other_shape=[5]) - - -class PotentialScaleReductionStaticTest(test.TestCase, - _PotentialScaleReductionTest): - - @property - def use_static_shape(self): - return True - - def testIndependentNdimsLessThanOneRaises(self): - with self.assertRaisesRegexp(ValueError, "independent_chain_ndims"): - mcmc_diagnostics.potential_scale_reduction( - rng.rand(2, 3, 4), independent_chain_ndims=0) - - -class PotentialScaleReductionDynamicTest(test.TestCase, - _PotentialScaleReductionTest): - - @property - def use_static_shape(self): - return False - - -class _ReduceVarianceTest(object): - - @property - def use_static_shape(self): - raise NotImplementedError( - "Subclass failed to impliment `use_static_shape`.") - - def check_versus_numpy(self, x_, axis, biased, keepdims): - with self.test_session(): - x_ = np.asarray(x_) - x = array_ops.placeholder_with_default( - input=x_, shape=x_.shape if self.use_static_shape else None) - var = mcmc_diagnostics._reduce_variance( - x, axis=axis, biased=biased, keepdims=keepdims) - np_var = np.var(x_, axis=axis, ddof=0 if biased else 1, keepdims=keepdims) - - if self.use_static_shape: - self.assertAllEqual(np_var.shape, var.shape) - - var_ = var.eval() - # We will mask below, which changes shape, so check shape explicitly here. - self.assertAllEqual(np_var.shape, var_.shape) - - # We get NaN when we divide by zero due to the size being the same as ddof - nan_mask = np.isnan(np_var) - if nan_mask.any(): - self.assertTrue(np.isnan(var_[nan_mask]).all()) - self.assertAllClose(np_var[~nan_mask], var_[~nan_mask], atol=0, rtol=0.02) - - def testScalarBiasedTrue(self): - self.check_versus_numpy(x_=-1.234, axis=None, biased=True, keepdims=False) - - def testScalarBiasedFalse(self): - # This should result in NaN. - self.check_versus_numpy(x_=-1.234, axis=None, biased=False, keepdims=False) - - def testShape2x3x4AxisNoneBiasedFalseKeepdimsFalse(self): - self.check_versus_numpy( - x_=rng.randn(2, 3, 4), axis=None, biased=True, keepdims=False) - - def testShape2x3x4Axis1BiasedFalseKeepdimsTrue(self): - self.check_versus_numpy( - x_=rng.randn(2, 3, 4), axis=1, biased=True, keepdims=True) - - def testShape2x3x4x5Axis13BiasedFalseKeepdimsTrue(self): - self.check_versus_numpy( - x_=rng.randn(2, 3, 4, 5), axis=1, biased=True, keepdims=True) - - def testShape2x3x4x5Axis13BiasedFalseKeepdimsFalse(self): - self.check_versus_numpy( - x_=rng.randn(2, 3, 4, 5), axis=1, biased=False, keepdims=False) - - -class ReduceVarianceTestStaticShape(test.TestCase, _ReduceVarianceTest): - - @property - def use_static_shape(self): - return True - - -class ReduceVarianceTestDynamicShape(test.TestCase, _ReduceVarianceTest): - - @property - def use_static_shape(self): - return False - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py deleted file mode 100644 index f508e5b114a55fc1aeb07212595fda45fc308c7b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py +++ /dev/null @@ -1,340 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Metropolis-Hastings.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings_impl as mh -from tensorflow.contrib.distributions.python.ops import mvn_tril as mvn_tril_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.platform import test - - -class MetropolisHastingsTest(test.TestCase): - - def testKernelStateTensor(self): - """Test that transition kernel works with tensor input to `state`.""" - loc = variable_scope.get_variable("loc", initializer=0.) - - def target_log_prob_fn(loc): - return normal_lib.Normal(loc=0.0, scale=0.1).log_prob(loc) - - new_state, _ = mh.kernel( - target_log_prob_fn=target_log_prob_fn, - proposal_fn=mh.proposal_normal(scale=0.05), - current_state=loc, - seed=231251) - loc_update = loc.assign(new_state) - - init = variables.initialize_all_variables() - with self.test_session() as sess: - sess.run(init) - loc_samples = [] - for _ in range(2500): - loc_sample = sess.run(loc_update) - loc_samples.append(loc_sample) - loc_samples = loc_samples[500:] # drop samples for burn-in - - self.assertAllClose(np.mean(loc_samples), 0.0, rtol=1e-5, atol=1e-1) - self.assertAllClose(np.std(loc_samples), 0.1, rtol=1e-5, atol=1e-1) - - def testKernelStateList(self): - """Test that transition kernel works with list input to `state`.""" - num_chains = 2 - loc_one = variable_scope.get_variable( - "loc_one", [num_chains], - initializer=init_ops.zeros_initializer()) - loc_two = variable_scope.get_variable( - "loc_two", [num_chains], initializer=init_ops.zeros_initializer()) - - def target_log_prob_fn(loc_one, loc_two): - loc = array_ops.stack([loc_one, loc_two]) - log_prob = mvn_tril_lib.MultivariateNormalTriL( - loc=constant_op.constant([0., 0.]), - scale_tril=constant_op.constant([[0.1, 0.1], [0.0, 0.1]])).log_prob( - loc) - return math_ops.reduce_sum(log_prob, 0) - - def proposal_fn(loc_one, loc_two): - loc_one_proposal = mh.proposal_normal(scale=0.05) - loc_two_proposal = mh.proposal_normal(scale=0.05) - loc_one_sample, _ = loc_one_proposal(loc_one) - loc_two_sample, _ = loc_two_proposal(loc_two) - return [loc_one_sample, loc_two_sample], None - - new_state, _ = mh.kernel( - target_log_prob_fn=target_log_prob_fn, - proposal_fn=proposal_fn, - current_state=[loc_one, loc_two], - seed=12415) - loc_one_update = loc_one.assign(new_state[0]) - loc_two_update = loc_two.assign(new_state[1]) - - init = variables.initialize_all_variables() - with self.test_session() as sess: - sess.run(init) - loc_one_samples = [] - loc_two_samples = [] - for _ in range(10000): - loc_one_sample, loc_two_sample = sess.run( - [loc_one_update, loc_two_update]) - loc_one_samples.append(loc_one_sample) - loc_two_samples.append(loc_two_sample) - - loc_one_samples = np.array(loc_one_samples) - loc_two_samples = np.array(loc_two_samples) - loc_one_samples = loc_one_samples[1000:] # drop samples for burn-in - loc_two_samples = loc_two_samples[1000:] # drop samples for burn-in - - self.assertAllClose(np.mean(loc_one_samples, 0), - np.array([0.] * num_chains), - rtol=1e-5, atol=1e-1) - self.assertAllClose(np.mean(loc_two_samples, 0), - np.array([0.] * num_chains), - rtol=1e-5, atol=1e-1) - self.assertAllClose(np.std(loc_one_samples, 0), - np.array([0.1] * num_chains), - rtol=1e-5, atol=1e-1) - self.assertAllClose(np.std(loc_two_samples, 0), - np.array([0.1] * num_chains), - rtol=1e-5, atol=1e-1) - - def testKernelResultsUsingTruncatedDistribution(self): - def log_prob(x): - return array_ops.where( - x >= 0., - -x - x**2, - array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype))) - # The truncated distribution has the property that it is likely to attract - # the flow toward, and below, zero...but for x <=0, - # log_prob(x) = -inf, which should result in rejection, as well - # as a non-finite log_prob. Thus, this distribution gives us an opportunity - # to test out the kernel results ability to correctly capture rejections due - # to finite AND non-finite reasons. - - num_results = 1000 - # Large step size, will give rejections due to going into a region of - # log_prob = -inf. - step_size = 0.3 - num_chains = 2 - - with self.test_session(graph=ops.Graph()) as sess: - - # Start multiple independent chains. - initial_state = ops.convert_to_tensor([0.1] * num_chains) - - states = [] - is_accepted = [] - proposed_states = [] - current_state = initial_state - for _ in range(num_results): - current_state, kernel_results = mh.kernel( - target_log_prob_fn=log_prob, - proposal_fn=mh.proposal_uniform(step_size=step_size), - current_state=current_state, - seed=42) - states.append(current_state) - proposed_states.append(kernel_results.proposed_state) - is_accepted.append(kernel_results.is_accepted) - - states = array_ops.stack(states) - proposed_states = array_ops.stack(proposed_states) - is_accepted = array_ops.stack(is_accepted) - states_, pstates_, is_accepted_ = sess.run( - [states, proposed_states, is_accepted]) - - # We better have accepted a decent amount, even near end of the chain. - self.assertLess( - 0.1, is_accepted_[int(0.9 * num_results):].mean()) - # We better not have any NaNs in states. - self.assertAllEqual(np.zeros_like(states_), - np.isnan(states_)) - # We better not have any +inf in states. - self.assertAllEqual(np.zeros_like(states_), - np.isposinf(states_)) - - # The move is accepted ==> state = proposed state. - self.assertAllEqual( - states_[is_accepted_], - pstates_[is_accepted_], - ) - - # The move was rejected <==> state[t] == state[t - 1]. - for t in range(1, num_results): - for i in range(num_chains): - if is_accepted_[t, i]: - self.assertNotEqual(states_[t, i], states_[t - 1, i]) - else: - self.assertEqual(states_[t, i], states_[t - 1, i]) - - def testDensityIncreasingStepAccepted(self): - """Tests that if a transition increases density, it is always accepted.""" - target_log_density = lambda x: - x * x - state = variable_scope.get_variable("state", initializer=10.) - state_log_density = variable_scope.get_variable( - "state_log_density", - initializer=target_log_density(state.initialized_value())) - log_accept_ratio = variable_scope.get_variable( - "log_accept_ratio", initializer=0.) - - get_next_proposal = lambda x: (x - 1., None) - step = mh.evolve(state, state_log_density, log_accept_ratio, - target_log_density, get_next_proposal, seed=1234) - init = variables.initialize_all_variables() - with self.test_session() as sess: - sess.run(init) - for j in range(9): - sess.run(step) - sample = sess.run(state) - sample_log_density = sess.run(state_log_density) - self.assertAlmostEqual(sample, 9 - j) - self.assertAlmostEqual(sample_log_density, - (9 - j) * (9 - j)) - - def testSampleProperties(self): - """Tests that the samples converge to the target distribution.""" - - def target_log_density(x): - """Log-density corresponding to a normal distribution with mean = 4.""" - return - (x - 2.0) * (x - 2.0) * 0.5 - - # Use the uniform random walker to generate proposals. - proposal_fn = mh.proposal_uniform( - step_size=1.0, seed=1234) - - state = variable_scope.get_variable("state", initializer=0.0) - state_log_density = variable_scope.get_variable( - "state_log_density", - initializer=target_log_density(state.initialized_value())) - log_accept_ratio = variable_scope.get_variable( - "log_accept_ratio", initializer=0.) - - # Random walk MCMC converges slowly so need to put in enough iterations. - num_iterations = 5000 - step = mh.evolve(state, state_log_density, log_accept_ratio, - target_log_density, proposal_fn, seed=4321) - - init = variables.global_variables_initializer() - - sample_sum, sample_sq_sum = 0.0, 0.0 - with self.test_session() as sess: - sess.run(init) - for _ in np.arange(num_iterations): - # Allow for the mixing of the chain and discard these samples. - sess.run(step) - for _ in np.arange(num_iterations): - sess.run(step) - sample = sess.run(state) - sample_sum += sample - sample_sq_sum += sample * sample - - sample_mean = sample_sum / num_iterations - sample_variance = sample_sq_sum / num_iterations - sample_mean * sample_mean - # The samples have large autocorrelation which reduces the effective sample - # size. - self.assertAlmostEqual(sample_mean, 2.0, delta=0.1) - self.assertAlmostEqual(sample_variance, 1.0, delta=0.1) - - def testProposalNormal(self): - """Tests that the normal proposals are correctly distributed.""" - - initial_points = array_ops.ones([10000], dtype=dtypes.float32) - proposal_fn = mh.proposal_normal( - scale=2.0, seed=1234) - proposal_points, _ = proposal_fn(initial_points) - - with self.test_session() as sess: - sample = sess.run(proposal_points) - - # It is expected that the elements in proposal_points have the same mean as - # initial_points and have the standard deviation that was supplied to the - # proposal scheme. - self.assertAlmostEqual(np.mean(sample), 1.0, delta=0.1) - self.assertAlmostEqual(np.std(sample), 2.0, delta=0.1) - - def testDocstringExample(self): - """Tests the simplified docstring example with multiple chains.""" - - n = 2 # dimension of the problem - - # Generate 300 initial values randomly. Each of these would be an - # independent starting point for a Markov chain. - state = variable_scope.get_variable( - "state", initializer=random_ops.random_normal( - [300, n], mean=3.0, dtype=dtypes.float32, seed=42)) - - # Computes the log(p(x)) for the unit normal density and ignores the - # normalization constant. - def log_density(x): - return - math_ops.reduce_sum(x * x, reduction_indices=-1) / 2.0 - - # Initial log-density value - state_log_density = variable_scope.get_variable( - "state_log_density", - initializer=log_density(state.initialized_value())) - - # A variable to store the log_acceptance_ratio: - log_acceptance_ratio = variable_scope.get_variable( - "log_acceptance_ratio", - initializer=array_ops.zeros([300], dtype=dtypes.float32)) - - # Generates random proposals by moving each coordinate uniformly and - # independently in a box of size 2 centered around the current value. - # Returns the new point and also the log of the Hastings ratio (the - # ratio of the probability of going from the proposal to origin and the - # probability of the reverse transition). When this ratio is 1, the value - # may be omitted and replaced by None. - def random_proposal(x): - return (x + random_ops.random_uniform( - array_ops.shape(x), minval=-1, maxval=1, - dtype=x.dtype, seed=12)), None - - # Create the op to propagate the chain for 100 steps. - stepper = mh.evolve( - state, state_log_density, log_acceptance_ratio, - log_density, random_proposal, n_steps=100, seed=123) - init = variables.initialize_all_variables() - with self.test_session() as sess: - sess.run(init) - # Run the chains for a total of 1000 steps. - for _ in range(10): - sess.run(stepper) - samples = sess.run(state) - covariance = np.eye(n) - # Verify that the estimated mean and covariance are close to the true - # values. - self.assertAlmostEqual( - np.max(np.abs(np.mean(samples, 0) - - np.zeros(n))), 0, - delta=0.1) - self.assertAlmostEqual( - np.max(np.abs(np.reshape(np.cov(samples, rowvar=False), [n**2]) - - np.reshape(covariance, [n**2]))), 0, - delta=0.2) - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py deleted file mode 100644 index 756c25683bd4b0c8c77e9e28485ca2a85582999c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functional test for GradientDescent.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -import math -from tensorflow.contrib.bayesflow.python.ops.optimizers import SGLDOptimizer -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - - -class SGLDOptimizerTest(test.TestCase): - - def testBasic(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - decay_rate = 0.53 - sgd_optimizer = SGLDOptimizer(3.0, preconditioner_decay_rate=decay_rate) - sgd_op = sgd_optimizer.apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + - (1 - decay_rate) * 0.1**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) - grads_scaled = (0.5 * 0.01 / math.sqrt( - decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) - self.assertAllCloseAccordingToType(1, sgd_optimizer._counter.eval()) - - def testBasicMultiInstance(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - vara = variables.Variable([1.1, 2.1], dtype=dtype) - varb = variables.Variable([3.0, 4.0], dtype=dtype) - gradsa = constant_op.constant([0.1, 0.1], dtype=dtype) - gradsb = constant_op.constant([0.01, 0.01], dtype=dtype) - decay_rate = 0.5 - sgd_optimizer = SGLDOptimizer(3.0, preconditioner_decay_rate=decay_rate) - sgd_op = sgd_optimizer.apply_gradients( - zip([grads0, grads1], [var0, var1])) - sgd_optimizer2 = SGLDOptimizer( - 3.0, preconditioner_decay_rate=decay_rate) - sgd_op2 = sgd_optimizer2.apply_gradients( - zip([gradsa, gradsb], [vara, varb])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - self.assertAllCloseAccordingToType([1.1, 2.1], vara.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], varb.eval()) - - # Run 1 step of sgd - sgd_op.run() - sgd_op2.run() - # Validate updated params - grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + - (1 - decay_rate) * 0.1**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) - self.assertAllCloseAccordingToType( - [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], vara.eval()) - - grads_scaled = (0.5 * 0.01 / math.sqrt( - decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) - self.assertAllCloseAccordingToType( - [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], varb.eval()) - self.assertNotEqual(sgd_optimizer.variable_scope, - sgd_optimizer2.variable_scope) - self.assertNotEqual(sgd_optimizer.variable_scope.name, - sgd_optimizer2.variable_scope.name) - self.assertAllCloseAccordingToType(1, sgd_optimizer._counter.eval()) - self.assertAllCloseAccordingToType(1, sgd_optimizer2._counter.eval()) - - def testTensorLearningRate(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - lrate = constant_op.constant(3.0) - decay_rate = 0.5 - sgd_op = SGLDOptimizer( - lrate, preconditioner_decay_rate=constant_op.constant( - decay_rate)).apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + - (1 - decay_rate) * 0.1**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) - grads_scaled = (0.5 * 0.01 / math.sqrt( - decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) - - def testGradWrtRef(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - opt = SGLDOptimizer(3.0) - values = [1.0, 3.0] - vars_ = [variables.Variable([v], dtype=dtype) for v in values] - grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_) - variables.global_variables_initializer().run() - for grad, _ in grads_and_vars: - self.assertAllCloseAccordingToType([1.0], grad.eval()) - - def testWithGlobalStep(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - global_step = variables.Variable(0, trainable=False) - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - decay_rate = 0.1 - sgd_op = SGLDOptimizer( - 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( - zip([grads0, grads1], [var0, var1]), global_step=global_step) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - - # Validate updated params and global_step - grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + - (1 - decay_rate) * 0.1**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) - grads_scaled = (0.5 * 0.01 / math.sqrt( - decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) - self.assertAllCloseAccordingToType(1, global_step.eval()) - - def testSparseBasic(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([[1.1], [2.1]], dtype=dtype) - var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) - grads0 = ops.IndexedSlices( - constant_op.constant([0.1], shape=[1, 1], dtype=dtype), - constant_op.constant([0]), constant_op.constant([2, 1])) - grads1 = ops.IndexedSlices( - constant_op.constant([0.01], shape=[1, 1], dtype=dtype), - constant_op.constant([1]), constant_op.constant([2, 1])) - decay_rate = 0.9 - sgd_op = SGLDOptimizer( - 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([[1.1], [2.1]], var0.eval()) - self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + - (1 - decay_rate) * 0.1**2 + 1e-8)) - self.assertAllCloseAccordingToType([[1.1 - 3.0 * grads_scaled], [2.1]], - var0.eval()) - grads_scaled = (0.5 * 0.01 / math.sqrt( - decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [[3.0 - 3.0 * 0], [4.0 - 3.0 * grads_scaled]], var1.eval()) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py deleted file mode 100644 index f978cf86417dc5ff5412a3eee584330a266e0964..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for utility functions related to managing `tf.Variable`s.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import warnings - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import variable_utils - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.ops import variable_scope as varscope_ops -from tensorflow.python.ops import variables as variables_ops -from tensorflow.python.platform import test - - -def test_fn(x): - x = ops.convert_to_tensor(x, name="x") - dtype = x.dtype.as_numpy_dtype - s = x.shape.as_list() - z = varscope_ops.get_variable( - name="z", - dtype=dtype, - initializer=np.arange(np.prod(s)).reshape(s).astype(dtype)) - y = varscope_ops.get_variable( - name="y", - dtype=dtype, - initializer=np.arange(np.prod(s)).reshape(s).astype(dtype)**2) - return x + y + z - - -class _WrapCallableTest(object): - - def testDefaultArgsWorkCorrectly(self): - with self.test_session(): - x = constant_op.constant(self.dtype([0.1, 0.2])) - wrapped_fn, vars_args = variable_utils.externalize_variables_as_args( - test_fn, [x]) - - varscope_ops.get_variable_scope().reuse_variables() - - result = wrapped_fn(self.dtype(2), [3, 4, 5], 0.5) - - y_actual = varscope_ops.get_variable("y", dtype=self.dtype) - z_actual = varscope_ops.get_variable("z", dtype=self.dtype) - - variables_ops.global_variables_initializer().run() - result_ = result.eval() - - self.assertEqual(self.dtype, result_.dtype) - self.assertAllEqual([5.5, 6.5, 7.5], result_) - self.assertAllEqual([y_actual, z_actual], vars_args) - - def testNonDefaultArgsWorkCorrectly(self): - with self.test_session(): - x = constant_op.constant(self.dtype([0.1, 0.2])) - - _ = test_fn(self.dtype([0., 0.])) # Needed to create vars. - varscope_ops.get_variable_scope().reuse_variables() - - y_actual = varscope_ops.get_variable("y", dtype=self.dtype) - - wrapped_fn, vars_args = variable_utils.externalize_variables_as_args( - test_fn, [x], possible_ancestor_vars=[y_actual]) - - result = wrapped_fn(self.dtype([2, 3]), 0.5) # x, y - - variables_ops.global_variables_initializer().run() - result_ = result.eval() - - self.assertEqual(self.dtype, result_.dtype) - self.assertAllEqual([2.5, 4.5], result_) - self.assertAllEqual([y_actual], vars_args) - - def testWarnings(self): - with self.test_session(): - x = constant_op.constant(self.dtype([0.1, 0.2])) - wrapped_fn, _ = variable_utils.externalize_variables_as_args( - test_fn, [x], possible_ancestor_vars=[]) - varscope_ops.get_variable_scope().reuse_variables() - with warnings.catch_warnings(record=True) as w: - wrapped_fn(self.dtype(2)) - w = sorted(w, key=lambda w: str(w.message)) - self.assertEqual(2, len(w)) - self.assertRegexpMatches( - str(w[0].message), - r"Variable .* 'y:0' .* not found in bypass dict.") - self.assertRegexpMatches( - str(w[1].message), - r"Variable .* 'z:0' .* not found in bypass dict.") - - def testExceptions(self): - with self.test_session(): - x = constant_op.constant(self.dtype([0.1, 0.2])) - wrapped_fn, _ = variable_utils.externalize_variables_as_args( - test_fn, - [x], - possible_ancestor_vars=[], - assert_variable_override=True) - varscope_ops.get_variable_scope().reuse_variables() - with self.assertRaisesRegexp(ValueError, r"not found"): - wrapped_fn(self.dtype(2)) - - -class WrapCallableTest16(test.TestCase, _WrapCallableTest): - dtype = np.float16 - - -class WrapCallableTest32(test.TestCase, _WrapCallableTest): - dtype = np.float32 - - -class WrapCallableTest64(test.TestCase, _WrapCallableTest): - dtype = np.float64 - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_sgd_optimizer_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/variational_sgd_optimizer_test.py deleted file mode 100644 index 83c64dbe0fd586edcb784a5c09a4c133aaa99cff..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_sgd_optimizer_test.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functional test for GradientDescent.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from tensorflow.contrib.bayesflow.python.ops.optimizers import VariationalSGDOptimizer -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - - -class VariationalSGDOptimizerTest(test.TestCase): - - def testBasic(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - decay_rate = 0.53 - sgd_op = VariationalSGDOptimizer( - 1, - 1, - preconditioner_decay_rate=decay_rate, - max_learning_rate=3.0, - burnin_max_learning_rate=3.0, - use_single_learning_rate=True).apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], - var0.eval()) - self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], - var1.eval()) - - def testBasicMultiInstance(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - vara = variables.Variable([1.1, 2.1], dtype=dtype) - varb = variables.Variable([3.0, 4.0], dtype=dtype) - gradsa = constant_op.constant([0.1, 0.1], dtype=dtype) - gradsb = constant_op.constant([0.01, 0.01], dtype=dtype) - decay_rate = 0.5 - batch_size = 2 - total_num_examples = 10 - optimizer = VariationalSGDOptimizer( - batch_size, - total_num_examples, - max_learning_rate=1.0, - burnin_max_learning_rate=3.0, - preconditioner_decay_rate=decay_rate) - sgd_op = optimizer.apply_gradients( - zip([grads0, grads1], [var0, var1])) - optimizer2 = VariationalSGDOptimizer( - batch_size, - total_num_examples, - max_learning_rate=1.0, - burnin_max_learning_rate=10.0, - burnin=0, - preconditioner_decay_rate=decay_rate) - sgd_op2 = optimizer2.apply_gradients( - zip([gradsa, gradsb], [vara, varb])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - self.assertAllCloseAccordingToType([1.1, 2.1], vara.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], varb.eval()) - - # Run 1 step of sgd - sgd_op.run() - sgd_op2.run() - # Validate updated params - self.assertAllCloseAccordingToType([1.1 - 3. * 0.1, 2.1 - 3. * 0.1], - var0.eval()) - self.assertAllCloseAccordingToType([1.1 - 0.1, 2.1 - 0.1], vara.eval()) - - self.assertAllCloseAccordingToType([3.0 - 3. * 0.01, 4.0 - 3. * 0.01], - var1.eval()) - self.assertAllCloseAccordingToType([3.0 - 0.01, 4.0 - 0.01], - varb.eval()) - self.assertNotEqual(optimizer.variable_scope, - optimizer2.variable_scope) - self.assertNotEqual(optimizer.variable_scope.name, - optimizer2.variable_scope.name) - self.assertAllCloseAccordingToType(1, optimizer._counter.eval()) - self.assertAllCloseAccordingToType(1, optimizer2._counter.eval()) - - def testTensorLearningRate(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - lrate = constant_op.constant(3.0) - decay_rate = 0.5 - batch_size = 2 - total_num_examples = 10 - sgd_op = VariationalSGDOptimizer( - batch_size, - total_num_examples, - max_learning_rate=lrate, - burnin=0, - preconditioner_decay_rate=decay_rate).apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], - var0.eval()) - self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], - var1.eval()) - - def testTensorDecayLearningRate(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - lrate = variables.Variable(3.0) - lrate_decay_op = lrate.assign_add(-3.) - decay_rate = 0.5 - batch_size = 2 - total_num_examples = 10 - optimizer = VariationalSGDOptimizer( - batch_size, - total_num_examples, - max_learning_rate=lrate, - burnin=0, - preconditioner_decay_rate=decay_rate) - sgd_op = optimizer.apply_gradients(zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], - var0.eval()) - self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], - var1.eval()) - # Update learning rate to 0 - lrate_decay_op.eval() - sgd_op.run() - # Validate params haven't changed - self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], - var0.eval()) - self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], - var1.eval()) - lrate_decay_op.eval() - - with self.assertRaises(errors.InvalidArgumentError): - sgd_op.run() - - def testGradWrtRef(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - opt = VariationalSGDOptimizer(1, 1, max_learning_rate=1.0) - values = [1.0, 3.0] - vars_ = [variables.Variable([v], dtype=dtype) for v in values] - grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_) - variables.global_variables_initializer().run() - for grad, _ in grads_and_vars: - self.assertAllCloseAccordingToType([1.0], grad.eval()) - - def testWithGlobalStep(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - global_step = variables.Variable(0, trainable=False) - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - decay_rate = 0.1 - batch_size = 2 - total_num_examples = 10 - sgd_optimizer = VariationalSGDOptimizer( - batch_size, - total_num_examples, - max_learning_rate=3.0, - burnin=0, - preconditioner_decay_rate=decay_rate) - sgd_op = sgd_optimizer.apply_gradients( - zip([grads0, grads1], [var0, var1]), global_step=global_step) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - - # Validate updated params and global_step - self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], - var0.eval()) - self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], - var1.eval()) - self.assertAllCloseAccordingToType(1, global_step.eval()) - self.assertAllCloseAccordingToType(1, sgd_optimizer._counter.eval()) - - def testSparseBasic(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([[1.1], [2.1]], dtype=dtype) - var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) - grads0 = ops.IndexedSlices( - constant_op.constant([0.1], shape=[1, 1], dtype=dtype), - constant_op.constant([0]), constant_op.constant([2, 1])) - grads1 = ops.IndexedSlices( - constant_op.constant([0.01], shape=[1, 1], dtype=dtype), - constant_op.constant([1]), constant_op.constant([2, 1])) - decay_rate = 0.1 - batch_size = 2 - total_num_examples = 10 - sgd_op = VariationalSGDOptimizer( - batch_size, - total_num_examples, - max_learning_rate=3.0, - burnin=0, - preconditioner_decay_rate=decay_rate).apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([[1.1], [2.1]], var0.eval()) - self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType([[1.1 - 3.0 * 0.1], [2.1]], - var0.eval()) - self.assertAllCloseAccordingToType( - [[3.0 - 3.0 * 0], [4.0 - 3.0 * 0.01]], var1.eval()) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py deleted file mode 100644 index 9f7a95f138f7fd3e726f095dc16f41abb6182e17..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Csiszar f-Divergence and helpers. - -See ${python/contrib.bayesflow.csiszar_divergence}. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.csiszar_divergence_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = [ - 'amari_alpha', - 'arithmetic_geometric', - 'chi_square', - 'csiszar_vimco', - 'dual_csiszar_function', - 'jeffreys', - 'jensen_shannon', - 'kl_forward', - 'kl_reverse', - 'log1p_abs', - 'modified_gan', - 'monte_carlo_csiszar_f_divergence', - 'pearson', - 'squared_hellinger', - 'symmetrized_csiszar_function', - 'total_variation', - 't_power', - 'triangular', -] - -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py deleted file mode 100644 index 8efd59d6516924bea538717d45bb4ae303583421..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py +++ /dev/null @@ -1,1105 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Csiszar f-Divergence and helpers. - -@@amari_alpha -@@arithmetic_geometric -@@chi_square -@@csiszar_vimco -@@dual_csiszar_function -@@jeffreys -@@jensen_shannon -@@kl_forward -@@kl_reverse -@@log1p_abs -@@modified_gan -@@monte_carlo_csiszar_f_divergence -@@pearson -@@squared_hellinger -@@symmetrized_csiszar_function -@@total_variation -@@triangular - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib import framework as contrib_framework -from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import util as distribution_util - - -def amari_alpha(logu, alpha=1., self_normalized=False, name=None): - """The Amari-alpha Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True`, the Amari-alpha Csiszar-function is: - - ```none - f(u) = { -log(u) + (u - 1), alpha = 0 - { u log(u) - (u - 1), alpha = 1 - { [(u**alpha - 1) - alpha (u - 1)] / (alpha (alpha - 1)), otherwise - ``` - - When `self_normalized = False` the `(u - 1)` terms are omitted. - - Warning: when `alpha != 0` and/or `self_normalized = True` this function makes - non-log-space calculations and may therefore be numerically unstable for - `|logu| >> 0`. - - For more information, see: - A. Cichocki and S. Amari. "Families of Alpha-Beta-and GammaDivergences: - Flexible and Robust Measures of Similarities." Entropy, vol. 12, no. 6, pp. - 1532-1568, 2010. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - alpha: `float`-like Python scalar. (See Mathematical Details for meaning.) - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - amari_alpha_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - - Raises: - TypeError: if `alpha` is `None` or a `Tensor`. - TypeError: if `self_normalized` is `None` or a `Tensor`. - """ - with ops.name_scope(name, "amari_alpha", [logu]): - if alpha is None or contrib_framework.is_tensor(alpha): - raise TypeError("`alpha` cannot be `None` or `Tensor` type.") - if self_normalized is None or contrib_framework.is_tensor(self_normalized): - raise TypeError("`self_normalized` cannot be `None` or `Tensor` type.") - - logu = ops.convert_to_tensor(logu, name="logu") - - if alpha == 0.: - f = -logu - elif alpha == 1.: - f = math_ops.exp(logu) * logu - else: - f = math_ops.expm1(alpha * logu) / (alpha * (alpha - 1.)) - - if not self_normalized: - return f - - if alpha == 0.: - return f + math_ops.expm1(logu) - elif alpha == 1.: - return f - math_ops.expm1(logu) - else: - return f - math_ops.expm1(logu) / (alpha - 1.) - - -def kl_reverse(logu, self_normalized=False, name=None): - """The reverse Kullback-Leibler Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True`, the KL-reverse Csiszar-function is: - - ```none - f(u) = -log(u) + (u - 1) - ``` - - When `self_normalized = False` the `(u - 1)` term is omitted. - - Observe that as an f-Divergence, this Csiszar-function implies: - - ```none - D_f[p, q] = KL[q, p] - ``` - - The KL is "reverse" because in maximum likelihood we think of minimizing `q` - as in `KL[p, q]`. - - Warning: when self_normalized = True` this function makes non-log-space - calculations and may therefore be numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - kl_reverse_of_u: `float`-like `Tensor` of the Csiszar-function evaluated at - `u = exp(logu)`. - - Raises: - TypeError: if `self_normalized` is `None` or a `Tensor`. - """ - - with ops.name_scope(name, "kl_reverse", [logu]): - return amari_alpha(logu, alpha=0., self_normalized=self_normalized) - - -def kl_forward(logu, self_normalized=False, name=None): - """The forward Kullback-Leibler Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True`, the KL-forward Csiszar-function is: - - ```none - f(u) = u log(u) - (u - 1) - ``` - - When `self_normalized = False` the `(u - 1)` term is omitted. - - Observe that as an f-Divergence, this Csiszar-function implies: - - ```none - D_f[p, q] = KL[p, q] - ``` - - The KL is "forward" because in maximum likelihood we think of minimizing `q` - as in `KL[p, q]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - kl_forward_of_u: `float`-like `Tensor` of the Csiszar-function evaluated at - `u = exp(logu)`. - - Raises: - TypeError: if `self_normalized` is `None` or a `Tensor`. - """ - - with ops.name_scope(name, "kl_forward", [logu]): - return amari_alpha(logu, alpha=1., self_normalized=self_normalized) - - -def jensen_shannon(logu, self_normalized=False, name=None): - """The Jensen-Shannon Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True`, the Jensen-Shannon Csiszar-function is: - - ```none - f(u) = u log(u) - (1 + u) log(1 + u) + (u + 1) log(2) - ``` - - When `self_normalized = False` the `(u + 1) log(2)` term is omitted. - - Observe that as an f-Divergence, this Csiszar-function implies: - - ```none - D_f[p, q] = KL[p, m] + KL[q, m] - m(x) = 0.5 p(x) + 0.5 q(x) - ``` - - In a sense, this divergence is the "reverse" of the Arithmetic-Geometric - f-Divergence. - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - For more information, see: - Lin, J. "Divergence measures based on the Shannon entropy." IEEE Trans. - Inf. Th., 37, 145-151, 1991. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - jensen_shannon_of_u: `float`-like `Tensor` of the Csiszar-function - evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "jensen_shannon", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - npdt = logu.dtype.as_numpy_dtype - y = nn_ops.softplus(logu) - if self_normalized: - y -= np.log(2).astype(npdt) - return math_ops.exp(logu) * logu - (1. + math_ops.exp(logu)) * y - - -def arithmetic_geometric(logu, self_normalized=False, name=None): - """The Arithmetic-Geometric Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True` the Arithmetic-Geometric Csiszar-function is: - - ```none - f(u) = (1 + u) log( (1 + u) / sqrt(u) ) - (1 + u) log(2) - ``` - - When `self_normalized = False` the `(1 + u) log(2)` term is omitted. - - Observe that as an f-Divergence, this Csiszar-function implies: - - ```none - D_f[p, q] = KL[m, p] + KL[m, q] - m(x) = 0.5 p(x) + 0.5 q(x) - ``` - - In a sense, this divergence is the "reverse" of the Jensen-Shannon - f-Divergence. - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - arithmetic_geometric_of_u: `float`-like `Tensor` of the - Csiszar-function evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "arithmetic_geometric", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - y = nn_ops.softplus(logu) - 0.5 * logu - if self_normalized: - y -= np.log(2.).astype(logu.dtype.as_numpy_dtype) - return (1. + math_ops.exp(logu)) * y - - -def total_variation(logu, name=None): - """The Total Variation Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Total-Variation Csiszar-function is: - - ```none - f(u) = 0.5 |u - 1| - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - total_variation_of_u: `float`-like `Tensor` of the Csiszar-function - evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "total_variation", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return 0.5 * math_ops.abs(math_ops.expm1(logu)) - - -def pearson(logu, name=None): - """The Pearson Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Pearson Csiszar-function is: - - ```none - f(u) = (u - 1)**2 - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - pearson_of_u: `float`-like `Tensor` of the Csiszar-function evaluated at - `u = exp(logu)`. - """ - - with ops.name_scope(name, "pearson", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return math_ops.square(math_ops.expm1(logu)) - - -def squared_hellinger(logu, name=None): - """The Squared-Hellinger Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Squared-Hellinger Csiszar-function is: - - ```none - f(u) = (sqrt(u) - 1)**2 - ``` - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - squared_hellinger_of_u: `float`-like `Tensor` of the Csiszar-function - evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "squared_hellinger", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return pearson(0.5 * logu) - - -def triangular(logu, name=None): - """The Triangular Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Triangular Csiszar-function is: - - ```none - f(u) = (u - 1)**2 / (1 + u) - ``` - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - triangular_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "triangular", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return pearson(logu) / (1. + math_ops.exp(logu)) - - -def t_power(logu, t, self_normalized=False, name=None): - """The T-Power Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True` the T-Power Csiszar-function is: - - ```none - f(u) = s [ u**t - 1 - t(u - 1) ] - s = { -1 0 < t < 1 - { +1 otherwise - ``` - - When `self_normalized = False` the `- t(u - 1)` term is omitted. - - This is similar to the `amari_alpha` Csiszar-function, with the associated - divergence being the same up to factors depending only on `t`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - t: `Tensor` of same `dtype` as `logu` and broadcastable shape. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - t_power_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - with ops.name_scope(name, "t_power", [logu, t]): - logu = ops.convert_to_tensor(logu, name="logu") - t = ops.convert_to_tensor(t, dtype=logu.dtype.base_dtype, name="t") - fu = math_ops.expm1(t * logu) - if self_normalized: - fu -= t * math_ops.expm1(logu) - fu *= array_ops.where(math_ops.logical_and(0. < t, t < 1.), - -array_ops.ones_like(t), - array_ops.ones_like(t)) - return fu - - -def log1p_abs(logu, name=None): - """The log1p-abs Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Log1p-Abs Csiszar-function is: - - ```none - f(u) = u**(sign(u-1)) - 1 - ``` - - This function is so-named because it was invented from the following recipe. - Choose a convex function g such that g(0)=0 and solve for f: - - ```none - log(1 + f(u)) = g(log(u)). - <=> - f(u) = exp(g(log(u))) - 1 - ``` - - That is, the graph is identically `g` when y-axis is `log1p`-domain and x-axis - is `log`-domain. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - log1p_abs_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "log1p_abs", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return math_ops.expm1(math_ops.abs(logu)) - - -def jeffreys(logu, name=None): - """The Jeffreys Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Jeffreys Csiszar-function is: - - ```none - f(u) = 0.5 ( u log(u) - log(u) ) - = 0.5 kl_forward + 0.5 kl_reverse - = symmetrized_csiszar_function(kl_reverse) - = symmetrized_csiszar_function(kl_forward) - ``` - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - jeffreys_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "jeffreys", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return 0.5 * math_ops.expm1(logu) * logu - - -def chi_square(logu, name=None): - """The chi-Square Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Chi-square Csiszar-function is: - - ```none - f(u) = u**2 - 1 - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - chi_square_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "chi_square", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return math_ops.expm1(2. * logu) - - -def modified_gan(logu, self_normalized=False, name=None): - """The Modified-GAN Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True` the modified-GAN (Generative/Adversarial - Network) Csiszar-function is: - - ```none - f(u) = log(1 + u) - log(u) + 0.5 (u - 1) - ``` - - When `self_normalized = False` the `0.5 (u - 1)` is omitted. - - The unmodified GAN Csiszar-function is identical to Jensen-Shannon (with - `self_normalized = False`). - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - chi_square_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "chi_square", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - y = nn_ops.softplus(logu) - logu - if self_normalized: - y += 0.5 * math_ops.expm1(logu) - return y - - -def dual_csiszar_function(logu, csiszar_function, name=None): - """Calculates the dual Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Csiszar-dual is defined as: - - ```none - f^*(u) = u f(1 / u) - ``` - - where `f` is some other Csiszar-function. - - For example, the dual of `kl_reverse` is `kl_forward`, i.e., - - ```none - f(u) = -log(u) - f^*(u) = u f(1 / u) = -u log(1 / u) = u log(u) - ``` - - The dual of the dual is the original function: - - ```none - f^**(u) = {u f(1/u)}^*(u) = u (1/u) f(1/(1/u)) = f(u) - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - csiszar_function: Python `callable` representing a Csiszar-function over - log-domain. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - dual_f_of_u: `float`-like `Tensor` of the result of calculating the dual of - `f` at `u = exp(logu)`. - """ - - with ops.name_scope(name, "dual_csiszar_function", [logu]): - return math_ops.exp(logu) * csiszar_function(-logu) - - -def symmetrized_csiszar_function(logu, csiszar_function, name=None): - """Symmetrizes a Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The symmetrized Csiszar-function is defined as: - - ```none - f_g(u) = 0.5 g(u) + 0.5 u g (1 / u) - ``` - - where `g` is some other Csiszar-function. - - We say the function is "symmetrized" because: - - ```none - D_{f_g}[p, q] = D_{f_g}[q, p] - ``` - - for all `p << >> q` (i.e., `support(p) = support(q)`). - - There exists alternatives for symmetrizing a Csiszar-function. For example, - - ```none - f_g(u) = max(f(u), f^*(u)), - ``` - - where `f^*` is the dual Csiszar-function, also implies a symmetric - f-Divergence. - - Example: - - When either of the following functions are symmetrized, we obtain the - Jensen-Shannon Csiszar-function, i.e., - - ```none - g(u) = -log(u) - (1 + u) log((1 + u) / 2) + u - 1 - h(u) = log(4) + 2 u log(u / (1 + u)) - ``` - - implies, - - ```none - f_g(u) = f_h(u) = u log(u) - (1 + u) log((1 + u) / 2) - = jensen_shannon(log(u)). - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - csiszar_function: Python `callable` representing a Csiszar-function over - log-domain. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - symmetrized_g_of_u: `float`-like `Tensor` of the result of applying the - symmetrization of `g` evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "symmetrized_csiszar_function", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return 0.5 * (csiszar_function(logu) - + dual_csiszar_function(logu, csiszar_function)) - - -def monte_carlo_csiszar_f_divergence( - f, - p_log_prob, - q, - num_draws, - use_reparametrization=None, - seed=None, - name=None): - """Monte-Carlo approximation of the Csiszar f-Divergence. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Csiszar f-Divergence for Csiszar-function f is given by: - - ```none - D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ] - ~= m**-1 sum_j^m f( p(x_j) / q(x_j) ), - where x_j ~iid q(X) - ``` - - Tricks: Reparameterization and Score-Gradient - - When q is "reparameterized", i.e., a diffeomorphic transformation of a - parameterless distribution (e.g., - `Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and - expectation, i.e., - `grad[Avg{ s_i : i=1...n }] = Avg{ grad[s_i] : i=1...n }` where `S_n=Avg{s_i}` - and `s_i = f(x_i), x_i ~iid q(X)`. - - However, if q is not reparameterized, TensorFlow's gradient will be incorrect - since the chain-rule stops at samples of unreparameterized distributions. In - this circumstance using the Score-Gradient trick results in an unbiased - gradient, i.e., - - ```none - grad[ E_q[f(X)] ] - = grad[ int dx q(x) f(x) ] - = int dx grad[ q(x) f(x) ] - = int dx [ q'(x) f(x) + q(x) f'(x) ] - = int dx q(x) [q'(x) / q(x) f(x) + f'(x) ] - = int dx q(x) grad[ f(x) q(x) / stop_grad[q(x)] ] - = E_q[ grad[ f(x) q(x) / stop_grad[q(x)] ] ] - ``` - - Unless `q.reparameterization_type != distribution.FULLY_REPARAMETERIZED` it is - usually preferable to set `use_reparametrization = True`. - - Example Application: - - The Csiszar f-Divergence is a useful framework for variational inference. - I.e., observe that, - - ```none - f(p(x)) = f( E_{q(Z | x)}[ p(x, Z) / q(Z | x) ] ) - <= E_{q(Z | x)}[ f( p(x, Z) / q(Z | x) ) ] - := D_f[p(x, Z), q(Z | x)] - ``` - - The inequality follows from the fact that the "perspective" of `f`, i.e., - `(s, t) |-> t f(s / t))`, is convex in `(s, t)` when `s/t in domain(f)` and - `t` is a real. Since the above framework includes the popular Evidence Lower - BOund (ELBO) as a special case, i.e., `f(u) = -log(u)`, we call this framework - "Evidence Divergence Bound Optimization" (EDBO). - - Args: - f: Python `callable` representing a Csiszar-function in log-space, i.e., - takes `p_log_prob(q_samples) - q.log_prob(q_samples)`. - p_log_prob: Python `callable` taking (a batch of) samples from `q` and - returning the natural-log of the probability under distribution `p`. - (In variational inference `p` is the joint distribution.) - q: `tf.Distribution`-like instance; must implement: - `reparameterization_type`, `sample(n, seed)`, and `log_prob(x)`. - (In variational inference `q` is the approximate posterior distribution.) - num_draws: Integer scalar number of draws used to approximate the - f-Divergence expectation. - use_reparametrization: Python `bool`. When `None` (the default), - automatically set to: - `q.reparameterization_type == distribution.FULLY_REPARAMETERIZED`. - When `True` uses the standard Monte-Carlo average. When `False` uses the - score-gradient trick. (See above for details.) When `False`, consider - using `csiszar_vimco`. - seed: Python `int` seed for `q.sample`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - monte_carlo_csiszar_f_divergence: `float`-like `Tensor` Monte Carlo - approximation of the Csiszar f-Divergence. - - Raises: - ValueError: if `q` is not a reparameterized distribution and - `use_reparametrization = True`. A distribution `q` is said to be - "reparameterized" when its samples are generated by transforming the - samples of another distribution which does not depend on the - parameterization of `q`. This property ensures the gradient (with respect - to parameters) is valid. - TypeError: if `p_log_prob` is not a Python `callable`. - """ - with ops.name_scope(name, "monte_carlo_csiszar_f_divergence", [num_draws]): - if use_reparametrization is None: - use_reparametrization = (q.reparameterization_type - == distribution.FULLY_REPARAMETERIZED) - elif (use_reparametrization and - q.reparameterization_type != distribution.FULLY_REPARAMETERIZED): - # TODO(jvdillon): Consider only raising an exception if the gradient is - # requested. - raise ValueError( - "Distribution `q` must be reparameterized, i.e., a diffeomorphic " - "transformation of a parameterless distribution. (Otherwise this " - "function has a biased gradient.)") - if not callable(p_log_prob): - raise TypeError("`p_log_prob` must be a Python `callable` function.") - return monte_carlo.expectation( - f=lambda q_samples: f(p_log_prob(q_samples) - q.log_prob(q_samples)), - samples=q.sample(num_draws, seed=seed), - log_prob=q.log_prob, # Only used if use_reparametrization=False. - use_reparametrization=use_reparametrization) - - -def csiszar_vimco(f, - p_log_prob, - q, - num_draws, - num_batch_draws=1, - seed=None, - name=None): - """Use VIMCO to lower the variance of gradient[csiszar_function(Avg(logu))]. - - This function generalizes "Variational Inference for Monte Carlo Objectives" - (VIMCO), i.e., https://arxiv.org/abs/1602.06725, to Csiszar f-Divergences. - - Note: if `q.reparameterization_type = distribution.FULLY_REPARAMETERIZED`, - consider using `monte_carlo_csiszar_f_divergence`. - - The VIMCO loss is: - - ```none - vimco = f(Avg{logu[i] : i=0,...,m-1}) - where, - logu[i] = log( p(x, h[i]) / q(h[i] | x) ) - h[i] iid~ q(H | x) - ``` - - Interestingly, the VIMCO gradient is not the naive gradient of `vimco`. - Rather, it is characterized by: - - ```none - grad[vimco] - variance_reducing_term - where, - variance_reducing_term = Sum{ grad[log q(h[i] | x)] * - (vimco - f(log Avg{h[j;i] : j=0,...,m-1})) - : i=0, ..., m-1 } - h[j;i] = { u[j] j!=i - { GeometricAverage{ u[k] : k!=i} j==i - ``` - - (We omitted `stop_gradient` for brevity. See implementation for more details.) - - The `Avg{h[j;i] : j}` term is a kind of "swap-out average" where the `i`-th - element has been replaced by the leave-`i`-out Geometric-average. - - This implementation prefers numerical precision over efficiency, i.e., - `O(num_draws * num_batch_draws * prod(batch_shape) * prod(event_shape))`. - (The constant may be fairly large, perhaps around 12.) - - Args: - f: Python `callable` representing a Csiszar-function in log-space. - p_log_prob: Python `callable` representing the natural-log of the - probability under distribution `p`. (In variational inference `p` is the - joint distribution.) - q: `tf.Distribution`-like instance; must implement: `sample(n, seed)`, and - `log_prob(x)`. (In variational inference `q` is the approximate posterior - distribution.) - num_draws: Integer scalar number of draws used to approximate the - f-Divergence expectation. - num_batch_draws: Integer scalar number of draws used to approximate the - f-Divergence expectation. - seed: Python `int` seed for `q.sample`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - vimco: The Csiszar f-Divergence generalized VIMCO objective. - - Raises: - ValueError: if `num_draws < 2`. - """ - with ops.name_scope(name, "csiszar_vimco", [num_draws, num_batch_draws]): - if num_draws < 2: - raise ValueError("Must specify num_draws > 1.") - stop = array_ops.stop_gradient # For readability. - x = stop(q.sample(sample_shape=[num_draws, num_batch_draws], - seed=seed)) - logqx = q.log_prob(x) - logu = p_log_prob(x) - logqx - f_log_avg_u, f_log_sooavg_u = [f(r) for r in csiszar_vimco_helper(logu)] - dotprod = math_ops.reduce_sum( - logqx * stop(f_log_avg_u - f_log_sooavg_u), - axis=0) # Sum over iid samples. - # We now rewrite f_log_avg_u so that: - # `grad[f_log_avg_u] := grad[f_log_avg_u + dotprod]`. - # To achieve this, we use a trick that - # `f(x) - stop(f(x)) == zeros_like(f(x))` - # but its gradient is grad[f(x)]. - # Note that IEEE754 specifies that `x - x == 0.` and `x + 0. == x`, hence - # this trick loses no precision. For more discussion regarding the relevant - # portions of the IEEE754 standard, see the StackOverflow question, - # "Is there a floating point value of x, for which x-x == 0 is false?" - # http://stackoverflow.com/q/2686644 - f_log_avg_u += dotprod - stop(dotprod) # Add zeros_like(dot_prod). - return math_ops.reduce_mean(f_log_avg_u, axis=0) # Avg over batches. - - -def csiszar_vimco_helper(logu, name=None): - """Helper to `csiszar_vimco`; computes `log_avg_u`, `log_sooavg_u`. - - `axis = 0` of `logu` is presumed to correspond to iid samples from `q`, i.e., - - ```none - logu[j] = log(u[j]) - u[j] = p(x, h[j]) / q(h[j] | x) - h[j] iid~ q(H | x) - ``` - - Args: - logu: Floating-type `Tensor` representing `log(p(x, h) / q(h | x))`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - log_avg_u: `logu.dtype` `Tensor` corresponding to the natural-log of the - average of `u`. The sum of the gradient of `log_avg_u` is `1`. - log_sooavg_u: `logu.dtype` `Tensor` characterized by the natural-log of the - average of `u`` except that the average swaps-out `u[i]` for the - leave-`i`-out Geometric-average. The mean of the gradient of - `log_sooavg_u` is `1`. Mathematically `log_sooavg_u` is, - ```none - log_sooavg_u[i] = log(Avg{h[j ; i] : j=0, ..., m-1}) - h[j ; i] = { u[j] j!=i - { GeometricAverage{u[k] : k != i} j==i - ``` - - """ - with ops.name_scope(name, "csiszar_vimco_helper", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - - n = logu.shape.with_rank_at_least(1)[0].value - if n is None: - n = array_ops.shape(logu)[0] - log_n = math_ops.log(math_ops.cast(n, dtype=logu.dtype)) - nm1 = math_ops.cast(n - 1, dtype=logu.dtype) - else: - log_n = np.log(n).astype(logu.dtype.as_numpy_dtype) - nm1 = np.asarray(n - 1, dtype=logu.dtype.as_numpy_dtype) - - # Throughout we reduce across axis=0 since this is presumed to be iid - # samples. - - log_max_u = math_ops.reduce_max(logu, axis=0) - log_sum_u_minus_log_max_u = math_ops.reduce_logsumexp( - logu - log_max_u, axis=0) - - # log_loosum_u[i] = - # = logsumexp(logu[j] : j != i) - # = log( exp(logsumexp(logu)) - exp(logu[i]) ) - # = log( exp(logsumexp(logu - logu[i])) exp(logu[i]) - exp(logu[i])) - # = logu[i] + log(exp(logsumexp(logu - logu[i])) - 1) - # = logu[i] + log(exp(logsumexp(logu) - logu[i]) - 1) - # = logu[i] + softplus_inverse(logsumexp(logu) - logu[i]) - d = log_sum_u_minus_log_max_u + (log_max_u - logu) - # We use `d != 0` rather than `d > 0.` because `d < 0.` should never - # happens; if it does we want to complain loudly (which `softplus_inverse` - # will). - d_ok = math_ops.not_equal(d, 0.) - safe_d = array_ops.where(d_ok, d, array_ops.ones_like(d)) - d_ok_result = logu + distribution_util.softplus_inverse(safe_d) - - inf = np.array(np.inf, dtype=logu.dtype.as_numpy_dtype) - - # When not(d_ok) and is_positive_and_largest then we manually compute the - # log_loosum_u. (We can efficiently do this for any one point but not all, - # hence we still need the above calculation.) This is good because when - # this condition is met, we cannot use the above calculation; its -inf. - # We now compute the log-leave-out-max-sum, replicate it to every - # point and make sure to select it only when we need to. - is_positive_and_largest = math_ops.logical_and( - logu > 0., - math_ops.equal(logu, log_max_u[array_ops.newaxis, ...])) - log_lomsum_u = math_ops.reduce_logsumexp( - array_ops.where(is_positive_and_largest, - array_ops.fill(array_ops.shape(logu), -inf), - logu), - axis=0, keep_dims=True) - log_lomsum_u = array_ops.tile( - log_lomsum_u, - multiples=1 + array_ops.pad([n-1], [[0, array_ops.rank(logu)-1]])) - - d_not_ok_result = array_ops.where( - is_positive_and_largest, - log_lomsum_u, - array_ops.fill(array_ops.shape(d), -inf)) - - log_loosum_u = array_ops.where(d_ok, d_ok_result, d_not_ok_result) - - # The swap-one-out-sum ("soosum") is n different sums, each of which - # replaces the i-th item with the i-th-left-out average, i.e., - # soo_sum_u[i] = [exp(logu) - exp(logu[i])] + exp(mean(logu[!=i])) - # = exp(log_loosum_u[i]) + exp(looavg_logu[i]) - looavg_logu = (math_ops.reduce_sum(logu, axis=0) - logu) / nm1 - log_soosum_u = math_ops.reduce_logsumexp( - array_ops.stack([log_loosum_u, looavg_logu]), - axis=0) - - log_avg_u = log_sum_u_minus_log_max_u + log_max_u - log_n - log_sooavg_u = log_soosum_u - log_n - - log_avg_u.set_shape(logu.shape.with_rank_at_least(1)[1:]) - log_sooavg_u.set_shape(logu.shape) - - return log_avg_u, log_sooavg_u diff --git a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py deleted file mode 100644 index d44fe6529a7ff0da0c6747e193fdb98a272a8da3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functions for specifying custom gradients. - -@@custom_gradient - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops - -__all__ = [ - "custom_gradient", -] - - -def custom_gradient(fx, gx, x, axis=(), fx_gx_manually_stopped=False, - name=None): - """Enables specifying a custom gradient. - - This function works by clever application of `stop_gradient`. I.e., observe - that: - - ```none - h(x) = x * stop_gradient(g(x)) + stop_gradient(f(x) - x * g(x)) - ``` - - is such that `h(x) = stop_gradient(f(x))` and `grad[h(x), x] = - stop_gradient(g(x)).` - - In addition to scalar-domain/scalar-range functions, this function also - supports tensor-domain/scalar-range functions. However, in the latter case it - is necessary to reduce `x` to a scalar. This can be done by indicating the - `axis` over which `f` operates or by appropriately `reduce_sum`-ing `x`, prior - to calling this function. - - Partial Custom Gradient: - - Suppose `h(x) = htilde(x, y)`. Note that `dh/dx = stop(g(x))` but `dh/dy = - None`. This is because a `Tensor` cannot have only a portion of its gradient - stopped. To circumvent this issue, one must manually `stop_gradient` the - relevant portions of `f`, `g`. For example see the unit-test, - `test_works_correctly_fx_gx_manually_stopped`. - - Args: - fx: `Tensor`. Output of function evaluated at `x`. - gx: `Tensor`. Gradient of function evaluated at `x`. - x: `Tensor`. Point of evaluation for `f, g`. - axis: 1D `int` `Tensor` representing dimensions of `x` which are the domain - of `f`. If `()` (the default), `f` is assumed scalar-domain/scalar-range. - If `None` `f` is assumed to render one scalar given all of `x`. Otherwise - `f` is assumed to output one scalar for each of `axis` dimensions of `x`. - fx_gx_manually_stopped: Python `bool` indicating that `fx`, `gx` manually - have `stop_gradient` applied. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - fx: Floating-type `Tensor` equal to `f(x)` but which has gradient - `stop_gradient(g(x))`. - """ - with ops.name_scope(name, "custom_gradient", [fx, gx, x]): - fx = ops.convert_to_tensor(fx, name="fx") - # We don't want to bother eagerly computing `gx` since we may not even need - # it. - with ops.control_dependencies([fx]): - gx = ops.convert_to_tensor(gx, dtype=fx.dtype, name="gx") - gx = array_ops.identity(gx, name="gx") - # Proof of correctness: - # - # f(x) = x * stop[gx] + stop[fx - x * gx] - # = stop[fx] - # - # g(x) = grad[fx] - # = stop[gx] + grad[stop[fx - x * gx]] - # = stop[gx] + 0 - # - # Notice that when x is zero it still works: - # grad[x * stop(gx) + stop(fx - x * gx)] = 1 * stop[gx] + 0 = stop[gx] - # - # The proof is similar for the tensor-domain case, except that `x` is - # replaced by `reduce_sum(x)`. - sum_x = math_ops.reduce_sum(x, axis=axis, name="sum_x") - if not fx_gx_manually_stopped: - fx = array_ops.stop_gradient(fx) - gx = array_ops.stop_gradient(gx) - # IEEE754 ensures `(x-x)==0.` and that `0.*x==0.` so we make sure to write - # the code this way, rather than, e.g., - # `sum_x * stop(gx) + stop(fx - sum_x * gx)`. - # For more discussion regarding the relevant portions of the IEEE754 - # standard, see the StackOverflow question, - # "Is there a floating point value of x, for which x-x == 0 is false?" - # http://stackoverflow.com/q/2686644 - return (sum_x - array_ops.stop_gradient(sum_x)) * gx + fx diff --git a/tensorflow/contrib/bayesflow/python/ops/docstring_util.py b/tensorflow/contrib/bayesflow/python/ops/docstring_util.py deleted file mode 100644 index 081f2d5a8bfd437fd173f63b4226fb7df6ca921c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/docstring_util.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities for programmable docstrings. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import re -import six - - -def expand_docstring(**kwargs): - """Decorator to programmatically expand the docstring. - - Args: - **kwargs: Keyword arguments to set. For each key-value pair `k` and `v`, - the key is found as `@{k}` in the docstring and replaced with `v`. - - Returns: - Decorated function. - """ - def _fn_wrapped(fn): - """Original function with modified `__doc__` attribute.""" - doc = _trim(fn.__doc__) - for k, v in six.iteritems(kwargs): - # Capture each @{k} reference to replace with v. - # We wrap the replacement in a function so no backslash escapes - # are processed. - pattern = r'@\{' + str(k) + r'\}' - doc = re.sub(pattern, lambda match: v, doc) # pylint: disable=cell-var-from-loop - fn.__doc__ = doc - return fn - return _fn_wrapped - - -def _trim(docstring): - """Trims docstring indentation. - - In general, multi-line docstrings carry their level of indentation when - defined under a function or class method. This function standardizes - indentation levels by removing them. Taken from PEP 257 docs. - - Args: - docstring: Python string to trim indentation. - - Returns: - Trimmed docstring. - """ - if not docstring: - return '' - # Convert tabs to spaces (following the normal Python rules) - # and split into a list of lines: - lines = docstring.expandtabs().splitlines() - # Determine minimum indentation (first line doesn't count): - indent = None - for line in lines[1:]: - stripped = line.lstrip() - if stripped: - if indent is None: - indent = len(line) - len(stripped) - else: - indent = min(indent, len(line) - len(stripped)) - # Remove indentation (first line is special): - trimmed = [lines[0].strip()] - if indent is not None: - for line in lines[1:]: - trimmed.append(line[indent:].rstrip()) - # Strip off trailing and leading blank lines: - while trimmed and not trimmed[-1]: - trimmed.pop() - while trimmed and not trimmed[0]: - trimmed.pop(0) - # Return a single string: - return '\n'.join(trimmed) diff --git a/tensorflow/contrib/bayesflow/python/ops/halton_sequence.py b/tensorflow/contrib/bayesflow/python/ops/halton_sequence.py deleted file mode 100644 index 49d747d538f5a4aa3134d28ba00a651cb509fa41..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/halton_sequence.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Support for low discrepancy Halton sequences. - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.halton_sequence_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = [ - 'sample', -] - -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py b/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py deleted file mode 100644 index 8cabf18903b5f15002470acdfb8fdd3ec31a7413..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Quasi Monte Carlo support: Halton sequence. - -@@sample -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops - - -__all__ = [ - 'sample', -] - - -# The maximum dimension we support. This is limited by the number of primes -# in the _PRIMES array. -_MAX_DIMENSION = 1000 - - -def sample(dim, num_samples=None, sample_indices=None, dtype=None, name=None): - r"""Returns a sample from the `m` dimensional Halton sequence. - - Warning: The sequence elements take values only between 0 and 1. Care must be - taken to appropriately transform the domain of a function if it differs from - the unit cube before evaluating integrals using Halton samples. It is also - important to remember that quasi-random numbers are not a replacement for - pseudo-random numbers in every context. Quasi random numbers are completely - deterministic and typically have significant negative autocorrelation (unless - randomized). - - Computes the members of the low discrepancy Halton sequence in dimension - `dim`. The d-dimensional sequence takes values in the unit hypercube in d - dimensions. Currently, only dimensions up to 1000 are supported. The prime - base for the `k`-th axes is the k-th prime starting from 2. For example, - if dim = 3, then the bases will be [2, 3, 5] respectively and the first - element of the sequence will be: [0.5, 0.333, 0.2]. For a more complete - description of the Halton sequences see: - https://en.wikipedia.org/wiki/Halton_sequence. For low discrepancy sequences - and their applications see: - https://en.wikipedia.org/wiki/Low-discrepancy_sequence. - - The user must supply either `num_samples` or `sample_indices` but not both. - The former is the number of samples to produce starting from the first - element. If `sample_indices` is given instead, the specified elements of - the sequence are generated. For example, sample_indices=tf.range(10) is - equivalent to specifying n=10. - - Example Use: - - ```python - bf = tf.contrib.bayesflow - - # Produce the first 1000 members of the Halton sequence in 3 dimensions. - num_samples = 1000 - dim = 3 - sample = bf.halton_sequence.sample(dim, num_samples=num_samples) - - # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional - # hypercube. - powers = tf.range(1.0, limit=dim + 1) - integral = tf.reduce_mean(tf.reduce_prod(sample ** powers, axis=-1)) - true_value = 1.0 / tf.reduce_prod(powers + 1.0) - with tf.Session() as session: - values = session.run((integral, true_value)) - - # Produces a relative absolute error of 1.7%. - print ("Estimated: %f, True Value: %f" % values) - - # Now skip the first 1000 samples and recompute the integral with the next - # thousand samples. The sample_indices argument can be used to do this. - - - sample_indices = tf.range(start=1000, limit=1000 + num_samples, - dtype=tf.int32) - sample_leaped = halton.sample(dim, sample_indices=sample_indices) - - integral_leaped = tf.reduce_mean(tf.reduce_prod(sample_leaped ** powers, - axis=-1)) - with tf.Session() as session: - values = session.run((integral_leaped, true_value)) - # Now produces a relative absolute error of 0.05%. - print ("Leaped Estimated: %f, True Value: %f" % values) - ``` - - Args: - dim: Positive Python `int` representing each sample's `event_size.` Must - not be greater than 1000. - num_samples: (Optional) positive Python `int`. The number of samples to - generate. Either this parameter or sample_indices must be specified but - not both. If this parameter is None, then the behaviour is determined by - the `sample_indices`. - sample_indices: (Optional) `Tensor` of dtype int32 and rank 1. The elements - of the sequence to compute specified by their position in the sequence. - The entries index into the Halton sequence starting with 0 and hence, - must be whole numbers. For example, sample_indices=[0, 5, 6] will produce - the first, sixth and seventh elements of the sequence. If this parameter - is None, then the `num_samples` parameter must be specified which gives - the number of desired samples starting from the first sample. - dtype: (Optional) The dtype of the sample. One of `float32` or `float64`. - Default is `float32`. - name: (Optional) Python `str` describing ops managed by this function. If - not supplied the name of this function is used. - - Returns: - halton_elements: Elements of the Halton sequence. `Tensor` of supplied dtype - and `shape` `[num_samples, dim]` if `num_samples` was specified or shape - `[s, dim]` where s is the size of `sample_indices` if `sample_indices` - were specified. - - Raises: - ValueError: if both `sample_indices` and `num_samples` were specified or - if dimension `dim` is less than 1 or greater than 1000. - """ - if dim < 1 or dim > _MAX_DIMENSION: - raise ValueError( - 'Dimension must be between 1 and {}. Supplied {}'.format(_MAX_DIMENSION, - dim)) - if (num_samples is None) == (sample_indices is None): - raise ValueError('Either `num_samples` or `sample_indices` must be' - ' specified but not both.') - - dtype = dtype or dtypes.float32 - if not dtype.is_floating: - raise ValueError('dtype must be of `float`-type') - - with ops.name_scope(name, 'sample', values=[sample_indices]): - # Here and in the following, the shape layout is as follows: - # [sample dimension, event dimension, coefficient dimension]. - # The coefficient dimension is an intermediate axes which will hold the - # weights of the starting integer when expressed in the (prime) base for - # an event dimension. - indices = _get_indices(num_samples, sample_indices, dtype) - radixes = array_ops.constant(_PRIMES[0:dim], dtype=dtype, shape=[dim, 1]) - - max_sizes_by_axes = _base_expansion_size(math_ops.reduce_max(indices), - radixes) - - max_size = math_ops.reduce_max(max_sizes_by_axes) - - # The powers of the radixes that we will need. Note that there is a bit - # of an excess here. Suppose we need the place value coefficients of 7 - # in base 2 and 3. For 2, we will have 3 digits but we only need 2 digits - # for base 3. However, we can only create rectangular tensors so we - # store both expansions in a [2, 3] tensor. This leads to the problem that - # we might end up attempting to raise large numbers to large powers. For - # example, base 2 expansion of 1024 has 10 digits. If we were in 10 - # dimensions, then the 10th prime (29) we will end up computing 29^10 even - # though we don't need it. We avoid this by setting the exponents for each - # axes to 0 beyond the maximum value needed for that dimension. - exponents_by_axes = array_ops.tile([math_ops.range(max_size)], [dim, 1]) - weight_mask = exponents_by_axes > max_sizes_by_axes - capped_exponents = array_ops.where( - weight_mask, array_ops.zeros_like(exponents_by_axes), exponents_by_axes) - weights = radixes ** capped_exponents - coeffs = math_ops.floor_div(indices, weights) - coeffs *= 1 - math_ops.cast(weight_mask, dtype) - coeffs = (coeffs % radixes) / radixes - return math_ops.reduce_sum(coeffs / weights, axis=-1) - - -def _get_indices(n, sample_indices, dtype, name=None): - """Generates starting points for the Halton sequence procedure. - - The k'th element of the sequence is generated starting from a positive integer - which must be distinct for each `k`. It is conventional to choose the starting - point as `k` itself (or `k+1` if k is zero based). This function generates - the starting integers for the required elements and reshapes the result for - later use. - - Args: - n: Positive `int`. The number of samples to generate. If this - parameter is supplied, then `sample_indices` should be None. - sample_indices: `Tensor` of dtype int32 and rank 1. The entries - index into the Halton sequence starting with 0 and hence, must be whole - numbers. For example, sample_indices=[0, 5, 6] will produce the first, - sixth and seventh elements of the sequence. If this parameter is not None - then `n` must be None. - dtype: The dtype of the sample. One of `float32` or `float64`. - Default is `float32`. - name: Python `str` name which describes ops created by this function. - - Returns: - indices: `Tensor` of dtype `dtype` and shape = `[n, 1, 1]`. - """ - with ops.name_scope(name, 'get_indices', [n, sample_indices]): - if sample_indices is None: - sample_indices = math_ops.range(n, dtype=dtype) - else: - sample_indices = math_ops.cast(sample_indices, dtype) - - # Shift the indices so they are 1 based. - indices = sample_indices + 1 - - # Reshape to make space for the event dimension and the place value - # coefficients. - return array_ops.reshape(indices, [-1, 1, 1]) - - -def _base_expansion_size(num, bases): - """Computes the number of terms in the place value expansion. - - Let num = a0 + a1 b + a2 b^2 + ... ak b^k be the place value expansion of - `num` in base b (ak <> 0). This function computes and returns `k` for each - base `b` specified in `bases`. - - This can be inferred from the base `b` logarithm of `num` as follows: - $$k = Floor(log_b (num)) + 1 = Floor( log(num) / log(b)) + 1$$ - - Args: - num: Scalar `Tensor` of dtype either `float32` or `float64`. The number to - compute the base expansion size of. - bases: `Tensor` of the same dtype as num. The bases to compute the size - against. - - Returns: - Tensor of same dtype and shape as `bases` containing the size of num when - written in that base. - """ - return math_ops.floor(math_ops.log(num) / math_ops.log(bases)) + 1 - - -def _primes_less_than(n): - # Based on - # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188 - """Returns sorted array of primes such that `2 <= prime < n`.""" - small_primes = np.array((2, 3, 5)) - if n <= 6: - return small_primes[small_primes < n] - sieve = np.ones(n // 3 + (n % 6 == 2), dtype=np.bool) - sieve[0] = False - m = int(n ** 0.5) // 3 + 1 - for i in range(m): - if not sieve[i]: - continue - k = 3 * i + 1 | 1 - sieve[k ** 2 // 3::2 * k] = False - sieve[(k ** 2 + 4 * k - 2 * k * (i & 1)) // 3::2 * k] = False - return np.r_[2, 3, 3 * np.nonzero(sieve)[0] + 1 | 1] - -_PRIMES = _primes_less_than(7919+1) - -assert len(_PRIMES) == _MAX_DIMENSION diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc.py b/tensorflow/contrib/bayesflow/python/ops/hmc.py deleted file mode 100644 index 7fd5652c5c3e085b23c05baef6e3a42b7a42e08f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/hmc.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# go/tf-wildcard-import -from tensorflow.contrib.bayesflow.python.ops.hmc_impl import * # pylint: disable=wildcard-import,unused-wildcard-import,g-importing-member -from tensorflow.python.util import all_util - -_allowed_symbols = [ - "sample_chain", - "sample_annealed_importance_chain", - "kernel", -] - -all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py deleted file mode 100644 index 82693c2b7bcdbca9f6f4a1d799be5728bb5d36bf..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py +++ /dev/null @@ -1,1178 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm. - -@@sample_chain -@@sample_annealed_importance_chain -@@kernel -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import numpy as np - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import gradients_impl as gradients_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import util as distributions_util - -__all__ = [ - "sample_chain", - "sample_annealed_importance_chain", - "kernel", -] - - -KernelResults = collections.namedtuple( - "KernelResults", - [ - "log_accept_ratio", - "current_grads_target_log_prob", # "Current result" means "accepted". - "current_target_log_prob", # "Current result" means "accepted". - "is_accepted", - "proposed_grads_target_log_prob", - "proposed_state", - "proposed_target_log_prob", - ]) - - -def _make_dummy_kernel_results( - dummy_state, - dummy_target_log_prob, - dummy_grads_target_log_prob): - return KernelResults( - log_accept_ratio=dummy_target_log_prob, - current_grads_target_log_prob=dummy_grads_target_log_prob, - current_target_log_prob=dummy_target_log_prob, - is_accepted=array_ops.ones_like(dummy_target_log_prob, dtypes.bool), - proposed_grads_target_log_prob=dummy_grads_target_log_prob, - proposed_state=dummy_state, - proposed_target_log_prob=dummy_target_log_prob, - ) - - -def sample_chain( - num_results, - target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - num_burnin_steps=0, - num_steps_between_results=0, - seed=None, - current_target_log_prob=None, - current_grads_target_log_prob=None, - name=None): - """Runs multiple iterations of one or more Hamiltonian Monte Carlo chains. - - Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm - that takes a series of gradient-informed steps to produce a Metropolis - proposal. This function samples from an HMC Markov chain at `current_state` - and whose stationary distribution has log-unnormalized-density - `target_log_prob_fn()`. - - This function samples from multiple chains in parallel. It assumes that the - the leftmost dimensions of (each) `current_state` (part) index an independent - chain. The function `target_log_prob_fn()` sums log-probabilities across - event dimensions (i.e., current state (part) rightmost dimensions). Each - element of the output of `target_log_prob_fn()` represents the (possibly - unnormalized) log-probability of the joint distribution over (all) the current - state (parts). - - The `current_state` can be represented as a single `Tensor` or a `list` of - `Tensors` which collectively represent the current state. When specifying a - `list`, one must also specify a list of `step_size`s. - - Note: `target_log_prob_fn` is called exactly twice. - - Since HMC states are correlated, it is sometimes desirable to produce - additional intermediate states, and then discard them, ending up with a set of - states with decreased autocorrelation. See [1]. Such "thinning" is made - possible by setting `num_steps_between_results > 0`. The chain then takes - `num_steps_between_results` extra steps between the steps that make it into - the results. The extra steps are never materialized (in calls to `sess.run`), - and thus do not increase memory requirements. - - [1]: "Statistically efficient thinning of a Markov chain sampler." - Art B. Owen. April 2017. - http://statweb.stanford.edu/~owen/reports/bestthinning.pdf - - #### Examples: - - ##### Sample from a diagonal-variance Gaussian. - - ```python - tfd = tf.contrib.distributions - - def make_likelihood(true_variances): - return tfd.MultivariateNormalDiag( - scale_diag=tf.sqrt(true_variances)) - - dims = 10 - dtype = np.float32 - true_variances = tf.linspace(dtype(1), dtype(3), dims) - likelihood = make_likelihood(true_variances) - - states, kernel_results = hmc.sample_chain( - num_results=1000, - target_log_prob_fn=likelihood.log_prob, - current_state=tf.zeros(dims), - step_size=0.5, - num_leapfrog_steps=2, - num_burnin_steps=500) - - # Compute sample stats. - sample_mean = tf.reduce_mean(states, axis=0) - sample_var = tf.reduce_mean( - tf.squared_difference(states, sample_mean), - axis=0) - ``` - - ##### Sampling from factor-analysis posteriors with known factors. - - I.e., - - ```none - for i=1..n: - w[i] ~ Normal(0, eye(d)) # prior - x[i] ~ Normal(loc=matmul(w[i], F)) # likelihood - ``` - - where `F` denotes factors. - - ```python - tfd = tf.contrib.distributions - - def make_prior(dims, dtype): - return tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)) - - def make_likelihood(weights, factors): - return tfd.MultivariateNormalDiag( - loc=tf.tensordot(weights, factors, axes=[[0], [-1]])) - - # Setup data. - num_weights = 10 - num_factors = 4 - num_chains = 100 - dtype = np.float32 - - prior = make_prior(num_weights, dtype) - weights = prior.sample(num_chains) - factors = np.random.randn(num_factors, num_weights).astype(dtype) - x = make_likelihood(weights, factors).sample(num_chains) - - def target_log_prob(w): - # Target joint is: `f(w) = p(w, x | factors)`. - return prior.log_prob(w) + make_likelihood(w, factors).log_prob(x) - - # Get `num_results` samples from `num_chains` independent chains. - chains_states, kernels_results = hmc.sample_chain( - num_results=1000, - target_log_prob_fn=target_log_prob, - current_state=tf.zeros([num_chains, dims], dtype), - step_size=0.1, - num_leapfrog_steps=2, - num_burnin_steps=500) - - # Compute sample stats. - sample_mean = tf.reduce_mean(chains_states, axis=[0, 1]) - sample_var = tf.reduce_mean( - tf.squared_difference(chains_states, sample_mean), - axis=[0, 1]) - ``` - - Args: - num_results: Integer number of Markov chain draws. - target_log_prob_fn: Python callable which takes an argument like - `current_state` (or `*current_state` if it's a list) and returns its - (possibly unnormalized) log-density under the target distribution. - current_state: `Tensor` or Python `list` of `Tensor`s representing the - current state(s) of the Markov chain(s). The first `r` dimensions index - independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. - step_size: `Tensor` or Python `list` of `Tensor`s representing the step size - for the leapfrog integrator. Must broadcast with the shape of - `current_state`. Larger step sizes lead to faster progress, but too-large - step sizes make rejection exponentially more likely. When possible, it's - often helpful to match per-variable step sizes to the standard deviations - of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - num_burnin_steps: Integer number of chain steps to take before starting to - collect results. - Default value: 0 (i.e., no burn-in). - num_steps_between_results: Integer number of chain steps between collecting - a result. Only one out of every `num_steps_between_samples + 1` steps is - included in the returned results. The number of returned chain states is - still equal to `num_results`. Default value: 0 (i.e., no thinning). - seed: Python integer to seed the random number generator. - current_target_log_prob: (Optional) `Tensor` representing the value of - `target_log_prob_fn` at the `current_state`. The only reason to specify - this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - current_grads_target_log_prob: (Optional) Python list of `Tensor`s - representing gradient of `target_log_prob` at the `current_state` and wrt - the `current_state`. Must have same shape as `current_state`. The only - reason to specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_sample_chain"). - - Returns: - next_states: Tensor or Python list of `Tensor`s representing the - state(s) of the Markov chain(s) at each result step. Has same shape as - input `current_state` but with a prepended `num_results`-size dimension. - kernel_results: `collections.namedtuple` of internal calculations used to - advance the chain. - """ - with ops.name_scope( - name, "hmc_sample_chain", - [num_results, current_state, step_size, num_leapfrog_steps, - num_burnin_steps, num_steps_between_results, seed, - current_target_log_prob, current_grads_target_log_prob]): - with ops.name_scope("initialize"): - [ - current_state, - step_size, - current_target_log_prob, - current_grads_target_log_prob, - ] = _prepare_args( - target_log_prob_fn, - current_state, - step_size, - current_target_log_prob, - current_grads_target_log_prob) - num_results = ops.convert_to_tensor( - num_results, - dtype=dtypes.int32, - name="num_results") - num_leapfrog_steps = ops.convert_to_tensor( - num_leapfrog_steps, - dtype=dtypes.int32, - name="num_leapfrog_steps") - num_burnin_steps = ops.convert_to_tensor( - num_burnin_steps, - dtype=dtypes.int32, - name="num_burnin_steps") - num_steps_between_results = ops.convert_to_tensor( - num_steps_between_results, - dtype=dtypes.int32, - name="num_steps_between_results") - - def _run_chain(num_steps, current_state, kernel_results): - """Runs the chain(s) for `num_steps`.""" - def _loop_body(iter_, current_state, kernel_results): - return [iter_ + 1] + list(kernel( - target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - seed, - kernel_results.current_target_log_prob, - kernel_results.current_grads_target_log_prob)) - while_loop_kwargs = dict( - cond=lambda iter_, *args: iter_ < num_steps, - body=_loop_body, - loop_vars=[ - np.int32(0), - current_state, - kernel_results, - ], - ) - if seed is not None: - while_loop_kwargs["parallel_iterations"] = 1 - return control_flow_ops.while_loop( - **while_loop_kwargs)[1:] # Lop-off "iter_". - - def _scan_body(args_list, iter_): - """Closure which implements `tf.scan` body.""" - current_state, kernel_results = args_list - return _run_chain( - 1 + array_ops.where(math_ops.equal(iter_, 0), - num_burnin_steps, - num_steps_between_results), - current_state, - kernel_results) - - scan_kwargs = dict( - fn=_scan_body, - elems=math_ops.range(num_results), # iter_: used to choose burnin. - initializer=[ - current_state, - _make_dummy_kernel_results( - current_state, - current_target_log_prob, - current_grads_target_log_prob), - ]) - if seed is not None: - scan_kwargs["parallel_iterations"] = 1 - return functional_ops.scan(**scan_kwargs) - - -def sample_annealed_importance_chain( - proposal_log_prob_fn, - num_steps, - target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - seed=None, - name=None): - """Runs annealed importance sampling (AIS) to estimate normalizing constants. - - This function uses Hamiltonian Monte Carlo to sample from a series of - distributions that slowly interpolates between an initial "proposal" - distribution: - - `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)` - - and the target distribution: - - `exp(target_log_prob_fn(x) - target_log_normalizer)`, - - accumulating importance weights along the way. The product of these - importance weights gives an unbiased estimate of the ratio of the - normalizing constants of the initial distribution and the target - distribution: - - `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`. - - Note: `proposal_log_prob_fn` and `target_log_prob_fn` are called exactly three - times (although this may be reduced to two times, in the future). - - #### Examples: - - ##### Estimate the normalizing constant of a log-gamma distribution. - - ```python - tfd = tf.contrib.distributions - - # Run 100 AIS chains in parallel - num_chains = 100 - dims = 20 - dtype = np.float32 - - proposal = tfd.MultivatiateNormalDiag( - loc=tf.zeros([dims], dtype=dtype)) - - target = tfd.TransformedDistribution( - distribution=tfd.Gamma(concentration=dtype(2), - rate=dtype(3)), - bijector=tfd.bijectors.Invert(tfd.bijectors.Exp()), - event_shape=[dims]) - - chains_state, ais_weights, kernels_results = ( - hmc.sample_annealed_importance_chain( - proposal_log_prob_fn=proposal.log_prob, - num_steps=1000, - target_log_prob_fn=target.log_prob, - step_size=0.2, - current_state=proposal.sample(num_chains), - num_leapfrog_steps=2)) - - log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights) - - np.log(num_chains)) - log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.) - ``` - - ##### Estimate marginal likelihood of a Bayesian regression model. - - ```python - tfd = tf.contrib.distributions - - def make_prior(dims, dtype): - return tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)) - - def make_likelihood(weights, x): - return tfd.MultivariateNormalDiag( - loc=tf.tensordot(weights, x, axes=[[0], [-1]])) - - # Run 100 AIS chains in parallel - num_chains = 100 - dims = 10 - dtype = np.float32 - - # Make training data. - x = np.random.randn(num_chains, dims).astype(dtype) - true_weights = np.random.randn(dims).astype(dtype) - y = np.dot(x, true_weights) + np.random.randn(num_chains) - - # Setup model. - prior = make_prior(dims, dtype) - def target_log_prob_fn(weights): - return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y) - - proposal = tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)) - - weight_samples, ais_weights, kernel_results = ( - hmc.sample_annealed_importance_chain( - num_steps=1000, - proposal_log_prob_fn=proposal.log_prob, - target_log_prob_fn=target_log_prob_fn - current_state=tf.zeros([num_chains, dims], dtype), - step_size=0.1, - num_leapfrog_steps=2)) - log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights) - - np.log(num_chains)) - ``` - - Args: - proposal_log_prob_fn: Python callable that returns the log density of the - initial distribution. - num_steps: Integer number of Markov chain updates to run. More - iterations means more expense, but smoother annealing between q - and p, which in turn means exponentially lower variance for the - normalizing constant estimator. - target_log_prob_fn: Python callable which takes an argument like - `current_state` (or `*current_state` if it's a list) and returns its - (possibly unnormalized) log-density under the target distribution. - current_state: `Tensor` or Python `list` of `Tensor`s representing the - current state(s) of the Markov chain(s). The first `r` dimensions index - independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. - step_size: `Tensor` or Python `list` of `Tensor`s representing the step size - for the leapfrog integrator. Must broadcast with the shape of - `current_state`. Larger step sizes lead to faster progress, but too-large - step sizes make rejection exponentially more likely. When possible, it's - often helpful to match per-variable step sizes to the standard deviations - of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - seed: Python integer to seed the random number generator. - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_sample_annealed_importance_chain"). - - Returns: - next_state: `Tensor` or Python list of `Tensor`s representing the - state(s) of the Markov chain(s) at the final iteration. Has same shape as - input `current_state`. - ais_weights: Tensor with the estimated weight(s). Has shape matching - `target_log_prob_fn(current_state)`. - kernel_results: `collections.namedtuple` of internal calculations used to - advance the chain. - """ - def make_convex_combined_log_prob_fn(iter_): - def _fn(*args): - p = proposal_log_prob_fn(*args) - t = target_log_prob_fn(*args) - dtype = p.dtype.base_dtype - beta = (math_ops.cast(iter_ + 1, dtype) - / math_ops.cast(num_steps, dtype)) - return (1. - beta) * p + beta * t - return _fn - - with ops.name_scope( - name, "hmc_sample_annealed_importance_chain", - [num_steps, current_state, step_size, num_leapfrog_steps, seed]): - with ops.name_scope("initialize"): - [ - current_state, - step_size, - current_log_prob, - current_grads_log_prob, - ] = _prepare_args( - make_convex_combined_log_prob_fn(iter_=0), - current_state, - step_size, - description="convex_combined_log_prob") - num_steps = ops.convert_to_tensor( - num_steps, - dtype=dtypes.int32, - name="num_steps") - num_leapfrog_steps = ops.convert_to_tensor( - num_leapfrog_steps, - dtype=dtypes.int32, - name="num_leapfrog_steps") - def _loop_body(iter_, ais_weights, current_state, kernel_results): - """Closure which implements `tf.while_loop` body.""" - current_state_parts = (list(current_state) - if _is_list_like(current_state) - else [current_state]) - # TODO(b/72994218): Consider refactoring things to avoid this unecessary - # call. - ais_weights += ((target_log_prob_fn(*current_state_parts) - - proposal_log_prob_fn(*current_state_parts)) - / math_ops.cast(num_steps, ais_weights.dtype)) - return [iter_ + 1, ais_weights] + list(kernel( - make_convex_combined_log_prob_fn(iter_), - current_state, - step_size, - num_leapfrog_steps, - seed, - kernel_results.current_target_log_prob, - kernel_results.current_grads_target_log_prob)) - - while_loop_kwargs = dict( - cond=lambda iter_, *args: iter_ < num_steps, - body=_loop_body, - loop_vars=[ - np.int32(0), # iter_ - array_ops.zeros_like(current_log_prob), # ais_weights - current_state, - _make_dummy_kernel_results(current_state, - current_log_prob, - current_grads_log_prob), - ]) - if seed is not None: - while_loop_kwargs["parallel_iterations"] = 1 - - [ais_weights, current_state, kernel_results] = control_flow_ops.while_loop( - **while_loop_kwargs)[1:] # Lop-off "iter_". - - return [current_state, ais_weights, kernel_results] - - -def kernel(target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - seed=None, - current_target_log_prob=None, - current_grads_target_log_prob=None, - name=None): - """Runs one iteration of Hamiltonian Monte Carlo. - - Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) - algorithm that takes a series of gradient-informed steps to produce - a Metropolis proposal. This function applies one step of HMC to - randomly update the variable `x`. - - This function can update multiple chains in parallel. It assumes that all - leftmost dimensions of `current_state` index independent chain states (and are - therefore updated independently). The output of `target_log_prob_fn()` should - sum log-probabilities across all event dimensions. Slices along the rightmost - dimensions may have different target distributions; for example, - `current_state[0, :]` could have a different target distribution from - `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of - independent chains is `tf.size(target_log_prob_fn(*current_state))`.) - - #### Examples: - - ##### Simple chain with warm-up. - - ```python - tfd = tf.contrib.distributions - - # Tuning acceptance rates: - dtype = np.float32 - target_accept_rate = 0.631 - num_warmup_iter = 500 - num_chain_iter = 500 - - x = tf.get_variable(name="x", initializer=dtype(1)) - step_size = tf.get_variable(name="step_size", initializer=dtype(1)) - - target = tfd.Normal(loc=dtype(0), scale=dtype(1)) - - next_x, other_results = hmc.kernel( - target_log_prob_fn=target.log_prob, - current_state=x, - step_size=step_size, - num_leapfrog_steps=3)[:4] - - x_update = x.assign(next_x) - - step_size_update = step_size.assign_add( - step_size * tf.where( - tf.exp(tf.minimum(other_results.log_accept_ratio), 0.) > - target_accept_rate, - 0.01, -0.01)) - - warmup = tf.group([x_update, step_size_update]) - - tf.global_variables_initializer().run() - - sess.graph.finalize() # No more graph building. - - # Warm up the sampler and adapt the step size - for _ in xrange(num_warmup_iter): - sess.run(warmup) - - # Collect samples without adapting step size - samples = np.zeros([num_chain_iter]) - for i in xrange(num_chain_iter): - _, x_, target_log_prob_, grad_ = sess.run([ - x_update, - x, - other_results.target_log_prob, - other_results.grads_target_log_prob]) - samples[i] = x_ - - print(samples.mean(), samples.std()) - ``` - - ##### Sample from more complicated posterior. - - I.e., - - ```none - W ~ MVN(loc=0, scale=sigma * eye(dims)) - for i=1...num_samples: - X[i] ~ MVN(loc=0, scale=eye(dims)) - eps[i] ~ Normal(loc=0, scale=1) - Y[i] = X[i].T * W + eps[i] - ``` - - ```python - tfd = tf.contrib.distributions - - def make_training_data(num_samples, dims, sigma): - dt = np.asarray(sigma).dtype - zeros = tf.zeros(dims, dtype=dt) - x = tfd.MultivariateNormalDiag( - loc=zeros).sample(num_samples, seed=1) - w = tfd.MultivariateNormalDiag( - loc=zeros, - scale_identity_multiplier=sigma).sample(seed=2) - noise = tfd.Normal( - loc=dt(0), - scale=dt(1)).sample(num_samples, seed=3) - y = tf.tensordot(x, w, axes=[[1], [0]]) + noise - return y, x, w - - def make_prior(sigma, dims): - # p(w | sigma) - return tfd.MultivariateNormalDiag( - loc=tf.zeros([dims], dtype=sigma.dtype), - scale_identity_multiplier=sigma) - - def make_likelihood(x, w): - # p(y | x, w) - return tfd.MultivariateNormalDiag( - loc=tf.tensordot(x, w, axes=[[1], [0]])) - - # Setup assumptions. - dtype = np.float32 - num_samples = 150 - dims = 10 - num_iters = int(5e3) - - true_sigma = dtype(0.5) - y, x, true_weights = make_training_data(num_samples, dims, true_sigma) - - # Estimate of `log(true_sigma)`. - log_sigma = tf.get_variable(name="log_sigma", initializer=dtype(0)) - sigma = tf.exp(log_sigma) - - # State of the Markov chain. - weights = tf.get_variable( - name="weights", - initializer=np.random.randn(dims).astype(dtype)) - - prior = make_prior(sigma, dims) - - def joint_log_prob_fn(w): - # f(w) = log p(w, y | x) - return prior.log_prob(w) + make_likelihood(x, w).log_prob(y) - - weights_update = weights.assign( - hmc.kernel(target_log_prob_fn=joint_log_prob, - current_state=weights, - step_size=0.1, - num_leapfrog_steps=5)[0]) - - with tf.control_dependencies([weights_update]): - loss = -prior.log_prob(weights) - - optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) - log_sigma_update = optimizer.minimize(loss, var_list=[log_sigma]) - - sess.graph.finalize() # No more graph building. - - tf.global_variables_initializer().run() - - sigma_history = np.zeros(num_iters, dtype) - weights_history = np.zeros([num_iters, dims], dtype) - - for i in xrange(num_iters): - _, sigma_, weights_, _ = sess.run([log_sigma_update, sigma, weights]) - weights_history[i, :] = weights_ - sigma_history[i] = sigma_ - - true_weights_ = sess.run(true_weights) - - # Should converge to something close to true_sigma. - plt.plot(sigma_history); - plt.ylabel("sigma"); - plt.xlabel("iteration"); - ``` - - Args: - target_log_prob_fn: Python callable which takes an argument like - `current_state` (or `*current_state` if it's a list) and returns its - (possibly unnormalized) log-density under the target distribution. - current_state: `Tensor` or Python `list` of `Tensor`s representing the - current state(s) of the Markov chain(s). The first `r` dimensions index - independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. - step_size: `Tensor` or Python `list` of `Tensor`s representing the step size - for the leapfrog integrator. Must broadcast with the shape of - `current_state`. Larger step sizes lead to faster progress, but too-large - step sizes make rejection exponentially more likely. When possible, it's - often helpful to match per-variable step sizes to the standard deviations - of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - seed: Python integer to seed the random number generator. - current_target_log_prob: (Optional) `Tensor` representing the value of - `target_log_prob_fn` at the `current_state`. The only reason to - specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - current_grads_target_log_prob: (Optional) Python list of `Tensor`s - representing gradient of `current_target_log_prob` at the `current_state` - and wrt the `current_state`. Must have same shape as `current_state`. The - only reason to specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_kernel"). - - Returns: - next_state: Tensor or Python list of `Tensor`s representing the state(s) - of the Markov chain(s) at each result step. Has same shape as - `current_state`. - kernel_results: `collections.namedtuple` of internal calculations used to - advance the chain. - - Raises: - ValueError: if there isn't one `step_size` or a list with same length as - `current_state`. - """ - with ops.name_scope( - name, "hmc_kernel", - [current_state, step_size, num_leapfrog_steps, seed, - current_target_log_prob, current_grads_target_log_prob]): - with ops.name_scope("initialize"): - [current_state_parts, step_sizes, current_target_log_prob, - current_grads_target_log_prob] = _prepare_args( - target_log_prob_fn, current_state, step_size, - current_target_log_prob, current_grads_target_log_prob, - maybe_expand=True) - independent_chain_ndims = distributions_util.prefer_static_rank( - current_target_log_prob) - current_momentums = [] - for s in current_state_parts: - current_momentums.append(random_ops.random_normal( - shape=array_ops.shape(s), - dtype=s.dtype.base_dtype, - seed=seed)) - seed = distributions_util.gen_new_seed( - seed, salt="hmc_kernel_momentums") - - num_leapfrog_steps = ops.convert_to_tensor( - num_leapfrog_steps, - dtype=dtypes.int32, - name="num_leapfrog_steps") - [ - proposed_momentums, - proposed_state_parts, - proposed_target_log_prob, - proposed_grads_target_log_prob, - ] = _leapfrog_integrator(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - num_leapfrog_steps, - current_target_log_prob, - current_grads_target_log_prob) - - energy_change = _compute_energy_change(current_target_log_prob, - current_momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims) - log_accept_ratio = -energy_change - - # u < exp(log_accept_ratio), where u~Uniform[0,1) - # ==> log(u) < log_accept_ratio - random_value = random_ops.random_uniform( - shape=array_ops.shape(energy_change), - dtype=energy_change.dtype, - seed=seed) - random_negative = math_ops.log(random_value) - is_accepted = random_negative < log_accept_ratio - - accepted_target_log_prob = array_ops.where(is_accepted, - proposed_target_log_prob, - current_target_log_prob) - - next_state_parts = [_choose(is_accepted, - proposed_state_part, - current_state_part, - independent_chain_ndims) - for current_state_part, proposed_state_part - in zip(current_state_parts, proposed_state_parts)] - - accepted_grads_target_log_prob = [ - _choose(is_accepted, - proposed_grad, - grad, - independent_chain_ndims) - for proposed_grad, grad - in zip(proposed_grads_target_log_prob, current_grads_target_log_prob)] - - maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0] - return [ - maybe_flatten(next_state_parts), - KernelResults( - log_accept_ratio=log_accept_ratio, - current_grads_target_log_prob=accepted_grads_target_log_prob, - current_target_log_prob=accepted_target_log_prob, - is_accepted=is_accepted, - proposed_grads_target_log_prob=proposed_grads_target_log_prob, - proposed_state=maybe_flatten(proposed_state_parts), - proposed_target_log_prob=proposed_target_log_prob, - ), - ] - - -def _leapfrog_integrator(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - num_leapfrog_steps, - current_target_log_prob=None, - current_grads_target_log_prob=None, - name=None): - """Applies `num_leapfrog_steps` of the leapfrog integrator. - - Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`. - - #### Examples: - - ##### Simple quadratic potential. - - ```python - tfd = tf.contrib.distributions - - dims = 10 - num_iter = int(1e3) - dtype = np.float32 - - position = tf.placeholder(np.float32) - momentum = tf.placeholder(np.float32) - - [ - next_momentums, - next_positions, - ] = hmc._leapfrog_integrator( - current_momentums=[momentum], - target_log_prob_fn=tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)).log_prob, - current_state_parts=[position], - step_sizes=0.1, - num_leapfrog_steps=3)[:2] - - sess.graph.finalize() # No more graph building. - - momentum_ = np.random.randn(dims).astype(dtype) - position_ = np.random.randn(dims).astype(dtype) - - positions = np.zeros([num_iter, dims], dtype) - for i in xrange(num_iter): - position_, momentum_ = sess.run( - [next_momentums[0], next_position[0]], - feed_dict={position: position_, momentum: momentum_}) - positions[i] = position_ - - plt.plot(positions[:, 0]); # Sinusoidal. - ``` - - Args: - current_momentums: Tensor containing the value(s) of the momentum - variable(s) to update. - target_log_prob_fn: Python callable which takes an argument like - `*current_state_parts` and returns its (possibly unnormalized) log-density - under the target distribution. - current_state_parts: Python `list` of `Tensor`s representing the current - state(s) of the Markov chain(s). The first `independent_chain_ndims` of - the `Tensor`(s) index different chains. - step_sizes: Python `list` of `Tensor`s representing the step size for the - leapfrog integrator. Must broadcast with the shape of - `current_state_parts`. Larger step sizes lead to faster progress, but - too-large step sizes make rejection exponentially more likely. When - possible, it's often helpful to match per-variable step sizes to the - standard deviations of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - current_target_log_prob: (Optional) `Tensor` representing the value of - `target_log_prob_fn(*current_state_parts)`. The only reason to specify - this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - current_grads_target_log_prob: (Optional) Python list of `Tensor`s - representing gradient of `target_log_prob_fn(*current_state_parts`) wrt - `current_state_parts`. Must have same shape as `current_state_parts`. The - only reason to specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_leapfrog_integrator"). - - Returns: - proposed_momentums: Updated value of the momentum. - proposed_state_parts: Tensor or Python list of `Tensor`s representing the - state(s) of the Markov chain(s) at each result step. Has same shape as - input `current_state_parts`. - proposed_target_log_prob: `Tensor` representing the value of - `target_log_prob_fn` at `next_state`. - proposed_grads_target_log_prob: Gradient of `proposed_target_log_prob` wrt - `next_state`. - - Raises: - ValueError: if `len(momentums) != len(state_parts)`. - ValueError: if `len(state_parts) != len(step_sizes)`. - ValueError: if `len(state_parts) != len(grads_target_log_prob)`. - TypeError: if `not target_log_prob.dtype.is_floating`. - """ - def _loop_body(step, - current_momentums, - current_state_parts, - ignore_current_target_log_prob, # pylint: disable=unused-argument - current_grads_target_log_prob): - return [step + 1] + list(_leapfrog_step(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - current_grads_target_log_prob)) - - with ops.name_scope( - name, "hmc_leapfrog_integrator", - [current_momentums, current_state_parts, step_sizes, num_leapfrog_steps, - current_target_log_prob, current_grads_target_log_prob]): - if len(current_momentums) != len(current_state_parts): - raise ValueError("`momentums` must be in one-to-one correspondence " - "with `state_parts`") - num_leapfrog_steps = ops.convert_to_tensor(num_leapfrog_steps, - name="num_leapfrog_steps") - current_target_log_prob, current_grads_target_log_prob = ( - _maybe_call_fn_and_grads( - target_log_prob_fn, - current_state_parts, - current_target_log_prob, - current_grads_target_log_prob)) - return control_flow_ops.while_loop( - cond=lambda iter_, *args: iter_ < num_leapfrog_steps, - body=_loop_body, - loop_vars=[ - np.int32(0), # iter_ - current_momentums, - current_state_parts, - current_target_log_prob, - current_grads_target_log_prob, - ], - back_prop=False)[1:] # Lop-off "iter_". - - -def _leapfrog_step(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - current_grads_target_log_prob, - name=None): - """Applies one step of the leapfrog integrator.""" - with ops.name_scope( - name, "_leapfrog_step", - [current_momentums, current_state_parts, step_sizes, - current_grads_target_log_prob]): - proposed_momentums = [m + 0.5 * ss * g for m, ss, g - in zip(current_momentums, - step_sizes, - current_grads_target_log_prob)] - proposed_state_parts = [x + ss * m for x, ss, m - in zip(current_state_parts, - step_sizes, - proposed_momentums)] - proposed_target_log_prob = target_log_prob_fn(*proposed_state_parts) - if not proposed_target_log_prob.dtype.is_floating: - raise TypeError("`target_log_prob_fn` must produce a `Tensor` " - "with `float` `dtype`.") - proposed_grads_target_log_prob = gradients_ops.gradients( - proposed_target_log_prob, proposed_state_parts) - if any(g is None for g in proposed_grads_target_log_prob): - raise ValueError( - "Encountered `None` gradient. Does your target `target_log_prob_fn` " - "access all `tf.Variable`s via `tf.get_variable`?\n" - " current_state_parts: {}\n" - " proposed_state_parts: {}\n" - " proposed_grads_target_log_prob: {}".format( - current_state_parts, - proposed_state_parts, - proposed_grads_target_log_prob)) - proposed_momentums = [m + 0.5 * ss * g for m, ss, g - in zip(proposed_momentums, - step_sizes, - proposed_grads_target_log_prob)] - return [ - proposed_momentums, - proposed_state_parts, - proposed_target_log_prob, - proposed_grads_target_log_prob, - ] - - -def _compute_energy_change(current_target_log_prob, - current_momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims, - name=None): - """Helper to `kernel` which computes the energy change.""" - with ops.name_scope( - name, "compute_energy_change", - ([current_target_log_prob, proposed_target_log_prob, - independent_chain_ndims] + - current_momentums + proposed_momentums)): - # Abbreviate lk0=log_kinetic_energy and lk1=proposed_log_kinetic_energy - # since they're a mouthful and lets us inline more. - lk0, lk1 = [], [] - for current_momentum, proposed_momentum in zip(current_momentums, - proposed_momentums): - axis = math_ops.range(independent_chain_ndims, - array_ops.rank(current_momentum)) - lk0.append(_log_sum_sq(current_momentum, axis)) - lk1.append(_log_sum_sq(proposed_momentum, axis)) - - lk0 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk0, axis=-1), - axis=-1) - lk1 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk1, axis=-1), - axis=-1) - lp0 = -current_target_log_prob # potential - lp1 = -proposed_target_log_prob # proposed_potential - x = array_ops.stack([lp1, math_ops.exp(lk1), -lp0, -math_ops.exp(lk0)], - axis=-1) - - # The sum is NaN if any element is NaN or we see both +Inf and -Inf. - # Thus we will replace such rows with infinite energy change which implies - # rejection. Recall that float-comparisons with NaN are always False. - is_sum_determinate = ( - math_ops.reduce_all(math_ops.is_finite(x) | (x >= 0.), axis=-1) & - math_ops.reduce_all(math_ops.is_finite(x) | (x <= 0.), axis=-1)) - is_sum_determinate = array_ops.tile( - is_sum_determinate[..., array_ops.newaxis], - multiples=array_ops.concat([ - array_ops.ones(array_ops.rank(is_sum_determinate), - dtype=dtypes.int32), - [4], - ], axis=0)) - x = array_ops.where(is_sum_determinate, - x, - array_ops.fill(array_ops.shape(x), - value=x.dtype.as_numpy_dtype(np.inf))) - - return math_ops.reduce_sum(x, axis=-1) - - -def _choose(is_accepted, - accepted, - rejected, - independent_chain_ndims, - name=None): - """Helper to `kernel` which expand_dims `is_accepted` to apply tf.where.""" - def _expand_is_accepted_like(x): - with ops.name_scope("_choose"): - expand_shape = array_ops.concat([ - array_ops.shape(is_accepted), - array_ops.ones([array_ops.rank(x) - array_ops.rank(is_accepted)], - dtype=dtypes.int32), - ], axis=0) - multiples = array_ops.concat([ - array_ops.ones([array_ops.rank(is_accepted)], dtype=dtypes.int32), - array_ops.shape(x)[independent_chain_ndims:], - ], axis=0) - m = array_ops.tile(array_ops.reshape(is_accepted, expand_shape), - multiples) - m.set_shape(x.shape) - return m - with ops.name_scope(name, "_choose", values=[ - is_accepted, accepted, rejected, independent_chain_ndims]): - return array_ops.where(_expand_is_accepted_like(accepted), - accepted, - rejected) - - -def _maybe_call_fn_and_grads(fn, - fn_arg_list, - fn_result=None, - grads_fn_result=None, - description="target_log_prob"): - """Helper which computes `fn_result` and `grads` if needed.""" - fn_arg_list = (list(fn_arg_list) if _is_list_like(fn_arg_list) - else [fn_arg_list]) - if fn_result is None: - fn_result = fn(*fn_arg_list) - if not fn_result.dtype.is_floating: - raise TypeError("`{}` must be a `Tensor` with `float` `dtype`.".format( - description)) - if grads_fn_result is None: - grads_fn_result = gradients_ops.gradients( - fn_result, fn_arg_list) - if len(fn_arg_list) != len(grads_fn_result): - raise ValueError("`{}` must be in one-to-one correspondence with " - "`grads_{}`".format(*[description]*2)) - if any(g is None for g in grads_fn_result): - raise ValueError("Encountered `None` gradient.") - return fn_result, grads_fn_result - - -def _prepare_args(target_log_prob_fn, state, step_size, - target_log_prob=None, grads_target_log_prob=None, - maybe_expand=False, description="target_log_prob"): - """Helper which processes input args to meet list-like assumptions.""" - state_parts = list(state) if _is_list_like(state) else [state] - state_parts = [ops.convert_to_tensor(s, name="state") - for s in state_parts] - target_log_prob, grads_target_log_prob = _maybe_call_fn_and_grads( - target_log_prob_fn, - state_parts, - target_log_prob, - grads_target_log_prob, - description) - step_sizes = list(step_size) if _is_list_like(step_size) else [step_size] - step_sizes = [ - ops.convert_to_tensor( - s, name="step_size", dtype=target_log_prob.dtype) - for s in step_sizes] - if len(step_sizes) == 1: - step_sizes *= len(state_parts) - if len(state_parts) != len(step_sizes): - raise ValueError("There should be exactly one `step_size` or it should " - "have same length as `current_state`.") - maybe_flatten = lambda x: x if maybe_expand or _is_list_like(state) else x[0] - return [ - maybe_flatten(state_parts), - maybe_flatten(step_sizes), - target_log_prob, - grads_target_log_prob, - ] - - -def _is_list_like(x): - """Helper which returns `True` if input is `list`-like.""" - return isinstance(x, (tuple, list)) - - -def _log_sum_sq(x, axis=None): - """Computes log(sum(x**2)).""" - return math_ops.reduce_logsumexp(2. * math_ops.log(math_ops.abs(x)), axis) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers.py b/tensorflow/contrib/bayesflow/python/ops/layers.py deleted file mode 100644 index a742b7c1aa593d6c08bf9d8d597c99c9fc4e7aed..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/layers.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Probabilistic neural layers. - -See ${python/contrib.bayesflow.layers}. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.layers_conv_variational import * -from tensorflow.contrib.bayesflow.python.ops.layers_dense_variational import * -from tensorflow.contrib.bayesflow.python.ops.layers_util import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = [ - 'Convolution1DReparameterization', - 'Convolution2DReparameterization', - 'Convolution3DReparameterization', - 'Convolution1DFlipout', - 'Convolution2DFlipout', - 'Convolution3DFlipout', - 'Conv1DReparameterization', - 'Conv2DReparameterization', - 'Conv3DReparameterization', - 'Conv1DFlipout', - 'Conv2DFlipout', - 'Conv3DFlipout', - 'convolution1d_reparameterization', - 'convolution2d_reparameterization', - 'convolution3d_reparameterization', - 'convolution1d_flipout', - 'convolution2d_flipout', - 'convolution3d_flipout', - 'conv1d_reparameterization', - 'conv2d_reparameterization', - 'conv3d_reparameterization', - 'conv1d_flipout', - 'conv2d_flipout', - 'conv3d_flipout', - 'DenseReparameterization', - 'DenseLocalReparameterization', - 'DenseFlipout', - 'dense_reparameterization', - 'dense_local_reparameterization', - 'dense_flipout', - 'default_loc_scale_fn', - 'default_mean_field_normal_fn', -] - -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py b/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py deleted file mode 100644 index cb80718f719ff31fb8ba5066170342fc69630780..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py +++ /dev/null @@ -1,2486 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Convolutional variational layer classes and their functional aliases. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.bayesflow.python.ops import docstring_util -from tensorflow.contrib.bayesflow.python.ops import layers_util -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.layers import base as layers_lib -from tensorflow.python.layers import utils -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import standard_ops -from tensorflow.python.ops.distributions import kullback_leibler as kl_lib -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.ops.distributions import util as distribution_util - -doc_args = """activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: A string, the name of the layer.""" - - -class _ConvVariational(layers_lib.Layer): - """Abstract nD convolution layer (private, used as implementation base). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Properties: - rank: Python integer, dimensionality of convolution. - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - """ - - @docstring_util.expand_docstring(args=doc_args) - def __init__( - self, - rank, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - # pylint: disable=g-doc-args - """Construct layer. - - Args: - rank: An integer, the rank of the convolution, e.g. "2" for 2D - convolution. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of n integers, specifying the - length of the convolution window. - strides: An integer or tuple/list of n integers, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or - `channels_first`. The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape `(batch, ..., - channels)` while `channels_first` corresponds to inputs with shape - `(batch, channels, ...)`. - dilation_rate: An integer or tuple/list of n integers, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - @{args} - """ - # pylint: enable=g-doc-args - super(_ConvVariational, self).__init__( - trainable=trainable, - name=name, - activity_regularizer=activity_regularizer, - **kwargs) - self.rank = rank - self.filters = filters - self.kernel_size = utils.normalize_tuple(kernel_size, rank, "kernel_size") - self.strides = utils.normalize_tuple(strides, rank, "strides") - self.padding = utils.normalize_padding(padding) - self.data_format = utils.normalize_data_format(data_format) - self.dilation_rate = utils.normalize_tuple( - dilation_rate, rank, "dilation_rate") - self.activation = activation - self.input_spec = layers_lib.InputSpec(ndim=self.rank + 2) - self.kernel_posterior_fn = kernel_posterior_fn - self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn - self.kernel_prior_fn = kernel_prior_fn - self.kernel_divergence_fn = kernel_divergence_fn - self.bias_posterior_fn = bias_posterior_fn - self.bias_posterior_tensor_fn = bias_posterior_tensor_fn - self.bias_prior_fn = bias_prior_fn - self.bias_divergence_fn = bias_divergence_fn - - def build(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape) - if self.data_format == "channels_first": - channel_axis = 1 - else: - channel_axis = -1 - if input_shape[channel_axis].value is None: - raise ValueError("The channel dimension of the inputs " - "should be defined. Found `None`.") - input_dim = input_shape[channel_axis].value - kernel_shape = self.kernel_size + (input_dim, self.filters) - dtype = dtypes.as_dtype(self.dtype) - - # Must have a posterior kernel. - self.kernel_posterior = self.kernel_posterior_fn( - dtype, kernel_shape, "kernel_posterior", - self.trainable, self.add_variable) - - if self.kernel_prior_fn is None: - self.kernel_prior = None - else: - self.kernel_prior = self.kernel_prior_fn( - dtype, kernel_shape, "kernel_prior", - self.trainable, self.add_variable) - self._built_kernel_divergence = False - - if self.bias_posterior_fn is None: - self.bias_posterior = None - else: - self.bias_posterior = self.bias_posterior_fn( - dtype, (self.filters,), "bias_posterior", - self.trainable, self.add_variable) - - if self.bias_prior_fn is None: - self.bias_prior = None - else: - self.bias_prior = self.bias_prior_fn( - dtype, (self.filters,), "bias_prior", - self.trainable, self.add_variable) - self._built_bias_divergence = False - - self.input_spec = layers_lib.InputSpec(ndim=self.rank + 2, - axes={channel_axis: input_dim}) - self._convolution_op = nn_ops.Convolution( - input_shape, - filter_shape=tensor_shape.TensorShape(kernel_shape), - dilation_rate=self.dilation_rate, - strides=self.strides, - padding=self.padding.upper(), - data_format=utils.convert_data_format(self.data_format, - self.rank + 2)) - - self.built = True - - def call(self, inputs): - inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) - - outputs = self._apply_variational_kernel(inputs) - outputs = self._apply_variational_bias(outputs) - if self.activation is not None: - outputs = self.activation(outputs) - if not self._built_kernel_divergence: - kernel_posterior = self.kernel_posterior - kernel_prior = self.kernel_prior - if isinstance(self.kernel_posterior, independent_lib.Independent): - kernel_posterior = kernel_posterior.distribution - if isinstance(self.kernel_prior, independent_lib.Independent): - kernel_prior = kernel_prior.distribution - self._apply_divergence(self.kernel_divergence_fn, - kernel_posterior, - kernel_prior, - self.kernel_posterior_tensor, - name="divergence_kernel") - self._built_kernel_divergence = True - if not self._built_bias_divergence: - bias_posterior = self.bias_posterior - bias_prior = self.bias_prior - if isinstance(self.bias_posterior, independent_lib.Independent): - bias_posterior = bias_posterior.distribution - if isinstance(self.bias_prior, independent_lib.Independent): - bias_prior = bias_prior.distribution - self._apply_divergence(self.bias_divergence_fn, - bias_posterior, - bias_prior, - self.bias_posterior_tensor, - name="divergence_bias") - self._built_bias_divergence = True - return outputs - - def _apply_variational_bias(self, inputs): - if self.bias_posterior is None: - self.bias_posterior_tensor = None - return inputs - self.bias_posterior_tensor = self.bias_posterior_tensor_fn( - self.bias_posterior) - outputs = inputs - if self.data_format == "channels_first": - if self.rank == 1: - # nn.bias_add does not accept a 1D input tensor. - bias = array_ops.reshape(self.bias_posterior_tensor, - (1, self.filters, 1)) - outputs += bias - if self.rank == 2: - outputs = nn.bias_add(outputs, - self.bias_posterior_tensor, - data_format="NCHW") - if self.rank == 3: - # As of Mar 2017, direct addition is significantly slower than - # bias_add when computing gradients. To use bias_add, we collapse Z - # and Y into a single dimension to obtain a 4D input tensor. - outputs_shape = outputs.shape.as_list() - outputs_4d = array_ops.reshape(outputs, - [outputs_shape[0], outputs_shape[1], - outputs_shape[2] * outputs_shape[3], - outputs_shape[4]]) - outputs_4d = nn.bias_add(outputs_4d, - self.bias_posterior_tensor, - data_format="NCHW") - outputs = array_ops.reshape(outputs_4d, outputs_shape) - else: - outputs = nn.bias_add(outputs, - self.bias_posterior_tensor, - data_format="NHWC") - return outputs - - def _apply_divergence(self, divergence_fn, posterior, prior, - posterior_tensor, name): - if (divergence_fn is None or - posterior is None or - prior is None): - divergence = None - return - divergence = standard_ops.identity( - divergence_fn( - posterior, prior, posterior_tensor), - name=name) - self.add_loss(divergence) - - def _compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).as_list() - if self.data_format == "channels_last": - space = input_shape[1:-1] - new_space = [] - for i in range(len(space)): - new_dim = utils.conv_output_length( - space[i], - self.kernel_size[i], - padding=self.padding, - stride=self.strides[i], - dilation=self.dilation_rate[i]) - new_space.append(new_dim) - return tensor_shape.TensorShape([input_shape[0]] + new_space + - [self.filters]) - else: - space = input_shape[2:] - new_space = [] - for i in range(len(space)): - new_dim = utils.conv_output_length( - space[i], - self.kernel_size[i], - padding=self.padding, - stride=self.strides[i], - dilation=self.dilation_rate[i]) - new_space.append(new_dim) - return tensor_shape.TensorShape([input_shape[0], self.filters] + - new_space) - - -class _ConvReparameterization(_ConvVariational): - """Abstract nD convolution layer (private, used as implementation base). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Properties: - rank: Python integer, dimensionality of convolution. - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - - @docstring_util.expand_docstring(args=doc_args) - def __init__( - self, - rank, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - # pylint: disable=g-doc-args - """Construct layer. - - Args: - rank: An integer, the rank of the convolution, e.g. "2" for 2D - convolution. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of n integers, specifying the - length of the convolution window. - strides: An integer or tuple/list of n integers, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or - `channels_first`. The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape `(batch, ..., - channels)` while `channels_first` corresponds to inputs with shape - `(batch, channels, ...)`. - dilation_rate: An integer or tuple/list of n integers, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - @{args} - """ - # pylint: enable=g-doc-args - super(_ConvReparameterization, self).__init__( - rank=rank, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, **kwargs) - - def _apply_variational_kernel(self, inputs): - self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn( - self.kernel_posterior) - self.kernel_posterior_affine = None - self.kernel_posterior_affine_tensor = None - outputs = self._convolution_op(inputs, self.kernel_posterior_tensor) - return outputs - - -class Conv1DReparameterization(_ConvReparameterization): - """1D convolution layer (e.g. temporal convolution). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Properties: - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 128, 1]) - net = tfp.layers.Conv1DReparameterization(64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu)(net) - net = tf.reshape(net, [-1, 128 * 64]) - logits = tfp.layers.DenseReparameterization(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - - @docstring_util.expand_docstring(args=doc_args) - def __init__( - self, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - # pylint: disable=g-doc-args - """Construct layer. - - Args: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of a single integer, specifying the - length of the 1D convolution window. - strides: An integer or tuple/list of a single integer, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or - `channels_first`. The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape `(batch, length, - channels)` while `channels_first` corresponds to inputs with shape - `(batch, channels, length)`. - dilation_rate: An integer or tuple/list of a single integer, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - @{args} - """ - # pylint: enable=g-doc-args - super(Conv1DReparameterization, self).__init__( - rank=1, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, **kwargs) - - -@docstring_util.expand_docstring(args=doc_args) -def conv1d_reparameterization( - inputs, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - reuse=None): - # pylint: disable=g-doc-args - """Functional interface for 1D convolution layer (e.g. temporal convolution). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of a single integer, specifying the - length of the 1D convolution window. - strides: An integer or tuple/list of a single integer, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - dilation_rate: An integer or tuple/list of a single integer, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - @{args} - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 128, 1]) - net = tfp.layers.conv1d_reparameterization(net, - filters=64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu) - net = tf.reshape(net, [-1, 128 * 64]) - logits = tfp.layers.dense_reparameterization(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - # pylint: enable=g-doc-args - layer = Conv1DReparameterization( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class Conv2DReparameterization(_ConvReparameterization): - """2D convolution layer (e.g. spatial convolution over images). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Properties: - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 32, 32, 3]) - net = tfp.layers.Conv2DReparameterization(64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu)(net) - net = tf.layers.MaxPooling2D(pool_size=2, - strides=2, - padding="SAME")(net) - net = tf.reshape(net, [-1, 8 * 8 * 64]) - logits = tfp.layers.DenseReparameterization(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - - @docstring_util.expand_docstring(args=doc_args) - def __init__( - self, - filters, - kernel_size, - strides=(1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - # pylint: disable=g-doc-args - """Construct layer. - - Args: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or - `channels_first`. The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape `(batch, height, - width, channels)` while `channels_first` corresponds to inputs with - shape `(batch, channels, height, width)`. - dilation_rate: An integer or tuple/list of 2 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - @{args} - """ - # pylint: enable=g-doc-args - super(Conv2DReparameterization, self).__init__( - rank=2, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, **kwargs) - - -@docstring_util.expand_docstring(args=doc_args) -def conv2d_reparameterization( - inputs, - filters, - kernel_size, - strides=(1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - reuse=None): - # pylint: disable=g-doc-args - """Functional interface for the 2D convolution layer. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - dilation_rate: An integer or tuple/list of 2 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - @{args} - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 32, 32, 3]) - net = tfp.layers.conv2d_reparameterization(net, - filters=64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu) - net = tf.layers.max_pooling2d(net, - pool_size=2, - strides=2, - padding="SAME") - net = tf.reshape(net, [-1, 8 * 8 * 64]) - logits = tfp.layers.dense_reparameterization(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - # pylint: enable=g-doc-args - layer = Conv2DReparameterization( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class Conv3DReparameterization(_ConvReparameterization): - """3D convolution layer (e.g. spatial convolution over volumes). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Properties: - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 256, 32, 32, 3]) - net = tfp.layers.Conv3DReparameterization(64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu)(net) - net = tf.layers.MaxPooling2D(pool_size=2, - strides=2, - padding="SAME")(net) - net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) - logits = tfp.layers.DenseReparameterization(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - - @docstring_util.expand_docstring(args=doc_args) - def __init__( - self, - filters, - kernel_size, - strides=(1, 1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - # pylint: disable=g-doc-args - """Construct layer. - - Args: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 3 integers, specifying the - depth, height and width of the 3D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the convolution along the depth, - height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or - `channels_first`. The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape `(batch, depth, - height, width, channels)` while `channels_first` corresponds to inputs - with shape `(batch, channels, depth, height, width)`. - dilation_rate: An integer or tuple/list of 3 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - @{args} - """ - # pylint: enable=g-doc-args - super(Conv3DReparameterization, self).__init__( - rank=3, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, **kwargs) - - -@docstring_util.expand_docstring(args=doc_args) -def conv3d_reparameterization( - inputs, - filters, - kernel_size, - strides=(1, 1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - reuse=None): - # pylint: disable=g-doc-args - """Functional interface for the 3D convolution layer. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 3 integers, specifying the - depth, height and width of the 3D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the convolution along the depth, - height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - dilation_rate: An integer or tuple/list of 3 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - @{args} - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 256, 32, 32, 3]) - net = tfp.layers.conv3d_reparameterization(net, - filters=64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu) - net = tf.layers.max_pooling2d(net, - pool_size=2, - strides=2, - padding="SAME") - net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) - logits = tfp.layers.dense_reparameterization(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - # pylint: enable=g-doc-args - layer = Conv3DReparameterization( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class _ConvFlipout(_ConvVariational): - """Abstract nD convolution layer (private, used as implementation base). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Properties: - rank: Python integer, dimensionality of convolution. - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - seed: Python integer, used to create random seeds. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse. - International Conference on Learning Representations, 2018. - """ - - @docstring_util.expand_docstring(args=doc_args) - def __init__( - self, - rank, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - **kwargs): - # pylint: disable=g-doc-args - """Construct layer. - - Args: - rank: An integer, the rank of the convolution, e.g. "2" for 2D - convolution. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of n integers, specifying the - length of the convolution window. - strides: An integer or tuple/list of n integers, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or - `channels_first`. The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape `(batch, ..., - channels)` while `channels_first` corresponds to inputs with shape - `(batch, channels, ...)`. - dilation_rate: An integer or tuple/list of n integers, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - @{args} - """ - # pylint: enable=g-doc-args - super(_ConvFlipout, self).__init__( - rank=rank, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, **kwargs) - self.seed = seed - - def _apply_variational_kernel(self, inputs): - if (not isinstance(self.kernel_posterior, independent_lib.Independent) or - not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): - raise TypeError( - "`{}` requires " - "`kernel_posterior_fn` produce an instance of " - "`tf.distributions.Independent(tf.distributions.Normal)` " - "(saw: \"{}\").".format( - type(self).__name__, self.kernel_posterior.name)) - self.kernel_posterior_affine = normal_lib.Normal( - loc=array_ops.zeros_like(self.kernel_posterior.distribution.loc), - scale=self.kernel_posterior.distribution.scale) - self.kernel_posterior_affine_tensor = ( - self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) - self.kernel_posterior_tensor = None - - outputs = self._convolution_op( - inputs, self.kernel_posterior.distribution.loc) - - input_shape = array_ops.shape(inputs) - output_shape = array_ops.shape(outputs) - batch_shape = array_ops.expand_dims(input_shape[0], 0) - channels = input_shape[-1] - - sign_input = layers_util.random_sign( - array_ops.concat([batch_shape, - array_ops.expand_dims(channels, 0)], 0), - dtype=inputs.dtype, - seed=self.seed) - sign_output = layers_util.random_sign( - array_ops.concat([batch_shape, - array_ops.expand_dims(self.filters, 0)], 0), - dtype=inputs.dtype, - seed=distribution_util.gen_new_seed( - self.seed, salt="conv_flipout")) - for _ in range(self.rank): - sign_input = array_ops.expand_dims(sign_input, 1) # 2D ex: (B, 1, 1, C) - sign_output = array_ops.expand_dims(sign_output, 1) - - sign_input = array_ops.tile( # tile for element-wise op broadcasting - sign_input, - [1] + [input_shape[i + 1] for i in range(self.rank)] + [1]) - sign_output = array_ops.tile( - sign_output, - [1] + [output_shape[i + 1] for i in range(self.rank)] + [1]) - - perturbed_inputs = self._convolution_op( - inputs * sign_input, self.kernel_posterior_affine_tensor) * sign_output - - outputs += perturbed_inputs - return outputs - - -class Conv1DFlipout(_ConvFlipout): - """1D convolution layer (e.g. temporal convolution) with Flipout. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Properties: - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - seed: Python integer, used to create random seeds. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 128, 1]) - net = tfp.layers.Conv1DFlipout(64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu)(net) - net = tf.reshape(net, [-1, 128 * 64]) - logits = tfp.layers.DenseFlipout(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse. - International Conference on Learning Representations, 2018. - """ - - @docstring_util.expand_docstring(args=doc_args) - def __init__( - self, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - **kwargs): - # pylint: disable=g-doc-args - """Construct layer. - - Args: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of a single integer, specifying the - length of the 1D convolution window. - strides: An integer or tuple/list of a single integer, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or - `channels_first`. The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape `(batch, length, - channels)` while `channels_first` corresponds to inputs with shape - `(batch, channels, length)`. - dilation_rate: An integer or tuple/list of a single integer, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - @{args} - """ - # pylint: enable=g-doc-args - super(Conv1DFlipout, self).__init__( - rank=1, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, **kwargs) - - -@docstring_util.expand_docstring(args=doc_args) -def conv1d_flipout( - inputs, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - reuse=None): - # pylint: disable=g-doc-args - """Functional interface for 1D convolution layer (e.g. temporal convolution). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of a single integer, specifying the - length of the 1D convolution window. - strides: An integer or tuple/list of a single integer, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - dilation_rate: An integer or tuple/list of a single integer, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - @{args} - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 128, 1]) - net = tfp.layers.conv1d_flipout(net, - filters=64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu) - net = tf.reshape(net, [-1, 128 * 64]) - logits = tfp.layers.dense_flipout(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse. - International Conference on Learning Representations, 2018. - """ - # pylint: enable=g-doc-args - layer = Conv1DFlipout( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class Conv2DFlipout(_ConvFlipout): - """2D convolution layer (e.g. spatial convolution over images) with Flipout. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Properties: - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - seed: Python integer, used to create random seeds. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 32, 32, 3]) - net = tfp.layers.Conv2DFlipout(64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu)(net) - net = tf.layers.MaxPooling2D(pool_size=2, - strides=2, - padding="SAME")(net) - net = tf.reshape(net, [-1, 8 * 8 * 64]) - logits = tfp.layers.DenseFlipout(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse. - International Conference on Learning Representations, 2018. - """ - - @docstring_util.expand_docstring(args=doc_args) - def __init__( - self, - filters, - kernel_size, - strides=(1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - **kwargs): - # pylint: disable=g-doc-args - """Construct layer. - - Args: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or - `channels_first`. The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape `(batch, height, - width, channels)` while `channels_first` corresponds to inputs with - shape `(batch, channels, height, width)`. - dilation_rate: An integer or tuple/list of 2 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - @{args} - """ - # pylint: enable=g-doc-args - super(Conv2DFlipout, self).__init__( - rank=2, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, **kwargs) - - -@docstring_util.expand_docstring(args=doc_args) -def conv2d_flipout( - inputs, - filters, - kernel_size, - strides=(1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - reuse=None): - # pylint: disable=g-doc-args - """Functional interface for the 2D convolution layer. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - dilation_rate: An integer or tuple/list of 2 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - @{args} - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 32, 32, 3]) - net = tfp.layers.conv2d_flipout(net, - filters=64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu) - net = tf.layers.max_pooling2d(net, - pool_size=2, - strides=2, - padding="SAME") - net = tf.reshape(net, [-1, 8 * 8 * 64]) - logits = tfp.layers.dense_flipout(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse. - International Conference on Learning Representations, 2018. - """ - # pylint: enable=g-doc-args - layer = Conv2DFlipout( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class Conv3DFlipout(_ConvFlipout): - """3D convolution layer (e.g. spatial convolution over volumes) with Flipout. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Properties: - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - seed: Python integer, used to create random seeds. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 256, 32, 32, 3]) - net = tfp.layers.Conv3DFlipout(64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu)(net) - net = tf.layers.MaxPooling2D(pool_size=2, - strides=2, - padding="SAME")(net) - net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) - logits = tfp.layers.DenseFlipout(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse. - International Conference on Learning Representations, 2018. - """ - - @docstring_util.expand_docstring(args=doc_args) - def __init__( - self, - filters, - kernel_size, - strides=(1, 1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - **kwargs): - # pylint: disable=g-doc-args - """Construct layer. - - Args: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 3 integers, specifying the - depth, height and width of the 3D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the convolution along the depth, - height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or - `channels_first`. The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape `(batch, depth, - height, width, channels)` while `channels_first` corresponds to inputs - with shape `(batch, channels, depth, height, width)`. - dilation_rate: An integer or tuple/list of 3 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - @{args} - """ - # pylint: enable=g-doc-args - super(Conv3DFlipout, self).__init__( - rank=3, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, **kwargs) - - -@docstring_util.expand_docstring(args=doc_args) -def conv3d_flipout( - inputs, - filters, - kernel_size, - strides=(1, 1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - reuse=None): - # pylint: disable=g-doc-args - """Functional interface for the 3D convolution layer. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 3 integers, specifying the - depth, height and width of the 3D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the convolution along the depth, - height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - dilation_rate: An integer or tuple/list of 3 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - @{args} - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 256, 32, 32, 3]) - net = tfp.layers.conv3d_flipout(net, - filters=64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu) - net = tf.layers.max_pooling2d(net, - pool_size=2, - strides=2, - padding="SAME") - net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) - logits = tfp.layers.dense_flipout(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse. - International Conference on Learning Representations, 2018. - """ - # pylint: enable=g-doc-args - layer = Conv3DFlipout( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -# Aliases - -Convolution1DReparameterization = Conv1DReparameterization -Convolution2DReparameterization = Conv2DReparameterization -Convolution3DReparameterization = Conv3DReparameterization -convolution1d_reparameterization = conv1d_reparameterization -convolution2d_reparameterization = conv2d_reparameterization -convolution3d_reparameterization = conv3d_reparameterization -Convolution1DFlipout = Conv1DFlipout -Convolution2DFlipout = Conv2DFlipout -Convolution3DFlipout = Conv3DFlipout -convolution1d_flipout = conv1d_flipout -convolution2d_flipout = conv2d_flipout -convolution3d_flipout = conv3d_flipout diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py deleted file mode 100644 index 1f1d8fda2a5db4db33a2b6e5d7f027c4b509011a..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py +++ /dev/null @@ -1,955 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Dense Bayesian layer using KL-divergence based variational inference. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.bayesflow.python.ops import docstring_util -from tensorflow.contrib.bayesflow.python.ops import layers_util -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.layers import base as layers_lib -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import standard_ops -from tensorflow.python.ops.distributions import kullback_leibler as kl_lib -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.ops.distributions import util as distribution_util - - -doc_args = """units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name.""" - - -class _DenseVariational(layers_lib.Layer): - """Abstract densely-connected class (private, used as implementation base). - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Properties: - units: Python integer, dimensionality of the output space. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - """ - - @docstring_util.expand_docstring(args=doc_args) - def __init__( - self, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - # pylint: disable=g-doc-args - """Construct layer. - - Args: - @{args} - """ - # pylint: enable=g-doc-args - super(_DenseVariational, self).__init__( - trainable=trainable, - name=name, - activity_regularizer=activity_regularizer, - **kwargs) - self.units = units - self.activation = activation - self.input_spec = layers_lib.InputSpec(min_ndim=2) - self.kernel_posterior_fn = kernel_posterior_fn - self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn - self.kernel_prior_fn = kernel_prior_fn - self.kernel_divergence_fn = kernel_divergence_fn - self.bias_posterior_fn = bias_posterior_fn - self.bias_posterior_tensor_fn = bias_posterior_tensor_fn - self.bias_prior_fn = bias_prior_fn - self.bias_divergence_fn = bias_divergence_fn - - def build(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape) - in_size = input_shape.with_rank_at_least(2)[-1].value - if in_size is None: - raise ValueError("The last dimension of the inputs to `Dense` " - "should be defined. Found `None`.") - self._input_spec = layers_lib.InputSpec(min_ndim=2, axes={-1: in_size}) - dtype = dtypes.as_dtype(self.dtype) - - # Must have a posterior kernel. - self.kernel_posterior = self.kernel_posterior_fn( - dtype, [in_size, self.units], "kernel_posterior", - self.trainable, self.add_variable) - - if self.kernel_prior_fn is None: - self.kernel_prior = None - else: - self.kernel_prior = self.kernel_prior_fn( - dtype, [in_size, self.units], "kernel_prior", - self.trainable, self.add_variable) - self._built_kernel_divergence = False - - if self.bias_posterior_fn is None: - self.bias_posterior = None - else: - self.bias_posterior = self.bias_posterior_fn( - dtype, [self.units], "bias_posterior", - self.trainable, self.add_variable) - - if self.bias_prior_fn is None: - self.bias_prior = None - else: - self.bias_prior = self.bias_prior_fn( - dtype, [self.units], "bias_prior", - self.trainable, self.add_variable) - self._built_bias_divergence = False - - self.built = True - - def call(self, inputs): - inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) - - outputs = self._apply_variational_kernel(inputs) - outputs = self._apply_variational_bias(outputs) - if self.activation is not None: - outputs = self.activation(outputs) # pylint: disable=not-callable - if not self._built_kernel_divergence: - kernel_posterior = self.kernel_posterior - kernel_prior = self.kernel_prior - if isinstance(self.kernel_posterior, independent_lib.Independent): - kernel_posterior = kernel_posterior.distribution - if isinstance(self.kernel_prior, independent_lib.Independent): - kernel_prior = kernel_prior.distribution - self._apply_divergence(self.kernel_divergence_fn, - kernel_posterior, - kernel_prior, - self.kernel_posterior_tensor, - name="divergence_kernel") - self._built_kernel_divergence = True - if not self._built_bias_divergence: - bias_posterior = self.bias_posterior - bias_prior = self.bias_prior - if isinstance(self.bias_posterior, independent_lib.Independent): - bias_posterior = bias_posterior.distribution - if isinstance(self.bias_prior, independent_lib.Independent): - bias_prior = bias_prior.distribution - self._apply_divergence(self.bias_divergence_fn, - bias_posterior, - bias_prior, - self.bias_posterior_tensor, - name="divergence_bias") - self._built_bias_divergence = True - return outputs - - def _apply_variational_bias(self, inputs): - if self.bias_posterior is None: - self.bias_posterior_tensor = None - return inputs - self.bias_posterior_tensor = self.bias_posterior_tensor_fn( - self.bias_posterior) - return nn.bias_add(inputs, self.bias_posterior_tensor) - - def _apply_divergence(self, divergence_fn, posterior, prior, - posterior_tensor, name): - if (divergence_fn is None or - posterior is None or - prior is None): - divergence = None - return - divergence = standard_ops.identity( - divergence_fn( - posterior, prior, posterior_tensor), - name=name) - self.add_loss(divergence) - - def _matmul(self, inputs, kernel): - if inputs.shape.ndims <= 2: - return standard_ops.matmul(inputs, kernel) - # To handle broadcasting, we must use `tensordot`. - return standard_ops.tensordot(inputs, kernel, axes=[[-1], [0]]) - - def _compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).with_rank_at_least(2) - if input_shape[-1].value is None: - raise ValueError( - "The innermost dimension of input_shape must be defined, " - "but saw: {}".format(input_shape)) - return input_shape[:-1].concatenate(self.units) - - -class DenseReparameterization(_DenseVariational): - """Densely-connected layer class with reparameterization estimator. - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - It uses the reparameterization estimator [1], which performs a Monte Carlo - approximation of the distribution integrating over the `kernel` and - `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Properties: - units: Python integer, dimensionality of the output space. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tfp.layers.DenseReparameterization( - 512, activation=tf.nn.relu)(features) - logits = tfp.layers.DenseReparameterization(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - - @docstring_util.expand_docstring(args=doc_args) - def __init__( - self, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn( - is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - # pylint: disable=g-doc-args - """Construct layer. - - Args: - @{args} - """ - # pylint: enable=g-doc-args - super(DenseReparameterization, self).__init__( - units=units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - **kwargs) - - def _apply_variational_kernel(self, inputs): - self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn( - self.kernel_posterior) - self.kernel_posterior_affine = None - self.kernel_posterior_affine_tensor = None - return self._matmul(inputs, self.kernel_posterior_tensor) - - -@docstring_util.expand_docstring(args=doc_args) -def dense_reparameterization( - inputs, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - reuse=None): - # pylint: disable=g-doc-args - """Densely-connected layer with reparameterization estimator. - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - It uses the reparameterization estimator [1], which performs a Monte Carlo - approximation of the distribution integrating over the `kernel` and - `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - inputs: Tensor input. - @{args} - - Returns: - output: `Tensor` representing a the affine transformed input under a random - draw from the surrogate posterior distribution. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tfp.layers.dense_reparameterization( - features, 512, activation=tf.nn.relu) - logits = tfp.layers.dense_reparameterization(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - # pylint: enable=g-doc-args - layer = DenseReparameterization( - units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class DenseLocalReparameterization(_DenseVariational): - """Densely-connected layer class with local reparameterization estimator. - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - It uses the local reparameterization estimator [1], which performs a - Monte Carlo approximation of the distribution on the hidden units - induced by the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Properties: - units: Python integer, dimensionality of the output space. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tfp.layers.DenseLocalReparameterization( - 512, activation=tf.nn.relu)(features) - logits = tfp.layers.DenseLocalReparameterization(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses local reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Variational Dropout and the Local Reparameterization Trick." - Diederik P. Kingma, Tim Salimans, Max Welling. - Neural Information Processing Systems, 2015. - """ - - @docstring_util.expand_docstring(args=doc_args) - def __init__( - self, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn( - is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - # pylint: disable=g-doc-args - """Construct layer. - - Args: - @{args} - """ - # pylint: enable=g-doc-args - super(DenseLocalReparameterization, self).__init__( - units=units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - **kwargs) - - def _apply_variational_kernel(self, inputs): - if (not isinstance(self.kernel_posterior, independent_lib.Independent) or - not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): - raise TypeError( - "`DenseLocalReparameterization` requires " - "`kernel_posterior_fn` produce an instance of " - "`tf.distributions.Independent(tf.distributions.Normal)` " - "(saw: \"{}\").".format(self.kernel_posterior.name)) - self.kernel_posterior_affine = normal_lib.Normal( - loc=self._matmul(inputs, self.kernel_posterior.distribution.loc), - scale=standard_ops.sqrt(self._matmul( - standard_ops.square(inputs), - standard_ops.square(self.kernel_posterior.distribution.scale)))) - self.kernel_posterior_affine_tensor = ( - self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) - self.kernel_posterior_tensor = None - return self.kernel_posterior_affine_tensor - - -@docstring_util.expand_docstring(args=doc_args) -def dense_local_reparameterization( - inputs, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn( - is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - reuse=None): - # pylint: disable=g-doc-args - """Densely-connected layer with local reparameterization estimator. - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - It uses the local reparameterization estimator [1], which performs a - Monte Carlo approximation of the distribution on the hidden units - induced by the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - inputs: Tensor input. - @{args} - - Returns: - output: `Tensor` representing a the affine transformed input under a random - draw from the surrogate posterior distribution. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tfp.layers.dense_local_reparameterization( - features, 512, activation=tf.nn.relu) - logits = tfp.layers.dense_local_reparameterization(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses local reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Variational Dropout and the Local Reparameterization Trick." - Diederik P. Kingma, Tim Salimans, Max Welling. - Neural Information Processing Systems, 2015. - """ - # pylint: enable=g-doc-args - layer = DenseLocalReparameterization( - units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class DenseFlipout(_DenseVariational): - """Densely-connected layer class with Flipout estimator. - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - It uses the Flipout estimator [1], which performs a Monte Carlo - approximation of the distribution integrating over the `kernel` and - `bias`. Flipout uses roughly twice as many floating point operations - as the reparameterization estimator but has the advantage of - significantly lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Properties: - units: Python integer, dimensionality of the output space. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - seed: Python integer, used to create random seeds. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tfp.layers.DenseFlipout( - 512, activation=tf.nn.relu)(features) - logits = tfp.layers.DenseFlipout(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - - @docstring_util.expand_docstring(args=doc_args) - def __init__( - self, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn( - is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - **kwargs): - # pylint: disable=g-doc-args - """Construct layer. - - Args: - @{args} - """ - # pylint: enable=g-doc-args - super(DenseFlipout, self).__init__( - units=units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - **kwargs) - self.seed = seed - - def _apply_variational_kernel(self, inputs): - if (not isinstance(self.kernel_posterior, independent_lib.Independent) or - not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): - raise TypeError( - "`DenseFlipout` requires " - "`kernel_posterior_fn` produce an instance of " - "`tf.distributions.Independent(tf.distributions.Normal)` " - "(saw: \"{}\").".format(self.kernel_posterior.name)) - self.kernel_posterior_affine = normal_lib.Normal( - loc=array_ops.zeros_like(self.kernel_posterior.distribution.loc), - scale=self.kernel_posterior.distribution.scale) - self.kernel_posterior_affine_tensor = ( - self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) - self.kernel_posterior_tensor = None - - input_shape = array_ops.shape(inputs) - batch_shape = input_shape[:-1] - - sign_input = layers_util.random_sign( - input_shape, - dtype=inputs.dtype, - seed=self.seed) - sign_output = layers_util.random_sign( - array_ops.concat([batch_shape, - array_ops.expand_dims(self.units, 0)], 0), - dtype=inputs.dtype, - seed=distribution_util.gen_new_seed( - self.seed, salt="dense_flipout")) - perturbed_inputs = self._matmul( - inputs * sign_input, self.kernel_posterior_affine_tensor) * sign_output - - outputs = self._matmul(inputs, self.kernel_posterior.distribution.loc) - outputs += perturbed_inputs - return outputs - - -@docstring_util.expand_docstring(args=doc_args) -def dense_flipout( - inputs, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn( - is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - reuse=None): - # pylint: disable=g-doc-args - """Densely-connected layer with Flipout estimator. - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - It uses the Flipout estimator [1], which performs a Monte Carlo - approximation of the distribution integrating over the `kernel` and - `bias`. Flipout uses roughly twice as many floating point operations - as the reparameterization estimator but has the advantage of - significantly lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - inputs: Tensor input. - @{args} - - Returns: - output: `Tensor` representing a the affine transformed input under a random - draw from the surrogate posterior distribution. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tfp.layers.dense_flipout( - features, 512, activation=tf.nn.relu) - logits = tfp.layers.dense_flipout(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - # pylint: enable=g-doc-args - layer = DenseFlipout( - units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_util.py b/tensorflow/contrib/bayesflow/python/ops/layers_util.py deleted file mode 100644 index 8c1fb203f7328e8260e49b4326d813fbe133613e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/layers_util.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities for probabilistic layers. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import normal as normal_lib - - -def default_loc_scale_fn( - is_singular=False, - loc_initializer=init_ops.random_normal_initializer(stddev=0.1), - untransformed_scale_initializer=init_ops.random_normal_initializer( - mean=-3., stddev=0.1), - loc_regularizer=None, - untransformed_scale_regularizer=None, - loc_constraint=None, - untransformed_scale_constraint=None): - """Makes closure which creates `loc`, `scale` params from `tf.get_variable`. - - This function produces a closure which produces `loc`, `scale` using - `tf.get_variable`. The closure accepts the following arguments: - - dtype: Type of parameter's event. - shape: Python `list`-like representing the parameter's event shape. - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Args: - is_singular: Python `bool` indicating if `scale is None`. Default: `False`. - loc_initializer: Initializer function for the `loc` parameters. - The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`. - untransformed_scale_initializer: Initializer function for the `scale` - parameters. Default value: `tf.random_normal_initializer(mean=-3., - stddev=0.1)`. This implies the softplus transformed result has mean - approximately `0.05` and std. deviation approximately `0.005`. - loc_regularizer: Regularizer function for the `loc` parameters. - The default (`None`) is to use the `tf.get_variable` default. - untransformed_scale_regularizer: Regularizer function for the `scale` - parameters. The default (`None`) is to use the `tf.get_variable` default. - loc_constraint: An optional projection function to be applied to the - loc after being updated by an `Optimizer`. The function must take as input - the unprojected variable and must return the projected variable (which - must have the same shape). Constraints are not safe to use when doing - asynchronous distributed training. - The default (`None`) is to use the `tf.get_variable` default. - untransformed_scale_constraint: An optional projection function to be - applied to the `scale` parameters after being updated by an `Optimizer` - (e.g. used to implement norm constraints or value constraints). The - function must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are not - safe to use when doing asynchronous distributed training. The default - (`None`) is to use the `tf.get_variable` default. - - Returns: - default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale` - parameters from args: `dtype, shape, name, trainable, add_variable_fn`. - """ - def _fn(dtype, shape, name, trainable, add_variable_fn): - """Creates `loc`, `scale` parameters.""" - loc = add_variable_fn( - name=name + "_loc", - shape=shape, - initializer=loc_initializer, - regularizer=loc_regularizer, - constraint=loc_constraint, - dtype=dtype, - trainable=trainable) - if is_singular: - return loc, None - untransformed_scale = add_variable_fn( - name=name + "_untransformed_scale", - shape=shape, - initializer=untransformed_scale_initializer, - regularizer=untransformed_scale_regularizer, - constraint=untransformed_scale_constraint, - dtype=dtype, - trainable=trainable) - scale = (np.finfo(dtype.as_numpy_dtype).eps + - nn_ops.softplus(untransformed_scale)) - return loc, scale - return _fn - - -def default_mean_field_normal_fn( - is_singular=False, - loc_initializer=None, - untransformed_scale_initializer=None, - loc_regularizer=None, - untransformed_scale_regularizer=None, - loc_constraint=None, - untransformed_scale_constraint=None): - """Creates a function to build Normal distributions with trainable params. - - This function produces a closure which produces `tf.distributions.Normal` - parameterized by a loc` and `scale` each created using `tf.get_variable`. The - produced closure accepts the following arguments: - - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - shape: Python `list`-like representing the parameter's event shape. - dtype: Type of parameter's event. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Args: - is_singular: Python `bool` if `True`, forces the special case limit of - `scale->0`, i.e., a `Deterministic` distribution. - loc_initializer: Initializer function for the `loc` parameters. - If `None` (default), values are initialized using the default - initializer used by `tf.get_variable`. - untransformed_scale_initializer: Initializer function for the `scale` - parameters. If `None` (default), values are initialized using the default - initializer used by `tf.get_variable`. - loc_regularizer: Regularizer function for the `loc` parameters. - untransformed_scale_regularizer: Regularizer function for the `scale` - parameters. - loc_constraint: An optional projection function to be applied to the - loc after being updated by an `Optimizer`. The function must take as input - the unprojected variable and must return the projected variable (which - must have the same shape). Constraints are not safe to use when doing - asynchronous distributed training. - untransformed_scale_constraint: An optional projection function to be - applied to the `scale` parameters after being updated by an `Optimizer` - (e.g. used to implement norm constraints or value constraints). The - function must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are not - safe to use when doing asynchronous distributed training. - - Returns: - make_normal_fn: Python `callable` which creates a `tf.distributions.Normal` - using from args: `dtype, shape, name, trainable, add_variable_fn`. - """ - loc_scale_fn_ = default_loc_scale_fn( - is_singular, - loc_initializer, - untransformed_scale_initializer, - loc_regularizer, - untransformed_scale_regularizer, - loc_constraint, - untransformed_scale_constraint) - def _fn(dtype, shape, name, trainable, add_variable_fn): - """Creates multivariate `Deterministic` or `Normal` distribution.""" - loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn) - if scale is None: - dist = deterministic_lib.Deterministic(loc=loc) - else: - dist = normal_lib.Normal(loc=loc, scale=scale) - reinterpreted_batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0] - return independent_lib.Independent( - dist, reinterpreted_batch_ndims=reinterpreted_batch_ndims) - return _fn - - -def random_sign(shape, dtype=dtypes.float32, seed=None): - """Draw values from {-1, 1} uniformly, i.e., Rademacher distribution.""" - random_bernoulli = random_ops.random_uniform(shape, minval=0, maxval=2, - dtype=dtypes.int32, - seed=seed) - return math_ops.cast(2 * random_bernoulli - 1, dtype) diff --git a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py deleted file mode 100644 index f3a645eafc249d1c39e0d4a238ae7ec8755c78d8..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities for Markov Chain Monte Carlo (MCMC) sampling.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.mcmc_diagnostics_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = [ - "effective_sample_size", - "potential_scale_reduction", -] - -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py deleted file mode 100644 index 0424b6952bc89ce7fe5b00b0135c9a5fe1faa8cf..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py +++ /dev/null @@ -1,400 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities for Markov Chain Monte Carlo (MCMC) sampling. - -@@effective_sample_size -@@potential_scale_reduction -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distributions.python.ops import sample_stats -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops - -__all__ = [ - "effective_sample_size", - "potential_scale_reduction", -] - - -def effective_sample_size(states, - filter_threshold=0., - filter_beyond_lag=None, - name=None): - """Estimate a lower bound on effective sample size for each independent chain. - - Roughly speaking, "effective sample size" (ESS) is the size of an iid sample - with the same variance as `state`. - - More precisely, given a stationary sequence of possibly correlated random - variables `X_1, X_2,...,X_N`, each identically distributed ESS is the number - such that - - ```Variance{ N**-1 * Sum{X_i} } = ESS**-1 * Variance{ X_1 }.``` - - If the sequence is uncorrelated, `ESS = N`. In general, one should expect - `ESS <= N`, with more highly correlated sequences having smaller `ESS`. - - #### Example of using ESS to estimate standard error. - - ``` - tfd = tf.contrib.distributions - tfb = tf.contrib.bayesflow - - target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.]) - - # Get 1000 states from one chain. - states = tfb.hmc.sample_chain( - num_results=1000, - target_log_prob_fn=target.log_prob, - current_state=tf.constant([0., 0.]), - step_size=0.05, - num_leapfrog_steps=20, - num_burnin_steps=200) - states.shape - ==> (1000, 2) - - ess = effective_sample_size(states) - ==> Shape (2,) Tensor - - mean, variance = tf.nn.moments(states, axis=0) - standard_error = tf.sqrt(variance / ess) - ``` - - Some math shows that, with `R_k` the auto-correlation sequence, - `R_k := Covariance{X_1, X_{1+k}} / Variance{X_1}`, we have - - ```ESS(N) = N / [ 1 + 2 * ( (N - 1) / N * R_1 + ... + 1 / N * R_{N-1} ) ]``` - - This function estimates the above by first estimating the auto-correlation. - Since `R_k` must be estimated using only `N - k` samples, it becomes - progressively noisier for larger `k`. For this reason, the summation over - `R_k` should be truncated at some number `filter_beyond_lag < N`. Since many - MCMC methods generate chains where `R_k > 0`, a reasonable critera is to - truncate at the first index where the estimated auto-correlation becomes - negative. - - The arguments `filter_beyond_lag`, `filter_threshold` are filters intended to - remove noisy tail terms from `R_k`. They combine in an "OR" manner meaning - terms are removed if they were to be filtered under the `filter_beyond_lag` OR - `filter_threshold` criteria. - - Args: - states: `Tensor` or list of `Tensor` objects. Dimension zero should index - identically distributed states. - filter_threshold: `Tensor` or list of `Tensor` objects. - Must broadcast with `state`. The auto-correlation sequence is truncated - after the first appearance of a term less than `filter_threshold`. - Setting to `None` means we use no threshold filter. Since `|R_k| <= 1`, - setting to any number less than `-1` has the same effect. - filter_beyond_lag: `Tensor` or list of `Tensor` objects. Must be - `int`-like and scalar valued. The auto-correlation sequence is truncated - to this length. Setting to `None` means we do not filter based on number - of lags. - name: `String` name to prepend to created ops. - - Returns: - ess: `Tensor` or list of `Tensor` objects. The effective sample size of - each component of `states`. Shape will be `states.shape[1:]`. - - Raises: - ValueError: If `states` and `filter_threshold` or `states` and - `filter_beyond_lag` are both lists with different lengths. - """ - states_was_list = _is_list_like(states) - - # Convert all args to lists. - if not states_was_list: - states = [states] - - filter_beyond_lag = _broadcast_maybelist_arg(states, filter_beyond_lag, - "filter_beyond_lag") - filter_threshold = _broadcast_maybelist_arg(states, filter_threshold, - "filter_threshold") - - # Process items, one at a time. - with ops.name_scope(name, "effective_sample_size"): - ess_list = [ - _effective_sample_size_single_state(s, ml, mlt) - for (s, ml, mlt) in zip(states, filter_beyond_lag, filter_threshold) - ] - - if states_was_list: - return ess_list - return ess_list[0] - - -def _effective_sample_size_single_state(states, filter_beyond_lag, - filter_threshold): - """ESS computation for one single Tensor argument.""" - - with ops.name_scope( - "effective_sample_size_single_state", - values=[states, filter_beyond_lag, filter_threshold]): - - states = ops.convert_to_tensor(states, name="states") - dt = states.dtype - - # filter_beyond_lag == None ==> auto_corr is the full sequence. - auto_corr = sample_stats.auto_correlation( - states, axis=0, max_lags=filter_beyond_lag) - if filter_threshold is not None: - filter_threshold = ops.convert_to_tensor( - filter_threshold, dtype=dt, name="filter_threshold") - # Get a binary mask to zero out values of auto_corr below the threshold. - # mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i, - # mask[i, ...] = 0, otherwise. - # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...] - # Building step by step, - # Assume auto_corr = [1, 0.5, 0.0, 0.3], and filter_threshold = 0.2. - # Step 1: mask = [False, False, True, False] - mask = auto_corr < filter_threshold - # Step 2: mask = [0, 0, 1, 1] - mask = math_ops.cast(mask, dtype=dt) - # Step 3: mask = [0, 0, 1, 2] - mask = math_ops.cumsum(mask, axis=0) - # Step 4: mask = [1, 1, 0, 0] - mask = math_ops.maximum(1. - mask, 0.) - auto_corr *= mask - - # With R[k] := auto_corr[k, ...], - # ESS = N / {1 + 2 * Sum_{k=1}^N (N - k) / N * R[k]} - # = N / {-1 + 2 * Sum_{k=0}^N (N - k) / N * R[k]} (since R[0] = 1) - # approx N / {-1 + 2 * Sum_{k=0}^M (N - k) / N * R[k]} - # where M is the filter_beyond_lag truncation point chosen above. - - # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total - # ndims the same as auto_corr - n = _axis_size(states, axis=0) - k = math_ops.range(0., _axis_size(auto_corr, axis=0)) - nk_factor = (n - k) / n - if auto_corr.shape.ndims is not None: - new_shape = [-1] + [1] * (auto_corr.shape.ndims - 1) - else: - new_shape = array_ops.concat( - ([-1], - array_ops.ones([array_ops.rank(auto_corr) - 1], dtype=dtypes.int32)), - axis=0) - nk_factor = array_ops.reshape(nk_factor, new_shape) - - return n / (-1 + 2 * math_ops.reduce_sum(nk_factor * auto_corr, axis=0)) - - -def potential_scale_reduction(chains_states, - independent_chain_ndims=1, - name=None): - """Gelman and Rubin's potential scale reduction factor for chain convergence. - - Given `N > 1` states from each of `C > 1` independent chains, the potential - scale reduction factor, commonly referred to as R-hat, measures convergence of - the chains (to the same target) by testing for equality of means. - Specifically, R-hat measures the degree to which variance (of the means) - between chains exceeds what one would expect if the chains were identically - distributed. See [1], [2]. - - Some guidelines: - - * The initial state of the chains should be drawn from a distribution - overdispersed with respect to the target. - * If all chains converge to the target, then as `N --> infinity`, R-hat --> 1. - Before that, R-hat > 1 (except in pathological cases, e.g. if the chain - paths were identical). - * The above holds for any number of chains `C > 1`. Increasing `C` does - improves effectiveness of the diagnostic. - * Sometimes, R-hat < 1.2 is used to indicate approximate convergence, but of - course this is problem depedendent. See [2]. - * R-hat only measures non-convergence of the mean. If higher moments, or other - statistics are desired, a different diagnostic should be used. See [2]. - - #### Examples - - Diagnosing convergence by monitoring 10 chains that each attempt to - sample from a 2-variate normal. - - ```python - tfd = tf.contrib.distributions - tfb = tf.contrib.bayesflow - - target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.]) - - # Get 10 (2x) overdispersed initial states. - initial_state = target.sample(10) * 2. - ==> (10, 2) - - # Get 1000 samples from the 10 independent chains. - chains_states, _ = tfb.hmc.sample_chain( - num_results=1000, - target_log_prob_fn=target.log_prob, - current_state=initial_state, - step_size=0.05, - num_leapfrog_steps=20, - num_burnin_steps=200) - chains_states.shape - ==> (1000, 10, 2) - - rhat = tfb.mcmc_diagnostics.potential_scale_reduction( - chains_states, independent_chain_ndims=1) - - # The second dimension needed a longer burn-in. - rhat.eval() - ==> [1.05, 1.3] - ``` - - To see why R-hat is reasonable, let `X` be a random variable drawn uniformly - from the combined states (combined over all chains). Then, in the limit - `N, C --> infinity`, with `E`, `Var` denoting expectation and variance, - - ```R-hat = ( E[Var[X | chain]] + Var[E[X | chain]] ) / E[Var[X | chain]].``` - - Using the law of total variance, the numerator is the variance of the combined - states, and the denominator is the total variance minus the variance of the - the individual chain means. If the chains are all drawing from the same - distribution, they will have the same mean, and thus the ratio should be one. - - [1] "Inference from Iterative Simulation Using Multiple Sequences" - Andrew Gelman and Donald B. Rubin - Statist. Sci. Volume 7, Number 4 (1992), 457-472. - [2] "General Methods for Monitoring Convergence of Iterative Simulations" - Stephen P. Brooks and Andrew Gelman - Journal of Computational and Graphical Statistics, 1998. Vol 7, No. 4. - - Args: - chains_states: `Tensor` or Python `list` of `Tensor`s representing the - state(s) of a Markov Chain at each result step. The `ith` state is - assumed to have shape `[Ni, Ci1, Ci2,...,CiD] + A`. - Dimension `0` indexes the `Ni > 1` result steps of the Markov Chain. - Dimensions `1` through `D` index the `Ci1 x ... x CiD` independent - chains to be tested for convergence to the same target. - The remaining dimensions, `A`, can have any shape (even empty). - independent_chain_ndims: Integer type `Tensor` with value `>= 1` giving the - number of giving the number of dimensions, from `dim = 1` to `dim = D`, - holding independent chain results to be tested for convergence. - name: `String` name to prepend to created ops. Default: - `potential_scale_reduction`. - - Returns: - `Tensor` or Python `list` of `Tensor`s representing the R-hat statistic for - the state(s). Same `dtype` as `state`, and shape equal to - `state.shape[1 + independent_chain_ndims:]`. - - Raises: - ValueError: If `independent_chain_ndims < 1`. - """ - chains_states_was_list = _is_list_like(chains_states) - if not chains_states_was_list: - chains_states = [chains_states] - - # tensor_util.constant_value returns None iff a constant value (as a numpy - # array) is not efficiently computable. Therefore, we try constant_value then - # check for None. - icn_const_ = tensor_util.constant_value( - ops.convert_to_tensor(independent_chain_ndims)) - if icn_const_ is not None: - independent_chain_ndims = icn_const_ - if icn_const_ < 1: - raise ValueError( - "Argument `independent_chain_ndims` must be `>= 1`, found: {}".format( - independent_chain_ndims)) - - with ops.name_scope(name, "potential_scale_reduction"): - rhat_list = [ - _potential_scale_reduction_single_state(s, independent_chain_ndims) - for s in chains_states - ] - - if chains_states_was_list: - return rhat_list - return rhat_list[0] - - -def _potential_scale_reduction_single_state(state, independent_chain_ndims): - """potential_scale_reduction for one single state `Tensor`.""" - with ops.name_scope( - "potential_scale_reduction_single_state", - values=[state, independent_chain_ndims]): - # We assume exactly one leading dimension indexes e.g. correlated samples - # from each Markov chain. - state = ops.convert_to_tensor(state, name="state") - sample_ndims = 1 - - sample_axis = math_ops.range(0, sample_ndims) - chain_axis = math_ops.range(sample_ndims, - sample_ndims + independent_chain_ndims) - sample_and_chain_axis = math_ops.range( - 0, sample_ndims + independent_chain_ndims) - - n = _axis_size(state, sample_axis) - m = _axis_size(state, chain_axis) - - # In the language of [2], - # B / n is the between chain variance, the variance of the chain means. - # W is the within sequence variance, the mean of the chain variances. - b_div_n = _reduce_variance( - math_ops.reduce_mean(state, sample_axis, keepdims=True), - sample_and_chain_axis, - biased=False) - w = math_ops.reduce_mean( - _reduce_variance(state, sample_axis, keepdims=True, biased=True), - sample_and_chain_axis) - - # sigma^2_+ is an estimate of the true variance, which would be unbiased if - # each chain was drawn from the target. c.f. "law of total variance." - sigma_2_plus = w + b_div_n - - return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n) - - -# TODO(b/72873233) Move some variant of this to sample_stats. -def _reduce_variance(x, axis=None, biased=True, keepdims=False): - with ops.name_scope("reduce_variance"): - x = ops.convert_to_tensor(x, name="x") - mean = math_ops.reduce_mean(x, axis=axis, keepdims=True) - biased_var = math_ops.reduce_mean( - math_ops.squared_difference(x, mean), axis=axis, keepdims=keepdims) - if biased: - return biased_var - n = _axis_size(x, axis) - return (n / (n - 1.)) * biased_var - - -def _axis_size(x, axis=None): - """Get number of elements of `x` in `axis`, as type `x.dtype`.""" - if axis is None: - return math_ops.cast(array_ops.size(x), x.dtype) - return math_ops.cast( - math_ops.reduce_prod(array_ops.gather(array_ops.shape(x), axis)), x.dtype) - - -def _is_list_like(x): - """Helper which returns `True` if input is `list`-like.""" - return isinstance(x, (tuple, list)) - - -def _broadcast_maybelist_arg(states, secondary_arg, name): - """Broadcast a listable secondary_arg to that of states.""" - if _is_list_like(secondary_arg): - if len(secondary_arg) != len(states): - raise ValueError("Argument `%s` was a list of different length ({}) than " - "`states` ({})".format(name, len(states))) - else: - secondary_arg = [secondary_arg] * len(states) - - return secondary_arg diff --git a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py b/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py deleted file mode 100644 index e7fcbc65ef379e84a140a06e020549f74f905a99..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functions to create a Markov Chain Monte Carlo Metropolis step.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.metropolis_hastings_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = [ - 'kernel', - 'evolve', - 'proposal_uniform', - 'proposal_normal', -] - -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py b/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py deleted file mode 100644 index 05aa134ed5c11092316af5f3e45ba07fdb491e90..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py +++ /dev/null @@ -1,527 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Metropolis-Hastings and proposal distributions. - -@@kernel -@@evolve -@@proposal_uniform -@@proposal_normal -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import state_ops - -__all__ = [ - "kernel", - "evolve", - "proposal_uniform", - "proposal_normal", -] - - -KernelResults = collections.namedtuple( - "KernelResults", - [ - "log_accept_ratio", - "current_target_log_prob", # "Current result" means "accepted". - "is_accepted", - "proposed_state", - ]) - - -def kernel(target_log_prob_fn, - proposal_fn, - current_state, - seed=None, - current_target_log_prob=None, - name=None): - """Runs the Metropolis-Hastings transition kernel. - - This function can update multiple chains in parallel. It assumes that all - leftmost dimensions of `current_state` index independent chain states (and are - therefore updated independently). The output of `target_log_prob_fn()` should - sum log-probabilities across all event dimensions. Slices along the rightmost - dimensions may have different target distributions; for example, - `current_state[0, :]` could have a different target distribution from - `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of - independent chains is `tf.size(target_log_prob_fn(*current_state))`.) - - Args: - target_log_prob_fn: Python callable which takes an argument like - `current_state` (or `*current_state` if it's a list) and returns its - (possibly unnormalized) log-density under the target distribution. - proposal_fn: Python callable which takes an argument like `current_state` - (or `*current_state` if it's a list) and returns a tuple of proposed - states of same shape as `state`, and a log ratio `Tensor` of same shape - as `current_target_log_prob`. The log ratio is the log-probability of - `state` given proposed states minus the log-probability of proposed - states given `state`. If the proposal is symmetric, set the second value - to `None`: this enables more efficient computation than explicitly - supplying a tensor of zeros. - current_state: `Tensor` or Python `list` of `Tensor`s representing the - current state(s) of the Markov chain(s). The first `r` dimensions index - independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. - seed: Python integer to seed the random number generator. - current_target_log_prob: (Optional) `Tensor` representing the value of - `target_log_prob_fn` at the `current_state`. The only reason to - specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - name: A name of the operation (optional). - - Returns: - next_state: Tensor or Python list of `Tensor`s representing the state(s) - of the Markov chain(s) at each result step. Has same shape as - `current_state`. - kernel_results: `collections.namedtuple` of internal calculations used to - advance the chain. - - #### Examples - - We illustrate Metropolis-Hastings on a Normal likelihood with - unknown mean. - - ```python - tfd = tf.contrib.distributions - tfp = tf.contrib.bayesflow - - loc = tf.get_variable("loc", initializer=1.) - x = tf.constant([0.0] * 50) - - def make_target_log_prob_fn(x): - def target_log_prob_fn(loc): - prior = tfd.Normal(loc=0., scale=1.) - likelihood = tfd.Independent( - tfd.Normal(loc=loc, scale=0.1), - reinterpreted_batch_ndims=1) - return prior.log_prob(loc) + likelihood.log_prob(x) - return target_log_prob_fn - - next_state, kernel_results = tfp.metropolis_hastings.kernel( - target_log_prob_fn=make_target_log_prob_fn(x), - proposal_fn=tfp.metropolis_hastings.proposal_normal(), - current_state=loc) - loc_update = loc.assign(next_state) - ``` - - We illustrate Metropolis-Hastings on a Normal likelihood with - unknown mean and variance. We apply 4 chains. - - ```python - tfd = tf.contrib.distributions - tfp = tf.contrib.bayesflow - - num_chains = 4 - loc = tf.get_variable("loc", shape=[num_chains], - initializer=tf.random_normal_initializer()) - scale = tf.get_variable("scale", shape=[num_chains], - initializer=tf.ones_initializer()) - x = tf.constant([0.0] * 50) - - def make_target_log_prob_fn(x): - data = tf.reshape(x, shape=[-1, 1]) - def target_log_prob_fn(loc, scale): - prior_loc = tfd.Normal(loc=0., scale=1.) - prior_scale = tfd.InverseGamma(concentration=1., rate=1.) - likelihood = tfd.Independent( - tfd.Normal(loc=loc, scale=scale), - reinterpreted_batch_ndims=1) - return (prior_loc.log_prob(loc) + - prior_scale.log_prob(scale) + - likelihood.log_prob(data)) - return target_log_prob_fn - - def proposal_fn(loc, scale): - loc_proposal = tfp.metropolis_hastings.proposal_normal() - scale_proposal = tfp.metropolis_hastings.proposal_uniform(minval=-1.) - proposed_loc, _ = loc_proposal(loc) - proposed_scale, _ = scale_proposal(scale) - proposed_scale = tf.maximum(proposed_scale, 0.01) - return [proposed_loc, proposed_scale], None - - next_state, kernel_results = tfp.metropolis_hastings.kernel( - target_log_prob_fn=make_target_log_prob_fn(x), - proposal_fn=proposal_fn, - current_state=[loc, scale]) - train_op = tf.group(loc.assign(next_state[0]), - scale.assign(next_state[1])) - ``` - - """ - with ops.name_scope( - name, "metropolis_hastings_kernel", - [current_state, seed, current_target_log_prob]): - with ops.name_scope("initialize"): - maybe_expand = lambda x: list(x) if _is_list_like(x) else [x] - current_state_parts = maybe_expand(current_state) - if current_target_log_prob is None: - current_target_log_prob = target_log_prob_fn(*current_state_parts) - - proposed_state, log_transit_ratio = proposal_fn(*current_state_parts) - proposed_state_parts = maybe_expand(proposed_state) - - proposed_target_log_prob = target_log_prob_fn(*proposed_state_parts) - - with ops.name_scope( - "accept_reject", - [current_state_parts, proposed_state_parts, - current_target_log_prob, proposed_target_log_prob]): - log_accept_ratio = proposed_target_log_prob - current_target_log_prob - if log_transit_ratio is not None: - # If the log_transit_ratio is None, then assume the proposal is - # symmetric, i.e., - # log p(old | new) - log p(new | old) = 0. - log_accept_ratio += log_transit_ratio - - # u < exp(log_accept_ratio), where u~Uniform[0,1) - # ==> log(u) < log_accept_ratio - random_value = random_ops.random_uniform( - array_ops.shape(log_accept_ratio), - dtype=log_accept_ratio.dtype, - seed=seed) - random_negative = math_ops.log(random_value) - is_accepted = random_negative < log_accept_ratio - next_state_parts = [array_ops.where(is_accepted, - proposed_state_part, - current_state_part) - for proposed_state_part, current_state_part in - zip(proposed_state_parts, current_state_parts)] - accepted_log_prob = array_ops.where(is_accepted, - proposed_target_log_prob, - current_target_log_prob) - maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0] - return [ - maybe_flatten(next_state_parts), - KernelResults( - log_accept_ratio=log_accept_ratio, - current_target_log_prob=accepted_log_prob, - is_accepted=is_accepted, - proposed_state=maybe_flatten(proposed_state_parts), - ), - ] - - -def evolve(initial_sample, - initial_log_density, - initial_log_accept_ratio, - target_log_prob_fn, - proposal_fn, - n_steps=1, - seed=None, - name=None): - """Performs `n_steps` of the Metropolis-Hastings update. - - Given a probability density function, `f(x)` and a proposal scheme which - generates new points from old, this `Op` returns a tensor - which may be used to generate approximate samples from the target distribution - using the Metropolis-Hastings algorithm. These samples are from a Markov chain - whose equilibrium distribution matches the target distribution. - - The probability distribution may have an unknown normalization constan. - We parameterize the probability density as follows: - - ```none - f(x) = exp(L(x) + constant) - ``` - - Here `L(x)` is any continuous function with an (possibly unknown but finite) - upper bound, i.e. there exists a number beta such that - `L(x)< beta < infinity` for all x. The constant is the normalization needed - to make `f(x)` a probability density (as opposed to just a finite measure). - - Although `initial_sample` can be arbitrary, a poor choice may result in a - slow-to-mix chain. In many cases the best choice is the one that maximizes - the target density, i.e., choose `initial_sample` such that - `f(initial_sample) >= f(x)` for all `x`. - - - If the support of the distribution is a strict subset of R^n (but of non zero - measure), then the unnormalized log-density `L(x)` should return `-infinity` - outside the support domain. This effectively forces the sampler to only - explore points in the regions of finite support. - - Usage: - This function is meant to be wrapped up with some of the common proposal - schemes (e.g. random walk, Langevin diffusion etc) to produce a more user - friendly interface. However, it may also be used to create bespoke samplers. - - The following example, demonstrates the use to generate a 1000 uniform random - walk Metropolis samplers run in parallel for the normal target distribution. - - ```python - n = 3 # dimension of the problem - - # Generate 1000 initial values randomly. Each of these would be an - # independent starting point for a Markov chain. - state = tf.get_variable( - "state", - initializer=tf.random_normal([1000, n], - mean=3.0, - dtype=tf.float64, - seed=42)) - - # Computes the log(p(x)) for the unit normal density and ignores the - # normalization constant. - def log_density(x): - return -tf.reduce_sum(x * x, reduction_indices=-1) / 2.0 - - # Initial log-density value - state_log_density = tf.get_variable( - "state_log_density", - initializer=log_density(state.initialized_value())) - - # A variable to store the log_acceptance_ratio: - log_acceptance_ratio = tf.get_variable( - "log_acceptance_ratio", - initializer=tf.zeros([1000], dtype=tf.float64)) - - # Generates random proposals by moving each coordinate uniformly and - # independently in a box of size 2 centered around the current value. - # Returns the new point and also the log of the Hastings ratio (the - # ratio of the probability of going from the proposal to origin and the - # probability of the reverse transition). When this ratio is 1, the value - # may be omitted and replaced by None. - def random_proposal(x): - return (x + tf.random_uniform(tf.shape(x), minval=-1, maxval=1, - dtype=x.dtype, seed=12)), None - - # Create the op to propagate the chain for 100 steps. - stepper = mh.evolve( - state, state_log_density, log_acceptance_ratio, - log_density, random_proposal, n_steps=100, seed=123) - init = tf.initialize_all_variables() - with tf.Session() as sess: - sess.run(init) - # Run the chains for a total of 1000 steps and print out the mean across - # the chains every 100 iterations. - for n_iter in range(10): - # Executing the stepper advances the chain to the next state. - sess.run(stepper) - # Print out the current value of the mean(sample) for every dimension. - print(np.mean(sess.run(state), 0)) - # Estimated covariance matrix - samples = sess.run(state) - print(np.cov(samples, rowvar=False)) - ``` - - Args: - initial_sample: A float-like `tf.Variable` of any shape that can - be consumed by the `target_log_prob_fn` and `proposal_fn` - callables. - initial_log_density: Float-like `tf.Variable` with `dtype` and shape - equivalent to `target_log_prob_fn(initial_sample)`, i.e., matching - the result of `target_log_prob_fn` invoked at `current_state`. - initial_log_accept_ratio: A `tf.Variable` with `dtype` and shape matching - `initial_log_density`. Stands for the log of Metropolis-Hastings - acceptance ratio after propagating the chain for `n_steps`. - target_log_prob_fn: A Python callable evaluated at - `current_state` and returning a float-like `Tensor` of log target-density - up to a normalizing constant. In other words, - `target_log_prob_fn(x) = log(g(x))`, where - `target_density = g(x)/Z` for some constant `A`. The shape of the input - tensor is the same as the shape of the `current_state`. The shape of the - output tensor is either - (a). Same as the input shape if the density being sampled is one - dimensional, or - (b). If the density is defined for `events` of shape - `event_shape = [E1, E2, ... Ee]`, then the input tensor should be of - shape `batch_shape + event_shape`, here `batch_shape = [B1, ..., Bb]` - and the result must be of shape [B1, ..., Bb]. For example, if the - distribution that is being sampled is a 10 dimensional normal, - then the input tensor may be of shape [100, 10] or [30, 20, 10]. The - last dimension will then be 'consumed' by `target_log_prob_fn` - and it should return tensors of shape [100] and [30, 20] respectively. - proposal_fn: A callable accepting a real valued `Tensor` of current sample - points and returning a tuple of two `Tensors`. The first element of the - pair should be a `Tensor` containing the proposal state and should have - the same shape as the input `Tensor`. The second element of the pair gives - the log of the ratio of the probability of transitioning from the - proposal points to the input points and the probability of transitioning - from the input points to the proposal points. If the proposal is - symmetric, i.e. - Probability(Proposal -> Current) = Probability(Current -> Proposal) - the second value should be set to None instead of explicitly supplying a - tensor of zeros. In addition to being convenient, this also leads to a - more efficient graph. - n_steps: A positive `int` or a scalar `int32` tensor. Sets the number of - iterations of the chain. - seed: `int` or None. The random seed for this `Op`. If `None`, no seed is - applied. - name: A string that sets the name for this `Op`. - - Returns: - forward_step: an `Op` to step the Markov chain forward for `n_steps`. - """ - - with ops.name_scope(name, "metropolis_hastings", [initial_sample]): - current_state = initial_sample - current_target_log_prob = initial_log_density - log_accept_ratio = initial_log_accept_ratio - - def step(i, current_state, current_target_log_prob, log_accept_ratio): - """Wrap single Markov chain iteration in `while_loop`.""" - next_state, kernel_results = kernel( - target_log_prob_fn=target_log_prob_fn, - proposal_fn=proposal_fn, - current_state=current_state, - current_target_log_prob=current_target_log_prob, - seed=seed) - accepted_log_prob = kernel_results.current_target_log_prob - log_accept_ratio = kernel_results.log_accept_ratio - return i + 1, next_state, accepted_log_prob, log_accept_ratio - - (_, accepted_state, accepted_target_log_prob, accepted_log_accept_ratio) = ( - control_flow_ops.while_loop( - cond=lambda i, *ignored_args: i < n_steps, - body=step, - loop_vars=[ - 0, # i - current_state, - current_target_log_prob, - log_accept_ratio, - ], - parallel_iterations=1 if seed is not None else 10, - # TODO(b/73775595): Confirm optimal setting of swap_memory. - swap_memory=1)) - - forward_step = control_flow_ops.group( - state_ops.assign(current_target_log_prob, accepted_target_log_prob), - state_ops.assign(current_state, accepted_state), - state_ops.assign(log_accept_ratio, accepted_log_accept_ratio)) - - return forward_step - - -def proposal_uniform(step_size=1., - seed=None, - name=None): - """Returns a callable that adds a random uniform tensor to the input. - - This function returns a callable that accepts one `Tensor` argument of any - shape and a real data type (i.e. `tf.float32` or `tf.float64`). It adds a - sample from a random uniform distribution drawn from [-stepsize, stepsize] - to its input. It also returns the log of the ratio of the probability of - moving from the input point to the proposed point, but since this log ratio is - identically equal to 0 (because the probability of drawing a value `x` from - the symmetric uniform distribution is the same as the probability of drawing - `-x`), it simply returns None for the second element of the returned tuple. - - Args: - step_size: A positive `float` or a scalar tensor of real dtype - controlling the scale of the uniform distribution. - If step_size = a, then draws are made uniformly from [-a, a]. - seed: `int` or None. The random seed for this `Op`. If `None`, no seed is - applied. - name: A string that sets the name for this `Op`. - - Returns: - proposal_fn: A callable accepting one float-like `Tensor` and returning a - 2-tuple. The first value in the tuple is a `Tensor` of the same shape and - dtype as the input argument and the second element of the tuple is None. - """ - - with ops.name_scope(name, "proposal_uniform", [step_size]): - step_size = ops.convert_to_tensor(step_size, name="step_size") - - def proposal_fn(input_state, name=None): - """Adds a uniform perturbation to the input state. - - Args: - input_state: A `Tensor` of any shape and real dtype. - name: A string that sets the name for this `Op`. - - Returns: - proposal_state: A float-like `Tensor` with `dtype` and shape matching - `input_state`. - log_transit_ratio: `None`. Proposal is symmetric. - """ - with ops.name_scope(name, "proposer", [input_state]): - input_state = ops.convert_to_tensor(input_state, name="input_state") - return input_state + random_ops.random_uniform( - array_ops.shape(input_state), - minval=-step_size, - maxval=step_size, - seed=seed), None - return proposal_fn - - -def proposal_normal(scale=1., - seed=None, - name=None): - """Returns a callable that adds a random normal tensor to the input. - - This function returns a callable that accepts one `Tensor` argument of any - shape and a real data type (i.e. `tf.float32` or `tf.float64`). The callable - adds a sample from a normal distribution with the supplied standard deviation - and zero mean to its input argument (called the proposal point). - The callable returns a tuple with the proposal point as the first element. - The second element is identically `None`. It is included so the callable is - compatible with the expected signature of the proposal scheme argument in the - `metropolis_hastings` function. A value of `None` indicates that the - probability of going from the input point to the proposal point is equal to - the probability of going from the proposal point to the input point. - - Args: - scale: A positive `float` or a scalar tensor of any real dtype controlling - the scale of the normal distribution. - seed: `int` or None. The random seed for this `Op`. If `None`, no seed is - applied. - name: A string that sets the name for this `Op`. - - Returns: - proposal_fn: A callable accepting one float-like `Tensor` and returning a - 2-tuple. The first value in the tuple is a `Tensor` of the same shape and - dtype as the input argument and the second element of the tuple is None. - """ - - with ops.name_scope(name, "proposal_normal", [scale]): - scale = ops.convert_to_tensor(scale, name="scale") - - def proposal_fn(input_state, name=None): - """Adds a normal perturbation to the input state. - - Args: - input_state: A `Tensor` of any shape and real dtype. - name: A string that sets the name for this `Op`. - - Returns: - proposal_state: A float-like `Tensor` with `dtype` and shape matching - `input_state`. - log_transit_ratio: `None`. Proposal is symmetric. - """ - - with ops.name_scope(name, "proposer", [input_state]): - input_state = ops.convert_to_tensor(input_state, name="input_state") - return input_state + random_ops.random_normal( - array_ops.shape(input_state), - mean=0., - stddev=scale, - dtype=scale.dtype, - seed=seed), None - return proposal_fn - - -def _is_list_like(x): - """Helper which returns `True` if input is `list`-like.""" - return isinstance(x, (tuple, list)) diff --git a/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py b/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py deleted file mode 100644 index 7786656398e3c87704227be95b3cd23a38785249..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""An optimizer module for stochastic gradient Langevin dynamics.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope as varscope_ops -from tensorflow.python.training import optimizer -from tensorflow.python.training import training_ops - - -class SGLDOptimizer(optimizer.Optimizer): - """An optimizer module for stochastic gradient Langevin dynamics. - - This implements the preconditioned Stochastic Gradient Langevin Dynamics - optimizer [1]. The optimization variable is regarded as a sample from the - posterior under Stochastic Gradient Langevin Dynamics with noise rescaled in - each dimension according to RMSProp [2]. - - Note: If a prior is included in the loss, it should be scaled by - `1/num_pseudo_batches`, where num_pseudo_batches is the number of minibatches - in the data. I.e., it should be divided by the `num_pseudo_batches` term - described below. - - [1]: "Preconditioned Stochastic Gradient Langevin Dynamics for Deep Neural - Networks." Chunyuan Li, Changyou Chen, David Carlson, Lawrence Carin. - ArXiv:1512.07666, 2015. https://arxiv.org/abs/1512.07666 - [2]: http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf - - Args: - learning_rate: Scalar `float`-like `Tensor`. The base learning rate for the - optimizer. Must be tuned to the specific function being minimized. - preconditioner_decay_rate: Scalar `float`-like `Tensor`. The exponential - decay rate of the rescaling of the preconditioner (RMSprop). (This is - "alpha" in [1]). Should be smaller than but nearly `1` to approximate - sampling from the posterior. (Default: `0.95`) - num_pseudo_batches: Scalar `int`-like `Tensor`. The effective number of - minibatches in the data set. Trades off noise and prior with the SGD - likelihood term. Note: Assumes the loss is taken as the mean over a - minibatch. Otherwise if the sum was taken, divide this number by the - batch size. (Default: `1`) - burnin: Scalar `int`-like `Tensor`. The number of iterations to collect - gradient statistics to update the preconditioner before starting to draw - noisy samples. (Default: `25`) - diagonal_bias: Scalar `float`-like `Tensor`. Term added to the diagonal of - the preconditioner to prevent the preconditioner from degenerating. - (Default: `1e-8`) - name: Python `str` describing ops managed by this function. - (Default: `"SGLDOptimizer"`) - variable_scope: Variable scope used for calls to `tf.get_variable`. - If `None`, a new variable scope is created using name - `ops.get_default_graph().unique_name(name or default_name)`. - - Raises: - InvalidArgumentError: If preconditioner_decay_rate is a `Tensor` not in - `(0,1]`. - """ - - def __init__(self, - learning_rate, - preconditioner_decay_rate=0.95, - num_pseudo_batches=1, - burnin=25, - diagonal_bias=1e-8, - name=None, - variable_scope=None): - default_name = 'SGLDOptimizer' - with ops.name_scope(name, default_name, [ - learning_rate, preconditioner_decay_rate, num_pseudo_batches, burnin, - diagonal_bias - ]): - if variable_scope is None: - var_scope_name = ops.get_default_graph().unique_name( - name or default_name) - with varscope_ops.variable_scope(var_scope_name) as scope: - self._variable_scope = scope - else: - self._variable_scope = variable_scope - - self._preconditioner_decay_rate = ops.convert_to_tensor( - preconditioner_decay_rate, name='preconditioner_decay_rate') - self._num_pseudo_batches = ops.convert_to_tensor( - num_pseudo_batches, name='num_pseudo_batches') - self._burnin = ops.convert_to_tensor(burnin, name='burnin') - self._diagonal_bias = ops.convert_to_tensor( - diagonal_bias, name='diagonal_bias') - self._learning_rate = ops.convert_to_tensor( - learning_rate, name='learning_rate') - - with varscope_ops.variable_scope(self._variable_scope): - self._counter = varscope_ops.get_variable( - 'counter', initializer=0, trainable=False) - - self._preconditioner_decay_rate = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._preconditioner_decay_rate, - message='`preconditioner_decay_rate` must be non-negative'), - check_ops.assert_less_equal( - self._preconditioner_decay_rate, - 1., - message='`preconditioner_decay_rate` must be at most 1.'), - ], self._preconditioner_decay_rate) - - self._num_pseudo_batches = control_flow_ops.with_dependencies([ - check_ops.assert_greater( - self._num_pseudo_batches, - 0, - message='`num_pseudo_batches` must be greater than zero') - ], self._num_pseudo_batches) - - self._burnin = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._burnin, message='`burnin` must be non-negative'), - check_ops.assert_integer( - self._burnin, message='`burnin` must be an integer') - ], self._burnin) - - self._diagonal_bias = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._diagonal_bias, - message='`diagonal_bias` must be non-negative') - ], self._diagonal_bias) - - super(SGLDOptimizer, self).__init__(use_locking=False, - name=name or default_name) - - def _create_slots(self, var_list): - for v in var_list: - init_rms = init_ops.ones_initializer(dtype=v.dtype) - self._get_or_make_slot_with_initializer(v, init_rms, v.get_shape(), - v.dtype, 'rms', self._name) - - def _prepare(self): - # We need to put the conversion and check here because a user will likely - # want to decay the learning rate dynamically. - self._learning_rate_tensor = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._learning_rate, message='`learning_rate` must be non-negative') - ], ops.convert_to_tensor(self._learning_rate, name='learning_rate_tensor')) - self._decay_tensor = ops.convert_to_tensor( - self._preconditioner_decay_rate, name='preconditioner_decay_rate') - - super(SGLDOptimizer, self)._prepare() - - def _apply_dense(self, grad, var): - rms = self.get_slot(var, 'rms') - - with ops.control_dependencies([ - self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor, - var.dtype.base_dtype))]): - new_grad = self._apply_noisy_update(rms, grad) - - return training_ops.apply_gradient_descent( - var, - math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), - new_grad, - use_locking=self._use_locking).op - - def _apply_sparse(self, grad, var): - rms = self.get_slot(var, 'rms') - - with ops.control_dependencies([ - self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor, - var.dtype.base_dtype))]): - new_grad = self._apply_noisy_update(rms, grad) - - return training_ops.apply_gradient_descent( - var, - math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), - new_grad, - use_locking=self._use_locking).op - - def _finish(self, update_ops, name_scope): - update_ops.append([self._counter.assign_add(1)]) - return control_flow_ops.group(*update_ops, name=name_scope) - - @property - def variable_scope(self): - """Variable scope of all calls to `tf.get_variable`.""" - return self._variable_scope - - def _apply_noisy_update(self, mom, grad): - # Compute and apply the gradient update following - # preconditioned Langevin dynamics - stddev = array_ops.where( - array_ops.squeeze(self._counter > self._burnin), - math_ops.cast(math_ops.rsqrt(self._learning_rate), grad.dtype), - array_ops.zeros([], grad.dtype)) - - preconditioner = math_ops.rsqrt( - mom + math_ops.cast(self._diagonal_bias, grad.dtype)) - return ( - 0.5 * preconditioner * grad * math_ops.cast(self._num_pseudo_batches, - grad.dtype) + - random_ops.random_normal(array_ops.shape(grad), 1.0, dtype=grad.dtype) * - stddev * math_ops.sqrt(preconditioner)) - - def _update_momentum(self, mom, grad, decay): - # Keep an exponentially weighted moving average of squared gradients. - # Not thread safe - return mom.assign_add((1.0 - decay) * (math_ops.square(grad) - mom)) diff --git a/tensorflow/contrib/bayesflow/python/ops/variable_utils.py b/tensorflow/contrib/bayesflow/python/ops/variable_utils.py deleted file mode 100644 index eadf6f4d5fa1c776e2c71c66c4b64b8f5ac98359..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/variable_utils.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility functions related to managing `tf.Variable`s.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# go/tf-wildcard-import -from tensorflow.contrib.bayesflow.python.ops.variable_utils_impl import * # pylint: disable=wildcard-import,unused-wildcard-import,g-importing-member -from tensorflow.python.util import all_util - -_allowed_symbols = [ - "externalize_variables_as_args", -] - -all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py b/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py deleted file mode 100644 index ca3d75b5bfee093449026c7d1d62e3bdeff6b096..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility functions related to managing `tf.Variable`s. - -@@externalize_variables_as_args -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import warnings - -from tensorflow.python.framework import ops -from tensorflow.python.ops import gradients_impl as gradients_ops -from tensorflow.python.ops import variable_scope as varscope_ops -from tensorflow.python.ops import variables as variables_ops - -__all__ = [ - "externalize_variables_as_args", -] - - -# Cause all warnings to always be triggered. -# Not having this means subsequent calls wont trigger the warning. -warnings.simplefilter("always") - - -def externalize_variables_as_args(fn, - fn_args=(), - ancestor_variables=None, - possible_ancestor_vars=None, - assert_variable_override=False, - name=None): - """"Converts variables within a callable into explicit args. - - Makes a new callable from `fn` which has arguments `list(fn_args) + - list(ancestor_variables)`. If `ancestor_variables` is not specified, it is - inferred by checking which of `possible_ancestor_vars` actually influences the - return value of `fn` (concretely, gradient of `fn(*fn_args)` is not `None`). - By default `possible_ancestor_vars` is `tf.trainable_variables() + - tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)`. - - #### Examples: - - ```python - num_samples = 2 - num_dims = 1 - dtype = np.float32 - - def foo(x): - x = tf.convert_to_tensor(x, dtype=dtype, name="x") - s = x.shape.as_list() - y = tf.get_variable( - name="y", - dtype=dtype, - initializer=np.arange(np.prod(s)).reshape(s).astype(dtype)) - return x + y - - x = tf.constant(dtype([0.1, 0.2])) - - wrapped_foo, discovered_ancestor_variables = ( - externalize_variables_as_args(foo, [x])) - - new_x = dtype([[1.], [2.]]) - new_y = dtype([[3.], [4.]]) - new_result = wrapped_foo(new_x, new_y) - # ==> [[4.], [6.]] - - discovered_ancestor_variables == [tf.get_variable("y", dtype)] - # ==> [True] - ``` - - Args: - fn: Python callable which returns a `Tensor` and accepts `*fn_args`. - fn_args: Python list of args to `fn`. Represents dummy arguments passed to - `fn` to trace its execution; actual values are unimportant. These args are - only used to construct the output of `fn` and to resolve the ancestor - `tf.Variable`s. - Default value: `()` (i.e., `fn` takes no args). - ancestor_variables: Python list of `tf.Variable`s. When `None` the list is - expanded to non-`None` gradients of `fn(*fn_args)`. By directly providing - the `ancestor_variables` the internal call to `fn` is avoided. - Default value: `None` (i.e., `tf.Variable` dependencies are discovered). - possible_ancestor_vars: Python list of possible `tf.Variable`s which might - be a dependency of computing `fn(*fn_args)`. - Default value: `None` (i.e., expanded as described above). - assert_variable_override: Python `bool` indicating that not finding a - `tf.Variable` in the override list is an exception. - Default value: `False` (i.e., missing a `Variable` triggers a `warning`). - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "externalize_variables_as_args"). - - Returns: - wrapped_fn: Python callable taking arguments like - `*(list(fn_args) + discovered_ancestor_variables)`. - discovered_ancestor_variables: Python list of `tf.Variable`s known to be a - dependency of `fn(*fn_args)`. - - Raises: - ValueError: if `assert_variable_override` is `True` and `Variable` is - requested but not overridden. - """ - def _make_bypassing_custom_getter_fn(new_var_dict): - """Return dict value rather than what would otherwise be dict key.""" - def _custom_getter(getter, *args, **kwargs): - v = getter(*args, **kwargs) - new_v = new_var_dict.get(v, None) - if new_v is None: - msg = "Variable \"{}\" not found in bypass dict.".format(v) - if assert_variable_override: - raise ValueError(msg) - warnings.warn(msg) - return v - return new_v - return _custom_getter - - with ops.name_scope(name, "externalize_variables_as_args"): - if ancestor_variables is not None and not ancestor_variables: - return fn, () - if ancestor_variables is None: - y = fn(*fn_args) # Side-effect: adds trainable vars. - if possible_ancestor_vars is None: - possible_ancestor_vars = ( - variables_ops.trainable_variables() + - ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) - # TODO(b/72873296): Add a dedicated op for identifying ancestors. - ancestors = [v for g, v - in zip(gradients_ops.gradients(y, possible_ancestor_vars), - possible_ancestor_vars) - if g is not None] - ancestor_variables = sorted(ancestors, key=lambda v: v.name) - n = len(fn_args) - def _fn(*args): - with ops.name_scope("wrapped_fn"): - vars_dict = dict( - (k, ops.convert_to_tensor( - v, dtype=k.dtype.base_dtype, name=k.op.name)) - for k, v in zip(ancestor_variables, args[n:])) - with varscope_ops.variable_scope( - varscope_ops.get_variable_scope(), - reuse=True, - custom_getter=_make_bypassing_custom_getter_fn(vars_dict)): - return fn(*args[:n]) - return _fn, ancestor_variables diff --git a/tensorflow/contrib/bayesflow/python/ops/variational_sgd_optimizer.py b/tensorflow/contrib/bayesflow/python/ops/variational_sgd_optimizer.py deleted file mode 100644 index 4d5f0cfe9713a011b32c5aba8d429847d81f33e2..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/variational_sgd_optimizer.py +++ /dev/null @@ -1,279 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""An optimizer module for constant stochastic gradient descent.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope as varscope_ops -from tensorflow.python.training import optimizer -from tensorflow.python.training import training_ops - - -class VariationalSGDOptimizer(optimizer.Optimizer): - """An optimizer module for constant stochastic gradient descent. - - This implements an optimizer module for the constant stochastic gradient - descent algorithm [1]. The optimization variable is regarded as an - approximate sample from the posterior . - - Note: If a prior is included in the loss, it should be scaled by - `1/num_pseudo_batches`, where num_pseudo_batches is the number of minibatches - in the data. I.e., it should be divided by the `num_pseudo_batches` term - described below. - - [1]: "Stochastic Gradient Descent as Approximate Bayesian Inference - Stephan Mandt, Matthew D. Hoffman, David M. Blei. - ArXiv:1704.04289, 2017. https://arxiv.org/abs/1704.04289 - - Args: - batch_size: Scalar `int`-like `Tensor`. The number of examples in a - minibatch in the data set. Note: Assumes the loss is taken as the mean - over a minibatch. Otherwise if the sum was taken set this to 1. - total_num_examples: Scalar `int`-like `Tensor`. The total number of examples - in the data set. - max_learning_rate: Scalar `float`-like `Tensor`. A maximum allowable - effective coordinate-wise learning rate. The algorithm scales down any - effective learning rate (i.e. after preconditioning) that is larger than - this. (Default: `1`) - preconditioner_decay_rate: Scalar `float`-like `Tensor`. The exponential - decay rate of the rescaling of the preconditioner (RMSprop). (This is - "alpha" in [1]). Should be smaller than but nearly `1` to approximate - sampling from the posterior. (Default: `0.95`) - burnin: Scalar `int`-like `Tensor`. The number of iterations to collect - gradient statistics to update the preconditioner before starting to draw - noisy samples. (Default: `25`) - burnin_max_learning_rate: Scalar `float`-like `Tensor`. Maximum learning - rate to use during the burnin period. - (Default: `1e-8`) - use_single_learning_rate: Boolean Indicates whether one single learning - rate is used or coordinate_wise learning rates are used. - (Default: `False`) - name: Python `str` describing ops managed by this function. - (Default: `"VariationalSGDOptimizer"`) - variable_scope: Variable scope used for calls to `tf.get_variable`. - If `None`, a new variable scope is created using name - `ops.get_default_graph().unique_name(name or default_name)`. - - Raises: - InvalidArgumentError: If preconditioner_decay_rate is a `Tensor` not in - `(0,1]`. - """ - - def __init__(self, - batch_size, - total_num_examples, - max_learning_rate=1.0, - preconditioner_decay_rate=0.95, - burnin=25, - burnin_max_learning_rate=1e-6, - use_single_learning_rate=False, - name=None, - variable_scope=None): - default_name = 'VariationalSGDOptimizer' - with ops.name_scope(name, default_name, [ - max_learning_rate, preconditioner_decay_rate, batch_size, burnin, - burnin_max_learning_rate - ]): - if variable_scope is None: - var_scope_name = ops.get_default_graph().unique_name( - name or default_name) - with varscope_ops.variable_scope(var_scope_name) as scope: - self._variable_scope = scope - else: - self._variable_scope = variable_scope - - self._preconditioner_decay_rate = ops.convert_to_tensor( - preconditioner_decay_rate, name='preconditioner_decay_rate') - self._batch_size = ops.convert_to_tensor(batch_size, name='batch_size') - self._total_num_examples = ops.convert_to_tensor( - total_num_examples, name='total_num_examples') - self._burnin = ops.convert_to_tensor(burnin, name='burnin') - self._burnin_max_learning_rate = ops.convert_to_tensor( - burnin_max_learning_rate, name='burnin_max_learning_rate') - self._max_learning_rate = ops.convert_to_tensor( - max_learning_rate, name='max_learning_rate') - self._use_single_learning_rate = use_single_learning_rate - - with varscope_ops.variable_scope(self._variable_scope): - self._counter = varscope_ops.get_variable( - 'counter', initializer=0, trainable=False) - - self._preconditioner_decay_rate = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._preconditioner_decay_rate, - message='`preconditioner_decay_rate` must be non-negative'), - check_ops.assert_less_equal( - self._preconditioner_decay_rate, - 1., - message='`preconditioner_decay_rate` must be at most 1.'), - ], self._preconditioner_decay_rate) - - self._batch_size = control_flow_ops.with_dependencies([ - check_ops.assert_greater( - self._batch_size, - 0, - message='`batch_size` must be greater than zero') - ], self._batch_size) - - self._total_num_examples = control_flow_ops.with_dependencies([ - check_ops.assert_greater( - self._total_num_examples, - 0, - message='`total_num_examples` must be greater than zero') - ], self._total_num_examples) - - self._burnin = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._burnin, message='`burnin` must be non-negative'), - check_ops.assert_integer( - self._burnin, message='`burnin` must be an integer') - ], self._burnin) - - self._burnin_max_learning_rate = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._burnin_max_learning_rate, - message='`burnin_max_learning_rate` must be non-negative') - ], self._burnin_max_learning_rate) - - self._max_learning_rate = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._max_learning_rate, - message='`max_learning_rate` must be non-negative') - ], self._max_learning_rate) - - super(VariationalSGDOptimizer, self).__init__( - use_locking=False, name=name or default_name) - - def _create_slots(self, var_list): - for v in var_list: - init_moment = init_ops.zeros_initializer(dtype=v.dtype) - self._get_or_make_slot_with_initializer( - v, init_moment, v.get_shape(), v.dtype, 'first_moment', self._name) - self._get_or_make_slot_with_initializer( - v, init_moment, v.get_shape(), v.dtype, 'second_moment', self._name) - - def _prepare(self): - self._decay_tensor = ops.convert_to_tensor( - self._preconditioner_decay_rate, name='preconditioner_decay_rate') - self._batch_size_tensor = ops.convert_to_tensor( - self._batch_size, name='batch_size_tensor') - - super(VariationalSGDOptimizer, self)._prepare() - - def _get_coordinatewise_learning_rate(self, grad, var): - # Compute the learning rate using a moving average for the diagonal of BB^T - avg_first = self.get_slot(var, 'first_moment') - avg_second = self.get_slot(var, 'second_moment') - decay_tensor = math_ops.cast(self._decay_tensor, var.dtype) - batch_size = math_ops.cast(self._batch_size_tensor, var.dtype) - - # Create an estimator for the moving average of gradient mean and variance - # via Welford's algorithm - if isinstance(grad, ops.Tensor): - delta = grad - avg_first - first_moment_update = avg_first.assign_add( - array_ops.where(self._counter < 1, math_ops.cast(1, var.dtype), - 1. - decay_tensor) * delta) - - with ops.control_dependencies([first_moment_update]): - second_moment_update = avg_second.assign_add( - math_ops.cast(self._counter < 1, var.dtype) * - -(1. - decay_tensor) * ( - avg_second - decay_tensor * math_ops.square(delta))) - diag_preconditioner = control_flow_ops.with_dependencies( - [second_moment_update], - clip_ops.clip_by_value(avg_second, 1e-12, 1e12)) - elif isinstance(grad, ops.IndexedSlices): - delta = grad.values - array_ops.gather_nd(avg_first, grad.indices) - first_moment_update = state_ops.scatter_add( - avg_first, - grad.indices, - array_ops.where(self._counter < 1, - math_ops.cast(1., var.dtype), - 1. - decay_tensor) * delta) - - with ops.control_dependencies([first_moment_update]): - avg_second = state_ops.scatter_add( - avg_second, - grad.indices, - math_ops.cast(self._counter < 1, var.dtype) * - -(1. - decay_tensor) * ( - array_ops.gather_nd(avg_second, grad.indices) - decay_tensor * - math_ops.square(delta))) - avg_second = array_ops.gather_nd(avg_second, grad.indices) - # TODO(b/70783772) - diag_preconditioner = clip_ops.clip_by_value(avg_second, 1e-12, 1e12) - else: - raise errors.InvalidArgumentError( - None, None, 'grad must of type Tensor or IndexedSlice') - - diag_preconditioner *= batch_size - - if self._use_single_learning_rate: - diag_preconditioner = math_ops.reduce_mean(diag_preconditioner) - - # From Theorem 2 Corollary 1 of Mandt et al. 2017 - return 2. * batch_size / ( - math_ops.cast(self._total_num_examples, var.dtype.base_dtype) * - diag_preconditioner) - - def _apply_dense(self, grad, var): - - max_learning_rate = array_ops.where(self._counter < self._burnin, - self._burnin_max_learning_rate, - self._max_learning_rate) - - learn_rates = clip_ops.clip_by_value( - self._get_coordinatewise_learning_rate(grad, var), 0.0, - math_ops.cast(max_learning_rate, var.dtype.base_dtype)) - - newgrad = grad * learn_rates - return training_ops.apply_gradient_descent( - var, - math_ops.cast(1.0, var.dtype), - newgrad, - use_locking=self._use_locking).op - - def _apply_sparse(self, grad, var): - - max_learning_rate = array_ops.where(self._counter < self._burnin, - self._burnin_max_learning_rate, - self._max_learning_rate) - - learn_rate = clip_ops.clip_by_value( - self._get_coordinatewise_learning_rate(grad, var), 0.0, - math_ops.cast(max_learning_rate, var.dtype)) - delta = grad.values * learn_rate - - return state_ops.scatter_sub(var, grad.indices, delta, - use_locking=self._use_locking) - - def _finish(self, update_ops, name_scope): - update_ops.append([self._counter.assign_add(1)]) - return control_flow_ops.group(*update_ops, name=name_scope) - - @property - def variable_scope(self): - """Variable scope of all calls to `tf.get_variable`.""" - return self._variable_scope diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index 289f5bb3140974d8c37f4938ceef27275b099f9a..dcd235f876c87b4d7d85c0f1d0fc2e855ced99ea 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -13,20 +13,23 @@ load("//tensorflow:tensorflow.bzl", "py_test") filegroup( name = "all_files", srcs = glob( - ["**/*"], - exclude = [ - "**/OWNERS", - ], + include = ["**/*"], + exclude = ["**/OWNERS"], ), visibility = ["//tensorflow:__subpackages__"], ) py_library( name = "init_py", - srcs = [ - "__init__.py", - ], + srcs = ["__init__.py"], srcs_version = "PY2AND3", + deps = [ + "custom_export_strategy", + ":custom_loss_head", + ":estimator", + ":model", + ":trainer_hooks", + ], ) py_library( @@ -149,7 +152,7 @@ py_library( py_test( name = "dnn_tree_combined_estimator_test", - size = "small", + size = "medium", srcs = ["dnn_tree_combined_estimator_test.py"], srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index 23ba76210b3b68d0d0b2eef9d4040882654bdad9..d9b0d89a03dce40d34f76bb1262d26bb587a2dc7 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -54,7 +54,7 @@ def make_custom_export_strategy(name, An `ExportStrategy`. """ base_strategy = saved_model_export_utils.make_export_strategy( - serving_input_fn=export_input_fn) + serving_input_fn=export_input_fn, strip_default_attrs=True) input_fn = export_input_fn() (sorted_feature_names, dense_floats, sparse_float_indices, _, _, sparse_int_indices, _, _) = gbdt_batch.extract_features( diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index cec3892b57655dc967b4e7926f7f5a6a30084487..2e7b8cba05b89feaac3f47e13d26e7ae37a7b0ae 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -25,15 +25,20 @@ from __future__ import division from __future__ import print_function import six - from tensorflow.contrib import layers from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks from tensorflow.contrib.boosted_trees.python.ops import model_ops from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch from tensorflow.contrib.layers.python.layers import optimizers +from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn +from tensorflow.contrib.learn.python.learn.estimators import model_fn as contrib_model_fn_lib +from tensorflow.contrib.learn.python.learn.estimators import prediction_key +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export_output +from tensorflow.python.feature_column import feature_column as feature_column_lib from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import nn @@ -46,6 +51,52 @@ from tensorflow.python.training import training_util _DNN_LEARNING_RATE = 0.001 +_CORE_MODE_TO_CONTRIB_MODE_ = { + model_fn_lib.ModeKeys.TRAIN: contrib_model_fn_lib.ModeKeys.TRAIN, + model_fn_lib.ModeKeys.EVAL: contrib_model_fn_lib.ModeKeys.EVAL, + model_fn_lib.ModeKeys.PREDICT: contrib_model_fn_lib.ModeKeys.INFER +} + + +def _core_mode_to_contrib_mode(mode): + return _CORE_MODE_TO_CONTRIB_MODE_[mode] + + +def _export_outputs_to_output_alternatives(export_outputs): + """Converts EstimatorSpec.export_outputs to output_alternatives. + + Args: + export_outputs: export_outputs created by create_estimator_spec. + Returns: + converted output_alternatives. + """ + output = dict() + if export_outputs is not None: + for key, value in export_outputs.items(): + if isinstance(value, export_output.ClassificationOutput): + exported_predictions = { + prediction_key.PredictionKey.SCORES: value.scores, + prediction_key.PredictionKey.CLASSES: value.classes + } + output[key] = (constants.ProblemType.CLASSIFICATION, + exported_predictions) + return output + return None + + +def _estimator_spec_to_model_fn_ops(estimator_spec, is_regression): + alternatives = [] + if not is_regression: + _export_outputs_to_output_alternatives(estimator_spec.export_outputs) + + return model_fn.ModelFnOps( + mode=_core_mode_to_contrib_mode(estimator_spec.mode), + predictions=estimator_spec.predictions, + loss=estimator_spec.loss, + train_op=estimator_spec.train_op, + eval_metric_ops=estimator_spec.eval_metric_ops, + output_alternatives=alternatives) + def _get_optimizer(optimizer): if callable(optimizer): @@ -59,16 +110,26 @@ def _add_hidden_layer_summary(value, tag): summary.histogram("%s_activation" % tag, value) -def _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, - dnn_feature_columns, tree_learner_config, num_trees, - tree_examples_per_layer, - config=None, dnn_optimizer="Adagrad", - dnn_activation_fn=nn.relu, dnn_dropout=None, - dnn_input_layer_partitioner=None, - dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, - tree_feature_columns=None, - tree_center_bias=True): +def _dnn_tree_combined_model_fn(features, + labels, + mode, + head, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + config=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + tree_feature_columns=None, + tree_center_bias=False, + use_core_versions=False, + is_regression=False): """DNN and GBDT combined model_fn. Args: @@ -106,6 +167,9 @@ def _dnn_tree_combined_model_fn( set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_versions: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. + is_regression: Whether the problem is regression or not. Returns: A `ModelFnOps` object. @@ -135,11 +199,17 @@ def _dnn_tree_combined_model_fn( "input_from_feature_columns", values=tuple(six.itervalues(features)), partitioner=dnn_partitioner) as input_layer_scope: - input_layer = layers.input_from_feature_columns( - columns_to_tensors=features, - feature_columns=dnn_feature_columns, - weight_collections=[dnn_parent_scope], - scope=input_layer_scope) + if use_core_versions: + input_layer = feature_column_lib.input_layer( + features=features, + feature_columns=dnn_feature_columns, + weight_collections=[dnn_parent_scope]) + else: + input_layer = layers.input_from_feature_columns( + columns_to_tensors=features, + feature_columns=dnn_feature_columns, + weight_collections=[dnn_parent_scope], + scope=input_layer_scope) previous_layer = input_layer for layer_id, num_hidden_units in enumerate(dnn_hidden_units): with variable_scope.variable_scope( @@ -222,24 +292,51 @@ def _dnn_tree_combined_model_fn( del loss return control_flow_ops.no_op() - model_fn_ops = head.create_model_fn_ops( - features=features, - mode=mode, - labels=labels, - train_op_fn=_no_train_op_fn, - logits=tree_train_logits) - dnn_train_op = head.create_model_fn_ops( - features=features, - mode=mode, - labels=labels, - train_op_fn=_dnn_train_op_fn, - logits=dnn_logits).train_op - tree_train_op = head.create_model_fn_ops( - features=tree_features, - mode=mode, - labels=labels, - train_op_fn=_tree_train_op_fn, - logits=tree_train_logits).train_op + if use_core_versions: + model_fn_ops = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits) + dnn_train_op = _estimator_spec_to_model_fn_ops(dnn_train_op, + is_regression).train_op + + tree_train_op = head.create_estimator_spec( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits) + tree_train_op = _estimator_spec_to_model_fn_ops(tree_train_op, + is_regression).train_op + + model_fn_ops = _estimator_spec_to_model_fn_ops(model_fn_ops, is_regression) + else: + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits).train_op + tree_train_op = head.create_model_fn_ops( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits).train_op if tree_center_bias: num_trees += 1 @@ -277,7 +374,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, tree_feature_columns=None, - tree_center_bias=True): + tree_center_bias=False, + use_core_versions=False): """Initializes a DNNBoostedTreeCombinedClassifier instance. Args: @@ -322,6 +420,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_versions: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. """ head = head_lib.multi_class_head( n_classes=n_classes, @@ -336,8 +436,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): tree_learner_config, num_trees, tree_examples_per_layer, config, dnn_optimizer, dnn_activation_fn, dnn_dropout, dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, - tree_feature_columns, tree_center_bias) + dnn_steps_to_train, tree_feature_columns, tree_center_bias, + use_core_versions) super(DNNBoostedTreeCombinedClassifier, self).__init__( model_fn=_model_fn, model_dir=model_dir, @@ -366,7 +466,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, tree_feature_columns=None, - tree_center_bias=True): + tree_center_bias=False, + use_core_versions=False): """Initializes a DNNBoostedTreeCombinedRegressor instance. Args: @@ -411,6 +512,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_versions: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. """ head = head_lib.regression_head( label_name=label_name, @@ -426,11 +529,26 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, - tree_learner_config, num_trees, tree_examples_per_layer, config, - dnn_optimizer, dnn_activation_fn, dnn_dropout, - dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, tree_feature_columns, tree_center_bias) + features, + labels, + mode, + head, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + config, + dnn_optimizer, + dnn_activation_fn, + dnn_dropout, + dnn_input_layer_partitioner, + dnn_input_layer_to_tree, + dnn_steps_to_train, + tree_feature_columns, + tree_center_bias, + use_core_versions, + is_regression=True) super(DNNBoostedTreeCombinedRegressor, self).__init__( model_fn=_model_fn, model_dir=model_dir, @@ -460,7 +578,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, tree_feature_columns=None, - tree_center_bias=True): + tree_center_bias=False, + use_core_versions=False): """Initializes a DNNBoostedTreeCombinedEstimator instance. Args: @@ -500,6 +619,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_versions: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. """ def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( @@ -507,8 +628,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): tree_learner_config, num_trees, tree_examples_per_layer, config, dnn_optimizer, dnn_activation_fn, dnn_dropout, dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, - tree_feature_columns, tree_center_bias) + dnn_steps_to_train, tree_feature_columns, tree_center_bias, + use_core_versions) super(DNNBoostedTreeCombinedEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py index 83d58c561008e8a5a69eb503d1605bb9e940f281..f495edc62f0909880c170ccb4cf5d11e3f20f55c 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py @@ -19,15 +19,17 @@ from __future__ import division from __future__ import print_function import tempfile - from tensorflow.contrib.boosted_trees.estimator_batch import dnn_tree_combined_estimator as estimator from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils from tensorflow.contrib.learn.python.learn.estimators import run_config +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import googletest @@ -100,6 +102,35 @@ class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): classifier.fit(input_fn=_train_input_fn, steps=15) classifier.evaluate(input_fn=_eval_input_fn, steps=1) + def testFitAndEvaluateDontThrowExceptionWithCore(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + # Use core head + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) + + classifier = estimator.DNNBoostedTreeCombinedEstimator( + head=head_fn, + dnn_hidden_units=[1], + # Use core feature columns + dnn_feature_columns=[core_feature_column.numeric_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=True, + tree_feature_columns=[], + use_core_versions=True) + + classifier.fit(input_fn=_train_input_fn, steps=15) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 01752416b347dd0a5e646283b6b5572592df4690..70454aa6dbdb19297028a3f80822719bef5a0f72 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -81,7 +81,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): n_classes=n_classes, weight_column_name=weight_column_name, enable_centered_bias=False, - loss_fn=loss_fn) + loss_fn=loss_fn, + label_keys=label_keys) if learner_config.num_classes == 0: learner_config.num_classes = n_classes elif learner_config.num_classes != n_classes: diff --git a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc index 754b7bc3270d647fc381033b769eadd7b791771e..3bf33186ec13f5ff991db938d59849c0124a30a0 100644 --- a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc @@ -137,6 +137,61 @@ class TreeEnsembleDeserializeOp : public OpKernel { } }; +class TreeEnsembleUsedHandlersOp : public OpKernel { + public: + explicit TreeEnsembleUsedHandlersOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("num_all_handlers", &num_handlers_)); + } + + void Compute(OpKernelContext* context) override { + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; + + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &ensemble_resource)); + tf_shared_lock l(*ensemble_resource->get_mutex()); + core::ScopedUnref unref_me(ensemble_resource); + + // Get the stamp token. + const Tensor* stamp_token_t; + OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t)); + int64 stamp_token = stamp_token_t->scalar()(); + + // Only the Chief should run this Op and it is guaranteed to be in + // a consistent state so the stamps must always match. + CHECK(ensemble_resource->is_stamp_valid(stamp_token)); + + Tensor* output_used_handlers_t = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output("used_handlers_mask", {num_handlers_}, + &output_used_handlers_t)); + auto output_used_handlers = output_used_handlers_t->vec(); + + Tensor* output_num_used_handlers_t = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output("num_used_handlers", {}, + &output_num_used_handlers_t)); + int handler_idx = 0; + std::vector used_handlers = ensemble_resource->GetUsedHandlers(); + output_num_used_handlers_t->scalar()() = used_handlers.size(); + for (int64 i = 0; i < num_handlers_; ++i) { + if (handler_idx >= used_handlers.size() || + used_handlers[handler_idx] > i) { + output_used_handlers(i) = false; + } else { + OP_REQUIRES(context, used_handlers[handler_idx] == i, + errors::InvalidArgument("Handler IDs should be sorted.")); + ++handler_idx; + output_used_handlers(i) = true; + } + } + } + + private: + int64 num_handlers_; +}; + REGISTER_RESOURCE_HANDLE_KERNEL(DecisionTreeEnsembleResource); REGISTER_KERNEL_BUILDER( @@ -155,5 +210,7 @@ REGISTER_KERNEL_BUILDER(Name("TreeEnsembleSerialize").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("TreeEnsembleDeserialize").Device(DEVICE_CPU), TreeEnsembleDeserializeOp); +REGISTER_KERNEL_BUILDER(Name("TreeEnsembleUsedHandlers").Device(DEVICE_CPU), + TreeEnsembleUsedHandlersOp); } // namespace boosted_trees } // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index 0f4c2298f56be48bb32f52d5d44cff8afe284f1e..0b28f81e7ca9a1228adc5bde19c429265e0aa9b8 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -253,7 +253,7 @@ class CreateQuantileAccumulatorOp : public OpKernel { private: float epsilon_; int32 num_quantiles_; - // An upperbound on the number of enteries that the summaries might have + // An upper bound on the number of entries that the summaries might have // for a feature. int64 max_elements_; bool generate_quantiles_; diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index 7f8dea1d3c2a04b725843f6e2932a0cdfbc7733c..1bfeed306641111718984b2097512e5ec3fa8630 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -361,27 +361,10 @@ class GrowTreeEnsembleOp : public OpKernel { // Increment attempt stats. ensemble_resource->IncrementAttempts(); - // In case we want to do feature selection and we have reached the limit, - // build a list of handlers used so far to avoid adding new features. - std::vector allowed_handlers; - if (learner_config_.constraints().max_number_of_unique_feature_columns() > - 0) { - allowed_handlers = ensemble_resource->GetUsedHandlers(); - // TODO(soroush): We can disable handlers that are not going to be used to - // avoid unnecessary computations. - if (allowed_handlers.size() < - learner_config_.constraints() - .max_number_of_unique_feature_columns()) { - // We have not reached the limit yet. Empty the list of allow features - // which means we can keep adding new features. - allowed_handlers.clear(); - } - } - // Find best splits for each active partition. std::map best_splits; - FindBestSplitsPerPartition(context, allowed_handlers, partition_ids_list, - gains_list, splits_list, &best_splits); + FindBestSplitsPerPartition(context, partition_ids_list, gains_list, + splits_list, &best_splits); // No-op if no new splits can be considered. if (best_splits.empty()) { @@ -422,19 +405,12 @@ class GrowTreeEnsembleOp : public OpKernel { // and finds the best split for each partition. void FindBestSplitsPerPartition( OpKernelContext* const context, - const std::vector& allowed_handlers, // Empty means all handlers. const OpInputList& partition_ids_list, const OpInputList& gains_list, const OpInputList& splits_list, std::map* best_splits) { // Find best split per partition going through every feature candidate. // TODO(salehay): Is this worth parallelizing? for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) { - if (!allowed_handlers.empty()) { - if (!std::binary_search(allowed_handlers.begin(), - allowed_handlers.end(), handler_id)) { - continue; - } - } const auto& partition_ids = partition_ids_list[handler_id].vec(); const auto& gains = gains_list[handler_id].vec(); const auto& splits = splits_list[handler_id].vec(); diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h index cd925f6b65e569538212e9c26aef0abc8482960b..794ba2bcb0aafa26c5e1c90fcd66caf9dd5bf7d5 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h @@ -137,7 +137,7 @@ struct NodeStats { Eigen::MatrixXf hessian = TensorToEigenMatrix(grad_stats.second.t, grad_dim, grad_dim); // I is an identity matrix. - // The gain in general form is -g^T (H+l2 I)^-1 g. + // The gain in general form is g^T (H+l2 I)^-1 g. // The node weights are -(H+l2 I)^-1 g. Eigen::MatrixXf identity; identity.setIdentity(grad_dim, grad_dim); @@ -240,7 +240,7 @@ struct NodeStats { // given regularized Hessian and gradient vector g. void CalculateWeightAndGain(const Eigen::MatrixXf& hessian_and_reg, const Eigen::VectorXf& g) { - // The gain in general form is -g^T (Hessian_and_regularization)^-1 g. + // The gain in general form is g^T (Hessian_and_regularization)^-1 g. // The node weights are -(Hessian_and_regularization)^-1 g. Eigen::VectorXf weight; // If we want to calculate x = K^-1 v, instead of explicitly calculating diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc index cf4f9a097a3368465fd4d9afb981bbaa68b4df49..35b059f3496dbc8fb2b3d4fe6ec6b55a9d73dd0c 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc @@ -54,7 +54,7 @@ Status BatchFeatures::Initialize( TF_CHECK_AND_RETURN_IF_ERROR( dense_float_feature.dim_size(1) == 1, errors::InvalidArgument( - "Dense float features may not be multi-valent: dim_size(1) = ", + "Dense float features may not be multivalent: dim_size(1) = ", dense_float_feature.dim_size(1))); dense_float_feature_columns_.emplace_back(dense_float_feature); } diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc index 609519e8b1153a27d987c5f9ca9bfcc9ee6717d6..cfe9101e7435cd798569f3e52a87fc8ed7b6a239 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc @@ -59,7 +59,7 @@ TEST_F(BatchFeaturesTest, DenseFloatFeatures_Multivalent) { BatchFeatures batch_features(1); auto dense_vec = AsTensor({3.0f, 7.0f}, {1, 2}); auto expected_error = InvalidArgument( - "Dense float features may not be multi-valent: dim_size(1) = 2"); + "Dense float features may not be multivalent: dim_size(1) = 2"); EXPECT_EQ(expected_error, batch_features.Initialize({dense_vec}, {}, {}, {}, {}, {}, {})); } diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc index db34db998a7442c69f2ab468f4557d991429f4ee..ce67db797ded54f5023eaa89369d4781aad31a7c 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc @@ -54,7 +54,7 @@ Status DropoutUtils::DropOutTrees( if (probability_of_skipping_dropout < 0 || probability_of_skipping_dropout > 1) { return errors::InvalidArgument( - "Probability of skiping dropout must be in [0,1] range"); + "Probability of skipping dropout must be in [0,1] range"); } const auto num_trees = weights.size(); diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h index 928bfbfe5c9394ab4083aabced4c8e1149bb10aa..77c16da5410fe65b20839c7b6bc677067d7ff297 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h @@ -66,7 +66,7 @@ class DropoutUtils { // Current weights and num_updates will be updated as a result of this // func std::vector* current_weights, - // How many weight assignements have been done for each tree already. + // How many weight assignments have been done for each tree already. std::vector* num_updates); }; diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc index 0138aae3dbd3773241cb6644db625b99f9bf1372..cc7604745e6bb90837eeca1123faa88dc914e4fc 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc @@ -34,7 +34,7 @@ TEST_F(SparseColumnIterableTest, Empty) { } TEST_F(SparseColumnIterableTest, Iterate) { - // 8 examples having 7 sparse features with the 3rd and 7th multi-valent. + // 8 examples having 7 sparse features with the 3rd and 7th multivalent. // This can be visualized like the following: // Instance | Sparse | // 0 | x | diff --git a/tensorflow/contrib/boosted_trees/ops/model_ops.cc b/tensorflow/contrib/boosted_trees/ops/model_ops.cc index 0786c4166410720e8d4d70960e5747ff111076d8..9d6343c7e80f369bf6a5465821c5f4bacb984cd0 100644 --- a/tensorflow/contrib/boosted_trees/ops/model_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/model_ops.cc @@ -110,5 +110,32 @@ stamp_token: Token to use as the new value of the resource stamp. tree_ensemble_config: Serialized proto of the ensemble. )doc"); +REGISTER_OP("TreeEnsembleUsedHandlers") + .Attr("num_all_handlers: int >= 0") + .Input("tree_ensemble_handle: resource") + .Input("stamp_token: int64") + .Output("num_used_handlers: int64") + .Output("used_handlers_mask: bool") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused_input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); + c->set_output(0, c->Scalar()); + int num_all_handlers; + c->GetAttr("num_all_handlers", &num_all_handlers).IgnoreError(); + c->set_output(1, {c->Vector(num_all_handlers)}); + + return Status::OK(); + }) + .Doc(R"doc( +Returns the mask of used handlers along with the number of non-zero elements in +this mask. Used in feature selection. + +tree_ensemble_handle: Handle to the tree ensemble. +stamp_token: Token to use as the new value of the resource stamp. +num_used_handlers: number of feature column handlers used in the model. +used_handlers_mask: A boolean vector of showing which handlers are used in the + model. +)doc"); + } // namespace boosted_trees } // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc index ae99d53a2cf805d70d60746cd44f73f7fd9dc6e2..6aa52463987b55a54b7308765920cbe94c15b8d1 100644 --- a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc @@ -272,6 +272,20 @@ REGISTER_OP("Quantiles") .Input("sparse_indices: num_sparse_features * int64") .Output("dense_quantiles: num_dense_features * int32") .Output("sparse_quantiles: num_sparse_features * int32") + .SetShapeFn([](InferenceContext* c) { + int num_dense_features; + TF_RETURN_IF_ERROR(c->GetAttr("num_dense_features", &num_dense_features)); + int num_sparse_features; + TF_RETURN_IF_ERROR( + c->GetAttr("num_sparse_features", &num_sparse_features)); + // Set output shapes (dense_quantiles and sparse_quantiles) by the + // relevant inputs (dense_values and sparse_values). Note that the output + // has an additional dimension for dimension_ids. + for (int i = 0; i < num_dense_features + num_sparse_features; ++i) { + c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0), 2})); + } + return Status::OK(); + }) .Doc(R"doc( Computes quantile for each a given list of dense and sparse feature values using the given buckets. diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto index 4407c4d981785a279b6296f4726a221cacb4c5b1..81411aa84ae848cfaa1392e82a1e38c3df19cdb6 100644 --- a/tensorflow/contrib/boosted_trees/proto/tree_config.proto +++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto @@ -53,7 +53,7 @@ message DenseFloatBinarySplit { // Float feature column and split threshold describing // the rule feature <= threshold. int32 feature_column = 1; - // If feature column is multivalent, this holds the index of the dimensiong + // If feature column is multivalent, this holds the index of the dimension // for the split. Defaults to 0. int32 dimension_id = 5; float threshold = 2; diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py index 27c288bbf78b3b593d0807e92ac7fd9afc4d2725..63b9c5fddf0d9967d53077608664b59d9ae00481 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py @@ -310,6 +310,22 @@ class ModelOpsTest(test_util.TensorFlowTestCase): # The third tree was added after the save. self.assertAllClose(result.eval(), [[-1.1], [-1.1]]) + def testUsedHandlers(self): + with self.test_session(): + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + tree_ensemble_config.growing_metadata.used_handler_ids.append(1) + tree_ensemble_config.growing_metadata.used_handler_ids.append(5) + stamp_token = 3 + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=stamp_token, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="create_tree") + resources.initialize_resources(resources.shared_resources()).run() + result = model_ops.tree_ensemble_used_handlers( + tree_ensemble_handle, stamp_token, num_all_handlers=6) + self.assertAllEqual([0, 1, 0, 0, 0, 1], result.used_handlers_mask.eval()) + self.assertEqual(2, result.num_used_handlers.eval()) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py index c1acf351603dd80c2d14c7ee0a5b4c89706bc1bf..cf55759aaabfb265466f4bbf8b2806d4347ca0b1 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py @@ -120,8 +120,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): """Sets up the prediction tests. Create a batch of two examples having one dense float, two sparse float - single valued, one sparse float multidimensionl and one sparse int features. - The data looks like the following: + single valued, one sparse float multidimensional and one sparse int + features. The data looks like the following: | Instance | Dense0 | SparseF0 | SparseF1 | SparseI0 | SparseM | 0 | 7 | -3 | | 9,1 | __, 5.0 | 1 | -2 | | 4 | | 3, ___ @@ -810,7 +810,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): # building. This tree should never be dropped. num_trees = 10 with self.test_session(): - # Empty tree ensenble. + # Empty tree ensemble. tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() # Add 10 trees with some weights. for i in range(0, num_trees): @@ -951,7 +951,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): def testDropOutZeroProb(self): with self.test_session(): - # Empty tree ensenble. + # Empty tree ensemble. tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() # Add 1000 trees with some weights. for i in range(0, 999): @@ -994,7 +994,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): def testAveragingAllTrees(self): with self.test_session(): - # Empty tree ensenble. + # Empty tree ensemble. tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() adjusted_tree_ensemble_config = ( tree_config_pb2.DecisionTreeEnsembleConfig()) diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py index 81f58de28cbe98bb996c6665114eeb0030ee52f9..074623699d9d82f999c9cbc483ddcd8a959f4bad 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py @@ -482,7 +482,7 @@ class QuantilesOpTest(test_util.TensorFlowTestCase): """Sets up the quantile op tests. Create a batch of 4 examples having 2 dense and 4 sparse features. - Forth sparse feature is multivalent (3 dimensional) + Fourth sparse feature is multivalent (3 dimensional) The data looks like this | Instance | Dense 0 | Dense 1 | Sparse 0 | Sparse 1 |Sparse 2| SparseM | 0 | -0.1 | -1 | -2 | 0.1 | |_ ,1,_ diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index 8ca1aabacaf53b66aaba184962922294427d6803..3e524efbeac74ff754d63cae92b3e194411cb2de 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -1588,7 +1588,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): self.assertEqual( 2, tree_ensemble_config.tree_metadata[2].num_tree_weight_updates) - def testGrowExistingEnsembleTreeWithFeatureSelectionCanStillGrow(self): + def testGrowExistingEnsembleTreeWithFeatureSelectionUsedHandlers(self): """Test growing a tree with feature selection.""" with self.test_session() as session: # Create existing ensemble with one root split and one bias tree. @@ -1649,7 +1649,6 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): num_trees_attempted: 2 num_layers_attempted: 2 used_handler_ids: 2 - used_handler_ids: 5 } """, tree_ensemble_config) tree_ensemble_handle = model_ops.tree_ensemble_variable( @@ -1668,183 +1667,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) - # There are 2 handler_ids in used_handler_ids already but one of them - # is handler 2, so we can still grow trees. - learner_config.constraints.max_number_of_unique_feature_columns = 2 - learner_config = learner_config.SerializeToString() - # Prepare handler inputs. - handler1_partitions = np.array([0], dtype=np.int32) - handler1_gains = np.array([7.62], dtype=np.float32) - handler1_split = [_gen_dense_split_info(5, 0.52, -4.375, 7.143)] - handler2_partitions = np.array([0], dtype=np.int32) - handler2_gains = np.array([0.63], dtype=np.float32) - handler2_split = [_gen_dense_split_info(2, 0.23, -0.6, 0.24)] - handler3_partitions = np.array([0], dtype=np.int32) - handler3_gains = np.array([7.62], dtype=np.float32) - handler3_split = [_gen_categorical_split_info(8, 7, -4.375, 7.143)] - - # Grow tree ensemble. - grow_op = training_ops.grow_tree_ensemble( - tree_ensemble_handle, - stamp_token=0, - next_stamp_token=1, - learning_rate=1, - partition_ids=[ - handler1_partitions, handler2_partitions, handler3_partitions - ], - gains=[handler1_gains, handler2_gains, handler3_gains], - splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, - dropout_seed=123, - center_bias=True) - session.run(grow_op) - - # Expect a new tree to be added with the split from handler 1. - _, serialized = session.run( - model_ops.tree_ensemble_serialize(tree_ensemble_handle)) - tree_ensemble_config.ParseFromString(serialized) - self.assertEqual(3, len(tree_ensemble_config.trees)) - self.assertEqual( - 2, len(tree_ensemble_config.growing_metadata.used_handler_ids)) - - def testGrowExistingEnsembleTreeWithFeatureSelectionEmptyEnsemble(self): - """Test growing a tree with feature selection with empty ensemble.""" - with self.test_session() as session: - # Create existing ensemble with one root split and one bias tree. - tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() - tree_ensemble_handle = model_ops.tree_ensemble_variable( - stamp_token=0, - tree_ensemble_config=tree_ensemble_config.SerializeToString(), - name="tree_ensemble") - resources.initialize_resources(resources.shared_resources()).run() - - # Prepare learner config. - learner_config = _gen_learner_config( - num_classes=2, - l1_reg=0, - l2_reg=0, - tree_complexity=0, - max_depth=1, - min_node_weight=0, - pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) - learner_config.constraints.max_number_of_unique_feature_columns = 2 - learner_config = learner_config.SerializeToString() - # Prepare handler inputs. - handler1_partitions = np.array([0], dtype=np.int32) - handler1_gains = np.array([7.62], dtype=np.float32) - handler1_split = [_gen_dense_split_info(5, 0.52, -4.375, 7.143)] - handler2_partitions = np.array([0], dtype=np.int32) - handler2_gains = np.array([0.63], dtype=np.float32) - handler2_split = [_gen_dense_split_info(2, 0.23, -0.6, 0.24)] - handler3_partitions = np.array([0], dtype=np.int32) - handler3_gains = np.array([7.62], dtype=np.float32) - handler3_split = [_gen_categorical_split_info(8, 7, -4.375, 7.143)] - - # Grow tree ensemble. - grow_op = training_ops.grow_tree_ensemble( - tree_ensemble_handle, - stamp_token=0, - next_stamp_token=1, - learning_rate=1, - partition_ids=[ - handler1_partitions, handler2_partitions, handler3_partitions - ], - gains=[handler1_gains, handler2_gains, handler3_gains], - splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, - dropout_seed=123, - center_bias=True) - session.run(grow_op) - - _, serialized = session.run( - model_ops.tree_ensemble_serialize(tree_ensemble_handle)) - tree_ensemble_config.ParseFromString(serialized) - self.assertEqual(1, len(tree_ensemble_config.trees)) - self.assertEqual( - 1, len(tree_ensemble_config.growing_metadata.used_handler_ids)) - - def testGrowExistingEnsembleTreeWithFeatureSelectionCantGrow(self): - """Test growing a tree with feature selection with empty ensemble.""" - with self.test_session() as session: - # Create existing ensemble with one root split and one bias tree. - tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() - text_format.Merge(""" - trees { - nodes { - leaf { - vector { - value: -0.32 - value: 0.28 - } - } - } - } - trees { - nodes { - categorical_id_binary_split { - feature_column: 3 - feature_id: 7 - left_id: 1 - right_id: 2 - } - node_metadata { - gain: 1.3 - } - } - nodes { - leaf { - sparse_vector { - index: 0 - value: 2.3 - } - } - } - nodes { - leaf { - sparse_vector { - index: 0 - value: -0.9 - } - } - } - } - tree_weights: 0.7 - tree_weights: 1 - tree_metadata { - num_tree_weight_updates: 1 - num_layers_grown: 1 - is_finalized: true - } - tree_metadata { - num_tree_weight_updates: 5 - num_layers_grown: 1 - is_finalized: true - } - growing_metadata { - num_trees_attempted: 2 - num_layers_attempted: 2 - used_handler_ids: 4 - used_handler_ids: 5 - } - """, tree_ensemble_config) - tree_ensemble_handle = model_ops.tree_ensemble_variable( - stamp_token=0, - tree_ensemble_config=tree_ensemble_config.SerializeToString(), - name="tree_ensemble") - resources.initialize_resources(resources.shared_resources()).run() - # Prepare learner config. - learner_config = _gen_learner_config( - num_classes=2, - l1_reg=0, - l2_reg=0, - tree_complexity=0, - max_depth=1, - min_node_weight=0, - pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) - learner_config.constraints.max_number_of_unique_feature_columns = 2 + learner_config.constraints.max_number_of_unique_feature_columns = 3 learner_config = learner_config.SerializeToString() # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) @@ -1876,12 +1700,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): _, serialized = session.run( model_ops.tree_ensemble_serialize(tree_ensemble_handle)) tree_ensemble_config.ParseFromString(serialized) - # We can't grow a tree since we have reached the limit of 2 unique - # features [4, 5] and the only available splits are from - # handlers [0, 1, 2]. - self.assertEqual(2, len(tree_ensemble_config.trees)) - self.assertEqual( - 2, len(tree_ensemble_config.growing_metadata.used_handler_ids)) + self.assertEqual(3, len(tree_ensemble_config.trees)) + # 2 was already used. handler 0 is being added in this tree. + self.assertAllEqual( + [0, 2], tree_ensemble_config.growing_metadata.used_handler_ids) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py index 7a5f509047d46549ba81039a23d29ec987ca7920..25b2c9e2fd72bd018717e8a87fce726f26bad968 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py @@ -25,6 +25,7 @@ from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensem from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensemble_serialize # pylint: disable=unused-import from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensemble_stamp_token +from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensemble_used_handlers # pylint: enable=unused-import from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 97d57e8b23608d4c3a8719426a75056fc6417d1d..1b184d296b329cee481db67992e77d1e33e18035 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -184,7 +184,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): """Finalizes quantile summary stream and resets it for next iteration. Args: - stamp_token: Exepcted current token. + stamp_token: Expected current token. next_stamp_token: Next value for the token. Returns: A list of quantiles or approximate boundaries. diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index f0b66dcbbe1c5167b9993e66b30b1dc8a839c380..85b909e4f2556c520a5bffe46d5954683d9dda5a 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -57,6 +57,8 @@ PREDICTIONS = "predictions" PARTITION_IDS = "partition_ids" NUM_LAYERS_ATTEMPTED = "num_layers" NUM_TREES_ATTEMPTED = "num_trees" +NUM_USED_HANDLERS = "num_used_handlers" +USED_HANDLERS_MASK = "used_handlers_mask" _FEATURE_NAME_TEMPLATE = "%s_%d" @@ -70,7 +72,8 @@ def _get_column_by_index(tensor, indices): return array_ops.reshape(array_ops.gather(p_flat, i_flat), [shape[0], -1]) -def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats): +def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats, + used_handlers): """Returns predictions for the given logits and n_classes. Args: @@ -79,6 +82,8 @@ def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats): that contains predictions when no dropout was applied. partition_ids: A rank 1 `Tensor` with shape [batch_size]. ensemble_stats: A TreeEnsembleStatsOp result tuple. + used_handlers: A TreeEnsembleUsedHandlerOp result tuple of an int and a + boolean mask.. Returns: A dict of predictions. @@ -89,6 +94,8 @@ def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats): result[PARTITION_IDS] = partition_ids result[NUM_LAYERS_ATTEMPTED] = ensemble_stats.attempted_layers result[NUM_TREES_ATTEMPTED] = ensemble_stats.attempted_trees + result[NUM_USED_HANDLERS] = used_handlers.num_used_handlers + result[USED_HANDLERS_MASK] = used_handlers.used_handlers_mask return result @@ -361,6 +368,13 @@ class GradientBoostedDecisionTreeModel(object): """ ensemble_stats = training_ops.tree_ensemble_stats(ensemble_handle, ensemble_stamp) + num_handlers = ( + len(self._dense_floats) + len(self._sparse_float_shapes) + + len(self._sparse_int_shapes)) + # Used during feature selection. + used_handlers = model_ops.tree_ensemble_used_handlers( + ensemble_handle, ensemble_stamp, num_all_handlers=num_handlers) + # We don't need dropout info - we can always restore it based on the # seed. apply_dropout, seed = _dropout_params(mode, ensemble_stats) @@ -395,7 +409,7 @@ class GradientBoostedDecisionTreeModel(object): use_locking=True) return _make_predictions_dict(ensemble_stamp, predictions, partition_ids, - ensemble_stats) + ensemble_stats, used_handlers) def predict(self, mode): """Returns predictions given the features and mode. @@ -710,12 +724,28 @@ class GradientBoostedDecisionTreeModel(object): active_handlers_current_layer = ( active_handlers_current_layer < self._learner_config.feature_fraction_per_tree) - active_handlers = array_ops.stack(active_handlers_current_layer, - array_ops.ones( - [len(handlers)], dtype=dtypes.bool)) + active_handlers = array_ops.stack([ + active_handlers_current_layer, + array_ops.ones([len(handlers)], dtype=dtypes.bool)], axis=1) else: active_handlers = array_ops.ones([len(handlers), 2], dtype=dtypes.bool) + if self._learner_config.constraints.max_number_of_unique_feature_columns: + target = ( + self._learner_config.constraints.max_number_of_unique_feature_columns) + + def _feature_selection_active_handlers(): + # The active list for current and the next iteration. + used_handlers = array_ops.reshape(predictions_dict[USED_HANDLERS_MASK], + [-1, 1]) + used_handlers = array_ops.concat([used_handlers, used_handlers], axis=1) + return math_ops.logical_and(used_handlers, active_handlers) + + active_handlers = ( + control_flow_ops.cond(predictions_dict[NUM_USED_HANDLERS] >= target, + _feature_selection_active_handlers, + lambda: active_handlers)) + # Prepare empty gradients and hessians when handlers are not ready. empty_hess_shape = [1] + hessian_shape.as_list() empty_grad_shape = [1] + gradient_shape.as_list() diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index dba51d4f527792d2a8dedc693f74c07119fd231d..6411f57a5419123e799af9231a04fce8ae7724d4 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -47,6 +47,38 @@ def _squared_loss(label, unused_weights, predictions): return loss +def _append_to_leaf(leaf, c_id, w): + """Helper method for building tree leaves. + + Appends weight contributions for the given class index to a leaf node. + + Args: + leaf: leaf node to append to. + c_id: class Id for the weight update. + w: weight contribution value. + """ + leaf.sparse_vector.index.append(c_id) + leaf.sparse_vector.value.append(w) + + +def _set_float_split(split, feat_col, thresh, l_id, r_id): + """Helper method for building tree float splits. + + Sets split feature column, threshold and children. + + Args: + split: split node to update. + feat_col: feature column for the split. + thresh: threshold to split on forming rule x <= thresh. + l_id: left child Id. + r_id: right child Id. + """ + split.feature_column = feat_col + split.threshold = thresh + split.left_id = l_id + split.right_id = r_id + + class GbdtTest(test_util.TensorFlowTestCase): def setUp(self): @@ -917,6 +949,350 @@ class GbdtTest(test_util.TensorFlowTestCase): output.trees[0].nodes[2].leaf.sparse_vector.value[0], atol=1e-4, rtol=1e-4) + def testTrainFnChiefFeatureSelectionReachedLimitNoGoodSplit(self): + """Tests the train function running on chief with feature selection.""" + with self.test_session() as sess: + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.max_number_of_unique_feature_columns = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float_0"] = array_ops.ones([4, 1], dtypes.float32) + # Feature 1 is predictive but it won't be used because we have reached the + # limit of num_used_handlers >= max_number_of_unique_feature_columns + features["dense_float_1"] = array_ops.constant([0, 0, 1, 1], + dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = variables.Variable( + initial_value=0, + name="ensemble_stamp", + trainable=False, + dtype=dtypes.int64) + + predictions_dict = { + "predictions": + predictions, + "predictions_no_dropout": + predictions, + "partition_ids": + partition_ids, + "ensemble_stamp": + ensemble_stamp, + "num_trees": + 12, + "num_used_handlers": + array_ops.constant(1, dtype=dtypes.int64), + "used_handlers_mask": + array_ops.constant([True, False], dtype=dtypes.bool), + } + + labels = array_ops.constant([0, 0, 1, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + # Create train op. + train_op = gbdt_model.train( + loss=math_ops.reduce_mean( + _squared_loss(labels, weights, predictions)), + predictions_dict=predictions_dict, + labels=labels) + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect no splits to be chosen because the quantile + # buckets will not be ready. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 0) + self.assertEquals(len(output.tree_weights), 0) + self.assertEquals(stamp_token.eval(), 1) + + # Update the stamp to be able to run a second time. + sess.run([ensemble_stamp.assign_add(1)]) + + # On second run, expect a trivial split to be chosen to basically + # predict the average. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 1) + self.assertAllClose(output.tree_weights, [0.1]) + self.assertEquals(stamp_token.eval(), 2) + expected_tree = """ + nodes { + dense_float_binary_split { + feature_column: 0 + threshold: 1.0 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 0 + } + } + nodes { + leaf { + vector { + value: -0.25 + } + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + }""" + self.assertProtoEquals(expected_tree, output.trees[0]) + + def testTrainFnChiefFeatureSelectionWithGoodSplits(self): + """Tests the train function running on chief with feature selection.""" + with self.test_session() as sess: + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.max_number_of_unique_feature_columns = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float_0"] = array_ops.ones([4, 1], dtypes.float32) + # Feature 1 is predictive and is in our selected features so it will be + # used even when we're at the limit. + features["dense_float_1"] = array_ops.constant([0, 0, 1, 1], + dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = variables.Variable( + initial_value=0, + name="ensemble_stamp", + trainable=False, + dtype=dtypes.int64) + + predictions_dict = { + "predictions": + predictions, + "predictions_no_dropout": + predictions, + "partition_ids": + partition_ids, + "ensemble_stamp": + ensemble_stamp, + "num_trees": + 12, + "num_used_handlers": + array_ops.constant(1, dtype=dtypes.int64), + "used_handlers_mask": + array_ops.constant([False, True], dtype=dtypes.bool), + } + + labels = array_ops.constant([0, 0, 1, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + # Create train op. + train_op = gbdt_model.train( + loss=math_ops.reduce_mean( + _squared_loss(labels, weights, predictions)), + predictions_dict=predictions_dict, + labels=labels) + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect no splits to be chosen because the quantile + # buckets will not be ready. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 0) + self.assertEquals(len(output.tree_weights), 0) + self.assertEquals(stamp_token.eval(), 1) + + # Update the stamp to be able to run a second time. + sess.run([ensemble_stamp.assign_add(1)]) + + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + + self.assertEquals(len(output.trees), 1) + self.assertAllClose(output.tree_weights, [0.1]) + self.assertEquals(stamp_token.eval(), 2) + expected_tree = """ + nodes { + dense_float_binary_split { + feature_column: 1 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 0.5 + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + } + nodes { + leaf { + vector { + value: -0.5 + } + } + }""" + self.assertProtoEquals(expected_tree, output.trees[0]) + + def testTrainFnChiefFeatureSelectionReachedLimitIncrementAttemptedLayer(self): + """Tests the train function running on chief with feature selection.""" + with self.test_session() as sess: + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + tree = tree_ensemble_config.trees.add() + + _set_float_split(tree.nodes.add() + .sparse_float_binary_split_default_right.split, 2, 4.0, + 1, 2) + _append_to_leaf(tree.nodes.add().leaf, 0, 0.5) + _append_to_leaf(tree.nodes.add().leaf, 1, 1.2) + tree_ensemble_config.tree_weights.append(1.0) + metadata = tree_ensemble_config.tree_metadata.add() + metadata.is_finalized = False + metadata.num_layers_grown = 1 + tree_ensemble_config = tree_ensemble_config.SerializeToString() + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config=tree_ensemble_config, + name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.max_number_of_unique_feature_columns = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + # Both features will be disabled since the feature selection limit is + # already reached. + features["dense_float_0"] = array_ops.ones([4, 1], dtypes.float32) + features["dense_float_1"] = array_ops.constant([0, 0, 1, 1], + dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = variables.Variable( + initial_value=0, + name="ensemble_stamp", + trainable=False, + dtype=dtypes.int64) + + predictions_dict = { + "predictions": + predictions, + "predictions_no_dropout": + predictions, + "partition_ids": + partition_ids, + "ensemble_stamp": + ensemble_stamp, + "num_trees": + 12, + # We have somehow reached our limit 1. Both of the handlers will be + # disabled. + "num_used_handlers": + array_ops.constant(1, dtype=dtypes.int64), + "used_handlers_mask": + array_ops.constant([False, False], dtype=dtypes.bool), + } + + labels = array_ops.constant([0, 0, 1, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + # Create train op. + train_op = gbdt_model.train( + loss=math_ops.reduce_mean( + _squared_loss(labels, weights, predictions)), + predictions_dict=predictions_dict, + labels=labels) + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect no splits to be chosen because the quantile + # buckets will not be ready. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 1) + self.assertEquals(output.growing_metadata.num_layers_attempted, 1) + self.assertEquals(stamp_token.eval(), 1) + + # Update the stamp to be able to run a second time. + sess.run([ensemble_stamp.assign_add(1)]) + + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + # Make sure the trees are not modified, but the num_layers_attempted is + # incremented so that eventually the training stops. + self.assertEquals(len(output.trees), 1) + self.assertEquals(len(output.trees[0].nodes), 3) + + self.assertEquals(output.growing_metadata.num_layers_attempted, 2) if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD index 6b03df2b8eb636ad888d050a3b2b29eabc3f8934..1a124eca364424b651de86bfaac6f33ad131804b 100644 --- a/tensorflow/contrib/cluster_resolver/BUILD +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -110,5 +110,6 @@ tf_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:training", ], + grpc_enabled = True, main = "python/training/tpu_cluster_resolver_test.py", ) diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py index b04822fa9d66465e34a545d3b00c399bbb196514..1c480b25134b1e54200e0ddb780bd7bb0f122341 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py @@ -53,11 +53,16 @@ class ClusterResolver(object): raise NotImplementedError( 'cluster_spec is not implemented for {}.'.format(self)) + @abc.abstractmethod + def master(self): + """...""" + raise NotImplementedError('master is not implemented for {}.'.format(self)) + class SimpleClusterResolver(ClusterResolver): """Simple implementation of ClusterResolver that accepts a ClusterSpec.""" - def __init__(self, cluster_spec): + def __init__(self, cluster_spec, master=''): """Creates a SimpleClusterResolver from a ClusterSpec.""" super(SimpleClusterResolver, self).__init__() @@ -65,10 +70,18 @@ class SimpleClusterResolver(ClusterResolver): raise TypeError('cluster_spec must be a ClusterSpec.') self._cluster_spec = cluster_spec + if not isinstance(master, str): + raise TypeError('master must be a string.') + self._master = master + def cluster_spec(self): """Returns the ClusterSpec passed into the constructor.""" return self._cluster_spec + def master(self): + """Returns the master address to use when creating a session.""" + return self._master + class UnionClusterResolver(ClusterResolver): """Performs a union on underlying ClusterResolvers. @@ -87,9 +100,13 @@ class UnionClusterResolver(ClusterResolver): Raises: TypeError: If any argument is not a subclass of `ClusterResolvers`. + ValueError: If there are no arguments passed. """ super(UnionClusterResolver, self).__init__() + if not args: + raise ValueError('At least one ClusterResolver is required.') + for cluster_resolver in args: if not isinstance(cluster_resolver, ClusterResolver): raise TypeError('All arguments must be a sub-class of ' @@ -169,3 +186,7 @@ class UnionClusterResolver(ClusterResolver): merged_cluster[job_name].update(task_dict) return ClusterSpec(merged_cluster) + + def master(self): + """master returns the master address from the first cluster resolver.""" + return self._cluster_resolvers[0].master() diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py index dbfb77723cdaab66e29bb41b764593bb5fd61b35..d9c97d53eb3663f6ab2f7b40395592dc7638b896 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py @@ -234,5 +234,7 @@ class UnionClusterResolverTest(test.TestCase): self._verifyClusterSpecEquality(cluster_spec, expected_proto) +# TODO(saeta): Include tests for master resolution + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py index d6f2eced93ba4fda5ac27f9412b6f729981f4f40..3f5824128948453634bc5e5a7d6fdeedae60f5bd 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py @@ -134,3 +134,6 @@ class GceClusterResolver(ClusterResolver): worker_list.sort() return ClusterSpec({self._job_name: worker_list}) + + def master(self): + return '' diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index a6a6e642e4e4c721b94821a70d55d6fe931347d6..300b19733e2b4d1b912f966e94ae0286ed9c694d 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -18,12 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os from six.moves.urllib.request import Request from six.moves.urllib.request import urlopen from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.python.training.server_lib import ClusterSpec +from tensorflow.python.training import server_lib +from tensorflow.python.util import compat _GOOGLE_API_CLIENT_INSTALLED = True try: @@ -33,6 +35,9 @@ except ImportError: _GOOGLE_API_CLIENT_INSTALLED = False +_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' + + class TPUClusterResolver(ClusterResolver): """Cluster Resolver for Google Cloud TPUs. @@ -46,13 +51,30 @@ class TPUClusterResolver(ClusterResolver): req = Request('http://metadata/computeMetadata/v1/%s' % path, headers={'Metadata-Flavor': 'Google'}) resp = urlopen(req) - return resp.read() + return compat.as_bytes(resp.read()) + + def _shouldResolve(self): + if (self._tpu == compat.as_bytes('') or + self._tpu == compat.as_bytes('local') or + self._tpu.startswith(compat.as_bytes('/bns')) or + self._tpu.startswith(compat.as_bytes('grpc://'))): + return False + return True + + def _inGke(self): + """When running in GKE, the environment variable will be set.""" + return _GKE_ENV_VARIABLE in os.environ + + def _gkeMaster(self): + return os.environ[_GKE_ENV_VARIABLE].split(',')[0] def __init__(self, - tpu_names, + tpu=None, zone=None, project=None, - job_name='tpu_worker', + job_name='worker', + coordinator_name='coordinator', + coordinator_address=None, credentials='default', service=None): """Creates a new TPUClusterResolver object. @@ -61,7 +83,11 @@ class TPUClusterResolver(ClusterResolver): for the IP addresses and ports of each Cloud TPU listed. Args: - tpu_names: A list of names of the target Cloud TPUs. + tpu: Either a string, or a list of strings corresponding to the TPUs to + use. If the single string is the empty string, the string 'local', or a + string that begins with 'grpc://' or '/bns', then it is assumed to not + correspond with a Cloud TPU and will instead be passed as the session + master and no ClusterSpec propagation will be done. zone: Zone where the TPUs are located. If omitted or empty, we will assume that the zone of the TPU is the same as the zone of the GCE VM, which we will try to discover from the GCE metadata service. @@ -69,6 +95,12 @@ class TPUClusterResolver(ClusterResolver): empty, we will try to discover the project name of the GCE VM from the GCE metadata service. job_name: Name of the TensorFlow job the TPUs belong to. + coordinator_name: The name to use for the coordinator. Set to None if the + coordinator should not be included in the computed ClusterSpec. + coordinator_address: The address of the coordinator (typically an ip:port + pair). If set to None, a TF server will be started. If coordinator_name + is None, a TF server will not be started even if coordinator_address is + None. credentials: GCE Credentials. If None, then we use default credentials from the oauth2client service: The GCE API object returned by the googleapiclient.discovery @@ -77,29 +109,47 @@ class TPUClusterResolver(ClusterResolver): Raises: ImportError: If the googleapiclient is not installed. + ValueError: If no TPUs are specified. """ + if isinstance(tpu, list): + if not tpu: + raise ValueError('At least one TPU must be specified.') + if len(tpu) != 1: + raise NotImplementedError( + 'Using multiple TPUs in a single session is not yet implemented') + tpu = tpu[0] + + # When using GKE with Cloud TPUs, the env variable will be set. + if tpu is None and self._inGke(): + tpu = self._gkeMaster() + + self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes + self._job_name = job_name + self._credentials = credentials + + should_resolve = self._shouldResolve() - if not project: - project = self._requestComputeMetadata('/project/project-id') + if not project and should_resolve: + project = compat.as_str( + self._requestComputeMetadata('project/project-id')) - if not zone: - zone_path = self._requestComputeMetadata('/instance/zone') + if not zone and should_resolve: + zone_path = compat.as_str(self._requestComputeMetadata('instance/zone')) zone = zone_path.split('/')[-1] self._project = project self._zone = zone - self._tpu_names = tpu_names - self._job_name = job_name - self._credentials = credentials - if credentials == 'default': + if credentials == 'default' and should_resolve: if _GOOGLE_API_CLIENT_INSTALLED: self._credentials = GoogleCredentials.get_application_default() - if service is None: + if service is None and should_resolve: if not _GOOGLE_API_CLIENT_INSTALLED: raise ImportError('googleapiclient must be installed before using the ' - 'TPU cluster resolver') + 'TPU cluster resolver. Execute: `pip install ' + '--upgrade google-api-python-client` to install with ' + 'pip.') self._service = discovery.build( 'tpu', 'v1alpha1', @@ -107,25 +157,41 @@ class TPUClusterResolver(ClusterResolver): else: self._service = service - def get_master(self): - """Get the ClusterSpec grpc master path. + self._coordinator_name = coordinator_name + if coordinator_name and not coordinator_address and should_resolve: + self._start_local_server() + else: + self._coordinator_address = coordinator_address - This returns the grpc path (grpc://1.2.3.4:8470) of first instance in the - ClusterSpec returned by the cluster_spec function. This is suitable for use - for the `master` argument in tf.Session() when you are using one TPU. + def master(self): + """Get the Master string to be used for the session. + + In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of + first instance in the ClusterSpec returned by the cluster_spec function. + + If a non-TPU name is used when constructing a TPUClusterResolver, that will + be returned instead (e.g. If the tpus argument's value when constructing + this TPUClusterResolver was 'grpc://10.240.1.2:8470', + 'grpc://10.240.1.2:8470' will be returned). Returns: - string, the grpc path of the first instance in the ClusterSpec. + string, the connection string to use when creating a session. Raises: ValueError: If none of the TPUs specified exists. """ + if not self._shouldResolve(): + return self._tpu + job_tasks = self.cluster_spec().job_tasks(self._job_name) if not job_tasks: raise ValueError('No TPUs exists with the specified names exist.') return 'grpc://' + job_tasks[0] + def get_master(self): + return self.master() + def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. @@ -134,17 +200,54 @@ class TPUClusterResolver(ClusterResolver): Returns: A ClusterSpec containing host information returned from Cloud TPUs. - """ - worker_list = [] - for tpu_name in self._tpu_names: - full_name = 'projects/%s/locations/%s/nodes/%s' % ( - self._project, self._zone, tpu_name) - request = self._service.projects().locations().nodes().get(name=full_name) - response = request.execute() - - if 'health' in response and response['health'] == 'HEALTHY': - instance_url = '%s:%s' % (response['ipAddress'], response['port']) - worker_list.append(instance_url) - - return ClusterSpec({self._job_name: worker_list}) + Raises: + RuntimeError: If the provided TPU is not healthy. + """ + if not self._shouldResolve(): + return server_lib.ClusterSpec({}) + + full_name = 'projects/%s/locations/%s/nodes/%s' % ( + self._project, self._zone, compat.as_text(self._tpu)) + request = self._service.projects().locations().nodes().get(name=full_name) + response = request.execute() + + if 'health' in response and response['health'] != 'HEALTHY': + raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, + response['health'])) + + if 'networkEndpoints' in response: + worker_list = [ + '%s:%s' % (endpoint['ipAddress'], endpoint['port']) + for endpoint in response['networkEndpoints'] + ] + else: + # Fall back to the deprecated response format + instance_url = '%s:%s' % (response['ipAddress'], response['port']) + worker_list = [instance_url] + + cluster_spec = {self._job_name: worker_list} + + if self._coordinator_address: + cluster_spec[self._coordinator_name] = [self._coordinator_address] + + return server_lib.ClusterSpec(cluster_spec) + + def _start_local_server(self): + address = self._requestComputeMetadata('instance/network-interfaces/0/ip') + self._server = server_lib.Server( + { + 'local': ['0.0.0.0:0'] + }, protocol='grpc', config=None, start=True) + # self._server.target is of the form: grpc://ipaddress:port + target = compat.as_bytes(self._server.target) + splits = target.split(compat.as_bytes(':')) + assert len(splits) == 3, self._server.target + assert splits[0] == compat.as_bytes('grpc'), self._server.target + self._coordinator_port = compat.as_text(splits[2]) + self._coordinator_address = '%s:%s' % ( + address, compat.as_text(self._coordinator_port)) + + def __deepcopy__(self, memo): + # TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy. + return self diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 4fd34629cf74f90869c77b8cb098d3c585a49404..48c3f6bb4f2d1643982e03d9ed68db14c10c184a 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -18,10 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver from tensorflow.python.platform import test from tensorflow.python.training import server_lib - +from tensorflow.python.util import compat mock = test.mock @@ -50,10 +52,12 @@ class MockNodeClass(object): def mock_request_compute_metadata(cls, *args, **kwargs): del cls, kwargs # Unused. - if args[0] == '/project/project-id': + if args[0] == 'project/project-id': return 'test-project' - elif args[0] == '/instance/zone': + elif args[0] == 'instance/zone': return 'projects/test-project/locations/us-central1-c' + elif args[0] == 'instance/network-interfaces/0/ip': + return '10.128.1.2' return '' @@ -71,18 +75,17 @@ class TPUClusterResolverTest(test.TestCase): expected_proto: Expected protobuf """ self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def()) - self.assertProtoEquals( - expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def()) - self.assertProtoEquals( - expected_proto, - server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def()) self.assertProtoEquals( expected_proto, - server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def()) + server_lib.ClusterSpec(cluster_spec).as_cluster_def()) + self.assertProtoEquals(expected_proto, + server_lib.ClusterSpec( + cluster_spec.as_cluster_def()).as_cluster_def()) + self.assertProtoEquals(expected_proto, + server_lib.ClusterSpec( + cluster_spec.as_dict()).as_cluster_def()) - def mock_service_client( - self, - tpu_map=None): + def mock_service_client(self, tpu_map=None): if tpu_map is None: tpu_map = {} @@ -98,8 +101,7 @@ class TPUClusterResolverTest(test.TestCase): return mock_client - @mock.patch.object(TPUClusterResolver, - '_requestComputeMetadata', + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', mock_request_compute_metadata) def testRetrieveProjectAndZoneFromMetadata(self): tpu_map = { @@ -113,17 +115,26 @@ class TPUClusterResolverTest(test.TestCase): tpu_cluster_resolver = TPUClusterResolver( project=None, zone=None, - tpu_names=['test-tpu-1'], + tpu=['test-tpu-1'], credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = tpu_cluster_resolver.cluster_spec() expected_proto = """ - job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } } - """ - self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + job { + name: 'coordinator' + tasks { key: 0 value: '10.128.1.2:%s' } + } + job { + name: 'worker' + tasks { key: 0 value: '10.1.2.3:8470' } + } + """ % tpu_cluster_resolver._coordinator_port + self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) - def testSimpleSuccessfulRetrieval(self): + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testRetrieveProjectAndZoneFromMetadataNoCoordinator(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', @@ -133,116 +144,230 @@ class TPUClusterResolverTest(test.TestCase): } tpu_cluster_resolver = TPUClusterResolver( - project='test-project', - zone='us-central1-c', - tpu_names=['test-tpu-1'], + project=None, + zone=None, + tpu=['test-tpu-1'], + coordinator_name=None, credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = tpu_cluster_resolver.cluster_spec() expected_proto = """ - job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } } + job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) - def testMultipleSuccessfulRetrieval(self): + def testSimpleSuccessfulRetrieval(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', 'port': '8470', 'health': 'HEALTHY' - }, - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { - 'ipAddress': '10.4.5.6', - 'port': '8470', - 'health': 'HEALTHY' } } tpu_cluster_resolver = TPUClusterResolver( project='test-project', zone='us-central1-c', - tpu_names=['test-tpu-2', 'test-tpu-1'], + tpu=['test-tpu-1'], + coordinator_address='10.128.1.5:10203', credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = tpu_cluster_resolver.cluster_spec() expected_proto = """ - job { name: 'tpu_worker' tasks { key: 0 value: '10.4.5.6:8470' } - tasks { key: 1 value: '10.1.2.3:8470' } } + job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } } + job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) - def testHealthyTpuNodeRetrieval(self): + def testNewNetworkEndpointFormat(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { - 'ipAddress': '10.1.2.3', - 'port': '8470', - 'health': 'HEALTHY' - }, - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { - 'ipAddress': '10.4.5.6', - 'port': '8470', - }, - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-3': { - 'ipAddress': '10.7.8.9', - 'port': '8470', - 'health': 'UNHEALTHY' + 'health': 'HEALTHY', + 'networkEndpoints': [{ + 'ipAddress': '10.2.3.4', + 'port': 8470, + }] } } tpu_cluster_resolver = TPUClusterResolver( project='test-project', zone='us-central1-c', - tpu_names=['test-tpu-2', 'test-tpu-1', 'test-tpu-3'], + tpu='test-tpu-1', + coordinator_address='10.128.1.5:10203', credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = tpu_cluster_resolver.cluster_spec() expected_proto = """ - job { - name: 'tpu_worker' - tasks { - key: 0 - value: '10.1.2.3:8470' - } - } + job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } } + job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + self.assertEqual('grpc://10.2.3.4:8470', tpu_cluster_resolver.master()) - def testGetMasterMultipleEntries(self): + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testPodResolution(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { - 'ipAddress': '10.1.2.3', - 'port': '8470', - 'health': 'HEALTHY' - }, - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { - 'ipAddress': '10.4.5.6', - 'port': '8470', - 'health': 'HEALTHY' + 'health': + 'HEALTHY', + 'networkEndpoints': [ + { + 'ipAddress': '10.2.3.4', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.5', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.6', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.7', + 'port': 8470, + }, + ] + } + } + + tpu_cluster_resolver = TPUClusterResolver( + tpu='test-tpu-1', + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'coordinator', + tasks { key: 0 value: '10.128.1.2:%s'} + } + job { + name: 'worker' + tasks { key: 0 value: '10.2.3.4:8470' } + tasks { key: 1 value: '10.2.3.5:8470' } + tasks { key: 2 value: '10.2.3.6:8470' } + tasks { key: 3 value: '10.2.3.7:8470' } + } + """ % tpu_cluster_resolver._coordinator_port + self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) + + def testPodResolutionNoCoordinator(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'health': + 'HEALTHY', + 'networkEndpoints': [ + { + 'ipAddress': '10.2.3.4', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.5', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.6', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.7', + 'port': 8470, + }, + ] } } tpu_cluster_resolver = TPUClusterResolver( project='test-project', zone='us-central1-c', - tpu_names=['test-tpu-2', 'test-tpu-1'], + tpu='test-tpu-1', + coordinator_name=None, credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) - self.assertEqual('grpc://10.4.5.6:8470', tpu_cluster_resolver.get_master()) + + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'worker' + tasks { key: 0 value: '10.2.3.4:8470' } + tasks { key: 1 value: '10.2.3.5:8470' } + tasks { key: 2 value: '10.2.3.6:8470' } + tasks { key: 3 value: '10.2.3.7:8470' } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) def testGetMasterNoEntries(self): tpu_map = {} + with self.assertRaises(ValueError): + TPUClusterResolver( + project='test-project', + zone='us-central1-c', + tpu=[], + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + # TODO(saeta): Convert to parameterized test when included in OSS TF. + def verifyShouldResolve(self, tpu, should_resolve): tpu_cluster_resolver = TPUClusterResolver( project='test-project', zone='us-central1-c', - tpu_names=[], + tpu=tpu, + coordinator_name=None, credentials=None, - service=self.mock_service_client(tpu_map=tpu_map)) - with self.assertRaises(ValueError): - tpu_cluster_resolver.get_master() + service=self.mock_service_client(tpu_map={})) + self.assertEqual(should_resolve, tpu_cluster_resolver._shouldResolve(), + "TPU: '%s'" % tpu) + + def testShouldResolveNoName(self): + self.verifyShouldResolve('', False) + + def testShouldResolveLocal(self): + self.verifyShouldResolve('local', False) + + def testShouldResolveGrpc(self): + self.verifyShouldResolve('grpc://10.1.2.3:8470', False) + + def testShouldResolveBns(self): + self.verifyShouldResolve('/bns/foo/bar', False) + + def testShouldResolveName(self): + self.verifyShouldResolve('mytpu', True) + + def testShouldResolveList(self): + self.verifyShouldResolve(['myothertpu'], True) + + def testShouldResolveGrpcPrefix(self): + self.verifyShouldResolve('grpctpu', True) + + def testNoCallComputeMetadata(self): + tpu_cluster_resolver = TPUClusterResolver(tpu='/bns/foo/bar') + self.assertEqual( + compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master()) + self.assertEqual( + server_lib.ClusterSpec({}), tpu_cluster_resolver.cluster_spec()) + + def testGkeEnvironment(self): + os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470' + self.assertTrue('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ) + tpu_cluster_resolver = TPUClusterResolver() + self.assertTrue(tpu_cluster_resolver._inGke()) + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470'), + compat.as_bytes(tpu_cluster_resolver._gkeMaster())) + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470'), + compat.as_bytes(tpu_cluster_resolver.get_master())) + del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index 8f85a75ee466dbac524a1266dc2522109ca77cd5..fe83bb32046cd75328c92a74cdb4fdb6ce44560e 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -26,7 +26,7 @@ The CMake files in this directory can build the core TensorFlow runtime, an example C++ binary, and a PIP package containing the runtime and Python bindings. -### Pre-requisites +### Prerequisites * CMake version 3.5 or later. @@ -34,14 +34,16 @@ bindings. * [SWIG](http://www.swig.org/download.html) -* Additional pre-requisites for Microsoft Windows: +* Additional prerequisites for Microsoft Windows: - Visual Studio 2015 - Python 3.5 - - NumPy 1.11.0 or later -* Additional pre-requisites for Linux: +* Additional prerequisites for Linux: - Python 2.7 or later - [Docker](https://www.docker.com/) (for automated testing) + +* Python dependencies: + - wheel - NumPy 1.11.0 or later ### Known-good configurations @@ -102,7 +104,7 @@ ops or APIs. Step-by-step Windows build ========================== -1. Install the pre-requisites detailed above, and set up your environment. +1. Install the prerequisites detailed above, and set up your environment. * The following commands assume that you are using the Windows Command Prompt (`cmd.exe`). You will need to set up your environment to use the diff --git a/tensorflow/contrib/cmake/external/cub.cmake b/tensorflow/contrib/cmake/external/cub.cmake index 836889895567f679d9960e29ece1600d1a7a58eb..98a8c7e736e5c8c407b90e8eac440cdc7ab21579 100644 --- a/tensorflow/contrib/cmake/external/cub.cmake +++ b/tensorflow/contrib/cmake/external/cub.cmake @@ -14,8 +14,8 @@ # ============================================================================== include (ExternalProject) -set(cub_URL https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip) -set(cub_HASH SHA256=20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31) +set(cub_URL https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.8.0.zip) +set(cub_HASH SHA256=6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3) set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_ARCHIVE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/cub_archive) diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index a9f43a3ecba4830533efcc13f8c4c1c61fe1ef78..cc218e8ab8ce211a85aa3ece318558dd24049c83 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG 730b778632e79cc3c96ad237f282d687ee325ce7) +set(GRPC_TAG 575bda39755b98d1f7099406bb57a6e3b2074874) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") @@ -35,6 +35,7 @@ else() set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a) endif() diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index f3a37ff5088e3f9e54e38c0edb5777c27b26969f..b9d1dd88d4c2d3c9141ba56e14911e06b4d33f7c 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 8502189abfa44c249c01c2cad64e6ed660a9a668) +set(nsync_TAG 0559ce013feac8db639ee1bf776aca0325d28777) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index aba8a5244e17d717293deec6d9b6e8e725ef010e..ab464bc99a43138130bb2758ae28ecef29805c31 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src) set(PROTOBUF_URL https://github.com/google/protobuf.git) -set(PROTOBUF_TAG 396336eb961b75f03b25824fe86cf6490fb75e3a) +set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") diff --git a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt index aaae18a313dd082b428654091c9411600c981ec9..6f059c7225dd0938b758e8f9c28ec36fcff6db4c 100644 --- a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt +++ b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt @@ -42,7 +42,6 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/c++11") add_definitions ("-DNSYNC_USE_CPP11_TIMEPOINT -DNSYNC_ATOMIC_CPP11") set (NSYNC_OS_CPP_SRC - "platform/c++11/src/nsync_semaphore_mutex.cc" "platform/c++11/src/per_thread_waiter.cc" "platform/c++11/src/yield.cc" "platform/c++11/src/time_rep_timespec.cc" @@ -52,6 +51,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/win32") add_compile_options ("/TP") set (NSYNC_OS_SRC + "platform/c++11/src/nsync_semaphore_mutex.cc" "platform/win32/src/clock_gettime.c" "platform/win32/src/pthread_key_win32.cc" ${NSYNC_OS_CPP_SRC} @@ -68,6 +68,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC ${NSYNC_OS_CPP_SRC} + "platform/c++11/src/nsync_semaphore_mutex.cc" "platform/posix/src/clock_gettime.c" "platform/posix/src/nsync_semaphore_mutex.c" ) @@ -75,9 +76,11 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") "platform/posix/src/start_thread.c" ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "LinuxX") + include_directories (BEFORE "${PROJECT_SOURCE_DIR}/platform/c++11.futex") include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC + "platform/linux/src/nsync_semaphore_futex.c" ${NSYNC_OS_CPP_SRC} ) set (NSYNC_TEST_OS_SRC @@ -87,6 +90,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC + "platform/c++11/src/nsync_semaphore_mutex.cc" ${NSYNC_OS_CPP_SRC} ) set (NSYNC_TEST_OS_SRC @@ -96,6 +100,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC + "platform/c++11/src/nsync_semaphore_mutex.cc" ${NSYNC_OS_CPP_SRC} ) set (NSYNC_TEST_OS_SRC @@ -105,6 +110,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC + "platform/c++11/src/nsync_semaphore_mutex.cc" ${NSYNC_OS_CPP_SRC} ) set (NSYNC_TEST_OS_SRC diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index bfe53c01b3b5fb9db8a5d8fa280d1d7f98974882..112b690511cea1ad5f306af718a8e32995033cf6 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -82,6 +82,7 @@ tensorflow/python/kernel_tests tensorflow/python/kernel_tests/distributions tensorflow/python/kernel_tests/linalg tensorflow/python/kernel_tests/random +tensorflow/python/kernel_tests/testdata tensorflow/python/layers tensorflow/python/lib tensorflow/python/lib/core @@ -147,8 +148,6 @@ tensorflow/contrib/crf tensorflow/contrib/crf/python tensorflow/contrib/crf/python/ops tensorflow/contrib/cudnn_rnn -tensorflow/contrib/cudnn_rnn/kernels -tensorflow/contrib/cudnn_rnn/ops tensorflow/contrib/cudnn_rnn/python tensorflow/contrib/cudnn_rnn/python/layers tensorflow/contrib/cudnn_rnn/python/ops @@ -165,6 +164,7 @@ tensorflow/contrib/distributions/python tensorflow/contrib/distributions/python/ops tensorflow/contrib/distributions/python/ops/bijectors tensorflow/contrib/eager +tensorflow/contrib/eager/proto tensorflow/contrib/eager/python tensorflow/contrib/estimator tensorflow/contrib/estimator/python diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt index 8a9c406d8b118c10ddcaafb0e4fc242aa79cdb57..c03c0c80fe62a4f95d0fcf240ee25725a19d86f0 100644 --- a/tensorflow/contrib/cmake/python_protos.txt +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -4,6 +4,7 @@ tensorflow/python tensorflow/contrib/boosted_trees/proto tensorflow/contrib/cloud/kernels tensorflow/contrib/decision_trees/proto +tensorflow/contrib/eager/proto tensorflow/contrib/gdr tensorflow/contrib/lite/toco tensorflow/contrib/mpi diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index 96ac60d095dbc84470ff1be92f4bf52bb420fc52..a54cbff33b66d63d7229fa2f50b8a4ca962111ed 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -63,6 +63,12 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" "${tensorflow_source_dir}/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc" ) +file(GLOB_RECURSE tf_core_cpu_whitelisted_srcs + "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_id.h" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_id.cc" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc" +) +list(REMOVE_ITEM tf_core_cpu_exclude_srcs ${tf_core_cpu_whitelisted_srcs}) list(REMOVE_ITEM tf_core_cpu_srcs ${tf_core_cpu_exclude_srcs}) if (tensorflow_ENABLE_GPU) @@ -79,6 +85,7 @@ if (tensorflow_ENABLE_GPU) "${tensorflow_source_dir}/tensorflow/core/*test*.cc" ) list(REMOVE_ITEM tf_core_gpu_srcs ${tf_core_gpu_exclude_srcs}) + list(REMOVE_ITEM tf_core_gpu_srcs ${tf_core_cpu_whitelisted_srcs}) list(APPEND tf_core_cpu_srcs ${tf_core_gpu_srcs}) endif() diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 998f99ecc19f88921dce14fde892912fb699ad08..ed018b4fed8e47632f632723f19cc755f2079f86 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -67,8 +67,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" - "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc" - "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 59e094812aaf4da2549d96314fc550e5635f9de8..d6712aa2b48795bb6faf5153c9a8774a7d8bf3c1 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -21,6 +21,7 @@ set(tf_op_lib_names "checkpoint_ops" "control_flow_ops" "ctc_ops" + "cudnn_rnn_ops" "data_flow_ops" "dataset_ops" "functional_ops" @@ -84,7 +85,6 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/t GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(coder "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc") -GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index b730ebd3baacafe8ae401e8987104f3062372954..31e715b654c8baa53e25f54b4854c94e80c88049 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -326,6 +326,7 @@ GENERATE_PYTHON_OP_LIB("checkpoint_ops") GENERATE_PYTHON_OP_LIB("control_flow_ops" ADDITIONAL_LIBRARIES $) GENERATE_PYTHON_OP_LIB("ctc_ops") +GENERATE_PYTHON_OP_LIB("cudnn_rnn_ops") GENERATE_PYTHON_OP_LIB("data_flow_ops") GENERATE_PYTHON_OP_LIB("dataset_ops") GENERATE_PYTHON_OP_LIB("image_ops") @@ -348,6 +349,7 @@ GENERATE_PYTHON_OP_LIB("state_ops") GENERATE_PYTHON_OP_LIB("sparse_ops") GENERATE_PYTHON_OP_LIB("spectral_ops") GENERATE_PYTHON_OP_LIB("string_ops") +GENERATE_PYTHON_OP_LIB("summary_ops") GENERATE_PYTHON_OP_LIB("user_ops") GENERATE_PYTHON_OP_LIB("training_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/training/gen_training_ops.py) @@ -366,8 +368,6 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py) GENERATE_PYTHON_OP_LIB("contrib_coder_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/coder/python/ops/gen_coder_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_cudnn_rnn_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py) GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops" @@ -419,8 +419,6 @@ GENERATE_PYTHON_OP_LIB("stateless_random_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py) GENERATE_PYTHON_OP_LIB("debug_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/debug/ops/gen_debug_ops.py) -GENERATE_PYTHON_OP_LIB("summary_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/summary/gen_summary_ops.py) add_custom_target(tf_python_ops SOURCES ${tf_python_ops_generated_files} ${PYTHON_PROTO_GENFILES}) add_dependencies(tf_python_ops tf_python_op_gen_main) diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index 6d36d5fc5c2854b2d7d2542a3cb12e033e193b88..9738bbeb9aebaeb67495127528e26634887d392c 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -100,8 +100,7 @@ if(WIN32) endif(WIN32) target_include_directories(tensorflow PUBLIC - $ - $) + $) install(TARGETS tensorflow EXPORT tensorflow_export RUNTIME DESTINATION bin @@ -133,10 +132,6 @@ install(DIRECTORY ${tensorflow_source_dir}/tensorflow/stream_executor/ install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src/google/ DESTINATION include/google FILES_MATCHING PATTERN "*.h") -# nsync headers -install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/ - DESTINATION include/external/nsync - FILES_MATCHING PATTERN "*.h") # Eigen directory install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/Eigen/ DESTINATION include/Eigen) diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 1c4ebd7f0c1113bcd0857fb0858df2248499f920..b86a8f1ec236d820c2c8bbfec059d8eaed851c59 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -195,9 +195,11 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/profiler/model_analyzer_test.py" # Fails because uses data dependencies with bazel "${tensorflow_source_dir}/tensorflow/python/saved_model/saved_model_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py" # requires scipy "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/preprocessing/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py" # Takes very long to run without sharding (defined in bazel build file). "${tensorflow_source_dir}/tensorflow/python/kernel_tests/cwise_ops_test.py" # Loading resources in contrib doesn't seem to work on Windows @@ -208,6 +210,9 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py" # Test is flaky on Windows GPU builds (b/38283730). "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/gmm_test.py" + # Disable following manual tag in BUILD. + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py" + ) if (WIN32) set(tf_test_src_py_exclude @@ -222,6 +227,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/debug/cli/curses_ui_test.py" # TFDBG grpc:// mode is not yet available on Windows. "${tensorflow_source_dir}/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py" + "${tensorflow_source_dir}/tensorflow/python/debug/lib/grpc_large_data_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/source_remote_test.py" # stl on windows handles overflows different @@ -475,6 +481,10 @@ if (tensorflow_BUILD_CC_TESTS) "${tensorflow_source_dir}/tensorflow/core/profiler/internal/advisor/*_test.cc" ) + list(REMOVE_ITEM tf_test_src_simple + ${tf_core_profiler_test_srcs} + ) + set(tf_test_lib tf_test_lib) add_library(${tf_test_lib} STATIC ${tf_src_testlib}) diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index fec358c4e1067dc8dc8173d1b9d05dc90b90ca05..fa86ad38c975a95171883adba152e32cd3905082 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -9,52 +9,10 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -tf_custom_op_library( - name = "python/ops/_cudnn_rnn_ops.so", - srcs = [ - "kernels/cudnn_rnn_ops.cc", - "ops/cudnn_rnn_ops.cc", - ], - deps = [ - "//tensorflow/core/kernels:bounds_check_lib", - "@farmhash_archive//:farmhash", - ], -) - -tf_kernel_library( - name = "cudnn_rnn_kernels", - srcs = ["kernels/cudnn_rnn_ops.cc"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor", - "//tensorflow/core/kernels:bounds_check_lib", - "//third_party/eigen3", - "@farmhash_archive//:farmhash", - ], -) - -tf_gen_op_libs( - op_lib_names = ["cudnn_rnn_ops"], - deps = [ - "//tensorflow/core:lib", - ], -) - -tf_gen_op_wrapper_py( - name = "cudnn_rnn_ops", - deps = [":cudnn_rnn_ops_op_lib"], -) tf_custom_op_py_library( name = "cudnn_rnn_py", @@ -64,20 +22,13 @@ tf_custom_op_py_library( "python/layers/cudnn_rnn.py", "python/ops/cudnn_rnn_ops.py", ], - dso = [ - ":python/ops/_cudnn_rnn_ops.so", - ], - kernels = [ - ":cudnn_rnn_kernels", - ":cudnn_rnn_ops_op_lib", - ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - ":cudnn_rnn_ops", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:cudnn_rnn_ops_gen", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:init_ops", @@ -173,23 +124,6 @@ cuda_py_test( ], ) -tf_cc_test( - name = "cudnn_rnn_ops_test_cc", - size = "small", - srcs = [ - "ops/cudnn_rnn_ops_test.cc", - ], - deps = [ - ":cudnn_rnn_ops_op_lib", - "//tensorflow/core", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) - filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index e87162f0ee9cc4eed795555171f55a93639e83cf..622241a1774545529a4cdcb974333b53c8f56caa 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -17,27 +17,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops from tensorflow.contrib.rnn.python.ops import lstm_ops -from tensorflow.contrib.util import loader from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_cudnn_rnn_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.platform import resource_loader from tensorflow.python.training import saver -_cudnn_rnn_ops_so = loader.load_op_library( - resource_loader.get_path_to_datafile("_cudnn_rnn_ops.so")) - CUDNN_RNN_UNIDIRECTION = "unidirectional" CUDNN_RNN_BIDIRECTION = "bidirectional" CUDNN_LSTM = "lstm" diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index 0458199ff771bc45603106411550a39448e515b8..9e25a77d9fd3fecdf82fdc69de97671c8ca6bb2b 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -9,6 +9,10 @@ load( "tf_custom_op_library", "tf_gen_op_libs", ) +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "if_static", +) py_library( name = "data", @@ -17,6 +21,7 @@ py_library( deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:prefetching_ops", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", @@ -29,7 +34,11 @@ py_library( tf_custom_op_library( name = "_dataset_ops.so", srcs = ["ops/dataset_ops.cc"], - deps = ["//tensorflow/contrib/data/kernels:dataset_kernels"], + deps = ["//tensorflow/contrib/data/kernels:dataset_kernels"] + + if_static( + extra_deps = ["//tensorflow/core:lib_proto_parsing"], + otherwise = [], + ), ) tf_gen_op_libs( diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index fcdccdd26ca1824bf13f1fd0cfd80b20ca8a10c3..766721d8d2c2cc22a290d07f064471cb67c07d90 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -23,20 +23,25 @@ removing existing functionality. See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@Counter +@@SqlDataset @@batch_and_drop_remainder +@@bucket_by_sequence_length @@dense_to_sparse_batch @@enumerate_dataset @@group_by_window @@ignore_errors +@@make_batched_features_dataset @@make_saveable_from_iterator @@map_and_batch @@padded_batch_and_drop_remainder @@parallel_interleave +@@prefetch_to_device @@read_batch_features @@rejection_resample @@scan @@shuffle_and_repeat +@@sliding_window_batch @@sloppy_interleave @@unbatch @@ -58,15 +63,21 @@ from tensorflow.contrib.data.python.ops.counter import Counter from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset from tensorflow.contrib.data.python.ops.error_ops import ignore_errors from tensorflow.contrib.data.python.ops.get_single_element import get_single_element +from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length from tensorflow.contrib.data.python.ops.grouping import group_by_window from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator +from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device +from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset from tensorflow.contrib.data.python.ops.readers import read_batch_features from tensorflow.contrib.data.python.ops.readers import SqlDataset from tensorflow.contrib.data.python.ops.resampling import rejection_resample from tensorflow.contrib.data.python.ops.scan_ops import scan from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat +from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch +from tensorflow.python.data.ops.iterator_ops import Iterator +from tensorflow.python.ops.parsing_ops import parse_single_example_v2 as parse_single_example # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD index 9bd6a42da2d93263e84a759cffdc5a9e8f9742fd..c87da7dfaa5943f7918c370f63362673844c7f0e 100644 --- a/tensorflow/contrib/data/kernels/BUILD +++ b/tensorflow/contrib/data/kernels/BUILD @@ -10,6 +10,7 @@ cc_library( name = "prefetching_kernels", srcs = ["prefetching_kernels.cc"], deps = [ + "//tensorflow/core:core_cpu_headers_lib", "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index d3df14bdd03476e9ee4015b374512e5bb9893a63..79d1fc3494d7fd223c52b3086686f732d3875767 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" @@ -35,38 +36,38 @@ using FunctionBufferCallback = std::function; class FunctionBufferingResource : public ResourceBase { public: FunctionBufferingResource(FunctionLibraryRuntime* lib, + std::unique_ptr pflr, const NameAttrList& func, int64 buffer_size, const string& source_device, const string& target_device, const std::vector& func_args, int64 thread_pool_size) : lib_(lib), + pflr_(std::move(pflr)), func_(func), buffer_size_(buffer_size), source_device_(source_device), target_device_(target_device), func_args_(func_args), - thread_pool_(new thread::ThreadPool(Env::Default(), ThreadOptions(), - "buffer_resource", thread_pool_size, - false /* low_latency_hint */)), handle_(kInvalidHandle), is_buffering_(false), end_of_sequence_(false), cancelled_(false) { - runner_ = [this](std::function c) { - thread_pool_->Schedule(std::move(c)); - }; + if (thread_pool_size > 0) { + thread_pool_ = new thread::ThreadPool(Env::Default(), ThreadOptions(), + "buffer_resource", thread_pool_size, + false /* low_latency_hint */); + runner_ = [this](std::function c) { + thread_pool_->Schedule(std::move(c)); + }; + } } ~FunctionBufferingResource() override { Cancel(); - { - mutex_lock l(mu_); - while (is_buffering_) { - cond_var_.wait(l); - } + if (thread_pool_ != nullptr) { + delete thread_pool_; } - delete thread_pool_; } string DebugString() override { @@ -100,6 +101,20 @@ class FunctionBufferingResource : public ResourceBase { void Cancel() LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); cancelled_ = true; + while (is_buffering_) { + cond_var_.wait(l); + } + } + + // Cancels all pending operations and then clears out the state. + void Reset() LOCKS_EXCLUDED(mu_) { + Cancel(); + mutex_lock l(mu_); + buffer_.clear(); + requests_.clear(); + is_buffering_ = false; + end_of_sequence_ = false; + cancelled_ = false; } // If the buffer has anything, runs `callback` on the first element in the @@ -172,7 +187,9 @@ class FunctionBufferingResource : public ResourceBase { FunctionLibraryRuntime::Options opts; // Copied from CapturedFunction::generate_step_id(); opts.step_id = -std::abs(static_cast(random::New64())); - opts.runner = &runner_; + if (runner_ != nullptr) { + opts.runner = &runner_; + } opts.source_device = source_device_; AllocatorAttributes arg_alloc_attr; arg_alloc_attr.set_on_host(true); @@ -191,13 +208,12 @@ class FunctionBufferingResource : public ResourceBase { mutex_lock l(mu_); BufferElement buffer_element; buffer_element.status = status; - if (!status.ok()) { + if (status.ok()) { + buffer_element.value.swap(*rets); + } else { end_of_sequence_ = true; is_buffering_ = false; - buffer_.push_back(std::move(buffer_element)); - return; } - buffer_element.value.swap(*rets); buffer_.push_back(std::move(buffer_element)); if (!requests_.empty()) { buffer_front = std::move(buffer_.front()); @@ -205,7 +221,7 @@ class FunctionBufferingResource : public ResourceBase { callback = std::move(requests_.front()); requests_.pop_front(); } - if (buffer_.size() < buffer_size_) { + if (buffer_.size() < buffer_size_ && !end_of_sequence_) { restart_buffering = true; } else { is_buffering_ = false; @@ -222,12 +238,13 @@ class FunctionBufferingResource : public ResourceBase { mutex mu_; FunctionLibraryRuntime* lib_; + std::unique_ptr pflr_; NameAttrList func_; const int64 buffer_size_; const string source_device_; const string target_device_; const std::vector func_args_; - thread::ThreadPool* thread_pool_; + thread::ThreadPool* thread_pool_ = nullptr; FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_); std::deque buffer_ GUARDED_BY(mu_); std::deque requests_ GUARDED_BY(mu_); @@ -241,7 +258,7 @@ class FunctionBufferingResource : public ResourceBase { class FunctionBufferResourceHandleOp : public OpKernel { public: explicit FunctionBufferResourceHandleOp(OpKernelConstruction* ctx) - : OpKernel(ctx) { + : OpKernel(ctx), flib_def_(nullptr) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_)); @@ -249,6 +266,17 @@ class FunctionBufferResourceHandleOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("thread_pool_size", &thread_pool_size_)); } + ~FunctionBufferResourceHandleOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + void Compute(OpKernelContext* ctx) override { const Tensor* string_arg; OP_REQUIRES_OK(ctx, ctx->input("string_arg", &string_arg)); @@ -267,28 +295,39 @@ class FunctionBufferResourceHandleOp : public OpKernel { const string& source_device = ctx->device()->name(); - ContainerInfo cinfo; - OP_REQUIRES_OK(ctx, cinfo.Init(ctx->resource_manager(), def())); - // Create the resource. - FunctionBufferingResource* buffer; - OP_REQUIRES_OK( - ctx, ctx->resource_manager()->LookupOrCreate( - cinfo.container(), cinfo.name(), &buffer, - [lib, &source_device, &target_device, func_args, - this](FunctionBufferingResource** ptr) { - *ptr = new FunctionBufferingResource( - lib, func_, buffer_size_, source_device, target_device, - func_args, thread_pool_size_); - return Status::OK(); - })); - OP_REQUIRES_OK(ctx, buffer->Instantiate()); + mutex_lock l(mu_); + if (!initialized_) { + OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def())); + FunctionLibraryRuntime* clone_lib; + std::unique_ptr pflr; + OP_REQUIRES_OK(ctx, lib->Clone(&flib_def_, &pflr, &clone_lib)); + // Create the resource. + FunctionBufferingResource* buffer; + OP_REQUIRES_OK( + ctx, + ctx->resource_manager()->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &buffer, + [clone_lib, &pflr, &source_device, &target_device, func_args, + this](FunctionBufferingResource** ptr) { + *ptr = new FunctionBufferingResource( + clone_lib, std::move(pflr), func_, buffer_size_, + source_device, target_device, func_args, thread_pool_size_); + return Status::OK(); + })); + OP_REQUIRES_OK(ctx, buffer->Instantiate()); + initialized_ = true; + } OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( - ctx, 0, cinfo.container(), cinfo.name(), + ctx, 0, cinfo_.container(), cinfo_.name(), MakeTypeIndex())); } private: + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; + std::unique_ptr flib_def_; NameAttrList func_; int64 buffer_size_; string container_; @@ -374,4 +413,62 @@ REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext") FunctionBufferingResourceGetNextOp); #endif // TENSORFLOW_USE_SYCL +// Resets the FunctionBufferingResource, cancelling all pending requests and +// clearing out the buffer. +class FunctionBufferingResourceResetOp : public OpKernel { + public: + explicit FunctionBufferingResourceResetOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + ~FunctionBufferingResourceResetOp() override {} + + void Compute(OpKernelContext* ctx) override { + ResourceHandle handle; + OP_REQUIRES_OK(ctx, + HandleFromInput(ctx, "function_buffer_resource", &handle)); + FunctionBufferingResource* buffer = nullptr; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, handle, &buffer)); + core::ScopedUnref s(buffer); + + buffer->Reset(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset") + .Device(DEVICE_CPU) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceResetOp); +REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset") + .Device(DEVICE_GPU) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceResetOp); +#if TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset") + .Device(DEVICE_SYCL) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceResetOp); +#endif // TENSORFLOW_USE_SYCL + +class IteratorGetDeviceOp : public OpKernel { + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* ctx) override { + // NOTE(mrry): We do not currently Validate that the handle + // corresponds to a real IteratorResource, because that symbol is + // not exposed from the framework library. + Tensor* device_name_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({}), &device_name_t)); + // NOTE(mrry): Since the operation's input is a resource, we must be + // colocated with it, and so we can simply return the current device's + // name without looking at the input. + device_name_t->scalar()() = ctx->device()->name(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("IteratorGetDevice").Device(DEVICE_CPU), + IteratorGetDeviceOp); + } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc index 4b3edde85fc755f1c7694a555b867317e81f149d..63e19ae3f837c9d3cfb1221df64360ee74117f13 100644 --- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc @@ -166,14 +166,10 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { params.runner = [pool](std::function c) { pool->Schedule(std::move(c)); }; - params.stats_aggregator_getter = [ctx]() { - return ctx->stats_aggregator(); - }; + params.stats_aggregator_getter = ctx->stats_aggregator_getter(); params.lib = ctx->lib(); params.function_library = ctx->function_library(); - params.allocator_getter = [ctx](AllocatorAttributes attrs) { - return ctx->allocator(attrs); - }; + params.allocator_getter = ctx->allocator_getter(); IteratorContext threadpool_ctx(params); return input_impl_->GetNext(&threadpool_ctx, out_tensors, end_of_sequence); diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index a4c1212da11a2410461a120ed5f7116e80e4b903..bd96448d64e94c04da6d6b1d6506342631d5b3fb 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -37,6 +37,14 @@ REGISTER_OP("UniqueDataset") Creates a dataset that contains the unique elements of `input_dataset`. )doc"); +REGISTER_OP("IteratorGetDevice") + .Input("resource: resource") + .Output("device: string") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Returns the name of the device on which `resource` has been placed. +)doc"); + REGISTER_OP("FunctionBufferingResource") .Input("string_arg: string") .Input("target_device: string") @@ -75,6 +83,15 @@ output: A list of return values. output_types: The type list for the return values. )doc"); +REGISTER_OP("FunctionBufferingResourceReset") + .Input("function_buffer_resource: resource") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Resets the FunctionBufferingResource. + +function_buffer_resource: The FunctionBufferingResource handle. +)doc"); + REGISTER_OP("ThreadPoolDataset") .Input("input_dataset: variant") .Input("thread_pool: resource") diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 82cd276ce8073b1e66bbc620fa845733aaaca4d4..0b3bf63f79430a7b0fb0a1b72f0b287f1370eb60 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -168,8 +168,10 @@ py_test( srcs = ["interleave_dataset_op_test.py"], srcs_version = "PY2AND3", tags = [ + "manual", "no_oss", "no_pip", + "notap", ], deps = [ ":dataset_serialization_test", @@ -295,6 +297,7 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", + "//third_party/py/numpy", ], ) @@ -476,7 +479,8 @@ py_test( srcs_version = "PY2AND3", tags = [ "manual", - "no_oss", # b/68785503 + "no_oss", + "notap", ], deps = [ "//tensorflow/contrib/data/python/ops:prefetching_ops", @@ -493,6 +497,23 @@ py_test( ], ) +tf_py_test( + name = "slide_dataset_op_test", + size = "small", + srcs = ["slide_dataset_op_test.py"], + additional_deps = [ + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", + "//third_party/py/numpy", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 71dc1c1172c9d515d4c85f85257c952135098329..75482f67da11401305b7b342cd5c971da71a4f3c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -311,10 +311,10 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def _testBatchAndMapDatasetHelper(self, num_parallel_batches=1): + def _testMapAndBatchDatasetHelper(self, num_parallel_batches=1): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> - # RepeatDataset(count) -> BatchAndMapDataset(square_3, batch_size). + # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) @@ -381,11 +381,51 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - def testBatchAndMapDataset(self): - return self._testBatchAndMapDatasetHelper() + def testMapAndBatchDataset(self): + return self._testMapAndBatchDatasetHelper() - def testBatchAndMapDatasetWithParallelBatching(self): - return self._testBatchAndMapDatasetHelper(num_parallel_batches=10) + def testMapAndBatchDatasetWithParallelBatching(self): + return self._testMapAndBatchDatasetHelper(num_parallel_batches=10) + + def _testMapAndBatchPartialBatchHelper(self, drop_remainder=False): + iterator = ( + dataset_ops.Dataset.range(10).apply( + batching.map_and_batch( + lambda x: array_ops.reshape(x * x, [1]), + batch_size=4, + drop_remainder=drop_remainder)).make_one_shot_iterator()) + if drop_remainder: + self.assertEqual([4, 1], iterator.output_shapes.as_list()) + else: + self.assertEqual([None, 1], iterator.output_shapes.as_list()) + next_element = iterator.get_next() + with self.test_session() as sess: + self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) + self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) + if not drop_remainder: + self.assertAllEqual([[64], [81]], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testMapAndBatchPartialBatch(self): + return self._testMapAndBatchPartialBatchHelper() + + def testMapAndBatchPartialBatchDropRemainder(self): + return self._testMapAndBatchPartialBatchHelper(drop_remainder=True) + + def testMapAndBatchYieldsPartialBatch(self): + iterator = (dataset_ops.Dataset.range(10) + .apply(batching.map_and_batch( + lambda x: array_ops.reshape(x * x, [1]), 4)) + .make_one_shot_iterator()) + self.assertEqual([None, 1], iterator.output_shapes.as_list()) + next_element = iterator.get_next() + with self.test_session() as sess: + self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) + self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) + self.assertAllEqual([[64], [81]], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) def testMapAndBatchSparse(self): @@ -411,7 +451,7 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testBatchAndMapDatasetFails(self): + def testMapAndBatchDatasetFails(self): """Test a dataset that maps a TF function across its input elements.""" dataset = dataset_ops.Dataset.from_tensors( array_ops.check_numerics( @@ -425,7 +465,7 @@ class BatchDatasetTest(test.TestCase): with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(init_op, feed_dict={batch_size: 14}) - def testBatchAndMapDatasetShapeMismatch(self): + def testMapAndBatchDatasetShapeMismatch(self): """Test a dataset that maps a TF function across its input elements.""" def generator(): diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index f1b494e1a620992365ed75613b508e32f94b40a4..d0131896a1a5986cfc5ed37785a0d0090ae6600c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import random + import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base @@ -379,5 +381,118 @@ class BucketTest(test.TestCase): self.assertEqual(batches, 15) +class BucketBySequenceLength(test.TestCase): + + def testBucket(self): + + boundaries = [10, 20, 30] + batch_sizes = [10, 8, 4, 2] + lengths = [8, 13, 25, 35] + + def element_gen(): + # Produce 1 batch for each bucket + elements = [] + for batch_size, length in zip(batch_sizes, lengths): + for _ in range(batch_size): + elements.append([1] * length) + random.shuffle(elements) + for el in elements: + yield (el,) + + element_len = lambda el: array_ops.shape(el)[0] + dataset = dataset_ops.Dataset.from_generator( + element_gen, (dtypes.int64,), ([None],)).apply( + grouping.bucket_by_sequence_length( + element_len, boundaries, batch_sizes)) + batch, = dataset.make_one_shot_iterator().get_next() + + with self.test_session() as sess: + batches = [] + for _ in range(4): + batches.append(sess.run(batch)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(batch) + batch_sizes_val = [] + lengths_val = [] + for batch in batches: + batch_size = batch.shape[0] + length = batch.shape[1] + batch_sizes_val.append(batch_size) + lengths_val.append(length) + self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) + self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) + self.assertEqual(sorted(lengths), sorted(lengths_val)) + + def testPadToBoundary(self): + + boundaries = [10, 20, 30] + batch_sizes = [10, 8, 4, 2] + lengths = [8, 13, 25] + + def element_gen(): + # Produce 1 batch for each bucket + elements = [] + for batch_size, length in zip(batch_sizes[:-1], lengths): + for _ in range(batch_size): + elements.append([1] * length) + random.shuffle(elements) + for el in elements: + yield (el,) + for _ in range(batch_sizes[-1]): + el = [1] * (boundaries[-1] + 5) + yield (el,) + + element_len = lambda el: array_ops.shape(el)[0] + dataset = dataset_ops.Dataset.from_generator( + element_gen, (dtypes.int64,), ([None],)).apply( + grouping.bucket_by_sequence_length( + element_len, boundaries, batch_sizes, + pad_to_bucket_boundary=True)) + batch, = dataset.make_one_shot_iterator().get_next() + + with self.test_session() as sess: + batches = [] + for _ in range(3): + batches.append(sess.run(batch)) + with self.assertRaisesOpError("bucket_boundaries"): + sess.run(batch) + batch_sizes_val = [] + lengths_val = [] + for batch in batches: + batch_size = batch.shape[0] + length = batch.shape[1] + batch_sizes_val.append(batch_size) + lengths_val.append(length) + batch_sizes = batch_sizes[:-1] + self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) + self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) + self.assertEqual(sorted(boundaries), sorted(lengths_val)) + + def testTupleElements(self): + + def elements_gen(): + text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]] + label = [1, 2, 1, 2] + for x, y in zip(text, label): + yield (x, y) + + def element_length_fn(x, y): + del y + return array_ops.shape(x)[0] + + dataset = dataset_ops.Dataset.from_generator( + generator=elements_gen, + output_shapes=(tensor_shape.TensorShape([None]), + tensor_shape.TensorShape([])), + output_types=(dtypes.int32, dtypes.int32)) + dataset = dataset.apply(grouping.bucket_by_sequence_length( + element_length_func=element_length_fn, + bucket_batch_sizes=[2, 2, 2], + bucket_boundaries=[0, 8])) + shapes = dataset.output_shapes + self.assertEqual([None, None], shapes[0].as_list()) + self.assertEqual([None], shapes[1].as_list()) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py index 32ea44f7c7ba329dc253bb9fbbcac0a1ed16aec7..87b7c6ddb7afcbaaf8fe97cd8be87e6f5af8cd4d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py @@ -22,6 +22,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -33,17 +34,25 @@ class GetSingleElementTest(test.TestCase): take_value = array_ops.placeholder_with_default( constant_op.constant(1, dtype=dtypes.int64), shape=[]) + def make_sparse(x): + x_1d = array_ops.reshape(x, [1]) + x_2d = array_ops.reshape(x, [1, 1]) + return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d) + dataset = (dataset_ops.Dataset.range(100) .skip(skip_value) - .map(lambda x: x * x) + .map(lambda x: (x * x, make_sparse(x))) .take(take_value)) element = get_single_element.get_single_element(dataset) with self.test_session() as sess: - self.assertEqual(0, sess.run(element, feed_dict={skip_value: 0})) - self.assertEqual(25, sess.run(element, feed_dict={skip_value: 5})) - self.assertEqual(100, sess.run(element, feed_dict={skip_value: 10})) + for x in [0, 5, 10]: + dense_val, sparse_val = sess.run(element, feed_dict={skip_value: x}) + self.assertEqual(x * x, dense_val) + self.assertAllEqual([[x]], sparse_val.indices) + self.assertAllEqual([x], sparse_val.values) + self.assertAllEqual([x], sparse_val.dense_shape) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Dataset was empty."): diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index dc3e38db59301bf1819999f479171af35930e9d2..a14736ac09c9174d1536677ad05db76dc8887913 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools import threading from tensorflow.contrib.data.python.ops import prefetching_ops @@ -26,6 +25,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -38,25 +38,29 @@ class StagingAreaOpsTest(test.TestCase): def setUp(self): self._event = threading.Event() - def _prefetch_fn_helper(self, buffer_name, device0, device1): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + def _create_ds_and_iterator(self, device0, initializable=False): def gen(): - for i in itertools.count(start=1, step=1): - yield [i + 0.0] + for i in range(1, 10): + yield [float(i)] if i == 6: self._event.set() with ops.device(device0): - dataset_3 = dataset_ops.Dataset.from_generator(gen, (dtypes.float32)) - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_3_handle = iterator_3.string_handle() + ds = dataset_ops.Dataset.from_generator(gen, (dtypes.float32)) + if initializable: + ds_iterator = ds.make_initializable_iterator() + else: + ds_iterator = ds.make_one_shot_iterator() + return (ds, ds_iterator) + + def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1): + ds_iterator_handle = ds_iterator.string_handle() @function.Defun(dtypes.string) def _remote_fn(h): remote_iterator = iterator_ops.Iterator.from_string_handle( - h, dataset_3.output_types, dataset_3.output_shapes) + h, ds.output_types, ds.output_shapes) return remote_iterator.get_next() target = constant_op.constant(device0) @@ -64,7 +68,7 @@ class StagingAreaOpsTest(test.TestCase): buffer_resource_handle = prefetching_ops.function_buffering_resource( f=_remote_fn, target_device=target, - string_arg=iterator_3_handle, + string_arg=ds_iterator_handle, buffer_size=3, thread_pool_size=2, shared_name=buffer_name) @@ -73,6 +77,20 @@ class StagingAreaOpsTest(test.TestCase): prefetch_op = prefetching_ops.function_buffering_resource_get_next( function_buffer_resource=buffer_resource_handle, output_types=[dtypes.float32]) + reset_op = prefetching_ops.function_buffering_resource_reset( + function_buffer_resource=buffer_resource_handle) + destroy_op = resource_variable_ops.destroy_resource_op( + buffer_resource_handle, ignore_lookup_error=True) + + return (prefetch_op, reset_op, destroy_op) + + def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + + ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False) + prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name, + device0, device1) with self.test_session(config=worker_config) as sess: elem = sess.run(prefetch_op) @@ -86,27 +104,150 @@ class StagingAreaOpsTest(test.TestCase): self._event.wait() elem = sess.run(prefetch_op) self.assertEqual(elem, [5.0]) - sess.run( - resource_variable_ops.destroy_resource_op( - buffer_resource_handle, ignore_lookup_error=True)) + sess.run(destroy_op) def testSameDeviceCPU(self): - self._prefetch_fn_helper("same_device_cpu", - "/job:localhost/replica:0/task:0/cpu:0", - "/job:localhost/replica:0/task:0/cpu:0") + self._prefetch_fn_helper_one_shot("same_device_cpu", + "/job:localhost/replica:0/task:0/cpu:0", + "/job:localhost/replica:0/task:0/cpu:0") def testDifferentDeviceCPU(self): - self._prefetch_fn_helper("diff_device_cpu", - "/job:localhost/replica:0/task:0/cpu:0", - "/job:localhost/replica:0/task:0/cpu:1") + self._prefetch_fn_helper_one_shot("diff_device_cpu", + "/job:localhost/replica:0/task:0/cpu:0", + "/job:localhost/replica:0/task:0/cpu:1") def testDifferentDeviceCPUGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") - self._prefetch_fn_helper("cpu_gpu", "/job:localhost/replica:0/task:0/cpu:0", - "/job:localhost/replica:0/task:0/gpu:0") + self._prefetch_fn_helper_one_shot("cpu_gpu", + "/job:localhost/replica:0/task:0/cpu:0", + "/job:localhost/replica:0/task:0/gpu:0") + + def testReinitialization(self): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + + device0 = "/job:localhost/replica:0/task:0/cpu:0" + device1 = "/job:localhost/replica:0/task:0/cpu:1" + ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True) + prefetch_op, reset_op, destroy_op = self._create_ops( + ds, ds_iterator, "reinit", device0, device1) + + with self.test_session(config=worker_config) as sess: + sess.run(ds_iterator.initializer) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [1.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [2.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [3.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [4.0]) + self._event.wait() + elem = sess.run(prefetch_op) + self.assertEqual(elem, [5.0]) + # Lets reset the function buffering resource and reinitialize the + # iterator. Should be able to go through this again. + self._event.clear() + sess.run(reset_op) + sess.run(ds_iterator.initializer) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [1.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [2.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [3.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [4.0]) + self._event.wait() + elem = sess.run(prefetch_op) + self.assertEqual(elem, [5.0]) + sess.run(destroy_op) + + def testReinitializationOutOfRange(self): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + + device0 = "/job:localhost/replica:0/task:0/cpu:0" + device1 = "/job:localhost/replica:0/task:0/cpu:1" + ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True) + prefetch_op, reset_op, destroy_op = self._create_ops( + ds, ds_iterator, "reinit", device0, device1) + + with self.test_session(config=worker_config) as sess: + sess.run(ds_iterator.initializer) + for i in range(1, 10): + elem = sess.run(prefetch_op) + self.assertEqual(elem, [float(i)]) + # Try fetching after its over twice to test out end of sequence. + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + + # Now reset everything and try it out again. + self._event.clear() + sess.run(reset_op) + sess.run(ds_iterator.initializer) + for i in range(1, 10): + elem = sess.run(prefetch_op) + self.assertEqual(elem, [float(i)]) + # Try fetching after its over twice to test out end of sequence. + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + + sess.run(destroy_op) + + def testPrefetchToDevice(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToDeviceGpu(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/gpu:0")) + + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + with self.test_session() as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 6efe97444a375febc550ff3a3ea04bcd9330a3a5..6ee1b572f121a9a40dfd638f7a858d5f1176ea3c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -21,6 +21,8 @@ import gzip import os import zlib +import numpy as np + from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import readers from tensorflow.core.example import example_pb2 @@ -262,12 +264,19 @@ class ReadBatchFeaturesTest(test.TestCase): self._num_records = 7 self.test_filenames = self._createFiles() - def _read_batch_features(self, filenames, num_epochs, batch_size): + def _read_batch_features(self, + filenames, + num_epochs, + batch_size, + reader_num_threads=1, + parser_num_threads=1, + shuffle=False, + shuffle_seed=None): self.filenames = filenames self.num_epochs = num_epochs self.batch_size = batch_size - return readers.read_batch_features( + return readers.make_batched_features_dataset( file_pattern=self.filenames, batch_size=self.batch_size, features={ @@ -276,8 +285,12 @@ class ReadBatchFeaturesTest(test.TestCase): "keywords": parsing_ops.VarLenFeature(dtypes.string) }, reader=core_readers.TFRecordDataset, - randomize_input=False, - num_epochs=self.num_epochs) + num_epochs=self.num_epochs, + shuffle=shuffle, + shuffle_seed=shuffle_seed, + reader_num_threads=reader_num_threads, + parser_num_threads=parser_num_threads).make_one_shot_iterator( + ).get_next() def _record(self, f, r): example = example_pb2.Example(features=feature_pb2.Features( @@ -312,24 +325,35 @@ class ReadBatchFeaturesTest(test.TestCase): writer.close() return filenames - def _next_actual_batch(self, sess): - file_op = self.outputs["file"] - keywords_indices_op = self.outputs["keywords"].indices - keywords_values_op = self.outputs["keywords"].values - keywords_dense_shape_op = self.outputs["keywords"].dense_shape - record_op = self.outputs["record"] + def _run_actual_batch(self, outputs, sess): + file_op = outputs["file"] + keywords_indices_op = outputs["keywords"].indices + keywords_values_op = outputs["keywords"].values + keywords_dense_shape_op = outputs["keywords"].dense_shape + record_op = outputs["record"] return sess.run([ file_op, keywords_indices_op, keywords_values_op, keywords_dense_shape_op, record_op ]) - def _next_expected_batch(self, file_indices, batch_size, num_epochs): + def _next_actual_batch(self, sess): + return self._run_actual_batch(self.outputs, sess) + + def _next_expected_batch(self, + file_indices, + batch_size, + num_epochs, + cycle_length=1): def _next_record(file_indices): for j in file_indices: for i in range(self._num_records): yield j, i + def _next_record_interleaved(file_indices, cycle_length): + return self._interleave([_next_record([i]) for i in file_indices], + cycle_length) + file_batch = [] keywords_batch_indices = [] keywords_batch_values = [] @@ -337,7 +361,11 @@ class ReadBatchFeaturesTest(test.TestCase): record_batch = [] batch_index = 0 for _ in range(num_epochs): - for record in _next_record(file_indices): + if cycle_length == 1: + next_records = _next_record(file_indices) + else: + next_records = _next_record_interleaved(file_indices, cycle_length) + for record in next_records: f = record[0] r = record[1] file_batch.append(f) @@ -365,14 +393,41 @@ class ReadBatchFeaturesTest(test.TestCase): [len(file_batch), keywords_batch_max_len], record_batch ] - def _verify_records(self, sess, batch_size, file_index=None, num_epochs=1): + def _interleave(self, iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 + + def _verify_records(self, + sess, + batch_size, + file_index=None, + num_epochs=1, + interleave_cycle_length=1): if file_index is not None: file_indices = [file_index] else: file_indices = range(self._num_files) - for expected_batch in self._next_expected_batch(file_indices, batch_size, - num_epochs): + for expected_batch in self._next_expected_batch( + file_indices, batch_size, num_epochs, interleave_cycle_length): actual_batch = self._next_actual_batch(sess) for i in range(len(expected_batch)): self.assertAllEqual(expected_batch[i], actual_batch[i]) @@ -435,6 +490,484 @@ class ReadBatchFeaturesTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testReadWithFusedShuffleRepeatDataset(self): + num_epochs = 5 + total_records = num_epochs * self._num_records + for batch_size in [1, 2]: + # Test that shuffling with same seed produces the same result. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + outputs1 = self._read_batch_features( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5) + outputs2 = self._read_batch_features( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5) + for _ in range(total_records // batch_size): + batch1 = self._run_actual_batch(outputs1, sess) + batch2 = self._run_actual_batch(outputs2, sess) + for i in range(len(batch1)): + self.assertAllEqual(batch1[i], batch2[i]) + + # Test that shuffling with different seeds produces a different order. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + outputs1 = self._read_batch_features( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5) + outputs2 = self._read_batch_features( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=15) + all_equal = True + for _ in range(total_records // batch_size): + batch1 = self._run_actual_batch(outputs1, sess) + batch2 = self._run_actual_batch(outputs2, sess) + for i in range(len(batch1)): + all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) + self.assertFalse(all_equal) + + def testParallelReadersAndParsers(self): + num_epochs = 5 + for batch_size in [1, 2]: + for reader_num_threads in [2, 4]: + for parser_num_threads in [2, 4]: + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + self.outputs = self._read_batch_features( + filenames=self.test_filenames, + num_epochs=num_epochs, + batch_size=batch_size, + reader_num_threads=reader_num_threads, + parser_num_threads=parser_num_threads) + self._verify_records( + sess, + batch_size, + num_epochs=num_epochs, + interleave_cycle_length=reader_num_threads) + with self.assertRaises(errors.OutOfRangeError): + self._next_actual_batch(sess) + + +class MakeCsvDatasetTest(test.TestCase): + + COLUMN_TYPES = [ + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string + ] + COLUMNS = ["col%d" % i for i in range(len(COLUMN_TYPES))] + DEFAULT_VALS = [[], [], [], [], ["NULL"]] + DEFAULTS = [ + constant_op.constant([], dtype=dtypes.int32), + constant_op.constant([], dtype=dtypes.int64), + constant_op.constant([], dtype=dtypes.float32), + constant_op.constant([], dtype=dtypes.float64), + constant_op.constant(["NULL"], dtype=dtypes.string) + ] + LABEL = COLUMNS[0] + + def setUp(self): + super(MakeCsvDatasetTest, self).setUp() + self._num_files = 2 + self._num_records = 11 + self._test_filenames = self._create_files() + + def _csv_values(self, fileno, recordno): + return [ + fileno, + recordno, + fileno * recordno * 0.5, + fileno * recordno + 0.5, + "record %d" % recordno if recordno % 2 == 1 else "", + ] + + def _csv_record(self, fileno, recordno): + return ",".join(str(v) for v in self._csv_values(fileno, recordno)) + + def _create_file(self, fileno, header=True, comment=True): + fn = os.path.join(self.get_temp_dir(), "csv_file%d.csv" % fileno) + f = open(fn, "w") + if header: + f.write(",".join(self.COLUMNS) + "\n") + for recno in range(self._num_records): + f.write(self._csv_record(fileno, recno) + "\n") + if comment: + f.write("# Some comment goes here. Should be ignored!\n") + f.close() + return fn + + def _create_files(self): + filenames = [] + for i in range(self._num_files): + filenames.append(self._create_file(i)) + return filenames + + def _make_csv_dataset( + self, + filenames, + defaults, + column_names=COLUMNS, + label_name=LABEL, + batch_size=1, + num_epochs=1, + shuffle=False, + shuffle_seed=None, + header=True, + comment="#", + na_value="", + default_float_type=dtypes.float32, + ): + return readers.make_csv_dataset( + filenames, + batch_size=batch_size, + column_names=column_names, + column_defaults=defaults, + label_name=label_name, + num_epochs=num_epochs, + shuffle=shuffle, + shuffle_seed=shuffle_seed, + header=header, + comment=comment, + na_value=na_value, + default_float_type=default_float_type, + ) + + def _next_actual_batch(self, file_indices, batch_size, num_epochs, defaults): + features = {col: list() for col in self.COLUMNS} + for _ in range(num_epochs): + for i in file_indices: + for j in range(self._num_records): + values = self._csv_values(i, j) + for n, v in enumerate(values): + if v == "": # pylint: disable=g-explicit-bool-comparison + values[n] = defaults[n][0] + values[-1] = values[-1].encode("utf-8") + + # Regroup lists by column instead of row + for n, col in enumerate(self.COLUMNS): + features[col].append(values[n]) + if len(list(features.values())[0]) == batch_size: + yield features + features = {col: list() for col in self.COLUMNS} + + def _run_actual_batch(self, outputs, sess): + features, labels = sess.run(outputs) + batch = [features[k] for k in self.COLUMNS if k != self.LABEL] + batch.append(labels) + return batch + + def _verify_records( + self, + sess, + dataset, + file_indices, + defaults=tuple(DEFAULT_VALS), + label_name=LABEL, + batch_size=1, + num_epochs=1, + ): + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + for expected_features in self._next_actual_batch(file_indices, batch_size, + num_epochs, defaults): + actual_features = sess.run(get_next) + + if label_name is not None: + expected_labels = expected_features.pop(label_name) + # Compare labels + self.assertAllEqual(expected_labels, actual_features[1]) + actual_features = actual_features[0] # Extract features dict from tuple + + for k in expected_features.keys(): + # Compare features + self.assertAllEqual(expected_features[k], actual_features[k]) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_make_csv_dataset(self): + defaults = self.DEFAULTS + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Basic test: read from file 0. + dataset = self._make_csv_dataset(self._test_filenames[0], defaults) + self._verify_records(sess, dataset, [0]) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Basic test: read from file 1. + dataset = self._make_csv_dataset(self._test_filenames[1], defaults) + self._verify_records(sess, dataset, [1]) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Read from both files. + dataset = self._make_csv_dataset(self._test_filenames, defaults) + self._verify_records(sess, dataset, range(self._num_files)) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Read from both files. Exercise the `batch` and `num_epochs` parameters + # of make_csv_dataset and make sure they work. + dataset = self._make_csv_dataset( + self._test_filenames, defaults, batch_size=2, num_epochs=10) + self._verify_records( + sess, dataset, range(self._num_files), batch_size=2, num_epochs=10) + + def test_make_csv_dataset_with_bad_columns(self): + """Tests that exception is raised when input is malformed. + """ + dupe_columns = self.COLUMNS[:-1] + self.COLUMNS[:1] + defaults = self.DEFAULTS + + # Duplicate column names + with self.assertRaises(ValueError): + self._make_csv_dataset( + self._test_filenames, defaults, column_names=dupe_columns) + + # Label key not one of column names + with self.assertRaises(ValueError): + self._make_csv_dataset( + self._test_filenames, defaults, label_name="not_a_real_label") + + def test_make_csv_dataset_with_no_label(self): + """Tests that CSV datasets can be created when no label is specified. + """ + defaults = self.DEFAULTS + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Read from both files. Make sure this works with no label key supplied. + dataset = self._make_csv_dataset( + self._test_filenames, + defaults, + batch_size=2, + num_epochs=10, + label_name=None) + self._verify_records( + sess, + dataset, + range(self._num_files), + batch_size=2, + num_epochs=10, + label_name=None) + + def test_make_csv_dataset_with_no_comments(self): + """Tests that datasets can be created from CSV files with no header line. + """ + defaults = self.DEFAULTS + file_without_header = self._create_file( + len(self._test_filenames), comment=False) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + file_without_header, + defaults, + batch_size=2, + num_epochs=10, + comment=None, + ) + self._verify_records( + sess, + dataset, + [len(self._test_filenames)], + batch_size=2, + num_epochs=10, + ) + + def test_make_csv_dataset_with_no_header(self): + """Tests that datasets can be created from CSV files with no header line. + """ + defaults = self.DEFAULTS + file_without_header = self._create_file( + len(self._test_filenames), header=False) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + file_without_header, + defaults, + batch_size=2, + num_epochs=10, + header=False, + ) + self._verify_records( + sess, + dataset, + [len(self._test_filenames)], + batch_size=2, + num_epochs=10, + ) + + def test_make_csv_dataset_with_types(self): + """Tests that defaults can be a dtype instead of a Tensor for required vals. + """ + defaults = [d for d in self.COLUMN_TYPES[:-1]] + defaults.append(constant_op.constant(["NULL"], dtype=dtypes.string)) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset(self._test_filenames, defaults) + self._verify_records(sess, dataset, range(self._num_files)) + + def test_make_csv_dataset_with_no_col_names(self): + """Tests that datasets can be created when column names are not specified. + + In that case, we should infer the column names from the header lines. + """ + defaults = self.DEFAULTS + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Read from both files. Exercise the `batch` and `num_epochs` parameters + # of make_csv_dataset and make sure they work. + dataset = self._make_csv_dataset( + self._test_filenames, + defaults, + column_names=None, + batch_size=2, + num_epochs=10) + self._verify_records( + sess, dataset, range(self._num_files), batch_size=2, num_epochs=10) + + def test_make_csv_dataset_type_inference(self): + """Tests that datasets can be created when no defaults are specified. + + In that case, we should infer the types from the first N records. + """ + # Test that it works with standard test files (with comments, header, etc) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + self._test_filenames, defaults=None, batch_size=2, num_epochs=10) + self._verify_records( + sess, + dataset, + range(self._num_files), + batch_size=2, + num_epochs=10, + defaults=[[], [], [], [], [""]]) + + # Test on a deliberately tricky file + fn = os.path.join(self.get_temp_dir(), "file.csv") + expected_dtypes = [ + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float32, + dtypes.string, dtypes.string + ] + rows = [[0, 0, 0, "NAN", "", "a"], [1, 2**31 + 1, 2**64, 123, "NAN", ""], + ['"123"', 2, 2**64, 123.4, "NAN", '"cd,efg"']] + expected = [[0, 0, 0, 0, "", "a"], [1, 2**31 + 1, 2**64, 123, "", ""], + [123, 2, 2**64, 123.4, "", "cd,efg"]] + for row in expected: + row[-1] = row[-1].encode("utf-8") # py3 expects byte strings + row[-2] = row[-2].encode("utf-8") # py3 expects byte strings + col_names = ["col%d" % i for i in range(len(expected_dtypes))] + with open(fn, "w") as f: + f.write(",".join(col_names)) + f.write("\n") + for row in rows: + f.write(",".join([str(v) if v else "" for v in row]) + "\n") + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + fn, + defaults=None, + column_names=None, + batch_size=1, + num_epochs=1, + label_name=None, + na_value="NAN", + default_float_type=dtypes.float32, + ) + features = dataset.make_one_shot_iterator().get_next() + # Check that types match + for i in range(len(expected_dtypes)): + assert features["col%d" % i].dtype == expected_dtypes[i] + for i in range(len(rows)): + assert sess.run(features) == dict(zip(col_names, expected[i])) + + # With float64 as default type for floats + expected_dtypes = [ + dtypes.int32, dtypes.int64, dtypes.float64, dtypes.float64, + dtypes.string, dtypes.string + ] + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + fn, + defaults=None, + column_names=None, + batch_size=1, + num_epochs=1, + label_name=None, + na_value="NAN", + default_float_type=dtypes.float64, + ) + features = dataset.make_one_shot_iterator().get_next() + # Check that types match + for i in range(len(expected_dtypes)): + assert features["col%d" % i].dtype == expected_dtypes[i] + for i in range(len(rows)): + assert sess.run(features) == dict(zip(col_names, expected[i])) + + def test_make_csv_dataset_with_shuffle(self): + total_records = self._num_files * self._num_records + defaults = self.DEFAULTS + for batch_size in [1, 2]: + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Test that shuffling with the same seed produces the same result + dataset1 = self._make_csv_dataset( + self._test_filenames, + defaults, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5) + dataset2 = self._make_csv_dataset( + self._test_filenames, + defaults, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5) + outputs1 = dataset1.make_one_shot_iterator().get_next() + outputs2 = dataset2.make_one_shot_iterator().get_next() + for _ in range(total_records // batch_size): + batch1 = self._run_actual_batch(outputs1, sess) + batch2 = self._run_actual_batch(outputs2, sess) + for i in range(len(batch1)): + self.assertAllEqual(batch1[i], batch2[i]) + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Test that shuffling with a different seed produces different results + dataset1 = self._make_csv_dataset( + self._test_filenames, + defaults, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5) + dataset2 = self._make_csv_dataset( + self._test_filenames, + defaults, + batch_size=batch_size, + shuffle=True, + shuffle_seed=6) + outputs1 = dataset1.make_one_shot_iterator().get_next() + outputs2 = dataset2.make_one_shot_iterator().get_next() + all_equal = False + for _ in range(total_records // batch_size): + batch1 = self._run_actual_batch(outputs1, sess) + batch2 = self._run_actual_batch(outputs2, sess) + for i in range(len(batch1)): + all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) + self.assertFalse(all_equal) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index 3c7b46629edb13459766b5ef3f392e8d00ad4db8..5f47dcb33999119a690bd633f0c97a12a1ae1c84 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -21,7 +21,10 @@ import numpy as np from tensorflow.contrib.data.python.ops import resampling from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -45,12 +48,10 @@ class ResampleTest(test.TestCase): target_dist=target_dist, initial_dist=initial_dist, class_func=lambda c, _: c, - seed=27)).make_initializable_iterator()) - init_op = iterator.initializer + seed=27)).make_one_shot_iterator()) get_next = iterator.get_next() with self.test_session() as sess: - sess.run(init_op) returned = [] with self.assertRaises(errors.OutOfRangeError): while True: @@ -70,6 +71,43 @@ class ResampleTest(test.TestCase): returned_dist = class_counts / total_returned self.assertAllClose(target_dist, returned_dist, atol=1e-2) + def testRandomClasses(self): + init_dist = [0.25, 0.25, 0.25, 0.25] + target_dist = [0.0, 0.0, 0.0, 1.0] + num_classes = len(init_dist) + # We don't need many samples to test a dirac-delta target distribution + num_samples = 100 + data_np = np.random.choice(num_classes, num_samples, p=init_dist) + + dataset = dataset_ops.Dataset.from_tensor_slices(data_np) + + # Apply a random mapping that preserves the data distribution. + def _remap_fn(_): + return math_ops.cast(random_ops.random_uniform([1]) * num_classes, + dtypes.int32)[0] + dataset = dataset.map(_remap_fn) + + # Reshape distribution. + dataset = dataset.apply( + resampling.rejection_resample( + class_func=lambda x: x, + target_dist=target_dist, + initial_dist=init_dist)) + + get_next = dataset.make_one_shot_iterator().get_next() + + with self.test_session() as sess: + returned = [] + with self.assertRaises(errors.OutOfRangeError): + while True: + returned.append(sess.run(get_next)) + + classes, _ = zip(*returned) + bincount = np.bincount( + np.array(classes), + minlength=num_classes).astype(np.float32) / len(classes) + + self.assertAllClose(target_dist, bincount, atol=1e-2) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..33c48e20bea53b88d69a59e715af38b22dd2cbd4 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -0,0 +1,242 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.ops import sliding +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class SlideDatasetTest(test.TestCase): + + def testSlideDataset(self): + """Test an dataset that maps a TF function across its input elements.""" + components = (np.arange(7), + np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], + np.array(37.0) * np.arange(7)) + + count = array_ops.placeholder(dtypes.int64, shape=[]) + window_size = array_ops.placeholder(dtypes.int64, shape=[]) + stride = array_ops.placeholder(dtypes.int64, shape=[]) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> + # RepeatDataset(count) -> _SlideDataset(window_size, stride). + iterator = (dataset_ops.Dataset.from_tensor_slices(components) + .map(_map_fn) + .repeat(count) + .apply(sliding.sliding_window_batch(window_size, stride)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual([[None] + list(c.shape[1:]) for c in components], + [t.shape.as_list() for t in get_next]) + + with self.test_session() as sess: + # Slide over a finite input, where the window_size divides the + # total number of elements. + sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 7}) + # Same formula with convolution layer. + num_batches = (20 * 7 - 14) // 7 + 1 + for i in range(num_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(14): + self.assertAllEqual(component[(i*7 + j) % 7]**2, + result_component[j]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Slide over a finite input, where the window_size does not + # divide the total number of elements. + sess.run(init_op, feed_dict={count: 20, window_size: 17, stride: 9}) + + num_batches = (20 * 7 - 17) // 9 + 1 + for i in range(num_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(17): + self.assertAllEqual(component[(i*9 + j) % 7]**2, + result_component[j]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Slide over a finite input, which is less than window_size, + # should fail straight away. + sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 4}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 8}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Slide over an empty input should fail straight away. + sess.run(init_op, feed_dict={count: 0, window_size: 8, stride: 4}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Empty window_size should be an initialization time error. + with self.assertRaises(errors.InvalidArgumentError): + sess.run(init_op, feed_dict={count: 14, window_size: 0, stride: 0}) + + # Invalid stride should be an initialization time error. + with self.assertRaises(errors.InvalidArgumentError): + sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 0}) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 3}) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 5}) + + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testSlideSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(5, 3)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + num_batches = (10 - 5) // 3 + 1 + for i in range(num_batches): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], + values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4], + dense_shape=[5, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSlideSparseWithDifferentDenseShapes(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=array_ops.expand_dims( + math_ops.range(i, dtype=dtypes.int64), 1), + values=array_ops.fill([math_ops.to_int32(i)], i), + dense_shape=[i]) + + iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(5, 3)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + num_batches = (10 - 5) // 3 + 1 + for i in range(num_batches): + actual = sess.run(get_next) + expected_indices = [] + expected_values = [] + for j in range(5): + for k in range(i * 3 + j): + expected_indices.append([j, k]) + expected_values.append(i * 3 + j) + expected = sparse_tensor.SparseTensorValue( + indices=expected_indices, + values=expected_values, + dense_shape=[5, i * 3 + 5 - 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testNestedSlideSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = (dataset_ops.Dataset.range(10) + .map(_sparse) + .apply(sliding.sliding_window_batch(4, 2)) + .apply(sliding.sliding_window_batch(3, 1)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + # Slide: 1st batch. + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], + [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], + [2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]], + values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7], + dense_shape=[3, 4, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + # Slide: 2nd batch. + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], + [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], + [2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]], + values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9], + dense_shape=[3, 4, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSlideShapeError(self): + + def generator(): + yield [1.0, 2.0, 3.0] + yield [4.0, 5.0, 6.0] + yield [7.0, 8.0, 9.0, 10.0] + + iterator = (dataset_ops.Dataset.from_generator(generator, dtypes.float32, + output_shapes=[None]) + .apply(sliding.sliding_window_batch(3, 1)) + .make_initializable_iterator()) + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r"Cannot batch tensors with different shapes in component 0. " + r"First element had shape \[3\] and element 2 had shape \[4\]."): + sess.run(next_element) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 789cb9c99a6bba06a1e3bd3371d1378065f49f46..647620eb849268abd679d0f4ff9149ab46c30e9a 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -67,17 +67,23 @@ py_library( srcs_version = "PY2AND3", deps = [ ":dataset_ops", + ":shuffle_ops", + "//tensorflow/python:constant_op", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:lib", + "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", "//tensorflow/python/data/util:nest", + "//third_party/py/numpy", ], ) @@ -104,6 +110,7 @@ py_library( "interleave_ops.py", "resampling.py", "scan_ops.py", + "sliding.py", "stats_ops.py", "threadpool.py", "unique.py", @@ -126,6 +133,7 @@ py_library( "//tensorflow/python:tensor_util", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:readers", "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", @@ -169,6 +177,10 @@ py_library( srcs = ["prefetching_ops.py"], deps = [ ":contrib_op_loader", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", ], ) diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 6eb512dec67cb7b9c8c4518d03aee0b436205f9a..a212adf6cf580267f9f1e6959bef95f04a4ad782 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -348,13 +348,19 @@ class _RestructuredDataset(dataset_ops.Dataset): class _MapAndBatchDataset(dataset_ops.MapDataset): """A `Dataset` that maps a function over a batch of elements.""" - def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches): + def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches, + drop_remainder): """See `Dataset.map()` for details.""" super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) - self._batch_size = ops.convert_to_tensor( + self._batch_size_t = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") - self._num_parallel_batches = ops.convert_to_tensor( + self._num_parallel_batches_t = ops.convert_to_tensor( num_parallel_batches, dtype=dtypes.int64, name="num_parallel_batches") + self._drop_remainder_t = ops.convert_to_tensor( + drop_remainder, dtype=dtypes.bool, name="drop_remainder") + + self._batch_size = batch_size + self._drop_remainder = drop_remainder def _as_variant_tensor(self): # pylint: disable=protected-access @@ -363,8 +369,9 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): input_resource, self._map_func.captured_inputs, f=self._map_func, - batch_size=self._batch_size, - num_parallel_batches=self._num_parallel_batches, + batch_size=self._batch_size_t, + num_parallel_batches=self._num_parallel_batches_t, + drop_remainder=self._drop_remainder_t, output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), output_shapes=nest.flatten( @@ -373,9 +380,9 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): @property def output_shapes(self): + dim = self._batch_size if self._drop_remainder else None return nest.pack_sequence_as(self._output_shapes, [ - tensor_shape.vector(tensor_util.constant_value( - self._batch_size)).concatenate(s) + tensor_shape.vector(dim).concatenate(s) for s in nest.flatten(self._output_shapes) ]) @@ -384,7 +391,10 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): return self._output_types -def map_and_batch(map_func, batch_size, num_parallel_batches=1): +def map_and_batch(map_func, + batch_size, + num_parallel_batches=1, + drop_remainder=False): """Fused implementation of `map` and `batch`. Maps `map_func` across `batch_size` consecutive elements of this dataset @@ -404,6 +414,9 @@ def map_and_batch(map_func, batch_size, num_parallel_batches=1): number of batches to create in parallel. On one hand, higher values can help mitigate the effect of stragglers. On the other hand, higher values can increase contention if CPU is scarce. + drop_remainder: A `tf.bool` scalar `tf.Tensor`, representing whether the + last batch should be dropped in case its size is smaller than desired; + the default behavior is not to drop the smaller batch. Returns: A `Dataset` transformation function, which can be passed to @@ -412,6 +425,6 @@ def map_and_batch(map_func, batch_size, num_parallel_batches=1): def _apply_fn(dataset): return _MapAndBatchDataset(dataset, map_func, batch_size, - num_parallel_batches) + num_parallel_batches, drop_remainder) return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/counter.py b/tensorflow/contrib/data/python/ops/counter.py index 63226fe78163c59025623a362d17c400fbe57c67..6ef65f9624601286691505a795a86dd6226eead1 100644 --- a/tensorflow/contrib/data/python/ops/counter.py +++ b/tensorflow/contrib/data/python/ops/counter.py @@ -25,7 +25,7 @@ from tensorflow.python.framework import ops def Counter(start=0, step=1, dtype=dtypes.int64): - """Creates a `Dataset` of a `step`-separated count startin from `start`. + """Creates a `Dataset` that counts from `start` in steps of size `step`. For example: @@ -38,12 +38,13 @@ def Counter(start=0, step=1, dtype=dtypes.int64): ``` Args: - start: starting value for count. - step: step size. - dtype: counter data type. + start: (Optional.) The starting value for the counter. Defaults to 0. + step: (Optional.) The step size for the counter. Defaults to 1. + dtype: (Optional.) The data type for counter elements. Defaults to + `tf.int64`. Returns: - A `Dataset` of scalar elements. + A `Dataset` of scalar `dtype` elements. """ with ops.name_scope("counter"): start = ops.convert_to_tensor(start, dtype=dtype, name="start") diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py index a817b45b71b608810a9d7536ec123ab84f7cdc3b..3a07df572748e464284f580d67e3a664e71acdfe 100644 --- a/tensorflow/contrib/data/python/ops/get_single_element.py +++ b/tensorflow/contrib/data/python/ops/get_single_element.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.ops import gen_dataset_ops @@ -59,9 +60,14 @@ def get_single_element(dataset): """ if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`dataset` must be a `tf.data.Dataset` object.") - return nest.pack_sequence_as( - dataset.output_types, - gen_dataset_ops.dataset_to_single_element( + + nested_ret = nest.pack_sequence_as( + dataset.output_types, gen_dataset_ops.dataset_to_single_element( dataset._as_variant_tensor(), # pylint: disable=protected-access - output_types=nest.flatten(dataset.output_types), - output_shapes=nest.flatten(dataset.output_shapes))) + output_types=nest.flatten(sparse.as_dense_types( + dataset.output_types, dataset.output_classes)), + output_shapes=nest.flatten(sparse.as_dense_shapes( + dataset.output_shapes, dataset.output_classes)))) + return sparse.deserialize_sparse_tensors( + nested_ret, dataset.output_types, dataset.output_shapes, + dataset.output_classes) diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 67b085002aa7797d858837fea4646fb968ad5d97..36591c055ae8f2c54981525ffcc3df128a990a61 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -17,13 +17,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import math_ops def group_by_window(key_func, @@ -35,7 +42,7 @@ def group_by_window(key_func, This transformation maps each consecutive element in a dataset to a key using `key_func` and groups the elements by key. It then applies `reduce_func` to at most `window_size_func(key)` elements matching the same - key. All execpt the final window for each key will contain + key. All except the final window for each key will contain `window_size_func(key)` elements; the final window may be smaller. You may provide either a constant `window_size` or a window size determined by @@ -85,6 +92,114 @@ def group_by_window(key_func, return _apply_fn +def bucket_by_sequence_length(element_length_func, + bucket_boundaries, + bucket_batch_sizes, + padded_shapes=None, + padding_values=None, + pad_to_bucket_boundary=False): + """A transformation that buckets elements in a `Dataset` by length. + + Elements of the `Dataset` are grouped together by length and then are padded + and batched. + + This is useful for sequence tasks in which the elements have variable length. + Grouping together elements that have similar lengths reduces the total + fraction of padding in a batch which increases training step efficiency. + + Args: + element_length_func: function from element in `Dataset` to `tf.int64`, + determines the length of the element, which will determine the bucket it + goes into. + bucket_boundaries: `list`, upper length boundaries of the buckets. + bucket_batch_sizes: `list`, batch size per bucket. Length should be + `len(bucket_boundaries) + 1`. + padded_shapes: Nested structure of `tf.TensorShape` to pass to + @{tf.data.Dataset.padded_batch}. If not provided, will use + `dataset.output_shapes`, which will result in variable length dimensions + being padded out to the maximum length in each batch. + padding_values: Values to pad with, passed to + @{tf.data.Dataset.padded_batch}. Defaults to padding with 0. + pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown + size to maximum length in batch. If `True`, will pad dimensions with + unknown size to bucket boundary, and caller must ensure that the source + `Dataset` does not contain any elements with length longer than + `max(bucket_boundaries)`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + + Raises: + ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`. + """ + with ops.name_scope("bucket_by_seq_length"): + if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1): + raise ValueError( + "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1") + + batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64) + + def element_to_bucket_id(*args): + """Return int64 id of the length bucket for this element.""" + seq_length = element_length_func(*args) + + boundaries = list(bucket_boundaries) + buckets_min = [np.iinfo(np.int32).min] + boundaries + buckets_max = boundaries + [np.iinfo(np.int32).max] + conditions_c = math_ops.logical_and( + math_ops.less_equal(buckets_min, seq_length), + math_ops.less(seq_length, buckets_max)) + bucket_id = math_ops.reduce_min(array_ops.where(conditions_c)) + + return bucket_id + + def window_size_fn(bucket_id): + # The window size is set to the batch size for this bucket + window_size = batch_sizes[bucket_id] + return window_size + + def make_padded_shapes(shapes, none_filler=None): + padded = [] + for shape in nest.flatten(shapes): + shape = tensor_shape.TensorShape(shape) + shape = [ + none_filler if d.value is None else d + for d in shape + ] + padded.append(shape) + return nest.pack_sequence_as(shapes, padded) + + def batching_fn(bucket_id, grouped_dataset): + """Batch elements in dataset.""" + batch_size = batch_sizes[bucket_id] + none_filler = None + if pad_to_bucket_boundary: + err_msg = ("When pad_to_bucket_boundary=True, elements must have " + "length <= max(bucket_boundaries).") + check = check_ops.assert_less( + bucket_id, + constant_op.constant(len(bucket_batch_sizes) - 1, + dtype=dtypes.int64), + message=err_msg) + with ops.control_dependencies([check]): + boundaries = constant_op.constant(bucket_boundaries, + dtype=dtypes.int64) + bucket_boundary = boundaries[bucket_id] + none_filler = bucket_boundary + shapes = make_padded_shapes( + padded_shapes or grouped_dataset.output_shapes, + none_filler=none_filler) + return grouped_dataset.padded_batch(batch_size, shapes, padding_values) + + def _apply_fn(dataset): + return dataset.apply( + group_by_window(element_to_bucket_id, batching_fn, + window_size_func=window_size_fn)) + + return _apply_fn + + class _VariantDataset(dataset_ops.Dataset): """A Dataset wrapper for a tf.variant-typed function argument.""" diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 3124ca1d1540e12d949dded88ce1c66181be3595..91f19da02d4a479820782822475d9121125fc38e 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -17,101 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import convert -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.data.ops import readers from tensorflow.python.util import deprecation -class ParallelInterleaveDataset(dataset_ops.Dataset): - """A `Dataset` that maps a function over its input and flattens the result.""" - - def __init__(self, input_dataset, map_func, cycle_length, block_length, - sloppy, buffer_output_elements, prefetch_input_elements): - """See `tf.contrib.data.parallel_interleave()` for details.""" - super(ParallelInterleaveDataset, self).__init__() - self._input_dataset = input_dataset - - @function.Defun(*nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes))) - def tf_map_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) - if dataset_ops._should_unpack_args(nested_args): # pylint: disable=protected-access - dataset = map_func(*nested_args) - else: - dataset = map_func(nested_args) - - if not isinstance(dataset, dataset_ops.Dataset): - raise TypeError("`map_func` must return a `Dataset` object.") - - self._output_classes = dataset.output_classes - self._output_types = dataset.output_types - self._output_shapes = dataset.output_shapes - - return dataset._as_variant_tensor() # pylint: disable=protected-access - - self._map_func = tf_map_func - self._map_func.add_to_graph(ops.get_default_graph()) - - self._cycle_length = ops.convert_to_tensor( - cycle_length, dtype=dtypes.int64, name="cycle_length") - self._block_length = ops.convert_to_tensor( - block_length, dtype=dtypes.int64, name="block_length") - self._sloppy = ops.convert_to_tensor( - sloppy, dtype=dtypes.bool, name="sloppy") - self._buffer_output_elements = convert.optional_param_to_tensor( - "buffer_output_elements", - buffer_output_elements, - argument_default=2 * block_length) - self._prefetch_input_elements = convert.optional_param_to_tensor( - "prefetch_input_elements", - prefetch_input_elements, - argument_default=2 * cycle_length) - - def _as_variant_tensor(self): - return gen_dataset_ops.parallel_interleave_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._map_func.captured_inputs, - self._cycle_length, - self._block_length, - self._sloppy, - self._buffer_output_elements, - self._prefetch_input_elements, - f=self._map_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) - - @property - def output_classes(self): - return self._output_classes - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types - - def parallel_interleave(map_func, cycle_length, block_length=1, @@ -162,7 +71,7 @@ def parallel_interleave(map_func, @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): - return ParallelInterleaveDataset( + return readers.ParallelInterleaveDataset( dataset, map_func, cycle_length, block_length, sloppy, buffer_output_elements, prefetch_input_elements) @@ -221,7 +130,7 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): - return ParallelInterleaveDataset( + return readers.ParallelInterleaveDataset( dataset, map_func, cycle_length, diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index 96a9e9ed6649444dac5e56d7dd2fcdb62fc56459..1438b5426f7a5df7eb6dcc6769d049538ff59267 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -17,20 +17,31 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import warnings + from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.contrib.data.python.ops import gen_dataset_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops # TODO(rohanj): Add a python class that constructs resource in the __init__ # method and provides a get_next() that calls the prefetch op. def function_buffering_resource(string_arg, target_device, - shared_name, f, buffer_size, - thread_pool_size=1, + thread_pool_size=0, container="", + shared_name=None, name=None): + if shared_name is None: + shared_name = "" return gen_dataset_ops.function_buffering_resource( string_arg=string_arg, target_device=target_device, @@ -49,3 +60,129 @@ def function_buffering_resource_get_next(function_buffer_resource, function_buffer_resource=function_buffer_resource, output_types=output_types, name=name) + + +def function_buffering_resource_reset(function_buffer_resource, name=None): + return gen_dataset_ops.function_buffering_resource_reset( + function_buffer_resource=function_buffer_resource, name=name) + + +# pylint: disable=protected-access +class _PrefetchToDeviceIterator(object): + """A replacement for @{tf.data.Iterator} that prefetches to another device.""" + + def __init__(self, input_dataset, device, buffer_size): + self._input_dataset = input_dataset + self._get_next_call_count = 0 + input_iterator = input_dataset.make_one_shot_iterator() + input_iterator_handle = input_iterator.string_handle() + + @function.Defun(dtypes.string) + def _prefetch_fn(handle): + remote_iterator = iterator_ops.Iterator.from_string_handle( + handle, input_iterator.output_types, input_iterator.output_shapes, + input_iterator.output_classes) + return remote_iterator.get_next() + + with ops.device(device): + self._buffering_resource = function_buffering_resource( + f=_prefetch_fn, + target_device=gen_dataset_ops.iterator_get_device( + input_iterator._iterator_resource), + string_arg=input_iterator_handle, + buffer_size=buffer_size, + thread_pool_size=0) + + def get_next(self, name=None): + """See @{tf.data.Iterator.get_next}.""" + self._get_next_call_count += 1 + if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: + warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) + + flat_ret = gen_dataset_ops.function_buffering_resource_get_next( + self._buffering_resource, + output_types=nest.flatten(sparse.as_dense_types( + self.output_types, self.output_classes)), name=name) + + ret = sparse.deserialize_sparse_tensors( + nest.pack_sequence_as(self.output_types, flat_ret), + self.output_types, self.output_shapes, self.output_classes) + + for tensor, shape in zip( + nest.flatten(ret), nest.flatten(self.output_shapes)): + if isinstance(tensor, ops.Tensor): + tensor.set_shape(shape) + + return ret + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types +# pylint: enable=protected-access + + +class _PrefetchToDeviceDataset(dataset_ops.Dataset): + """A `Dataset` whose iterator prefetches elements to another device.""" + + def __init__(self, input_dataset, device, buffer_size): + self._input_dataset = input_dataset + self._device = device + self._buffer_size = buffer_size if buffer_size is not None else 1 + + def make_one_shot_iterator(self): + return _PrefetchToDeviceIterator(self._input_dataset, self._device, + self._buffer_size) + + def make_initializable_iterator(self, shared_name=None): + raise NotImplementedError("`prefetch_to_device()` is not currently " + "compatible with initializable iterators. Use " + "`make_one_shot_iterator()` instead.") + + def _as_variant_tensor(self): + # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset + # transformation methods is called. + # TODO(mrry): Investigate support for chaining further transformations after + # the prefetch, including GPU support. + raise NotImplementedError("`prefetch_to_device()` must be the last " + "transformation in a dataset pipeline.") + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +def prefetch_to_device(device, buffer_size=None): + """A transformation that prefetches dataset values to the given `device`. + + NOTE: Although the transformation creates a @{tf.data.Dataset}, the + transformation must be the final `Dataset` in the input pipeline. + + Args: + device: A string. The name of a device to which elements will be prefetched. + buffer_size: (Optional.) The number of elements to buffer on `device`. + Defaults to an automatically chosen value. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + def _apply_fn(dataset): + return _PrefetchToDeviceDataset(dataset, device, buffer_size) + + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py index 7d727165feabb101549567f28a2dfa07083de244..28ef5e50f39dd7d1b6f124e58e068fc968ddd6dc 100644 --- a/tensorflow/contrib/data/python/ops/random_ops.py +++ b/tensorflow/contrib/data/python/ops/random_ops.py @@ -19,11 +19,10 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import random_seed from tensorflow.python.data.util import sparse -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops @@ -34,16 +33,7 @@ class RandomDataset(dataset_ops.Dataset): def __init__(self, seed=None): """A `Dataset` of pseudorandom values.""" super(RandomDataset, self).__init__() - seed, seed2 = random_seed.get_seed(seed) - if seed is None: - self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") - else: - self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") - if seed2 is None: - self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") - else: - self._seed2 = ops.convert_to_tensor( - seed2, dtype=dtypes.int64, name="seed2") + self._seed, self._seed2 = random_seed.get_seed(seed) def _as_variant_tensor(self): return gen_dataset_ops.random_dataset( diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 57f30102778f3bac47580f9bdf94e411dfe1b621..95edca6cdd2e22ca5c2ed4b10ebe6462f9446811 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -17,20 +17,459 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import csv + +import numpy as np + +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.data.util import nest +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.lib.io import file_io from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import string_ops from tensorflow.python.platform import gfile +from tensorflow.python.util import deprecation + +_ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32, + dtypes.int64, dtypes.string) + + +def _is_valid_int32(str_val): + try: + # Checks equality to prevent int32 overflow + return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype( + str_val) + except (ValueError, OverflowError): + return False + + +def _is_valid_int64(str_val): + try: + dtypes.int64.as_numpy_dtype(str_val) + return True + except (ValueError, OverflowError): + return False + + +def _is_valid_float(str_val, float_dtype): + try: + return float_dtype.as_numpy_dtype(str_val) < np.inf + except ValueError: + return False + + +def _infer_type(str_val, na_value, prev_type, float_dtype): + """Given a string, infers its tensor type. + + Infers the type of a value by picking the least 'permissive' type possible, + while still allowing the previous type inference for this column to be valid. + + Args: + str_val: String value to infer the type of. + na_value: Additional string to recognize as a NA/NaN CSV value. + prev_type: Type previously inferred based on values of this column that + we've seen up till now. + float_dtype: Either `tf.float32` or `tf.float64`. Denotes what float type + to parse float strings as. + Returns: + Inferred dtype. + """ + if str_val in ("", na_value): + return prev_type + + if _is_valid_int32(str_val) and prev_type in (None, dtypes.int32): + return dtypes.int32 + + if _is_valid_int64(str_val) and prev_type in (None, dtypes.int32, + dtypes.int64): + return dtypes.int64 + + if _is_valid_float(str_val, float_dtype) and prev_type != dtypes.string: + return float_dtype + + return dtypes.string + + +def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, + comment): + for fn in filenames: + with file_io.FileIO(fn, "r") as f: + rdr = csv.reader( + f, + delimiter=field_delim, + quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE) + if header: + next(rdr) # Skip header lines + + for csv_row in rdr: + if comment is not None and csv_row[0].startswith(comment): + continue # Skip comment lines + + if len(csv_row) != num_cols: + raise ValueError( + "Problem inferring types: CSV row has different number of fields " + "than expected.") + yield csv_row + + +def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim, + na_value, header, comment, float_dtype, + rows_for_inference): + """Infers column types from the first N valid CSV records of files.""" + inferred_types = [None] * num_cols + + for rows_read, csv_row in enumerate( + _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, + comment)): + if rows_for_inference is not None and rows_read >= rows_for_inference: + break + for i, str_val in enumerate(csv_row): + inferred_types[i] = _infer_type(str_val, na_value, inferred_types[i], + float_dtype) + + # Replace None's with a default type + inferred_types = [t or dtypes.string for t in inferred_types] + # Default to 0 or '' for null values + return [ + constant_op.constant([0 if t is not dtypes.string else ""], dtype=t) + for t in inferred_types + ] + + +def _infer_column_names(filenames, field_delim, use_quote_delim): + """Infers column names from first rows of files.""" + csv_kwargs = { + "delimiter": field_delim, + "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE + } + with file_io.FileIO(filenames[0], "r") as f: + column_names = next(csv.reader(f, **csv_kwargs)) + + for name in filenames[1:]: + with file_io.FileIO(name, "r") as f: + if next(csv.reader(f, **csv_kwargs)) != column_names: + raise ValueError("Files have different column names in the header row.") + return column_names + + +def make_csv_dataset( + file_pattern, + batch_size, + column_names=None, + column_defaults=None, + label_name=None, + field_delim=",", + use_quote_delim=True, + na_value="", + header=True, + comment=None, + num_epochs=None, + shuffle=True, + shuffle_buffer_size=10000, + shuffle_seed=None, + prefetch_buffer_size=1, + default_float_type=dtypes.float32, + num_rows_for_inference=100, +): + """Reads CSV files into a dataset. + + Reads CSV files into a dataset, where each element is a (features, labels) + tuple that corresponds to a batch of CSV rows. The features dictionary + maps feature column names to `Tensor`s containing the corresponding + feature data, and labels is a `Tensor` containing the batch's label data. + + Args: + file_pattern: List of files or patterns of file paths containing CSV + records. See @{tf.gfile.Glob} for pattern rules. + batch_size: An int representing the number of consecutive elements of this + dataset to combine in a single batch. + column_names: An optional list of strings that corresponds to the CSV + columns, in order. One per column of the input record. If this is not + provided, infers the column names from the first row of the records. + These names will be the keys of the features dict of each dataset element. + column_defaults: A optional list of default values for the CSV fields. One + item per column of the input record. Each item in the list is either a + valid CSV dtype (float32, float64, int32, int64, or string), or a + `Tensor` with one of the aforementioned types. The tensor can either be + a scalar default value (if the column is optional), or an empty tensor (if + the column is required). If a dtype is provided instead of a tensor, the + column is also treated as required. If this list is not provided, tries + to infer types based on reading the first num_rows_for_inference rows of + files specified, and assumes all columns are optional, defaulting to `0` + for numeric values and `""` for string values. + label_name: A optional string corresponding to the label column. If + provided, the data for this column is returned as a separate `Tensor` from + the features dictionary, so that the dataset complies with the format + expected by a `tf.Estimator.train` or `tf.Estimator.evaluate` input + function. + field_delim: An optional `string`. Defaults to `","`. Char delimiter to + separate fields in a record. + use_quote_delim: An optional bool. Defaults to `True`. If false, treats + double quotation marks as regular characters inside of the string fields. + na_value: Additional string to recognize as NA/NaN. + header: A bool that indicates whether the first rows of provided CSV files + correspond to header lines with column names, and should not be included + in the data. + comment: An optional character string that marks lines that should not be + parsed as csv records. If this is provided, all lines that start with + this character will not be parsed. + num_epochs: An int specifying the number of times this dataset is repeated. + If None, cycles through the dataset forever. + shuffle: A bool that indicates whether the input should be shuffled. + shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size + ensures better shuffling, but would increase memory usage and startup + time. + shuffle_seed: Randomization seed to use for shuffling. + prefetch_buffer_size: An int specifying the number of feature batches to + prefetch for performance improvement. Recommended value is the number of + batches consumed per training step. + default_float_type: Either `tf.float32` or `tf.float64`. If defaults are + not provided, float-like strings are interpreted to be this type. + num_rows_for_inference: Number of rows of a file to use for type inference + if record_defaults is not provided. If None, reads all the rows of all + the files. Defaults to 100. + + Returns: + A dataset, where each element is a (features, labels) tuple that corresponds + to a batch of `batch_size` CSV rows. The features dictionary maps feature + column names to `Tensor`s containing the corresponding column data, and + labels is a `Tensor` containing the column data for the label column + specified by `label_name`. + + Raises: + ValueError: If any of the arguments is malformed. + """ + filenames = _get_file_names(file_pattern, shuffle) + if comment is not None and len(comment) != 1: + raise ValueError("`comment` arg must be a single-character string or None") + + # Clean arguments; figure out column names and defaults + if column_names is None: + if not header: + raise ValueError("Cannot infer column names without a header line.") + # If column names are not provided, infer from the header lines + column_names = _infer_column_names(filenames, field_delim, use_quote_delim) + if len(column_names) != len(set(column_names)): + raise ValueError("Cannot have duplicate column names.") + + if column_defaults is not None: + column_defaults = [ + constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x + for x in column_defaults + ] + else: + # If column defaults are not provided, infer from records at graph + # construction time + column_defaults = _infer_column_defaults( + filenames, len(column_names), field_delim, use_quote_delim, na_value, + header, comment, default_float_type, num_rows_for_inference) + + dataset = dataset_ops.Dataset.from_tensor_slices(filenames) + if label_name is not None and label_name not in column_names: + raise ValueError("`label_name` provided must be one of the columns.") + + # Define map and filter functions + def filter_fn(line): + return math_ops.not_equal(string_ops.substr(line, 0, 1), comment) + + def filename_to_dataset(filename): + ds = core_readers.TextLineDataset(filename) + if header: + ds = ds.skip(1) + if comment is not None: + ds = ds.filter(filter_fn) + return ds + + def decode_csv(line): + """Decodes CSV line into features. + + Args: + line: String tensor corresponding to one csv record. + Returns: + A dictionary of feature names to values for that particular record. If + label_name is provided, extracts the label feature to be returned as the + second element of the tuple. + """ + columns = parsing_ops.decode_csv( + line, + column_defaults, + field_delim=field_delim, + use_quote_delim=use_quote_delim, + na_value=na_value, + ) + features = dict(zip(column_names, columns)) + if label_name is not None: + label = features.pop(label_name) + return features, label + return features + + # TODO(rachelim): interleave records from files for better shuffling + dataset = dataset.flat_map(filename_to_dataset) + # TODO(rachelim): use fused shuffle_and_repeat for perf + if shuffle: + dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed) + if num_epochs != 1: + dataset = dataset.repeat(num_epochs) + + dataset = dataset.batch(batch_size) + dataset = dataset.map(decode_csv) + dataset = dataset.prefetch(prefetch_buffer_size) + return dataset + + +def make_batched_features_dataset(file_pattern, + batch_size, + features, + reader=core_readers.TFRecordDataset, + reader_args=None, + num_epochs=None, + shuffle=True, + shuffle_buffer_size=10000, + shuffle_seed=None, + prefetch_buffer_size=1, + reader_num_threads=1, + parser_num_threads=2, + sloppy_ordering=False): + """Returns a `Dataset` of feature dictionaries from `Example` protos. + + Example: + + ``` + serialized_examples = [ + features { + feature { key: "age" value { int64_list { value: [ 0 ] } } } + feature { key: "gender" value { bytes_list { value: [ "f" ] } } } + feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } } + }, + features { + feature { key: "age" value { int64_list { value: [] } } } + feature { key: "gender" value { bytes_list { value: [ "f" ] } } } + feature { key: "kws" value { bytes_list { value: [ "sports" ] } } } + } + ] + ``` + + We can use arguments: + + ``` + features: { + "age": FixedLenFeature([], dtype=tf.int64, default_value=-1), + "gender": FixedLenFeature([], dtype=tf.string), + "kws": VarLenFeature(dtype=tf.string), + } + ``` + And the expected output is: + + ```python + { + "age": [[0], [-1]], + "gender": [["f"], ["f"]], + "kws": SparseTensor( + indices=[[0, 0], [0, 1], [1, 0]], + values=["code", "art", "sports"] + dense_shape=[2, 2]), + } + ``` + Args: + file_pattern: List of files or patterns of file paths containing + `Example` records. See `tf.gfile.Glob` for pattern rules. + batch_size: An int representing the number of consecutive elements of this + dataset to combine in a single batch. + features: A `dict` mapping feature keys to `FixedLenFeature` or + `VarLenFeature` values. See `tf.parse_example`. + reader: A function or class that can be + called with a `filenames` tensor and (optional) `reader_args` and returns + a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`. + reader_args: Additional arguments to pass to the reader class. + num_epochs: Integer specifying the number of times to read through the + dataset. If None, cycles through the dataset forever. Defaults to `None`. + shuffle: A boolean, indicates whether the input should be shuffled. Defaults + to `True`. + shuffle_buffer_size: Buffer size of the ShuffleDataset. A large capacity + ensures better shuffling but would increase memory usage and startup time. + shuffle_seed: Randomization seed to use for shuffling. + prefetch_buffer_size: Number of feature batches to prefetch in order to + improve performance. Recommended value is the number of batches consumed + per training step (default is 1). + reader_num_threads: Number of threads used to read `Example` records. If >1, + the results will be interleaved. + parser_num_threads: Number of threads to use for parsing `Example` tensors + into a dictionary of `Feature` tensors. + sloppy_ordering: If `True`, reading performance will be improved at + the cost of non-deterministic ordering. If `False`, the order of elements + produced is deterministic prior to shuffling (elements are still + randomized if `shuffle=True`. Note that if the seed is set, then order + of elements after shuffling is deterministic). Defaults to `False`. + + Returns: + A dataset of `dict` elements. Each `dict` maps feature keys to + `Tensor` or `SparseTensor` objects. + """ + # Create dataset of all matching filenames + if shuffle: + dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=True) + else: + # TODO(b/73959787): Use Dataset.list_files() once ordering is deterministic. + filenames = _get_file_names(file_pattern, shuffle) + dataset = dataset_ops.Dataset.from_tensor_slices(filenames) + + # Read `Example` records from files as tensor objects. + if reader_args is None: + reader_args = [] + + # Read files sequentially (if reader_num_threads=1) or in parallel + dataset = dataset.apply( + interleave_ops.parallel_interleave( + lambda filename: reader(filename, *reader_args), + cycle_length=reader_num_threads, + sloppy=sloppy_ordering)) + + # Extract values if the `Example` tensors are stored as key-value tuples. + if dataset.output_types == (dtypes.string, dtypes.string): + dataset = dataset.map(lambda _, v: v) + + # Apply dataset repeat and shuffle transformations. + repeat_dataset = (num_epochs != 1) + if repeat_dataset and shuffle: + # Used fused shuffle_and_repeat operation for better performance + dataset = dataset.apply( + shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, + shuffle_seed)) + elif repeat_dataset: + dataset = dataset.repeat(num_epochs) + elif shuffle: + dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed) + + dataset = dataset.batch(batch_size) + + # Parse `Example` tensors to a dictionary of `Feature` tensors. + dataset = dataset.map( + lambda x: parsing_ops.parse_example(x, features), + num_parallel_calls=parser_num_threads) + + # TODO(rachelim): Add an optional label_name argument for extracting the label + # from the features dictionary, to comply with the type expected by the + # input_fn to a `tf.Estimator.train` or `tf.Estimator.evaluate` function. + dataset = dataset.prefetch(prefetch_buffer_size) + return dataset + + +@deprecation.deprecated(None, + "Use `tf.contrib.data.make_batched_features_dataset`") def read_batch_features(file_pattern, batch_size, features, - reader, + reader=core_readers.TFRecordDataset, reader_args=None, randomize_input=True, num_epochs=None, @@ -84,43 +523,38 @@ def read_batch_features(file_pattern, dataset to combine in a single batch. features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. See `tf.parse_example`. - reader: A function or class that can be called with a `filenames` tensor - and (optional) `reader_args` and returns a `Dataset` of Examples. + reader: A function or class that can be + called with a `filenames` tensor and (optional) `reader_args` and returns + a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`. reader_args: Additional arguments to pass to the reader class. randomize_input: Whether the input should be randomized. num_epochs: Integer specifying the number of times to read through the dataset. If None, cycles through the dataset forever. - capacity: Capacity of the ShuffleDataset. A large capacity ensures better + capacity: Buffer size of the ShuffleDataset. A large capacity ensures better shuffling but would increase memory usage and startup time. - Returns: A dict from keys in features to `Tensor` or `SparseTensor` objects. """ - filenames = _get_file_names(file_pattern, randomize_input) - if reader_args: - dataset = reader(filenames, *reader_args) - else: - dataset = reader(filenames) - if dataset.output_types == (dtypes.string, dtypes.string): - dataset = dataset.map(lambda _, v: v) - if num_epochs != 1: - dataset = dataset.repeat(num_epochs) - if randomize_input: - dataset = dataset.shuffle(capacity) - dataset = dataset.batch(batch_size) - dataset = dataset.map(lambda x: parsing_ops.parse_example(x, features)) - dataset = dataset.prefetch(1) + dataset = make_batched_features_dataset( + file_pattern, + batch_size, + features, + reader=reader, + reader_args=reader_args, + shuffle=randomize_input, + num_epochs=num_epochs, + shuffle_buffer_size=capacity) iterator = dataset.make_one_shot_iterator() outputs = iterator.get_next() return outputs -def _get_file_names(file_pattern, randomize_input): +def _get_file_names(file_pattern, shuffle): """Parse list of file names from pattern, optionally shuffled. Args: file_pattern: File glob pattern, or list of glob patterns. - randomize_input: Whether to shuffle the order of file names. + shuffle: Whether to shuffle the order of file names. Returns: List of file names matching `file_pattern`. @@ -141,7 +575,7 @@ def _get_file_names(file_pattern, randomize_input): raise ValueError("No files match %s." % file_pattern) # Sort files so it will be deterministic for unit tests. - if not randomize_input: + if not shuffle: file_names = sorted(file_names) return file_names diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py index 56f526a330bfbea7305b0754bfd114c5e97db506..b465397437adbdfaf865efb8ed2f80e57f48fcab 100644 --- a/tensorflow/contrib/data/python/ops/resampling.py +++ b/tensorflow/contrib/data/python/ops/resampling.py @@ -54,7 +54,7 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" dist_estimation_batch_size = 32 - target_dist_t = ops.convert_to_tensor(target_dist, name="initial_dist") + target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist") class_values_ds = dataset.map(class_func) if initial_dist is not None: initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist") @@ -101,14 +101,16 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): initial_dist_ds)) .map(maybe_warn_on_large_rejection)) - current_probabilities_ds = dataset_ops.Dataset.zip( - (acceptance_dist_ds, class_values_ds)).map(array_ops.gather) + def _gather_and_copy(class_val, acceptance_prob, data): + return (class_val, array_ops.gather(acceptance_prob, class_val), data) + current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip( + (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy) filtered_ds = ( - dataset_ops.Dataset.zip((class_values_ds, current_probabilities_ds, - dataset)) + current_probabilities_and_class_and_data_ds .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) return filtered_ds.map(lambda class_value, _, data: (class_value, data)) + return _apply_fn @@ -151,7 +153,7 @@ def _calculate_acceptance_probs(initial_probs, target_probs): ``` - A solution for a_i in terms of the other variabes is the following: + A solution for a_i in terms of the other variables is the following: ```a_i = (t_i / p_i) / max_i[t_i / p_i]``` """ # Add tiny to initial_probs to avoid divide by zero. diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py index 99bb79bc06a421f811869ca9169aaa11deaca2f3..f35795abd38000b13cec0f08596e2ff66e86286c 100644 --- a/tensorflow/contrib/data/python/ops/shuffle_ops.py +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -19,11 +19,11 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import random_seed from tensorflow.python.data.util import sparse from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed from tensorflow.python.ops import gen_dataset_ops @@ -45,17 +45,7 @@ class _ShuffleAndRepeatDataset(dataset_ops.Dataset): else: self._count = ops.convert_to_tensor( count, dtype=dtypes.int64, name="count") - - seed, seed2 = random_seed.get_seed(seed) - if seed is None: - self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") - else: - self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") - if seed2 is None: - self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") - else: - self._seed2 = ops.convert_to_tensor( - seed2, dtype=dtypes.int64, name="seed2") + self._seed, self._seed2 = random_seed.get_seed(seed) def _as_variant_tensor(self): # pylint: disable=protected-access diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py new file mode 100644 index 0000000000000000000000000000000000000000..19cc3cb89fc5c494f79ce1d25ed57c92099c8bd2 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -0,0 +1,102 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Sliding dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops + + +class _SlideDataset(dataset_ops.Dataset): + """A `Dataset` that passes a sliding window over its input.""" + + def __init__(self, input_dataset, window_size, stride=1): + """See `sliding_window_batch` for details.""" + super(_SlideDataset, self).__init__() + self._input_dataset = input_dataset + self._window_size = ops.convert_to_tensor( + window_size, dtype=dtypes.int64, name="window_size") + self._stride = ops.convert_to_tensor( + stride, dtype=dtypes.int64, name="stride") + + def _as_variant_tensor(self): + return gen_dataset_ops.slide_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + window_size=self._window_size, + stride=self._stride, + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + input_shapes = self._input_dataset.output_shapes + return nest.pack_sequence_as(input_shapes, [ + tensor_shape.vector(None).concatenate(s) + for s in nest.flatten(self._input_dataset.output_shapes) + ]) + + @property + def output_types(self): + return self._input_dataset.output_types + + +def sliding_window_batch(window_size, stride=1): + """A sliding window with size of `window_size` and step of `stride`. + + This transformation passes a sliding window over this dataset. The + window size is `window_size` and step size is `stride`. If the left + elements cannot fill up the sliding window, this transformation will + drop the final smaller element. For example: + + ```python + # NOTE: The following examples use `{ ... }` to represent the + # contents of a dataset. + a = { [1], [2], [3], [4], [5], [6] } + + a.apply(tf.contrib.data.sliding_window_batch(window_size=3, stride=2)) == + { + [[1], [2], [3]], + [[3], [4], [5]], + } + ``` + + Args: + window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + elements in the sliding window. + stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + steps moving the sliding window forward for one iteration. The default + is `1`. It must be in `[1, window_size)`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + def _apply_fn(dataset): + return _SlideDataset(dataset, window_size, stride) + + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index 9cd1701c397b5a0bf5cc47c1bcab033704794d80..b5cf0fcfe91ebc22444302fca5d488a278ef2994 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -47,7 +47,7 @@ class StatsAggregator(object): dataset = ... iterator = dataset.make_one_shot_iterator() stats_aggregator = stats_ops.StatsAggregator() - set_op = stats_op.set_stats_aggregator_op(iterator, stats_aggregator) + set_op = stats_aggregator.subscribe(iterator) with tf.Session() as sess: # Running `set_op` will associate `iterator` with `stats_aggregator`. diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py index 3f85aa84cd53fcf5e21480aac96e067766ad1b65..56f67e1766bbaff680bdff6b939df0c3ba68c679 100644 --- a/tensorflow/contrib/data/python/ops/threadpool.py +++ b/tensorflow/contrib/data/python/ops/threadpool.py @@ -44,7 +44,7 @@ class PrivateThreadPool(object): def __init__(self, num_threads, display_name=None): """Creates a `PrivateThreadPool` with the given number of threads.""" - if context.in_eager_mode(): + if context.executing_eagerly(): shared_name = _generate_shared_name("privatethreadpool") self._resource = gen_dataset_ops.thread_pool_handle( num_threads=num_threads, diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD index f6de5998d73a4869d2444cd90c9b64d1a2c889ac..ae3847b8b62452b1afbe472fcb6369181ec60b73 100644 --- a/tensorflow/contrib/decision_trees/proto/BUILD +++ b/tensorflow/contrib/decision_trees/proto/BUILD @@ -25,7 +25,6 @@ tf_proto_library( name = "generic_tree_model", srcs = ["generic_tree_model.proto"], cc_api_version = 2, - go_api_version = 2, java_api_version = 2, visibility = ["//visibility:public"], ) @@ -34,7 +33,6 @@ tf_proto_library( name = "generic_tree_model_extensions", srcs = ["generic_tree_model_extensions.proto"], cc_api_version = 2, - go_api_version = 2, protodeps = [":generic_tree_model"], visibility = ["//visibility:public"], ) diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 35dd2ee439be401bdc0ce971a9dd71ed0cb411d0..1c381cc354fa4e5a630cfb5025dfd4bddf04a71c 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -251,6 +251,21 @@ cuda_py_test( ], ) +cuda_py_test( + name = "kumaraswamy_test", + srcs = ["python/kernel_tests/kumaraswamy_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "moving_stats_test", size = "small", @@ -335,6 +350,7 @@ cuda_py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:platform_test", ], + tags = ["nomsan"], ) cuda_py_test( @@ -459,6 +475,25 @@ cuda_py_test( tags = ["nomsan"], # disable to avoid false positives from scipy. ) +cuda_py_test( + name = "statistical_testing_test", + size = "medium", + srcs = [ + "python/kernel_tests/statistical_testing_test.py", + ], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + ], + tags = [ + "manual", + "noasan", + "noguitar", + "optonly", + ], +) + cuda_py_test( name = "vector_sinh_arcsinh_diag_test", size = "medium", @@ -782,6 +817,25 @@ cuda_py_test( tags = ["noasan"], # times out b/63678675 ) +cuda_py_test( + name = "affine_scalar_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/affine_scalar_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "affine_linear_operator_test", size = "small", @@ -801,6 +855,22 @@ cuda_py_test( ], ) +cuda_py_test( + name = "batch_normalization_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/batch_normalization_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "chain_test", size = "small", @@ -915,6 +985,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "kumaraswamy_bijector_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/kumaraswamy_bijector_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "masked_autoregressive_test", size = "small", @@ -1017,10 +1106,12 @@ cuda_py_test( ], ) +# Tests for SinhArcSinh bijector. The file name has the extra "_bijector" to +# avoid BUILD rule name conflicts with the distribution by the same name. cuda_py_test( - name = "sigmoid_centered_test", + name = "sinh_arcsinh_bijector_test", size = "small", - srcs = ["python/kernel_tests/bijectors/sigmoid_centered_test.py"], + srcs = ["python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py"], additional_deps = [ ":bijectors_py", ":distributions_py", @@ -1036,12 +1127,10 @@ cuda_py_test( ], ) -# Tests for SinhArcSinh bijector. The file name has the extra "_bijector" to -# avoid BUILD rule name conflicts with the distribution by the same name. cuda_py_test( - name = "sinh_arcsinh_bijector_test", + name = "softmax_centered_test", size = "small", - srcs = ["python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py"], + srcs = ["python/kernel_tests/bijectors/softmax_centered_test.py"], additional_deps = [ ":bijectors_py", ":distributions_py", @@ -1058,9 +1147,9 @@ cuda_py_test( ) cuda_py_test( - name = "softmax_centered_test", + name = "softplus_test", size = "small", - srcs = ["python/kernel_tests/bijectors/softmax_centered_test.py"], + srcs = ["python/kernel_tests/bijectors/softplus_test.py"], additional_deps = [ ":bijectors_py", ":distributions_py", @@ -1077,9 +1166,9 @@ cuda_py_test( ) cuda_py_test( - name = "softplus_test", + name = "square_test", size = "small", - srcs = ["python/kernel_tests/bijectors/softplus_test.py"], + srcs = ["python/kernel_tests/bijectors/square_test.py"], additional_deps = [ ":bijectors_py", ":distributions_py", diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py new file mode 100644 index 0000000000000000000000000000000000000000..16173a166fd943413345036df12245c2a4ab8343 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py @@ -0,0 +1,153 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Affine Scalar Tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.bijectors.affine_scalar import AffineScalar +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency +from tensorflow.python.platform import test + + +class AffineScalarBijectorTest(test.TestCase): + """Tests correctness of the Y = scale @ x + shift transformation.""" + + def testProperties(self): + with self.test_session(): + mu = -1. + # scale corresponds to 1. + bijector = AffineScalar(shift=mu) + self.assertEqual("affine_scalar", bijector.name) + + def testNoBatchScalar(self): + with self.test_session() as sess: + + def static_run(fun, x): + return fun(x).eval() + + def dynamic_run(fun, x_value): + x_value = np.array(x_value) + x = array_ops.placeholder(dtypes.float32, name="x") + return sess.run(fun(x), feed_dict={x: x_value}) + + for run in (static_run, dynamic_run): + mu = -1. + # Corresponds to scale = 2 + bijector = AffineScalar(shift=mu, scale=2.) + x = [1., 2, 3] # Three scalar samples (no batches). + self.assertAllClose([1., 3, 5], run(bijector.forward, x)) + self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) + self.assertAllClose([-np.log(2.)] * 3, + run(bijector.inverse_log_det_jacobian, x)) + + def testOneBatchScalarViaIdentityIn64BitUserProvidesShiftOnly(self): + with self.test_session() as sess: + + def static_run(fun, x): + return fun(x).eval() + + def dynamic_run(fun, x_value): + x_value = np.array(x_value).astype(np.float64) + x = array_ops.placeholder(dtypes.float64, name="x") + return sess.run(fun(x), feed_dict={x: x_value}) + + for run in (static_run, dynamic_run): + mu = np.float64([1.]) + # One batch, scalar. + # Corresponds to scale = 1. + bijector = AffineScalar(shift=mu) + x = np.float64([1.]) # One sample from one batches. + self.assertAllClose([2.], run(bijector.forward, x)) + self.assertAllClose([0.], run(bijector.inverse, x)) + self.assertAllClose([0.], run(bijector.inverse_log_det_jacobian, x)) + + def testOneBatchScalarViaIdentityIn64BitUserProvidesScaleOnly(self): + with self.test_session() as sess: + + def static_run(fun, x): + return fun(x).eval() + + def dynamic_run(fun, x_value): + x_value = np.array(x_value).astype(np.float64) + x = array_ops.placeholder(dtypes.float64, name="x") + return sess.run(fun(x), feed_dict={x: x_value}) + + for run in (static_run, dynamic_run): + multiplier = np.float64([2.]) + # One batch, scalar. + # Corresponds to scale = 2, shift = 0. + bijector = AffineScalar(scale=multiplier) + x = np.float64([1.]) # One sample from one batches. + self.assertAllClose([2.], run(bijector.forward, x)) + self.assertAllClose([0.5], run(bijector.inverse, x)) + self.assertAllClose([np.log(0.5)], + run(bijector.inverse_log_det_jacobian, x)) + + def testTwoBatchScalarIdentityViaIdentity(self): + with self.test_session() as sess: + + def static_run(fun, x): + return fun(x).eval() + + def dynamic_run(fun, x_value): + x_value = np.array(x_value) + x = array_ops.placeholder(dtypes.float32, name="x") + return sess.run(fun(x), feed_dict={x: x_value}) + + for run in (static_run, dynamic_run): + mu = [1., -1] + # Univariate, two batches. + # Corresponds to scale = 1. + bijector = AffineScalar(shift=mu) + x = [1., 1] # One sample from each of two batches. + self.assertAllClose([2., 0], run(bijector.forward, x)) + self.assertAllClose([0., 2], run(bijector.inverse, x)) + self.assertAllClose([0., 0.], run(bijector.inverse_log_det_jacobian, x)) + + def testTwoBatchScalarIdentityViaScale(self): + with self.test_session() as sess: + + def static_run(fun, x): + return fun(x).eval() + + def dynamic_run(fun, x_value): + x_value = np.array(x_value) + x = array_ops.placeholder(dtypes.float32, name="x") + return sess.run(fun(x), feed_dict={x: x_value}) + + for run in (static_run, dynamic_run): + mu = [1., -1] + # Univariate, two batches. + # Corresponds to scale = 1. + bijector = AffineScalar(shift=mu, scale=[2., 1]) + x = [1., 1] # One sample from each of two batches. + self.assertAllClose([3., 0], run(bijector.forward, x)) + self.assertAllClose([0., 2], run(bijector.inverse, x)) + self.assertAllClose( + [-np.log(2), 0.], run(bijector.inverse_log_det_jacobian, x)) + + def testScalarCongruency(self): + with self.test_session(): + bijector = AffineScalar(shift=3.6, scale=0.42) + assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py index c9158117f7a982e37047e8dd2b534a30040a87d9..077e6176b4e7aecb28369d49edad6d1367cc7259 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py @@ -25,7 +25,6 @@ import numpy as np from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops -from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test @@ -36,192 +35,9 @@ class AffineBijectorTest(test.TestCase): with self.test_session(): mu = -1. # scale corresponds to 1. - bijector = Affine(shift=mu, event_ndims=0) + bijector = Affine(shift=mu) self.assertEqual("affine", bijector.name) - def testNoBatchScalarViaIdentity(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = -1. - # Corresponds to scale = 2 - bijector = Affine( - shift=mu, scale_identity_multiplier=2., event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = [1., 2, 3] # Three scalar samples (no batches). - self.assertAllClose([1., 3, 5], run(bijector.forward, x)) - self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.), - run(bijector.inverse_log_det_jacobian, x)) - - def testNoBatchScalarViaDiag(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = -1. - # Corresponds to scale = 2 - bijector = Affine(shift=mu, scale_identity_multiplier=2., event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = [1., 2, 3] # Three scalar samples (no batches). - self.assertAllClose([1., 3, 5], run(bijector.forward, x)) - self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.), - run(bijector.inverse_log_det_jacobian, x)) - - def testWeirdSampleNoBatchScalarViaDiagMultiplier(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = -1. - # Corresponds to scale = 2. - bijector = Affine( - shift=mu, scale_identity_multiplier=2., event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = [[1., 2, 3], [4, 5, 6]] # Weird sample shape. - self.assertAllClose([[1., 3, 5], - [7, 9, 11]], - run(bijector.forward, x)) - self.assertAllClose([[1., 1.5, 2.], - [2.5, 3, 3.5]], - run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.), - run(bijector.inverse_log_det_jacobian, x)) - - def testOneBatchScalarViaIdentityIn64BitUserProvidesShiftOnly(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value).astype(np.float64) - x = array_ops.placeholder(dtypes.float64, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = np.float64([1.]) - # One batch, scalar. - # Corresponds to scale = 1. - bijector = Affine(shift=mu, event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = np.float64([1.]) # One sample from one batches. - self.assertAllClose([2.], run(bijector.forward, x)) - self.assertAllClose([0.], run(bijector.inverse, x)) - self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x)) - - def testOneBatchScalarViaIdentityIn64BitUserProvidesMultiplierOnly(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value).astype(np.float64) - x = array_ops.placeholder(dtypes.float64, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - multiplier = np.float64([2.]) - # One batch, scalar. - # Corresponds to scale = 2, shift = 0. - bijector = Affine(scale_identity_multiplier=multiplier, event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = np.float64([1.]) # One sample from one batches. - self.assertAllClose([2.], run(bijector.forward, x)) - self.assertAllClose([0.5], run(bijector.inverse, x)) - self.assertAllClose([np.log(0.5)], - run(bijector.inverse_log_det_jacobian, x)) - - def testOneBatchScalarViaDiagMultiplier(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = [1.] - # One batch, scalar. - # Corresponds to scale = 1. - bijector = Affine(shift=mu, scale_identity_multiplier=1., event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = [1.] # One sample from one batches. - self.assertAllClose([2.], run(bijector.forward, x)) - self.assertAllClose([0.], run(bijector.inverse, x)) - self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x)) - - def testTwoBatchScalarIdentityViaIdentity(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = [1., -1] - # Univariate, two batches. - # Corresponds to scale = 1. - bijector = Affine(shift=mu, event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = [1., 1] # One sample from each of two batches. - self.assertAllClose([2., 0], run(bijector.forward, x)) - self.assertAllClose([0., 2], run(bijector.inverse, x)) - self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x)) - - def testTwoBatchScalarIdentityViaDiagMultiplier(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = [1., -1] - # Univariate, two batches. - # Corresponds to scale = 1. - bijector = Affine(shift=mu, scale_identity_multiplier=1., event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = [1., 1] # One sample from each of two batches. - self.assertAllClose([2., 0], run(bijector.forward, x)) - self.assertAllClose([0., 2], run(bijector.inverse, x)) - self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x)) - def testNoBatchMultivariateIdentity(self): with self.test_session() as sess: @@ -238,7 +54,6 @@ class AffineBijectorTest(test.TestCase): # Multivariate # Corresponds to scale = [[1., 0], [0, 1.]] bijector = Affine(shift=mu) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 1] # matmul(sigma, x) + shift # = [-1, -1] + [1, -1] @@ -269,7 +84,6 @@ class AffineBijectorTest(test.TestCase): # Multivariate # Corresponds to scale = [[2., 0], [0, 1.]] bijector = Affine(shift=mu, scale_diag=[2., 1]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 1] # matmul(sigma, x) + shift # = [-1, -1] + [1, -1] @@ -297,22 +111,17 @@ class AffineBijectorTest(test.TestCase): x = array_ops.placeholder(dtypes.float32, name="x") mu = array_ops.placeholder(dtypes.float32, name="mu") scale_diag = array_ops.placeholder(dtypes.float32, name="scale_diag") - event_ndims = array_ops.placeholder(dtypes.int32, name="event_ndims") x_value = np.array([[1., 1]], dtype=np.float32) mu_value = np.array([1., -1], dtype=np.float32) scale_diag_value = np.array([2., 2], dtype=np.float32) - event_ndims_value = np.array(1, dtype=np.int32) feed_dict = { x: x_value, mu: mu_value, scale_diag: scale_diag_value, - event_ndims: event_ndims_value } - bijector = Affine( - shift=mu, scale_diag=scale_diag, event_ndims=event_ndims) - self.assertEqual(1, sess.run(bijector.event_ndims, feed_dict)) + bijector = Affine(shift=mu, scale_diag=scale_diag) self.assertAllClose([[3., 1]], sess.run(bijector.forward(x), feed_dict)) self.assertAllClose([[0., 1]], sess.run(bijector.inverse(x), feed_dict)) self.assertAllClose( @@ -335,7 +144,6 @@ class AffineBijectorTest(test.TestCase): # Corresponds to 1 2x2 matrix, with twos on the diagonal. scale = 2. bijector = Affine(shift=mu, scale_identity_multiplier=scale) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[[1., 1]]] self.assertAllClose([[[3., 1]]], run(bijector.forward, x)) self.assertAllClose([[[0., 1]]], run(bijector.inverse, x)) @@ -358,7 +166,6 @@ class AffineBijectorTest(test.TestCase): # Corresponds to 1 2x2 matrix, with twos on the diagonal. scale_diag = [[2., 2]] bijector = Affine(shift=mu, scale_diag=scale_diag) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[[1., 1]]] self.assertAllClose([[[3., 1]]], run(bijector.forward, x)) self.assertAllClose([[[0., 1]]], run(bijector.inverse, x)) @@ -370,23 +177,18 @@ class AffineBijectorTest(test.TestCase): x = array_ops.placeholder(dtypes.float32, name="x") mu = array_ops.placeholder(dtypes.float32, name="mu") scale_diag = array_ops.placeholder(dtypes.float32, name="scale_diag") - event_ndims = array_ops.placeholder(dtypes.int32, name="event_ndims") x_value = np.array([[[1., 1]]], dtype=np.float32) mu_value = np.array([[1., -1]], dtype=np.float32) scale_diag_value = np.array([[2., 2]], dtype=np.float32) - event_ndims_value = 1 feed_dict = { x: x_value, mu: mu_value, scale_diag: scale_diag_value, - event_ndims: event_ndims_value } - bijector = Affine( - shift=mu, scale_diag=scale_diag, event_ndims=event_ndims) - self.assertEqual(1, sess.run(bijector.event_ndims, feed_dict)) + bijector = Affine(shift=mu, scale_diag=scale_diag) self.assertAllClose([[[3., 1]]], sess.run(bijector.forward(x), feed_dict)) self.assertAllClose([[[0., 1]]], sess.run(bijector.inverse(x), feed_dict)) self.assertAllClose([-np.log(4)], @@ -410,9 +212,7 @@ class AffineBijectorTest(test.TestCase): bijector = Affine( shift=mu, scale_identity_multiplier=1., - scale_diag=[1., 1., 1.], - event_ndims=1) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" + scale_diag=[1., 1., 1.]) x = [1., 2, 3] # Three scalar samples (no batches). self.assertAllClose([1., 3, 5], run(bijector.forward, x)) self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) @@ -437,7 +237,6 @@ class AffineBijectorTest(test.TestCase): shift=mu, scale_identity_multiplier=1., scale_tril=[[1., 0], [2., 1]]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[1., 2]] # One multivariate sample. self.assertAllClose([[1., 5]], run(bijector.forward, x)) self.assertAllClose([[1., 0.5]], run(bijector.inverse, x)) @@ -460,7 +259,6 @@ class AffineBijectorTest(test.TestCase): # scale = [[2., 0], [2, 3]] bijector = Affine( shift=mu, scale_diag=[1., 2.], scale_tril=[[1., 0], [2., 1]]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[1., 2]] # One multivariate sample. self.assertAllClose([[1., 7]], run(bijector.forward, x)) self.assertAllClose([[1., 1 / 3.]], run(bijector.inverse, x)) @@ -486,7 +284,6 @@ class AffineBijectorTest(test.TestCase): scale_identity_multiplier=1.0, scale_diag=[1., 2.], scale_tril=[[1., 0], [2., 1]]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[1., 2]] # One multivariate sample. self.assertAllClose([[2., 9]], run(bijector.forward, x)) self.assertAllClose([[2 / 3., 5 / 12.]], run(bijector.inverse, x)) @@ -514,7 +311,6 @@ class AffineBijectorTest(test.TestCase): scale_perturb_factor=[[2., 0], [0., 0], [0, 1]]) bijector_ref = Affine(shift=mu, scale_diag=[10., 2, 3]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Vector. self.assertAllClose([9., 3, 8], run(bijector.forward, x)) self.assertAllClose( @@ -550,7 +346,6 @@ class AffineBijectorTest(test.TestCase): scale_perturb_factor=[[2., 0], [0., 0], [0, 1]]) bijector_ref = Affine(shift=mu, scale_diag=[10., 3, 5]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Vector. self.assertAllClose([9., 5, 14], run(bijector.forward, x)) self.assertAllClose( @@ -586,7 +381,6 @@ class AffineBijectorTest(test.TestCase): bijector_ref = Affine( shift=mu, scale_tril=[[10., 0, 0], [1, 3, 0], [2, 3, 5]]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Vector. self.assertAllClose([9., 6, 22], run(bijector.forward, x)) self.assertAllClose( @@ -622,7 +416,6 @@ class AffineBijectorTest(test.TestCase): bijector_ref = Affine( shift=mu, scale_tril=[[6., 0, 0], [1, 3, 0], [2, 3, 5]]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Vector. self.assertAllClose([5., 6, 22], run(bijector.forward, x)) self.assertAllClose( @@ -647,38 +440,6 @@ class AffineBijectorTest(test.TestCase): with self.assertRaisesOpError("diagonal part must be non-zero"): bijector.forward([1., 1.]).eval() - def testEventNdimsLargerThanOneRaises(self): - with self.test_session(): - mu = [1., -1] - with self.assertRaisesRegexp( - ValueError, (r"event_ndims\(2\) was not 0 or 1")): - # Scale corresponds to 2x2 identity matrix. - bijector = Affine(shift=mu, event_ndims=2, validate_args=True) - bijector.forward([1., 1.]).eval() - - def testScaleZeroScalarRaises(self): - with self.test_session(): - mu = -1. - # Check Identity matrix with zero scaling. - bijector = Affine( - shift=mu, - scale_identity_multiplier=0., - event_ndims=0, - validate_args=True) - with self.assertRaisesOpError("identity_multiplier should be non-zero"): - bijector.forward(1.).eval() - - def testScaleDiagAndEventNdimsZeroRaises(self): - # Check Diag matrix with zero scaling. - with self.assertRaisesRegexp(ValueError, "only scale argument"): - Affine(shift=None, scale_diag=[0.0], event_ndims=0, validate_args=True) - - def testScalarCongruency(self): - with self.test_session(): - bijector = Affine( - shift=3.6, scale_identity_multiplier=0.42, event_ndims=0) - assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.) - def _makeScale(self, x, scale_identity_multiplier=None, @@ -747,14 +508,12 @@ class AffineBijectorTest(test.TestCase): scale_args = dict({"x": x}, **args) scale = self._makeScale(**scale_args) - bijector_args = dict({"event_ndims": 1}, **args) - # We haven't specified enough information for the scale. if scale is None: with self.assertRaisesRegexp(ValueError, ("must be specified.")): - bijector = Affine(shift=shift, **bijector_args) + bijector = Affine(shift=shift, **args) else: - bijector = Affine(shift=shift, **bijector_args) + bijector = Affine(shift=shift, **args) np_x = x # For the case a vector is passed in, we need to make the shape # match the matrix for matmul to work. @@ -829,15 +588,5 @@ class AffineBijectorTest(test.TestCase): x=np.array( [1., 2], dtype=np.float32)) - def testScalarEventIdentityScale(self): - with self.test_session() as sess: - doubler = Affine( - scale_identity_multiplier=2., - event_ndims=0) - doubler2 = doubler.inverse_log_det_jacobian(2.) - doubler2_ildj_ = sess.run([doubler2]) - self.assertAllClose([-np.log(2.)], doubler2_ildj_) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a215a4a2b1ffbea7951bdb9b4352ed567e0b1e41 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py @@ -0,0 +1,236 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for BatchNorm Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib import distributions +from tensorflow.contrib.distributions.python.ops import test_util +from tensorflow.contrib.distributions.python.ops.bijectors.batch_normalization import BatchNormalization +from tensorflow.contrib.distributions.python.ops.bijectors.invert import Invert +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.layers import normalization +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib +from tensorflow.python.platform import test +from tensorflow.python.training import adam + + +class BatchNormTest(test_util.VectorDistributionTestHelpers, + test.TestCase): + + def _reduction_axes(self, input_shape, event_dims): + if isinstance(event_dims, int): + event_dims = [event_dims] + ndims = len(input_shape) + # Convert event_dims to non-negative indexing. + event_dims = list(event_dims) + for idx, x in enumerate(event_dims): + if x < 0: + event_dims[idx] = ndims + x + return tuple(i for i in range(ndims) if i not in event_dims) + + def testForwardInverse(self): + """Tests forward and backward passes with different event shapes. + + input_shape: Tuple of shapes for input tensor. + event_dims: Tuple of dimension indices that will be normalized. + training: Boolean of whether bijector runs in training or inference mode. + """ + params = [ + ((5*2, 4), [-1], False), + ((5, 2, 4), [-1], False), + ((5, 2, 4), [1, 2], False), + ((5, 2, 4), [0, 1], False), + ((5*2, 4), [-1], True), + ((5, 2, 4), [-1], True), + ((5, 2, 4), [1, 2], True), + ((5, 2, 4), [0, 1], True) + ] + for input_shape, event_dims, training in params: + x_ = np.arange(5 * 4 * 2).astype(np.float32).reshape(input_shape) + with self.test_session() as sess: + x = constant_op.constant(x_) + # When training, memorize the exact mean of the last + # minibatch that it normalized (instead of moving average assignment). + layer = normalization.BatchNormalization( + axis=event_dims, momentum=0., epsilon=0.) + batch_norm = BatchNormalization( + batchnorm_layer=layer, training=training) + # Minibatch statistics are saved only after norm_x has been computed. + norm_x = batch_norm.inverse(x) + with ops.control_dependencies(batch_norm.batchnorm.updates): + moving_mean = array_ops.identity(batch_norm.batchnorm.moving_mean) + moving_var = array_ops.identity(batch_norm.batchnorm.moving_variance) + denorm_x = batch_norm.forward(array_ops.identity(norm_x)) + fldj = batch_norm.forward_log_det_jacobian(x) + # Use identity to invalidate cache. + ildj = batch_norm.inverse_log_det_jacobian( + array_ops.identity(denorm_x)) + variables.global_variables_initializer().run() + # Update variables. + norm_x_ = sess.run(norm_x) + [ + norm_x_, + moving_mean_, + moving_var_, + denorm_x_, + ildj_, + fldj_, + ] = sess.run([ + norm_x, + moving_mean, + moving_var, + denorm_x, + ildj, + fldj, + ]) + self.assertEqual("batch_normalization", batch_norm.name) + + reduction_axes = self._reduction_axes(input_shape, event_dims) + keepdims = len(event_dims) > 1 + + expected_batch_mean = np.mean( + x_, axis=reduction_axes, keepdims=keepdims) + expected_batch_var = np.var(x_, axis=reduction_axes, keepdims=keepdims) + + if training: + # When training=True, values become normalized across batch dim and + # original values are recovered after de-normalizing. + zeros = np.zeros_like(norm_x_) + self.assertAllClose(np.mean(zeros, axis=reduction_axes), + np.mean(norm_x_, axis=reduction_axes)) + + self.assertAllClose(expected_batch_mean, moving_mean_) + self.assertAllClose(expected_batch_var, moving_var_) + self.assertAllClose(x_, denorm_x_, atol=1e-5) + # Since moving statistics are set to batch statistics after + # normalization, ildj and -fldj should match. + self.assertAllClose(ildj_, -fldj_) + # ildj is computed with minibatch statistics. + expected_ildj = np.sum(np.log(1.) - .5 * np.log( + expected_batch_var + batch_norm.batchnorm.epsilon)) + self.assertAllClose(expected_ildj, ildj_) + else: + # When training=False, moving_mean, moving_var remain at their + # initialized values (0., 1.), resulting in no scale/shift (a small + # shift occurs if epsilon > 0.) + self.assertAllClose(x_, norm_x_) + self.assertAllClose(x_, denorm_x_, atol=1e-5) + # ildj is computed with saved statistics. + expected_ildj = np.sum( + np.log(1.) - .5 * np.log(1. + batch_norm.batchnorm.epsilon)) + self.assertAllClose(expected_ildj, ildj_) + + def testMaximumLikelihoodTraining(self): + # Test Maximum Likelihood training with default bijector. + with self.test_session() as sess: + base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.]) + batch_norm = BatchNormalization(training=True) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=base_dist, + bijector=batch_norm) + target_dist = distributions.MultivariateNormalDiag(loc=[1., 2.]) + target_samples = target_dist.sample(100) + dist_samples = dist.sample(3000) + loss = -math_ops.reduce_mean(dist.log_prob(target_samples)) + with ops.control_dependencies(batch_norm.batchnorm.updates): + train_op = adam.AdamOptimizer(1e-2).minimize(loss) + moving_mean = array_ops.identity(batch_norm.batchnorm.moving_mean) + moving_var = array_ops.identity(batch_norm.batchnorm.moving_variance) + variables.global_variables_initializer().run() + for _ in range(3000): + sess.run(train_op) + [ + dist_samples_, + moving_mean_, + moving_var_ + ] = sess.run([ + dist_samples, + moving_mean, + moving_var + ]) + self.assertAllClose([1., 2.], np.mean(dist_samples_, axis=0), atol=5e-2) + self.assertAllClose([1., 2.], moving_mean_, atol=5e-2) + self.assertAllClose([1., 1.], moving_var_, atol=5e-2) + + def testLogProb(self): + with self.test_session() as sess: + layer = normalization.BatchNormalization(epsilon=0.) + batch_norm = BatchNormalization(batchnorm_layer=layer, training=False) + base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.]) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=base_dist, + bijector=batch_norm, + validate_args=True) + samples = dist.sample(int(1e5)) + # No volume distortion since training=False, bijector is initialized + # to the identity transformation. + base_log_prob = base_dist.log_prob(samples) + dist_log_prob = dist.log_prob(samples) + variables.global_variables_initializer().run() + base_log_prob_, dist_log_prob_ = sess.run([base_log_prob, dist_log_prob]) + self.assertAllClose(base_log_prob_, dist_log_prob_) + + def testMutuallyConsistent(self): + # BatchNorm bijector is only mutually consistent when training=False. + dims = 4 + with self.test_session() as sess: + layer = normalization.BatchNormalization(epsilon=0.) + batch_norm = BatchNormalization(batchnorm_layer=layer, training=False) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=0., scale=1.), + bijector=batch_norm, + event_shape=[dims], + validate_args=True) + self.run_test_sample_consistent_log_prob( + sess_run_fn=sess.run, + dist=dist, + num_samples=int(1e5), + radius=2., + center=0., + rtol=0.02) + + def testInvertMutuallyConsistent(self): + # BatchNorm bijector is only mutually consistent when training=False. + dims = 4 + with self.test_session() as sess: + layer = normalization.BatchNormalization(epsilon=0.) + batch_norm = Invert( + BatchNormalization(batchnorm_layer=layer, training=False)) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=0., scale=1.), + bijector=batch_norm, + event_shape=[dims], + validate_args=True) + self.run_test_sample_consistent_log_prob( + sess_run_fn=sess.run, + dist=dist, + num_samples=int(1e5), + radius=2., + center=0., + rtol=0.02) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py index 20e754308449af3f0399101f4ea1bb47b3356424..a748acd667e58f9b527bab11d8bc4d086996e9f3 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py @@ -66,12 +66,10 @@ class ChainBijectorTest(test.TestCase): def testShapeGetters(self): with self.test_session(): bijector = Chain([ - SoftmaxCentered( - event_ndims=1, validate_args=True), - SoftmaxCentered( - event_ndims=0, validate_args=True) + SoftmaxCentered(validate_args=True), + SoftmaxCentered(validate_args=True), ]) - x = tensor_shape.TensorShape([]) + x = tensor_shape.TensorShape([1]) y = tensor_shape.TensorShape([2 + 1]) self.assertAllEqual(y, bijector.forward_event_shape(x)) self.assertAllEqual( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py index 0ff35304283fce9ce3f9e5d31b1258394e384d7b..f392e83d2c3da9dac43c2e87070e952ae2060b34 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py @@ -18,70 +18,111 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.distributions.python.ops import bijectors -from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops -from tensorflow.python.ops.distributions import gamma as gamma_lib -from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib -from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test -class InvertBijectorTest(test.TestCase): - """Tests the correctness of the Y = Invert(bij) transformation.""" +class CholeskyOuterProductBijectorTest(test.TestCase): + """Tests the correctness of the Y = X @ X.T transformation.""" - def testBijector(self): + def testBijectorMatrix(self): with self.test_session(): - for fwd in [ - bijectors.Identity(), - bijectors.Exp(event_ndims=1), - bijectors.Affine( - shift=[0., 1.], scale_diag=[2., 3.], event_ndims=1), - bijectors.Softplus(event_ndims=1), - bijectors.SoftmaxCentered(event_ndims=1), - bijectors.SigmoidCentered(), - ]: - rev = bijectors.Invert(fwd) - self.assertEqual("_".join(["invert", fwd.name]), rev.name) - x = [[[1., 2.], - [2., 3.]]] - self.assertAllClose(fwd.inverse(x).eval(), rev.forward(x).eval()) - self.assertAllClose(fwd.forward(x).eval(), rev.inverse(x).eval()) - self.assertAllClose( - fwd.forward_log_det_jacobian(x).eval(), - rev.inverse_log_det_jacobian(x).eval()) - self.assertAllClose( - fwd.inverse_log_det_jacobian(x).eval(), - rev.forward_log_det_jacobian(x).eval()) + bijector = bijectors.CholeskyOuterProduct(validate_args=True) + self.assertEqual("cholesky_outer_product", bijector.name) + x = [[[1., 0], [2, 1]], [[np.sqrt(2.), 0], [np.sqrt(8.), 1]]] + y = np.matmul(x, np.transpose(x, axes=(0, 2, 1))) + # Fairly easy to compute differentials since we have 2x2. + dx_dy = [[[2. * 1, 0, 0], + [2, 1, 0], + [0, 2 * 2, 2 * 1]], + [[2 * np.sqrt(2.), 0, 0], + [np.sqrt(8.), np.sqrt(2.), 0], + [0, 2 * np.sqrt(8.), 2 * 1]]] + ildj = -np.sum( + np.log(np.asarray(dx_dy).diagonal( + offset=0, axis1=1, axis2=2)), + axis=1) + self.assertAllEqual((2, 2, 2), bijector.forward(x).get_shape()) + self.assertAllEqual((2, 2, 2), bijector.inverse(y).get_shape()) + self.assertAllClose(y, bijector.forward(x).eval()) + self.assertAllClose(x, bijector.inverse(y).eval()) + self.assertAllClose( + ildj, bijector.inverse_log_det_jacobian(y).eval(), atol=0., rtol=1e-7) + self.assertAllClose( + -bijector.inverse_log_det_jacobian(y).eval(), + bijector.forward_log_det_jacobian(x).eval(), + atol=0., + rtol=1e-7) - def testScalarCongruency(self): - with self.test_session(): - bijector = bijectors.Invert(bijectors.Exp()) - assert_scalar_congruency( - bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05) + def testNoBatchStatic(self): + x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y) + y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T) + with self.test_session() as sess: + y_actual = bijectors.CholeskyOuterProduct().forward(x=x) + x_actual = bijectors.CholeskyOuterProduct().inverse(y=y) + [y_actual_, x_actual_] = sess.run([y_actual, x_actual]) + self.assertAllEqual([2, 2], y_actual.get_shape()) + self.assertAllEqual([2, 2], x_actual.get_shape()) + self.assertAllClose(y, y_actual_) + self.assertAllClose(x, x_actual_) - def testShapeGetters(self): - with self.test_session(): - bijector = bijectors.Invert(bijectors.SigmoidCentered(validate_args=True)) - x = tensor_shape.TensorShape([2]) - y = tensor_shape.TensorShape([]) - self.assertAllEqual(y, bijector.forward_event_shape(x)) - self.assertAllEqual( - y.as_list(), - bijector.forward_event_shape_tensor(x.as_list()).eval()) - self.assertAllEqual(x, bijector.inverse_event_shape(y)) - self.assertAllEqual( - x.as_list(), - bijector.inverse_event_shape_tensor(y.as_list()).eval()) + def testNoBatchDeferred(self): + x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y) + y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T) + with self.test_session() as sess: + x_pl = array_ops.placeholder(dtypes.float32) + y_pl = array_ops.placeholder(dtypes.float32) + y_actual = bijectors.CholeskyOuterProduct().forward(x=x_pl) + x_actual = bijectors.CholeskyOuterProduct().inverse(y=y_pl) + [y_actual_, x_actual_] = sess.run([y_actual, x_actual], + feed_dict={x_pl: x, y_pl: y}) + self.assertEqual(None, y_actual.get_shape()) + self.assertEqual(None, x_actual.get_shape()) + self.assertAllClose(y, y_actual_) + self.assertAllClose(x, x_actual_) - def testDocstringExample(self): - with self.test_session(): - exp_gamma_distribution = ( - transformed_distribution_lib.TransformedDistribution( - distribution=gamma_lib.Gamma(concentration=1., rate=2.), - bijector=bijectors.Invert(bijectors.Exp()))) - self.assertAllEqual( - [], array_ops.shape(exp_gamma_distribution.sample()).eval()) + def testBatchStatic(self): + x = np.array([[[1., 0], + [2, 1]], + [[3., 0], + [1, 2]]]) # np.linalg.cholesky(y) + y = np.array([[[1., 2], + [2, 5]], + [[9., 3], + [3, 5]]]) # np.matmul(x, x.T) + with self.test_session() as sess: + y_actual = bijectors.CholeskyOuterProduct().forward(x=x) + x_actual = bijectors.CholeskyOuterProduct().inverse(y=y) + [y_actual_, x_actual_] = sess.run([y_actual, x_actual]) + self.assertEqual([2, 2, 2], y_actual.get_shape()) + self.assertEqual([2, 2, 2], x_actual.get_shape()) + self.assertAllClose(y, y_actual_) + self.assertAllClose(x, x_actual_) + + def testBatchDeferred(self): + x = np.array([[[1., 0], + [2, 1]], + [[3., 0], + [1, 2]]]) # np.linalg.cholesky(y) + y = np.array([[[1., 2], + [2, 5]], + [[9., 3], + [3, 5]]]) # np.matmul(x, x.T) + with self.test_session() as sess: + x_pl = array_ops.placeholder(dtypes.float32) + y_pl = array_ops.placeholder(dtypes.float32) + y_actual = bijectors.CholeskyOuterProduct().forward(x=x_pl) + x_actual = bijectors.CholeskyOuterProduct().inverse(y=y_pl) + [y_actual_, x_actual_] = sess.run([y_actual, x_actual], + feed_dict={x_pl: x, y_pl: y}) + self.assertEqual(None, y_actual.get_shape()) + self.assertEqual(None, x_actual.get_shape()) + self.assertAllClose(y, y_actual_) + self.assertAllClose(x, x_actual_) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py index 0ff35304283fce9ce3f9e5d31b1258394e384d7b..58ba9cedb1437df4e000ce32fe39664afa76c3b5 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py @@ -35,11 +35,9 @@ class InvertBijectorTest(test.TestCase): for fwd in [ bijectors.Identity(), bijectors.Exp(event_ndims=1), - bijectors.Affine( - shift=[0., 1.], scale_diag=[2., 3.], event_ndims=1), + bijectors.Affine(shift=[0., 1.], scale_diag=[2., 3.]), bijectors.Softplus(event_ndims=1), - bijectors.SoftmaxCentered(event_ndims=1), - bijectors.SigmoidCentered(), + bijectors.SoftmaxCentered(), ]: rev = bijectors.Invert(fwd) self.assertEqual("_".join(["invert", fwd.name]), rev.name) @@ -62,9 +60,9 @@ class InvertBijectorTest(test.TestCase): def testShapeGetters(self): with self.test_session(): - bijector = bijectors.Invert(bijectors.SigmoidCentered(validate_args=True)) + bijector = bijectors.Invert(bijectors.SoftmaxCentered(validate_args=True)) x = tensor_shape.TensorShape([2]) - y = tensor_shape.TensorShape([]) + y = tensor_shape.TensorShape([1]) self.assertAllEqual(y, bijector.forward_event_shape(x)) self.assertAllEqual( y.as_list(), diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ad11d9f2484c4b08c67c5f82aec1320475d1d983 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Kumaraswamy Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import Kumaraswamy +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency +from tensorflow.python.platform import test + + +class KumaraswamyBijectorTest(test.TestCase): + """Tests correctness of the Kumaraswamy bijector.""" + + def testBijector(self): + with self.test_session(): + a = 2. + b = 0.3 + bijector = Kumaraswamy( + concentration1=a, concentration0=b, + event_ndims=0, validate_args=True) + self.assertEqual("kumaraswamy", bijector.name) + x = np.array([[[0.1], [0.2], [0.3], [0.4], [0.5]]], dtype=np.float32) + # Kumaraswamy cdf. This is the same as inverse(x). + y = 1. - (1. - x ** a) ** b + self.assertAllClose(y, bijector.inverse(x).eval()) + self.assertAllClose(x, bijector.forward(y).eval()) + kumaraswamy_log_pdf = (np.log(a) + np.log(b) + (a - 1) * np.log(x) + + (b - 1) * np.log1p(-x ** a)) + + self.assertAllClose( + # We should lose a dimension from calculating the determinant of the + # jacobian. + kumaraswamy_log_pdf, + bijector.inverse_log_det_jacobian(x).eval()) + self.assertAllClose( + -bijector.inverse_log_det_jacobian(x).eval(), + bijector.forward_log_det_jacobian(y).eval(), + rtol=1e-4, + atol=0.) + + def testScalarCongruency(self): + with self.test_session(): + assert_scalar_congruency( + Kumaraswamy(concentration1=0.5, concentration0=1.1), + lower_x=0., upper_x=1., n=int(10e3), rtol=0.02) + + def testBijectiveAndFinite(self): + with self.test_session(): + concentration1 = 1.2 + concentration0 = 2. + bijector = Kumaraswamy( + concentration1=concentration1, + concentration0=concentration0, validate_args=True) + # Omitting the endpoints 0 and 1, since idlj will be inifinity at these + # endpoints. + y = np.linspace(.01, 0.99, num=10).astype(np.float32) + x = 1 - (1 - y ** concentration1) ** concentration0 + assert_bijective_and_finite(bijector, x, y, rtol=1e-3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_centered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_centered_test.py deleted file mode 100644 index 4ff3f334ccb59f1c117b3d35032d9e799cfd79bb..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_centered_test.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import SigmoidCentered -from tensorflow.python.platform import test - - -class SigmoidCenteredBijectorTest(test.TestCase): - """Tests correctness of the Y = g(X) = (1 + exp(-X))^-1 transformation.""" - - def testBijector(self): - with self.test_session(): - sigmoid = SigmoidCentered() - self.assertEqual("sigmoid_centered", sigmoid.name) - x = np.log([[2., 3, 4], - [4., 8, 12]]) - y = [[[2. / 3, 1. / 3], - [3. / 4, 1. / 4], - [4. / 5, 1. / 5]], - [[4. / 5, 1. / 5], - [8. / 9, 1. / 9], - [12. / 13, 1. / 13]]] - self.assertAllClose(y, sigmoid.forward(x).eval()) - self.assertAllClose(x, sigmoid.inverse(y).eval()) - self.assertAllClose( - -np.sum(np.log(y), axis=2), - sigmoid.inverse_log_det_jacobian(y).eval(), - atol=0., - rtol=1e-7) - self.assertAllClose( - -sigmoid.inverse_log_det_jacobian(y).eval(), - sigmoid.forward_log_det_jacobian(x).eval(), - atol=0., - rtol=1e-7) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py index 62e3869db090e9c9327bc552d10234ff76ba28fd..cad4dd1ac8de0da6405aacb9047714b37eec73e3 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py @@ -21,7 +21,9 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered +from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite from tensorflow.python.platform import test @@ -32,22 +34,16 @@ rng = np.random.RandomState(42) class SoftmaxCenteredBijectorTest(test.TestCase): """Tests correctness of the Y = g(X) = exp(X) / sum(exp(X)) transformation.""" - def testBijectorScalar(self): + def testBijectorVector(self): with self.test_session(): - softmax = SoftmaxCentered() # scalar by default + softmax = SoftmaxCentered() self.assertEqual("softmax_centered", softmax.name) - x = np.log([[2., 3, 4], - [4., 8, 12]]) - y = [[[2. / 3, 1. / 3], - [3. / 4, 1. / 4], - [4. / 5, 1. / 5]], - [[4. / 5, 1. / 5], - [8. / 9, 1. / 9], - [12. / 13, 1. / 13]]] + x = np.log([[2., 3, 4], [4., 8, 12]]) + y = [[0.2, 0.3, 0.4, 0.1], [0.16, 0.32, 0.48, 0.04]] self.assertAllClose(y, softmax.forward(x).eval()) self.assertAllClose(x, softmax.inverse(y).eval()) self.assertAllClose( - -np.sum(np.log(y), axis=2), + -np.sum(np.log(y), axis=1), softmax.inverse_log_det_jacobian(y).eval(), atol=0., rtol=1e-7) @@ -57,45 +53,49 @@ class SoftmaxCenteredBijectorTest(test.TestCase): atol=0., rtol=1e-7) - def testBijectorVector(self): + def testBijectorUnknownShape(self): with self.test_session(): - softmax = SoftmaxCentered(event_ndims=1) + softmax = SoftmaxCentered() self.assertEqual("softmax_centered", softmax.name) - x = np.log([[2., 3, 4], [4., 8, 12]]) - y = [[0.2, 0.3, 0.4, 0.1], [0.16, 0.32, 0.48, 0.04]] - self.assertAllClose(y, softmax.forward(x).eval()) - self.assertAllClose(x, softmax.inverse(y).eval()) + x = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32) + real_x = np.log([[2., 3, 4], [4., 8, 12]]) + y = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32) + real_y = [[0.2, 0.3, 0.4, 0.1], [0.16, 0.32, 0.48, 0.04]] + self.assertAllClose(real_y, softmax.forward(x).eval( + feed_dict={x: real_x})) + self.assertAllClose(real_x, softmax.inverse(y).eval( + feed_dict={y: real_y})) self.assertAllClose( - -np.sum(np.log(y), axis=1), - softmax.inverse_log_det_jacobian(y).eval(), + -np.sum(np.log(real_y), axis=1), + softmax.inverse_log_det_jacobian(y).eval( + feed_dict={y: real_y}), atol=0., rtol=1e-7) self.assertAllClose( - -softmax.inverse_log_det_jacobian(y).eval(), - softmax.forward_log_det_jacobian(x).eval(), + -softmax.inverse_log_det_jacobian(y).eval( + feed_dict={y: real_y}), + softmax.forward_log_det_jacobian(x).eval( + feed_dict={x: real_x}), atol=0., rtol=1e-7) def testShapeGetters(self): with self.test_session(): - for x, y, b in ((tensor_shape.TensorShape([]), - tensor_shape.TensorShape([2]), - SoftmaxCentered( - event_ndims=0, validate_args=True)), - (tensor_shape.TensorShape([4]), - tensor_shape.TensorShape([5]), - SoftmaxCentered( - event_ndims=1, validate_args=True))): - self.assertAllEqual(y, b.forward_event_shape(x)) - self.assertAllEqual(y.as_list(), - b.forward_event_shape_tensor(x.as_list()).eval()) - self.assertAllEqual(x, b.inverse_event_shape(y)) - self.assertAllEqual(x.as_list(), - b.inverse_event_shape_tensor(y.as_list()).eval()) + x = tensor_shape.TensorShape([4]) + y = tensor_shape.TensorShape([5]) + bijector = SoftmaxCentered(validate_args=True) + self.assertAllEqual(y, bijector.forward_event_shape(x)) + self.assertAllEqual(y.as_list(), + bijector.forward_event_shape_tensor( + x.as_list()).eval()) + self.assertAllEqual(x, bijector.inverse_event_shape(y)) + self.assertAllEqual(x.as_list(), + bijector.inverse_event_shape_tensor( + y.as_list()).eval()) def testBijectiveAndFinite(self): with self.test_session(): - softmax = SoftmaxCentered(event_ndims=1) + softmax = SoftmaxCentered() x = np.linspace(-50, 50, num=10).reshape(5, 2).astype(np.float32) # Make y values on the simplex with a wide range. y_0 = np.ones(5).astype(np.float32) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f03d6f1343a11ae4517f9034ceb0c99ca6fe7fa2 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py @@ -0,0 +1,58 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency +from tensorflow.python.platform import test + + +class SquareBijectorTest(test.TestCase): + """Tests the correctness of the Y = X ** 2 transformation.""" + + def testBijectorScalar(self): + with self.test_session(): + bijector = bijectors.Square(validate_args=True) + self.assertEqual("square", bijector.name) + x = [[[1., 5], + [2, 1]], + [[np.sqrt(2.), 3], + [np.sqrt(8.), 1]]] + y = np.square(x) + ildj = -np.log(2.) - np.log(x) + self.assertAllClose(y, bijector.forward(x).eval()) + self.assertAllClose(x, bijector.inverse(y).eval()) + self.assertAllClose( + ildj, bijector.inverse_log_det_jacobian(y).eval(), atol=0., rtol=1e-7) + self.assertAllClose( + -bijector.inverse_log_det_jacobian(y).eval(), + bijector.forward_log_det_jacobian(x).eval(), + atol=0., + rtol=1e-7) + + def testScalarCongruency(self): + with self.test_session(): + bijector = bijectors.Square(validate_args=True) + assert_scalar_congruency(bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py index 507ceb35853ebe0a996d789b3bdf8a5f2284549c..68e0d9cb8277f3953039963fec0da499db7a16d1 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py @@ -16,6 +16,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib import distributions from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -25,23 +27,23 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -ds = distributions +tfd = distributions class DistributionTest(test.TestCase): def testParamShapesAndFromParams(self): classes = [ - ds.Normal, - ds.Bernoulli, - ds.Beta, - ds.Chi2, - ds.Exponential, - ds.Gamma, - ds.InverseGamma, - ds.Laplace, - ds.StudentT, - ds.Uniform, + tfd.Normal, + tfd.Bernoulli, + tfd.Beta, + tfd.Chi2, + tfd.Exponential, + tfd.Gamma, + tfd.InverseGamma, + tfd.Laplace, + tfd.StudentT, + tfd.Uniform, ] sample_shapes = [(), (10,), (10, 20, 30)] @@ -63,15 +65,15 @@ class DistributionTest(test.TestCase): with self.test_session(): # Note: we cannot easily test all distributions since each requires # different initialization arguments. We therefore spot test a few. - normal = ds.Normal(loc=1., scale=2., validate_args=True) + normal = tfd.Normal(loc=1., scale=2., validate_args=True) self.assertEqual(normal.parameters, normal.copy().parameters) - wishart = ds.WishartFull(df=2, scale=[[1., 2], [2, 5]], - validate_args=True) + wishart = tfd.WishartFull(df=2, scale=[[1., 2], [2, 5]], + validate_args=True) self.assertEqual(wishart.parameters, wishart.copy().parameters) def testCopyOverride(self): with self.test_session(): - normal = ds.Normal(loc=1., scale=2., validate_args=True) + normal = tfd.Normal(loc=1., scale=2., validate_args=True) unused_normal_copy = normal.copy(validate_args=False) base_params = normal.parameters.copy() copy_params = normal.copy(validate_args=False).parameters.copy() @@ -84,19 +86,19 @@ class DistributionTest(test.TestCase): mu = 1. sigma = 2. - normal = ds.Normal(mu, sigma, validate_args=True) + normal = tfd.Normal(mu, sigma, validate_args=True) self.assertTrue(tensor_util.constant_value(normal.is_scalar_event())) self.assertTrue(tensor_util.constant_value(normal.is_scalar_batch())) - normal = ds.Normal([mu], [sigma], validate_args=True) + normal = tfd.Normal([mu], [sigma], validate_args=True) self.assertTrue(tensor_util.constant_value(normal.is_scalar_event())) self.assertFalse(tensor_util.constant_value(normal.is_scalar_batch())) - mvn = ds.MultivariateNormalDiag([mu], [sigma], validate_args=True) + mvn = tfd.MultivariateNormalDiag([mu], [sigma], validate_args=True) self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event())) self.assertTrue(tensor_util.constant_value(mvn.is_scalar_batch())) - mvn = ds.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True) + mvn = tfd.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True) self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event())) self.assertFalse(tensor_util.constant_value(mvn.is_scalar_batch())) @@ -126,7 +128,7 @@ class DistributionTest(test.TestCase): self.assertFalse(is_scalar.eval(feed_dict={x: [1]})) def _GetFakeDistribution(self): - class FakeDistribution(ds.Distribution): + class FakeDistribution(tfd.Distribution): """Fake Distribution for testing _set_sample_static_shape.""" def __init__(self, batch_shape=None, event_shape=None): @@ -188,6 +190,105 @@ class DistributionTest(test.TestCase): y = dist._set_sample_static_shape(x, sample_shape) self.assertTrue(y.get_shape().ndims is None) + def testStrWorksCorrectlyScalar(self): + normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1)) + self.assertEqual( + ("tf.distributions.Normal(" + "\"Normal\", " + "batch_shape=(), " + "event_shape=(), " + "dtype=float16)"), # Got the dtype right. + str(normal)) + + chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly") + self.assertEqual( + ("tf.distributions.Chi2(" + "\"silly\", " # What a silly name that is! + "batch_shape=(2,), " + "event_shape=(), " + "dtype=float32)"), + str(chi2)) + + exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32)) + self.assertEqual( + ("tf.distributions.Exponential(\"Exponential\", " + # No batch shape. + "event_shape=(), " + "dtype=float32)"), + str(exp)) + + def testStrWorksCorrectlyMultivariate(self): + mvn_static = tfd.MultivariateNormalDiag( + loc=np.zeros([2, 2]), name="MVN") + self.assertEqual( + ("tf.distributions.MultivariateNormalDiag(" + "\"MVN\", " + "batch_shape=(2,), " + "event_shape=(2,), " + "dtype=float64)"), + str(mvn_static)) + + mvn_dynamic = tfd.MultivariateNormalDiag( + loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32), + name="MVN2") + self.assertEqual( + ("tf.distributions.MultivariateNormalDiag(" + "\"MVN2\", " + "batch_shape=(?,), " # Partially known. + "event_shape=(3,), " + "dtype=float32)"), + str(mvn_dynamic)) + + def testReprWorksCorrectlyScalar(self): + normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1)) + self.assertEqual( + (""), # Got the dtype right. + repr(normal)) + + chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly") + self.assertEqual( + (""), + repr(chi2)) + + exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32)) + self.assertEqual( + ("" + " event_shape=()" + " dtype=float32>"), + repr(exp)) + + def testReprWorksCorrectlyMultivariate(self): + mvn_static = tfd.MultivariateNormalDiag( + loc=np.zeros([2, 2]), name="MVN") + self.assertEqual( + (""), + repr(mvn_static)) + + mvn_dynamic = tfd.MultivariateNormalDiag( + loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32), + name="MVN2") + self.assertEqual( + (""), + repr(mvn_dynamic)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py b/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py index 06318ca09dec851cf025fa35c83732b85824cbee..6a69f9e60b99a17c657f074597a075890265a93b 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bernoulli as bernoulli_lib +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging @@ -126,6 +127,100 @@ class ProductDistributionTest(test.TestCase): self.assertAllClose(sample_entropy_, actual_entropy_, rtol=0.01, atol=0.) self.assertAllClose(loc, actual_mode_, rtol=1e-6, atol=0.) + def testKLRaises(self): + ind1 = independent_lib.Independent( + distribution=normal_lib.Normal( + loc=np.float32([-1., 1]), + scale=np.float32([0.1, 0.5])), + reinterpreted_batch_ndims=1) + ind2 = independent_lib.Independent( + distribution=normal_lib.Normal( + loc=np.float32(-1), + scale=np.float32(0.5)), + reinterpreted_batch_ndims=0) + + with self.assertRaisesRegexp( + ValueError, "Event shapes do not match"): + kullback_leibler.kl_divergence(ind1, ind2) + + ind1 = independent_lib.Independent( + distribution=normal_lib.Normal( + loc=np.float32([-1., 1]), + scale=np.float32([0.1, 0.5])), + reinterpreted_batch_ndims=1) + ind2 = independent_lib.Independent( + distribution=mvn_diag_lib.MultivariateNormalDiag( + loc=np.float32([-1., 1]), + scale_diag=np.float32([0.1, 0.5])), + reinterpreted_batch_ndims=0) + + with self.assertRaisesRegexp( + NotImplementedError, "different event shapes"): + kullback_leibler.kl_divergence(ind1, ind2) + + def testKLScalarToMultivariate(self): + normal1 = normal_lib.Normal( + loc=np.float32([-1., 1]), + scale=np.float32([0.1, 0.5])) + ind1 = independent_lib.Independent( + distribution=normal1, reinterpreted_batch_ndims=1) + + normal2 = normal_lib.Normal( + loc=np.float32([-3., 3]), + scale=np.float32([0.3, 0.3])) + ind2 = independent_lib.Independent( + distribution=normal2, reinterpreted_batch_ndims=1) + + normal_kl = kullback_leibler.kl_divergence(normal1, normal2) + ind_kl = kullback_leibler.kl_divergence(ind1, ind2) + self.assertAllClose( + self.evaluate(math_ops.reduce_sum(normal_kl, axis=-1)), + self.evaluate(ind_kl)) + + def testKLIdentity(self): + normal1 = normal_lib.Normal( + loc=np.float32([-1., 1]), + scale=np.float32([0.1, 0.5])) + # This is functionally just a wrapper around normal1, + # and doesn't change any outputs. + ind1 = independent_lib.Independent( + distribution=normal1, reinterpreted_batch_ndims=0) + + normal2 = normal_lib.Normal( + loc=np.float32([-3., 3]), + scale=np.float32([0.3, 0.3])) + # This is functionally just a wrapper around normal2, + # and doesn't change any outputs. + ind2 = independent_lib.Independent( + distribution=normal2, reinterpreted_batch_ndims=0) + + normal_kl = kullback_leibler.kl_divergence(normal1, normal2) + ind_kl = kullback_leibler.kl_divergence(ind1, ind2) + self.assertAllClose( + self.evaluate(normal_kl), self.evaluate(ind_kl)) + + def testKLMultivariateToMultivariate(self): + # (1, 1, 2) batch of MVNDiag + mvn1 = mvn_diag_lib.MultivariateNormalDiag( + loc=np.float32([[[[-1., 1, 3.], [2., 4., 3.]]]]), + scale_diag=np.float32([[[0.2, 0.1, 5.], [2., 3., 4.]]])) + ind1 = independent_lib.Independent( + distribution=mvn1, reinterpreted_batch_ndims=2) + + # (1, 1, 2) batch of MVNDiag + mvn2 = mvn_diag_lib.MultivariateNormalDiag( + loc=np.float32([[[[-2., 3, 2.], [1., 3., 2.]]]]), + scale_diag=np.float32([[[0.1, 0.5, 3.], [1., 2., 1.]]])) + + ind2 = independent_lib.Independent( + distribution=mvn2, reinterpreted_batch_ndims=2) + + mvn_kl = kullback_leibler.kl_divergence(mvn1, mvn2) + ind_kl = kullback_leibler.kl_divergence(ind1, ind2) + self.assertAllClose( + self.evaluate(math_ops.reduce_sum(mvn_kl, axis=[-1, -2])), + self.evaluate(ind_kl)) + def _testMnistLike(self, static_shape): sample_shape = [4, 5] batch_shape = [10] diff --git a/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py index ea3c86b5c0f42b64fc6e4e362cbcc162bccf74a2..2980e2bfe93b2e2aa01d38fc9fa4650a015efc06 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py @@ -130,10 +130,8 @@ class KumaraswamyTest(test.TestCase): dist.prob([.1, .3, .6]).eval() dist.prob([.2, .3, .5]).eval() # Either condition can trigger. - with self.assertRaisesOpError("sample must be positive"): + with self.assertRaisesOpError("sample must be non-negative"): dist.prob([-1., 0.1, 0.5]).eval() - with self.assertRaisesOpError("sample must be positive"): - dist.prob([0., 0.1, 0.5]).eval() with self.assertRaisesOpError("sample must be no larger than `1`"): dist.prob([.1, .2, 1.2]).eval() @@ -249,13 +247,13 @@ class KumaraswamyTest(test.TestCase): a = np.array([1., 2, 3]) b = np.array([2., 4, 1.2]) dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False) - with self.assertRaisesOpError("Condition x < y.*"): + with self.assertRaisesOpError("Mode undefined for concentration1 <= 1."): dist.mode().eval() a = np.array([2., 2, 3]) b = np.array([1., 4, 1.2]) dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False) - with self.assertRaisesOpError("Condition x < y.*"): + with self.assertRaisesOpError("Mode undefined for concentration0 <= 1."): dist.mode().eval() def testKumaraswamyModeEnableAllowNanStats(self): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3548ac18078a0b40f117c2bf9e2b34d20cee163b --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py @@ -0,0 +1,166 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the statistical testing library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import statistical_testing as st +from tensorflow.python.framework import errors +from tensorflow.python.ops import check_ops +from tensorflow.python.platform import test + + +class StatisticalTestingTest(test.TestCase): + + def test_dkwm_design_mean_one_sample_soundness(self): + numbers = [1e-5, 1e-2, 1.1e-1, 0.9, 1., 1.02, 2., 10., 1e2, 1e5, 1e10] + rates = [1e-6, 1e-3, 1e-2, 1.1e-1, 0.2, 0.5, 0.7, 1.] + with self.test_session() as sess: + for ff in rates: + for fp in rates: + sufficient_n = st.min_num_samples_for_dkwm_mean_test( + numbers, 0., 1., false_fail_rate=ff, false_pass_rate=fp) + detectable_d = st.min_discrepancy_of_true_means_detectable_by_dkwm( + sufficient_n, 0., 1., false_fail_rate=ff, false_pass_rate=fp) + sess.run(check_ops.assert_less_equal(detectable_d, numbers)) + + def test_dkwm_design_mean_two_sample_soundness(self): + numbers = [1e-5, 1e-2, 1.1e-1, 0.9, 1., 1.02, 2., 10., 1e2, 1e5, 1e10] + rates = [1e-6, 1e-3, 1e-2, 1.1e-1, 0.2, 0.5, 0.7, 1.] + with self.test_session() as sess: + for ff in rates: + for fp in rates: + (sufficient_n1, + sufficient_n2) = st.min_num_samples_for_dkwm_mean_two_sample_test( + numbers, 0., 1., 0., 1., + false_fail_rate=ff, false_pass_rate=fp) + d_fn = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample + detectable_d = d_fn( + sufficient_n1, 0., 1., sufficient_n2, 0., 1., + false_fail_rate=ff, false_pass_rate=fp) + sess.run(check_ops.assert_less_equal(detectable_d, numbers)) + + def test_true_mean_confidence_interval_by_dkwm_one_sample(self): + rng = np.random.RandomState(seed=0) + + num_samples = 5000 + # 5000 samples is chosen to be enough to find discrepancies of + # size 0.1 or more with assurance 1e-6, as confirmed here: + with self.test_session() as sess: + d = st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, 0., 1., false_fail_rate=1e-6, false_pass_rate=1e-6) + d = sess.run(d) + self.assertLess(d, 0.1) + + # Test that the confidence interval computed for the mean includes + # 0.5 and excludes 0.4 and 0.6. + with self.test_session() as sess: + samples = rng.uniform(size=num_samples).astype(np.float32) + (low, high) = st.true_mean_confidence_interval_by_dkwm( + samples, 0., 1., error_rate=1e-6) + low, high = sess.run([low, high]) + self.assertGreater(low, 0.4) + self.assertLess(low, 0.5) + self.assertGreater(high, 0.5) + self.assertLess(high, 0.6) + + def test_dkwm_mean_one_sample_assertion(self): + rng = np.random.RandomState(seed=0) + num_samples = 5000 + + # Test that the test assertion agrees that the mean of the standard + # uniform distribution is 0.5. + samples = rng.uniform(size=num_samples).astype(np.float32) + with self.test_session() as sess: + sess.run(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.5, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not 0.4. + with self.assertRaises(errors.InvalidArgumentError): + sess.run(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.4, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not 0.6. + with self.assertRaises(errors.InvalidArgumentError): + sess.run(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.6, false_fail_rate=1e-6)) + + def test_dkwm_mean_two_sample_assertion(self): + rng = np.random.RandomState(seed=0) + num_samples = 15000 + + # 15000 samples is chosen to be enough to find discrepancies of + # size 0.1 or more with assurance 1e-6, as confirmed here: + with self.test_session() as sess: + d = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( + num_samples, 0., 1., num_samples, 0., 1., + false_fail_rate=1e-6, false_pass_rate=1e-6) + d = sess.run(d) + self.assertLess(d, 0.1) + + # Test that the test assertion agrees that the standard + # uniform distribution has the same mean as itself. + samples1 = rng.uniform(size=num_samples).astype(np.float32) + samples2 = rng.uniform(size=num_samples).astype(np.float32) + with self.test_session() as sess: + sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., samples2, 0., 1., false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is different from the mean of beta(2, 1). + beta_high_samples = rng.beta(2, 1, size=num_samples).astype(np.float32) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., + beta_high_samples, 0., 1., + false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is different from the mean of beta(1, 2). + beta_low_samples = rng.beta(1, 2, size=num_samples).astype(np.float32) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., + beta_low_samples, 0., 1., + false_fail_rate=1e-6)) + + def test_dkwm_argument_validity_checking(self): + rng = np.random.RandomState(seed=0) + samples = rng.uniform(size=5000).astype(np.float32) + + # Test that the test library complains if the given samples fall + # outside the purported bounds. + with self.test_session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(st.true_mean_confidence_interval_by_dkwm( + samples, 0., 0.5, error_rate=0.5)) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(st.true_mean_confidence_interval_by_dkwm( + samples, 0.5, 1., error_rate=0.5)) + + # But doesn't complain if they don't. + op = st.true_mean_confidence_interval_by_dkwm( + samples, 0., 1., error_rate=0.5) + _ = sess.run(op) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py index cbaf74d3f66253ae5727e1ba579e2d49235b748e..f0ba1ec3eb57c67c1a0edb15639e91916a4509b7 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py @@ -186,12 +186,14 @@ class TransformedDistributionTest(test.TestCase): standard_normal = ds.Normal(loc=0., scale=1.) multi_logit_normal = self._cls()( distribution=standard_normal, - bijector=softmax) - x = [[-np.log(3.), 0.], - [np.log(3), np.log(5)]] + bijector=softmax, + event_shape=[1]) + x = [[[-np.log(3.)], [0.]], + [[np.log(3)], [np.log(5)]]] y = softmax.forward(x).eval() - expected_log_pdf = (stats.norm(loc=0., scale=1.).logpdf(x) - - np.sum(np.log(y), axis=-1)) + expected_log_pdf = ( + np.squeeze(stats.norm(loc=0., scale=1.).logpdf(x)) - + np.sum(np.log(y), axis=-1)) self.assertAllClose(expected_log_pdf, multi_logit_normal.log_prob(y).eval()) self.assertAllClose( @@ -245,9 +247,8 @@ class TransformedDistributionTest(test.TestCase): with self.test_session() as sess: exp2 = self._cls()( ds.Exponential(rate=0.25), - bijector=ds.bijectors.Affine( - scale_identity_multiplier=2., - event_ndims=0)) + bijector=ds.bijectors.AffineScalar(scale=2.) + ) log_prob = exp2.log_prob(1.) log_prob_ = sess.run(log_prob) base_log_prob = -0.5 * 0.25 + np.log(0.25) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py index 9044aa2850ae35f29cd48b0c5f54aa948bea0408..dcecce981f16a2d9e772d4e40062ff250725c3ac 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py @@ -390,6 +390,26 @@ class WishartCholeskyTest(test.TestCase): chol_scale, dtype=np.int32), validate_args=False) + def testSampleBroadcasts(self): + dims = 2 + batch_shape = [2, 3] + sample_shape = [2, 1] + scale = np.float32([ + [[1., 0.5], + [0.5, 1.]], + [[0.5, 0.25], + [0.25, 0.75]], + ]) + scale = np.reshape(np.concatenate([scale, scale, scale], axis=0), + batch_shape + [dims, dims]) + wishart = distributions.WishartFull(df=5, scale=scale) + x = wishart.sample(sample_shape, seed=42) + with self.test_session() as sess: + x_ = sess.run(x) + expected_shape = sample_shape + batch_shape + [dims, dims] + self.assertAllEqual(expected_shape, x.shape) + self.assertAllEqual(expected_shape, x_.shape) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py index 852298bf334666db003353d5fc8e172ffb738668..69f3d57ff000d6c9acc8aa9e3d0ad8d9cbb6bb3c 100644 --- a/tensorflow/contrib/distributions/python/ops/autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py @@ -36,7 +36,8 @@ class Autoregressive(distribution_lib.Distribution): "Autoregressive models decompose the joint density as a product of conditionals, and model each conditional in turn. Normalizing flows transform a base density (e.g. a standard Gaussian) into the target density - by an invertible transformation with tractable Jacobian." [1] + by an invertible transformation with tractable Jacobian." [(Papamakarios et + al., 2016)][1] In other words, the "autoregressive property" is equivalent to the decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided @@ -45,17 +46,18 @@ class Autoregressive(distribution_lib.Distribution): Practically speaking the autoregressive property means that there exists a permutation of the event coordinates such that each coordinate is a - diffeomorphic function of only preceding coordinates. [2] + diffeomorphic function of only preceding coordinates + [(van den Oord et al., 2016)][2]. #### Mathematical Details - The probability function is, + The probability function is ```none prob(x; fn, n) = fn(x).prob(x) ``` - And a sample is generated by, + And a sample is generated by ```none x = fn(...fn(fn(x0).sample()).sample()).sample() @@ -93,13 +95,15 @@ class Autoregressive(distribution_lib.Distribution): ``` - [1]: "Masked Autoregressive Flow for Density Estimation." - George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. - https://arxiv.org/abs/1705.07057 + #### References - [2]: "Conditional Image Generation with PixelCNN Decoders." - Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, Alex - Graves, Koray Kavukcuoglu. Arxiv, 2016. + [1]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 + + [2]: Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, + Alex Graves, and Koray Kavukcuoglu. Conditional Image Generation with + PixelCNN Decoders. In _Neural Information Processing Systems_, 2016. https://arxiv.org/abs/1606.05328 """ diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index 93923c3f083c7f5136b55e9021cbd6323684b976..bc6b02542ebf3b83d58f888509dafb86351de8a7 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -17,7 +17,9 @@ @@AbsoluteValue @@Affine @@AffineLinearOperator +@@AffineScalar @@Bijector +@@BatchNormalization @@Chain @@CholeskyOuterProduct @@ConditionalBijector @@ -26,16 +28,17 @@ @@Identity @@Inline @@Invert +@@Kumaraswamy @@MaskedAutoregressiveFlow @@Permute @@PowerTransform @@RealNVP @@Reshape @@Sigmoid -@@SigmoidCentered @@SinhArcsinh @@SoftmaxCentered @@Softplus +@@Square @@Weibull @@masked_autoregressive_default_template @@ -52,6 +55,8 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value import * from tensorflow.contrib.distributions.python.ops.bijectors.affine import * from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import * +from tensorflow.contrib.distributions.python.ops.bijectors.affine_scalar import * +from tensorflow.contrib.distributions.python.ops.bijectors.batch_normalization import * from tensorflow.contrib.distributions.python.ops.bijectors.chain import * from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import * from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import * @@ -59,16 +64,17 @@ from tensorflow.contrib.distributions.python.ops.bijectors.exp import * from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import * from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * +from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import * from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import * from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import * from tensorflow.contrib.distributions.python.ops.bijectors.reshape import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import * -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import * from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.softplus import * +from tensorflow.contrib.distributions.python.ops.bijectors.square import * from tensorflow.python.ops.distributions.bijector import * from tensorflow.python.ops.distributions.identity_bijector import Identity diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py index 05bb9c2f9bdf35e222c94db3491157893da64ebd..bef7bbb49b715497695f7513e19ecab4fa56c47e 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py @@ -62,7 +62,7 @@ class Affine(bijector.Bijector): matrices, i.e., the matmul is [matrix-free]( https://en.wikipedia.org/wiki/Matrix-free_methods) when possible. - Examples: + #### Examples ```python # Y = X @@ -104,7 +104,6 @@ class Affine(bijector.Bijector): scale_tril=None, scale_perturb_factor=None, scale_perturb_diag=None, - event_ndims=1, validate_args=False, name="affine"): """Instantiates the `Affine` bijector. @@ -157,8 +156,6 @@ class Affine(bijector.Bijector): matrix. `scale_perturb_diag` has shape [N1, N2, ... r], which represents an `r x r` diagonal matrix. When `None` low rank updates will take the form `scale_perturb_factor * scale_perturb_factor.T`. - event_ndims: Scalar `int` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. Must be 0 or 1. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. @@ -187,23 +184,6 @@ class Affine(bijector.Bijector): with self._name_scope("init", values=[ shift, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_diag, scale_perturb_factor]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims_const = tensor_util.constant_value(event_ndims) - if event_ndims_const is not None and event_ndims_const not in (0, 1): - raise ValueError("event_ndims(%s) was not 0 or 1" % event_ndims_const) - else: - if validate_args: - # Shape tool will catch if event_ndims is negative. - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_less( - event_ndims, 2, message="event_ndims must be 0 or 1")], - event_ndims) - - if event_ndims_const == 0 and not self._is_only_identity_multiplier: - raise ValueError( - "If event_ndims == 0, the only scale argument you can pass is " - "scale_identity_multiplier. All others operate on vectors.") - # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. dtype = dtypes.float32 @@ -251,12 +231,11 @@ class Affine(bijector.Bijector): self._scale = scale self._shaper = _DistributionShape( batch_ndims=batch_ndims, - event_ndims=event_ndims, + event_ndims=1, validate_args=validate_args) super(Affine, self).__init__( - event_ndims=event_ndims, + event_ndims=1, graph_parents=( - [event_ndims] + [self._scale] if tensor_util.is_tensor(self._scale) else self._scale.graph_parents + [self._shift] if self._shift is not None else []), @@ -388,9 +367,7 @@ class Affine(bijector.Bijector): if self._is_only_identity_multiplier: # We don't pad in this case and instead let the fldj be applied # via broadcast. - event_size = distribution_util.pick_vector( - math_ops.equal(self._shaper.event_ndims, 0), - [1], array_ops.shape(x))[-1] + event_size = array_ops.shape(x)[-1] event_size = math_ops.cast(event_size, dtype=self._scale.dtype) return math_ops.log(math_ops.abs(self._scale)) * event_size return self.scale.log_abs_determinant() diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py new file mode 100644 index 0000000000000000000000000000000000000000..8adaa54c843d1b243a02967402a37b7c63fabbdf --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py @@ -0,0 +1,138 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Affine bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + + +__all__ = [ + "AffineScalar", +] + + +class AffineScalar(bijector.Bijector): + """Compute `Y = g(X; shift, scale) = scale * X + shift`. + + Examples: + + ```python + # Y = X + b = AffineScalar() + + # Y = X + shift + b = AffineScalar(shift=[1., 2, 3]) + + # Y = 2 * X + shift + b = AffineScalar( + shift=[1., 2, 3], + scale=2.) + ``` + + """ + + def __init__(self, + shift=None, + scale=None, + validate_args=False, + name="affine_scalar"): + """Instantiates the `AffineScalar` bijector. + + This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments, + giving the forward operation: + + ```none + Y = g(X) = scale * X + shift + ``` + + if `scale` is not specified, then the bijector has the semantics of + `scale = 1.`. Similarly, if `shift` is not specified, then the bijector + has the semantics of `shift = 0.`. + + Args: + shift: Floating-point `Tensor`. If this is set to `None`, no shift is + applied. + scale: Floating-point `Tensor`. If this is set to `None`, no scale is + applied. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + + with self._name_scope("init", values=[scale, shift]): + self._shift = shift + self._scale = scale + + if self._shift is not None: + self._shift = ops.convert_to_tensor(shift, name="shift") + + if self._scale is not None: + self._scale = ops.convert_to_tensor(self._scale, name="scale") + if validate_args: + self._scale = control_flow_ops.with_dependencies( + [check_ops.assert_none_equal( + self._scale, + array_ops.zeros([], dtype=self._scale.dtype))], + self._scale) + + super(AffineScalar, self).__init__( + event_ndims=0, + is_constant_jacobian=True, + validate_args=validate_args, + name=name) + + @property + def shift(self): + """The `shift` `Tensor` in `Y = scale @ X + shift`.""" + return self._shift + + @property + def scale(self): + """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" + return self._scale + + def _forward(self, x): + y = array_ops.identity(x) + if self.scale is not None: + y *= self.scale + if self.shift is not None: + y += self.shift + return y + + def _inverse(self, y): + x = array_ops.identity(y) + if self.shift is not None: + x -= self.shift + if self.scale is not None: + x /= self.scale + return x + + def _forward_log_det_jacobian(self, x): + log_det_jacobian = array_ops.zeros_like(x) + if self.scale is None: + return log_det_jacobian + log_det_jacobian += math_ops.log(math_ops.abs(self.scale)) + return log_det_jacobian diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..33fdd32d7a0a01685690e598c69adca2c95972e9 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py @@ -0,0 +1,259 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Batch Norm bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.layers import normalization +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops.distributions import bijector + + +__all__ = [ + "BatchNormalization", +] + + +def _undo_batch_normalization(x, + mean, + variance, + offset, + scale, + variance_epsilon, + name=None): + r"""Inverse of tf.nn.batch_normalization. + + Args: + x: Input `Tensor` of arbitrary dimensionality. + mean: A mean `Tensor`. + variance: A variance `Tensor`. + offset: An offset `Tensor`, often denoted `beta` in equations, or + None. If present, will be added to the normalized tensor. + scale: A scale `Tensor`, often denoted `gamma` in equations, or + `None`. If present, the scale is applied to the normalized tensor. + variance_epsilon: A small `float` added to the minibatch `variance` to + prevent dividing by zero. + name: A name for this operation (optional). + + Returns: + batch_unnormalized: The de-normalized, de-scaled, de-offset `Tensor`. + """ + with ops.name_scope( + name, "undo_batchnorm", [x, mean, variance, scale, offset]): + # inv = math_ops.rsqrt(variance + variance_epsilon) + # if scale is not None: + # inv *= scale + # return x * inv + ( + # offset - mean * inv if offset is not None else -mean * inv) + rescale = math_ops.sqrt(variance + variance_epsilon) + if scale is not None: + rescale /= scale + batch_unnormalized = x * rescale + ( + mean - offset * rescale if offset is not None else mean) + return batch_unnormalized + + +class BatchNormalization(bijector.Bijector): + """Compute `Y = g(X) s.t. X = g^-1(Y) = (Y - mean(Y)) / std(Y)`. + + Applies Batch Normalization [(Ioffe and Szegedy, 2015)][1] to samples from a + data distribution. This can be used to stabilize training of normalizing + flows ([Papamakarios et al., 2016][3]; [Dinh et al., 2017][2]) + + When training Deep Neural Networks (DNNs), it is common practice to + normalize or whiten features by shifting them to have zero mean and + scaling them to have unit variance. + + The `inverse()` method of the `BatchNormalization` bijector, which is used in + the log-likelihood computation of data samples, implements the normalization + procedure (shift-and-scale) using the mean and standard deviation of the + current minibatch. + + Conversely, the `forward()` method of the bijector de-normalizes samples (e.g. + `X*std(Y) + mean(Y)` with the running-average mean and standard deviation + computed at training-time. De-normalization is useful for sampling. + + ```python + + dist = tfd.TransformedDistribution( + distribution=tfd.Normal()), + bijector=tfb.BatchNorm()) + + y = tfd.MultivariateNormalDiag(loc=1., scale=2.).sample(100) # ~ N(1, 2) + x = dist.bijector.inverse(y) # ~ N(0, 1) + y = dist.sample() # ~ N(1, 2) + ``` + + During training time, `BatchNorm.inverse` and `BatchNorm.forward` are not + guaranteed to be inverses of each other because `inverse(y)` uses statistics + of the current minibatch, while `forward(x)` uses running-average statistics + accumulated from training. In other words, + `BatchNorm.inverse(BatchNorm.forward(...))` and + `BatchNorm.forward(BatchNorm.inverse(...))` will be identical when + `training=False` but may be different when `training=True`. + + #### References + + [1]: Sergey Ioffe and Christian Szegedy. Batch Normalization: Accelerating + Deep Network Training by Reducing Internal Covariate Shift. In + _International Conference on Machine Learning_, 2015. + https://arxiv.org/abs/1502.03167 + + [2]: Laurent Dinh, Jascha Sohl-Dickstein, and Samy Bengio. Density Estimation + using Real NVP. In _International Conference on Learning + Representations_, 2017. https://arxiv.org/abs/1605.08803 + + [3]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 + """ + + def __init__(self, + batchnorm_layer=None, + training=True, + validate_args=False, + name="batch_normalization"): + """Instantiates the `BatchNorm` bijector. + + Args: + batchnorm_layer: `tf.layers.BatchNormalization` layer object. If `None`, + defaults to + `tf.layers.BatchNormalization(gamma_constraint=nn_ops.relu(x) + 1e-6)`. + This ensures positivity of the scale variable. + + training: If True, updates running-average statistics during call to + `inverse()`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + Raises: + ValueError: If bn_layer is not an instance of + `tf.layers.BatchNormalization`, or if it is specified with `renorm=True` + or a virtual batch size. + """ + # Scale must be positive. + g_constraint = lambda x: nn.relu(x) + 1e-6 + self.batchnorm = batchnorm_layer or normalization.BatchNormalization( + gamma_constraint=g_constraint) + self._validate_bn_layer(self.batchnorm) + self._training = training + super(BatchNormalization, self).__init__( + validate_args=validate_args, name=name) + + def _validate_bn_layer(self, layer): + """Check for valid BatchNormalization layer. + + Args: + layer: Instance of `tf.layers.BatchNormalization`. + Raises: + ValueError: If batchnorm_layer argument is not an instance of + `tf.layers.BatchNormalization`, or if `batchnorm_layer.renorm=True` or + if `batchnorm_layer.virtual_batch_size` is specified. + """ + if not isinstance(layer, normalization.BatchNormalization): + raise ValueError( + "batchnorm_layer must be an instance of BatchNormalization layer.") + if layer.renorm: + raise ValueError("BatchNorm Bijector does not support renormalization.") + if layer.virtual_batch_size: + raise ValueError( + "BatchNorm Bijector does not support virtual batch sizes.") + + def _get_broadcast_fn(self, x): + # Compute shape to broadcast scale/shift parameters to. + if not x.shape.is_fully_defined(): + raise ValueError("Input must have shape known at graph construction.") + input_shape = np.int32(x.shape.as_list()) + + ndims = len(input_shape) + # event_dims = self._compute_event_dims(x) + reduction_axes = [i for i in range(ndims) if i not in self.batchnorm.axis] + # Broadcasting only necessary for single-axis batch norm where the axis is + # not the last dimension + broadcast_shape = [1] * ndims + broadcast_shape[self.batchnorm.axis[0]] = ( + input_shape[self.batchnorm.axis[0]]) + def _broadcast(v): + if (v is not None and + len(v.get_shape()) != ndims and + reduction_axes != list(range(ndims - 1))): + return array_ops.reshape(v, broadcast_shape) + return v + return _broadcast + + def _normalize(self, y): + return self.batchnorm.apply(y, training=self._training) + + def _de_normalize(self, x): + # Uses the saved statistics. + if not self.batchnorm.built: + input_shape = x.get_shape() + self.batchnorm.build(input_shape) + broadcast_fn = self._get_broadcast_fn(x) + mean = broadcast_fn(self.batchnorm.moving_mean) + variance = broadcast_fn(self.batchnorm.moving_variance) + beta = broadcast_fn(self.batchnorm.beta) if self.batchnorm.center else None + gamma = broadcast_fn(self.batchnorm.gamma) if self.batchnorm.scale else None + return _undo_batch_normalization( + x, mean, variance, beta, gamma, self.batchnorm.epsilon) + + def _forward(self, x): + return self._de_normalize(x) + + def _inverse(self, y): + return self._normalize(y) + + def _forward_log_det_jacobian(self, x): + # Uses saved statistics to compute volume distortion. + return -self._inverse_log_det_jacobian(x, use_saved_statistics=True) + + def _inverse_log_det_jacobian(self, y, use_saved_statistics=False): + if not y.shape.is_fully_defined(): + raise ValueError("Input must have shape known at graph construction.") + input_shape = np.int32(y.shape.as_list()) + + if not self.batchnorm.built: + # Create variables. + self.batchnorm.build(input_shape) + + event_dims = self.batchnorm.axis + reduction_axes = [i for i in range(len(input_shape)) if i not in event_dims] + + if use_saved_statistics or not self._training: + log_variance = math_ops.log( + self.batchnorm.moving_variance + self.batchnorm.epsilon) + else: + # At training-time, ildj is computed from the mean and log-variance across + # the current minibatch. + _, v = nn.moments(y, axes=reduction_axes, keep_dims=True) + log_variance = math_ops.log(v + self.batchnorm.epsilon) + + # `gamma` and `log Var(y)` reductions over event_dims. + # Log(total change in area from gamma term). + log_total_gamma = math_ops.reduce_sum(math_ops.log(self.batchnorm.gamma)) + + # Log(total change in area from log-variance term). + log_total_variance = math_ops.reduce_sum(log_variance) + # The ildj is scalar, as it does not depend on the values of x and are + # constant across minibatch elements. + return log_total_gamma - 0.5 * log_total_variance diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index cbd60f92a60612c6cf791b2c7708a3310c6e2b6b..8f09e16058b766c788ab3acced6940fd0026b521 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -20,8 +20,6 @@ from __future__ import print_function import numpy as np -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -39,8 +37,6 @@ __all__ = [ class CholeskyOuterProduct(bijector.Bijector): """Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix. - `event_ndims` must be 0 or 2, i.e., scalar or matrix. - Note: the upper-triangular part of X is ignored (whether or not its zero). The surjectivity of g as a map from the set of n x n positive-diagonal @@ -61,49 +57,34 @@ class CholeskyOuterProduct(bijector.Bijector): that, if `I = L_3 @ L_3.T`, with L_3 being lower-triangular with positive- diagonal, then `L_3 = I`. Thus, `L_1 = L_2`, proving injectivity of g. - Examples: + #### Examples ```python - bijector.CholeskyOuterProduct(event_ndims=2).forward(x=[[1., 0], [2, 1]]) + bijector.CholeskyOuterProduct().forward(x=[[1., 0], [2, 1]]) # Result: [[1., 2], [2, 5]], i.e., x @ x.T - bijector.CholeskyOuterProduct(event_ndims=2).inverse(y=[[1., 2], [2, 5]]) + bijector.CholeskyOuterProduct().inverse(y=[[1., 2], [2, 5]]) # Result: [[1., 0], [2, 1]], i.e., cholesky(y). ``` """ - def __init__(self, event_ndims=2, validate_args=False, - name="cholesky_outer_product"): + def __init__(self, validate_args=False, name="cholesky_outer_product"): """Instantiates the `CholeskyOuterProduct` bijector. Args: - event_ndims: `constant` `int32` scalar `Tensor` indicating the number of - dimensions associated with a particular draw from the distribution. Must - be 0 or 2. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if event_ndims is neither 0 or 2. """ self._graph_parents = [] self._name = name - with self._name_scope("init", values=[event_ndims]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims is None or event_ndims not in [0, 2]: - raise ValueError("`event_ndims` must be a TF constant which is 0 or 2") - self._static_event_ndims = event_ndims super(CholeskyOuterProduct, self).__init__( - event_ndims=event_ndims, + event_ndims=2, validate_args=validate_args, name=name) def _forward(self, x): - if self._static_event_ndims == 0: - return math_ops.square(x) if self.validate_args: is_matrix = check_ops.assert_rank_at_least(x, 2) shape = array_ops.shape(x) @@ -114,11 +95,7 @@ class CholeskyOuterProduct(bijector.Bijector): return math_ops.matmul(x, x, adjoint_b=True) def _inverse(self, y): - return (math_ops.sqrt(y) if self._static_event_ndims == 0 - else linalg_ops.cholesky(y)) - - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(x=self._inverse(y)) + return linalg_ops.cholesky(y) def _forward_log_det_jacobian(self, x): # Let Y be a symmetric, positive definite matrix and write: @@ -161,13 +138,6 @@ class CholeskyOuterProduct(bijector.Bijector): # Since there is a 2 X[j,j] term for every lower-triangular element of X we # conclude: # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. - if self._static_event_ndims == 0: - if self.validate_args: - is_positive = check_ops.assert_positive( - x, message="All elements must be positive.") - x = control_flow_ops.with_dependencies([is_positive], x) - return np.log(2.) + math_ops.log(x) - diag = array_ops.matrix_diag_part(x) # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py new file mode 100644 index 0000000000000000000000000000000000000000..f5de052c9ed18b1ebf4c174aeea3a951b1ddcd9d --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py @@ -0,0 +1,153 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Kumaraswamy bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + +__all__ = [ + "Kumaraswamy", +] + + +class Kumaraswamy(bijector.Bijector): + """Compute `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a), X in [0, 1]`. + + This bijector maps inputs from `[0, 1]` to [0, 1]`. The inverse of the + bijector applied to a uniform random variable `X ~ U(0, 1) gives back a + random variable with the [Kumaraswamy distribution]( + https://en.wikipedia.org/wiki/Kumaraswamy_distribution): + + ```none + Y ~ Kumaraswamy(a, b) + pdf(y; a, b, 0 <= y <= 1) = a * b * y ** (a - 1) * (1 - y**a) ** (b - 1) + ``` + """ + + def __init__(self, + concentration1=None, + concentration0=None, + event_ndims=0, + validate_args=False, + name="kumaraswamy"): + """Instantiates the `Kumaraswamy` bijector. + + Args: + concentration1: Python `float` scalar indicating the transform power, + i.e., `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)` where `a` is + `concentration1`. + concentration0: Python `float` scalar indicating the transform power, + i.e., `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)` where `b` is + `concentration0`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. Currently only zero is + supported. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: If `event_ndims` is not zero. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims_const = tensor_util.constant_value(event_ndims) + if event_ndims_const is not None and event_ndims_const not in (0,): + raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) + else: + if validate_args: + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_equal( + event_ndims, 0, message="event_ndims was not 0")], + event_ndims) + + with self._name_scope("init", values=[concentration1, concentration0]): + concentration1 = self._maybe_assert_valid_concentration( + ops.convert_to_tensor(concentration1, name="concentration1"), + validate_args=validate_args) + concentration0 = self._maybe_assert_valid_concentration( + ops.convert_to_tensor(concentration0, name="concentration0"), + validate_args=validate_args) + + self._concentration1 = concentration1 + self._concentration0 = concentration0 + super(Kumaraswamy, self).__init__( + event_ndims=0, + validate_args=validate_args, + name=name) + + @property + def concentration1(self): + """The `a` in: `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)`.""" + return self._concentration1 + + @property + def concentration0(self): + """The `b` in: `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)`.""" + return self._concentration0 + + def _forward(self, x): + x = self._maybe_assert_valid(x) + return math_ops.exp( + math_ops.log1p(-math_ops.exp(math_ops.log1p(-x) / self.concentration0)) + / self.concentration1) + + def _inverse(self, y): + y = self._maybe_assert_valid(y) + return math_ops.exp(math_ops.log1p( + -(1 - y**self.concentration1)**self.concentration0)) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid(y) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + math_ops.log(self.concentration1) + math_ops.log(self.concentration0) + + (self.concentration1 - 1) * math_ops.log(y) + + (self.concentration0 - 1) * math_ops.log1p(-y**self.concentration1), + axis=event_dims) + + def _maybe_assert_valid_concentration(self, concentration, validate_args): + """Checks the validity of a concentration parameter.""" + if not validate_args: + return concentration + return control_flow_ops.with_dependencies([ + check_ops.assert_positive( + concentration, + message="Concentration parameter must be positive."), + ], concentration) + + def _maybe_assert_valid(self, x): + if not self.validate_args: + return x + return control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + x, + message="sample must be non-negative"), + check_ops.assert_less_equal( + x, array_ops.ones([], self.concentration0.dtype), + message="sample must be no larger than `1`."), + ], x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index 5251dbcb5748f75688aa43ce6e4e9dbd76be78bb..84b2340c75514c3d2c12bf4d775ba74450a0dc26 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -45,14 +45,15 @@ __all__ = [ class MaskedAutoregressiveFlow(bijector_lib.Bijector): """Affine MaskedAutoregressiveFlow bijector for vector-valued events. - The affine autoregressive flow [1] provides a relatively simple framework for - user-specified (deep) architectures to learn a distribution over vector-valued - events. Regarding terminology, + The affine autoregressive flow [(Papamakarios et al., 2016)][3] provides a + relatively simple framework for user-specified (deep) architectures to learn + a distribution over vector-valued events. Regarding terminology, "Autoregressive models decompose the joint density as a product of conditionals, and model each conditional in turn. Normalizing flows transform a base density (e.g. a standard Gaussian) into the target density - by an invertible transformation with tractable Jacobian." [1] + by an invertible transformation with tractable Jacobian." + [(Papamakarios et al., 2016)][3] In other words, the "autoregressive property" is equivalent to the decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided @@ -75,26 +76,26 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): Given a `shift_and_log_scale_fn`, the forward and inverse transformations are (a sequence of) affine transformations. A "valid" `shift_and_log_scale_fn` - must compute each `shift` (aka `loc` or "mu" [2]) and `log(scale)` (aka - "alpha" [2]) such that each are broadcastable with the arguments to `forward` - and `inverse`, i.e., such that the calculations in `forward`, `inverse` - [below] are possible. + must compute each `shift` (aka `loc` or "mu" in [Germain et al. (2015)][1]) + and `log(scale)` (aka "alpha" in [Germain et al. (2015)][1]) such that each + are broadcastable with the arguments to `forward` and `inverse`, i.e., such + that the calculations in `forward`, `inverse` [below] are possible. For convenience, `masked_autoregressive_default_template` is offered as a possible `shift_and_log_scale_fn` function. It implements the MADE - architecture [2]. MADE is a feed-forward network that computes a `shift` and - `log(scale)` using `masked_dense` layers in a deep neural network. Weights are - masked to ensure the autoregressive property. It is possible that this - architecture is suboptimal for your task. To build alternative networks, - either change the arguments to `masked_autoregressive_default_template`, use - the `masked_dense` function to roll-out your own, or use some other - architecture, e.g., using `tf.layers`. + architecture [(Germain et al., 2015)][1]. MADE is a feed-forward network that + computes a `shift` and `log(scale)` using `masked_dense` layers in a deep + neural network. Weights are masked to ensure the autoregressive property. It + is possible that this architecture is suboptimal for your task. To build + alternative networks, either change the arguments to + `masked_autoregressive_default_template`, use the `masked_dense` function to + roll-out your own, or use some other architecture, e.g., using `tf.layers`. Warning: no attempt is made to validate that the `shift_and_log_scale_fn` enforces the "autoregressive property". Assuming `shift_and_log_scale_fn` has valid shape and autoregressive - semantics, the forward transformation is, + semantics, the forward transformation is ```python def forward(x): @@ -106,7 +107,7 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): return y ``` - and the inverse transformation is, + and the inverse transformation is ```python def inverse(y): @@ -121,7 +122,7 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): the "last" `y` used to compute `shift`, `log_scale`. (Roughly speaking, this also proves the transform is bijective.) - #### Example Use + #### Examples ```python tfd = tf.contrib.distributions @@ -142,7 +143,8 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): maf.log_prob(x) # Almost free; uses Bijector caching. maf.log_prob(0.) # Cheap; no `tf.while_loop` despite no Bijector caching. - # [1] also describes an "Inverse Autoregressive Flow", e.g., + # [Papamakarios et al. (2016)][3] also describe an Inverse Autoregressive + # Flow [(Kingma et al., 2016)][2]: iaf = tfd.TransformedDistribution( distribution=tfd.Normal(loc=0., scale=1.), bijector=tfb.Invert(tfb.MaskedAutoregressiveFlow( @@ -168,14 +170,20 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): event_shape=[dims]) ``` - [1]: "Masked Autoregressive Flow for Density Estimation." - George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. - https://arxiv.org/abs/1705.07057 + #### References - [2]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 + [1]: Mathieu Germain, Karol Gregor, Iain Murray, and Hugo Larochelle. MADE: + Masked Autoencoder for Distribution Estimation. In _International + Conference on Machine Learning_, 2015. https://arxiv.org/abs/1502.03509 + [2]: Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya + Sutskever, and Max Welling. Improving Variational Inference with Inverse + Autoregressive Flow. In _Neural Information Processing Systems_, 2016. + https://arxiv.org/abs/1606.04934 + + [3]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ def __init__(self, @@ -329,11 +337,7 @@ def masked_dense(inputs, **kwargs): """A autoregressively masked dense layer. Analogous to `tf.layers.dense`. - See [1] for detailed explanation. - - [1]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 + See [Germain et al. (2015)][1] for detailed explanation. Arguments: inputs: Tensor input. @@ -358,6 +362,12 @@ def masked_dense(inputs, Raises: NotImplementedError: if rightmost dimension of `inputs` is unknown prior to graph execution. + + #### References + + [1]: Mathieu Germain, Karol Gregor, Iain Murray, and Hugo Larochelle. MADE: + Masked Autoencoder for Distribution Estimation. In _International + Conference on Machine Learning_, 2015. https://arxiv.org/abs/1502.03509 """ # TODO(b/67594795): Better support of dynamic shape. input_depth = inputs.shape.with_rank_at_least(1)[-1].value @@ -398,23 +408,24 @@ def masked_autoregressive_default_template( name=None, *args, **kwargs): - """Build the MADE Model [1]. + """Build the Masked Autoregressive Density Estimator (Germain et al., 2015). This will be wrapped in a make_template to ensure the variables are only - created once. It takes the input and returns the `loc` ("mu" [1]) and - `log_scale` ("alpha" [1]) from the MADE network. + created once. It takes the input and returns the `loc` ("mu" in [Germain et + al. (2015)][1]) and `log_scale` ("alpha" in [Germain et al. (2015)][1]) from + the MADE network. Warning: This function uses `masked_dense` to create randomly initialized `tf.Variables`. It is presumed that these will be fit, just as you would any other neural architecture which uses `tf.layers.dense`. - #### About Hidden Layers: + #### About Hidden Layers Each element of `hidden_layers` should be greater than the `input_depth` (i.e., `input_depth = tf.shape(input)[-1]` where `input` is the input to the neural network). This is necessary to ensure the autoregressivity property. - #### About Clipping: + #### About Clipping This function also optionally clips the `log_scale` (but possibly not its gradient). This is useful because if `log_scale` is too small/large it might @@ -427,11 +438,7 @@ def masked_autoregressive_default_template( `grad[exp(clip(x))] = grad[x] exp(clip(x))` rather than the usual `grad[clip(x)] exp(clip(x))`. - [1]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - Arguments: + Args: hidden_layers: Python `list`-like of non-negative integer, scalars indicating the number of units in each hidden layer. Default: `[512, 512]. shift_only: Python `bool` indicating if only the `shift` term shall be @@ -450,12 +457,20 @@ def masked_autoregressive_default_template( **kwargs: `tf.layers.dense` keyword arguments. Returns: - shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). - log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). + shift: `Float`-like `Tensor` of shift terms (the "mu" in + [Germain et al. (2015)][1]). + log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in + [Germain et al. (2015)][1]). Raises: NotImplementedError: if rightmost dimension of `inputs` is unknown prior to graph execution. + + #### References + + [1]: Mathieu Germain, Karol Gregor, Iain Murray, and Hugo Larochelle. MADE: + Masked Autoencoder for Distribution Estimation. In _International + Conference on Machine Learning_, 2015. https://arxiv.org/abs/1502.03509 """ with ops.name_scope(name, "masked_autoregressive_default_template", diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py index 2840f52e742eac5e9e37a576bf7f6d6f05a07a35..71ab369d01aafc33854a2c2437f96bbb493cc6fb 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py @@ -38,7 +38,7 @@ class RealNVP(bijector_lib.Bijector): """RealNVP "affine coupling layer" for vector-valued events. Real NVP models a normalizing flow on a `D`-dimensional distribution via a - single `D-d`-dimensional conditional distribution [1]: + single `D-d`-dimensional conditional distribution [(Dinh et al., 2017)][1]: `y[d:D] = y[d:D] * math_ops.exp(log_scale_fn(y[d:D])) + shift_fn(y[d:D])` `y[0:d] = x[0:d]` @@ -51,31 +51,34 @@ class RealNVP(bijector_lib.Bijector): Masking is currently only supported for base distributions with `event_ndims=1`. For more sophisticated masking schemes like checkerboard or - channel-wise masking [2], use the `tfb.Permute` bijector to re-order desired - masked units into the first `d` units. For base distributions with - `event_ndims > 1`, use the `tfb.Reshape` bijector to flatten the event shape. - - Recall that the MAF bijector [2] implements a normalizing flow via an - autoregressive transformation. MAF and IAF have opposite computational - tradeoffs - MAF can train all units in parallel but must sample units - sequentially, while IAF must train units sequentially but can sample in - parallel. In contrast, Real NVP can compute both forward and inverse - computations in parallel. However, the lack of an autoregressive + channel-wise masking [(Papamakarios et al., 2016)[4], use the `tfb.Permute` + bijector to re-order desired masked units into the first `d` units. For base + distributions with `event_ndims > 1`, use the `tfb.Reshape` bijector to + flatten the event shape. + + Recall that the MAF bijector [(Papamakarios et al., 2016)][4] implements a + normalizing flow via an autoregressive transformation. MAF and IAF have + opposite computational tradeoffs - MAF can train all units in parallel but + must sample units sequentially, while IAF must train units sequentially but + can sample in parallel. In contrast, Real NVP can compute both forward and + inverse computations in parallel. However, the lack of an autoregressive transformations makes it less expressive on a per-bijector basis. A "valid" `shift_and_log_scale_fn` must compute each `shift` (aka `loc` or - "mu" [2]) and `log(scale)` (aka "alpha" [2]) such that each are broadcastable - with the arguments to `forward` and `inverse`, i.e., such that the - calculations in `forward`, `inverse` [below] are possible. For convenience, + "mu" in [Papamakarios et al. (2016)][4]) and `log(scale)` (aka "alpha" in + [Papamakarios et al. (2016)][4]) such that each are broadcastable with the + arguments to `forward` and `inverse`, i.e., such that the calculations in + `forward`, `inverse` [below] are possible. For convenience, `real_nvp_default_nvp` is offered as a possible `shift_and_log_scale_fn` function. - NICE [3] is a special case of the Real NVP bijector which discards the scale - transformation, resulting in a constant-time inverse-log-determinant-Jacobian. - To use a NICE bijector instead of Real NVP, `shift_and_log_scale_fn` should - return `(shift, None)`, and `is_constant_jacobian` should be set to `True` in - the `RealNVP` constructor. Calling `real_nvp_default_template` with - `shift_only=True` returns one such NICE-compatible `shift_and_log_scale_fn`. + NICE [(Dinh et al., 2014)][2] is a special case of the Real NVP bijector + which discards the scale transformation, resulting in a constant-time + inverse-log-determinant-Jacobian. To use a NICE bijector instead of Real + NVP, `shift_and_log_scale_fn` should return `(shift, None)`, and + `is_constant_jacobian` should be set to `True` in the `RealNVP` constructor. + Calling `real_nvp_default_template` with `shift_only=True` returns one such + NICE-compatible `shift_and_log_scale_fn`. Caching: the scalar input depth `D` of the base distribution is not known at construction time. The first call to any of `forward(x)`, `inverse(x)`, @@ -103,23 +106,24 @@ class RealNVP(bijector_lib.Bijector): nvp.log_prob(0.) ``` - For more examples, see [4]. + For more examples, see [Jang (2018)][3]. - [1]: "Density Estimation using Real NVP." - Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio. ICLR. 2017. - https://arxiv.org/abs/1605.08803 + #### References - [2]: "Masked Autoregressive Flow for Density Estimation." - George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. - https://arxiv.org/abs/1705.07057 + [1]: Laurent Dinh, Jascha Sohl-Dickstein, and Samy Bengio. Density Estimation + using Real NVP. In _International Conference on Learning + Representations_, 2017. https://arxiv.org/abs/1605.08803 - [3]: "NICE: Non-linear Independent Components Estimation." - Laurent Dinh, David Krueger, Yoshua Bengio. ICLR. 2015. - https://arxiv.org/abs/1410.8516 + [2]: Laurent Dinh, David Krueger, and Yoshua Bengio. NICE: Non-linear + Independent Components Estimation. _arXiv preprint arXiv:1410.8516_, + 2014. https://arxiv.org/abs/1410.8516 - [4]: "Normalizing Flows Tutorial, Part 2: Modern Normalizing Flows." - Eric Jang. Blog post. January 2018. - http://blog.evjang.com/2018/01/nf2.html + [3]: Eric Jang. Normalizing Flows Tutorial, Part 2: Modern Normalizing Flows. + _Technical Report_, 2018. http://blog.evjang.com/2018/01/nf2.html + + [4]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ def __init__(self, @@ -250,12 +254,20 @@ def real_nvp_default_template( **kwargs: `tf.layers.dense` keyword arguments. Returns: - shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). - log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). + shift: `Float`-like `Tensor` of shift terms ("mu" in + [Papamakarios et al. (2016)][1]). + log_scale: `Float`-like `Tensor` of log(scale) terms ("alpha" in + [Papamakarios et al. (2016)][1]). Raises: NotImplementedError: if rightmost dimension of `inputs` is unknown prior to graph execution. + + #### References + + [1]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ with ops.name_scope(name, "real_nvp_default_template"): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py deleted file mode 100644 index 223bc9d042c69be05b0e578835a31ed6e83c0c97..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SigmoidCentered bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distributions.python.ops.bijectors import softmax_centered - - -__all__ = [ - "SigmoidCentered", -] - - -class SigmoidCentered(softmax_centered.SoftmaxCentered): - """Bijector which computes Y = g(X) = exp([X 0]) / (1 + exp(-X)). - - Equivalent to: `bijector.SoftmaxCentered(event_ndims=0)`. - - See `bijector.SoftmaxCentered` for more details. - """ - - def __init__(self, validate_args=False, name="sigmoid_centered"): - super(SigmoidCentered, self).__init__( - event_ndims=0, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index a9dcce6c526600f3b26c6bceb730417000917ce7..dc94fd0a38de29f5a7ee6ca826aab0ecf8712966 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -18,13 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - from tensorflow.contrib.distributions.python.ops import distribution_util -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -47,17 +42,14 @@ class SoftmaxCentered(bijector.Bijector): e.g., `softmax(x) = exp(x-c) / sum(exp(x-c))` where `c` is the implicit last coordinate. - Because we append a coordinate, this bijector only supports `event_ndim in [0, - 1]`, i.e., scalars and vectors. - Example Use: ```python - bijector.SoftmaxCentered(event_ndims=1).forward(tf.log([2, 3, 4])) + bijector.SoftmaxCentered().forward(tf.log([2, 3, 4])) # Result: [0.2, 0.3, 0.4, 0.1] # Extra result: 0.1 - bijector.SoftmaxCentered(event_ndims=1).inverse([0.2, 0.3, 0.4, 0.1]) + bijector.SoftmaxCentered().inverse([0.2, 0.3, 0.4, 0.1]) # Result: tf.log([2, 3, 4]) # Extra coordinate removed. ``` @@ -69,82 +61,47 @@ class SoftmaxCentered(bijector.Bijector): """ def __init__(self, - event_ndims=0, validate_args=False, name="softmax_centered"): self._graph_parents = [] self._name = name - with self._name_scope("init", values=[event_ndims]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims is None or event_ndims not in [0, 1]: - raise ValueError("`event_ndims` must be a TF constant which is 0 or 1") - self._static_event_ndims = event_ndims super(SoftmaxCentered, self).__init__( - event_ndims=event_ndims, + event_ndims=1, validate_args=validate_args, name=name) def _forward_event_shape(self, input_shape): - if input_shape.ndims is None: + if input_shape.ndims is None or input_shape[-1] is None: return input_shape - if input_shape.ndims != self._static_event_ndims: - raise ValueError("input_shape.dims = %d != %d" % - (input_shape.ndims, self._static_event_ndims)) - if input_shape.ndims == 0: - return tensor_shape.TensorShape([2]) - if input_shape.ndims == 1: - return tensor_shape.TensorShape(input_shape[0] + 1) - # Unreachable code: - raise ValueError("event_ndims = %d must be 0 or 1" % input_shape.ndims) + return tensor_shape.TensorShape([input_shape[-1] + 1]) def _forward_event_shape_tensor(self, input_shape): - ndims = array_ops.shape(input_shape) - if self.validate_args: - # It is not possible for a negative shape so we need only check <= 1. - is_zero_or_one = check_ops.assert_equal( - ndims, 0 if self._static_event_ndims == 0 else 1, - message="event_ndims must be 0 or 1") - ndims = control_flow_ops.with_dependencies([is_zero_or_one], ndims) - if self._static_event_ndims == 0: - return ops.convert_to_tensor( - [2], dtype=dtypes.int32, name="output_shape") - return input_shape + 1 + return (input_shape[-1] + 1)[..., array_ops.newaxis] def _inverse_event_shape(self, output_shape): - if output_shape.ndims is None: + if output_shape.ndims is None or output_shape[-1] is None: return output_shape - if output_shape.ndims != 1: - raise ValueError("output_shape.ndims = %d != 1" % output_shape.ndims) - if self._static_event_ndims == 0: - return tensor_shape.TensorShape([]) - return tensor_shape.TensorShape(output_shape[0] - 1) + if output_shape[-1] <= 1: + raise ValueError("output_shape[-1] = %d <= 1" % output_shape[-1]) + return tensor_shape.TensorShape([output_shape[-1] - 1]) def _inverse_event_shape_tensor(self, output_shape): - ndims = array_ops.shape(output_shape)[0] if self.validate_args: # It is not possible for a negative shape so we need only check <= 1. - is_one = check_ops.assert_equal( - ndims, 1, message="event_ndims must be 1") - ndims = control_flow_ops.with_dependencies([is_one], ndims) - if self._static_event_ndims == 0: - return ops.convert_to_tensor([], dtype=dtypes.int32, name="output_shape") - return array_ops.expand_dims(output_shape[0] - 1, dim=0) + is_greater_one = check_ops.assert_greater( + output_shape[-1], 1, message="Need last dimension greater than 1.") + output_shape = control_flow_ops.with_dependencies( + [is_greater_one], output_shape) + return (output_shape[-1] - 1)[..., array_ops.newaxis] def _forward(self, x): # Pad the last dim with a zeros vector. We need this because it lets us # infer the scale in the inverse function. - y = array_ops.expand_dims(x, dim=-1) if self._static_event_ndims == 0 else x - y = distribution_util.pad(y, axis=-1, back=True) + y = distribution_util.pad(x, axis=-1, back=True) # Set shape hints. if x.shape.ndims is not None: - shape = x.shape.as_list() - if self._static_event_ndims == 0: - shape += [2] - elif shape[-1] is not None: - shape[-1] += 1 - shape = tensor_shape.TensorShape(shape) + shape = x.shape[:-1].concatenate(x.shape[-1] + 1) y.shape.assert_is_compatible_with(shape) y.set_shape(shape) @@ -161,42 +118,17 @@ class SoftmaxCentered(bijector.Bijector): # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization) # = log(exp(x[i])/normalization) - log(y[end]) # = log(y[i]) - log(y[end]) - shape = (np.asarray(y.shape.as_list(), dtype=np.int32) - if y.shape.is_fully_defined() - else array_ops.shape(y, name="shape")) - ndims = distribution_util.prefer_static_rank(y) # Do this first to make sure CSE catches that it'll happen again in # _inverse_log_det_jacobian. x = math_ops.log(y) - # We now extract the last coordinate of the rightmost dimension. - # Our trick is to slice from [0,0,...,shape[-1]-1] to shape[:-1]+[1]. - begin = array_ops.one_hot(indices=ndims-1, - depth=ndims, - on_value=shape[-1]-np.array(1, dtype=shape.dtype), - dtype=shape.dtype) - size = array_ops.concat([shape[:-1], np.asarray([1], dtype=shape.dtype)], 0) - log_normalization = -array_ops.strided_slice(x, begin, begin + size) - - # Here we slice out all but the last coordinate; see above for idea. - begin = array_ops.zeros_like(shape) - size = array_ops.concat([shape[:-1], [shape[-1] - 1]], 0) - x = array_ops.strided_slice(x, begin, begin + size) - - x += log_normalization - - if self._static_event_ndims == 0: - x = array_ops.squeeze(x, squeeze_dims=[ndims-1]) + log_normalization = (-x[..., -1])[..., array_ops.newaxis] + x = x[..., :-1] + log_normalization # Set shape hints. if y.shape.ndims is not None: - shape = y.shape.as_list() - if self._static_event_ndims == 0: - shape = shape[:-1] - elif shape[-1] is not None: - shape[-1] -= 1 - shape = tensor_shape.TensorShape(shape) + shape = y.shape[:-1].concatenate(y.shape[-1] - 1) x.shape.assert_is_compatible_with(shape) x.set_shape(shape) @@ -222,19 +154,16 @@ class SoftmaxCentered(bijector.Bijector): return -math_ops.reduce_sum(math_ops.log(y), axis=-1) def _forward_log_det_jacobian(self, x): - if self._static_event_ndims == 0: - return x - 2. * nn_ops.softplus(x) - else: - # This code is similar to nn_ops.log_softmax but different because we have - # an implicit zero column to handle. I.e., instead of: - # reduce_sum(logits - reduce_sum(exp(logits), dim)) - # we must do: - # log_normalization = 1 + reduce_sum(exp(logits)) - # -log_normalization + reduce_sum(logits - log_normalization) - log_normalization = nn_ops.softplus( - math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) - fldj = (-log_normalization + - math_ops.reduce_sum(x - log_normalization, - axis=-1, - keep_dims=True)) - return array_ops.squeeze(fldj, squeeze_dims=-1) + # This code is similar to nn_ops.log_softmax but different because we have + # an implicit zero column to handle. I.e., instead of: + # reduce_sum(logits - reduce_sum(exp(logits), dim)) + # we must do: + # log_normalization = 1 + reduce_sum(exp(logits)) + # -log_normalization + reduce_sum(logits - log_normalization) + log_normalization = nn_ops.softplus( + math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) + fldj = (-log_normalization + + math_ops.reduce_sum(x - log_normalization, + axis=-1, + keep_dims=True)) + return array_ops.squeeze(fldj, squeeze_dims=-1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/square.py b/tensorflow/contrib/distributions/python/ops/bijectors/square.py new file mode 100644 index 0000000000000000000000000000000000000000..1e9dbf35091fe51f2478dc085c394a77295ca4ee --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/square.py @@ -0,0 +1,84 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Square bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + + +__all__ = [ + "Square", +] + + +class Square(bijector.Bijector): + """Compute `g(X) = X^2`; X is a positive real number. + + g is a bijection between the non-negative real numbers (R_+) and the + non-negative real numbers. + + #### Examples + + ```python + bijector.Square().forward(x=[[1., 0], [2, 1]]) + # Result: [[1., 0], [4, 1]], i.e., x^2 + + bijector.Square().inverse(y=[[1., 4], [9, 1]]) + # Result: [[1., 2], [3, 1]], i.e., sqrt(y). + ``` + + """ + + def __init__(self, validate_args=False, name="square"): + """Instantiates the `Square` bijector. + + Args: + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._name = name + super(Square, self).__init__( + event_ndims=0, + validate_args=validate_args, + name=name) + + def _forward(self, x): + x = self._maybe_assert_valid(x) + return math_ops.square(x) + + def _inverse(self, y): + y = self._maybe_assert_valid(y) + return math_ops.sqrt(y) + + def _forward_log_det_jacobian(self, x): + x = self._maybe_assert_valid(x) + return np.log(2.) + math_ops.log(x) + + def _maybe_assert_valid(self, t): + if not self.validate_args: + return t + is_valid = check_ops.assert_non_negative( + t, message="All elements must be non-negative.") + return control_flow_ops.with_dependencies([is_valid], t) + diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py index bdd5571c966a74e58e4f9f8eed2628f131a1b92e..e610f469e5d5f446b75c734cc39811de30a8cb9a 100644 --- a/tensorflow/contrib/distributions/python/ops/chi2.py +++ b/tensorflow/contrib/distributions/python/ops/chi2.py @@ -21,6 +21,8 @@ from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import gamma @@ -87,7 +89,11 @@ class Chi2(gamma.Gamma): # allow_nan_stats=True # through to the parent class results in unnecessary asserts. with ops.name_scope(name, values=[df]): - self._df = ops.convert_to_tensor(df, name="df") + with ops.control_dependencies([ + check_ops.assert_positive(df), + ] if validate_args else []): + self._df = array_ops.identity(df, name="df") + super(Chi2, self).__init__( concentration=0.5 * self._df, rate=constant_op.constant(0.5, dtype=self._df.dtype), diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index d0efaefb8e78ddf4436e9e5a112d2c1cdddaf3b5..8d05ad6b8032fb8bada99389959091fb1c28beda 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -190,9 +190,6 @@ class _Gumbel(distribution.Distribution): def _log_prob(self, x): return self._log_unnormalized_prob(x) - self._log_normalization() - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - def _log_cdf(self, x): return -math_ops.exp(-self._z(x)) diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index cbce005013281ff3c58c94d525d5ce7a865d725a..7dcb3e3ac4db1855adacb7ec0fa8554c45d9c859 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.distributions import kullback_leibler class Independent(distribution_lib.Distribution): @@ -254,3 +255,58 @@ class Independent(distribution_lib.Distribution): else: which_maximum = np.maximum return which_maximum(0, ndims - 1) + + +@kullback_leibler.RegisterKL(Independent, Independent) +def _kl_independent(a, b, name="kl_independent"): + """Batched KL divergence `KL(a || b)` for Independent distributions. + + We can leverage the fact that + ``` + KL(Independent(a) || Independent(b)) = sum(KL(a || b)) + ``` + where the sum is over the `reinterpreted_batch_ndims`. + + Args: + a: Instance of `Independent`. + b: Instance of `Independent`. + name: (optional) name to use for created ops. Default "kl_independent". + + Returns: + Batchwise `KL(a || b)`. + + Raises: + ValueError: If the event space for `a` and `b`, or their underlying + distributions don't match. + """ + p = a.distribution + q = b.distribution + + # The KL between any two (non)-batched distributions is a scalar. + # Given that the KL between two factored distributions is the sum, i.e. + # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute + # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions. + if a.event_shape.is_fully_defined() and b.event_shape.is_fully_defined(): + if a.event_shape == b.event_shape: + if p.event_shape == q.event_shape: + num_reduce_dims = a.event_shape.ndims - p.event_shape.ndims + reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)] + + return math_ops.reduce_sum( + kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims) + else: + raise NotImplementedError("KL between Independents with different " + "event shapes not supported.") + else: + raise ValueError("Event shapes do not match.") + else: + with ops.control_dependencies([ + check_ops.assert_equal(a.event_shape_tensor(), b.event_shape_tensor()), + check_ops.assert_equal(p.event_shape_tensor(), q.event_shape_tensor()) + ]): + num_reduce_dims = ( + array_ops.shape(a.event_shape_tensor()[0]) - + array_ops.shape(p.event_shape_tensor()[0])) + reduce_dims = math_ops.range(-num_reduce_dims - 1, -1, 1) + return math_ops.reduce_sum( + kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims) diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index ee4d86867d48b20e97757bcec57d452085814b80..51ac61dcf640ca89f22c47127bda71316a179ca4 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -192,12 +192,6 @@ class InverseGamma(distribution.Distribution): def _log_prob(self, x): return self._log_unnormalized_prob(x) - self._log_normalization() - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - - def _log_cdf(self, x): - return math_ops.log(self._cdf(x)) - def _cdf(self, x): x = self._maybe_assert_valid_sample(x) # Note that igammac returns the upper regularized incomplete gamma diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py index 74d5d8773cf3e69a52554c87d656fea2835c8354..192dede6ff1d4de8d4be9965c414e7453d7b5d4b 100644 --- a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py @@ -20,15 +20,17 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops -from tensorflow.python.ops.distributions import beta from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.ops.distributions import uniform from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -42,25 +44,23 @@ _kumaraswamy_sample_note = """Note: `x` must have dtype `self.dtype` and be in def _harmonic_number(x): """Compute the harmonic number from its analytic continuation. - Derivation from [1] and Euler's constant [2]. - [1] - - https://en.wikipedia.org/wiki/Digamma_function#Relation_to_harmonic_numbers - [2] - https://en.wikipedia.org/wiki/Euler%E2%80%93Mascheroni_constant - + Derivation from [here]( + https://en.wikipedia.org/wiki/Digamma_function#Relation_to_harmonic_numbers) + and [Euler's constant]( + https://en.wikipedia.org/wiki/Euler%E2%80%93Mascheroni_constant). Args: x: input float. Returns: z: The analytic continuation of the harmonic number for the input. - """ one = array_ops.ones([], dtype=x.dtype) return math_ops.digamma(x + one) - math_ops.digamma(one) @tf_export("distributions.Kumaraswamy") -class Kumaraswamy(beta.Beta): +class Kumaraswamy(transformed_distribution.TransformedDistribution): """Kumaraswamy distribution. The Kumaraswamy distribution is defined over the `(0, 1)` interval using @@ -151,59 +151,32 @@ class Kumaraswamy(beta.Beta): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ + concentration1 = ops.convert_to_tensor( + concentration1, name="concentration1") + concentration0 = ops.convert_to_tensor( + concentration0, name="concentration0") super(Kumaraswamy, self).__init__( - concentration1=concentration1, - concentration0=concentration0, - validate_args=validate_args, - allow_nan_stats=allow_nan_stats, + distribution=uniform.Uniform( + low=array_ops.zeros([], dtype=concentration1.dtype), + high=array_ops.ones([], dtype=concentration1.dtype), + allow_nan_stats=allow_nan_stats), + bijector=bijectors.Kumaraswamy( + concentration1=concentration1, concentration0=concentration0, + validate_args=validate_args), + batch_shape=distribution_util.get_broadcast_shape( + concentration1, concentration0), name=name) self._reparameterization_type = distribution.FULLY_REPARAMETERIZED - def _sample_n(self, n, seed=None): - expanded_concentration1 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration1 - expanded_concentration0 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration0 - shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) - uniform_sample = random_ops.random_uniform( - shape=shape, minval=0.0, maxval=1.0, dtype=self.dtype, seed=seed) - - kumaraswamy_sample = (1 - uniform_sample**(1. / expanded_concentration0))**( - 1. / expanded_concentration1) - return kumaraswamy_sample - - @distribution_util.AppendDocstring(_kumaraswamy_sample_note) - def _log_cdf(self, x): - a = self.concentration1 - b = self.concentration0 - return math_ops.log1p(-(1 - x**a)**b) + @property + def concentration1(self): + """Concentration parameter associated with a `1` outcome.""" + return self.bijector.concentration1 - @distribution_util.AppendDocstring(_kumaraswamy_sample_note) - def _cdf(self, x): - a = self.concentration1 - b = self.concentration0 - return 1 - (1 - x**a)**b - - def _survival_function(self, x): - a = self.concentration1 - b = self.concentration0 - return (1 - x**a)**b - - def _log_survival_function(self, x): - a = self.concentration1 - b = self.concentration0 - return b * math_ops.log1p(-x**a) - - def _log_unnormalized_prob(self, x): - x = self._maybe_assert_valid_sample(x) - a = self.concentration1 - b = self.concentration0 - return (a - 1) * math_ops.log(x) + (b - 1) * math_ops.log1p(-x**a) - - def _log_normalization(self): - a = self.concentration1 - b = self.concentration0 - return -(math_ops.log(a) + math_ops.log(b)) + @property + def concentration0(self): + """Concentration parameter associated with a `0` outcome.""" + return self.bijector.concentration0 def _entropy(self): a = self.concentration1 @@ -213,10 +186,11 @@ class Kumaraswamy(beta.Beta): def _moment(self, n): """Compute the n'th (uncentered) moment.""" + total_concentration = self.concentration1 + self.concentration0 expanded_concentration1 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration1 + total_concentration, dtype=self.dtype) * self.concentration1 expanded_concentration0 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration0 + total_concentration, dtype=self.dtype) * self.concentration0 beta_arg0 = 1 + n / expanded_concentration1 beta_arg = array_ops.stack([beta_arg0, expanded_concentration0], -1) log_moment = math_ops.log(expanded_concentration0) + special_math_ops.lbeta( @@ -246,13 +220,14 @@ class Kumaraswamy(beta.Beta): name="nan") is_defined = (self.concentration1 > 1.) & (self.concentration0 > 1.) return array_ops.where(is_defined, mode, nan) + return control_flow_ops.with_dependencies([ check_ops.assert_less( - array_ops.ones([], dtype=self.dtype), + array_ops.ones([], dtype=self.concentration1.dtype), self.concentration1, message="Mode undefined for concentration1 <= 1."), check_ops.assert_less( - array_ops.ones([], dtype=self.dtype), + array_ops.ones([], dtype=self.concentration0.dtype), self.concentration0, message="Mode undefined for concentration0 <= 1.") ], mode) diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index 473677f8d91b184e029f345bb05f5c5d63df7a40..68e6bca5a554b29a450911073eb5c4fe55f313c6 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -185,9 +185,6 @@ class Logistic(distribution.Distribution): def _log_prob(self, x): return self._log_unnormalized_prob(x) - self._log_normalization() - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - def _log_cdf(self, x): return -nn_ops.softplus(-self._z(x)) diff --git a/tensorflow/contrib/distributions/python/ops/moving_stats.py b/tensorflow/contrib/distributions/python/ops/moving_stats.py index 20f85643b9e7db61b4786dffe4115c7d3c00b046..87d40805a3c7a9c2871305af7f7182b7e2923530 100644 --- a/tensorflow/contrib/distributions/python/ops/moving_stats.py +++ b/tensorflow/contrib/distributions/python/ops/moving_stats.py @@ -47,9 +47,7 @@ def assign_moving_mean_variance( Note: `mean_var` is updated *after* `variance_var`, i.e., `variance_var` uses the lag-1 mean. - For derivation justification, see equation 143 of: - T. Finch, Feb 2009. "Incremental calculation of weighted mean and variance". - http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf + For derivation justification, see [Finch (2009; Eq. 143)][1]. Args: mean_var: `float`-like `Variable` representing the exponentially weighted @@ -72,6 +70,12 @@ def assign_moving_mean_variance( TypeError: if `mean_var` does not have float type `dtype`. TypeError: if `mean_var`, `variance_var`, `value`, `decay` have different `base_dtype`. + + #### References + + [1]: Tony Finch. Incremental calculation of weighted mean and variance. + _Technical Report_, 2009. + http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf """ with ops.name_scope(name, "assign_moving_mean_variance", [variance_var, mean_var, value, decay]): @@ -183,9 +187,7 @@ def moving_mean_variance(value, decay, collections=None, name=None): Note: `mean_var` is updated *after* `variance_var`, i.e., `variance_var` uses the lag-`1` mean. - For derivation justification, see equation 143 of: - T. Finch, Feb 2009. "Incremental calculation of weighted mean and variance". - http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf + For derivation justification, see [Finch (2009; Eq. 143)][1]. Unlike `assign_moving_mean_variance`, this function handles variable creation. @@ -208,6 +210,12 @@ def moving_mean_variance(value, decay, collections=None, name=None): Raises: TypeError: if `value_var` does not have float type `dtype`. TypeError: if `value`, `decay` have different `base_dtype`. + + #### References + + [1]: Tony Finch. Incremental calculation of weighted mean and variance. + _Technical Report_, 2009. + http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf """ if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py index b76cebf79fad09ebec68f2459c6fe80794ea81c0..46c2cc8b7a8c536a90176fbb2b2d52fed61e4705 100644 --- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py @@ -203,9 +203,6 @@ class OneHotCategorical(distribution.Distribution): ret = array_ops.reshape(ret, logits_shape) return ret - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - def _entropy(self): return -math_ops.reduce_sum( nn_ops.log_softmax(self.logits) * self.probs, axis=-1) diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index 2aa771a71efe52c8d86d459f090ea8ee137c4487..ff33f327c7a77597e516208cacad8c4aed65d1c9 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -285,9 +285,6 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): ret = array_ops.reshape(log_prob, logits_shape) return ret - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - def _assert_valid_sample(self, x): if not self.validate_args: return x diff --git a/tensorflow/contrib/distributions/python/ops/shape.py b/tensorflow/contrib/distributions/python/ops/shape.py index 5fb6f0c7eaa8c4734ea4c161b0eee6f24d4c9850..bac0b79d5908712f4e64259768fb6f3b4558f620 100644 --- a/tensorflow/contrib/distributions/python/ops/shape.py +++ b/tensorflow/contrib/distributions/python/ops/shape.py @@ -32,45 +32,50 @@ from tensorflow.python.ops.distributions import util as distribution_util class _DistributionShape(object): """Manage and manipulate `Distribution` shape. - Terminology: - Recall that a `Tensor` has: - - `shape`: size of `Tensor` dimensions, - - `ndims`: size of `shape`; number of `Tensor` dimensions, - - `dims`: indexes into `shape`; useful for transpose, reduce. - - `Tensor`s sampled from a `Distribution` can be partitioned by `sample_dims`, - `batch_dims`, and `event_dims`. To understand the semantics of these - dimensions, consider when two of the three are fixed and the remaining - is varied: - - `sample_dims`: indexes independent draws from identical - parameterizations of the `Distribution`. - - `batch_dims`: indexes independent draws from non-identical - parameterizations of the `Distribution`. - - `event_dims`: indexes event coordinates from one sample. - - The `sample`, `batch`, and `event` dimensions constitute the entirety of a - `Distribution` `Tensor`'s shape. - - The dimensions are always in `sample`, `batch`, `event` order. - - Purpose: - This class partitions `Tensor` notions of `shape`, `ndims`, and `dims` into - `Distribution` notions of `sample,` `batch,` and `event` dimensions. That - is, it computes any of: + #### Terminology - ``` - sample_shape batch_shape event_shape - sample_dims batch_dims event_dims - sample_ndims batch_ndims event_ndims - ``` + Recall that a `Tensor` has: + - `shape`: size of `Tensor` dimensions, + - `ndims`: size of `shape`; number of `Tensor` dimensions, + - `dims`: indexes into `shape`; useful for transpose, reduce. + + `Tensor`s sampled from a `Distribution` can be partitioned by `sample_dims`, + `batch_dims`, and `event_dims`. To understand the semantics of these + dimensions, consider when two of the three are fixed and the remaining + is varied: + - `sample_dims`: indexes independent draws from identical + parameterizations of the `Distribution`. + - `batch_dims`: indexes independent draws from non-identical + parameterizations of the `Distribution`. + - `event_dims`: indexes event coordinates from one sample. + + The `sample`, `batch`, and `event` dimensions constitute the entirety of a + `Distribution` `Tensor`'s shape. + + The dimensions are always in `sample`, `batch`, `event` order. + + #### Purpose + + This class partitions `Tensor` notions of `shape`, `ndims`, and `dims` into + `Distribution` notions of `sample,` `batch,` and `event` dimensions. That + is, it computes any of: + + ``` + sample_shape batch_shape event_shape + sample_dims batch_dims event_dims + sample_ndims batch_ndims event_ndims + ``` - for a given `Tensor`, e.g., the result of - `Distribution.sample(sample_shape=...)`. + for a given `Tensor`, e.g., the result of + `Distribution.sample(sample_shape=...)`. - For a given `Tensor`, this class computes the above table using minimal - information: `batch_ndims` and `event_ndims`. + For a given `Tensor`, this class computes the above table using minimal + information: `batch_ndims` and `event_ndims`. + + #### Examples + + We show examples of distribution shape semantics. - Examples of `Distribution` `shape` semantics: - Sample dimensions: Computing summary statistics, i.e., the average is a reduction over sample dimensions. @@ -111,52 +116,54 @@ class _DistributionShape(object): tf.div(1., tf.reduce_prod(x, event_dims)) ``` - Examples using this class: - Write `S, B, E` for `sample_shape`, `batch_shape`, and `event_shape`. - - ```python - # 150 iid samples from one multivariate Normal with two degrees of freedom. - mu = [0., 0] - sigma = [[1., 0], - [0, 1]] - mvn = MultivariateNormal(mu, sigma) - rand_mvn = mvn.sample(sample_shape=[3, 50]) - shaper = DistributionShape(batch_ndims=0, event_ndims=1) - S, B, E = shaper.get_shape(rand_mvn) - # S = [3, 50] - # B = [] - # E = [2] - - # 12 iid samples from one Wishart with 2x2 events. - sigma = [[1., 0], - [2, 1]] - wishart = Wishart(df=5, scale=sigma) - rand_wishart = wishart.sample(sample_shape=[3, 4]) - shaper = DistributionShape(batch_ndims=0, event_ndims=2) - S, B, E = shaper.get_shape(rand_wishart) - # S = [3, 4] - # B = [] - # E = [2, 2] - - # 100 iid samples from two, non-identical trivariate Normal distributions. - mu = ... # shape(2, 3) - sigma = ... # shape(2, 3, 3) - X = MultivariateNormal(mu, sigma).sample(shape=[4, 25]) - # S = [4, 25] - # B = [2] - # E = [3] - ``` - - Argument Validation: - When `validate_args=False`, checks that cannot be done during - graph construction are performed at graph execution. This may result in a - performance degradation because data must be switched from GPU to CPU. - - For example, when `validate_args=False` and `event_ndims` is a - non-constant `Tensor`, it is checked to be a non-negative integer at graph - execution. (Same for `batch_ndims`). Constant `Tensor`s and non-`Tensor` - arguments are always checked for correctness since this can be done for - "free," i.e., during graph construction. + We show examples using this class. + + Write `S, B, E` for `sample_shape`, `batch_shape`, and `event_shape`. + + ```python + # 150 iid samples from one multivariate Normal with two degrees of freedom. + mu = [0., 0] + sigma = [[1., 0], + [0, 1]] + mvn = MultivariateNormal(mu, sigma) + rand_mvn = mvn.sample(sample_shape=[3, 50]) + shaper = DistributionShape(batch_ndims=0, event_ndims=1) + S, B, E = shaper.get_shape(rand_mvn) + # S = [3, 50] + # B = [] + # E = [2] + + # 12 iid samples from one Wishart with 2x2 events. + sigma = [[1., 0], + [2, 1]] + wishart = Wishart(df=5, scale=sigma) + rand_wishart = wishart.sample(sample_shape=[3, 4]) + shaper = DistributionShape(batch_ndims=0, event_ndims=2) + S, B, E = shaper.get_shape(rand_wishart) + # S = [3, 4] + # B = [] + # E = [2, 2] + + # 100 iid samples from two, non-identical trivariate Normal distributions. + mu = ... # shape(2, 3) + sigma = ... # shape(2, 3, 3) + X = MultivariateNormal(mu, sigma).sample(shape=[4, 25]) + # S = [4, 25] + # B = [2] + # E = [3] + ``` + + #### Argument Validation + + When `validate_args=False`, checks that cannot be done during + graph construction are performed at graph execution. This may result in a + performance degradation because data must be switched from GPU to CPU. + + For example, when `validate_args=False` and `event_ndims` is a + non-constant `Tensor`, it is checked to be a non-negative integer at graph + execution. (Same for `batch_ndims`). Constant `Tensor`s and non-`Tensor` + arguments are always checked for correctness since this can be done for + "free," i.e., during graph construction. """ def __init__(self, diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index c4b8f055b7fbc3f0835b503eddd7617610326d8c..0d8a1926913766da374cb65767dccfa28bf75579 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -174,13 +174,12 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): skewness=skewness.dtype.as_numpy_dtype(0.), tailweight=tailweight, event_ndims=0) - # Make the Affine bijector, Z --> loc + scale * Z (2 / F_0(2)) + # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2)) c = 2 * scale / f_noskew.forward(ops.convert_to_tensor(2, dtype=dtype)) - affine = bijectors.Affine( + affine = bijectors.AffineScalar( shift=loc, - scale_identity_multiplier=c, - validate_args=validate_args, - event_ndims=0) + scale=c, + validate_args=validate_args) bijector = bijectors.Chain([affine, f]) diff --git a/tensorflow/contrib/distributions/python/ops/statistical_testing.py b/tensorflow/contrib/distributions/python/ops/statistical_testing.py new file mode 100644 index 0000000000000000000000000000000000000000..d66c34cc1a45cc09da5138a5f72ae3817690db49 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/statistical_testing.py @@ -0,0 +1,728 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Statistical test assertions calibrated for their error rates.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops + +__all__ = [ + "true_mean_confidence_interval_by_dkwm", + "assert_true_mean_equal_by_dkwm", + "min_discrepancy_of_true_means_detectable_by_dkwm", + "min_num_samples_for_dkwm_mean_test", + "assert_true_mean_equal_by_dkwm_two_sample", + "min_discrepancy_of_true_means_detectable_by_dkwm_two_sample", + "min_num_samples_for_dkwm_mean_two_sample_test", +] + + +def _batch_sort_vector(x, ascending=True, name=None): + with ops.name_scope(name, "sort_each_row", [x]): + x = ops.convert_to_tensor(x, name="x") + n = array_ops.shape(x)[-1] + if ascending: + y, _ = nn_ops.top_k(-x, k=n, sorted=True) + y = -y + else: + y, _ = nn_ops.top_k(x, k=n, sorted=True) + y.set_shape(x.shape) + return y + + +def _do_maximum_mean(samples, envelope, high, name=None): + """Common code between maximum_mean and minimum_mean.""" + with ops.name_scope(name, "do_maximum_mean", [samples, envelope, high]): + n = array_ops.rank(samples) + # Move the batch dimension of `samples` to the rightmost position, + # where the _batch_sort_vector function wants it. + perm = array_ops.concat([math_ops.range(1, n), [0]], axis=0) + samples = array_ops.transpose(samples, perm) + + samples = _batch_sort_vector(samples) + batch_shape = array_ops.shape(samples)[:-1] + n = array_ops.shape(samples)[-1] + step = 1. / math_ops.cast(n, dtype=samples.dtype.base_dtype) + + def _loop_body(iter_, total, to_skip): + total = array_ops.where( + step <= to_skip, + total, + array_ops.where( + to_skip > 0., + total + (step - to_skip) * samples[..., iter_], + total + step * samples[..., iter_])) + to_skip = array_ops.where(step <= to_skip, to_skip - step, 0.) + return [iter_ + 1, total, to_skip] + + _, total, _ = control_flow_ops.while_loop( + cond=lambda iter_, *args: iter_ < n, + body=_loop_body, + loop_vars=[ + 0, + array_ops.zeros(batch_shape, dtype=samples.dtype.base_dtype), + envelope, # to_skip + ]) + + return total + envelope * high + + +def _maximum_mean(samples, envelope, high, name=None): + """Returns a stochastic upper bound on the mean of a scalar distribution. + + The idea is that if the true CDF is within an `eps`-envelope of the + empirical CDF of the samples, and the support is bounded above, then + the mean is bounded above as well. In symbols, + + ```none + sup_x(|F_n(x) - F(x)|) < eps + ``` + + The 0th dimension of `samples` is interpreted as independent and + identically distributed samples. The remaining dimensions are + broadcast together with `envelope` and `high`, and operated on + separately. + + Args: + samples: Floating-point tensor of samples from the distribution(s) + of interest. Entries are assumed IID across the 0th dimension. + The other dimensions must broadcast with `envelope` and `high`. + envelope: Floating-point tensor of sizes of admissible CDF + envelopes (i.e., the `eps` above). + high: Floating-point tensor of upper bounds on the distributions' + supports. + name: A name for this operation (optional). + + Returns: + bound: Floating-point tensor of upper bounds on the true means. + + Raises: + InvalidArgumentError: If some `sample` is found to be larger than + the corresponding `high`. + """ + with ops.name_scope(name, "maximum_mean", [samples, envelope, high]): + samples = ops.convert_to_tensor(samples, name="samples") + envelope = ops.convert_to_tensor(envelope, name="envelope") + high = ops.convert_to_tensor(high, name="high") + + xmax = math_ops.reduce_max(samples, axis=[-1]) + msg = "Given sample maximum value exceeds expectations" + check_op = check_ops.assert_less_equal(xmax, high, message=msg) + with ops.control_dependencies([check_op]): + return array_ops.identity(_do_maximum_mean(samples, envelope, high)) + + +def _minimum_mean(samples, envelope, low, name=None): + """Returns a stochastic lower bound on the mean of a scalar distribution. + + The idea is that if the true CDF is within an `eps`-envelope of the + empirical CDF of the samples, and the support is bounded below, then + the mean is bounded below as well. In symbols, + + ```none + sup_x(|F_n(x) - F(x)|) < eps + ``` + + The 0th dimension of `samples` is interpreted as independent and + identically distributed samples. The remaining dimensions are + broadcast together with `envelope` and `low`, and operated on + separately. + + Args: + samples: Floating-point tensor of samples from the distribution(s) + of interest. Entries are assumed IID across the 0th dimension. + The other dimensions must broadcast with `envelope` and `low`. + envelope: Floating-point tensor of sizes of admissible CDF + envelopes (i.e., the `eps` above). + low: Floating-point tensor of lower bounds on the distributions' + supports. + name: A name for this operation (optional). + + Returns: + bound: Floating-point tensor of lower bounds on the true means. + + Raises: + InvalidArgumentError: If some `sample` is found to be smaller than + the corresponding `low`. + """ + with ops.name_scope(name, "minimum_mean", [samples, envelope, low]): + samples = ops.convert_to_tensor(samples, name="samples") + envelope = ops.convert_to_tensor(envelope, name="envelope") + low = ops.convert_to_tensor(low, name="low") + + xmin = math_ops.reduce_min(samples, axis=[-1]) + msg = "Given sample minimum value falls below expectations" + check_op = check_ops.assert_greater_equal(xmin, low, message=msg) + with ops.control_dependencies([check_op]): + return - _do_maximum_mean(-samples, envelope, -low) + + +def _dkwm_cdf_envelope(n, error_rate, name=None): + """Computes the CDF envelope that the DKWM inequality licenses. + + The [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval) + gives a stochastic bound on the distance between the true cumulative + distribution function (CDF) of any distribution and its empirical + CDF. To wit, for `n` iid samples from any distribution with CDF F, + + ```none + P(sup_x |F_n(x) - F(x)| > eps) < 2exp(-2n eps^2) + ``` + + This function computes the envelope size `eps` as a function of the + number of samples `n` and the desired limit on the left-hand + probability above. + + Args: + n: Tensor of numbers of samples drawn. + error_rate: Floating-point tensor of admissible rates of mistakes. + name: A name for this operation (optional). + + Returns: + eps: Tensor of maximum distances the true CDF can be from the + empirical CDF. This scales as `O(sqrt(-log(error_rate)))` and + as `O(1 / sqrt(n))`. The shape is the broadcast of `n` and + `error_rate`. + """ + with ops.name_scope(name, "dkwm_cdf_envelope", [n, error_rate]): + n = math_ops.cast(n, dtype=error_rate.dtype) + return math_ops.sqrt(-gen_math_ops.log(error_rate / 2.) / (2. * n)) + + +def _check_shape_dominates(tensor, tensors): + """Check that broadcasting `tensor` against `tensors` does not expand it. + + Why? Because I want to be very sure that the samples tensor is not + accidentally enlarged by broadcasting against tensors that are + supposed to be describing the distribution(s) sampled from, lest the + sample counts end up inflated. + + Args: + tensor: A Tensor whose shape is to be protected against broadcasting. + tensors: A list of Tensors to check + + Returns: + tensor: `tf.identity(tensor)` with control dependencies attached; + be sure to use that downstream. + """ + def check(t): + target = array_ops.shape(tensor)[1:] + result = array_ops.broadcast_dynamic_shape(target, array_ops.shape(t)) + # This rank check ensures that I don't get a wrong answer from the + # _shapes_ broadcasting against each other. + gt = check_ops.assert_greater(array_ops.rank(target), array_ops.rank(t)) + eq = check_ops.assert_equal(target, result) + return gt, eq + checks = list(itertools.chain(*[check(t) for t in tensors])) + with ops.control_dependencies(checks): + return array_ops.identity(array_ops.identity(tensor)) + + +def true_mean_confidence_interval_by_dkwm( + samples, low, high, error_rate=1e-6, name=None): + """Computes a confidence interval for the mean of a scalar distribution. + + In batch mode, computes confidence intervals for all distributions + in the batch (which need not be identically distributed). + + Relies on the [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval). + + The probability (over the randomness of drawing the given samples) + that any true mean is outside the corresponding returned interval is + no more than the given `error_rate`. The size of the intervals + scale as + `O(1 / sqrt(#samples))`, as `O(high - low)`, and as `O(-log(error_rate))`. + + Note that `error_rate` is a total error rate for all the confidence + intervals in the batch. As such, if the batch is nontrivial, the + error rate is not broadcast but divided (evenly) among the batch + members. + + Args: + samples: Floating-point tensor of samples from the distribution(s) + of interest. Entries are assumed IID across the 0th dimension. + The other dimensions must broadcast with `low` and `high`. + low: Floating-point tensor of lower bounds on the distributions' + supports. + high: Floating-point tensor of upper bounds on the distributions' + supports. + error_rate: *Scalar* admissible total rate of mistakes. + name: A name for this operation (optional). + + Returns: + low: A floating-point tensor of stochastic lower bounds on the true means. + high: A floating-point tensor of stochastic upper bounds on the true means. + """ + with ops.name_scope( + name, "true_mean_confidence_interval_by_dkwm", + [samples, low, high, error_rate]): + samples = ops.convert_to_tensor(samples, name="samples") + low = ops.convert_to_tensor(low, name="low") + high = ops.convert_to_tensor(high, name="high") + error_rate = ops.convert_to_tensor(error_rate, name="error_rate") + samples = _check_shape_dominates(samples, [low, high]) + check_ops.assert_scalar(error_rate) # Static shape + error_rate = _itemwise_error_rate(error_rate, [low, high], samples) + n = array_ops.shape(samples)[0] + envelope = _dkwm_cdf_envelope(n, error_rate) + min_mean = _minimum_mean(samples, envelope, low) + max_mean = _maximum_mean(samples, envelope, high) + return min_mean, max_mean + + +def _itemwise_error_rate( + total_error_rate, param_tensors, sample_tensor=None, name=None): + with ops.name_scope( + name, "itemwise_error_rate", + [total_error_rate, param_tensors, sample_tensor]): + result_shape = [1] + for p_tensor in param_tensors: + result_shape = array_ops.broadcast_dynamic_shape( + array_ops.shape(p_tensor), result_shape) + if sample_tensor is not None: + result_shape = array_ops.broadcast_dynamic_shape( + array_ops.shape(sample_tensor)[1:], result_shape) + num_items = math_ops.reduce_prod(result_shape) + return total_error_rate / math_ops.cast( + num_items, dtype=total_error_rate.dtype) + + +def assert_true_mean_equal_by_dkwm( + samples, low, high, expected, false_fail_rate=1e-6, name=None): + """Asserts the mean of the given distribution is as expected. + + More precisely, fails if there is enough evidence (using the + [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval)) + that the true mean of some distribution from which the given samples are + drawn is _not_ the given expected mean with statistical significance + `false_fail_rate` or stronger, otherwise passes. If you also want to + check that you are gathering enough evidence that a pass is not + spurious, see `min_num_samples_for_dkwm_mean_test` and + `min_discrepancy_of_true_means_detectable_by_dkwm`. + + Note that `false_fail_rate` is a total false failure rate for all + the assertions in the batch. As such, if the batch is nontrivial, + the assertion will insist on stronger evidence to fail any one member. + + Args: + samples: Floating-point tensor of samples from the distribution(s) + of interest. Entries are assumed IID across the 0th dimension. + The other dimensions must broadcast with `low` and `high`. + low: Floating-point tensor of lower bounds on the distributions' + supports. + high: Floating-point tensor of upper bounds on the distributions' + supports. + expected: Floating-point tensor of expected true means. + false_fail_rate: *Scalar* admissible total rate of mistakes. + name: A name for this operation (optional). + + Returns: + check: Op that raises `InvalidArgumentError` if any expected mean is + outside the corresponding confidence interval. + """ + with ops.name_scope( + name, "assert_true_mean_equal_by_dkwm", + [samples, low, high, expected, false_fail_rate]): + samples = ops.convert_to_tensor(samples, name="samples") + low = ops.convert_to_tensor(low, name="low") + high = ops.convert_to_tensor(high, name="high") + expected = ops.convert_to_tensor(expected, name="expected") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + samples = _check_shape_dominates(samples, [low, high, expected]) + min_mean, max_mean = true_mean_confidence_interval_by_dkwm( + samples, low, high, error_rate=false_fail_rate) + less_op = check_ops.assert_less( + min_mean, expected, message="Mean confidence interval too high") + with ops.control_dependencies([less_op]): + return check_ops.assert_greater( + max_mean, expected, message="Mean confidence interval too low") + + +def min_discrepancy_of_true_means_detectable_by_dkwm( + n, low, high, false_fail_rate, false_pass_rate, name=None): + """Returns the minimum mean discrepancy that a DKWM-based test can detect. + + DKWM is the [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval). + + Note that `false_fail_rate` is a total false failure rate for all + the tests in the batch. As such, if the batch is nontrivial, each + member will demand more samples. The `false_pass_rate` is also + interpreted as a total, but is treated asymmetrically: If each test + in the batch detects its corresponding discrepancy with probability + at least `1 - false_pass_rate`, then running all those tests and + failing if any one fails will jointly detect all those discrepancies + with the same `false_pass_rate`. + + Args: + n: Tensor of numbers of samples to be drawn from the distributions + of interest. + low: Floating-point tensor of lower bounds on the distributions' + supports. + high: Floating-point tensor of upper bounds on the distributions' + supports. + false_fail_rate: *Scalar* admissible total rate of false failures. + false_pass_rate: *Scalar* admissible rate of false passes. + name: A name for this operation (optional). + + Returns: + discr: Tensor of lower bounds on the distances between true + means detectable by a DKWM-based test. + + For each batch member `i`, of `K` total, drawing `n[i]` samples from + some scalar distribution supported on `[low[i], high[i]]` is enough + to detect a difference in means of size `discr[i]` or more. + Specifically, we guarantee that (a) if the true mean is the expected + mean, `assert_true_mean_equal_by_dkwm` will fail with probability at + most `false_fail_rate / K` (which amounts to `false_fail_rate` if + applied to the whole batch at once), and (b) if the true mean + differs from the expected mean by at least `discr[i]`, + `assert_true_mean_equal_by_dkwm` will pass with probability at most + `false_pass_rate`. + + The detectable discrepancy scales as + + - `O(high[i] - low[i])`, + - `O(1 / sqrt(n[i]))`, + - `O(-log(false_fail_rate/K))`, and + - `O(-log(false_pass_rate))`. + """ + with ops.name_scope( + name, "min_discrepancy_of_true_means_detectable_by_dkwm", + [n, low, high, false_fail_rate, false_pass_rate]): + n = ops.convert_to_tensor(n, name="n") + low = ops.convert_to_tensor(low, name="low") + high = ops.convert_to_tensor(high, name="high") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + false_pass_rate = ops.convert_to_tensor( + false_pass_rate, name="false_pass_rate") + # Algorithm: Assume a true CDF F. The DKWM inequality gives a + # stochastic bound on how far the observed empirical CDF F_n can be. + # Then, using the DKWM inequality again gives a stochastic bound on + # the farthest candidate true CDF F' that + # true_mean_confidence_interval_by_dkwm might consider. At worst, these + # errors may go in the same direction, so the distance between F and + # F' is bounded by the sum. + # On batching: false fail rates sum, so I need to reduce + # the input to account for the batching. False pass rates + # max, so I don't. + sampling_envelope = _dkwm_cdf_envelope(n, false_pass_rate) + false_fail_rate = _itemwise_error_rate(false_fail_rate, [n, low, high]) + analysis_envelope = _dkwm_cdf_envelope(n, false_fail_rate) + return (high - low) * (sampling_envelope + analysis_envelope) + + +def min_num_samples_for_dkwm_mean_test( + discrepancy, low, high, + false_fail_rate=1e-6, false_pass_rate=1e-6, name=None): + """Returns how many samples suffice for a one-sample DKWM mean test. + + To wit, returns an upper bound on the number of samples necessary to + guarantee detecting a mean difference of at least the given + `discrepancy`, with the given `false_fail_rate` and `false_pass_rate`, + using the [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval) + on a scalar distribution supported on `[low, high]`. + + Args: + discrepancy: Floating-point tensor of desired upper limits on mean + differences that may go undetected with probability higher than + `1 - false_pass_rate`. + low: Tensor of lower bounds on the distributions' support. + high: Tensor of upper bounds on the distributions' support. + false_fail_rate: *Scalar* admissible total rate of false failures. + false_pass_rate: *Scalar* admissible rate of false passes. + name: A name for this operation (optional). + + Returns: + n: Tensor of numbers of samples to be drawn from the distributions + of interest. + + The `discrepancy`, `low`, and `high` tensors must have + broadcast-compatible shapes. + + For each batch member `i`, of `K` total, drawing `n[i]` samples from + some scalar distribution supported on `[low[i], high[i]]` is enough + to detect a difference in means of size `discrepancy[i]` or more. + Specifically, we guarantee that (a) if the true mean is the expected + mean, `assert_true_mean_equal_by_dkwm` will fail with probability at + most `false_fail_rate / K` (which amounts to `false_fail_rate` if + applied to the whole batch at once), and (b) if the true mean + differs from the expected mean by at least `discrepancy[i]`, + `assert_true_mean_equal_by_dkwm` will pass with probability at most + `false_pass_rate`. + + The required number of samples scales + as `O((high[i] - low[i])**2)`, `O(-log(false_fail_rate/K))`, + `O(-log(false_pass_rate))`, and `O(1 / discrepancy[i]**2)`. + """ + with ops.name_scope( + name, "min_num_samples_for_dkwm_mean_test", + [low, high, false_fail_rate, false_pass_rate, discrepancy]): + discrepancy = ops.convert_to_tensor( + discrepancy, name="discrepancy") + low = ops.convert_to_tensor(low, name="low") + high = ops.convert_to_tensor(high, name="high") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + false_pass_rate = ops.convert_to_tensor( + false_pass_rate, name="false_pass_rate") + # Could choose to cleverly allocate envelopes, but this is sound. + envelope1 = discrepancy / (2. * (high - low)) + envelope2 = envelope1 + false_fail_rate = _itemwise_error_rate( + false_fail_rate, [low, high, discrepancy]) + n1 = -math_ops.log(false_fail_rate / 2.) / (2. * envelope1**2) + n2 = -math_ops.log(false_pass_rate / 2.) / (2. * envelope2**2) + return math_ops.maximum(n1, n2) + + +def assert_true_mean_equal_by_dkwm_two_sample( + samples1, low1, high1, samples2, low2, high2, + false_fail_rate=1e-6, name=None): + """Asserts the means of the given distributions are equal. + + More precisely, fails if there is enough evidence (using the + [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval)) + that the means of the distributions from which the given samples are + drawn are _not_ equal with statistical significance `false_fail_rate` + or stronger, otherwise passes. If you also want to check that you + are gathering enough evidence that a pass is not spurious, see + `min_num_samples_for_dkwm_mean_two_sample_test` and + `min_discrepancy_of_true_means_detectable_by_dkwm_two_sample`. + + Note that `false_fail_rate` is a total false failure rate for all + the assertions in the batch. As such, if the batch is nontrivial, + the assertion will insist on stronger evidence to fail any one member. + + Args: + samples1: Floating-point tensor of samples from the + distribution(s) A. Entries are assumed IID across the 0th + dimension. The other dimensions must broadcast with `low1`, + `high1`, `low2`, and `high2`. + low1: Floating-point tensor of lower bounds on the supports of the + distributions A. + high1: Floating-point tensor of upper bounds on the supports of + the distributions A. + samples2: Floating-point tensor of samples from the + distribution(s) B. Entries are assumed IID across the 0th + dimension. The other dimensions must broadcast with `low1`, + `high1`, `low2`, and `high2`. + low2: Floating-point tensor of lower bounds on the supports of the + distributions B. + high2: Floating-point tensor of upper bounds on the supports of + the distributions B. + false_fail_rate: *Scalar* admissible total rate of mistakes. + name: A name for this operation (optional). + + Returns: + check: Op that raises `InvalidArgumentError` if any pair of confidence + intervals true for corresponding true means do not overlap. + """ + with ops.name_scope( + name, "assert_true_mean_equal_by_dkwm_two_sample", + [samples1, low1, high1, samples2, low2, high2, false_fail_rate]): + samples1 = ops.convert_to_tensor(samples1, name="samples1") + low1 = ops.convert_to_tensor(low1, name="low1") + high1 = ops.convert_to_tensor(high1, name="high1") + samples2 = ops.convert_to_tensor(samples2, name="samples2") + low2 = ops.convert_to_tensor(low2, name="low2") + high2 = ops.convert_to_tensor(high2, name="high2") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + samples1 = _check_shape_dominates(samples1, [low1, high1]) + samples2 = _check_shape_dominates(samples2, [low2, high2]) + compatible_samples = check_ops.assert_equal( + array_ops.shape(samples1)[1:], array_ops.shape(samples2)[1:]) + with ops.control_dependencies([compatible_samples]): + # Could in principle play games with cleverly allocating + # significance instead of the even split below. It may be possible + # to get tighter intervals, in order to obtain a higher power test. + # Any allocation strategy that depends only on the support bounds + # and sample counts should be valid; however, because the intervals + # scale as O(-log(false_fail_rate)), there doesn't seem to be much + # room to win. + min_mean_1, max_mean_1 = true_mean_confidence_interval_by_dkwm( + samples1, low1, high1, false_fail_rate / 2.) + min_mean_2, max_mean_2 = true_mean_confidence_interval_by_dkwm( + samples2, low2, high2, false_fail_rate / 2.) + # I want to assert + # not (max_mean_1 < min_mean_2 or min_mean_1 > max_mean_2), + # but I think I only have and-combination of asserts, so use DeMorgan. + clause1_op = check_ops.assert_greater_equal(max_mean_1, min_mean_2) + with ops.control_dependencies([clause1_op]): + return check_ops.assert_less_equal(min_mean_1, max_mean_2) + + +def min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( + n1, low1, high1, n2, low2, high2, + false_fail_rate, false_pass_rate, name=None): + """Returns the minimum mean discrepancy for a two-sample DKWM-based test. + + DKWM is the [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval). + + Note that `false_fail_rate` is a total false failure rate for all + the tests in the batch. As such, if the batch is nontrivial, each + member will demand more samples. The `false_pass_rate` is also + interpreted as a total, but is treated asymmetrically: If each test + in the batch detects its corresponding discrepancy with probability + at least `1 - false_pass_rate`, then running all those tests and + failing if any one fails will jointly detect all those discrepancies + with the same `false_pass_rate`. + + Args: + n1: Tensor of numbers of samples to be drawn from the distributions A. + low1: Floating-point tensor of lower bounds on the supports of the + distributions A. + high1: Floating-point tensor of upper bounds on the supports of + the distributions A. + n2: Tensor of numbers of samples to be drawn from the distributions B. + low2: Floating-point tensor of lower bounds on the supports of the + distributions B. + high2: Floating-point tensor of upper bounds on the supports of + the distributions B. + false_fail_rate: *Scalar* admissible total rate of false failures. + false_pass_rate: *Scalar* admissible rate of false passes. + name: A name for this operation (optional). + + Returns: + discr: Tensor of lower bounds on the distances between true means + detectable by a two-sample DKWM-based test. + + For each batch member `i`, of `K` total, drawing `n1[i]` samples + from scalar distribution A supported on `[low1[i], high1[i]]` and `n2[i]` + samples from scalar distribution B supported on `[low2[i], high2[i]]` + is enough to detect a difference in their true means of size + `discr[i]` or more. Specifically, we guarantee that (a) if their + true means are equal, `assert_true_mean_equal_by_dkwm_two_sample` + will fail with probability at most `false_fail_rate/K` (which + amounts to `false_fail_rate` if applied to the whole batch at once), + and (b) if their true means differ by at least `discr[i]`, + `assert_true_mean_equal_by_dkwm_two_sample` will pass with + probability at most `false_pass_rate`. + + The detectable distribution scales as + + - `O(high1[i] - low1[i])`, `O(high2[i] - low2[i])`, + - `O(1 / sqrt(n1[i]))`, `O(1 / sqrt(n2[i]))`, + - `O(-log(false_fail_rate/K))`, and + - `O(-log(false_pass_rate))`. + """ + with ops.name_scope( + name, "min_discrepancy_of_true_means_detectable_by_dkwm_two_sample", + [n1, low1, high1, n2, low2, high2, false_fail_rate, false_pass_rate]): + n1 = ops.convert_to_tensor(n1, name="n1") + low1 = ops.convert_to_tensor(low1, name="low1") + high1 = ops.convert_to_tensor(high1, name="high1") + n2 = ops.convert_to_tensor(n2, name="n2") + low2 = ops.convert_to_tensor(low2, name="low2") + high2 = ops.convert_to_tensor(high2, name="high2") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + false_pass_rate = ops.convert_to_tensor( + false_pass_rate, name="false_pass_rate") + det_disc1 = min_discrepancy_of_true_means_detectable_by_dkwm( + n1, low1, high1, false_fail_rate / 2., false_pass_rate / 2.) + det_disc2 = min_discrepancy_of_true_means_detectable_by_dkwm( + n2, low2, high2, false_fail_rate / 2., false_pass_rate / 2.) + return det_disc1 + det_disc2 + + +def min_num_samples_for_dkwm_mean_two_sample_test( + discrepancy, low1, high1, low2, high2, + false_fail_rate=1e-6, false_pass_rate=1e-6, name=None): + """Returns how many samples suffice for a two-sample DKWM mean test. + + DKWM is the [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval). + + Args: + discrepancy: Floating-point tensor of desired upper limits on mean + differences that may go undetected with probability higher than + `1 - false_pass_rate`. + low1: Floating-point tensor of lower bounds on the supports of the + distributions A. + high1: Floating-point tensor of upper bounds on the supports of + the distributions A. + low2: Floating-point tensor of lower bounds on the supports of the + distributions B. + high2: Floating-point tensor of upper bounds on the supports of + the distributions B. + false_fail_rate: *Scalar* admissible total rate of false failures. + false_pass_rate: *Scalar* admissible rate of false passes. + name: A name for this operation (optional). + + Returns: + n1: Tensor of numbers of samples to be drawn from the distributions A. + n2: Tensor of numbers of samples to be drawn from the distributions B. + + For each batch member `i`, of `K` total, drawing `n1[i]` samples + from scalar distribution A supported on `[low1[i], high1[i]]` and `n2[i]` + samples from scalar distribution B supported on `[low2[i], high2[i]]` + is enough to detect a difference in their true means of size + `discr[i]` or more. Specifically, we guarantee that (a) if their + true means are equal, `assert_true_mean_equal_by_dkwm_two_sample` + will fail with probability at most `false_fail_rate/K` (which + amounts to `false_fail_rate` if applied to the whole batch at once), + and (b) if their true means differ by at least `discr[i]`, + `assert_true_mean_equal_by_dkwm_two_sample` will pass with + probability at most `false_pass_rate`. + + The required number of samples scales as + + - `O((high1[i] - low1[i])**2)`, `O((high2[i] - low2[i])**2)`, + - `O(-log(false_fail_rate/K))`, + - `O(-log(false_pass_rate))`, and + - `O(1 / discrepancy[i]**2)`. + """ + with ops.name_scope( + name, "min_num_samples_for_dkwm_mean_two_sample_test", + [low1, high1, low2, high2, + false_fail_rate, false_pass_rate, discrepancy]): + discrepancy = ops.convert_to_tensor(discrepancy, name="discrepancy") + low1 = ops.convert_to_tensor(low1, name="low1") + high1 = ops.convert_to_tensor(high1, name="high1") + low2 = ops.convert_to_tensor(low2, name="low2") + high2 = ops.convert_to_tensor(high2, name="high2") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + false_pass_rate = ops.convert_to_tensor( + false_pass_rate, name="false_pass_rate") + # Could choose to cleverly allocate discrepancy tolerances and + # failure probabilities, but this is sound. + n1 = min_num_samples_for_dkwm_mean_test( + discrepancy / 2., low1, high1, + false_fail_rate / 2., false_pass_rate / 2.) + n2 = min_num_samples_for_dkwm_mean_test( + discrepancy / 2., low2, high2, + false_fail_rate / 2., false_pass_rate / 2.) + return n1, n2 diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 0c747f8e68529484ae6f695b8500cde74857bb11..971d65c4a69140161461fdac93bb588014dd3e88 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -181,7 +181,7 @@ def quadrature_scheme_softmaxnormal_quantiles( edges = array_ops.reshape(edges, shape=array_ops.concat([ [-1], array_ops.ones([batch_ndims], dtype=dtypes.int32)], axis=0)) quantiles = dist.quantile(edges) - quantiles = SoftmaxCentered(event_ndims=1).forward(quantiles) + quantiles = SoftmaxCentered().forward(quantiles) # Cyclically permute left by one. perm = array_ops.concat([ math_ops.range(1, 1 + batch_ndims), [0]], axis=0) @@ -248,11 +248,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): The default quadrature scheme chooses `z_{N, n}` as `N` midpoints of the quantiles of `p(z)` (generalized quantiles if `K > 2`). - See [1] for more details. - - [1]. "Quadrature Compound: An approximating family of distributions" - Joshua Dillon, Ian Langmore, arXiv preprints - https://arxiv.org/abs/1801.03080 + See [Dillon and Langmore (2018)][1] for more details. #### About `Vector` distributions in TensorFlow. @@ -313,6 +309,13 @@ class VectorDiffeomixture(distribution_lib.Distribution): is_positive_definite=True), ], validate_args=True) + ``` + + #### References + + [1]: Joshua Dillon and Ian Langmore. Quadrature Compound: An approximating + family of distributions. _arXiv preprint arXiv:1801.03080_, 2018. + https://arxiv.org/abs/1801.03080 """ def __init__(self, diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index e1ccf116457a97261b9ce3965552764771d3bdd2..003c66b9413fdcad20fbcc8b4bf47259692932e7 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -227,7 +227,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): c = 2 * scale_diag_part / f_noskew.forward( ops.convert_to_tensor(2, dtype=dtype)) affine = bijectors.Affine( - shift=loc, scale_diag=c, validate_args=validate_args, event_ndims=1) + shift=loc, scale_diag=c, validate_args=validate_args) bijector = bijectors.Chain([affine, f]) diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index e4ac65012b9c7e3ed5ada3ed75020f3905740156..5a8c94dabf4c3c430bee544a48ee7acfe7dd7ed0 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -228,9 +228,12 @@ class _WishartLinearOperator(distribution.Distribution): # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) + expanded_df = self.df * array_ops.ones( + self.scale_operator.batch_shape_tensor(), + dtype=self.df.dtype.base_dtype) g = random_ops.random_gamma(shape=[n], alpha=self._multi_gamma_sequence( - 0.5 * self.df, self.dimension), + 0.5 * expanded_df, self.dimension), beta=0.5, dtype=self.dtype, seed=distribution_util.gen_new_seed( diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index a26ec8513f4b7b9c278edddc95e6acd2523194f2..80176397c02f22095a3a9be3d12c2115ec4eca29 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -11,12 +11,14 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ + ":checkpointable_utils", ":datasets", ":metrics", ":network", ":saver", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", "//tensorflow/python:numerics", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:script_ops", @@ -26,7 +28,6 @@ py_library( "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:core", - "//tensorflow/python/eager:custom_gradient", "//tensorflow/python/eager:execution_callbacks", "//tensorflow/python/eager:function", ], @@ -69,6 +70,7 @@ cuda_py_test( srcs = ["datasets_test.py"], additional_deps = [ ":datasets", + ":checkpointable_utils", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/python:dtypes", @@ -116,6 +118,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/contrib/eager/python:checkpointable_utils", "//tensorflow/contrib/summary:summary_ops", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", @@ -230,24 +233,27 @@ py_library( "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", + "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", - "//tensorflow/python:io_ops", + "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:session", "//tensorflow/python:tensor_shape", "//tensorflow/python:training", + "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", ], ) -py_test( +cuda_py_test( name = "checkpointable_utils_test", srcs = ["checkpointable_utils_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":checkpointable_utils", ":network", + "@six_archive//:six", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -262,7 +268,12 @@ py_test( "//tensorflow/python:variables", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", - "@six_archive//:six", + "//tensorflow/python/keras", + ], + tags = [ + "no_oss", # b/74395663 + "no_windows", # TODO: needs investigation on Windows + "notsan", ], ) diff --git a/tensorflow/contrib/eager/python/checkpointable_utils.py b/tensorflow/contrib/eager/python/checkpointable_utils.py index d9648ffb03d19429911b23d8c40c532e75e1cdfd..91a7aded11db6b4c8bcb061da6d6c69253603c85 100644 --- a/tensorflow/contrib/eager/python/checkpointable_utils.py +++ b/tensorflow/contrib/eager/python/checkpointable_utils.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections import weakref @@ -26,17 +27,18 @@ from tensorflow.python.client import session as session_lib from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops -from tensorflow.python.ops import io_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import checkpointable as core_checkpointable from tensorflow.python.training import checkpointable_utils as core_checkpointable_utils from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import saver as saver_lib +from tensorflow.python.util import deprecation _ESCAPE_CHAR = "." # For avoiding conflicts with user-specified names. @@ -217,12 +219,16 @@ def _serialize_checkpointables( object_proto = object_graph_proto.nodes.add() object_proto.slot_variables.extend(slot_variables.get(checkpointable, ())) object_name = object_names[checkpointable] - for name, saveable in ( + for name, saveable_factory in ( checkpointable._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access attribute = object_proto.attributes.add() attribute.name = name attribute.checkpoint_key = "%s/%s/%s" % ( object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name)) + if callable(saveable_factory): + saveable = saveable_factory(name=attribute.checkpoint_key) + else: + saveable = saveable_factory # Figure out the name-based Saver's name for this variable. saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( [saveable], convert_variable_to_tensor=False) @@ -278,6 +284,37 @@ def _serialize_object_graph(root_checkpointable): slot_variables=slot_variables) +def gather_initializers(root_checkpointable): + """Traverse the object graph and find initialization ops. + + Looks for `Checkpointable` objects which are dependencies of + `root_checkpointable` and which have an `initializer` property. Includes + initializers for slot variables only if the variable they are slotting for and + the optimizer are dependencies of `root_checkpointable` (i.e. if they would be + saved with a checkpoint). + + Args: + root_checkpointable: A `Checkpointable` object to gather initializers for. + Returns: + A list of initialization ops. + """ + # TODO(allenl): Extract out gathering logic so the naming logic doesn't have + # to run. + checkpointable_objects, path_to_root = ( + _breadth_first_checkpointable_traversal(root_checkpointable)) + object_names = { + obj: _object_prefix_from_path(path) + for obj, path in path_to_root.items()} + node_ids = {node: node_id for node_id, node + in enumerate(checkpointable_objects)} + _serialize_slot_variables( + checkpointable_objects=checkpointable_objects, + node_ids=node_ids, + object_names=object_names) + return [c.initializer for c in checkpointable_objects + if hasattr(c, "initializer") and c.initializer is not None] + + class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject): def __init__(self, tensor, name): @@ -288,7 +325,26 @@ class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject): return control_flow_ops.no_op() -class CheckpointLoadStatus(object): +class _LoadStatus(object): + """Abstract base for load status callbacks.""" + + @abc.abstractmethod + def assert_consumed(self): + """Raises an exception unless a non-trivial restoration has completed.""" + pass + + @abc.abstractmethod + def run_restore_ops(self, session=None): + """Runs restore ops from the checkpoint. Requires a valid checkpoint.""" + pass + + @abc.abstractmethod + def initialize_or_restore(self, session=None): + """Runs restore ops from the checkpoint, or initializes variables.""" + pass + + +class CheckpointLoadStatus(_LoadStatus): """Checks the status of checkpoint loading and manages restore ops. Returned from `Saver.restore`. Since `restore` may defer the loading of values @@ -342,12 +398,112 @@ class CheckpointLoadStatus(object): def run_restore_ops(self, session=None): """Run operations to restore objects in the dependency graph.""" - if context.in_eager_mode(): + if context.executing_eagerly(): return # Run eagerly if session is None: session = ops.get_default_session() session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict) + def initialize_or_restore(self, session=None): + """Alias for `run_restore_ops`. + + This method has a sibling in `InitializationOnlyStatus` which instead + initializes variables. That type is returned if no checkpoint is specified + in `Saver.restore`. + + Args: + session: The session to run restore ops in. If `None`, uses the default + session. + """ + self.run_restore_ops(session=session) + + +class InitializationOnlyStatus(_LoadStatus): + """Returned from `Saver.restore` when no checkpoint has been specified. + + Objects of this type have the same `assert_consumed` method as + `CheckpointLoadStatus`, but it always fails. However, + `initialize_or_restore` works on objects of both types, and will + initialize variables in `InitializationOnlyStatus` objects or restore them + otherwise. + """ + + def __init__(self, root_checkpointable): + self._root_checkpointable = root_checkpointable + + def assert_consumed(self): + """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" + raise AssertionError( + "No checkpoint specified (save_path=None); nothing is being restored.") + + def run_restore_ops(self, session=None): + """For consistency with `CheckpointLoadStatus`. + + Use `initialize_or_restore` for initializing if no checkpoint was passed + to `Saver.restore` and restoring otherwise. + + Args: + session: Not used. + """ + raise AssertionError( + "No checkpoint specified, so no restore ops are available " + "(save_path=None to Saver.restore).") + + def initialize_or_restore(self, session=None): + """Runs initialization ops for variables. + + Only objects which would be saved by `Saver.save` will be initialized. See + `gather_initializers` for details. + + This method does nothing when executing eagerly (initializers get run + eagerly). + + Args: + session: The session to run initialization ops in. If `None`, uses the + default session. + """ + if context.executing_eagerly(): + return # run eagerly + if session is None: + session = ops.get_default_session() + session.run(gather_initializers(self._root_checkpointable)) + + +_DEPRECATED_RESTORE_INSTRUCTIONS = ( + "Restoring a name-based tf.train.Saver checkpoint using the object-based " + "restore API. This mode uses global names to match variables, and so is " + "somewhat fragile. It also adds new restore ops to the graph each time it " + "is called. Prefer re-encoding training checkpoints in the object-based " + "format: run save() on the object-based saver (the same one this message " + "is coming from) and use that checkpoint in the future.") + + +class NameBasedSaverStatus(_LoadStatus): + """Status for loading a name-based training checkpoint.""" + + def __init__(self, object_saver, save_path): + self._object_saver = object_saver + self._save_path = save_path + + def assert_consumed(self): + """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" + raise AssertionError( + "Restoring a name-based checkpoint. No load status is available.") + + @deprecation.deprecated( + date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS) + def run_restore_ops(self, session=None): + """Load the name-based training checkpoint using a new `tf.train.Saver`.""" + if session is None and not context.executing_eagerly(): + session = ops.get_default_session() + with ops.device("/cpu:0"): + saver_lib.Saver(self._object_saver._global_variable_names()).restore( # pylint: disable=protected-access + sess=session, save_path=self._save_path) + + def initialize_or_restore(self, session=None): + """Alias for `run_restore_ops`.""" + self.run_restore_ops(session=session) + class _SessionWithFeedDictAdditions(session_lib.SessionInterface): """Pretends to be a session, inserts extra feeds on run().""" @@ -366,7 +522,19 @@ class _SessionWithFeedDictAdditions(session_lib.SessionInterface): fetches=fetches, feed_dict=feed_dict, **kwargs) -class Saver(object): +def _copy_saver_with_new_var_list(old_saver, new_var_list): + """Copy a `tf.train.Saver`'s state to a new Saver with different variables.""" + new_saver = saver_lib.Saver(var_list=new_var_list) + # TODO(allenl): Move to copying functionality to Saver? + # pylint: disable=protected-access + new_saver._last_checkpoints = old_saver._last_checkpoints + new_saver._checkpoints_to_be_deleted = old_saver._checkpoints_to_be_deleted + new_saver._next_checkpoint_time = old_saver._next_checkpoint_time + # pylint: enable=protected-access + return new_saver + + +class CheckpointableSaver(object): """Saves and restores a `Checkpointable` object and its dependencies. See `Checkpointable` for details of dependency management. `Saver` wraps @@ -396,8 +564,9 @@ class Saver(object): # Allow passing in a weak reference to avoid reference cycles when # `Checkpointable` objects save themselves. self._root_checkpointable_ref = root_checkpointable - if context.in_graph_mode(): - self._file_prefix_placeholder = constant_op.constant("model") + if not context.executing_eagerly(): + with ops.device("/cpu:0"): + self._file_prefix_placeholder = constant_op.constant("model") else: self._file_prefix_placeholder = None @@ -407,7 +576,6 @@ class Saver(object): self._last_save_saver = None # Op caching for restore - self._object_graph_restore_tensor = None self._last_restore_object_graph = None self._last_restore_checkpoint = None @@ -429,7 +597,7 @@ class Saver(object): Args: file_prefix: A prefix to use for the checkpoint filenames (/path/to/directory/and_a_prefix). Names are generated based on this - prefix and the global step, if provided. + prefix and `checkpoint_number`, if provided. checkpoint_number: An integer variable or Tensor, used to number checkpoints. Typically this value is saved along with other variables in training checkpoints, which will happen automatically if it was created @@ -444,46 +612,58 @@ class Saver(object): """ named_variables, graph_proto = _serialize_object_graph( self._root_checkpointable) - in_graph_mode = context.in_graph_mode() - if in_graph_mode: + if not context.executing_eagerly(): if session is None: session = ops.get_default_session() if self._object_graph_feed_tensor is None: - self._object_graph_feed_tensor = constant_op.constant( - "", dtype=dtypes.string) + with ops.device("/cpu:0"): + self._object_graph_feed_tensor = constant_op.constant( + "", dtype=dtypes.string) object_graph_tensor = self._object_graph_feed_tensor feed_additions = {object_graph_tensor: graph_proto.SerializeToString()} else: session = None - object_graph_tensor = constant_op.constant( - graph_proto.SerializeToString(), dtype=dtypes.string) + with ops.device("/cpu:0"): + object_graph_tensor = constant_op.constant( + graph_proto.SerializeToString(), dtype=dtypes.string) feed_additions = None assert _OBJECT_GRAPH_PROTO_KEY not in named_variables named_variables[_OBJECT_GRAPH_PROTO_KEY] = _NoRestoreSaveable( tensor=object_graph_tensor, name=_OBJECT_GRAPH_PROTO_KEY) - if not in_graph_mode or self._last_save_object_graph != graph_proto: - if self._last_save_object_graph is not None and in_graph_mode: - raise NotImplementedError( - "Using a single Saver to save a mutated object graph is not " - "currently supported when graph building. Use a different Saver " - "when the object graph changes (save ops will be duplicated), or " - "file a feature request if this limitation bothers you.") - saver = saver_lib.Saver(var_list=named_variables) - if in_graph_mode: - self._last_save_saver = saver - self._last_save_object_graph = graph_proto - else: - saver = self._last_save_saver - save_path = saver.save( - sess=_SessionWithFeedDictAdditions( - session=session, feed_additions=feed_additions), - save_path=file_prefix, - write_meta_graph=False, - global_step=checkpoint_number) + if (self._last_save_object_graph != graph_proto + # When executing eagerly, we need to re-create SaveableObjects each time + # save() is called so they pick up new Tensors passed to their + # constructors. That means the Saver needs to be copied with a new + # var_list. + or context.executing_eagerly()): + if self._last_save_object_graph is not None: + self._last_save_saver = _copy_saver_with_new_var_list( + old_saver=self._last_save_saver, new_var_list=named_variables) + else: + self._last_save_saver = saver_lib.Saver(var_list=named_variables) + self._last_save_object_graph = graph_proto + with ops.device("/cpu:0"): + save_path = self._last_save_saver.save( + sess=_SessionWithFeedDictAdditions( + session=session, feed_additions=feed_additions), + save_path=file_prefix, + write_meta_graph=False, + global_step=checkpoint_number) return save_path - def restore(self, save_path, session=None): + def _global_variable_names(self): + """Generate a `tf.train.Saver`-style `var_list` using `variable.name`s.""" + named_saveables, graph_proto = _serialize_object_graph( + self._root_checkpointable) + saver_names = {} + for object_proto in graph_proto.nodes: + for attribute_proto in object_proto.attributes: + saver_names[attribute_proto.full_name] = named_saveables[ + attribute_proto.checkpoint_key] + return saver_names + + def restore(self, save_path): """Restore a training checkpoint. Restores `root_checkpointable` and any objects that it tracks @@ -493,8 +673,7 @@ class Saver(object): constructor after this call will be matched if they have a corresponding object in the checkpoint. - When building a graph, restorations are added to the graph but not run. A - session is required to retrieve checkpoint metadata. + When building a graph, restorations are added to the graph but not run. To disallow deferred loading, assert immediately that all checkpointed variables have been matched to variable objects: @@ -518,45 +697,48 @@ class Saver(object): If the checkpoint has not been consumed completely, then the list of restore ops will grow as more objects are added to the dependency graph. + Name-based `tf.train.Saver` checkpoints can be loaded using this + method. There is no deferred loading, and names are used to match + variables. No restore ops are created/run until `run_restore_ops()` or + `initialize_or_restore()` are called on the returned status object, even + when executing eagerly. Re-encode name-based checkpoints using this + object-based `Saver.save` as soon as possible. + Args: save_path: The path to the checkpoint, as returned by `save` or `tf.train.latest_checkpoint`. If None (as when there is no latest - checkpoint for `tf.train.latest_checkpoint` to return), does nothing. - session: The session to retrieve metadata with. Ignored when executing - eagerly. If not provided when graph building, the default session is - used. + checkpoint for `tf.train.latest_checkpoint` to return), returns an + object which may run initializers for objects in the dependency + graph. If the checkpoint was written by the name-based `tf.train.Saver`, + names are used to match variables. Returns: - A `CheckpointLoadStatus` object, which can be used to make assertions - about the status of checkpoint restoration and run restore ops. + A load status object, which can be used to make assertions about the + status of checkpoint restoration and run initialization/restore ops + (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if + `save_path` is `None`). + + If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus` + object is returned which runs restore ops from a name-based saver. """ if save_path is None: - return - in_graph_mode = context.in_graph_mode() + return InitializationOnlyStatus(self._root_checkpointable) + in_graph_mode = not context.executing_eagerly() if in_graph_mode: - if session is None: - session = ops.get_default_session() file_prefix_tensor = self._file_prefix_placeholder file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} else: - session = None - file_prefix_tensor = constant_op.constant(save_path) + with ops.device("/cpu:0"): + file_prefix_tensor = constant_op.constant(save_path) file_prefix_feed_dict = None - if not in_graph_mode or self._object_graph_restore_tensor is None: - object_graph_string, = io_ops.restore_v2( - prefix=file_prefix_tensor, - tensor_names=[_OBJECT_GRAPH_PROTO_KEY], - shape_and_slices=[""], - dtypes=[dtypes.string], - name="object_graph_proto_read") - if in_graph_mode: - self._object_graph_restore_tensor = object_graph_string - if in_graph_mode: - object_graph_string = session.run( - self._object_graph_restore_tensor, - feed_dict=file_prefix_feed_dict) - else: - object_graph_string = object_graph_string.numpy() + reader = pywrap_tensorflow.NewCheckpointReader(save_path) + try: + object_graph_string = reader.get_tensor(_OBJECT_GRAPH_PROTO_KEY) + except errors_impl.NotFoundError: + # The object graph proto does not exist in this checkpoint. Try again with + # name-based saving. + return NameBasedSaverStatus(self, save_path) + object_graph_proto = ( checkpointable_object_graph_pb2.CheckpointableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) @@ -566,7 +748,6 @@ class Saver(object): if in_graph_mode: dtype_map = None else: - reader = pywrap_tensorflow.NewCheckpointReader(save_path) dtype_map = reader.get_variable_to_dtype_map() checkpoint = core_checkpointable_utils._Checkpoint( # pylint: disable=protected-access object_graph_proto=object_graph_proto, @@ -586,3 +767,103 @@ class Saver(object): load_status = CheckpointLoadStatus( checkpoint, feed_dict=file_prefix_feed_dict) return load_status + + +class Checkpoint(core_checkpointable.Checkpointable): + """A utility class which groups `Checkpointable` objects. + + Accepts arbitrary keyword arguments to its constructor and saves those values + with a checkpoint. Maintains a `save_counter` for numbering checkpoints. + + Example usage: + + ```python + import tensorflow as tf + import tensorflow.contrib.eager as tfe + import os + + checkpoint_directory = "/tmp/training_checkpoints" + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + + root = tfe.Checkpoint(optimizer=optimizer, model=model) + root.restore(tf.train.latest_checkpoint(checkpoint_directory)) + for _ in range(num_training_steps): + optimizer.minimize( ... ) + root.save(file_prefix=checkpoint_prefix) + ``` + + For more manual control over saving, use `tfe.CheckpointableSaver` directly. + + Attributes: + save_counter: Incremented when `save()` is called. Used to number + checkpoints. + """ + + def __init__(self, **kwargs): + """Group objects into a training checkpoint. + + Args: + **kwargs: Keyword arguments are set as attributes of this object, and are + saved with the checkpoint. Attribute values must derive from + `CheckpointableBase`. + Raises: + ValueError: If objects in `kwargs` are not Checkpointable. + """ + super(Checkpoint, self).__init__() + for k, v in sorted(kwargs.items(), key=lambda item: item[0]): + if not isinstance(v, core_checkpointable.CheckpointableBase): + raise ValueError( + ("`Checkpoint` was expecting an object derived from " + "`CheckpointableBase`, got %s.") % (v,)) + setattr(self, k, v) + self._save_counter = None # Created lazily for restore-on-create. + self._saver = CheckpointableSaver(weakref.ref(self)) + + def _maybe_create_save_counter(self): + """Create a save counter if it does not yet exist.""" + if self._save_counter is None: + # Initialized to 0 and incremented before saving. + with ops.device("/cpu:0"): + self._save_counter = add_variable( + self, name="save_counter", initializer=0, dtype=dtypes.int64) + + @property + def save_counter(self): + """An integer variable which starts at zero and is incremented on save. + + Used to number checkpoints. + + Returns: + The save counter variable. + """ + self._maybe_create_save_counter() + return self._save_counter + + def save(self, file_prefix, session=None): + """Save a checkpoint. Wraps `tfe.CheckpointableSaver.save`.""" + in_graph_mode = not context.executing_eagerly() + if in_graph_mode: + if session is None: + session = ops.get_default_session() + if self._save_counter is None: + # When graph building, if this is a new save counter variable then it + # needs to be initialized before assign_add. This is only an issue if + # restore() has not been called first. + session.run(self.save_counter.initializer) + with ops.colocate_with(self.save_counter): + assign_op = self.save_counter.assign_add(1) + if in_graph_mode: + session.run(assign_op) + return self._saver.save( + file_prefix=file_prefix, + checkpoint_number=self.save_counter, + session=session) + + def restore(self, save_path): + """Restore a checkpoint. Wraps `tfe.CheckpointableSaver.restore`.""" + status = self._saver.restore(save_path=save_path) + # Create the save counter now so it gets initialized with other variables + # when graph building. Creating it earlier would lead to double + # initialization when executing eagerly. + self._maybe_create_save_counter() + return status diff --git a/tensorflow/contrib/eager/python/checkpointable_utils_test.py b/tensorflow/contrib/eager/python/checkpointable_utils_test.py index b7554defde57201d21eefe70cbf4e4e98651e175..a8c47d76d1682296850c488f09aa6c358c5e6ee1 100644 --- a/tensorflow/contrib/eager/python/checkpointable_utils_test.py +++ b/tensorflow/contrib/eager/python/checkpointable_utils_test.py @@ -18,98 +18,32 @@ from __future__ import print_function import functools import os -import weakref import six from tensorflow.contrib.eager.python import checkpointable_utils -from tensorflow.contrib.eager.python import network as network_lib +from tensorflow.python.client import session as session_lib +from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.layers import base +from tensorflow.python.keras._impl.keras.engine import training from tensorflow.python.layers import core +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops +from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables from tensorflow.python.training import adam from tensorflow.python.training import checkpointable from tensorflow.python.training import saver as core_saver from tensorflow.python.training import training_util -class CheckpointableDenseLayer(core.Dense, checkpointable.Checkpointable): - - def __init__(self, *args, **kwargs): - checkpointable.Checkpointable.__init__(self) - core.Dense.__init__(self, *args, **kwargs) - - def add_variable(self, name, shape, **kwargs): - # Calls both Checkpointable._add_variable and Layer.add_variable. Eventually - # Layer.add_variable should inherit from Checkpointable and simply call - # super and then do post-processing. - return checkpointable.Checkpointable._add_variable_with_custom_getter( - self, - name=name, - shape=shape, - getter=functools.partial(core.Dense.add_variable, self), - **kwargs) - - -# pylint: disable=not-callable -class CheckpointableNetwork(network_lib.Network, checkpointable.Checkpointable): - - def __setattr__(self, name, value): - if isinstance(value, base.Layer): - self.track_layer(value, name=name) - # Checkpointable is next in the method resolution order, so this will catch - # Checkpointable objects which aren't Layers. - super(CheckpointableNetwork, self).__setattr__(name, value) - - def track_layer(self, layer, name): - self._track_checkpointable(layer, name=name) - return super(CheckpointableNetwork, self).track_layer(layer) - - -class CheckpointableAdam(adam.AdamOptimizer, checkpointable.Checkpointable): - - # NOTE: Copied from Optimizer with modifications to use add_variable - # for non-slot variables. These contortions are necessary to maintain - # checkpoint compatibility with variable.name based saving. - # TODO(allenl): Make this cleaner. - def _create_non_slot_variable(self, initial_value, name, colocate_with): - """Add an extra variable, not associated with a slot.""" - if context.in_graph_mode(): - graph = colocate_with.graph - else: - graph = None - - key = (name, graph) - v = self._non_slot_dict.get(key, None) - if v is None: - with ops.colocate_with(colocate_with): - def _variable_getter(name, shape, dtype, initializer): - del shape, dtype # not used, but there for compatibility - return variable_scope.variable( - name=name, initial_value=initializer, trainable=False) - - initial_value = ops.convert_to_tensor(initial_value) - v = self._add_variable_with_custom_getter( - name=name, - shape=initial_value.get_shape(), - initializer=initial_value, - getter=_variable_getter) - - self._non_slot_dict[key] = v - - return v - - class NonLayerCheckpointable(checkpointable.Checkpointable): def __init__(self): @@ -118,60 +52,20 @@ class NonLayerCheckpointable(checkpointable.Checkpointable): self, name="a_variable", shape=[]) -class MyNetwork(CheckpointableNetwork): - """A concrete Network for testing.""" +# pylint: disable=not-callable +class MyModel(training.Model): + """A concrete Model for testing.""" def __init__(self): - super(MyNetwork, self).__init__() - self._named_dense = CheckpointableDenseLayer(1, use_bias=True) - self._via_track_layer = self.track_layer( - CheckpointableDenseLayer(1, use_bias=False), name="via_track_layer") + super(MyModel, self).__init__() + self._named_dense = core.Dense(1, use_bias=True) + self._second = core.Dense(1, use_bias=False) # We can still track Checkpointables which aren't Layers. self._non_layer = NonLayerCheckpointable() def call(self, values): - return self._via_track_layer(self._named_dense(values)) - - -class Checkpoint(checkpointable.Checkpointable): - """A utility class which groups `Checkpointable` objects.""" - - def __init__(self, **kwargs): - super(Checkpoint, self).__init__() - for k, v in sorted(kwargs.items(), key=lambda item: item[0]): - setattr(self, k, v) - self._save_counter = None - self._saver = checkpointable_utils.Saver(weakref.ref(self)) - - @property - def save_counter(self): - """An integer variable which starts at zero and is incremented on save. - - Used to number checkpoints. - - Returns: - The save counter variable. - """ - if self._save_counter is None: - # Initialized to 0 and incremented before saving. - self._save_counter = checkpointable_utils.add_variable( - self, name="save_counter", initializer=0, dtype=dtypes.int64) - return self._save_counter - - def save(self, file_prefix, session=None): - assign_op = self.save_counter.assign_add(1) - if context.in_graph_mode(): - if session is None: - session = ops.get_default_session() - session.run(assign_op) - return self._saver.save( - file_prefix=file_prefix, - checkpoint_number=self.save_counter, - session=session) - - def restore(self, save_path): - return self._saver.restore( - save_path=save_path) + ret = self._second(self._named_dense(values)) + return ret class InterfaceTests(test.TestCase): @@ -206,8 +100,7 @@ class InterfaceTests(test.TestCase): with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"): checkpointable_utils.add_variable(obj, name="duplicate", shape=[]) - if context.in_graph_mode(): - self.evaluate(variables.global_variables_initializer()) + self.evaluate(checkpointable_utils.gather_initializers(obj)) self.assertEqual("constant_initializer:0", constant_initializer.name) self.assertEqual(1, self.evaluate(constant_initializer)) self.assertEqual("some_variable_scope/ones_initializer:0", @@ -217,14 +110,14 @@ class InterfaceTests(test.TestCase): [0., 0.]], self.evaluate(bare_initializer)) self.assertEqual("a_variable:0", obj.a_variable.name) self.assertEqual("duplicate:0", other_duplicate.name) - if context.in_graph_mode(): - # The .name attribute may be globally influenced, but the checkpoint name - # won't be (tested below). - self.assertEqual("duplicate_1:0", duplicate.name) - else: + if context.executing_eagerly(): # When executing eagerly, there's no uniquification of variable names. The # checkpoint name will be the same. self.assertEqual("duplicate:0", duplicate.name) + else: + # The .name attribute may be globally influenced, but the checkpoint name + # won't be (tested below). + self.assertEqual("duplicate_1:0", duplicate.name) named_variables, _ = checkpointable_utils._serialize_object_graph(obj) expected_checkpoint_names = ( "a_variable/.ATTRIBUTES/VARIABLE_VALUE", @@ -261,57 +154,99 @@ class InterfaceTests(test.TestCase): self.assertAllEqual([1., 1., 1.], self.evaluate(v2)) +class _MirroringSaveable(core_saver.BaseSaverBuilder.SaveableObject): + + def __init__(self, primary_variable, mirrored_variable, name): + self._primary_variable = primary_variable + self._mirrored_variable = mirrored_variable + tensor = self._primary_variable.read_value() + spec = core_saver.BaseSaverBuilder.SaveSpec( + tensor=tensor, + slice_spec="", + name=name) + super(_MirroringSaveable, self).__init__( + tensor, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into both variables.""" + tensor, = restored_tensors + return control_flow_ops.group( + self._primary_variable.assign(tensor), + self._mirrored_variable.assign(tensor)) + + +class _OwnsMirroredVariables(checkpointable.CheckpointableBase): + """A Checkpointable object which returns a more complex SaveableObject.""" + + def __init__(self): + self.non_dep_variable = variable_scope.get_variable( + name="non_dep_variable", initializer=6., use_resource=True) + self.mirrored = variable_scope.get_variable( + name="mirrored", initializer=15., use_resource=True) + + def _gather_saveables_for_checkpoint(self): + def _saveable_factory(name=self.non_dep_variable.name): + return _MirroringSaveable( + primary_variable=self.non_dep_variable, + mirrored_variable=self.mirrored, + name=name) + return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + + # The Saver sorts by name before parsing, so we need a name property. + @property + def name(self): + return self.non_dep_variable.name + + class CheckpointingTests(test.TestCase): @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNamingWithOptimizer(self): input_value = constant_op.constant([[3.]]) - network = MyNetwork() - # A nuisance Network using the same optimizer. Its slot variables should not + model = MyModel() + # A nuisance Model using the same optimizer. Its slot variables should not # go in the checkpoint, since it is never depended on. - other_network = MyNetwork() - optimizer = CheckpointableAdam(0.001) + other_model = MyModel() + optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() - root_checkpointable = Checkpoint( - optimizer=optimizer, network=network, optimizer_step=optimizer_step) - if context.in_eager_mode(): + root_checkpointable = checkpointable_utils.Checkpoint( + optimizer=optimizer, model=model, optimizer_step=optimizer_step) + if context.executing_eagerly(): optimizer.minimize( - lambda: network(input_value), + lambda: model(input_value), global_step=optimizer_step) optimizer.minimize( - lambda: other_network(input_value), + lambda: other_model(input_value), global_step=optimizer_step) else: train_op = optimizer.minimize( - network(input_value), global_step=optimizer_step) + model(input_value), global_step=optimizer_step) optimizer.minimize( - other_network(input_value), + other_model(input_value), global_step=optimizer_step) - self.evaluate(variables.global_variables_initializer()) + self.evaluate(checkpointable_utils.gather_initializers( + root_checkpointable)) self.evaluate(train_op) named_variables, serialized_graph = ( checkpointable_utils._serialize_object_graph(root_checkpointable)) expected_checkpoint_names = ( # Created in the root node, so no prefix. "optimizer_step", - # No name provided to track_checkpointable(), so the position is used - # instead (one-based). - "network/via_track_layer/kernel", - # track_checkpointable() with a name provided, so that's used - "network/_named_dense/kernel", - "network/_named_dense/bias", - # non-Layer dependency of the network - "network/_non_layer/a_variable", + "model/_second/kernel", + "model/_named_dense/kernel", + "model/_named_dense/bias", + # non-Layer dependency of the model + "model/_non_layer/a_variable", # The optimizer creates two non-slot variables "optimizer/beta1_power", "optimizer/beta2_power", # Slot variables - "network/via_track_layer/kernel/.OPTIMIZER_SLOT/optimizer/m", - "network/via_track_layer/kernel/.OPTIMIZER_SLOT/optimizer/v", - "network/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m", - "network/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v", - "network/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m", - "network/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v", + "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m", + "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v", + "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m", + "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v", + "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m", + "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v", ) suffix = "/.ATTRIBUTES/VARIABLE_VALUE" expected_checkpoint_names = [ @@ -323,11 +258,11 @@ class CheckpointingTests(test.TestCase): "global_step:0", named_variables["optimizer_step" + suffix].name) self.assertEqual( - "my_network/checkpointable_dense_layer_1/kernel:0", - named_variables["network/via_track_layer/kernel" + suffix].name) + "my_model/dense_1/kernel:0", + named_variables["model/_second/kernel" + suffix].name) self.assertEqual( - "my_network/checkpointable_dense_layer/kernel:0", - named_variables["network/_named_dense/kernel" + suffix].name) + "my_model/dense/kernel:0", + named_variables["model/_named_dense/kernel" + suffix].name) self.assertEqual( "beta1_power:0", named_variables["optimizer/beta1_power" + suffix].name) @@ -345,107 +280,150 @@ class CheckpointingTests(test.TestCase): serialized_graph.nodes[optimizer_node.children[0].node_id] .attributes[0].full_name) self.assertEqual( - "my_network/checkpointable_dense_layer/kernel", + "my_model/dense/kernel", serialized_graph.nodes[optimizer_node.slot_variables[0] .original_variable_node_id] .attributes[0].full_name) # We strip off the :0 suffix, as variable.name-based saving does. self.assertEqual( - "my_network/checkpointable_dense_layer/kernel/Adam", + "my_model/dense/kernel/Adam", serialized_graph.nodes[optimizer_node.slot_variables[0] .slot_variable_node_id] .attributes[0].full_name) self.assertEqual( - "my_network/checkpointable_dense_layer/kernel/Adam:0", + "my_model/dense/kernel/Adam:0", optimizer.get_slot( - var=named_variables["network/_named_dense/kernel" + suffix], + var=named_variables["model/_named_dense/kernel" + suffix], name="m").name) self.assertEqual( - "network/_named_dense/kernel" + suffix, + "model/_named_dense/kernel" + suffix, serialized_graph.nodes[ optimizer_node.slot_variables[0] .original_variable_node_id].attributes[0].checkpoint_key) self.assertEqual("m", optimizer_node.slot_variables[0].slot_name) self.assertEqual( - "network/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix, + "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix, serialized_graph.nodes[ optimizer_node.slot_variables[0] .slot_variable_node_id].attributes[0].checkpoint_key) + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testMoreComplexSaveableReturned(self): + v = _OwnsMirroredVariables() + checkpoint = checkpointable_utils.Checkpoint(v=v) + test_dir = self.get_temp_dir() + prefix = os.path.join(test_dir, "ckpt") + self.evaluate(v.non_dep_variable.assign(42.)) + save_path = checkpoint.save(prefix) + self.evaluate(v.non_dep_variable.assign(43.)) + self.evaluate(v.mirrored.assign(44.)) + checkpoint.restore(save_path).assert_consumed().initialize_or_restore() + self.assertEqual(42., self.evaluate(v.non_dep_variable)) + self.assertEqual(42., self.evaluate(v.mirrored)) + self.evaluate(v.non_dep_variable.assign(44.)) + save_path = checkpoint.save(prefix) + self.evaluate(v.non_dep_variable.assign(45.)) + checkpoint.restore(save_path).assert_consumed().initialize_or_restore() + self.assertEqual(44., self.evaluate(v.non_dep_variable)) + self.assertEqual(44., self.evaluate(v.mirrored)) + + @test_util.run_in_graph_and_eager_modes() + def testMoreComplexSaveableReturnedWithGlobalName(self): + # The same object can also be saved using the name-based saver. + v = _OwnsMirroredVariables() + saver = core_saver.Saver(var_list=[v]) + test_dir = self.get_temp_dir() + prefix = os.path.join(test_dir, "ckpt") + self.evaluate(v.non_dep_variable.assign(42.)) + with self.test_session() as sess: + save_path = saver.save(sess, prefix) + self.evaluate(v.non_dep_variable.assign(43.)) + self.evaluate(v.mirrored.assign(44.)) + saver.restore(sess, save_path) + self.assertEqual(42., self.evaluate(v.non_dep_variable)) + self.assertEqual(42., self.evaluate(v.mirrored)) + @test_util.run_in_graph_and_eager_modes() def testSaveRestore(self): - network = MyNetwork() - optimizer = CheckpointableAdam(0.001) - root_checkpointable = Checkpoint(optimizer=optimizer, network=network) + model = MyModel() + optimizer = adam.AdamOptimizer(0.001) + root_checkpointable = checkpointable_utils.Checkpoint( + optimizer=optimizer, model=model) input_value = constant_op.constant([[3.]]) - if context.in_eager_mode(): + if context.executing_eagerly(): optimizer.minimize( - lambda: network(input_value)) + lambda: model(input_value)) else: - train_op = optimizer.minimize(network(input_value)) + train_op = optimizer.minimize(model(input_value)) # TODO(allenl): Make initialization more pleasant when graph building. root_checkpointable.save_counter # pylint: disable=pointless-statement - self.evaluate(variables.global_variables_initializer()) + self.evaluate(checkpointable_utils.gather_initializers( + root_checkpointable)) self.evaluate(train_op) prefix = os.path.join(self.get_temp_dir(), "ckpt") - self.evaluate(state_ops.assign(network._named_dense.variables[1], [42.])) - m_bias_slot = optimizer.get_slot(network._named_dense.variables[1], "m") + self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.])) + m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m") self.evaluate(state_ops.assign(m_bias_slot, [1.5])) save_path = root_checkpointable.save(file_prefix=prefix) - self.evaluate(state_ops.assign(network._named_dense.variables[1], [43.])) + self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.])) self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3)) optimizer_variables = self.evaluate(optimizer.variables()) self.evaluate(state_ops.assign(m_bias_slot, [-2.])) # Immediate restoration status = root_checkpointable.restore(save_path=save_path).assert_consumed() status.run_restore_ops() - self.assertAllEqual([42.], self.evaluate(network._named_dense.variables[1])) + self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1])) self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter)) self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) - if context.in_graph_mode(): + if not context.executing_eagerly(): return # Restore-on-create is only supported when executing eagerly - on_create_network = MyNetwork() - on_create_optimizer = CheckpointableAdam(0.001) - on_create_root = Checkpoint( - optimizer=on_create_optimizer, network=on_create_network) + on_create_model = MyModel() + on_create_optimizer = adam.AdamOptimizer( + 0.001, + # Preserve beta1_power and beta2_power when appying gradients so we can + # test that they've been restored correctly. + beta1=1.0, beta2=1.0) + on_create_root = checkpointable_utils.Checkpoint( + optimizer=on_create_optimizer, model=on_create_model) # Deferred restoration status = on_create_root.restore(save_path=save_path) - on_create_network(constant_op.constant([[3.]])) # create variables + on_create_model(constant_op.constant([[3.]])) # create variables self.assertAllEqual(1, self.evaluate(on_create_root.save_counter)) self.assertAllEqual([42.], self.evaluate( - on_create_network._named_dense.variables[1])) + on_create_model._named_dense.variables[1])) on_create_m_bias_slot = on_create_optimizer.get_slot( - on_create_network._named_dense.variables[1], "m") + on_create_model._named_dense.variables[1], "m") # Optimizer slot variables are created when the original variable is # restored. self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) self.assertAllEqual(optimizer_variables[2:], self.evaluate(on_create_optimizer.variables())) - on_create_optimizer._create_slots( - [resource_variable_ops.ResourceVariable([1.])]) + dummy_var = resource_variable_ops.ResourceVariable([1.]) + on_create_optimizer.minimize(loss=dummy_var.read_value) status.assert_consumed() beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators() self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power)) self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power)) + # TODO(allenl): Debug garbage created by this test in python3. def testDeferredRestorationUsageEager(self): """An idiomatic eager execution example.""" num_training_steps = 10 checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") for training_continuation in range(3): - network = MyNetwork() - optimizer = CheckpointableAdam(0.001) - root = Checkpoint( - optimizer=optimizer, network=network, + model = MyModel() + optimizer = adam.AdamOptimizer(0.001) + root = checkpointable_utils.Checkpoint( + optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step()) root.restore(core_saver.latest_checkpoint(checkpoint_directory)) for _ in range(num_training_steps): # TODO(allenl): Use a Dataset and serialize/checkpoint it. input_value = constant_op.constant([[3.]]) optimizer.minimize( - lambda: network(input_value), # pylint: disable=cell-var-from-loop + lambda: model(input_value), # pylint: disable=cell-var-from-loop global_step=root.optimizer_step) root.save(file_prefix=checkpoint_prefix) self.assertEqual((training_continuation + 1) * num_training_steps, @@ -459,37 +437,66 @@ class CheckpointingTests(test.TestCase): checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") for training_continuation in range(3): with ops.Graph().as_default(): - network = MyNetwork() - optimizer = CheckpointableAdam(0.001) - root = Checkpoint( - optimizer=optimizer, network=network, + model = MyModel() + optimizer = adam.AdamOptimizer(0.001) + root = checkpointable_utils.Checkpoint( + optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) input_value = constant_op.constant([[3.]]) train_op = optimizer.minimize( - network(input_value), + model(input_value), global_step=root.global_step) - root.save_counter # pylint: disable=pointless-statement - init_op = variables.global_variables_initializer() checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) with self.test_session(graph=ops.get_default_graph()) as session: + status = root.restore(save_path=checkpoint_path) + status.initialize_or_restore(session=session) if checkpoint_path is None: self.assertEqual(0, training_continuation) - session.run(init_op) - # Another alternative would be to run initializers automatically - # if no checkpoint is being loaded. This would make deferred - # loading a bit more useful with graph execution. + with self.assertRaises(AssertionError): + status.assert_consumed() else: - status = root.restore(save_path=checkpoint_path).assert_consumed() - status.run_restore_ops() + status.assert_consumed() for _ in range(num_training_steps): session.run(train_op) - root.save(file_prefix=checkpoint_prefix, - session=session) + root.save(file_prefix=checkpoint_prefix, session=session) self.assertEqual((training_continuation + 1) * num_training_steps, session.run(root.global_step)) self.assertEqual(training_continuation + 1, session.run(root.save_counter)) + @test_util.run_in_graph_and_eager_modes() + def testAgnosticUsage(self): + """Graph/eager agnostic usage.""" + # Does create garbage when executing eagerly due to ops.Graph() creation. + num_training_steps = 10 + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + for training_continuation in range(3): + with ops.Graph().as_default(), self.test_session( + graph=ops.get_default_graph()), test_util.device(use_gpu=True): + model = MyModel() + optimizer = adam.AdamOptimizer(0.001) + root = checkpointable_utils.Checkpoint( + optimizer=optimizer, model=model, + global_step=training_util.get_or_create_global_step()) + checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + status = root.restore(save_path=checkpoint_path) + input_value = constant_op.constant([[3.]]) + train_fn = functools.partial( + optimizer.minimize, + functools.partial(model, input_value), + global_step=root.global_step) + if not context.executing_eagerly(): + train_fn = functools.partial(self.evaluate, train_fn()) + status.initialize_or_restore() + for _ in range(num_training_steps): + train_fn() + root.save(file_prefix=checkpoint_prefix) + self.assertEqual((training_continuation + 1) * num_training_steps, + self.evaluate(root.global_step)) + self.assertEqual(training_continuation + 1, + self.evaluate(root.save_counter)) + def _get_checkpoint_name(self, name): root = checkpointable.Checkpointable() checkpointable_utils.add_variable( @@ -531,6 +538,35 @@ class CheckpointingTests(test.TestCase): name, = named_variables.keys() self.assertEqual(name, "..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE") + def testAnonymousVarsInInit(self): + + class Model(training.Model): + + def __init__(self): + super(Model, self).__init__() + self.w = resource_variable_ops.ResourceVariable(0.0) + self.b = resource_variable_ops.ResourceVariable(0.0) + self.vars = [self.w, self.b] + + def call(self, x): + return x * self.w + self.b + + with context.eager_mode(): + model = Model() + optimizer = adam.AdamOptimizer(learning_rate=0.05) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + checkpoint = checkpointable_utils.Checkpoint( + model=model, optimizer=optimizer) + for _ in range(2): + checkpoint.save(checkpoint_prefix) + with backprop.GradientTape() as tape: + loss = (constant_op.constant(1.) + - model(constant_op.constant(1.))) ** 2 + grad = tape.gradient(loss, model.vars) + optimizer.apply_gradients( + [(g, v) for g, v in zip(grad, model.vars)]) + @test_util.run_in_graph_and_eager_modes() def testLateDependencyTracking(self): @@ -551,9 +587,11 @@ class CheckpointingTests(test.TestCase): self.evaluate(state_ops.assign(original.dep.var, 123.)) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - save_path = checkpointable_utils.Saver(original).save(checkpoint_prefix) + save_path = checkpointable_utils.CheckpointableSaver( + original).save(checkpoint_prefix) load_into = LateDependencies() - status = checkpointable_utils.Saver(load_into).restore(save_path) + status = checkpointable_utils.CheckpointableSaver( + load_into).restore(save_path) with self.assertRaises(AssertionError): status.assert_consumed() load_into.add_dep() @@ -582,11 +620,12 @@ class CheckpointingTests(test.TestCase): self.evaluate(state_ops.assign(dep_after_var.dep.var, -14.)) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - save_path = checkpointable_utils.Saver(dep_after_var).save( + save_path = checkpointable_utils.CheckpointableSaver(dep_after_var).save( checkpoint_prefix) loaded_dep_after_var = DepAfterVar() - status = checkpointable_utils.Saver(loaded_dep_after_var).restore(save_path) + status = checkpointable_utils.CheckpointableSaver( + loaded_dep_after_var).restore(save_path) loaded_dep_after_var.add_dep() status.assert_consumed() status.run_restore_ops() @@ -599,27 +638,33 @@ class CheckpointingTests(test.TestCase): root = checkpointable.Checkpointable() root.var = checkpointable_utils.add_variable( root, name="var", initializer=0.) - optimizer = CheckpointableAdam(0.1) - if context.in_graph_mode(): + optimizer = adam.AdamOptimizer(0.1) + if context.executing_eagerly(): + optimizer.minimize(root.var.read_value) + else: train_op = optimizer.minimize(root.var) - self.evaluate(variables.global_variables_initializer()) + # Note that `optimizer` has not been added as a dependency of + # `root`. Create a one-off grouping so that slot variables for `root.var` + # get initialized too. + self.evaluate(checkpointable_utils.gather_initializers( + checkpointable_utils.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) - else: - optimizer.minimize(root.var.read_value) self.evaluate(state_ops.assign(root.var, 12.)) - no_slots_path = checkpointable_utils.Saver(root).save( + no_slots_path = checkpointable_utils.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) - slots_path = checkpointable_utils.Saver(root).save( + slots_path = checkpointable_utils.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "with_slots")) new_root = checkpointable.Checkpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). - slot_status = checkpointable_utils.Saver(new_root).restore(slots_path) - no_slot_status = checkpointable_utils.Saver(new_root).restore(no_slots_path) + slot_status = checkpointable_utils.CheckpointableSaver( + new_root).restore(slots_path) + no_slot_status = checkpointable_utils.CheckpointableSaver( + new_root).restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = checkpointable_utils.add_variable( @@ -627,11 +672,11 @@ class CheckpointingTests(test.TestCase): no_slot_status.assert_consumed() no_slot_status.run_restore_ops() self.assertEqual(12., self.evaluate(new_root.var)) - new_root.optimizer = CheckpointableAdam(0.1) + new_root.optimizer = adam.AdamOptimizer(0.1) with self.assertRaisesRegexp(AssertionError, "beta1_power"): slot_status.assert_consumed() self.assertEqual(12., self.evaluate(new_root.var)) - if context.in_eager_mode(): + if context.executing_eagerly(): # Slot variables are only created with restoring initializers when # executing eagerly. self.assertEqual(14., self.evaluate( @@ -639,7 +684,9 @@ class CheckpointingTests(test.TestCase): else: self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var), None) - if context.in_graph_mode(): + if context.executing_eagerly(): + new_root.optimizer.minimize(new_root.var.read_value) + else: train_op = new_root.optimizer.minimize(new_root.var) # The slot variable now exists; restore() didn't create it, but we should # now have a restore op for it. @@ -647,8 +694,6 @@ class CheckpointingTests(test.TestCase): self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) self.evaluate(train_op) - else: - new_root.optimizer.minimize(new_root.var.read_value) slot_status.assert_consumed() @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) @@ -659,15 +704,17 @@ class CheckpointingTests(test.TestCase): save_root.dep.var = checkpointable_utils.add_variable( save_root.dep, name="var", initializer=0.) self.evaluate(state_ops.assign(save_root.dep.var, 12.)) - saver = checkpointable_utils.Saver(save_root) + saver = checkpointable_utils.CheckpointableSaver(save_root) first_path = saver.save(os.path.join(checkpoint_directory, "first")) self.evaluate(state_ops.assign(save_root.dep.var, 13.)) second_path = saver.save(os.path.join(checkpoint_directory, "second")) first_root = checkpointable.Checkpointable() second_root = checkpointable.Checkpointable() - first_status = checkpointable_utils.Saver(first_root).restore(first_path) - second_status = checkpointable_utils.Saver(second_root).restore(second_path) + first_status = checkpointable_utils.CheckpointableSaver( + first_root).restore(first_path) + second_status = checkpointable_utils.CheckpointableSaver( + second_root).restore(second_path) load_dep = checkpointable.Checkpointable() load_dep.var = checkpointable_utils.add_variable( load_dep, name="var", shape=[]) @@ -684,8 +731,10 @@ class CheckpointingTests(test.TestCase): # determines the final value. first_root = checkpointable.Checkpointable() second_root = checkpointable.Checkpointable() - second_status = checkpointable_utils.Saver(second_root).restore(second_path) - first_status = checkpointable_utils.Saver(first_root).restore(first_path) + second_status = checkpointable_utils.CheckpointableSaver( + second_root).restore(second_path) + first_status = checkpointable_utils.CheckpointableSaver( + first_root).restore(first_path) load_dep = checkpointable.Checkpointable() load_dep.var = checkpointable_utils.add_variable( load_dep, name="var", shape=[]) @@ -709,11 +758,11 @@ class CheckpointingTests(test.TestCase): save_root.dep_one.dep_three = dep_three save_root.dep_two.dep_three = dep_three checkpointable_utils.add_variable(dep_three, name="var", initializer=0.) - self.evaluate(variables.global_variables_initializer()) - save_path = checkpointable_utils.Saver(save_root).save( + self.evaluate(checkpointable_utils.gather_initializers(save_root)) + save_path = checkpointable_utils.CheckpointableSaver(save_root).save( os.path.join(checkpoint_directory, "ckpt")) load_root = checkpointable.Checkpointable() - checkpointable_utils.Saver(load_root).restore(save_path) + checkpointable_utils.CheckpointableSaver(load_root).restore(save_path) load_root.dep_one = checkpointable.Checkpointable() load_root.dep_two = checkpointable.Checkpointable() load_root.dep_one.dep_three = checkpointable.Checkpointable() @@ -732,8 +781,8 @@ class CheckpointingTests(test.TestCase): save_root.dep_one, name="var1", initializer=32., dtype=dtypes.float64) checkpointable_utils.add_variable( save_root.dep_two, name="var2", initializer=64., dtype=dtypes.float64) - self.evaluate(variables.global_variables_initializer()) - save_path = checkpointable_utils.Saver(save_root).save( + self.evaluate(checkpointable_utils.gather_initializers(save_root)) + save_path = checkpointable_utils.CheckpointableSaver(save_root).save( os.path.join(checkpoint_directory, "ckpt")) load_root = checkpointable.Checkpointable() load_root.dep_one = checkpointable.Checkpointable() @@ -742,7 +791,7 @@ class CheckpointingTests(test.TestCase): load_root.dep_one, name="var1", shape=[], dtype=dtypes.float64) v2 = checkpointable_utils.add_variable( load_root.dep_one, name="var2", shape=[], dtype=dtypes.float64) - status = checkpointable_utils.Saver(load_root).restore( + status = checkpointable_utils.CheckpointableSaver(load_root).restore( save_path).assert_consumed() status.run_restore_ops() self.assertEqual(32., self.evaluate(v1)) @@ -760,14 +809,15 @@ class CheckpointingTests(test.TestCase): first, "v1", initializer=[3., 1., 4.]) second.v = checkpointable_utils.add_variable( second, "v2", initializer=[1., 1., 2., 3.]) - self.evaluate(variables.global_variables_initializer()) + self.evaluate(checkpointable_utils.gather_initializers(first)) checkpoint_directory = self.get_temp_dir() - save_path = checkpointable_utils.Saver(first).save( + save_path = checkpointable_utils.CheckpointableSaver(first).save( os.path.join(checkpoint_directory, "ckpt")) # Test deferred loading first_load = checkpointable.Checkpointable() - status = checkpointable_utils.Saver(first_load).restore(save_path) + status = checkpointable_utils.CheckpointableSaver( + first_load).restore(save_path) second_load = checkpointable.Checkpointable() first_load.second = second_load second_load.first = first_load @@ -787,7 +837,7 @@ class CheckpointingTests(test.TestCase): self.assertAllEqual([2., 7., 1.], self.evaluate(first_load.v)) self.evaluate(second_load.v.assign([2., 7., 1., 8.])) self.assertAllEqual([2., 7., 1., 8.], self.evaluate(second_load.v)) - status = checkpointable_utils.Saver(first_load).restore( + status = checkpointable_utils.CheckpointableSaver(first_load).restore( save_path).assert_consumed() status.run_restore_ops() self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v)) @@ -806,14 +856,15 @@ class CheckpointingTests(test.TestCase): name="blah", initializer=0.) self.evaluate(first.var1.assign(4.)) self.evaluate(first.var2.assign(8.)) - save_path = checkpointable_utils.Saver(first).save( + save_path = checkpointable_utils.CheckpointableSaver(first).save( checkpoint_prefix) restore_graph = ops.Graph() with restore_graph.as_default(), self.test_session(restore_graph): second = checkpointable.Checkpointable() second.var2 = variable_scope.get_variable( name="blah", initializer=0.) - status = checkpointable_utils.Saver(second).restore(save_path) + status = checkpointable_utils.CheckpointableSaver( + second).restore(save_path) recreated_var1 = variable_scope.get_variable( name="outside_var", initializer=0.) status.run_restore_ops() @@ -833,15 +884,81 @@ class CheckpointingTests(test.TestCase): checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") obj = checkpointable.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) - obj.opt = CheckpointableAdam(0.1) + obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) - self.evaluate(variables.global_variables_initializer()) - saver = checkpointable_utils.Saver(obj) + self.evaluate(checkpointable_utils.gather_initializers(obj)) + saver = checkpointable_utils.CheckpointableSaver(obj) saver.save(checkpoint_prefix) before_ops = graph.get_operations() saver.save(checkpoint_prefix) self.assertEqual(before_ops, graph.get_operations()) + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testCheckpointCleanup(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + obj = checkpointable.Checkpointable() + obj.var = variable_scope.get_variable(name="v", initializer=0.) + self.evaluate(checkpointable_utils.gather_initializers(obj)) + saver = checkpointable_utils.Checkpoint(obj=obj) + for _ in range(10): + saver.save(checkpoint_prefix) + expected_filenames = ["checkpoint"] + for checkpoint_number in range(6, 11): + expected_filenames.append("ckpt-%d.index" % (checkpoint_number,)) + expected_filenames.append( + "ckpt-%d.data-00000-of-00001" % (checkpoint_number,)) + six.assertCountEqual( + self, + expected_filenames, + os.listdir(checkpoint_directory)) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testCheckpointCleanupChangingVarList(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + obj = checkpointable.Checkpointable() + obj.var = variable_scope.get_variable(name="v", initializer=0.) + self.evaluate(checkpointable_utils.gather_initializers(obj)) + checkpoint = checkpointable_utils.Checkpoint(obj=obj) + looped_variables = [] + for iteration in range(10): + new_variable = resource_variable_ops.ResourceVariable(iteration) + self.evaluate(new_variable.initializer) + setattr(checkpoint, "var_%d" % iteration, new_variable) + checkpoint.save(checkpoint_prefix) + looped_variables.append(new_variable) + expected_filenames = ["checkpoint"] + # We've copied the saver each time, but checkpoint management should still + # be consistent. + for checkpoint_number in range(6, 11): + expected_filenames.append("ckpt-%d.index" % (checkpoint_number,)) + expected_filenames.append( + "ckpt-%d.data-00000-of-00001" % (checkpoint_number,)) + six.assertCountEqual( + self, + expected_filenames, + os.listdir(checkpoint_directory)) + for v in looped_variables: + self.evaluate(v.assign(314)) + checkpoint.restore(checkpoint_prefix + "-6").run_restore_ops() + self.assertEqual(314, self.evaluate(checkpoint.var_9)) + self.assertEqual(314, self.evaluate(checkpoint.var_8)) + self.assertEqual(314, self.evaluate(checkpoint.var_6)) + self.assertEqual(5, self.evaluate(checkpoint.var_5)) + self.assertEqual(1, self.evaluate(checkpoint.var_1)) + self.assertEqual(0, self.evaluate(checkpoint.var_0)) + if context.executing_eagerly(): + checkpoint.restore(checkpoint_prefix + "-10").run_restore_ops() + self.assertEqual(9, self.evaluate(checkpoint.var_9)) + self.assertEqual(8, self.evaluate(checkpoint.var_8)) + self.assertEqual(1, self.evaluate(checkpoint.var_1)) + self.assertEqual(0, self.evaluate(checkpoint.var_0)) + else: + # Restoring into modified graphs is an error while graph building. + with self.assertRaises(NotImplementedError): + checkpoint.restore(checkpoint_prefix + "-10").run_restore_ops() + def testManyRestoresGraph(self): """Restores after the first should not modify the graph.""" with context.graph_mode(): @@ -851,15 +968,263 @@ class CheckpointingTests(test.TestCase): checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") obj = checkpointable.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) - obj.opt = CheckpointableAdam(0.1) + obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) - self.evaluate(variables.global_variables_initializer()) - saver = checkpointable_utils.Saver(obj) + self.evaluate(checkpointable_utils.gather_initializers(obj)) + saver = checkpointable_utils.CheckpointableSaver(obj) save_path = saver.save(checkpoint_prefix) saver.restore(save_path) before_ops = graph.get_operations() saver.restore(save_path) self.assertEqual(before_ops, graph.get_operations()) + def testMultipleGraphsNonSlotVariables(self): + with context.graph_mode(): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + optimizer = adam.AdamOptimizer(0.001) + # Construct a model in one graph + first_graph = ops.Graph() + first_session = session_lib.Session(graph=first_graph) + with first_graph.as_default(), first_session.as_default(): + first_variable = resource_variable_ops.ResourceVariable([1.]) + first_root_checkpointable = checkpointable_utils.Checkpoint( + optimizer=optimizer, variable=first_variable) + train_op = optimizer.minimize(first_variable.read_value) + self.evaluate(checkpointable_utils.gather_initializers( + first_root_checkpointable)) + self.evaluate(train_op) + self.evaluate(first_variable.assign([1.])) + self.evaluate(optimizer.get_slot( + var=first_variable, name="m").assign([2.])) + beta1_power, _ = optimizer._get_beta_accumulators() + self.evaluate(beta1_power.assign(3.)) + + # Save and load in a second graph + second_graph = ops.Graph() + with second_graph.as_default(), session_lib.Session(graph=second_graph): + second_variable = resource_variable_ops.ResourceVariable([1.]) + second_root_checkpointable = checkpointable_utils.Checkpoint( + optimizer=optimizer, variable=second_variable) + train_op = optimizer.minimize(second_variable.read_value) + second_root_checkpointable.restore(None).initialize_or_restore() + self.evaluate(train_op) + self.evaluate(second_variable.assign([4.])) + self.evaluate(optimizer.get_slot( + var=second_variable, name="m").assign([5.])) + beta1_power, _ = optimizer._get_beta_accumulators() + self.evaluate(beta1_power.assign(6.)) + save_path = second_root_checkpointable.save(checkpoint_prefix) + self.evaluate(second_variable.assign([7.])) + self.evaluate(optimizer.get_slot( + var=second_variable, name="m").assign([8.])) + beta1_power, _ = optimizer._get_beta_accumulators() + self.assertAllEqual(6., self.evaluate(beta1_power)) + status = second_root_checkpointable.restore(save_path) + status.assert_consumed().run_restore_ops() + self.assertAllEqual([4.], self.evaluate(second_variable)) + self.assertAllEqual([5.], self.evaluate(optimizer.get_slot( + var=second_variable, name="m"))) + beta1_power, _ = optimizer._get_beta_accumulators() + self.assertAllEqual(6., self.evaluate(beta1_power)) + + # Check that the first graph is unmolested + with first_graph.as_default(), first_session.as_default(): + self.assertAllEqual([1.], self.evaluate(first_variable)) + self.assertAllEqual([2.], self.evaluate(optimizer.get_slot( + var=first_variable, name="m"))) + beta1_power, _ = optimizer._get_beta_accumulators() + self.assertAllEqual(3., self.evaluate(beta1_power)) + + +class TemplateTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def test_checkpointable_save_restore(self): + + def _templated(): + v = variable_scope.get_variable( + "v", shape=[1], initializer=init_ops.zeros_initializer()) + v2 = variable_scope.get_variable( + "v2", shape=[1], initializer=init_ops.zeros_initializer()) + return v, v + 1., v2 + + save_template = template.make_template("s1", _templated) + save_root = checkpointable_utils.Checkpoint(my_template=save_template) + v1_save, _, v2_save = save_template() + self.evaluate(v1_save.assign([12.])) + self.evaluate(v2_save.assign([14.])) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = save_root.save(checkpoint_prefix) + + load_template = template.make_template("s2", _templated) + load_root = checkpointable_utils.Checkpoint(my_template=load_template) + status = load_root.restore(save_path) + var, var_plus_one, var2 = load_template() + self.assertEqual(2, len(load_template._checkpoint_dependencies)) + self.assertEqual("v", load_template._checkpoint_dependencies[0].name) + self.assertEqual("v2", load_template._checkpoint_dependencies[1].name) + status.assert_consumed().run_restore_ops() + self.assertAllEqual([12.], self.evaluate(var)) + self.assertAllEqual([13.], self.evaluate(var_plus_one)) + self.assertAllEqual([14.], self.evaluate(var2)) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def test_checkpointable_save_restore_nested(self): + + def _inner_template(): + v = variable_scope.get_variable( + "v", shape=[1], initializer=init_ops.zeros_initializer()) + return v + + def _outer_template(): + first_inner = template.make_template("i1", _inner_template) + second_inner = template.make_template("i2", _inner_template) + v1 = first_inner() + v2 = second_inner() + v3 = second_inner() + return (first_inner, second_inner), (v1, v2, v3) + + with variable_scope.variable_scope("ignored"): + save_template = template.make_template("s1", _outer_template) + save_root = checkpointable_utils.Checkpoint(my_template=save_template) + (inner_template_one, inner_template_two), _ = save_template() + self.evaluate(inner_template_one.variables[0].assign([20.])) + self.evaluate(inner_template_two.variables[0].assign([25.])) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = save_root.save(checkpoint_prefix) + + load_template = template.make_template("s2", _outer_template) + load_root = checkpointable_utils.Checkpoint(my_template=load_template) + status = load_root.restore(save_path) + (inner_template_one, inner_template_two), (v1, v2, v3) = load_template() + outer_template_dependencies = load_root.my_template._checkpoint_dependencies + self.assertEqual(2, len(outer_template_dependencies)) + self.assertEqual("i1", outer_template_dependencies[0].name) + self.assertIs(inner_template_one, outer_template_dependencies[0].ref) + self.assertEqual("i2", outer_template_dependencies[1].name) + self.assertIs(inner_template_two, outer_template_dependencies[1].ref) + self.assertEqual(1, len(inner_template_one._checkpoint_dependencies)) + self.assertEqual("v", inner_template_one._checkpoint_dependencies[0].name) + self.assertEqual(1, len(inner_template_two._checkpoint_dependencies)) + self.assertEqual("v", inner_template_two._checkpoint_dependencies[0].name) + status.assert_consumed().run_restore_ops() + self.assertAllEqual([20.], self.evaluate(v1)) + self.assertAllEqual([25.], self.evaluate(v2)) + self.assertAllEqual([25.], self.evaluate(v3)) + + +class CheckpointCompatibilityTests(test.TestCase): + + def _initialized_model(self): + input_value = constant_op.constant([[3.]]) + model = MyModel() + optimizer = adam.AdamOptimizer(0.001) + optimizer_step = training_util.get_or_create_global_step() + root_checkpointable = checkpointable_utils.Checkpoint( + optimizer=optimizer, model=model, optimizer_step=optimizer_step) + train_op = optimizer.minimize( + functools.partial(model, input_value), + global_step=optimizer_step) + self.evaluate(checkpointable_utils.gather_initializers( + root_checkpointable)) + self.evaluate(train_op) + # A regular variable, a slot variable, and a non-slot Optimizer variable + # with known values to check when loading. + self.evaluate(model._named_dense.bias.assign([1.])) + self.evaluate(optimizer.get_slot( + var=model._named_dense.bias, name="m").assign([2.])) + beta1_power, _ = optimizer._get_beta_accumulators() + self.evaluate(beta1_power.assign(3.)) + return root_checkpointable + + def _set_sentinels(self, root_checkpointable): + self.evaluate(root_checkpointable.model._named_dense.bias.assign([101.])) + self.evaluate( + root_checkpointable.optimizer.get_slot( + var=root_checkpointable.model._named_dense.bias, name="m") + .assign([102.])) + beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators() + self.evaluate(beta1_power.assign(103.)) + + def _check_sentinels(self, root_checkpointable): + self.assertAllEqual( + [1.], self.evaluate(root_checkpointable.model._named_dense.bias)) + self.assertAllEqual([2.], self.evaluate( + root_checkpointable.optimizer.get_slot( + var=root_checkpointable.model._named_dense.bias, name="m"))) + beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators() + self.assertAllEqual(3., self.evaluate(beta1_power)) + + def _write_name_based_checkpoint(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + with context.graph_mode(): + save_graph = ops.Graph() + with save_graph.as_default(), self.test_session( + graph=save_graph) as session: + root = self._initialized_model() + name_saver = core_saver.Saver() + return name_saver.save( + sess=session, save_path=checkpoint_prefix, + global_step=root.optimizer_step) + + @test_util.run_in_graph_and_eager_modes() + def testLoadFromNameBasedSaver(self): + """Save a name-based checkpoint, load it using the object-based API.""" + with test_util.device(use_gpu=True): + save_path = self._write_name_based_checkpoint() + root = self._initialized_model() + self._set_sentinels(root) + with self.assertRaises(AssertionError): + self._check_sentinels(root) + object_saver = checkpointable_utils.CheckpointableSaver(root) + status = object_saver.restore(save_path) + with self.assertRaises(AssertionError): + status.assert_consumed() + status.run_restore_ops() + self._check_sentinels(root) + self._set_sentinels(root) + status.initialize_or_restore() + self._check_sentinels(root) + + # TODO(allenl): Test for the core name-based saver loading object-based + # checkpoints once object-based checkpointing is in core. + + def testSaveGraphLoadEager(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + with context.graph_mode(): + save_graph = ops.Graph() + with save_graph.as_default(), self.test_session( + graph=save_graph) as session: + root = self._initialized_model() + object_saver = checkpointable_utils.CheckpointableSaver(root) + save_path = object_saver.save( + session=session, file_prefix=checkpoint_prefix) + with context.eager_mode(): + root = self._initialized_model() + self._set_sentinels(root) + root.restore(save_path).assert_consumed() + self._check_sentinels(root) + + def testSaveEagerLoadGraph(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + with context.eager_mode(): + root = self._initialized_model() + object_saver = checkpointable_utils.CheckpointableSaver(root) + save_path = object_saver.save(file_prefix=checkpoint_prefix) + with context.graph_mode(): + save_graph = ops.Graph() + with save_graph.as_default(), self.test_session( + graph=save_graph): + root = self._initialized_model() + self._set_sentinels(root) + root.restore(save_path).assert_consumed().run_restore_ops() + self._check_sentinels(root) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index d177bfeab2d1fdc05d7ced54df8723fae2c77fdb..a4c3283dac9194880a1297371ea7591af6dddb2b 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -27,11 +27,12 @@ from tensorflow.python.data.util import sparse from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training import checkpointable +from tensorflow.python.training.saver import BaseSaverBuilder _uid_counter = 0 _uid_lock = threading.Lock() @@ -45,8 +46,13 @@ def _generate_shared_name(prefix): return "{}{}".format(prefix, uid) -class Iterator(object): - """An iterator producing tf.Tensor objects from a tf.data.Dataset.""" +class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): + """An iterator producing tf.Tensor objects from a tf.data.Dataset. + + NOTE: Unlike the iterator created by the + @{tf.data.Dataset.make_one_shot_iterator} method, this class enables + additional experimental functionality, such as prefetching to the GPU. + """ def __init__(self, dataset): """Creates a new iterator over the given dataset. @@ -67,37 +73,12 @@ class Iterator(object): Raises: RuntimeError: When invoked without eager execution enabled. """ - - if not context.in_eager_mode(): - raise RuntimeError( - "{} objects can only be used when eager execution is enabled, use " - "tf.data.Dataset.make_iterator or " - "tf.data.Dataset.make_one_shot_iterator for graph construction". - format(type(self))) - with ops.device("/device:CPU:0"): - ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access - self._output_classes = dataset.output_classes - self._output_types = dataset.output_types - self._output_shapes = dataset.output_shapes - self._flat_output_types = nest.flatten( - sparse.as_dense_types(self._output_types, self._output_classes)) - self._flat_output_shapes = nest.flatten( - sparse.as_dense_shapes(self._output_shapes, self._output_classes)) - self._resource = gen_dataset_ops.iterator( - shared_name="", - container=_generate_shared_name("eageriterator"), - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - gen_dataset_ops.make_iterator(ds_variant, self._resource) - # Delete the resource when this object is deleted - self._resource_deleter = resource_variable_ops.EagerResourceDeleter( - handle=self._resource, handle_device="/device:CPU:0") - self._device = context.context().device_name - self._buffer_resource_handle = None + super(Iterator, self).__init__(dataset) if not context.context().device_spec.device_type: is_remote_device = False else: is_remote_device = context.context().device_spec.device_type != "CPU" + self._buffer_resource_handle = None if is_remote_device: with ops.device("/device:CPU:0"): iter_string_handle = gen_dataset_ops.iterator_to_string_handle( @@ -106,7 +87,7 @@ class Iterator(object): @function.Defun(dtypes.string) def remote_fn(h): remote_iterator = iterator_ops.Iterator.from_string_handle( - h, self._output_types, self._output_shapes) + h, self.output_types, self.output_shapes, self.output_classes) return remote_iterator.get_next() remote_fn.add_to_graph(None) @@ -124,89 +105,43 @@ class Iterator(object): handle=self._buffer_resource_handle, handle_device=self._device) - def __iter__(self): - return self - - def __next__(self): # For Python 3 compatibility - return self.next() - def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ - with ops.device(self._device): - if self._buffer_resource_handle is not None: + if self._buffer_resource_handle is not None: + with ops.device(self._device): ret = prefetching_ops.function_buffering_resource_get_next( function_buffer_resource=self._buffer_resource_handle, output_types=self._flat_output_types) - else: - # TODO(ashankar): Consider removing this ops.device() contextmanager - # and instead mimic ops placement in graphs: Operations on resource - # handles execute on the same device as where the resource is placed. - # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next` - # because in eager mode this code will run synchronously on the calling - # thread. Therefore we do not need to make a defensive context switch - # to a background thread, and can achieve a small constant performance - # boost by invoking the iterator synchronously. - ret = gen_dataset_ops.iterator_get_next_sync( - self._resource, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - - return sparse.deserialize_sparse_tensors( - nest.pack_sequence_as(self._output_types, ret), self._output_types, - self._output_shapes, self._output_classes) - - def next(self): - """Returns a nested structure of `tf.Tensor`s containing the next element. - """ - try: - return self._next_internal() - except errors.OutOfRangeError: - raise StopIteration - - @property - def output_classes(self): - """Returns the class of each component of an element of this iterator. - - The expected values are `tf.Tensor` and `tf.SparseTensor`. - - Returns: - A nested structure of Python `type` objects corresponding to each - component of an element of this dataset. - """ - return self._output_classes - - @property - def output_shapes(self): - """Returns the shape of each component of an element of this iterator. + return sparse.deserialize_sparse_tensors( + nest.pack_sequence_as(self._output_types, ret), self._output_types, + self._output_shapes, self._output_classes) + else: + return super(Iterator, self)._next_internal() - Returns: - A nested structure of `tf.TensorShape` objects corresponding to each - component of an element of this dataset. - """ - return self._output_shapes + # TODO(shivaniagrawal): Expose checkpointable stateful objects from dataset + # attributes(potential). - @property - def output_types(self): - """Returns the type of each component of an element of this iterator. + class _Saveable(BaseSaverBuilder.SaveableObject): + """SaveableObject for saving/restoring iterator state.""" - Returns: - A nested structure of `tf.DType` objects corresponding to each component - of an element of this dataset. - """ - return self._output_types + def __init__(self, iterator_resource, name): + serialized_iterator = gen_dataset_ops.serialize_iterator( + iterator_resource) + specs = [ + BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE") + ] + # pylint: disable=protected-access + super(Iterator._Saveable, self).__init__(iterator_resource, specs, name) - def get_next(self, name=None): - """Returns a nested structure of `tf.Tensor`s containing the next element. + def restore(self, restored_tensors, restored_shapes): + with ops.colocate_with(self.op): + return gen_dataset_ops.deserialize_iterator(self.op, + restored_tensors[0]) - Args: - name: (Optional.) A name for the created operation. Currently unused. + def _gather_saveables_for_checkpoint(self): - Returns: - A nested structure of `tf.Tensor` objects. + def _saveable_factory(name): + return self._Saveable(self._resource, name) - Raises: - `tf.errors.OutOfRangeError`: If the end of the dataset has been reached. - """ - del name - return self._next_internal() + return {"ITERATOR": _saveable_factory} diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 35c3c5d3fad0a84bbe4d24c7bb17878583bded4b..c658505de41bb6a0007440f4850fef720c3e97f1 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -16,6 +16,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + import threading import time @@ -24,6 +26,7 @@ import numpy as np from tensorflow.contrib import lookup from tensorflow.contrib.data.python.ops import threadpool from tensorflow.contrib.data.python.ops import unique +from tensorflow.contrib.eager.python import checkpointable_utils from tensorflow.contrib.eager.python import datasets from tensorflow.python.data import Dataset from tensorflow.python.eager import test @@ -44,6 +47,18 @@ class IteratorTest(test.TestCase): got.append(t.numpy()) self.assertAllEqual([0, 1, 2, 3], got) + def testBasicOneShotIterator(self): + got = [] + for t in Dataset.range(4).make_one_shot_iterator(): + got.append(t.numpy()) + self.assertAllEqual([0, 1, 2, 3], got) + + def testBasicImplicitIterator(self): + got = [] + for t in Dataset.range(4): + got.append(t.numpy()) + self.assertAllEqual([0, 1, 2, 3], got) + def testGetNext(self): iterator = datasets.Iterator(Dataset.range(4)) self.assertEqual(0, iterator.get_next().numpy()) @@ -53,6 +68,15 @@ class IteratorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): iterator.get_next() + def testGetNextOneShotIterator(self): + iterator = Dataset.range(4).make_one_shot_iterator() + self.assertEqual(0, iterator.get_next().numpy()) + self.assertEqual(1, iterator.get_next().numpy()) + self.assertEqual(2, iterator.get_next().numpy()) + self.assertEqual(3, iterator.get_next().numpy()) + with self.assertRaises(errors.OutOfRangeError): + iterator.get_next() + def testMultipleIteratorsOnTheSameDataset(self): ds = Dataset.range(4) it1 = datasets.Iterator(ds) @@ -200,6 +224,61 @@ class IteratorTest(test.TestCase): # perform work. self.assertLessEqual(len(thread_ids), num_threads) + def testSaveRestore(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + dataset = dataset.map(math_ops.square).batch(2) + iterator = datasets.Iterator(dataset) + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + self.assertAllEqual([1, 4], iterator.get_next().numpy()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual([9, 16], iterator.get_next().numpy()) + self.assertAllEqual([25, 36], iterator.get_next().numpy()) + checkpoint.restore(save_path) + self.assertAllEqual([9, 16], iterator.get_next().numpy()) + self.assertAllEqual([25, 36], iterator.get_next().numpy()) + + def testSaveRestoreMultipleIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + dataset = dataset.map(math_ops.square).batch(2) + iterator_1 = datasets.Iterator(dataset) + iterator_2 = datasets.Iterator(dataset) + dataset_2 = Dataset.range(10) + iterator_3 = datasets.Iterator(dataset_2) + + checkpoint = checkpointable_utils.Checkpoint( + iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3) + self.assertAllEqual([1, 4], iterator_1.get_next().numpy()) + self.assertEqual(0, iterator_3.get_next().numpy()) + self.assertEqual(1, iterator_3.get_next().numpy()) + self.assertEqual(2, iterator_3.get_next().numpy()) + + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual([1, 4], iterator_2.get_next().numpy()) + self.assertAllEqual([9, 16], iterator_2.get_next().numpy()) + self.assertEqual(3, iterator_3.get_next().numpy()) + checkpoint.restore(save_path) + self.assertAllEqual([9, 16], iterator_1.get_next().numpy()) + self.assertAllEqual([1, 4], iterator_2.get_next().numpy()) + self.assertEqual(3, iterator_3.get_next().numpy()) + + def testRestoreExhaustedIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + dataset = Dataset.range(3) + iterator = datasets.Iterator(dataset) + + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + self.assertEqual(0, iterator.get_next().numpy()) + self.assertEqual(1, iterator.get_next().numpy()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertEqual(2, iterator.get_next().numpy()) + checkpoint.restore(save_path) + self.assertEqual(2, iterator.get_next().numpy()) + class DatasetConstructorBenchmark(test.Benchmark): diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py index 68e7b5421fec7f73f10e381ca45f9d900de299d7..37c8f0d47adbde6932bf409cdcae9a1845d700b5 100644 --- a/tensorflow/contrib/eager/python/evaluator.py +++ b/tensorflow/contrib/eager/python/evaluator.py @@ -57,7 +57,7 @@ class Evaluator(object): self._model = model self._metrics = {} self._evaluators = {} - if context.in_graph_mode(): + if not context.executing_eagerly(): self.call = function.defun(self.call) # ---- API for users ---- @@ -90,7 +90,7 @@ class Evaluator(object): Only for graph execution. @end_compatibility """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("Evaluator.init_variables() not needed when " "eager execution is enabled.") return control_flow_ops.group([m.init_variables() for _, m in self.metrics]) @@ -113,7 +113,8 @@ class Evaluator(object): with summary_ops.create_file_writer( summary_logdir).as_default(), summary_ops.always_record_summaries(): return self._all_metric_results() - if context.in_eager_mode(): + + if context.executing_eagerly(): return f() else: return function.defun(f)() @@ -158,16 +159,16 @@ class Evaluator(object): @end_compatibility """ summary_logdir = kwargs.pop("summary_logdir", None) - if context.in_graph_mode(): - call_op = self.__call__(dataset.make_one_shot_iterator().get_next(), - *args, **kwargs) - init_op = self.init_variables() - results_op = self.all_metric_results(summary_logdir) - return (init_op, call_op, results_op) - # Eager case - for example in datasets.Iterator(dataset): - self.__call__(example, *args, **kwargs) - return self.all_metric_results(summary_logdir) + if context.executing_eagerly(): + for example in datasets.Iterator(dataset): + self.__call__(example, *args, **kwargs) + return self.all_metric_results(summary_logdir) + # Graph construction + call_op = self.__call__(dataset.make_one_shot_iterator().get_next(), *args, + **kwargs) + init_op = self.init_variables() + results_op = self.all_metric_results(summary_logdir) + return (init_op, call_op, results_op) @staticmethod def run_evaluation(init_op, call_op, results_op, sess=None): @@ -192,7 +193,7 @@ class Evaluator(object): Only for graph execution. @end_compatibility """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("Evaluator.run_evaluation() not supported when " "eager execution is enabled.") sess = sess or ops.get_default_session() diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py index b9ac79f46c83bb709918e3b72830b90ddcfd71b4..b80c90902353709b7f739585291ec3b5890c27c7 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py @@ -32,10 +32,11 @@ import tensorflow as tf import tensorflow.contrib.eager as tfe from tensorflow.examples.tutorials.mnist import input_data +layers = tf.keras.layers FLAGS = None -class Discriminator(tfe.Network): +class Discriminator(tf.keras.Model): """GAN Discriminator. A network to differentiate between generated and real handwritten digits. @@ -56,19 +57,15 @@ class Discriminator(tfe.Network): else: assert data_format == 'channels_last' self._input_shape = [-1, 28, 28, 1] - self.conv1 = self.track_layer(tf.layers.Conv2D(64, 5, padding='SAME', - data_format=data_format, - activation=tf.tanh)) - self.pool1 = self.track_layer( - tf.layers.AveragePooling2D(2, 2, data_format=data_format)) - self.conv2 = self.track_layer(tf.layers.Conv2D(128, 5, - data_format=data_format, - activation=tf.tanh)) - self.pool2 = self.track_layer( - tf.layers.AveragePooling2D(2, 2, data_format=data_format)) - self.flatten = self.track_layer(tf.layers.Flatten()) - self.fc1 = self.track_layer(tf.layers.Dense(1024, activation=tf.tanh)) - self.fc2 = self.track_layer(tf.layers.Dense(1, activation=None)) + self.conv1 = layers.Conv2D( + 64, 5, padding='SAME', data_format=data_format, activation=tf.tanh) + self.pool1 = layers.AveragePooling2D(2, 2, data_format=data_format) + self.conv2 = layers.Conv2D( + 128, 5, data_format=data_format, activation=tf.tanh) + self.pool2 = layers.AveragePooling2D(2, 2, data_format=data_format) + self.flatten = layers.Flatten() + self.fc1 = layers.Dense(1024, activation=tf.tanh) + self.fc2 = layers.Dense(1, activation=None) def call(self, inputs): """Return two logits per image estimating input authenticity. @@ -95,7 +92,7 @@ class Discriminator(tfe.Network): return x -class Generator(tfe.Network): +class Generator(tf.keras.Model): """Generator of handwritten digits similar to the ones in the MNIST dataset. """ @@ -116,18 +113,17 @@ class Generator(tfe.Network): else: assert data_format == 'channels_last' self._pre_conv_shape = [-1, 6, 6, 128] - self.fc1 = self.track_layer(tf.layers.Dense(6 * 6 * 128, - activation=tf.tanh)) + self.fc1 = layers.Dense(6 * 6 * 128, activation=tf.tanh) # In call(), we reshape the output of fc1 to _pre_conv_shape # Deconvolution layer. Resulting image shape: (batch, 14, 14, 64) - self.conv1 = self.track_layer(tf.layers.Conv2DTranspose( - 64, 4, strides=2, activation=None, data_format=data_format)) + self.conv1 = layers.Conv2DTranspose( + 64, 4, strides=2, activation=None, data_format=data_format) # Deconvolution layer. Resulting image shape: (batch, 28, 28, 1) - self.conv2 = self.track_layer(tf.layers.Conv2DTranspose( - 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format)) + self.conv2 = layers.Conv2DTranspose( + 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format) def call(self, inputs): """Return a batch of generated images. @@ -168,7 +164,8 @@ def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs): """ loss_on_real = tf.losses.sigmoid_cross_entropy( - tf.ones_like(discriminator_real_outputs), discriminator_real_outputs, + tf.ones_like(discriminator_real_outputs), + discriminator_real_outputs, label_smoothing=0.25) loss_on_generated = tf.losses.sigmoid_cross_entropy( tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs) @@ -198,9 +195,9 @@ def generator_loss(discriminator_gen_outputs): return loss -def train_one_epoch(generator, discriminator, - generator_optimizer, discriminator_optimizer, - dataset, log_interval, noise_dim): +def train_one_epoch(generator, discriminator, generator_optimizer, + discriminator_optimizer, dataset, step_counter, + log_interval, noise_dim): """Trains `generator` and `discriminator` models on `dataset`. Args: @@ -209,7 +206,8 @@ def train_one_epoch(generator, discriminator, generator_optimizer: Optimizer to use for generator. discriminator_optimizer: Optimizer to use for discriminator. dataset: Dataset of images to train on. - log_interval: How many global steps to wait between logging and collecting + step_counter: An integer variable, used to write summaries regularly. + log_interval: How many steps to wait between logging and collecting summaries. noise_dim: Dimension of noise vector to use. """ @@ -218,18 +216,23 @@ def train_one_epoch(generator, discriminator, total_discriminator_loss = 0.0 for (batch_index, images) in enumerate(tfe.Iterator(dataset)): with tf.device('/cpu:0'): - tf.assign_add(tf.train.get_global_step(), 1) + tf.assign_add(step_counter, 1) - with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval): + with tf.contrib.summary.record_summaries_every_n_global_steps( + log_interval, global_step=step_counter): current_batch_size = images.shape[0] - noise = tf.random_uniform(shape=[current_batch_size, noise_dim], - minval=-1., maxval=1., seed=batch_index) + noise = tf.random_uniform( + shape=[current_batch_size, noise_dim], + minval=-1., + maxval=1., + seed=batch_index) with tfe.GradientTape(persistent=True) as g: generated_images = generator(noise) - tf.contrib.summary.image('generated_images', - tf.reshape(generated_images, [-1, 28, 28, 1]), - max_images=10) + tf.contrib.summary.image( + 'generated_images', + tf.reshape(generated_images, [-1, 28, 28, 1]), + max_images=10) discriminator_gen_outputs = discriminator(generated_images) discriminator_real_outputs = discriminator(images) @@ -244,18 +247,16 @@ def train_one_epoch(generator, discriminator, discriminator_grad = g.gradient(discriminator_loss_val, discriminator.variables) - with tf.variable_scope('generator'): - generator_optimizer.apply_gradients(zip(generator_grad, - generator.variables)) - with tf.variable_scope('discriminator'): - discriminator_optimizer.apply_gradients(zip(discriminator_grad, - discriminator.variables)) + generator_optimizer.apply_gradients( + zip(generator_grad, generator.variables)) + discriminator_optimizer.apply_gradients( + zip(discriminator_grad, discriminator.variables)) if log_interval and batch_index > 0 and batch_index % log_interval == 0: print('Batch #%d\tAverage Generator Loss: %.6f\t' - 'Average Discriminator Loss: %.6f' % ( - batch_index, total_generator_loss/batch_index, - total_discriminator_loss/batch_index)) + 'Average Discriminator Loss: %.6f' % + (batch_index, total_generator_loss / batch_index, + total_discriminator_loss / batch_index)) def main(_): @@ -266,18 +267,18 @@ def main(_): # Load the datasets data = input_data.read_data_sets(FLAGS.data_dir) - dataset = (tf.data.Dataset - .from_tensor_slices(data.train.images) - .shuffle(60000) - .batch(FLAGS.batch_size)) - - # Create the models and optimizers - generator = Generator(data_format) - discriminator = Discriminator(data_format) - with tf.variable_scope('generator'): - generator_optimizer = tf.train.AdamOptimizer(FLAGS.lr) - with tf.variable_scope('discriminator'): - discriminator_optimizer = tf.train.AdamOptimizer(FLAGS.lr) + dataset = ( + tf.data.Dataset.from_tensor_slices(data.train.images).shuffle(60000) + .batch(FLAGS.batch_size)) + + # Create the models and optimizers. + model_objects = { + 'generator': Generator(data_format), + 'discriminator': Discriminator(data_format), + 'generator_optimizer': tf.train.AdamOptimizer(FLAGS.lr), + 'discriminator_optimizer': tf.train.AdamOptimizer(FLAGS.lr), + 'step_counter': tf.train.get_or_create_global_step(), + } # Prepare summary writer and checkpoint info summary_writer = tf.contrib.summary.create_summary_file_writer( @@ -286,28 +287,22 @@ def main(_): latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if latest_cpkt: print('Using latest checkpoint at ' + latest_cpkt) + checkpoint = tfe.Checkpoint(**model_objects) + # Restore variables on creation if a checkpoint exists. + checkpoint.restore(latest_cpkt) with tf.device(device): - for epoch in range(1, 101): - with tfe.restore_variables_on_create(latest_cpkt): - global_step = tf.train.get_or_create_global_step() - start = time.time() - with summary_writer.as_default(): - train_one_epoch(generator, discriminator, generator_optimizer, - discriminator_optimizer, - dataset, FLAGS.log_interval, FLAGS.noise) - end = time.time() - print('\nTrain time for epoch #%d (global step %d): %f' % ( - epoch, global_step.numpy(), end - start)) - - all_variables = ( - generator.variables - + discriminator.variables - + generator_optimizer.variables() - + discriminator_optimizer.variables() - + [global_step]) - tfe.Saver(all_variables).save( - checkpoint_prefix, global_step=global_step) + for _ in range(100): + start = time.time() + with summary_writer.as_default(): + train_one_epoch(dataset=dataset, log_interval=FLAGS.log_interval, + noise_dim=FLAGS.noise, **model_objects) + end = time.time() + checkpoint.save(checkpoint_prefix) + print('\nTrain time for epoch #%d (step %d): %f' % + (checkpoint.save_counter.numpy(), + checkpoint.step_counter.numpy(), + end - start)) if __name__ == '__main__': diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py index 4a3ca8d82bc2619b05a734f6d2e58431c1a45995..bd35e50c1f434d167c5a8c5aa7d224912523ce28 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py @@ -62,7 +62,7 @@ class MnistEagerGanBenchmark(tf.test.Benchmark): for _ in range(measure_batches)] measure_dataset = tf.data.Dataset.from_tensor_slices(measure_images) - tf.train.get_or_create_global_step() + step_counter = tf.train.get_or_create_global_step() with tf.device(device()): # Create the models and optimizers generator = mnist.Generator(data_format()) @@ -78,13 +78,15 @@ class MnistEagerGanBenchmark(tf.test.Benchmark): # warm up mnist.train_one_epoch(generator, discriminator, generator_optimizer, discriminator_optimizer, - burn_dataset, log_interval=SUMMARY_INTERVAL, + burn_dataset, step_counter, + log_interval=SUMMARY_INTERVAL, noise_dim=NOISE_DIM) # measure start = time.time() mnist.train_one_epoch(generator, discriminator, generator_optimizer, discriminator_optimizer, - measure_dataset, log_interval=SUMMARY_INTERVAL, + measure_dataset, step_counter, + log_interval=SUMMARY_INTERVAL, noise_dim=NOISE_DIM) self._report('train', start, measure_batches, batch_size) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index 6ce4de6ee0bf50400eff339ac04e132252a2b53e..4e1380afb2e6e722de65c691d4fbf44621072e87 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -32,24 +32,16 @@ import tensorflow as tf import tensorflow.contrib.eager as tfe +layers = tf.keras.layers -class LinearModel(tfe.Network): - """A TensorFlow linear regression model. - Uses TensorFlow's eager execution. - - For those familiar with TensorFlow graphs, notice the absence of - `tf.Session`. The `forward()` method here immediately executes and - returns output values. The `loss()` method immediately compares the - output of `forward()` with the target and returns the MSE loss value. - The `fit()` performs gradient-descent training on the model's weights - and bias. - """ +class LinearModel(tf.keras.Model): + """A TensorFlow linear regression model.""" def __init__(self): """Constructs a LinearModel object.""" super(LinearModel, self).__init__() - self._hidden_layer = self.track_layer(tf.layers.Dense(1)) + self._hidden_layer = layers.Dense(1) def call(self, xs): """Invoke the linear model. @@ -64,7 +56,7 @@ class LinearModel(tfe.Network): def mean_square_loss(model, xs, ys): - return tf.reduce_mean(tf.square(model(xs) - ys)) + return tf.reduce_mean(tf.square(tf.subtract(model(xs), ys))) def fit(model, dataset, optimizer, verbose=False, logdir=None): diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py index 9982fdb07eefa665379e7be095f4f8017d92cf97..a28bc8a43d7c90737c9baf9a634d736e9de52948 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py @@ -27,10 +27,11 @@ from __future__ import print_function import functools import tensorflow as tf -import tensorflow.contrib.eager as tfe +layers = tf.keras.layers -class _IdentityBlock(tfe.Network): + +class _IdentityBlock(tf.keras.Model): """_IdentityBlock is the block that has no conv layer at shortcut. Args: @@ -50,31 +51,24 @@ class _IdentityBlock(tfe.Network): bn_name_base = 'bn' + str(stage) + block + '_branch' bn_axis = 1 if data_format == 'channels_first' else 3 - self.conv2a = self.track_layer( - tf.layers.Conv2D( - filters1, (1, 1), - name=conv_name_base + '2a', - data_format=data_format)) - self.bn2a = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')) - - self.conv2b = self.track_layer( - tf.layers.Conv2D( - filters2, - kernel_size, - padding='same', - data_format=data_format, - name=conv_name_base + '2b')) - self.bn2b = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')) - - self.conv2c = self.track_layer( - tf.layers.Conv2D( - filters3, (1, 1), - name=conv_name_base + '2c', - data_format=data_format)) - self.bn2c = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')) + self.conv2a = layers.Conv2D( + filters1, (1, 1), name=conv_name_base + '2a', data_format=data_format) + self.bn2a = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2a') + + self.conv2b = layers.Conv2D( + filters2, + kernel_size, + padding='same', + data_format=data_format, + name=conv_name_base + '2b') + self.bn2b = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2b') + + self.conv2c = layers.Conv2D( + filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format) + self.bn2c = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2c') def call(self, input_tensor, training=False): x = self.conv2a(input_tensor) @@ -92,7 +86,7 @@ class _IdentityBlock(tfe.Network): return tf.nn.relu(x) -class _ConvBlock(tfe.Network): +class _ConvBlock(tf.keras.Model): """_ConvBlock is the block that has a conv layer at shortcut. Args: @@ -121,41 +115,35 @@ class _ConvBlock(tfe.Network): bn_name_base = 'bn' + str(stage) + block + '_branch' bn_axis = 1 if data_format == 'channels_first' else 3 - self.conv2a = self.track_layer( - tf.layers.Conv2D( - filters1, (1, 1), - strides=strides, - name=conv_name_base + '2a', - data_format=data_format)) - self.bn2a = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')) - - self.conv2b = self.track_layer( - tf.layers.Conv2D( - filters2, - kernel_size, - padding='same', - name=conv_name_base + '2b', - data_format=data_format)) - self.bn2b = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')) - - self.conv2c = self.track_layer( - tf.layers.Conv2D( - filters3, (1, 1), - name=conv_name_base + '2c', - data_format=data_format)) - self.bn2c = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')) - - self.conv_shortcut = self.track_layer( - tf.layers.Conv2D( - filters3, (1, 1), - strides=strides, - name=conv_name_base + '1', - data_format=data_format)) - self.bn_shortcut = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '1')) + self.conv2a = layers.Conv2D( + filters1, (1, 1), + strides=strides, + name=conv_name_base + '2a', + data_format=data_format) + self.bn2a = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2a') + + self.conv2b = layers.Conv2D( + filters2, + kernel_size, + padding='same', + name=conv_name_base + '2b', + data_format=data_format) + self.bn2b = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2b') + + self.conv2c = layers.Conv2D( + filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format) + self.bn2c = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2c') + + self.conv_shortcut = layers.Conv2D( + filters3, (1, 1), + strides=strides, + name=conv_name_base + '1', + data_format=data_format) + self.bn_shortcut = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '1') def call(self, input_tensor, training=False): x = self.conv2a(input_tensor) @@ -176,7 +164,8 @@ class _ConvBlock(tfe.Network): return tf.nn.relu(x) -class ResNet50(tfe.Network): +# pylint: disable=not-callable +class ResNet50(tf.keras.Model): """Instantiates the ResNet50 architecture. Args: @@ -220,32 +209,28 @@ class ResNet50(tfe.Network): self.include_top = include_top def conv_block(filters, stage, block, strides=(2, 2)): - l = _ConvBlock( + return _ConvBlock( 3, filters, stage=stage, block=block, data_format=data_format, strides=strides) - return self.track_layer(l) def id_block(filters, stage, block): - l = _IdentityBlock( + return _IdentityBlock( 3, filters, stage=stage, block=block, data_format=data_format) - return self.track_layer(l) - - self.conv1 = self.track_layer( - tf.layers.Conv2D( - 64, (7, 7), - strides=(2, 2), - data_format=data_format, - padding='same', - name='conv1')) + + self.conv1 = layers.Conv2D( + 64, (7, 7), + strides=(2, 2), + data_format=data_format, + padding='same', + name='conv1') bn_axis = 1 if data_format == 'channels_first' else 3 - self.bn_conv1 = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name='bn_conv1')) - self.max_pool = self.track_layer( - tf.layers.MaxPooling2D((3, 3), strides=(2, 2), data_format=data_format)) + self.bn_conv1 = layers.BatchNormalization(axis=bn_axis, name='bn_conv1') + self.max_pool = layers.MaxPooling2D( + (3, 3), strides=(2, 2), data_format=data_format) self.l2a = conv_block([64, 64, 256], stage=2, block='a', strides=(1, 1)) self.l2b = id_block([64, 64, 256], stage=2, block='b') @@ -267,13 +252,12 @@ class ResNet50(tfe.Network): self.l5b = id_block([512, 512, 2048], stage=5, block='b') self.l5c = id_block([512, 512, 2048], stage=5, block='c') - self.avg_pool = self.track_layer( - tf.layers.AveragePooling2D( - (7, 7), strides=(7, 7), data_format=data_format)) + self.avg_pool = layers.AveragePooling2D( + (7, 7), strides=(7, 7), data_format=data_format) if self.include_top: - self.fc1000 = self.track_layer( - tf.layers.Dense(classes, name='fc1000')) + self.flatten = layers.Flatten() + self.fc1000 = layers.Dense(classes, name='fc1000') else: reduction_indices = [1, 2] if data_format == 'channels_last' else [2, 3] reduction_indices = tf.constant(reduction_indices) @@ -288,7 +272,7 @@ class ResNet50(tfe.Network): else: self.global_pooling = None - def call(self, input_tensor, training=False): + def call(self, input_tensor, training): x = self.conv1(input_tensor) x = self.bn_conv1(x, training=training) x = tf.nn.relu(x) @@ -317,7 +301,7 @@ class ResNet50(tfe.Network): x = self.avg_pool(x) if self.include_top: - return self.fc1000(tf.layers.flatten(x)) + return self.fc1000(self.flatten(x)) elif self.global_pooling: return self.global_pooling(x) else: diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index 23317886e712323f4b520000e0fd372734fc53a1..551c76b0df71c88919df9cd6d81b4176b23b0ba3 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -55,7 +55,7 @@ class ResNet50GraphTest(tf.test.TestCase): with tf.Graph().as_default(): images = tf.placeholder(tf.float32, image_shape(None)) model = resnet50.ResNet50(data_format()) - predictions = model(images) + predictions = model(images, training=False) init = tf.global_variables_initializer() @@ -114,7 +114,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): with tf.Graph().as_default(): images = tf.placeholder(tf.float32, image_shape(None)) model = resnet50.ResNet50(data_format()) - predictions = model(images) + predictions = model(images, training=False) init = tf.global_variables_initializer() diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index 0ff8746884c288f824f5f22ab4c550370d0e0302..d6923293a374f29ab77be70fa9fea44efd1ea40b 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -64,28 +64,35 @@ def train_one_step(model, images, labels, optimizer): class ResNet50Test(tf.test.TestCase): - def _apply(self, defun=False): + def _apply(self, defun=False, execution_mode=None): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) if defun: model.call = tfe.defun(model.call) - with tf.device(device): + with tf.device(device), tfe.execution_mode(execution_mode): images, _ = random_batch(2) - output = model(images) + output = model(images, training=False) + tfe.async_wait() self.assertEqual((2, 1000), output.shape) def test_apply(self): self._apply(defun=False) + def test_apply_async(self): + self._apply(defun=False, execution_mode=tfe.ASYNC) + def test_apply_with_defun(self): self._apply(defun=True) + def test_apply_with_defun_async(self): + self._apply(defun=True, execution_mode=tfe.ASYNC) + def test_apply_no_top(self): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False) with tf.device(device): images, _ = random_batch(2) - output = model(images) + output = model(images, training=False) output_shape = ((2, 2048, 1, 1) if data_format == 'channels_first' else (2, 1, 1, 2048)) self.assertEqual(output_shape, output.shape) @@ -95,10 +102,10 @@ class ResNet50Test(tf.test.TestCase): model = resnet50.ResNet50(data_format, include_top=False, pooling='avg') with tf.device(device): images, _ = random_batch(2) - output = model(images) + output = model(images, training=False) self.assertEqual((2, 2048), output.shape) - def test_train(self): + def _test_train(self, execution_mode=None): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) tf.train.get_or_create_global_step() @@ -106,15 +113,22 @@ class ResNet50Test(tf.test.TestCase): with tf.contrib.summary.create_file_writer( logdir, max_queue=0, name='t0').as_default(), tf.contrib.summary.always_record_summaries(): - with tf.device(device): + with tf.device(device), tfe.execution_mode(execution_mode): optimizer = tf.train.GradientDescentOptimizer(0.1) images, labels = random_batch(2) train_one_step(model, images, labels, optimizer) self.assertEqual(320, len(model.variables)) + tfe.async_wait() events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'loss') + def test_train(self): + self._test_train() + + def test_train_async(self): + self._test_train(execution_mode=tfe.ASYNC) + def test_no_garbage(self): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) @@ -183,59 +197,84 @@ class ResNet50Benchmarks(tf.test.Benchmark): # a sync. This is a roundabout way, yes. tf.constant(1.).cpu() - def _benchmark_eager_apply(self, label, defun=False): - device, data_format = device_and_data_format() - model = resnet50.ResNet50(data_format) - if defun: - model.call = tfe.defun(model.call) - batch_size = 64 - num_burn = 5 - num_iters = 30 - with tf.device(device): - images, _ = random_batch(batch_size) - for _ in xrange(num_burn): - model(images).cpu() - gc.collect() - start = time.time() - for _ in xrange(num_iters): - model(images).cpu() - self._report(label, start, num_iters, device, batch_size, data_format) - - def benchmark_eager_apply(self): - self._benchmark_eager_apply('eager_apply', defun=False) - - def benchmark_eager_apply_with_defun(self): - self._benchmark_eager_apply('eager_apply_with_defun', defun=True) - - def _benchmark_eager_train(self, label, make_iterator, defun=False): - device, data_format = device_and_data_format() - for batch_size in self._train_batch_sizes(): - (images, labels) = random_batch(batch_size) - num_burn = 3 - num_iters = 10 + def _benchmark_eager_apply(self, label, defun=False, execution_mode=None): + with tfe.execution_mode(execution_mode): + device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) if defun: model.call = tfe.defun(model.call) - optimizer = tf.train.GradientDescentOptimizer(0.1) - + batch_size = 64 + num_burn = 5 + num_iters = 30 with tf.device(device): - iterator = make_iterator((images, labels)) + images, _ = random_batch(batch_size) for _ in xrange(num_burn): - (images, labels) = iterator.next() - train_one_step(model, images, labels, optimizer) - self._force_gpu_sync() + model(images, training=False).cpu() + if execution_mode: + tfe.async_wait() gc.collect() - start = time.time() for _ in xrange(num_iters): - (images, labels) = iterator.next() - train_one_step(model, images, labels, optimizer) - self._force_gpu_sync() + model(images, training=False).cpu() + if execution_mode: + tfe.async_wait() self._report(label, start, num_iters, device, batch_size, data_format) + def benchmark_eager_apply(self): + self._benchmark_eager_apply('eager_apply', defun=False) + + def benchmark_eager_apply_async(self): + self._benchmark_eager_apply( + 'eager_apply_async', defun=False, execution_mode=tfe.ASYNC) + + def benchmark_eager_apply_with_defun(self): + self._benchmark_eager_apply('eager_apply_with_defun', defun=True) + + def _benchmark_eager_train(self, + label, + make_iterator, + defun=False, + execution_mode=None): + with tfe.execution_mode(execution_mode): + device, data_format = device_and_data_format() + for batch_size in self._train_batch_sizes(): + (images, labels) = random_batch(batch_size) + num_burn = 3 + num_iters = 10 + model = resnet50.ResNet50(data_format) + if defun: + model.call = tfe.defun(model.call) + optimizer = tf.train.GradientDescentOptimizer(0.1) + + with tf.device(device): + iterator = make_iterator((images, labels)) + for _ in xrange(num_burn): + (images, labels) = iterator.next() + train_one_step(model, images, labels, optimizer) + if execution_mode: + tfe.async_wait() + self._force_gpu_sync() + gc.collect() + + start = time.time() + for _ in xrange(num_iters): + (images, labels) = iterator.next() + train_one_step(model, images, labels, optimizer) + if execution_mode: + tfe.async_wait() + self._force_gpu_sync() + self._report(label, start, num_iters, device, batch_size, data_format) + def benchmark_eager_train(self): self._benchmark_eager_train('eager_train', MockIterator, defun=False) + def benchmark_eager_train_async(self): + self._benchmark_eager_train( + 'eager_train_async', + MockIterator, + defun=False, + execution_mode=tfe.ASYNC) + def benchmark_eager_train_with_defun(self): self._benchmark_eager_train( 'eager_train_with_defun', MockIterator, defun=True) diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index aa87b94e7b0876e65405f6bcb2d6aabde36582bf..492adbe1d80941f9df96d6636e4933d11239408e 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -60,6 +60,7 @@ import functools import os import sys import time +import urllib import six import tensorflow as tf @@ -72,6 +73,8 @@ try: except ImportError: HAS_MATPLOTLIB = False +layers = tf.keras.layers + def parse(line): """Parse a line from the colors dataset.""" @@ -89,13 +92,35 @@ def parse(line): return rgb, chars, length +def maybe_download(filename, work_directory, source_url): + """Download the data from source url, unless it's already here. + + Args: + filename: string, name of the file in the directory. + work_directory: string, path to working directory. + source_url: url to download from if file doesn't exist. + + Returns: + Path to resulting file. + """ + if not tf.gfile.Exists(work_directory): + tf.gfile.MakeDirs(work_directory) + filepath = os.path.join(work_directory, filename) + if not tf.gfile.Exists(filepath): + temp_file_name, _ = urllib.request.urlretrieve(source_url) + tf.gfile.Copy(temp_file_name, filepath) + with tf.gfile.GFile(filepath) as f: + size = f.size() + print("Successfully downloaded", filename, size, "bytes.") + return filepath + + def load_dataset(data_dir, url, batch_size): """Loads the colors data at path into a PaddedDataset.""" # Downloads data at url into data_dir/basename(url). The dataset has a header # row (color_name, r, g, b) followed by comma-separated lines. - path = tf.contrib.learn.datasets.base.maybe_download( - os.path.basename(url), data_dir, url) + path = maybe_download(os.path.basename(url), data_dir, url) # This chain of commands loads our data by: # 1. skipping the header; (.skip(1)) @@ -109,7 +134,7 @@ def load_dataset(data_dir, url, batch_size): # pylint: disable=not-callable -class RNNColorbot(tfe.Network): +class RNNColorbot(tf.keras.Model): """Multi-layer (LSTM) RNN that regresses on real-valued vector labels. """ @@ -127,23 +152,20 @@ class RNNColorbot(tfe.Network): self.label_dimension = label_dimension self.keep_prob = keep_prob - # Note the calls to `track_layer` below; these calls register the layers as - # network components that house trainable variables. - self.cells = [ - self.track_layer(tf.nn.rnn_cell.BasicLSTMCell(size)) - for size in rnn_cell_sizes - ] - self.relu = self.track_layer( - tf.layers.Dense(label_dimension, activation=tf.nn.relu, name="relu")) + self.cells = self._add_cells( + [tf.nn.rnn_cell.BasicLSTMCell(size) for size in rnn_cell_sizes]) + self.relu = layers.Dense( + label_dimension, activation=tf.nn.relu, name="relu") - def call(self, chars, sequence_length, training=False): + def call(self, inputs, training=False): """Implements the RNN logic and prediction generation. Args: - chars: a Tensor of dimension [batch_size, time_steps, 256] holding a - batch of one-hot encoded color names - sequence_length: a Tensor of dimension [batch_size] holding the length - of each character sequence (i.e., color name) + inputs: A tuple (chars, sequence_length), where chars is a batch of + one-hot encoded color names represented as a Tensor with dimensions + [batch_size, time_steps, 256] and sequence_length holds the length + of each character sequence (color name) as a Tensor with dimension + [batch_size]. training: whether the invocation is happening during training Returns: @@ -151,6 +173,7 @@ class RNNColorbot(tfe.Network): passing chars through a multi-layer RNN and applying a ReLU to the final hidden state. """ + (chars, sequence_length) = inputs # Transpose the first and second dimensions so that chars is of shape # [time_steps, batch_size, dimension]. chars = tf.transpose(chars, [1, 0, 2]) @@ -181,6 +204,14 @@ class RNNColorbot(tfe.Network): hidden_states = tf.gather_nd(chars, indices) return self.relu(hidden_states) + def _add_cells(self, cells): + # "Magic" required for keras.Model classes to track all the variables in + # a list of layers.Layer objects. + # TODO(ashankar): Figure out API so user code doesn't have to do this. + for i, c in enumerate(cells): + setattr(self, "cell-%d" % i, c) + return cells + def loss(labels, predictions): """Computes mean squared loss.""" @@ -191,7 +222,7 @@ def test(model, eval_data): """Computes the average loss on eval_data, which should be a Dataset.""" avg_loss = tfe.metrics.Mean("loss") for (labels, chars, sequence_length) in tfe.Iterator(eval_data): - predictions = model(chars, sequence_length, training=False) + predictions = model((chars, sequence_length), training=False) avg_loss(loss(labels, predictions)) print("eval/loss: %.6f\n" % avg_loss.result()) with tf.contrib.summary.always_record_summaries(): @@ -204,7 +235,7 @@ def train_one_epoch(model, optimizer, train_data, log_interval=10): tf.train.get_or_create_global_step() def model_loss(labels, chars, sequence_length): - predictions = model(chars, sequence_length, training=True) + predictions = model((chars, sequence_length), training=True) loss_value = loss(labels, predictions) tf.contrib.summary.scalar("loss", loss_value) return loss_value @@ -277,7 +308,7 @@ def main(_): (chars, length) = (tf.identity(chars), tf.identity(length)) chars = tf.expand_dims(chars, 0) length = tf.expand_dims(length, 0) - preds = tf.unstack(model(chars, length, training=False)[0]) + preds = tf.unstack(model((chars, length), training=False)[0]) # Predictions cannot be negative, as they are generated by a ReLU layer; # they may, however, be greater than 1. diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index 5c5c59c87744f4ffa6db90e5d8d3aa3bc8132756..a90048d813bf345e8be32e9674a452175471b268 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -38,22 +38,26 @@ import tensorflow as tf from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn from tensorflow.contrib.eager.python import tfe +layers = tf.keras.layers -class RNN(tfe.Network): + +class RNN(tf.keras.Model): """A static RNN. - Similar to tf.nn.static_rnn, implemented as a tf.layer.Layer. + Similar to tf.nn.static_rnn, implemented as a class. """ def __init__(self, hidden_dim, num_layers, keep_ratio): super(RNN, self).__init__() self.keep_ratio = keep_ratio - for _ in range(num_layers): - self.track_layer(tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim)) + self.cells = self._add_cells([ + tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim) + for _ in range(num_layers) + ]) def call(self, input_seq, training): batch_size = int(input_seq.shape[1]) - for c in self.layers: + for c in self.cells: state = c.zero_state(batch_size, tf.float32) outputs = [] input_seq = tf.unstack(input_seq, num=int(input_seq.shape[0]), axis=0) @@ -64,10 +68,22 @@ class RNN(tfe.Network): input_seq = tf.stack(outputs, axis=0) if training: input_seq = tf.nn.dropout(input_seq, self.keep_ratio) - return input_seq, None - - -class Embedding(tf.layers.Layer): + # Returning a list instead of a single tensor so that the line: + # y = self.rnn(y, ...)[0] + # in PTBModel.call works for both this RNN and CudnnLSTM (which returns a + # tuple (output, output_states). + return [input_seq] + + def _add_cells(self, cells): + # "Magic" required for keras.Model classes to track all the variables in + # a list of Layer objects. + # TODO(ashankar): Figure out API so user code doesn't have to do this. + for i, c in enumerate(cells): + setattr(self, "cell-%d" % i, c) + return cells + + +class Embedding(layers.Layer): """An Embedding layer.""" def __init__(self, vocab_size, embedding_dim, **kwargs): @@ -87,7 +103,8 @@ class Embedding(tf.layers.Layer): return tf.nn.embedding_lookup(self.embedding, x) -class PTBModel(tfe.Network): +# pylint: disable=not-callable +class PTBModel(tf.keras.Model): """LSTM for word language modeling. Model described in: @@ -109,19 +126,16 @@ class PTBModel(tfe.Network): self.keep_ratio = 1 - dropout_ratio self.use_cudnn_rnn = use_cudnn_rnn - self.embedding = self.track_layer(Embedding(vocab_size, embedding_dim)) + self.embedding = Embedding(vocab_size, embedding_dim) if self.use_cudnn_rnn: self.rnn = cudnn_rnn.CudnnLSTM( num_layers, hidden_dim, dropout=dropout_ratio) else: self.rnn = RNN(hidden_dim, num_layers, self.keep_ratio) - self.track_layer(self.rnn) - self.linear = self.track_layer( - tf.layers.Dense( - vocab_size, - kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1))) + self.linear = layers.Dense( + vocab_size, kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1)) self._output_shape = [-1, embedding_dim] def call(self, input_seq, training): @@ -136,7 +150,7 @@ class PTBModel(tfe.Network): y = self.embedding(input_seq) if training: y = tf.nn.dropout(y, self.keep_ratio) - y, _ = self.rnn(y, training=training) + y = self.rnn(y, training=training)[0] return self.linear(tf.reshape(y, self._output_shape)) @@ -148,7 +162,7 @@ def clip_gradients(grads_and_vars, clip_ratio): def loss_fn(model, inputs, targets, training): labels = tf.reshape(targets, [-1]) - outputs = model(inputs, training) + outputs = model(inputs, training=training) return tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=outputs)) diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD index a1f8a759e2a556bc219f0aa13942f293c4f34cfa..5966f1d4873e8e77b3ad5914da7bfc7e69d4e341 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/BUILD +++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD @@ -38,5 +38,9 @@ cuda_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", ], - tags = ["no_pip"], # because spinn.py is under third_party/. + tags = [ + "no-internal-py3", # flaky + "no_cuda_on_cpu_tap", + "no_pip", # because spinn.py is under third_party/. + ], ) diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index 081b0af14fcc983a3f85d2a50e2bb04d2f2493b3..667365341829124060b724b8a5d6e542149ba704 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -33,6 +33,7 @@ import tensorflow as tf import tensorflow.contrib.eager as tfe from tensorflow.contrib.eager.python.examples.spinn import data from third_party.examples.eager.spinn import spinn +from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2 from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import test from tensorflow.python.framework import test_util @@ -172,7 +173,7 @@ class SpinnTest(test_util.TensorFlowTestCase): right_in.append(tf.random_normal((1, size * 2))) tracking.append(tf.random_normal((1, tracker_size * 2))) - out = reducer(left_in, right_in, tracking=tracking) + out = reducer(left_in, right_in=right_in, tracking=tracking) self.assertEqual(batch_size, len(out)) self.assertEqual(tf.float32, out[0].dtype) self.assertEqual((1, size * 2), out[0].shape) @@ -226,7 +227,7 @@ class SpinnTest(test_util.TensorFlowTestCase): self.assertEqual((batch_size, size * 2), stacks[0][0].shape) for _ in range(2): - out1, out2 = tracker(bufs, stacks) + out1, out2 = tracker(bufs, stacks=stacks) self.assertIsNone(out2) self.assertEqual(batch_size, len(out1)) self.assertEqual(tf.float32, out1[0].dtype) @@ -259,7 +260,7 @@ class SpinnTest(test_util.TensorFlowTestCase): self.assertEqual(tf.int64, transitions.dtype) self.assertEqual((num_transitions, 1), transitions.shape) - out = s(buffers, transitions, training=True) + out = s(buffers, transitions=transitions, training=True) self.assertEqual(tf.float32, out.dtype) self.assertEqual((1, embedding_dims), out.shape) @@ -285,12 +286,15 @@ class SpinnTest(test_util.TensorFlowTestCase): vocab_size) # Invoke model under non-training mode. - logits = model(prem, prem_trans, hypo, hypo_trans, training=False) + logits = model( + prem, premise_transition=prem_trans, hypothesis=hypo, + hypothesis_transition=hypo_trans, training=False) self.assertEqual(tf.float32, logits.dtype) self.assertEqual((batch_size, d_out), logits.shape) # Invoke model under training model. - logits = model(prem, prem_trans, hypo, hypo_trans, training=True) + logits = model(prem, premise_transition=prem_trans, hypothesis=hypo, + hypothesis_transition=hypo_trans, training=True) self.assertEqual(tf.float32, logits.dtype) self.assertEqual((batch_size, d_out), logits.shape) @@ -417,12 +421,17 @@ class SpinnTest(test_util.TensorFlowTestCase): if event.summary.value and event.summary.value[0].tag == "train/loss"] self.assertEqual(config.epochs, len(train_losses)) - self.assertLess(train_losses[-1], train_losses[0]) # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) - ckpt_variable_names = [ - item[0] for item in checkpoint_utils.list_variables(config.logdir)] + object_graph_string = checkpoint_utils.load_variable( + config.logdir, name="_CHECKPOINTABLE_OBJECT_GRAPH") + object_graph = checkpointable_object_graph_pb2.CheckpointableObjectGraph() + object_graph.ParseFromString(object_graph_string) + ckpt_variable_names = set() + for node in object_graph.nodes: + for attribute in node.attributes: + ckpt_variable_names.add(attribute.full_name) self.assertIn("global_step", ckpt_variable_names) for v in trainer.variables: variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md index ebb05051f27841f1cd3d21b6218986e774ed4c9f..11064981c6257a607f88c6f4414418c8d1f8eac7 100644 --- a/tensorflow/contrib/eager/python/g3doc/guide.md +++ b/tensorflow/contrib/eager/python/g3doc/guide.md @@ -273,9 +273,9 @@ assert 6 == df(3.)[0].numpy() d2f = tfe.gradients_function(lambda x: df(x)[0]) assert 2 == d2f(3.)[0].numpy() -# Third order derivative. +# Third order derivative: Will be None d3f = tfe.gradients_function(lambda x : d2f(x)[0]) -assert 0 == d3f(3.)[0].numpy() +assert None == d3f(3.)[0] ``` These functions can be used to train models. For example, consider the following @@ -574,49 +574,45 @@ repository](https://github.com/tensorflow/models/tree/master/official/mnist/mnis ### Checkpointing trained variables -TensorFlow Variables (`tfe.Variable`) provides a way to represent shared, -persistent state of your model. The `tfe.Saver` class (which is a thin wrapper -over the -[`tf.train.Saver`](https://www.tensorflow.org/api_docs/python/tf/train/Saver) -class) provides a means to save and restore variables to and from _checkpoints_. +TensorFlow Variables (`tfe.Variable`) provide a way to represent shared, +persistent state of your model. The `tfe.Checkpoint` class provides a means to +save and restore variables to and from _checkpoints_. For example: ```python # Create variables. -x = tfe.Variable(10., name='x') -y = tfe.Variable(5., name='y') +x = tfe.Variable(10.) +y = tfe.Variable(5.) -# Create a Saver. -saver = tfe.Saver([x, y]) +# Indicate that the variables should be saved as "x" and "y". +checkpoint = tfe.Checkpoint(x=x, y=y) # Assign new values to the variables and save. x.assign(2.) -saver.save('/tmp/ckpt') +save_path = checkpoint.save('/tmp/ckpt') # Change the variable after saving. x.assign(11.) assert 16. == (x + y).numpy() # 11 + 5 # Restore the values in the checkpoint. -saver.restore('/tmp/ckpt') +checkpoint.restore(save_path) # save_path='/tmp/ckpt-1' assert 7. == (x + y).numpy() # 2 + 5 ``` -### `tfe.Network` +### `tf.keras.Model` You may often want to organize your models using classes, like the `MNISTModel` -class described above. We recommend inheriting from the `tfe.Network` class as -it provides conveniences like keeping track of all model variables and methods -to save and restore from checkpoints. +class described above. We recommend inheriting from the `tf.keras.Model` class +as it provides conveniences like keeping track of all model variables. -Sub-classes of `tfe.Network` may register `Layer`s (like classes in -[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers), -or [Keras -layers](https://www.tensorflow.org/api_docs/python/tf/keras/layers)) -using a call to `self.track_layer()` and define the computation in an -implementation of `call()`. +Sub-classes of `tf.keras.Model` may register `Layer`s (like classes in +[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers), or [Keras +layers](https://www.tensorflow.org/api_docs/python/tf/keras/layers)) by +assigning them to attributes (`self.name = layer_object`) and define the +computation in an implementation of `call()`. Note that `tf.layers.Layer` objects (like `tf.layers.Dense`) create variables lazily, when the first input is encountered. @@ -624,12 +620,11 @@ lazily, when the first input is encountered. For example, consider the following two-layer neural network: ```python -class TwoLayerNet(tfe.Network): +class TwoLayerNet(tf.keras.Model): def __init__(self): super(TwoLayerNet, self).__init__() - self.layer1 = self.track_layer( - tf.layers.Dense(2, activation=tf.nn.relu, use_bias=False)) - self.layer2 = self.track_layer(tf.layers.Dense(3, use_bias=False)) + self.layer1 = tf.layers.Dense(2, activation=tf.nn.relu, use_bias=False) + self.layer2 = tf.layers.Dense(3, use_bias=False) def call(self, x): return self.layer2(self.layer1(x)) @@ -653,15 +648,16 @@ assert [1, 2] == net.variables[0].shape.as_list() # weights of layer1. assert [2, 3] == net.variables[1].shape.as_list() # weights of layer2. ``` -The `tfe.Network` class is itself a sub-class of `tf.layers.Layer`. This allows -instances of `tfe.Network` to be embedded in other networks. For example: +The `tf.keras.Model` class is itself a sub-class of `tf.layers.Layer`. This +allows instances of `tf.keras.Model` to be embedded in other models. For +example: ```python -class ThreeLayerNet(tfe.Network): +class ThreeLayerNet(tf.keras.Model): def __init__(self): super(ThreeLayerNet, self).__init__() - self.a = self.track_layer(TwoLayerNet()) - self.b = self.track_layer(tf.layers.Dense(4, use_bias=False)) + self.a = TwoLayerNet() + self.b = tf.layers.Dense(4, use_bias=False) def call(self, x): return self.b(self.a(x)) @@ -678,9 +674,8 @@ assert [3, 4] == net.variables[2].shape.as_list() See more examples in [`tensorflow/contrib/eager/python/examples`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples). -`tfe.Saver` in combination with `tfe.restore_variables_on_create` provides a -convenient way to save and load checkpoints without changing the program once -the checkpoint has been created. For example, we can set an objective for the +`tfe.Checkpoint` provides a convenient way to save and load training +checkpoints. Let's define something simple to train. We set an objective for the output of our network, choose an optimizer, and a location for the checkpoint: ```python @@ -691,30 +686,27 @@ checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') net = ThreeLayerNet() ``` -Note that variables have not been created yet. We want them to be restored from -a checkpoint, if one exists, so we create them inside a -`tfe.restore_variables_on_create` context manager. Then our training loop is the -same whether starting training or resuming from a previous checkpoint: +We group them in a `tfe.Checkpoint` and request that it be restored. This +ensures that variables created by these objects are restored before their values +are used. Our training loop is the same whether starting training or resuming +from a previous checkpoint: ```python -with tfe.restore_variables_on_create( - tf.train.latest_checkpoint(checkpoint_directory)): - global_step = tf.train.get_or_create_global_step() - for _ in range(100): - loss_fn = lambda: tf.norm(net(inp) - objective) - optimizer.minimize(loss_fn, global_step=global_step) - if tf.equal(global_step % 20, 0): - print("Step %d, output %s" % (global_step.numpy(), - net(inp).numpy())) - all_variables = ( - net.variables - + optimizer.variables() - + [global_step]) - # Save the checkpoint. - tfe.Saver(all_variables).save(checkpoint_prefix, global_step=global_step) -``` - -The first time it runs, `Network` variables are initialized randomly. Then the +global_step = tf.train.get_or_create_global_step() +checkpoint = tfe.Checkpoint( + global_step=global_step, optimizer=optimizer, network=net) +checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) +for _ in range(100): + loss_fn = lambda: tf.norm(net(inp) - objective) + optimizer.minimize(loss_fn, global_step=global_step) + if tf.equal(global_step % 20, 0): + print("Step %d, output %s" % (global_step.numpy(), + net(inp).numpy())) + # Save the checkpoint. + checkpoint.save(checkpoint_prefix) +``` + +The first time it runs, `Model` variables are initialized randomly. Then the output is trained to match the objective we've set: ``` diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index ea8dbf2b46ea4bd0e33645ae3c590c4dd13f7a52..2f2347736a073c7d9b3fb6685f52f8d58cc40570 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -30,12 +30,12 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope - +from tensorflow.python.training import checkpointable _to_replace = re.compile("[^A-Za-z0-9.]") -class Metric(object): +class Metric(checkpointable.CheckpointableBase): """A metric holds state for aggregating statistics over an evaluation run. Example use with eager execution: @@ -93,11 +93,12 @@ class Metric(object): `aggregate()`, it is for use by TensorFlow infrastructure. """ - def __init__(self, name=None): + def __init__(self, name=None, use_global_variables=False): self._built = False self._vars = [] self._initial_values = {} self._updates = [] + self._use_global_variables = use_global_variables name = name or self.__class__.__name__ # Replace things like spaces in name to create a valid scope name. scope_name = _to_replace.sub("_", name) @@ -108,13 +109,25 @@ class Metric(object): pos = scope.name.rfind(scope_name) self._name = name + scope.name[pos + len(scope_name):] self._scope = scope - if context.in_graph_mode(): + + # Ensures that if the user calls build directly we still set self._built to + # True to prevent variables from being recreated. + self._build = self.build + + def actual_build(*args, **kwargs): + self._build(*args, **kwargs) + self._built = True + self.build = actual_build + self.build.__doc__ = self._build.__doc__ + + # Captures construction scope for proper initialization. + if context.executing_eagerly(): + self._construction_scope = context.eager_mode + else: # We make self.call() into a graph callable here, so that we can # return a single op that performs all of the variable updates. self._construction_scope = ops.get_default_graph().as_default self.call = function.defun(self.call) - else: - self._construction_scope = context.eager_mode # ---- API for users ---- def __call__(self, *args, **kwargs): @@ -155,10 +168,11 @@ class Metric(object): initialization. Under eager execution, the variables are reset to their initial values as a side effect and this function returns None. """ - if context.in_graph_mode(): + if context.executing_eagerly(): + for v in self._vars: + v.assign(self._initial_values[v]) + else: return control_flow_ops.group([v.initializer for v in self._vars]) - for v in self._vars: - v.assign(self._initial_values[v]) # ---- To be implemented by descendants --- def build(self, *args, **kwargs): @@ -200,10 +214,10 @@ class Metric(object): def value(self): """In graph mode returns the result Tensor while in eager the callable.""" - if context.in_graph_mode(): - return self.result() - else: + if context.executing_eagerly(): return self.result + else: + return self.result() # We can support two different strategies of for doing data-parallel # distributed metric computations: @@ -245,19 +259,31 @@ class Metric(object): """***Only for use by descendants of Metric***.""" if self._built: raise RuntimeError("Can't call add_variable() except in build().") - collections = None if context.in_eager_mode() else [ - ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES - ] - v = variable_scope.get_variable( - name, - shape, - dtype, - initializer, + if context.executing_eagerly(): + collections = None + else: + if self._use_global_variables: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + else: + collections = [ops.GraphKeys.LOCAL_VARIABLES] + collections += [ops.GraphKeys.METRIC_VARIABLES] + # Variables are Checkpointable dependencies of Metrics regardless of the + # global/local distinction. Users can avoid saving variables by not adding a + # dependency on the Metric. + v = self._add_variable_with_custom_getter( + name=name, + shape=shape, + dtype=dtype, + initializer=initializer, trainable=False, collections=collections, - use_resource=True) + use_resource=True, + getter=variable_scope.get_variable, + # Raise duplicate variable exceptions from get_variable rather than + # Checkpointable. + overwrite=True) self._vars.append(v) - if context.in_eager_mode(): + if context.executing_eagerly(): self._initial_values[v] = v.value() return v @@ -267,8 +293,10 @@ class Mean(Metric): # TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64? # Or defaults to type of the input if it is tf.float32, else tf.float64? - def __init__(self, name=None, dtype=dtypes.float64): - super(Mean, self).__init__(name=name) + def __init__(self, name=None, dtype=dtypes.float64, + use_global_variables=False): + super(Mean, self).__init__(name=name, + use_global_variables=use_global_variables) self.dtype = dtype def build(self, *args, **kwargs): diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index a9ecaa3f8bced3043ea0eb0ac3aa8bfa65e9e1ff..15ac889191e0fe51269bc5740d5e0ab1bc0e2b72 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -18,8 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import tempfile +from tensorflow.contrib.eager.python import checkpointable_utils from tensorflow.contrib.eager.python import metrics from tensorflow.contrib.summary import summary_ops from tensorflow.contrib.summary import summary_test_util @@ -50,6 +52,19 @@ class MetricsTest(test.TestCase): self.assertEqual( set(m.variables), set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))) + self.assertEqual(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES), []) + self.assertEqual( + set(m.variables), + set(ops.get_collection(ops.GraphKeys.METRIC_VARIABLES))) + + def testUseGlobalVariablesCollections(self): + with context.graph_mode(), ops.Graph().as_default(): + m = metrics.Mean(use_global_variables=True) + m(1000) + self.assertEqual( + set(m.variables), + set(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))) + self.assertEqual(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES), []) self.assertEqual( set(m.variables), set(ops.get_collection(ops.GraphKeys.METRIC_VARIABLES))) @@ -180,6 +195,15 @@ class MetricsTest(test.TestCase): m2 = metrics.Mean() m2(2) + def testBuildMean(self): + # Verify that calling build() on Mean and then calling it won't recreate + # variables. + m = metrics.Mean() + m.build() + old_numer = m.numer + m(0.0) + self.assertTrue(old_numer is m.numer) + def testMetricsChain(self): with context.graph_mode(), self.test_session(): m1 = metrics.Mean() @@ -193,6 +217,31 @@ class MetricsTest(test.TestCase): self.assertAllEqual(m2.result().eval(), 2.0) self.assertAllEqual(m1.result().eval(), 1.0) + @test_util.run_in_graph_and_eager_modes() + def testSaveRestore(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + mean = metrics.Mean() + checkpoint = checkpointable_utils.Checkpoint(mean=mean) + mean.build() + mean._built = True + self.evaluate(mean.init_variables()) + self.evaluate(mean(100.)) + self.evaluate(mean(200.)) + save_path = checkpoint.save(checkpoint_prefix) + self.evaluate(mean(1000.)) + checkpoint.restore(save_path).assert_consumed().run_restore_ops() + self.evaluate(mean(300.)) + self.assertAllEqual(200., self.evaluate(mean.value())) + + restore_mean = metrics.Mean() + restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean) + status = restore_checkpoint.restore(save_path) + restore_update = restore_mean(300.) + status.assert_consumed().run_restore_ops() + self.evaluate(restore_update) + self.assertAllEqual(200., self.evaluate(restore_mean.value())) + self.assertEqual(3, self.evaluate(restore_mean.denom)) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index e3c13cbd2e8ccd2ab79da74e0e97905c6ed5c02d..e55a9276ab53f44f76dc5e537b3bdde7c975f463 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -149,7 +149,7 @@ class Network(base.Layer): # check we might have name collisions if the parent scope on init gets # closed before build is called. self._variable_scope_counts_on_init = ( - variable_scope._get_default_variable_store().variable_scopes_count) + variable_scope.get_variable_scope_store().variable_scopes_count) def _name_scope_name(self, current_variable_scope): """Overrides Layer op naming to match variable naming.""" @@ -639,7 +639,7 @@ def _make_custom_getter_for_deferred_restorations(): # Mark as already restored from this checkpoint. delayed_restoration.checkpointed_variables_to_restore[ checkpoint_name] = None - if context.in_graph_mode(): + if not context.executing_eagerly(): delayed_restoration.session.run(variable.initializer) if found_value: # Error checking should run even if we've already restored a value. @@ -772,7 +772,7 @@ def save_network_checkpoint( variable_map[mapped_name]._shared_name, variable._shared_name, network.scope_name)) - if context.in_eager_mode(): + if context.executing_eagerly(): sess = None else: sess = ops.get_default_session() @@ -853,7 +853,7 @@ def _restore_existing_variables(network, save_path, map_func, user_map_func): network_name=network.name, network_scope_name=network.scope_name)) if existing_variables_by_checkpoint_name: - if context.in_eager_mode(): + if context.executing_eagerly(): sess = None else: sess = ops.get_default_session() @@ -880,7 +880,7 @@ def _set_restore_on_create(network, save_path, map_func, user_map_func, # _DeferredRestoration objects once a Network has been built (so that # restoring in a loop does not take increasing amounts of memory). if checkpointed_variables_to_restore: - if context.in_eager_mode(): + if context.executing_eagerly(): sess = None else: sess = ops.get_default_session() diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py index 62421849c766a1124c726812428985c913c653a3..fdaca90fd13576e6ca8a3408aaf528dbc2384b0c 100644 --- a/tensorflow/contrib/eager/python/saver.py +++ b/tensorflow/contrib/eager/python/saver.py @@ -73,7 +73,7 @@ def restore_variables_on_create(save_path, map_func=None): NotFoundError: If the variable is not found in checkpoint. ValueError: If not used in eager mode or map_func is not callable. """ - if context.in_graph_mode(): + if not context.executing_eagerly(): raise ValueError( "Currently, restore_variables_on_create can only be used with " "eager execution enabled.") @@ -131,7 +131,7 @@ class Saver(object): Raises: RuntimeError: if invoked when eager execution has not been enabled. """ - if context.in_graph_mode(): + if not context.executing_eagerly(): raise RuntimeError("tfe.Saver can only be used when eager " "execution is enabled. Use tf.train.Saver when " "building graphs.") diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index d32bebf90c1e768d1efec26b3b78bf1a522a8f00..c6f3f20e781147140f2c4b339ed465ab7e919d37 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -56,14 +56,24 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@save_network_checkpoint @@restore_network_checkpoint +@@Checkpoint +@@Checkpointable +@@CheckpointableSaver + +@@executing_eagerly @@in_eager_mode -@@in_graph_mode +@@set_execution_mode +@@execution_mode +@@async_wait +@@async_clear_error @@run_test_in_graph_and_eager_modes @@DEVICE_PLACEMENT_EXPLICIT @@DEVICE_PLACEMENT_WARN @@DEVICE_PLACEMENT_SILENT +@@SYNC +@@ASYNC """ from __future__ import absolute_import @@ -74,6 +84,8 @@ from __future__ import print_function # pylint:disable=g-bad-import-order,g-import-not-at-top,unused-import # from tensorflow.contrib.eager.python import metrics +from tensorflow.contrib.eager.python.checkpointable_utils import CheckpointableSaver +from tensorflow.contrib.eager.python.checkpointable_utils import Checkpoint from tensorflow.contrib.eager.python.datasets import Iterator from tensorflow.contrib.eager.python.network import Network from tensorflow.contrib.eager.python.network import Sequential @@ -87,11 +99,15 @@ from tensorflow.python.eager import function from tensorflow.python.eager.context import DEVICE_PLACEMENT_EXPLICIT from tensorflow.python.eager.context import DEVICE_PLACEMENT_WARN from tensorflow.python.eager.context import DEVICE_PLACEMENT_SILENT -from tensorflow.python.eager.context import in_eager_mode -from tensorflow.python.eager.context import in_graph_mode +from tensorflow.python.eager.context import executing_eagerly from tensorflow.python.eager.context import list_devices +from tensorflow.python.eager.context import set_execution_mode +from tensorflow.python.eager.context import execution_mode +from tensorflow.python.eager.context import async_wait +from tensorflow.python.eager.context import async_clear_error +from tensorflow.python.eager.context import SYNC +from tensorflow.python.eager.context import ASYNC from tensorflow.python.eager.context import num_gpus -from tensorflow.python.eager.custom_gradient import custom_gradient from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks from tensorflow.python.eager.execution_callbacks import inf_callback @@ -101,10 +117,12 @@ from tensorflow.python.eager.execution_callbacks import seterr from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import eager_run as run from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes +from tensorflow.python.ops.custom_gradient import custom_gradient from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops import script_ops from tensorflow.python.ops import template +from tensorflow.python.training.checkpointable import Checkpointable from tensorflow.python.util.all_util import remove_undocumented py_func = script_ops.eager_py_func @@ -115,5 +133,6 @@ implicit_value_and_gradients = backprop.implicit_val_and_grad gradients_function = backprop.gradients_function value_and_gradients_function = backprop.val_and_grad_function GradientTape = backprop.GradientTape # pylint: disable=invalid-name +in_eager_mode = executing_eagerly remove_undocumented(__name__) diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index b6659c2a1797feab261d756e78b45231dbea5a02..e80ccbb74d8623e977a98cb7fa5eb41f3c9bf250 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -47,7 +47,8 @@ class TFETest(test_util.TensorFlowTestCase): def testVariableError(self): with self.assertRaisesRegexp( - RuntimeError, r'Variable not supported in Eager mode'): + RuntimeError, + r'Variable not supported when eager execution is enabled'): variables.Variable(initial_value=1.0) def testGradients(self): diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index ddccfce3c07d20bde78de297db25437a347d75cb..c846343d6d23198726153e6b693660f61232bee5 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -142,6 +142,7 @@ py_test( deps = [ ":extenders", "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/predictor", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", @@ -170,9 +171,11 @@ py_library( "//tensorflow/python:lookup_ops", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", + "//tensorflow/python:nn", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:summary", + "//tensorflow/python:training", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:metric_keys", @@ -192,6 +195,7 @@ py_test( ":head", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", @@ -289,6 +293,8 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:metrics", "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", @@ -352,6 +358,7 @@ cuda_py_test( size = "medium", srcs = ["python/estimator/replicate_model_fn_test.py"], additional_deps = [ + "@absl_py//absl/testing:parameterized", "//tensorflow/python/estimator", "//tensorflow/python/estimator:dnn", "//tensorflow/python/estimator:export_export", diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 0f75b77050b0ba4c752a6a74fdc7024170b6f318..6b9f9575b606f1822d760e8597c55994dd8af04c 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -39,6 +39,7 @@ _allowed_symbols = [ 'multi_class_head', 'multi_head', 'multi_label_head', + 'poisson_regression_head', 'regression_head', 'DNNEstimator', 'DNNLinearCombinedEstimator', diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py index c99bf8badb35e6fffb7cae8761db9d402b8b3a8f..266ae933052b11b9ab3edb662e95c90aae207dae 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders.py @@ -23,6 +23,7 @@ import six from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import util as estimator_util +from tensorflow.python.estimator.export.export_output import PredictOutput from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.ops import clip_ops @@ -33,7 +34,7 @@ _VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config']) def add_metrics(estimator, metric_fn): - """Creates a new ${tf.estimator.Estimator} which has given metrics. + """Creates a new @{tf.estimator.Estimator} which has given metrics. Example: @@ -60,7 +61,7 @@ def add_metrics(estimator, metric_fn): ``` Args: - estimator: A ${tf.estimator.Estimator} object. + estimator: A @{tf.estimator.Estimator} object. metric_fn: A function which should obey the following signature: - Args: can only have following four arguments in any order: * predictions: Predictions `Tensor` or dict of `Tensor` created by given @@ -78,7 +79,7 @@ def add_metrics(estimator, metric_fn): function, namely a `(metric_tensor, update_op)` tuple. Returns: - A new ${tf.estimator.Estimator} which has a union of original metrics with + A new @{tf.estimator.Estimator} which has a union of original metrics with given ones. """ _verify_metric_fn_args(metric_fn) @@ -161,14 +162,14 @@ def forward_features(estimator, keys=None): ``` Args: - estimator: A ${tf.estimator.Estimator} object. + estimator: A @{tf.estimator.Estimator} object. keys: a `string` or a `list` of `string`. If it is `None`, all of the `features` in `dict` is forwarded to the `predictions`. If it is a `string`, only given key is forwarded. If it is a `list` of strings, all the given `keys` are forwarded. Returns: - A new ${tf.estimator.Estimator} which forwards features to predictions. + A new @{tf.estimator.Estimator} which forwards features to predictions. Raises: ValueError: @@ -233,7 +234,17 @@ def forward_features(estimator, keys=None): 'argument of forward_features to filter unwanted features. Type of ' 'features[{}] is {}.'.format(key, key, type(feature))) predictions[key] = feature - return spec._replace(predictions=predictions) + spec = spec._replace(predictions=predictions) + if spec.export_outputs: + for ekey in ['predict', 'serving_default']: + if (ekey in spec.export_outputs and + isinstance(spec.export_outputs[ekey], + PredictOutput)): + export_outputs = spec.export_outputs[ekey].outputs + for key in get_keys(features): + export_outputs[key] = predictions[key] + + return spec return estimator_lib.Estimator( model_fn=new_model_fn, diff --git a/tensorflow/contrib/estimator/python/estimator/extenders_test.py b/tensorflow/contrib/estimator/python/estimator/extenders_test.py index ad1a8ef152b07ecbab33d9eb3184a2ae89def27d..407af2deaf0928361a4f0b0e44e842b7750118cb 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders_test.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders_test.py @@ -18,20 +18,27 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os +import tempfile import numpy as np from tensorflow.contrib.estimator.python.estimator import extenders +from tensorflow.contrib.predictor import from_saved_model from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator.canned import linear from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.training import training +from tensorflow.python.util import compat def get_input_fn(x, y): @@ -177,6 +184,44 @@ class ForwardFeaturesTest(test.TestCase): self.assertIn('id', predictions) self.assertEqual(101, predictions['id']) + def test_forward_in_exported(self): + + def serving_input_fn(): + features_ph = { + 'x': array_ops.placeholder(dtypes.float32, [None]), + 'id': array_ops.placeholder(dtypes.int32, [None]) + } + features = { + key: array_ops.expand_dims(tensor, -1) + for key, tensor in features_ph.items() + } + return estimator_lib.export.ServingInputReceiver(features, features_ph) + def input_fn(): + return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]] + # create estimator + feature_columns = [fc.numeric_column('x')] + estimator = linear.LinearRegressor(feature_columns) + estimator.train(input_fn=input_fn, steps=1) + estimator = extenders.forward_features(estimator, 'id') + + # export saved model + tmpdir = tempfile.mkdtemp() + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = estimator.export_savedmodel(export_dir_base, serving_input_fn) + self.assertTrue(gfile.Exists(export_dir)) + + # restore model + predict_fn = from_saved_model(export_dir, signature_def_key='predict') + predictions = predict_fn({'x': [3], 'id': [101]}) + + # verify that 'id' exists in predictions + self.assertIn('id', predictions) + self.assertEqual(101, predictions['id']) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + def test_forward_list(self): def input_fn(): diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index a45f6934cc5b9bb7bccf148edbd7553b702c2127..74da2cbb3f4557b4ddbbeb6debaae085407a0023 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -31,14 +31,17 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import nn from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.losses import losses from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary +from tensorflow.python.training import training_util _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY +# TODO(b/65403806): Switch loss_reduction default to SUM_OVER_BATCH_SIZE. def multi_class_head(n_classes, weight_column=None, label_vocabulary=None, @@ -237,11 +240,71 @@ def regression_head(weight_column=None, name=name) +def poisson_regression_head( + weight_column=None, + label_dimension=1, + loss_reduction=losses.Reduction.SUM, + compute_full_loss=True, + name=None): + """Creates a `_Head` for poisson regression using `tf.nn.log_poisson_loss`. + + The loss is the weighted sum over all input dimensions. Namely, if the input + labels have shape `[batch_size, label_dimension]`, the loss is the weighted + sum over both `batch_size` and `label_dimension`. + + The head expects `logits` with shape `[D0, D1, ... DN, label_dimension]`. + In many applications, the shape is `[batch_size, label_dimension]`. + + The `labels` shape must match `logits`, namely + `[D0, D1, ... DN, label_dimension]`. If `label_dimension=1`, shape + `[D0, D1, ... DN]` is also supported. + + If `weight_column` is specified, weights must be of shape + `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or + `[D0, D1, ... DN, label_dimension]`. + + This is implemented as a generalized linear model, see + https://en.wikipedia.org/wiki/Generalized_linear_model. + + Args: + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to down weight or boost examples during training. It + will be multiplied by the loss of the example. + label_dimension: Number of regression labels per example. This is the size + of the last dimension of the labels `Tensor` (typically, this has shape + `[batch_size, label_dimension]`). + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to + reduce training loss over batch. Defaults to `SUM`. + compute_full_loss: Whether to include the constant `log(z!)` term in + computing the poisson loss. See `tf.nn.log_poisson_loss` for the full + documentation. + name: name of the head. If provided, summary and metrics keys will be + suffixed by `"/" + name`. Also used as `name_scope` when creating ops. + + Returns: + An instance of `_Head` for poisson regression. + + Raises: + ValueError: If `label_dimension` or `loss_reduction` is invalid. + """ + def _poisson_loss(labels, logits): + return nn.log_poisson_loss( + targets=labels, log_input=logits, compute_full_loss=compute_full_loss) + return head_lib._regression_head_with_mean_squared_error_loss( # pylint:disable=protected-access + weight_column=weight_column, + label_dimension=label_dimension, + loss_reduction=loss_reduction, + loss_fn=_poisson_loss, + inverse_link_fn=math_ops.exp, + name=name) + + def multi_label_head(n_classes, weight_column=None, thresholds=None, label_vocabulary=None, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, name=None): """Creates a `_Head` for multi-label classification. @@ -292,7 +355,8 @@ def multi_label_head(n_classes, string type and have any value in `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to - reduce training loss over batch. Defaults to `SUM`. + reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely + weighted sum of losses divided by batch size. See `tf.losses.Reduction`. loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -341,7 +405,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weight_column=None, thresholds=None, label_vocabulary=None, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, name=None): self._n_classes = n_classes @@ -428,8 +492,8 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access processed_labels=processed_labels) def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None, - regularization_losses=None): + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None, regularization_losses=None): """Returns an `EstimatorSpec`. Args: @@ -441,8 +505,11 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access with shape `[D0, D1, ... DN, n_classes]` or `SparseTensor` with `dense_shape` `[D0, D1, ... DN, ?]`. `labels` is required argument when `mode` equals `TRAIN` or `EVAL`. + optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. + Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which + updates variables and increments `global_step`. train_op_fn: Function that takes a scalar loss `Tensor` and returns - `train_op`. Required in TRAIN mode. + `train_op`. Used if `optimizer` is `None`. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results users need to @@ -452,7 +519,8 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access Returns: `EstimatorSpec`. Raises: - ValueError: If `train_op_fn` is `None` in TRAIN mode. + ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN + mode, or if both are set. """ with ops.name_scope(self._name, 'head'): logits = head_lib._check_logits_final_dim(logits, self.logits_dimension) # pylint:disable=protected-access @@ -504,8 +572,16 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access regularization_loss=regularization_loss)) # Train. - if train_op_fn is None: - raise ValueError('train_op_fn can not be None.') + if optimizer is not None: + if train_op_fn is not None: + raise ValueError('train_op_fn and optimizer cannot both be set.') + train_op = optimizer.minimize( + regularized_training_loss, + global_step=training_util.get_global_step()) + elif train_op_fn is not None: + train_op = train_op_fn(regularized_training_loss) + else: + raise ValueError('train_op_fn and optimizer cannot both be None.') # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -531,7 +607,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, - train_op=train_op_fn(regularized_training_loss)) + train_op=train_op) def _eval_metric_ops( self, labels, probabilities, weights, unreduced_loss, diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index 1411635228457218578c0297d4d901e9c86ca91a..8837dfdc6c2d83495157f0d30b80ac8f6f245c60 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops @@ -271,9 +272,9 @@ class MultiLabelHead(test.TestCase): logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32) labels = np.array([[1, 0], [1, 1]], dtype=np.int64) - # loss = labels * -log(sigmoid(logits)) + - # (1 - labels) * -log(1 - sigmoid(logits)) - expected_training_loss = np.sum( + # loss = (labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits))) / 2 + expected_training_loss = 0.5 * np.sum( _sigmoid_cross_entropy(labels=labels, logits=logits)) actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, @@ -297,7 +298,7 @@ class MultiLabelHead(test.TestCase): # For large logits, this is approximated as: # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits - expected_training_loss = np.sum( + expected_training_loss = 0.5 * np.sum( np.array([[(10. + 10.) / 2.], [(15. + 0.) / 2.]], dtype=np.float32)) actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, @@ -360,7 +361,7 @@ class MultiLabelHead(test.TestCase): labels=labels_input)[0] with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) - self.assertAllClose(np.sum(loss), actual_training_loss.eval()) + self.assertAllClose(np.sum(loss) / 2., actual_training_loss.eval()) def test_eval_create_loss_loss_fn_wrong_shape(self): """Tests custom loss_fn that returns Tensor of unexpected shape.""" @@ -437,16 +438,17 @@ class MultiLabelHead(test.TestCase): labels = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.5972, + keys.AUC_PR: 0.7639, } self._test_eval( head=head, @@ -467,18 +469,17 @@ class MultiLabelHead(test.TestCase): labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) - ) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.5972, + keys.AUC_PR: 0.7639, } self._test_eval( head=head, @@ -509,7 +510,7 @@ class MultiLabelHead(test.TestCase): # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.5972, + keys.AUC_PR: 0.7639, } self._test_eval( head=head, @@ -532,18 +533,17 @@ class MultiLabelHead(test.TestCase): labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) - ) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.5972, + keys.AUC_PR: 0.7639, } self._test_eval( head=head, @@ -561,19 +561,18 @@ class MultiLabelHead(test.TestCase): labels = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) - ) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.5972, + keys.AUC_PR: 0.7639, keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 2. / 4., keys.PRECISION_AT_THRESHOLD % thresholds[0]: 2. / 3., keys.RECALL_AT_THRESHOLD % thresholds[0]: 2. / 3., @@ -602,8 +601,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, weighted sum over examples. - expected_loss = 25. + # Average over classes, weighted sum over examples, divide by batch_size. + # loss = ( 1 * (10 + 10) / 2 + 2 * (15 + 0) / 2) / 2 + expected_loss = 12.5 spec = head.create_estimator_spec( features={ @@ -616,12 +616,12 @@ class MultiLabelHead(test.TestCase): keys = metric_keys.MetricKeys expected_metrics = { - # Average loss over weighted examples. - keys.LOSS_MEAN: expected_loss / 3, + # Average loss over weighted examples (denominator is sum(weights)). + keys.LOSS_MEAN: expected_loss * (2. / 3.), # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.2000, - keys.AUC_PR: 0.5833, + keys.AUC_PR: 0.7833, } # Assert spec contains expected tensors. @@ -662,7 +662,7 @@ class MultiLabelHead(test.TestCase): # (1 - labels) * (logits > 0) * logits expected_unreduced_loss = [[(10. + 10.) / 2.], [(15. + 0.) / 2.]] expected_weights = [[1.], [2.]] - expected_training_loss = 1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2. + expected_training_loss = (1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2.) / 2. training_loss, unreduced_loss, actual_weights, _ = head.create_loss( features={ 'x': np.array(((42,),), dtype=np.int32), @@ -808,11 +808,8 @@ class MultiLabelHead(test.TestCase): self.assertEqual( six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) - _assert_simple_summaries(self, { - metric_keys.MetricKeys.LOSS: expected_loss, - # Average loss over examples. - metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2, - }, summary_str, tol) + _assert_simple_summaries( + self, {metric_keys.MetricKeys.LOSS: expected_loss}, summary_str, tol) def test_train(self): head = head_lib.multi_label_head(n_classes=2) @@ -822,8 +819,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 self._test_train( head=head, logits=logits, labels=labels, expected_loss=expected_loss) @@ -839,8 +837,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 self._test_train( head=head, logits=logits, labels=labels, expected_loss=expected_loss) @@ -857,11 +856,49 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 self._test_train( head=head, logits=logits, labels=labels, expected_loss=expected_loss) + def test_train_with_optimizer(self): + head = head_lib.multi_label_head(n_classes=2) + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + # For large logits, sigmoid cross entropy loss is approximated as: + # loss = labels * (logits < 0) * (-logits) + + # (1 - labels) * (logits > 0) * logits => + # expected_unweighted_loss = [[10., 10.], [15., 0.]] + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 + expected_train_result = 'my_train_op' + + class _Optimizer(object): + + def minimize(self, loss, global_step): + del global_step + return string_ops.string_join( + [constant_op.constant(expected_train_result), + string_ops.as_string(loss, precision=3)]) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + optimizer=_Optimizer()) + + tol = 1e-3 + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) + self.assertEqual( + six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), + train_result) + def test_train_with_regularization_losses(self): head = head_lib.multi_label_head( n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) @@ -915,8 +952,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, weighted sum over examples. - expected_loss = 25. + # Average over classes, weighted sum over examples, divide by batch_size. + # loss = ( 1 * (10 + 10) / 2 + 2 * (15 + 0) / 2 ) / 2 + expected_loss = 12.5 expected_train_result = 'my_train_op' def _train_op_fn(loss): return string_ops.string_join( @@ -950,11 +988,8 @@ class MultiLabelHead(test.TestCase): self.assertEqual( six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) - _assert_simple_summaries(self, { - metric_keys.MetricKeys.LOSS: expected_loss, - # Average loss over weighted examples. - metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 3, - }, summary_str, tol) + _assert_simple_summaries( + self, {metric_keys.MetricKeys.LOSS: expected_loss,}, summary_str, tol) def test_multi_dim_weighted_train_create_loss(self): """Logits and labels of shape [2, 2, 3], weights [2, 2].""" @@ -971,8 +1006,8 @@ class MultiLabelHead(test.TestCase): expected_unreduced_loss = [[[20./3.], [10./3.]], [[4.], [8.]]] # weights are reshaped to [2, 2, 1] to match logits. expected_weights = [[[1.], [1.5]], [[2.], [2.5]]] - # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 - expected_training_loss = 39.6667 + # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167 + expected_training_loss = 9.9167 training_loss, unreduced_loss, actual_weights, _ = head.create_loss( features={'weights': weights}, mode=model_fn.ModeKeys.TRAIN, @@ -998,8 +1033,8 @@ class MultiLabelHead(test.TestCase): weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 # = [[20/3, 10/3], [4, 8]] - # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 - expected_loss = 39.6667 + # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167 + expected_loss = 9.9167 expected_train_result = 'my_train_op' def _train_op_fn(loss): return string_ops.string_join( @@ -1087,15 +1122,15 @@ class MultiLabelHead(test.TestCase): weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 # = [[20/3, 10/3], [4, 8]] - # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 - expected_loss = 39.6667 + # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167 + expected_loss = 9.9167 keys = metric_keys.MetricKeys expected_metrics = { - keys.LOSS_MEAN: expected_loss / np.sum(weights), + keys.LOSS_MEAN: expected_loss * (4. / np.sum(weights)), # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.4977, - keys.AUC_PR: 0.4037, + keys.AUC_PR: 0.6645, } self._test_eval( head=head, @@ -1106,5 +1141,75 @@ class MultiLabelHead(test.TestCase): expected_metrics=expected_metrics) +class PoissonRegressionHead(test.TestCase): + + def setUp(self): + ops.reset_default_graph() + + def test_train(self): + head = head_lib.poisson_regression_head() + + # Create estimator spec. + logits = np.array([[0], [-1], [1]], dtype=np.float32) + labels = np.array([[1], [2], [3]], dtype=np.int32) + # With x = exp(logits), z = labels. + # loss = -ln(exp(-x) * (x^z) / z!) + # = x - z * ln(x) + ln(z!) + # = exp(logits) - labels * logits - ln(labels!) + # But for ln(z!) and z > 1, the Stirling approximation is used + # ln(z!) = z*ln(z) - z + 0.5*ln(2*pi*z) + # loss = [exp(0) - 1 * 0 + ln(1!), + # exp(-1) - 2 * (-1) + 2*ln(2) - 2 + 0.5*ln(2*pi*2), + # exp(1) - 3 * 1 + 3*ln(3) - 3 + 0.5*ln(2*pi*3)] + # = [1.0, 3.020, 1.482] + # sum_loss = 5.502 + expected_loss = 5.502 + atol = 0.001 + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + with ops.control_dependencies((check_ops.assert_near( + math_ops.to_float(expected_loss), math_ops.to_float(loss), + atol=atol, name='assert_loss'),)): + return constant_op.constant(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run([spec.loss, spec.train_op]) + self.assertAlmostEqual(expected_loss, loss, delta=atol) + self.assertEqual(expected_train_result, train_result) + + def test_predict(self): + head = head_lib.poisson_regression_head() + + # Create estimator spec. + logits = np.array([[0], [-1], [1]], dtype=np.float32) + expected_predictions = np.exp(logits) + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) + + # Assert spec contains expected tensors. + keys = prediction_keys.PredictionKeys + self.assertItemsEqual( + (keys.PREDICTIONS, keys.LOGITS), spec.predictions.keys()) + self.assertEqual(dtypes.float32, spec.predictions[keys.PREDICTIONS].dtype) + self.assertEqual(dtypes.float32, spec.predictions[keys.LOGITS].dtype) + + # Assert predictions. + with self.test_session(): + _initialize_variables(self, spec.scaffold) + self.assertAllClose( + expected_predictions, spec.predictions[keys.PREDICTIONS].eval()) + self.assertAllClose(logits, spec.predictions[keys.LOGITS].eval()) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index 0346ddc24bffd61068177f4622bd03be4acd53d9..bbbc19cc4dfb4b23f9b707023fbfdd124f1f48de 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -23,6 +23,7 @@ import six from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import metric_keys +from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -30,6 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary +from tensorflow.python.training import training_util _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -226,8 +228,10 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access weights=example_weights_by_head, processed_labels=labels_by_head) + # TODO(b/65403806): Support regularization_losses arg. def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None): + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None): """See `_Head`.""" if isinstance(logits, dict): logits_dict = logits @@ -248,9 +252,10 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access train_op_fn=_no_op_train_fn)) if mode == model_fn.ModeKeys.TRAIN: - if train_op_fn is None: - raise ValueError('train_op_fn can not be None in TRAIN mode.') - spec = self._merge_train(all_estimator_spec, train_op_fn) + spec = self._merge_train( + all_estimator_spec=all_estimator_spec, + optimizer=optimizer, + train_op_fn=train_op_fn) with ops.name_scope(''): summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss) return spec @@ -279,16 +284,21 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access begin_idx += head.logits_dimension return logits_dict - def _merge_train(self, all_estimator_spec, train_op_fn): + def _merge_train(self, all_estimator_spec, optimizer, train_op_fn): """Merges list of `EstimatorSpec` for training. Args: all_estimator_spec: list of `EstimatorSpec` for the individual heads. - train_op_fn: Function to create train op. See `create_estimator_spec` - documentation for more details. + optimizer: `Optimizer` instance to create train op. See + `create_estimator_spec` documentation for more details. + train_op_fn: Function to create train op. Used if `optimizer` is `None`. Returns: `EstimatorSpec` that merges all heads for TRAIN. + + Raises: + ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN + mode. """ losses = [] metrics = {} @@ -297,11 +307,20 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access # Metric keys already contain head.name. metrics.update(spec.eval_metric_ops or {}) loss = _merge_losses(losses, self._head_weights) + if optimizer is not None: + if train_op_fn is not None: + raise ValueError('train_op_fn and optimizer cannot both be set.') + train_op = optimizer.minimize( + loss, global_step=training_util.get_global_step()) + elif train_op_fn is not None: + train_op = train_op_fn(loss) + else: + raise ValueError('train_op_fn and optimizer cannot both be None.') return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, loss=loss, - train_op=train_op_fn(loss), + train_op=train_op, eval_metric_ops=metrics) def _merge_predict(self, all_estimator_spec): @@ -319,6 +338,7 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access all_estimator_spec[0].export_outputs, self._heads[0].name), } + merged_predict_outputs = {} for head, spec in zip(self._heads, all_estimator_spec): head_name = head.name for k, v in six.iteritems(spec.export_outputs): @@ -327,8 +347,15 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access else: key = '%s/%s' % (k, head_name) export_outputs[key] = v + if (k == head_lib._PREDICT_SERVING_KEY and # pylint:disable=protected-access + isinstance(v, export_output_lib.PredictOutput)): + for kp, vp in six.iteritems(v.outputs): + key = '%s/%s' % (head_name, kp) + merged_predict_outputs[key] = vp for k, v in six.iteritems(spec.predictions): predictions[(head_name, k)] = v + export_outputs[head_lib._PREDICT_SERVING_KEY] = ( # pylint:disable=protected-access + export_output_lib.PredictOutput(merged_predict_outputs)) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.PREDICT, diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index e47a6788f3b5440c4906b9f0430c802cf73237e3..74d3d6d728554587290301b6ddd5b9aaeb8cebac 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -127,8 +127,8 @@ class MultiHeadTest(test.TestCase): logits=logits) self.assertItemsEqual( - (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1', - 'head2', 'classification/head2', 'predict/head2'), + (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'classification/head1', + 'predict/head1', 'head2', 'classification/head2', 'predict/head2'), spec.export_outputs.keys()) # Assert predictions and export_outputs. @@ -158,6 +158,22 @@ class MultiHeadTest(test.TestCase): self.assertAllClose( expected_probabilities['head2'], sess.run(spec.export_outputs['head2'].scores)) + self.assertAllClose( + expected_probabilities['head1'], + sess.run( + spec.export_outputs['predict'].outputs['head1/probabilities'])) + self.assertAllClose( + expected_probabilities['head2'], + sess.run( + spec.export_outputs['predict'].outputs['head2/probabilities'])) + self.assertAllClose( + expected_probabilities['head1'], + sess.run( + spec.export_outputs['predict/head1'].outputs['probabilities'])) + self.assertAllClose( + expected_probabilities['head2'], + sess.run( + spec.export_outputs['predict/head2'].outputs['probabilities'])) def test_predict_two_heads_logits_tensor(self): """Tests predict with logits as Tensor.""" @@ -181,8 +197,8 @@ class MultiHeadTest(test.TestCase): logits=logits) self.assertItemsEqual( - (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1', - 'head2', 'classification/head2', 'predict/head2'), + (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'classification/head1', + 'predict/head1', 'head2', 'classification/head2', 'predict/head2'), spec.export_outputs.keys()) # Assert predictions and export_outputs. @@ -238,8 +254,8 @@ class MultiHeadTest(test.TestCase): logits=logits) self.assertItemsEqual( - (_DEFAULT_SERVING_KEY, 'head1', 'regression/head1', 'predict/head1', - 'head2', 'regression/head2', 'predict/head2'), + (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'regression/head1', + 'predict/head1', 'head2', 'regression/head2', 'predict/head2'), spec.export_outputs.keys()) # Assert predictions and export_outputs. @@ -283,10 +299,11 @@ class MultiHeadTest(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]] + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]] - # Average over classes, weighted sum over batch and heads. - expected_loss_head1 = 17.5 - expected_loss_head2 = 30.0 + # loss = ( (20 + 20 + 20) / 3 + (30 + 0 + 0) / 3 ) / 2 = 15 + expected_loss_head1 = 8.75 + expected_loss_head2 = 15. expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2 spec = multi_head.create_estimator_spec( @@ -300,14 +317,14 @@ class MultiHeadTest(test.TestCase): keys.LOSS + '/head1': expected_loss_head1, keys.LOSS + '/head2': expected_loss_head2, # Average loss over examples. - keys.LOSS_MEAN + '/head1': expected_loss_head1 / 2, - keys.LOSS_MEAN + '/head2': expected_loss_head2 / 2, + keys.LOSS_MEAN + '/head1': expected_loss_head1, + keys.LOSS_MEAN + '/head2': expected_loss_head2, # auc and auc_pr cannot be reliably calculated for only 4-6 samples, but # this assert tests that the algorithm remains consistent. keys.AUC + '/head1': 0.1667, keys.AUC + '/head2': 0.3333, - keys.AUC_PR + '/head1': 0.49999964, - keys.AUC_PR + '/head2': 0.33333313, + keys.AUC_PR + '/head1': 0.6667, + keys.AUC_PR + '/head2': 0.5000, } # Assert spec contains expected tensors. @@ -347,8 +364,8 @@ class MultiHeadTest(test.TestCase): tol = 1e-3 with self.test_session(): # Unreduced loss of the head is [[(10 + 10) / 2], (15 + 0) / 2] - # (averaged over classes, sum-reduced over examples). - self.assertAllClose(17.5, loss.eval(), rtol=tol, atol=tol) + # (averaged over classes, averaged over examples). + self.assertAllClose(8.75, loss.eval(), rtol=tol, atol=tol) def test_train_create_loss_two_heads_with_weights(self): # Use different example weighting for each head weighting. @@ -383,18 +400,18 @@ class MultiHeadTest(test.TestCase): with self.test_session(): # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] # = [10, 7.5] - # training_loss = 1 * 10 + 2 * 7.5 = 25 + # training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5 # head-weighted unreduced_loss = 1 * [10, 7.5] self.assertAllClose( [[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol) # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] # = [20, 10] - # training_loss = 2 * 20 + 3 * 10 = 70 + # training_loss = (2 * 20 + 3 * 10) / 2 = 35 # head-weighted unreduced_loss = 2 * [20, 10] self.assertAllClose( [[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol) - # head-weighted training_loss = 1 * 25 + 2 * 70 = 165 - self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol) + # head-weighted training_loss = 1 * 12.5 + 2 * 35 = 82.5 + self.assertAllClose(82.5, training_loss.eval(), rtol=tol, atol=tol) # head-weighted example weights self.assertAllClose( [[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol) @@ -431,18 +448,18 @@ class MultiHeadTest(test.TestCase): with self.test_session(): # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] # = [10, 7.5] - # training_loss = 1 * 10 + 2 * 7.5 = 25 + # training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5 # head-weighted unreduced_loss = 1 * [10, 7.5] self.assertAllClose( [[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol) # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] # = [20, 10] - # training_loss = 2 * 20 + 3 * 10 = 70 + # training_loss = (2 * 20 + 3 * 10) / 2 = 35 # head-weighted unreduced_loss = 2 * [20, 10] self.assertAllClose( [[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol) - # head-weighted training_loss = 1 * 25 + 2 * 70 = 165 - self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol) + # head-weighted training_loss = 1 * 12.5 + 2 * 35 = 82.5 + self.assertAllClose(82.5, training_loss.eval(), rtol=tol, atol=tol) # head-weighted example weights self.assertAllClose( [[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol) @@ -495,8 +512,8 @@ class MultiHeadTest(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 + expected_loss = 8.75 expected_train_result = 'my_train_op' def _train_op_fn(loss): return string_ops.string_join( @@ -530,10 +547,46 @@ class MultiHeadTest(test.TestCase): _assert_simple_summaries(self, { metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS + '/head1': expected_loss, - # Average loss over examples. - metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss / 2, }, summary_str, tol) + def test_train_one_head_with_optimizer(self): + head1 = head_lib.multi_label_head(n_classes=2, name='head1') + multi_head = multi_head_lib.multi_head([head1]) + + logits = {'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)} + labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)} + # For large logits, sigmoid cross entropy loss is approximated as: + # loss = labels * (logits < 0) * (-logits) + + # (1 - labels) * (logits > 0) * logits => + # expected_unweighted_loss = [[10., 10.], [15., 0.]] + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 + expected_loss = 8.75 + expected_train_result = 'my_train_op' + + class _Optimizer(object): + + def minimize(self, loss, global_step): + del global_step + return string_ops.string_join( + [constant_op.constant(expected_train_result), + string_ops.as_string(loss, precision=3)]) + + spec = multi_head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + optimizer=_Optimizer()) + + tol = 1e-3 + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) + self.assertEqual( + six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), + train_result) + def test_train_two_heads_with_weights(self): head1 = head_lib.multi_label_head(n_classes=2, name='head1') head2 = head_lib.multi_label_head(n_classes=3, name='head2') @@ -553,10 +606,12 @@ class MultiHeadTest(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]] + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]] + # loss = ( (20 + 20 + 20) / 3 + (30 + 0 + 0) / 3 ) / 2 = 15 # Average over classes, weighted sum over batch and heads. - expected_loss_head1 = 17.5 - expected_loss_head2 = 30.0 + expected_loss_head1 = 8.75 + expected_loss_head2 = 15.0 expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2 expected_train_result = 'my_train_op' def _train_op_fn(loss): @@ -592,9 +647,6 @@ class MultiHeadTest(test.TestCase): metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS + '/head1': expected_loss_head1, metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2, - # Average loss over examples. - metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss_head1 / 2, - metric_keys.MetricKeys.LOSS_MEAN + '/head2': expected_loss_head2 / 2, }, summary_str, tol) diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index e0fae2c99292385c6dd32cc6002cee2076a2bb20..fa2697800ec1a44f215f3d5fc9be2197a9e58219 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -136,7 +136,7 @@ def replicate_model_fn(model_fn, the train_op argument of `EstimatorSpec`. loss_reduction: controls whether losses are summed or averaged. devices: Optional list of devices to replicate the model across. This - argument can be used to replice only on the subset of available GPUs. + argument can be used to replicate only on the subset of available GPUs. If `None`, then all available GPUs are going to be used for replication. If no GPUs are available, then the model is going to be placed on the CPU. diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py index d46a18aacfcd911c56a9f22dc9581060c7b458a6..144b45982c8aec2e2b115c812b24e8843d60ce1e 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import re import shutil import tempfile +from absl.testing import parameterized import numpy as np import six @@ -57,26 +58,19 @@ from tensorflow.python.training import gradient_descent from tensorflow.python.training import training -# TODO(isaprykin): Parametrize all the tests on -# replicate_model_fn._VariableDistributionMode when it's supported. -class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): +class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase, + parameterized.TestCase): def setUp(self): self._model_dir = tempfile.mkdtemp() - def test_complete_flow_with_public_version(self): - return self._complete_flow_with_mode(mode=None) - - def test_complete_flow_with_mode_local_ps_server(self): - return self._complete_flow_with_mode( - replicate_model_fn._VariableDistributionMode. - SHARED_LOCAL_PARAMETER_SERVER) - - def test_complete_flow_with_mode_round_robin(self): - return self._complete_flow_with_mode( - replicate_model_fn._VariableDistributionMode.SHARED_ROUND_ROBIN) - - def _complete_flow_with_mode(self, mode): + @parameterized.named_parameters( + ('PublicInterface', None), + ('ParameterServerMode', replicate_model_fn._VariableDistributionMode. + SHARED_LOCAL_PARAMETER_SERVER), + ('RoundRobinMode', + replicate_model_fn._VariableDistributionMode.SHARED_ROUND_ROBIN)) + def test_complete_flow_with_mode(self, mode): n_classes = 3 input_dimension = 2 batch_size = 12 diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index 180f1b68f3b56113dfbbfc100bd04efc3bb8b31f..ad8568ad44ea84f96b97e98567a276c70520d53d 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -66,6 +66,7 @@ tf_custom_op_py_library( "//tensorflow/python:variables", "//tensorflow/python/estimator", "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", ], ) @@ -223,7 +224,10 @@ py_test( srcs = ["python/ops/kmeans_test.py"], shard_count = 4, srcs_version = "PY2AND3", - tags = ["notsan"], # b/67512932 + tags = [ + "nomac", # b/73741358 + "notsan", # b/67512932 + ], deps = [ ":factorization_py", ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", @@ -238,6 +242,7 @@ py_test( "//tensorflow/python:random_ops", "//tensorflow/python:training", "//tensorflow/python/estimator:run_config", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops.cc b/tensorflow/contrib/factorization/kernels/clustering_ops.cc index dd61f59585aee2e0245cfd6797b313b972c19bc5..2a6c97e8b9526894eba057505a2bf823ad778f56 100644 --- a/tensorflow/contrib/factorization/kernels/clustering_ops.cc +++ b/tensorflow/contrib/factorization/kernels/clustering_ops.cc @@ -353,7 +353,7 @@ class NearestNeighborsOp : public OpKernel { auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); const int64 num_threads = worker_threads.num_threads; // This kernel might be configured to use fewer than the total number of - // available CPUs on the host machine. To avoid descructive interference + // available CPUs on the host machine. To avoid destructive interference // with other jobs running on the host machine, we must only use a fraction // of total available L3 cache. Unfortunately, we cannot query the host // machine to get the number of physical CPUs. So, we use a fixed per-CPU diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py index 054888e734086c153f7af59f4548d4d20abab813..8e0ed1d80ec2603862aedb19cef1532626edb37c 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -106,7 +106,7 @@ class WALSModel(object): # the prep_gramian_op for row(column) can be run. worker_init_op = model.worker_init - # To be run once per interation sweep before the row(column) update + # To be run once per integration sweep before the row(column) update # initialize ops can be run. Note that in the distributed training # situations, this should only be run by the chief trainer. All other # trainers need to block until this is done. @@ -118,9 +118,9 @@ class WALSModel(object): init_row_update_op = model.initialize_row_update_op init_col_update_op = model.initialize_col_update_op - # Ops to upate row(column). This can either take the entire sparse tensor - # or slices of sparse tensor. For distributed trainer, each trainer - # handles just part of the matrix. + # Ops to update row(column). This can either take the entire sparse + # tensor or slices of sparse tensor. For distributed trainer, each + # trainer handles just part of the matrix. _, row_update_op, unreg_row_loss, row_reg, _ = model.update_row_factors( sp_input=matrix_slices_from_queue_for_worker_shard) row_loss = unreg_row_loss + row_reg @@ -220,7 +220,7 @@ class WALSModel(object): in the form of [[w_0, w_1, ...], [w_k, ... ], [...]], with the number of inner lists matching the number of row factor shards and the elements in each inner list are the weights for the rows of the corresponding row - factor shard. In this case, w_ij = unonbserved_weight + + factor shard. In this case, w_ij = unobserved_weight + row_weights[i] * col_weights[j]. - If this is a single non-negative real number, this value is used for all row weights and w_ij = unobserved_weight + row_weights * @@ -435,7 +435,7 @@ class WALSModel(object): gramian: Variable storing the gramian calculated from the factors. Returns: - A op that updates the gramian with the calcuated value from the factors. + A op that updates the gramian with the calculated value from the factors. """ partial_gramians = [] for f in factors: @@ -564,7 +564,7 @@ class WALSModel(object): Note that specifically this initializes the cache of the row and column weights on workers when `use_factors_weights_cache` is True. In this case, - if these weights are being calcualted and reset after the object is created, + if these weights are being calculated and reset after the object is created, it is important to ensure this ops is run afterwards so the cache reflects the correct values. """ diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py index c8137339155ef1da8ee53967eea84a550f12ecbc..002f9cfbddd67b6b124f4e22dd43b808c4d48b2a 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py @@ -210,7 +210,7 @@ class WalsModelTest(test.TestCase): # Test row projection. # Using the specified projection weights for the 2 row feature vectors. - # This is expected to reprodue the same row factors in the model as the + # This is expected to reproduce the same row factors in the model as the # weights and feature vectors are identical to that used in model # training. projected_rows = wals_model.project_row_factors( @@ -283,7 +283,7 @@ class WalsModelTest(test.TestCase): # Test column projection. # Using the specified projection weights for the 3 column feature vectors. - # This is expected to reprodue the same column factors in the model as the + # This is expected to reproduce the same column factors in the model as the # weights and feature vectors are identical to that used in model # training. projected_cols = wals_model.project_col_factors( @@ -385,7 +385,7 @@ class WalsModelTest(test.TestCase): # Test row projection. # Using the specified projection weights for the 2 row feature vectors. - # This is expected to reprodue the same row factors in the model as the + # This is expected to reproduce the same row factors in the model as the # weights and feature vectors are identical to that used in model # training. projected_rows = wals_model.project_row_factors( @@ -462,7 +462,7 @@ class WalsModelTest(test.TestCase): # Test column projection. # Using the specified projection weights for the 2 column feature vectors. - # This is expected to reprodue the same column factors in the model as the + # This is expected to reproduce the same column factors in the model as the # weights and feature vectors are identical to that used in model # training. projected_cols = wals_model.project_col_factors( diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py index 98d6434f4752b224201e38bed05ccd14428a758b..14d4c733e379a35d1ea3085bc633df174d12b01c 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py @@ -280,7 +280,7 @@ class GmmAlgorithm(object): self._define_score_samples() def _define_full_covariance_probs(self, shard_id, shard): - """Defines the full covariance probabilties per example in a class. + """Defines the full covariance probabilities per example in a class. Updates a matrix with dimension num_examples X num_classes. @@ -344,7 +344,7 @@ class GmmAlgorithm(object): def _define_prior_log_prob_operation(self, shard_id): """Computes the prior probability of all samples. - Updates a vector where each item is the prior probabibility of an + Updates a vector where each item is the prior probability of an input example. Args: diff --git a/tensorflow/contrib/factorization/python/ops/gmm_test.py b/tensorflow/contrib/factorization/python/ops/gmm_test.py index 00a4734eb6d89cd02484f1c5161366377cc71208..4fc9c96e9d0a317ef757d5e1bb6563ed7c8832af 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_test.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_test.py @@ -210,7 +210,7 @@ class GMMTestQueues(test.TestCase): return _fn # This test makes sure that there are no deadlocks when using a QueueRunner. - # Note that since cluster initialization is dependendent on inputs, if input + # Note that since cluster initialization is dependent on inputs, if input # is generated using a QueueRunner, one has to make sure that these runners # are started before the initialization. def test_queues(self): diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index c861cfff544a78617aa1ace730b50c094cf16330..38faca119d0b5ee883de3b215428a0db8a021016 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -26,6 +26,7 @@ from tensorflow.contrib.factorization.python.ops import clustering_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.export import export_output +from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -61,8 +62,8 @@ class _LossRelativeChangeHook(session_run_hook.SessionRunHook): loss = run_values.results assert loss is not None if self._prev_loss: - relative_change = (abs(loss - self._prev_loss) / - (1 + abs(self._prev_loss))) + relative_change = ( + abs(loss - self._prev_loss) / (1 + abs(self._prev_loss))) if relative_change < self._tolerance: run_context.request_stop() self._prev_loss = loss @@ -105,24 +106,32 @@ class _InitializeClustersHook(session_run_hook.SessionRunHook): logging.info(e) -def _parse_tensor_or_dict(features): +def _parse_features_if_necessary(features, feature_columns): """Helper function to convert the input points into a usable format. Args: - features: The input points. + features: The input features. + feature_columns: An optionable iterable containing all the feature columns + used by the model. All items in the set should be feature column instances + that can be passed to `tf.feature_column.input_layer`. If this is None, + all features will be used. Returns: - If `features` is a dict of `k` features, each of which is a vector of `n` - scalars, the return value is a Tensor of shape `(n, k)` representing `n` - input points, where the items in the `k` dimension are sorted - lexicographically by `features` key. If `features` is not a dict, it is - returned unmodified. + If `features` is a dict of `k` features (optionally filtered by + `feature_columns`), each of which is a vector of `n` scalars, the return + value is a Tensor of shape `(n, k)` representing `n` input points, where the + items in the `k` dimension are sorted lexicographically by `features` key. + If `features` is not a dict, it is returned unmodified. """ - if isinstance(features, dict): - keys = sorted(features.keys()) - with ops.colocate_with(features[keys[0]]): - features = array_ops.concat([features[k] for k in keys], axis=1) - return features + if not isinstance(features, dict): + return features + + if feature_columns: + return fc.input_layer(features, feature_columns) + + keys = sorted(features.keys()) + with ops.colocate_with(features[keys[0]]): + return array_ops.concat([features[k] for k in keys], axis=1) class _ModelFn(object): @@ -130,7 +139,8 @@ class _ModelFn(object): def __init__(self, num_clusters, initial_clusters, distance_metric, random_seed, use_mini_batch, mini_batch_steps_per_iteration, - kmeans_plus_plus_num_retries, relative_tolerance): + kmeans_plus_plus_num_retries, relative_tolerance, + feature_columns): self._num_clusters = num_clusters self._initial_clusters = initial_clusters self._distance_metric = distance_metric @@ -139,6 +149,7 @@ class _ModelFn(object): self._mini_batch_steps_per_iteration = mini_batch_steps_per_iteration self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries self._relative_tolerance = relative_tolerance + self._feature_columns = feature_columns def model_fn(self, features, mode, config): """Model function for the estimator. @@ -166,7 +177,7 @@ class _ModelFn(object): # input_points is a single Tensor. Therefore, the sharding functionality # in clustering_ops is unused, and some of the values below are lists of a # single item. - input_points = _parse_tensor_or_dict(features) + input_points = _parse_features_if_necessary(features, self._feature_columns) # Let N = the number of input_points. # all_distances: A list of one matrix of shape (N, num_clusters). Each value @@ -233,7 +244,57 @@ class _ModelFn(object): # TODO(agarwal,ands): support sharded input. class KMeansClustering(estimator.Estimator): - """An Estimator for K-Means clustering.""" + """An Estimator for K-Means clustering. + + Example: + ``` + import numpy as np + import tensorflow as tf + + num_points = 100 + dimensions = 2 + points = np.random.uniform(0, 1000, [num_points, dimensions]) + + def input_fn(): + return tf.train.limit_epochs( + tf.convert_to_tensor(points, dtype=tf.float32), num_epochs=1) + + num_clusters = 5 + kmeans = tf.contrib.factorization.KMeansClustering( + num_clusters=num_clusters, use_mini_batch=False) + + # train + num_iterations = 10 + previous_centers = None + for _ in xrange(num_iterations): + kmeans.train(input_fn) + cluster_centers = kmeans.cluster_centers() + if previous_centers is not None: + print 'delta:', cluster_centers - previous_centers + previous_centers = cluster_centers + print 'score:', kmeans.score(input_fn) + print 'cluster centers:', cluster_centers + + # map the input points to their clusters + cluster_indices = list(kmeans.predict_cluster_index(input_fn)) + for i, point in enumerate(points): + cluster_index = cluster_indices[i] + center = cluster_centers[cluster_index] + print 'point:', point, 'is in cluster', cluster_index, 'centered at', center + ``` + + The `SavedModel` saved by the `export_savedmodel` method does not include the + cluster centers. However, the cluster centers may be retrieved by the + latest checkpoint saved during training. Specifically, + ``` + kmeans.cluster_centers() + ``` + is equivalent to + ``` + tf.train.load_variable( + kmeans.model_dir, KMeansClustering.CLUSTER_CENTERS_VAR_NAME) + ``` + """ # Valid values for the distance_metric constructor argument. SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE @@ -253,6 +314,9 @@ class KMeansClustering(estimator.Estimator): CLUSTER_INDEX = 'cluster_index' ALL_DISTANCES = 'all_distances' + # Variable name used by cluster_centers(). + CLUSTER_CENTERS_VAR_NAME = clustering_ops.CLUSTERS_VAR_NAME + def __init__(self, num_clusters, model_dir=None, @@ -263,7 +327,8 @@ class KMeansClustering(estimator.Estimator): mini_batch_steps_per_iteration=1, kmeans_plus_plus_num_retries=2, relative_tolerance=None, - config=None): + config=None, + feature_columns=None): """Creates an Estimator for running KMeans training and inference. This Estimator implements the following variants of the K-means algorithm: @@ -330,6 +395,10 @@ class KMeansClustering(estimator.Estimator): iterations. Stops learning if the loss changes less than this amount. This may not work correctly if `use_mini_batch=True`. config: See @{tf.estimator.Estimator}. + feature_columns: An optionable iterable containing all the feature columns + used by the model. All items in the set should be feature column + instances that can be passed to `tf.feature_column.input_layer`. If this + is None, all features will be used. Raises: ValueError: An invalid argument was passed to `initial_clusters` or @@ -349,7 +418,8 @@ class KMeansClustering(estimator.Estimator): model_fn=_ModelFn( num_clusters, initial_clusters, distance_metric, random_seed, use_mini_batch, mini_batch_steps_per_iteration, - kmeans_plus_plus_num_retries, relative_tolerance).model_fn, + kmeans_plus_plus_num_retries, relative_tolerance, + feature_columns).model_fn, model_dir=model_dir, config=config) @@ -406,4 +476,4 @@ class KMeansClustering(estimator.Estimator): def cluster_centers(self): """Returns the cluster centers.""" - return self.get_variable_value(clustering_ops.CLUSTERS_VAR_NAME) + return self.get_variable_value(KMeansClustering.CLUSTER_CENTERS_VAR_NAME) diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py index f9598bfc08c05ea3bba88b3135da0cf2e6bb0c95..88eb9cf692992fe2e1fc4f060ac98dd721c22307 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py @@ -27,6 +27,7 @@ from sklearn.cluster import KMeans as SklearnKMeans # pylint: disable=g-import-not-at-top from tensorflow.contrib.factorization.python.ops import kmeans as kmeans_lib from tensorflow.python.estimator import run_config +from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -226,6 +227,44 @@ class KMeansTest(KMeansTestBase): self._infer_helper(kmeans, clusters, 10) self._infer_helper(kmeans, clusters, 1) + def _parse_feature_dict_helper(self, features, parsed_feature_dict): + # Perform a sanity check. + self.assertEqual(features.shape, parsed_feature_dict.shape) + self.assertEqual(features.dtype, parsed_feature_dict.dtype) + # Then check that running the tensor yields the original list of points. + with self.test_session() as sess: + parsed_points = sess.run(parsed_feature_dict) + self.assertAllEqual(self.points, parsed_points) + + def test_parse_features(self): + """Tests the various behaviours of kmeans._parse_features_if_necessary.""" + + # No-op if a tensor is passed in. + features = constant_op.constant(self.points) + parsed_features = kmeans_lib._parse_features_if_necessary(features, None) + self.assertAllEqual(features, parsed_features) + + # All values from a feature dict are transformed into a tensor. + feature_dict = { + 'x': [[point[0]] for point in self.points], + 'y': [[point[1]] for point in self.points] + } + parsed_feature_dict = kmeans_lib._parse_features_if_necessary( + feature_dict, None) + self._parse_feature_dict_helper(features, parsed_feature_dict) + + # Only the feature_columns of a feature dict are transformed into a tensor. + feature_dict_with_extras = { + 'foo': 'bar', + 'x': [[point[0]] for point in self.points], + 'baz': {'fizz': 'buzz'}, + 'y': [[point[1]] for point in self.points] + } + feature_columns = [fc.numeric_column(key='x'), fc.numeric_column(key='y')] + parsed_feature_dict = kmeans_lib._parse_features_if_necessary( + feature_dict_with_extras, feature_columns) + self._parse_feature_dict_helper(features, parsed_feature_dict) + class KMeansTestMultiStageInit(KMeansTestBase): @@ -374,7 +413,7 @@ class KMeansCosineDistanceTest(KMeansTestBase): self.assertAllClose(score, self.true_score, atol=1e-2) def test_predict_kmeans_plus_plus(self): - # Most points are concetrated near one center. KMeans++ is likely to find + # Most points are concentrated near one center. KMeans++ is likely to find # the less populated centers. points = np.array( [[2.5, 3.5], [2.5, 3.5], [-2, 3], [-2, 3], [-3, -3], [-3.1, -3.2], @@ -394,7 +433,6 @@ class KMeansCosineDistanceTest(KMeansTestBase): true_assignments = [0] * 2 + [1] * 2 + [2] * 8 true_score = len(points) - np.tensordot( normalize(points), true_centers[true_assignments]) - kmeans = kmeans_lib.KMeansClustering( 3, initial_clusters=self.initial_clusters, @@ -566,7 +604,7 @@ class KMeansTestQueues(test.TestCase): return _fn # This test makes sure that there are no deadlocks when using a QueueRunner. - # Note that since cluster initialization is dependendent on inputs, if input + # Note that since cluster initialization is dependent on inputs, if input # is generated using a QueueRunner, one has to make sure that these runners # are started before the initialization. def test_queues(self): diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py index 4fe22ea26ec5f5a43f1c99d1fee518b1d326c5c9..62db3bb4c40e0b1e7adfeb682734f8efbfff9cdb 100644 --- a/tensorflow/contrib/factorization/python/ops/wals.py +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -235,7 +235,7 @@ def _wals_factorization_model_function(features, labels, mode, params): num_items: An integer, the total number of items of this axis. update_fn: A function that takes one argument (`sp_input`), and that returns a tuple of - * new_factors: A flot Tensor of the factor values after update. + * new_factors: A float Tensor of the factor values after update. * update_op: a TensorFlow op which updates the factors. * loss: A float Tensor, the unregularized loss. * reg_loss: A float Tensor, the regularization loss. diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index 6fc053759c58d30c24657dd22e7d12be46fc7a7e..3614b2b15a6cbdd73f9f24c7e4e4534228d31499 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -25,13 +25,42 @@ py_library( srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ - ":sequential_feature_column", + ":sequence_feature_column", + "//tensorflow/python:util", ], ) py_library( - name = "sequential_feature_column", - srcs = ["python/feature_column/sequential_feature_column.py"], + name = "sequence_feature_column", + srcs = ["python/feature_column/sequence_feature_column.py"], srcs_version = "PY2AND3", - deps = [], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:variable_scope", + "//tensorflow/python/feature_column", + ], +) + +py_test( + name = "sequence_feature_column_test", + srcs = ["python/feature_column/sequence_feature_column_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":sequence_feature_column", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:training", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + ], ) diff --git a/tensorflow/contrib/feature_column/__init__.py b/tensorflow/contrib/feature_column/__init__.py index 6da7b126931effae9cc97091a27070d7013450d4..baa8c1567a5aeb39976ab04c54ae2728ba050a7c 100644 --- a/tensorflow/contrib/feature_column/__init__.py +++ b/tensorflow/contrib/feature_column/__init__.py @@ -19,12 +19,18 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.feature_column.python.feature_column.sequential_feature_column import * +from tensorflow.contrib.feature_column.python.feature_column.sequence_feature_column import * from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import _allowed_symbols = [ + 'sequence_categorical_column_with_hash_bucket', + 'sequence_categorical_column_with_identity', + 'sequence_categorical_column_with_vocabulary_list', + 'sequence_categorical_column_with_vocabulary_file', + 'sequence_input_layer', + 'sequence_numeric_column', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py new file mode 100644 index 0000000000000000000000000000000000000000..555beddeaab419bcb23d06f960d370b706d744c8 --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -0,0 +1,447 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental methods for tf.feature_column sequence input.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import collections + + +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import variable_scope + +# pylint: disable=protected-access +# TODO(b/73827486): Support SequenceExample. + + +def sequence_input_layer( + features, + feature_columns, + weight_collections=None, + trainable=True): + """"Builds input layer for sequence input. + + All `feature_columns` must be sequence dense columns with the same + `sequence_length`. The output of this method can be fed into sequence + networks, such as RNN. + + The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ from + batch to batch. + + If multiple `feature_columns` are given with `Di` `num_elements` each, their + outputs are concatenated. So, the final `Tensor` has shape + `[batch_size, T, D0 + D1 + ... + Dn]`. + + Example: + + ```python + rating = sequence_numeric_column('rating') + watches = sequence_categorical_column_with_identity( + 'watches', num_buckets=1000) + watches_embedding = embedding_column(watches, dimension=10) + columns = [rating, watches] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + features: A dict mapping keys to tensors. + feature_columns: An iterable of dense sequence columns. Valid columns are + - `embedding_column` that wraps a `sequence_categorical_column_with_*` + - `sequence_numeric_column`. + weight_collections: A list of collection names to which the Variable will be + added. Note that variables will also be added to collections + `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`. + trainable: If `True` also add the variable to the graph collection + `GraphKeys.TRAINABLE_VARIABLES`. + + Returns: + An `(input_layer, sequence_length)` tuple where: + - input_layer: A float `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ + from batch to batch. `D` is the sum of `num_elements` for all + `feature_columns`. + - sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence + length for each example. + + Raises: + ValueError: If any of the `feature_columns` is the wrong type. + """ + feature_columns = fc._clean_feature_columns(feature_columns) + for c in feature_columns: + if not isinstance(c, fc._SequenceDenseColumn): + raise ValueError( + 'All feature_columns must be of type _SequenceDenseColumn. ' + 'You can wrap a sequence_categorical_column with an embedding_column ' + 'or indicator_column. ' + 'Given (type {}): {}'.format(type(c), c)) + + with variable_scope.variable_scope( + None, default_name='sequence_input_layer', values=features.values()): + builder = fc._LazyBuilder(features) + output_tensors = [] + sequence_lengths = [] + ordered_columns = [] + for column in sorted(feature_columns, key=lambda x: x.name): + ordered_columns.append(column) + with variable_scope.variable_scope( + None, default_name=column._var_scope_name): + dense_tensor, sequence_length = column._get_sequence_dense_tensor( + builder, + weight_collections=weight_collections, + trainable=trainable) + # Flattens the final dimension to produce a 3D Tensor. + num_elements = column._variable_shape.num_elements() + shape = array_ops.shape(dense_tensor) + output_tensors.append( + array_ops.reshape( + dense_tensor, + shape=array_ops.concat([shape[:2], [num_elements]], axis=0))) + sequence_lengths.append(sequence_length) + fc._verify_static_batch_size_equality(output_tensors, ordered_columns) + fc._verify_static_batch_size_equality(sequence_lengths, ordered_columns) + sequence_length = _assert_all_equal_and_return(sequence_lengths) + return array_ops.concat(output_tensors, -1), sequence_length + + +def sequence_categorical_column_with_identity( + key, num_buckets, default_value=None): + """Returns a feature column that represents sequences of integers. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + watches = sequence_categorical_column_with_identity( + 'watches', num_buckets=1000) + watches_embedding = embedding_column(watches, dimension=10) + columns = [watches_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + num_buckets: Range of inputs. Namely, inputs are expected to be in the + range `[0, num_buckets)`. + default_value: If `None`, this column's graph operations will fail for + out-of-range inputs. Otherwise, this value must be in the range + `[0, num_buckets)`, and will replace out-of-range inputs. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: if `num_buckets` is less than one. + ValueError: if `default_value` is not in range `[0, num_buckets)`. + """ + return fc._SequenceCategoricalColumn( + fc.categorical_column_with_identity( + key=key, + num_buckets=num_buckets, + default_value=default_value)) + + +def sequence_categorical_column_with_hash_bucket( + key, hash_bucket_size, dtype=dtypes.string): + """A sequence of categorical terms where ids are set by hashing. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + tokens = sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=1000) + tokens_embedding = embedding_column(tokens, dimension=10) + columns = [tokens_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + hash_bucket_size: An int > 1. The number of buckets. + dtype: The type of features. Only string and integer types are supported. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: `hash_bucket_size` is not greater than 1. + ValueError: `dtype` is neither string nor integer. + """ + return fc._SequenceCategoricalColumn( + fc.categorical_column_with_hash_bucket( + key=key, + hash_bucket_size=hash_bucket_size, + dtype=dtype)) + + +def sequence_categorical_column_with_vocabulary_file( + key, vocabulary_file, vocabulary_size=None, num_oov_buckets=0, + default_value=None, dtype=dtypes.string): + """A sequence of categorical terms where ids use a vocabulary file. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + states = sequence_categorical_column_with_vocabulary_file( + key='states', vocabulary_file='/us/states.txt', vocabulary_size=50, + num_oov_buckets=5) + states_embedding = embedding_column(states, dimension=10) + columns = [states_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + vocabulary_file: The vocabulary file name. + vocabulary_size: Number of the elements in the vocabulary. This must be no + greater than length of `vocabulary_file`, if less than length, later + values are ignored. If None, it is set to the length of `vocabulary_file`. + num_oov_buckets: Non-negative integer, the number of out-of-vocabulary + buckets. All out-of-vocabulary inputs will be assigned IDs in the range + `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of + the input value. A positive `num_oov_buckets` can not be specified with + `default_value`. + default_value: The integer ID value to return for out-of-vocabulary feature + values, defaults to `-1`. This can not be specified with a positive + `num_oov_buckets`. + dtype: The type of features. Only string and integer types are supported. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: `vocabulary_file` is missing or cannot be opened. + ValueError: `vocabulary_size` is missing or < 1. + ValueError: `num_oov_buckets` is a negative integer. + ValueError: `num_oov_buckets` and `default_value` are both specified. + ValueError: `dtype` is neither string nor integer. + """ + return fc._SequenceCategoricalColumn( + fc.categorical_column_with_vocabulary_file( + key=key, + vocabulary_file=vocabulary_file, + vocabulary_size=vocabulary_size, + num_oov_buckets=num_oov_buckets, + default_value=default_value, + dtype=dtype)) + + +def sequence_categorical_column_with_vocabulary_list( + key, vocabulary_list, dtype=None, default_value=-1, num_oov_buckets=0): + """A sequence of categorical terms where ids use an in-memory list. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + colors = sequence_categorical_column_with_vocabulary_list( + key='colors', vocabulary_list=('R', 'G', 'B', 'Y'), + num_oov_buckets=2) + colors_embedding = embedding_column(colors, dimension=3) + columns = [colors_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + vocabulary_list: An ordered iterable defining the vocabulary. Each feature + is mapped to the index of its value (if present) in `vocabulary_list`. + Must be castable to `dtype`. + dtype: The type of features. Only string and integer types are supported. + If `None`, it will be inferred from `vocabulary_list`. + default_value: The integer ID value to return for out-of-vocabulary feature + values, defaults to `-1`. This can not be specified with a positive + `num_oov_buckets`. + num_oov_buckets: Non-negative integer, the number of out-of-vocabulary + buckets. All out-of-vocabulary inputs will be assigned IDs in the range + `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a + hash of the input value. A positive `num_oov_buckets` can not be specified + with `default_value`. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: if `vocabulary_list` is empty, or contains duplicate keys. + ValueError: `num_oov_buckets` is a negative integer. + ValueError: `num_oov_buckets` and `default_value` are both specified. + ValueError: if `dtype` is not integer or string. + """ + return fc._SequenceCategoricalColumn( + fc.categorical_column_with_vocabulary_list( + key=key, + vocabulary_list=vocabulary_list, + dtype=dtype, + default_value=default_value, + num_oov_buckets=num_oov_buckets)) + + +def sequence_numeric_column( + key, + shape=(1,), + default_value=0., + dtype=dtypes.float32): + """Returns a feature column that represents sequences of numeric data. + + Example: + + ```python + temperature = sequence_numeric_column('temperature') + columns = [temperature] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input features. + shape: The shape of the input data per sequence id. E.g. if `shape=(2,)`, + each example must contain `2 * sequence_length` values. + default_value: A single value compatible with `dtype` that is used for + padding the sparse data into a dense `Tensor`. + dtype: The type of values. + + Returns: + A `_SequenceNumericColumn`. + + Raises: + TypeError: if any dimension in shape is not an int. + ValueError: if any dimension in shape is not a positive integer. + ValueError: if `dtype` is not convertible to `tf.float32`. + """ + shape = fc._check_shape(shape=shape, key=key) + if not (dtype.is_integer or dtype.is_floating): + raise ValueError('dtype must be convertible to float. ' + 'dtype: {}, key: {}'.format(dtype, key)) + + return _SequenceNumericColumn( + key, + shape=shape, + default_value=default_value, + dtype=dtype) + + +def _assert_all_equal_and_return(tensors, name=None): + """Asserts that all tensors are equal and returns the first one.""" + with ops.name_scope(name, 'assert_all_equal', values=tensors): + if len(tensors) == 1: + return tensors[0] + assert_equal_ops = [] + for t in tensors[1:]: + assert_equal_ops.append(check_ops.assert_equal(tensors[0], t)) + with ops.control_dependencies(assert_equal_ops): + return array_ops.identity(tensors[0]) + + +class _SequenceNumericColumn( + fc._SequenceDenseColumn, + collections.namedtuple( + '_SequenceNumericColumn', + ['key', 'shape', 'default_value', 'dtype'])): + """Represents sequences of numeric data.""" + + @property + def name(self): + return self.key + + @property + def _parse_example_spec(self): + return {self.key: parsing_ops.VarLenFeature(self.dtype)} + + def _transform_feature(self, inputs): + return inputs.get(self.key) + + @property + def _variable_shape(self): + return tensor_shape.TensorShape(self.shape) + + def _get_sequence_dense_tensor( + self, inputs, weight_collections=None, trainable=None): + # Do nothing with weight_collections and trainable since no variables are + # created in this function. + del weight_collections + del trainable + sp_tensor = inputs.get(self) + dense_tensor = sparse_ops.sparse_tensor_to_dense( + sp_tensor, default_value=self.default_value) + # Reshape into [batch_size, T, variable_shape]. + dense_shape = array_ops.concat( + [array_ops.shape(dense_tensor)[:1], [-1], self._variable_shape], + axis=0) + dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape) + sequence_length = fc._sequence_length_from_sparse_tensor( + sp_tensor, num_elements=self._variable_shape.num_elements()) + return fc._SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) + +# pylint: enable=protected-access diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py new file mode 100644 index 0000000000000000000000000000000000000000..88f5d535162939e063eb1e7f43d495137c5adef4 --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -0,0 +1,816 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sequential_feature_column.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np + +from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.feature_column.feature_column import _LazyBuilder +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.platform import test +from tensorflow.python.training import monitored_session + + +class SequenceInputLayerTest(test.TestCase): + + def test_embedding_column(self): + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [2, 0] + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + + embedding_dimension_a = 2 + embedding_values_a = ( + (1., 2.), # id 0 + (3., 4.), # id 1 + (5., 6.) # id 2 + ) + embedding_dimension_b = 3 + embedding_values_b = ( + (11., 12., 13.), # id 0 + (14., 15., 16.), # id 1 + (17., 18., 19.) # id 2 + ) + def _get_initializer(embedding_dimension, embedding_values): + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + return _initializer + + expected_input_layer = [ + # example 0, ids_a [2], ids_b [1] + [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [2, 0] + [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]], + ] + expected_sequence_length = [1, 2] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = fc.embedding_column( + categorical_column_a, dimension=embedding_dimension_a, + initializer=_get_initializer(embedding_dimension_a, embedding_values_a)) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_b = fc.embedding_column( + categorical_column_b, dimension=embedding_dimension_b, + initializer=_get_initializer(embedding_dimension_b, embedding_values_b)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + # Test that columns are reordered alphabetically. + feature_columns=[embedding_column_b, embedding_column_a]) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('sequence_input_layer/aaa_embedding/embedding_weights:0', + 'sequence_input_layer/bbb_embedding/embedding_weights:0'), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values_a, global_vars[0].eval(session=sess)) + self.assertAllEqual(embedding_values_b, global_vars[1].eval(session=sess)) + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_embedding_column_with_non_sequence_categorical(self): + """Tests that error is raised for non-sequence categorical column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = fc.embedding_column( + categorical_column_a, dimension=2) + + with self.assertRaisesRegexp( + ValueError, + r'In embedding_column: aaa_embedding\. categorical_column must be of ' + r'type _SequenceCategoricalColumn to use sequence_input_layer\.'): + _, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[embedding_column_a]) + + def test_indicator_column(self): + vocabulary_size_a = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + vocabulary_size_b = 2 + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [1, 0] + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 1, 0), + dense_shape=(2, 2)) + + expected_input_layer = [ + # example 0, ids_a [2], ids_b [1] + [[0., 0., 1., 0., 1.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [1, 0] + [[1., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]], + ] + expected_sequence_length = [1, 2] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size_a) + indicator_column_a = fc.indicator_column(categorical_column_a) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size_b) + indicator_column_b = fc.indicator_column(categorical_column_b) + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + # Test that columns are reordered alphabetically. + feature_columns=[indicator_column_b, indicator_column_a]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_indicator_column_with_non_sequence_categorical(self): + """Tests that error is raised for non-sequence categorical column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column_a = fc.indicator_column(categorical_column_a) + + with self.assertRaisesRegexp( + ValueError, + r'In indicator_column: aaa_indicator\. categorical_column must be of ' + r'type _SequenceCategoricalColumn to use sequence_input_layer\.'): + _, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[indicator_column_a]) + + def test_numeric_column(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + expected_input_layer = [ + [[0.], [1.]], + [[10.], [0.]], + ] + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa') + + input_layer, sequence_length = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_numeric_column_multi_dim(self): + """Tests sequence_input_layer for multi-dimensional numeric_column.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), + (1, 0), (1, 1), (1, 2), (1, 3)), + values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + dense_shape=(2, 8)) + # The output of numeric_column._get_dense_tensor should be flattened. + expected_input_layer = [ + [[0., 1., 2., 3.], [4., 5., 6., 7.]], + [[10., 11., 12., 13.], [0., 0., 0., 0.]], + ] + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_not_equal(self): + """Tests that an error is raised when sequence lengths are not equal.""" + # Input a with sequence_length = [2, 1] + sparse_input_a = sparse_tensor.SparseTensorValue( + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + # Input b with sequence_length = [1, 1] + sparse_input_b = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0)), + values=(1., 10.), + dense_shape=(2, 2)) + numeric_column_a = sfc.sequence_numeric_column('aaa') + numeric_column_b = sfc.sequence_numeric_column('bbb') + + _, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + feature_columns=[numeric_column_a, numeric_column_b]) + + with monitored_session.MonitoredSession() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[Condition x == y did not hold element-wise:\] ' + r'\[x \(sequence_input_layer/aaa/sequence_length:0\) = \] \[2 1\] ' + r'\[y \(sequence_input_layer/bbb/sequence_length:0\) = \] \[1 1\]'): + sess.run(sequence_length) + + +class InputLayerTest(test.TestCase): + """Tests input_layer with sequence feature columns.""" + + def test_embedding_column(self): + """Tests that error is raised for sequence embedding column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = fc.embedding_column( + categorical_column_a, dimension=2) + + with self.assertRaisesRegexp( + ValueError, + r'In embedding_column: aaa_embedding\. categorical_column must not be ' + r'of type _SequenceCategoricalColumn\.'): + _ = fc.input_layer( + features={'aaa': sparse_input}, + feature_columns=[embedding_column_a]) + + def test_indicator_column(self): + """Tests that error is raised for sequence indicator column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column_a = fc.indicator_column(categorical_column_a) + + with self.assertRaisesRegexp( + ValueError, + r'In indicator_column: aaa_indicator\. categorical_column must not be ' + r'of type _SequenceCategoricalColumn\.'): + _ = fc.input_layer( + features={'aaa': sparse_input}, + feature_columns=[indicator_column_a]) + + +def _assert_sparse_tensor_value(test_case, expected, actual): + _assert_sparse_tensor_indices_shape(test_case, expected, actual) + + test_case.assertEqual( + np.array(expected.values).dtype, np.array(actual.values).dtype) + test_case.assertAllEqual(expected.values, actual.values) + + +def _assert_sparse_tensor_indices_shape(test_case, expected, actual): + test_case.assertEqual(np.int64, np.array(actual.indices).dtype) + test_case.assertAllEqual(expected.indices, actual.indices) + + test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype) + test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) + + +class SequenceCategoricalColumnWithIdentityTest(test.TestCase): + + def test_get_sparse_tensors(self): + column = sfc.sequence_categorical_column_with_identity( + 'aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + expected_sparse_ids = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=np.array((1, 2, 0), dtype=np.int64), + dense_shape=(2, 2, 1)) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_value( + self, + expected_sparse_ids, + id_weight_pair.id_tensor.eval(session=sess)) + + def test_get_sparse_tensors_inputs3d(self): + """Tests _get_sparse_tensors when the input is already 3D Tensor.""" + column = sfc.sequence_categorical_column_with_identity( + 'aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=(1, 2, 0), + dense_shape=(2, 2, 1)) + + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'Column aaa expected ID tensor of rank 2\.\s*' + r'id_tensor shape:\s*\[2 2 1\]'): + id_weight_pair = column._get_sparse_tensors( + _LazyBuilder({'aaa': inputs})) + with monitored_session.MonitoredSession() as sess: + id_weight_pair.id_tensor.eval(session=sess) + + +class SequenceCategoricalColumnWithHashBucketTest(test.TestCase): + + def test_get_sparse_tensors(self): + column = sfc.sequence_categorical_column_with_hash_bucket( + 'aaa', hash_bucket_size=10) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('omar', 'stringer', 'marlo'), + dense_shape=(2, 2)) + + expected_sparse_ids = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + # Ignored to avoid hash dependence in test. + values=np.array((0, 0, 0), dtype=np.int64), + dense_shape=(2, 2, 1)) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_indices_shape( + self, + expected_sparse_ids, + id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceCategoricalColumnWithVocabularyFileTest(test.TestCase): + + def _write_vocab(self, vocab_strings, file_name): + vocab_file = os.path.join(self.get_temp_dir(), file_name) + with open(vocab_file, 'w') as f: + f.write('\n'.join(vocab_strings)) + return vocab_file + + def setUp(self): + super(SequenceCategoricalColumnWithVocabularyFileTest, self).setUp() + + vocab_strings = ['omar', 'stringer', 'marlo'] + self._wire_vocabulary_file_name = self._write_vocab(vocab_strings, + 'wire_vocabulary.txt') + self._wire_vocabulary_size = 3 + + def test_get_sparse_tensors(self): + column = sfc.sequence_categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + expected_sparse_ids = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=(2, 2, 1)) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_value( + self, + expected_sparse_ids, + id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceCategoricalColumnWithVocabularyListTest(test.TestCase): + + def test_get_sparse_tensors(self): + column = sfc.sequence_categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + expected_sparse_ids = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=(2, 2, 1)) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_value( + self, + expected_sparse_ids, + id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceEmbeddingColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)) + + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + expected_lookups = [ + # example 0, ids [2] + [[7., 11.], [0., 0.]], + # example 1, ids [0, 1] + [[1., 2.], [3., 5.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [1] + [[3., 5.], [0., 0.]], + ] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, + initializer=_initializer) + + embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('embedding_weights:0',), tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual(expected_lookups, embedding_lookup.eval(session=sess)) + + def test_sequence_length(self): + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + expected_sequence_length = [1, 2] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=2) + + _, sequence_length = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + sequence_length = sess.run(sequence_length) + self.assertAllEqual(expected_sequence_length, sequence_length) + self.assertEqual(np.int64, sequence_length.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length = [0, 1, 2, 0, 1, 0] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=2) + + _, sequence_length = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +class SequenceIndicatorColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)) + + expected_lookups = [ + # example 0, ids [2] + [[0., 0., 1.], [0., 0., 0.]], + # example 1, ids [0, 1] + [[1., 0., 0.], [0., 1., 0.]], + # example 2, ids [] + [[0., 0., 0.], [0., 0., 0.]], + # example 3, ids [1] + [[0., 1., 0.], [0., 0., 0.]], + ] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column = fc.indicator_column(categorical_column) + + indicator_tensor, _ = indicator_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_lookups, indicator_tensor.eval(session=sess)) + + def test_sequence_length(self): + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + expected_sequence_length = [1, 2] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column = fc.indicator_column(categorical_column) + + _, sequence_length = indicator_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + sequence_length = sess.run(sequence_length) + self.assertAllEqual(expected_sequence_length, sequence_length) + self.assertEqual(np.int64, sequence_length.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length = [0, 1, 2, 0, 1, 0] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column = fc.indicator_column(categorical_column) + + _, sequence_length = indicator_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +class SequenceNumericColumnTest(test.TestCase): + + def test_defaults(self): + a = sfc.sequence_numeric_column('aaa') + self.assertEqual('aaa', a.key) + self.assertEqual('aaa', a.name) + self.assertEqual('aaa', a._var_scope_name) + self.assertEqual((1,), a.shape) + self.assertEqual(0., a.default_value) + self.assertEqual(dtypes.float32, a.dtype) + + def test_shape_saved_as_tuple(self): + a = sfc.sequence_numeric_column('aaa', shape=[1, 2]) + self.assertEqual((1, 2), a.shape) + + def test_shape_must_be_positive_integer(self): + with self.assertRaisesRegexp(TypeError, 'shape dimensions must be integer'): + sfc.sequence_numeric_column('aaa', shape=[1.0]) + + with self.assertRaisesRegexp( + ValueError, 'shape dimensions must be greater than 0'): + sfc.sequence_numeric_column('aaa', shape=[0]) + + def test_dtype_is_convertible_to_float(self): + with self.assertRaisesRegexp( + ValueError, 'dtype must be convertible to float'): + sfc.sequence_numeric_column('aaa', dtype=dtypes.string) + + def test_get_sequence_dense_tensor(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + expected_dense_tensor = [ + [[0.], [1.]], + [[10.], [0.]], + ] + numeric_column = sfc.sequence_numeric_column('aaa') + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_get_sequence_dense_tensor_with_shape(self): + """Tests get_sequence_dense_tensor with shape !=(1,).""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0., 1., 2.], [3., 4., 5.]] + # example 1, [[10., 11., 12.]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), + (1, 0), (1, 1), (1, 2)), + values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), + dense_shape=(2, 6)) + expected_dense_tensor = [ + [[0., 1., 2.], [3., 4., 5.]], + [[10., 11., 12.], [0., 0., 0.]], + ] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_get_dense_tensor_multi_dim(self): + """Tests get_sequence_dense_tensor for multi-dim numeric_column.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), + (1, 0), (1, 1), (1, 2), (1, 3)), + values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + dense_shape=(2, 8)) + expected_dense_tensor = [ + [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]], + [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]], + ] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_sequence_length(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0., 1., 2.], [3., 4., 5.]] + # example 1, [[10., 11., 12.]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), + (1, 0), (1, 1), (1, 2)), + values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), + dense_shape=(2, 6)) + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) + + _, sequence_length = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + sequence_length = sess.run(sequence_length) + self.assertAllEqual(expected_sequence_length, sequence_length) + self.assertEqual(np.int64, sequence_length.dtype) + + def test_sequence_length_with_shape(self): + """Tests _sequence_length with shape !=(1,).""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa') + + _, sequence_length = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [] + # example 1, values [[0.], [1.]] + # example 2, [[2.]] + # example 3, values [] + # example 4, [[3.]] + # example 5, values [] + indices=((1, 0), (1, 1), (2, 0), (4, 0)), + values=(0., 1., 2., 3.), + dense_shape=(6, 2)) + expected_sequence_length = [0, 2, 1, 0, 1, 0] + numeric_column = sfc.sequence_numeric_column('aaa') + + _, sequence_length = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index e61221a6b0d34373279a379f356c99c379488182..35341406a08dc681c861aea30fcff784e3b963ef 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -256,6 +256,9 @@ Status ReadInfoFile(const string& filename, uint32* width, uint32* height, if (p != std::string::npos) { string rgb24 = line.substr(p + 9, line.find(" ", p + 9)); rgb24 = rgb24.substr(0, rgb24.find(",")); + // Strip anything after " ", in case the format is + // `640x360 [SAR 1:1 DAR 16:9]` + rgb24 = rgb24.substr(0, rgb24.find(" ")); string rgb24_width = rgb24.substr(0, rgb24.find("x")); string rgb24_height = rgb24.substr(rgb24_width.length() + 1); if (strings::safe_strtou32(rgb24_width, &width_value) && @@ -270,8 +273,10 @@ Status ReadInfoFile(const string& filename, uint32* width, uint32* height, // We only look for the first stream mapping to have the number of the // frames. // Once processed we will not further process stream mapping section. - if (line.find("frame= ") == 0) { - string number = line.substr(8, line.find(" ", 8)); + if (line.find("frame=") == 0) { + // The format might be `frame= 166 ` or `frame=12488 ` + string number = line.substr(6); + number = number.substr(number.find_first_not_of(" ")); number = number.substr(0, number.find(" ")); if (strings::safe_strtou32(number, &frames_value)) { in_mapping = false; diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index dbdb5cfaaca1a687fefb81cee200295d5cbb7fd5..ac043fda0638e61f422e769ab3047a53a1b377bd 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -28,7 +28,6 @@ tf_custom_op_py_library( "python/framework/graph_util.py", "python/framework/tensor_util.py", "python/ops/__init__.py", - "python/ops/accumulate_n_v2.py", "python/ops/arg_scope.py", "python/ops/audio_ops.py", "python/ops/checkpoint_ops.py", @@ -63,7 +62,9 @@ tf_custom_op_py_library( "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:script_ops", + "//tensorflow/python:smart_cond", "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", "//tensorflow/python:state_ops_gen", @@ -161,23 +162,6 @@ py_test( ], ) -py_test( - name = "accumulate_n_v2_test", - size = "small", - srcs = ["python/ops/accumulate_n_v2_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":framework_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:platform_test", - "//tensorflow/python:variables", - "//third_party/py/numpy", - ], -) - cuda_py_test( name = "critical_section_test", size = "medium", @@ -196,26 +180,6 @@ cuda_py_test( ], ) -py_test( - name = "accumulate_n_v2_eager_test", - size = "small", - srcs = ["python/ops/accumulate_n_v2_eager_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":framework_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python/eager:backprop", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:tape", - "//third_party/py/numpy", - ], -) - py_test( name = "ops_test", size = "small", diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index deeb5bec79341f3e0468a127aeead69f960114d8..cbb68bd3eb257f9472515e5c29ce4f02057be321 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -71,6 +71,7 @@ See the @{$python/contrib.framework} guide. @@model_variable @@variable @@VariableDeviceChooser +@@convolutional_delta_orthogonal @@zero_initializer @@load_checkpoint @@ -82,11 +83,16 @@ See the @{$python/contrib.framework} guide. @@load_linear_multiclass_bias_initializer @@load_variable_slot_initializer +@@argsort @@py_func @@sort @@get_placeholders +@@smart_cond +@@smart_constant_value +@@smart_case + @@CriticalSection @@BoundedTensorSpec @@ -104,10 +110,12 @@ from tensorflow.contrib.framework.python.ops import * from tensorflow.python.framework.ops import prepend_name_scope from tensorflow.python.framework.ops import strip_name_scope +from tensorflow.python.framework.smart_cond import smart_case +from tensorflow.python.framework.smart_cond import smart_cond +from tensorflow.python.framework.smart_cond import smart_constant_value from tensorflow.python.framework.tensor_spec import BoundedTensorSpec from tensorflow.python.framework.tensor_spec import TensorSpec -from tensorflow.python.ops.control_flow_ops import smart_cond -from tensorflow.python.ops.control_flow_ops import smart_constant_value +from tensorflow.python.ops.init_ops import convolutional_delta_orthogonal from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = ['nest'] diff --git a/tensorflow/contrib/framework/python/framework/experimental_test.py b/tensorflow/contrib/framework/python/framework/experimental_test.py index 8e54e09e04ee3c0ddbd4fa84cc0912cb70c93e62..cfdc7df7d8fd4c1406bf447a79038ac33b11e047 100644 --- a/tensorflow/contrib/framework/python/framework/experimental_test.py +++ b/tensorflow/contrib/framework/python/framework/experimental_test.py @@ -49,7 +49,6 @@ class ExperimentalTest(test.TestCase): "\nTHIS FUNCTION IS EXPERIMENTAL. It may change or " "be removed at any time, and without warning." "\n" - "\n" "\nArgs:" "\n arg0: Arg 0." "\n arg1: Arg 1." diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py index 49eec3a3f1a0f357ea3adfade51e71cb0f89942d..2703224b1bf62831b6088558d4f93950fe938c10 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util.py +++ b/tensorflow/contrib/framework/python/framework/graph_util.py @@ -85,14 +85,19 @@ def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, if n not in reachable_by_input and n not in output_nodes_set: # n is between input and output, i.e., part of the fused op next_to_visit = [n] + visited = set() while next_to_visit: cur_node = next_to_visit[0] + visited.add(cur_node) del next_to_visit[0] if cur_node in reachable_by_input and cur_node not in input_nodes_set: raise TypeError("Node %s uses input %s not in input_nodes." % (n, cur_node)) if cur_node not in input_nodes_set: - next_to_visit += name_to_input_name[cur_node] + next_to_visit += [ + input_node for input_node in name_to_input_name[cur_node] + if input_node not in visited + ] elif n not in reachable_by_input: nodes_post_output.append(n) diff --git a/tensorflow/contrib/framework/python/framework/graph_util_test.py b/tensorflow/contrib/framework/python/framework/graph_util_test.py index b8a6d109e19211d271c2b15bac66ddacd38fe395..812c5fbd8cb759aef6eb1aad532c03794b2ceaf4 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util_test.py +++ b/tensorflow/contrib/framework/python/framework/graph_util_test.py @@ -42,7 +42,8 @@ class GraphUtilTest(test.TestCase): graph_def = graph_pb2.GraphDef() node_a = GetNewNode('A', 'Placeholder', []) node_b = GetNewNode('B', 'Op1', ['A']) - node_c = GetNewNode('C', 'Op1', ['B']) + # A loop in the part that will be fused. + node_c = GetNewNode('C', 'Op1', ['B', 'C']) node_d = GetNewNode('D', 'Op1', ['C']) node_e = GetNewNode('E', 'Op1', ['D']) graph_def.node.extend([node_a, node_b, node_c, node_d, node_e]) diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py index 8cdb340f2ddd9b3a7f55c1937ef045f4627e99be..a2834b648933772cab53002462c3edbe9a553e94 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py @@ -209,6 +209,7 @@ class WithShapeTest(test.TestCase): self.assertRaisesRegexp(errors_impl.OpError, "Wrong shape", tensor_2x2.eval, {tensor_no_shape: [42.0]}) + @test_util.enable_c_shapes def test_with_shape_partial(self): with self.test_session(): tensor_partial_shape = array_ops.placeholder(dtypes.float32) diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py deleted file mode 100644 index 476528b0dd3df05239d5dc402b466e06dd789985..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Ops that will eventually be folded into tensorflow/python/ops/math_ops.py -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -from tensorflow.python.eager import context -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_math_ops -from tensorflow.python.ops import math_ops - - - -def accumulate_n_v2(inputs, shape=None, tensor_dtype=None, name=None): - """Returns the element-wise sum of a list of tensors. - - Optionally, pass `shape` and `tensor_dtype` for shape and type checking, - otherwise, these are inferred. - - `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not - wait for all of its inputs to be ready before beginning to sum. This can - save memory if inputs are ready at different times, since minimum temporary - storage is proportional to the output size rather than the inputs size. - - Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. - - For example: - - ```python - a = tf.constant([[1, 2], [3, 4]]) - b = tf.constant([[5, 0], [0, 6]]) - tf.accumulate_n_v2([a, b, a]) # [[7, 4], [6, 14]] - - # Explicitly pass shape and type - tf.accumulate_n_v2([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) - # [[7, 4], - # [6, 14]] - ``` - - Args: - inputs: A list of `Tensor` objects, each with same shape and type. - shape: Shape of elements of `inputs`. - tensor_dtype: The type of `inputs`. - name: A name for the operation (optional). - - Returns: - A `Tensor` of same shape and type as the elements of `inputs`. - - Raises: - ValueError: If `inputs` don't all have same shape and dtype or the shape - cannot be inferred. - """ - _INPUTS_ERR_MSG = ValueError("inputs must be a list of at least one Tensor" - "with the same dtype and shape") - if not inputs or not isinstance(inputs, (list, tuple)): - raise _INPUTS_ERR_MSG - inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs) - if not all(isinstance(x, ops.Tensor) for x in inputs): - raise _INPUTS_ERR_MSG - if not all(x.dtype == inputs[0].dtype for x in inputs): - raise _INPUTS_ERR_MSG - if shape is not None: - shape = tensor_shape.as_shape(shape) - else: - shape = tensor_shape.unknown_shape() - for input_tensor in inputs: - if isinstance(input_tensor, ops.Tensor): - shape = shape.merge_with(input_tensor.get_shape()) - - # tensor_dtype is for safety only; operator's output type computed in C++ - if tensor_dtype is not None and tensor_dtype != inputs[0].dtype: - raise TypeError("tensor_dtype is {}, but input is of type {}" - .format(tensor_dtype, inputs[0].dtype)) - - if len(inputs) == 1 and name is None: - return inputs[0] - elif len(inputs) == 1 and name is not None: - return array_ops.identity(inputs[0], name=name) - elif context.in_eager_mode(): - # TemporaryVariable not currently supported in eager mode; fall back - # onto AddN for now. - # TODO(frreiss) remove this once the lifetime of eager variables gets - # addressed - return math_ops.add_n(inputs, name=name) - else: - return gen_math_ops._accumulate_nv2(inputs, name=name, shape=shape) - -# The following code should eventually be merged into -# tensorflow/python/ops/math_grad.py -@ops.RegisterGradient("AccumulateNV2") -def _AddNGrad(op, grad): - """Same as gradient for AddN. Copies the gradient to all inputs.""" - # Not broadcasting. - return [grad] * len(op.inputs) diff --git a/tensorflow/contrib/framework/python/ops/arg_scope.py b/tensorflow/contrib/framework/python/ops/arg_scope.py index 409657fe1da0e5540cd2ad6070d86737c039e91f..3cad1fee1984042e3a9ab91a0af70cbaca25cece 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope.py @@ -142,7 +142,7 @@ def arg_scope(list_ops_or_scope, **kwargs): else: # Assumes that list_ops_or_scope is a list/tuple of ops with kwargs. if not isinstance(list_ops_or_scope, (list, tuple)): - raise TypeError('list_ops_or_scope must either be a list/tuple or reused' + raise TypeError('list_ops_or_scope must either be a list/tuple or reused ' 'scope (i.e. dict)') try: current_scope = current_arg_scope().copy() diff --git a/tensorflow/contrib/framework/python/ops/critical_section_ops.py b/tensorflow/contrib/framework/python/ops/critical_section_ops.py index 3c5c55ed656432a33f19462130a9e58c2ab14efb..bd764ed57a6da0a4d356235108e998a80ac34362 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_ops.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_ops.py @@ -24,10 +24,8 @@ import collections # from tensorflow.core.protobuf import critical_section_pb2 from tensorflow.python.eager import context -from tensorflow.python.eager import function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_resource_variable_ops @@ -48,6 +46,26 @@ class _ExecutionSignature( pass +def _identity(x): + """Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`.""" + if isinstance(x, tensor_array_ops.TensorArray): + return x.identity() + elif isinstance(x, ops.Operation): + return control_flow_ops.group(x) + elif context.executing_eagerly() and x is None: + return None + else: + return array_ops.identity(x) + + +def _get_colocation(op): + """Get colocation symbol from op, if any.""" + try: + return op.get_attr("_class") + except ValueError: + return None + + class CriticalSection(object): """Critical section. @@ -143,7 +161,7 @@ class CriticalSection(object): def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name """Initialize the CriticalSection from constructor arguments.""" with ops.name_scope(name, "CriticalSection", []) as name: - with ops.control_dependencies(None): + with ops.init_scope(): # pylint: disable=protected-access container = ops.get_default_graph()._container # pylint: enable=protected-access @@ -154,7 +172,7 @@ class CriticalSection(object): self._handle = gen_resource_variable_ops.mutex_v2( shared_name=shared_name, container=container, name=name) - if context.in_graph_mode(): + if not context.executing_eagerly(): ops.add_to_collections(CRITICAL_SECTIONS, self) @property @@ -180,8 +198,8 @@ class CriticalSection(object): The tensors returned from `fn(*args, **kwargs)`. Raises: - ValueError: If `fn` attempts to use this `CriticalSection` in any nested - way. + ValueError: If `fn` attempts to lock this `CriticalSection` in any nested + or lazy way that may cause a deadlock. ValueError: If `exclusive_resource_access` is not provided (is `True`) and another `CriticalSection` has an execution requesting the same resources as in `*args`, `**kwargs`, and any additionaly captured @@ -193,67 +211,52 @@ class CriticalSection(object): exclusive_resource_access = kwargs.pop("exclusive_resource_access", True) with ops.name_scope(name, "critical_section_execute", []): + + # Ensure that mutex locking only happens *after* all args and + # kwargs have been executed. This avoids certain types of deadlocks. lock = gen_resource_variable_ops.mutex_lock(self._handle) - with ops.control_dependencies([lock]): - c_known_ops = set() - c_captured_tensors = set() + if not context.executing_eagerly(): + # NOTE(ebrevdo): This is to ensure we don't pick up spurious + # Operations created by other threads. + with ops.get_default_graph()._lock: # pylint: disable=protected-access + existing_ops = ops.get_default_graph().get_operations() + with ops.control_dependencies([lock]): + r = fn(*args, **kwargs) + # TODO(ebrevdo): If creating critical sections in a python loop, this + # makes graph creation time quadratic. Revisit if this + # becomes a problem. + created_ops = (set(ops.get_default_graph().get_operations()) + .difference(existing_ops)) + else: + with ops.control_dependencies([lock]): + r = fn(*args, **kwargs) + + if not context.executing_eagerly(): + self._add_control_dependencies_to_lock(created_ops, lock.op) - def add_op_internal(op): - c_known_ops.add(op) - for i in op.inputs: - if i.op not in c_known_ops: - c_captured_tensors.add(i) + # captured_resources is a list of resources that are directly + # accessed only by ops created during fn(), not by any + # ancestors of those ops in the graph. + captured_resources = set([ + input_ for op in created_ops + for input_ in op.inputs + if input_.dtype == dtypes.resource + ]) - c = function.HelperContext(add_op_internal) - with c: - r = fn(*args, **kwargs) + # NOTE(ebrevdo): The only time self._is_self_handle() is True + # in this call is if one of the recently created ops, within + # the execute(), themselves attempt to access the + # CriticalSection. This will cause a deadlock. + if any(self._is_self_handle(x) for x in captured_resources): + raise ValueError("The function fn attempts to directly access the " + "CriticalSection in which it would be running. " + "This is illegal and would cause deadlocks.") - resource_inputs = set([ - x for x in - list(nest.flatten(args)) + nest.flatten(kwargs.values()) + - list(c_captured_tensors) - if tensor_util.is_tensor(x) and x.dtype == dtypes.resource]) - - if self._handle in resource_inputs: - raise ValueError("The function fn attempts to access the " - "CriticalSection in which it would be running. " - "This is illegal and would cause deadlocks. " - "CriticalSection: %s." % self._handle) - - if context.in_graph_mode(): - # Collections and op introspection does not work in eager - # mode. This is generally ok; since eager mode (as of - # writing) executes sequentially anyway. - for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): - if sg.handle.name == self._handle.name: - # Other executions in the same critical section are allowed. - continue - if not (exclusive_resource_access or sg.exclusive_resource_access): - # Neither execution requested exclusive access. - continue - resource_intersection = resource_inputs.intersection(sg.resources) - if resource_intersection: - raise ValueError( - "This execution would access resources: %s. Either this " - "lock (CriticalSection: %s) or lock '%s' " - "(CriticalSection: %s) requested exclusive resource access " - "of this resource. Did you mean to call execute with keyword " - "argument exclusive_resource_access=False?" % - (list(resource_intersection), self._handle.name, - sg.op.name, sg.handle.name)) - - def identity(x): # pylint: disable=invalid-name - if isinstance(x, tensor_array_ops.TensorArray): - return x.identity() - elif isinstance(x, ops.Operation): - return control_flow_ops.group(x) - elif context.in_eager_mode() and x is None: - return None - else: - return array_ops.identity(x) - - r_flat = [identity(x) for x in nest.flatten(r)] + self._check_multiple_access_to_resources( + captured_resources, exclusive_resource_access) + + r_flat = [_identity(x) for x in nest.flatten(r)] with ops.control_dependencies(r_flat): # The identity must run on the same machine as self._handle @@ -266,23 +269,105 @@ class CriticalSection(object): # Make sure that if any element of r is accessed, all of # them are executed together. - r = nest.pack_sequence_as( - r, control_flow_ops.tuple(nest.flatten(r))) + r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r))) with ops.control_dependencies([ensure_lock_exists]): - outputs = nest.map_structure(identity, r) + outputs = nest.map_structure(_identity, r) - if context.in_graph_mode(): + if not context.executing_eagerly(): signature = _ExecutionSignature( op=lock.op, handle=self._handle, - resources=list(resource_inputs), + resources=list(captured_resources), exclusive_resource_access=exclusive_resource_access) ops.add_to_collections( CRITICAL_SECTION_EXECUTIONS, signature) return outputs + def _add_control_dependencies_to_lock(self, created_ops, lock_op): + """To avoid deadlocks, all args must be executed before lock_op.""" + # Get all arguments (explicit and captured) of all ops created by fn(). + all_args = set([input_.op for op in created_ops for input_ in op.inputs]) + all_args.update( + input_op for op in created_ops for input_op in op.control_inputs) + # Unfortunately, we can't use sets throughout because TF seems to + # create new Operation objects for the same op sometimes; and we + # can't rely on id(op). + + # pylint: disable=protected-access + all_args_dict = dict((op._id, op) for op in all_args) + + # Remove ops created within fn, or that lock_op already has a + # control dependency on. Also remove a possible self-loop. + for op in created_ops: + all_args_dict.pop(op._id, None) + for op in lock_op.control_inputs: + all_args_dict.pop(op._id, None) + for input_ in lock_op.inputs: + all_args_dict.pop(input_.op._id, None) + all_args_dict.pop(lock_op._id, None) + + all_args = all_args_dict.values() + + if not all_args: + # No control dependencies to add; return early. + return + + # This group is important: it ensures that any ops in all_args + # outside the control context of the lock_op (and this fn, which + # runs in the same context) are added to this context before + # being added to the control dependencies of lock_op. + all_args = control_flow_ops.group(*all_args) + + lock_op._add_control_input(all_args) + # pylint: enable=protected-access + + def _is_self_handle(self, x): + """Check if the tensor `x` is the same Mutex as `self._handle`.""" + return (x.op.type == "MutexV2" + # blank shared_name means the op will create a unique one. + and x.op.get_attr("shared_name") + and (x.op.get_attr("shared_name") == + self._handle.op.get_attr("shared_name")) + and (x.op.device == self._handle.op.device + or _get_colocation(x.op) == _get_colocation(self._handle.op))) + + def _check_multiple_access_to_resources( + self, captured_resources, exclusive_resource_access): + """Raise if captured_resources are accessed by another CriticalSection. + + Args: + captured_resources: Set of tensors of type resource. + exclusive_resource_access: Whether this execution requires exclusive + resource access. + + Raises: + ValueError: If any tensors in `captured_resources` are also accessed + by another `CriticalSection`, and at least one of them requires + exclusive resource access. + """ + # Collections and op introspection does not work in eager + # mode. This is generally ok; since eager mode (as of + # writing) executes sequentially anyway. + for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): + if self._is_self_handle(sg.handle): + # Other executions in the same critical section are allowed. + continue + if not (exclusive_resource_access or sg.exclusive_resource_access): + # Neither execution requested exclusive access. + continue + resource_intersection = captured_resources.intersection(sg.resources) + if resource_intersection: + raise ValueError( + "This execution would access resources: %s. Either this " + "lock (CriticalSection: %s) or lock '%s' " + "(CriticalSection: %s) requested exclusive resource access " + "of this resource. Did you mean to call execute with keyword " + "argument exclusive_resource_access=False?" % + (list(resource_intersection), self._handle.name, + sg.op.name, sg.handle.name)) + # TODO(ebrevdo): Re-enable once CriticalSection is in core. # def to_proto(self, export_scope=None): diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py index c916592ce1979fe3a79cf28ad4bdac44284cce97..ba660295cb3c97d26da7bf892c78bceee53cf2d4 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_test.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging # TODO(ebrevdo): Re-enable once CriticalSection is in core. # from tensorflow.python.training import saver as saver_lib @@ -37,7 +38,7 @@ class CriticalSectionTest(test.TestCase): v = resource_variable_ops.ResourceVariable(0.0, name="v") def fn(a, b): - c = v.read_value() + c = v.value() with ops.control_dependencies([c]): nv = v.assign_add(a * b) with ops.control_dependencies([nv]): @@ -140,15 +141,151 @@ class CriticalSectionTest(test.TestCase): ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)]) def testRecursiveCriticalSectionAccessIsIllegal(self): + # This does not work properly in eager mode. Eager users will + # just hit a deadlock if they do this. But at least it'll be easier + # to debug. + cs = critical_section_ops.CriticalSection() + def fn(x): + return cs.execute(lambda y: y + 1, x) + with self.assertRaisesRegexp( + ValueError, + r"attempts to directly access the CriticalSection in which it " + r"would be running"): + cs.execute(fn, 1.0) + + def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self): + # This one is subtle; and we're being overly cautious here. The + # deadlock we are ensuring we catch is: + # + # to_capture = CS[lambda x: x + 1](1.0) + # deadlocked = CS[lambda x: x + to_capture](1.0) + # + # This would have caused a deadlock because executing `deadlocked` will + # lock the mutex on CS; but then due to dependencies, will attempt + # to compute `to_capture`. This computation requires locking CS, + # but that is not possible now because CS is already locked by + # `deadlocked`. + # + # We check that CriticalSection.execute properly inserts new + # control dependencies to its lock to ensure all captured + # operations are finished before anything runs within the critical section. + cs = critical_section_ops.CriticalSection(shared_name="cs") + fn = array_ops.identity + to_capture = cs.execute(fn, 1.0) + fn_captures = lambda x: x + to_capture + to_capture_too = array_ops.identity(to_capture) + + ex_0 = cs.execute(fn_captures, 1.0) + + with ops.control_dependencies([to_capture]): + # This is OK because to_capture will execute before this next call + ex_1 = cs.execute(fn_captures, 1.0) + + dependency = array_ops.identity(to_capture) + + fn_captures_dependency = lambda x: x + dependency + + ex_2 = cs.execute(fn_captures_dependency, 1.0) + + with ops.control_dependencies([to_capture_too]): + ex_3 = cs.execute(fn_captures_dependency, 1.0) + + # Ensure there's no actual deadlock on to_execute. + self.assertEquals(2.0, self.evaluate(ex_0)) + self.assertEquals(2.0, self.evaluate(ex_1)) + self.assertEquals(2.0, self.evaluate(ex_2)) + self.assertEquals(2.0, self.evaluate(ex_3)) + + def testRecursiveCriticalSectionAccessWithinLoopIsProtected(self): + cs = critical_section_ops.CriticalSection(shared_name="cs") + + def body_implicit_capture(i, j): + # This would have caused a deadlock if not for logic in execute + # that inserts additional control dependencies onto the lock op: + # * Loop body argument j is captured by fn() + # * i is running in parallel to move forward the execution + # * j is not being checked by the predicate function + # * output of cs.execute() is returned as next j. + fn = lambda: j + 1 + return (i + 1, cs.execute(fn)) + + (i_n, j_n) = control_flow_ops.while_loop( + lambda i, _: i < 1000, + body_implicit_capture, + [0, 0], + parallel_iterations=25) + logging.warn( + "\n==============\nRunning " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_implicit_capture'\n" + "==============\n") + self.assertEquals((1000, 1000), self.evaluate((i_n, j_n))) + logging.warn( + "\n==============\nSuccessfully finished running " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_implicit_capture'\n" + "==============\n") + + def body_implicit_capture_protected(i, j): + # This version is ok because we manually add a control + # dependency on j, which is an argument to the while_loop body + # and captured by fn. + fn = lambda: j + 1 + with ops.control_dependencies([j]): + return (i + 1, cs.execute(fn)) + + (i_n, j_n) = control_flow_ops.while_loop( + lambda i, _: i < 1000, + body_implicit_capture_protected, + [0, 0], + parallel_iterations=25) + logging.warn( + "\n==============\nRunning " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_implicit_capture_protected'\n" + "==============\n") + self.assertEquals((1000, 1000), self.evaluate((i_n, j_n))) + logging.warn( + "\n==============\nSuccessfully finished running " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_implicit_capture_protected'\n" + "==============\n") + + def body_args_capture(i, j): + # This version is ok because j is an argument to fn and we can + # ensure there's a control dependency on j. + fn = lambda x: x + 1 + return (i + 1, cs.execute(fn, j)) + + (i_n, j_n) = control_flow_ops.while_loop( + lambda i, _: i < 1000, + body_args_capture, + [0, 0], + parallel_iterations=25) + logging.warn( + "\n==============\nRunning " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_args_capture'\n" + "==============\n") + self.assertEquals((1000, 1000), self.evaluate((i_n, j_n))) + logging.warn( + "\n==============\nSuccessfully finished running " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_args_capture'\n" + "==============\n") + + def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self): # This does not work properly in eager mode. Eager users will # just hit a deadlock if they do this. But at least it'll be easier # to debug. cs = critical_section_ops.CriticalSection(shared_name="cs") + cs_same = critical_section_ops.CriticalSection(shared_name="cs") def fn(x): - return cs.execute(lambda x: x+1, x) + return cs_same.execute(lambda x: x+1, x) with self.assertRaisesRegexp( ValueError, - r"attempts to access the CriticalSection in which it would be running"): + r"attempts to directly access the CriticalSection in which it " + r"would be running"): cs.execute(fn, 1.0) def testMultipleCSExecutionsRequestSameResource(self): @@ -179,6 +316,20 @@ class CriticalSectionTest(test.TestCase): ValueError, "requested exclusive resource access"): cs1.execute(lambda: v2 + 1) + def testControlDependencyFromOutsideWhileLoopMixedWithInsideLoop(self): + cs = critical_section_ops.CriticalSection() + v = resource_variable_ops.ResourceVariable(0, name="v") + # Make sure that the control dependencies on v do not cause issues + # in the lock_op's automatic control dependency adder. + # + # Note, here v must be a resource variable (or something similar), + # otherwise it gets hoisted into the while_loop by the time we add + # control dependencies to the lock_op. + out = control_flow_ops.while_loop( + lambda i: i < 10, lambda i: cs.execute(lambda j: v + j + 1, i), [0]) + self.evaluate(v.initializer) + self.assertEqual(10, self.evaluate(out)) + # TODO(ebrevdo): Re-enable once CriticalSection is in core. # # def testCriticalSectionAndExecuteOpSaverRoundTrip(self): diff --git a/tensorflow/contrib/framework/python/ops/sort_ops.py b/tensorflow/contrib/framework/python/ops/sort_ops.py index 8f62f0ea7b9b561f235b9496ffda97a9f378d530..1921a77c1e96ee3531d1ed0f98e41c27c9d427ac 100644 --- a/tensorflow/contrib/framework/python/ops/sort_ops.py +++ b/tensorflow/contrib/framework/python/ops/sort_ops.py @@ -14,6 +14,7 @@ # ============================================================================== """Support for sorting tensors. +@@argsort @@sort """ @@ -21,6 +22,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops as framework_ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -47,64 +51,141 @@ def sort(values, axis=-1, direction='ASCENDING', name=None): ValueError: If axis is not a constant scalar, or the direction is invalid. """ with framework_ops.name_scope(name, 'sort'): - if direction not in _SORT_IMPL: - raise ValueError('%s should be one of %s' % - (direction, ', '.join(sorted(_SORT_IMPL.keys())))) - # Axis must be an integer, not a Tensor. - axis = framework_ops.convert_to_tensor(axis, name='axis') - axis_static = tensor_util.constant_value(axis) - if axis.shape.ndims != 0 or axis_static is None: - raise ValueError('axis must be a constant scalar') - axis_static = int(axis_static) # Avoids NumPy casting error + return _sort_or_argsort(values, axis, direction, return_argsort=False) + + +def argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None): + """Returns the indices of a tensor that give its sorted order along an axis. + + For a 1D tensor, `tf.gather(values, tf.argsort(values))` is equivalent to + `tf.sort(values)`. For higher dimensions, the output has the same shape as + `values`, but along the given axis, values represent the index of the sorted + element in that slice of the tensor at the given position. + + Args: + values: 1-D or higher numeric `Tensor`. + axis: The axis along which to sort. The default is -1, which sorts the last + axis. + direction: The direction in which to sort the values (`'ASCENDING'` or + `'DESCENDING'`). + stable: If True, equal elements in the original tensor will not be + re-ordered in the returned order. Unstable sort is not yet implemented, + but will eventually be the default for performance reasons. If you + require a stable order, pass `stable=True` for forwards compatibility. + name: Optional name for the operation. + + Returns: + An int32 `Tensor` with the same shape as `values`. The indices that would + sort each slice of the given `values` along the given `axis`. + + Raises: + ValueError: If axis is not a constant scalar, or the direction is invalid. + """ + del stable # Unused. + with framework_ops.name_scope(name, 'argsort'): + return _sort_or_argsort(values, axis, direction, return_argsort=True) + + +def _sort_or_argsort(values, axis, direction, return_argsort): + """Internal sort/argsort implementation. + + Args: + values: The input values. + axis: The axis along which to sort. + direction: 'ASCENDING' or 'DESCENDING'. + return_argsort: Whether to return the argsort result. + + Returns: + Either the sorted values, or the indices of the sorted values in the + original tensor. See the `sort` and `argsort` docstrings. + + Raises: + ValueError: If axis is not a constant scalar, or the direction is invalid. + """ + if direction not in _SORT_IMPL: + raise ValueError('%s should be one of %s' % + (direction, ', '.join(sorted(_SORT_IMPL.keys())))) + # Axis must be an integer, not a Tensor. + axis = framework_ops.convert_to_tensor(axis, name='axis') + axis_static = tensor_util.constant_value(axis) + if axis.shape.ndims != 0 or axis_static is None: + raise ValueError('axis must be a constant scalar') + axis_static = int(axis_static) # Avoids NumPy casting error - values = framework_ops.convert_to_tensor(values, name='values') + values = framework_ops.convert_to_tensor(values, name='values') - return _SORT_IMPL[direction](values, axis_static) + return _SORT_IMPL[direction](values, axis_static, return_argsort) -def _descending_sort(values, axis): +def _descending_sort(values, axis, return_argsort=False): """Sorts values in reverse using `top_k`. Args: values: Tensor of numeric values. axis: Index of the axis which values should be sorted along. + return_argsort: If False, return the sorted values. If True, return the + indices that would sort the values. Returns: The sorted values. """ k = array_ops.shape(values)[axis] rank = array_ops.rank(values) + static_rank = values.shape.ndims # Fast path: sorting the last axis. if axis == -1 or axis + 1 == values.get_shape().ndims: - return nn_ops.top_k(values, k)[0] - - # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`. - if axis < 0: - # Make axis a Tensor with the real axis index if needed. - axis += rank - transposition = array_ops.concat( - [ - # Axes up to axis are unchanged. - math_ops.range(axis), - # Swap axis and rank - 1. - [rank - 1], - # Axes in [axis + 1, rank - 1) are unchanged. - math_ops.range(axis + 1, rank - 1), - # Swap axis and rank - 1. - [axis] - ], - axis=0) - top_k_input = array_ops.transpose(values, transposition) - values, unused_indices = nn_ops.top_k(top_k_input, k) - # transposition contains a single cycle of length 2 (swapping 2 elements), - # so it is an involution (it is its own inverse). - return array_ops.transpose(values, transposition) - - -def _ascending_sort(values, axis): + top_k_input = values + transposition = None + else: + # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`. + if axis < 0: + # Calculate the actual axis index if counting from the end. Use the static + # rank if available, or else make the axis back into a tensor. + axis += static_rank or rank + if static_rank is not None: + # Prefer to calculate the transposition array in NumPy and make it a + # constant. + transposition = constant_op.constant( + np.r_[ + # Axes up to axis are unchanged. + np.arange(axis), + # Swap axis and rank - 1. + [static_rank - 1], + # Axes in [axis + 1, rank - 1) are unchanged. + np.arange(axis + 1, static_rank - 1), + # Swap axis and rank - 1. + [axis]], + name='transposition') + else: + # Generate the transposition array from the tensors. + transposition = array_ops.concat( + [ + # Axes up to axis are unchanged. + math_ops.range(axis), + # Swap axis and rank - 1. + [rank - 1], + # Axes in [axis + 1, rank - 1) are unchanged. + math_ops.range(axis + 1, rank - 1), + # Swap axis and rank - 1. + [axis] + ], + axis=0) + top_k_input = array_ops.transpose(values, transposition) + + values, indices = nn_ops.top_k(top_k_input, k) + return_value = indices if return_argsort else values + if transposition is not None: + # transposition contains a single cycle of length 2 (swapping 2 elements), + # so it is an involution (it is its own inverse). + return_value = array_ops.transpose(return_value, transposition) + return return_value + + +def _ascending_sort(values, axis, return_argsort=False): # Negate the values to get the ascending order from descending sort. - values_or_indices = _descending_sort(-values, axis) - return -values_or_indices + values_or_indices = _descending_sort(-values, axis, return_argsort) + # If not argsort, negate the values again. + return values_or_indices if return_argsort else -values_or_indices _SORT_IMPL = { diff --git a/tensorflow/contrib/framework/python/ops/sort_ops_test.py b/tensorflow/contrib/framework/python/ops/sort_ops_test.py index d08ae502f10d98ee14d8bea2f76b18bedb935cea..a8fb94b245dccc8c7cf0e94cef9b436f881fe408 100644 --- a/tensorflow/contrib/framework/python/ops/sort_ops_test.py +++ b/tensorflow/contrib/framework/python/ops/sort_ops_test.py @@ -24,6 +24,8 @@ from tensorflow.contrib.framework.python.ops import sort_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -90,6 +92,38 @@ class SortTest(test.TestCase): axis=0, direction='DESCENDING').eval()) + def testSort_staticallyKnownRank_constantTransposition(self): + # The transposition array should be a constant if the rank of "values" is + # statically known. + tensor = random_ops.random_uniform( + # Rank is statically known to be 5, but the dimension lengths are not + # known. + random_ops.random_uniform( + shape=(5,), minval=0, maxval=10, dtype=dtypes.int32)) + sort_ops.sort(tensor, axis=1) + transposition = ( + ops.get_default_graph().get_tensor_by_name('sort/transposition:0')) + self.assertFalse(tensor_util.constant_value(transposition) is None) + self.assertAllEqual( + # Swaps "1" and "4" to put "1" at the end. + tensor_util.constant_value(transposition), + [0, 4, 2, 3, 1]) + + def testArgsort_1d(self): + arr = np.random.random(42) + with self.test_session(): + self.assertAllEqual( + np.sort(arr), + array_ops.gather(arr, sort_ops.argsort(arr)).eval()) + + def testArgsort(self): + arr = np.random.random((5, 6, 7, 8)) + for axis in range(4): + with self.test_session(): + self.assertAllEqual( + np.argsort(arr, axis=axis), + sort_ops.argsort(arr, axis=axis).eval()) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index fdfabd07c13f689d075ecbb8786d725fa8a62d01..47e51415fd9e7daa360ca06a11078f6edcf63b5b 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -44,11 +44,11 @@ from tensorflow.python.ops import functional_ops from tensorflow.python.ops import image_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_impl from tensorflow.python.ops import nn_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import resource_loader - __all__ = [ 'get_graph_def_from_disk', 'get_graph_def_from_resource', @@ -62,10 +62,11 @@ __all__ = [ 'frechet_inception_distance', 'frechet_classifier_distance', 'frechet_classifier_distance_from_activations', + 'mean_only_frechet_classifier_distance_from_activations', + 'diagonal_only_frechet_classifier_distance_from_activations', 'INCEPTION_DEFAULT_IMAGE_SIZE', ] - INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz' INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb' INCEPTION_INPUT = 'Mul:0' @@ -77,8 +78,7 @@ INCEPTION_DEFAULT_IMAGE_SIZE = 299 def _validate_images(images, image_size): images = ops.convert_to_tensor(images) images.shape.with_rank(4) - images.shape.assert_is_compatible_with( - [None, image_size, image_size, None]) + images.shape.assert_is_compatible_with([None, image_size, image_size, None]) return images @@ -109,9 +109,10 @@ def _symmetric_matrix_square_root(mat, eps=1e-10): math_ops.matmul(u, array_ops.diag(si)), v, transpose_b=True) -def preprocess_image( - images, height=INCEPTION_DEFAULT_IMAGE_SIZE, - width=INCEPTION_DEFAULT_IMAGE_SIZE, scope=None): +def preprocess_image(images, + height=INCEPTION_DEFAULT_IMAGE_SIZE, + width=INCEPTION_DEFAULT_IMAGE_SIZE, + scope=None): """Prepare a batch of images for evaluation. This is the preprocessing portion of the graph from @@ -272,8 +273,11 @@ def run_inception(images, return activations -def run_image_classifier(tensor, graph_def, input_tensor, - output_tensor, scope='RunClassifier'): +def run_image_classifier(tensor, + graph_def, + input_tensor, + output_tensor, + scope='RunClassifier'): """Runs a network from a frozen graph. Args: @@ -317,7 +321,7 @@ def classifier_score(images, classifier_fn, num_batches=1): NOTE: This function consumes images, computes their logits, and then computes the classifier score. If you would like to precompute many logits for - large batches, use clasifier_score_from_logits(), which this method also + large batches, use classifier_score_from_logits(), which this method also uses. Args: @@ -433,8 +437,8 @@ def trace_sqrt_product(sigma, sigma_v): sqrt_sigma = _symmetric_matrix_square_root(sigma) # This is sqrt(A sigma_v A) above - sqrt_a_sigmav_a = math_ops.matmul( - sqrt_sigma, math_ops.matmul(sigma_v, sqrt_sigma)) + sqrt_a_sigmav_a = math_ops.matmul(sqrt_sigma, + math_ops.matmul(sigma_v, sqrt_sigma)) return math_ops.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) @@ -450,9 +454,9 @@ def frechet_classifier_distance(real_images, This technique is described in detail in https://arxiv.org/abs/1706.08500. Given two Gaussian distribution with means m and m_w and covariance matrices - C and C_w, this function calcuates + C and C_w, this function calculates - |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) + |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) which captures how different the distributions of real images and generated images (or more accurately, their visual features) are. Note that unlike the @@ -463,7 +467,7 @@ def frechet_classifier_distance(real_images, Frechet distance is biased. It is more biased for small sample sizes. (e.g. even if the two distributions are the same, for a small sample size, the expected Frechet distance is large). It is important to use the same - sample size to compute frechet classifier distance when comparing two + sample size to compute Frechet classifier distance when comparing two generative models. NOTE: This function consumes images, computes their activations, and then @@ -511,10 +515,142 @@ def frechet_classifier_distance(real_images, return frechet_classifier_distance_from_activations(real_a, gen_a) -def frechet_classifier_distance_from_activations( +def mean_only_frechet_classifier_distance_from_activations( real_activations, generated_activations): """Classifier distance for evaluating a generative model from activations. + Given two Gaussian distribution with means m and m_w and covariance matrices + C and C_w, this function calcuates + + |m - m_w|^2 + + which captures how different the distributions of real images and generated + images (or more accurately, their visual features) are. Note that unlike the + Inception score, this is a true distance and utilizes information about real + world images. + + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + + In this variant, we only compute the difference between the means of the + fitted Gaussians. The computation leads to O(n) vs. O(n^2) memory usage, yet + still retains much of the same information as FID. + + Args: + real_activations: 2D array of activations of real images of size + [num_images, num_dims] to use to compute Frechet Inception distance. + generated_activations: 2D array of activations of generated images of size + [num_images, num_dims] to use to compute Frechet Inception distance. + + Returns: + The mean-only Frechet Inception distance. A floating-point scalar of the + same type as the output of the activations. + """ + real_activations.shape.assert_has_rank(2) + generated_activations.shape.assert_has_rank(2) + + activations_dtype = real_activations.dtype + if activations_dtype != dtypes.float64: + real_activations = math_ops.to_double(real_activations) + generated_activations = math_ops.to_double(generated_activations) + + # Compute means of activations. + m = math_ops.reduce_mean(real_activations, 0) + m_w = math_ops.reduce_mean(generated_activations, 0) + + # Next the distance between means. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. + mofid = mean + if activations_dtype != dtypes.float64: + mofid = math_ops.cast(mofid, activations_dtype) + + return mofid + + +def diagonal_only_frechet_classifier_distance_from_activations( + real_activations, generated_activations): + """Classifier distance for evaluating a generative model. + + This is based on the Frechet Inception distance, but for an arbitrary + classifier. + + This technique is described in detail in https://arxiv.org/abs/1706.08500. + Given two Gaussian distribution with means m and m_w and covariance matrices + C and C_w, this function calcuates + + |m - m_w|^2 + (sigma + sigma_w - 2(sigma x sigma_w)^(1/2)) + + which captures how different the distributions of real images and generated + images (or more accurately, their visual features) are. Note that unlike the + Inception score, this is a true distance and utilizes information about real + world images. In this variant, we compute diagonal-only covariance matrices. + As a result, instead of computing an expensive matrix square root, we can do + something much simpler, and has O(n) vs O(n^2) space complexity. + + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + + Args: + real_activations: Real images to use to compute Frechet Inception distance. + generated_activations: Generated images to use to compute Frechet Inception + distance. + + Returns: + The diagonal-only Frechet Inception distance. A floating-point scalar of + the same type as the output of the activations. + + Raises: + ValueError: If the shape of the variance and mean vectors are not equal. + """ + real_activations.shape.assert_has_rank(2) + generated_activations.shape.assert_has_rank(2) + + activations_dtype = real_activations.dtype + if activations_dtype != dtypes.float64: + real_activations = math_ops.to_double(real_activations) + generated_activations = math_ops.to_double(generated_activations) + + # Compute mean and covariance matrices of activations. + m, var = nn_impl.moments(real_activations, axes=[0]) + m_w, var_w = nn_impl.moments(generated_activations, axes=[0]) + + actual_shape = var.get_shape() + expected_shape = m.get_shape() + + if actual_shape != expected_shape: + raise ValueError('shape: {} must match expected shape: {}'.format( + actual_shape, expected_shape)) + + # Compute the two components of FID. + + # First the covariance component. + # Here, note that trace(A + B) = trace(A) + trace(B) + trace = math_ops.reduce_sum( + (var + var_w) - 2.0 * math_ops.sqrt(math_ops.multiply(var, var_w))) + + # Next the distance between means. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. + dofid = trace + mean + if activations_dtype != dtypes.float64: + dofid = math_ops.cast(dofid, activations_dtype) + + return dofid + + +def frechet_classifier_distance_from_activations(real_activations, + generated_activations): + """Classifier distance for evaluating a generative model. + This methods computes the Frechet classifier distance from activations of real images and generated images. This can be used independently of the frechet_classifier_distance() method, especially in the case of using large @@ -523,15 +659,22 @@ def frechet_classifier_distance_from_activations( This technique is described in detail in https://arxiv.org/abs/1706.08500. Given two Gaussian distribution with means m and m_w and covariance matrices - C and C_w, this function calcuates + C and C_w, this function calculates - |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) + |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) which captures how different the distributions of real images and generated images (or more accurately, their visual features) are. Note that unlike the Inception score, this is a true distance and utilizes information about real world images. + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + Args: real_activations: 2D Tensor containing activations of real data. Shape is [batch_size, activation_size]. @@ -553,36 +696,38 @@ def frechet_classifier_distance_from_activations( # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_activations, 0) - m_v = math_ops.reduce_mean(generated_activations, 0) + m_w = math_ops.reduce_mean(generated_activations, 0) num_examples = math_ops.to_double(array_ops.shape(real_activations)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T real_centered = real_activations - m sigma = math_ops.matmul( - real_centered, real_centered, transpose_a=True) / (num_examples - 1) + real_centered, real_centered, transpose_a=True) / ( + num_examples - 1) - gen_centered = generated_activations - m_v - sigma_v = math_ops.matmul( - gen_centered, gen_centered, transpose_a=True) / (num_examples - 1) + gen_centered = generated_activations - m_w + sigma_w = math_ops.matmul( + gen_centered, gen_centered, transpose_a=True) / ( + num_examples - 1) - # Find the Tr(sqrt(sigma sigma_v)) component of FID - sqrt_trace_component = trace_sqrt_product(sigma, sigma_v) + # Find the Tr(sqrt(sigma sigma_w)) component of FID + sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) # Compute the two components of FID. # First the covariance component. # Here, note that trace(A + B) = trace(A) + trace(B) - trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component + trace = math_ops.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component # Next the distance between means. - mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. fid = trace + mean if activations_dtype != dtypes.float64: fid = math_ops.cast(fid, activations_dtype) return fid - frechet_inception_distance = functools.partial( frechet_classifier_distance, classifier_fn=functools.partial( diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index 61dc8646ddc10605561ae6b19e90f4739c346608..663e49bdca3cb2dd9257da326488c877fcc4256d 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -50,6 +50,26 @@ def _expected_inception_score(logits): return np.exp(np.mean(per_example_logincscore)) +def _expected_mean_only_fid(real_imgs, gen_imgs): + m = np.mean(real_imgs, axis=0) + m_v = np.mean(gen_imgs, axis=0) + mean = np.square(m - m_v).sum() + mofid = mean + return mofid + + +def _expected_diagonal_only_fid(real_imgs, gen_imgs): + m = np.mean(real_imgs, axis=0) + m_v = np.mean(gen_imgs, axis=0) + var = np.var(real_imgs, axis=0) + var_v = np.var(gen_imgs, axis=0) + sqcc = np.sqrt(var * var_v) + mean = (np.square(m - m_v)).sum() + trace = (var + var_v - 2 * sqcc).sum() + dofid = mean + trace + return dofid + + def _expected_fid(real_imgs, gen_imgs): m = np.mean(real_imgs, axis=0) m_v = np.mean(gen_imgs, axis=0) @@ -285,6 +305,46 @@ class ClassifierMetricsTest(test.TestCase): self.assertAllClose(_expected_inception_score(logits), incscore_np) + def test_mean_only_frechet_classifier_distance_value(self): + """Test that `frechet_classifier_distance` gives the correct value.""" + np.random.seed(0) + + pool_real_a = np.float32(np.random.randn(256, 2048)) + pool_gen_a = np.float32(np.random.randn(256, 2048)) + + tf_pool_real_a = array_ops.constant(pool_real_a) + tf_pool_gen_a = array_ops.constant(pool_gen_a) + + mofid_op = classifier_metrics.mean_only_frechet_classifier_distance_from_activations( # pylint: disable=line-too-long + tf_pool_real_a, tf_pool_gen_a) + + with self.test_session() as sess: + actual_mofid = sess.run(mofid_op) + + expected_mofid = _expected_mean_only_fid(pool_real_a, pool_gen_a) + + self.assertAllClose(expected_mofid, actual_mofid, 0.0001) + + def test_diagonal_only_frechet_classifier_distance_value(self): + """Test that `frechet_classifier_distance` gives the correct value.""" + np.random.seed(0) + + pool_real_a = np.float32(np.random.randn(256, 2048)) + pool_gen_a = np.float32(np.random.randn(256, 2048)) + + tf_pool_real_a = array_ops.constant(pool_real_a) + tf_pool_gen_a = array_ops.constant(pool_gen_a) + + dofid_op = classifier_metrics.diagonal_only_frechet_classifier_distance_from_activations( # pylint: disable=line-too-long + tf_pool_real_a, tf_pool_gen_a) + + with self.test_session() as sess: + actual_dofid = sess.run(dofid_op) + + expected_dofid = _expected_diagonal_only_fid(pool_real_a, pool_gen_a) + + self.assertAllClose(expected_dofid, actual_dofid, 0.0001) + def test_frechet_classifier_distance_value(self): """Test that `frechet_classifier_distance` gives the correct value.""" np.random.seed(0) diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py index 9bebcacbe46d85fc4226c4275b71b3ecbde57a97..4b10bc0f8e607c02763d8ea622d6f8f2572c586d 100644 --- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py @@ -212,7 +212,7 @@ def sliced_wasserstein_distance(real_images, Args: real_images: (tensor) Real images (batch, height, width, channels). fake_images: (tensor) Fake images (batch, height, width, channels). - resolution_min: (int) Minimum resolution for the Laplacion pyramid. + resolution_min: (int) Minimum resolution for the Laplacian pyramid. patches_per_image: (int) Number of patches to extract per image per Laplacian level. patch_size: (int) Width of a square patch. @@ -221,7 +221,7 @@ def sliced_wasserstein_distance(real_images, use_svd: experimental method to compute a more accurate distance. Returns: List of tuples (distance_real, distance_fake) for each level of the - Laplacian pyramid from the highest resoluion to the lowest. + Laplacian pyramid from the highest resolution to the lowest. distance_real is the Wasserstein distance between real images distance_fake is the Wasserstein distance between real and fake images. Raises: diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py index cd31c62667fc048b1003d334377405b284f32af5..e2594faf85bcf91cbe09f266e4d4211d20bdee17 100644 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Miscellanous utilities for TFGAN code and examples. +"""Miscellaneous utilities for TFGAN code and examples. Includes: 1) Conditioning the value of a Tensor, based on techniques from diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py index 4cfae0de4451880cf8229903b0eb74b1c6e2e04d..9e4ec59e7098443efc53506a4ba159e84b5c1618 100644 --- a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py @@ -17,7 +17,7 @@ We use this to keep a history of values created by a generator, such that a discriminator can randomly be trained on some older samples, not just the current one. This can help to not let the discriminator get too far ahead of the -generator and also to keep the system from oscilating, if the discriminator +generator and also to keep the system from oscillating, if the discriminator forgets too fast what past samples from the generator looked like. See the following papers for more details. @@ -97,7 +97,7 @@ def tensor_pool(input_values, dtypes=[v.dtype for v in input_values], shapes=None) - # In pseudeo code this code does the following: + # In pseudo code this code does the following: # if not pool_full: # enqueue(input_values) # return input_values diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py index 845f89827b6e60eda41a55a80671f43460247b05..2fe06a287284ff994326d5a977a2e4d4634268ae 100644 --- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py +++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py @@ -148,7 +148,7 @@ class VirtualBatchnormTest(test.TestCase): self.assertAllClose(bn_np[i, ...], vb_np) def test_minibatch_independent(self): - """Test that virtual batch normalized exampels are independent. + """Test that virtual batch normalized examples are independent. Unlike batch normalization, virtual batch normalization has the property that the virtual batch normalized value of an example is independent of the diff --git a/tensorflow/contrib/graph_editor/reroute.py b/tensorflow/contrib/graph_editor/reroute.py index 7ffdbb7139281734917fdb715601b317eb58b82f..95c02a64d47c26e731ef2628fb551529e9bc3f4d 100644 --- a/tensorflow/contrib/graph_editor/reroute.py +++ b/tensorflow/contrib/graph_editor/reroute.py @@ -471,9 +471,10 @@ def remove_control_inputs(op, cops): if cop not in op.control_inputs: raise ValueError("{} is not a control_input of {}".format(op.name, cop.name)) + control_inputs = [cop for cop in op.control_inputs if cop not in cops] # pylint: disable=protected-access - op._control_inputs = [cop for cop in op._control_inputs if cop not in cops] - op._recompute_node_def() + op._remove_all_control_inputs() + op._add_control_inputs(control_inputs) # pylint: enable=protected-access @@ -496,9 +497,6 @@ def add_control_inputs(op, cops): if cop in op.control_inputs: raise ValueError("{} is already a control_input of {}".format(cop.name, op.name)) - # pylint: disable=protected-access - op._control_inputs += cops - op._recompute_node_def() - # pylint: enable=protected-access + op._add_control_inputs(cops) # pylint: disable=protected-access remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py index ca00394388f67e2ed9508684a47b23c3ee9e79e8..2603de640735a612cbd883cc6227fe3cd9f11fca 100644 --- a/tensorflow/contrib/graph_editor/tests/transform_test.py +++ b/tensorflow/contrib/graph_editor/tests/transform_test.py @@ -23,6 +23,7 @@ from tensorflow.contrib import graph_editor as ge from tensorflow.contrib.graph_editor.tests import match from tensorflow.python.client import session from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -84,9 +85,9 @@ class TransformTest(test.TestCase): def test_transform(self): transformer = ge.Transformer() - def my_transform_op_handler(info, op): + def my_transform_op_handler(info, op, new_inputs): add_noise = op.name.startswith("Add") - op_, op_outputs_ = ge.transform.copy_op_handler(info, op) + op_, op_outputs_ = ge.transform.copy_op_handler(info, op, new_inputs) if not add_noise: return op_, op_outputs_ # add some noise to op @@ -201,15 +202,56 @@ class TransformTest(test.TestCase): get_operation_by_name("res/grad/mul1_grad/Mul_1")) # Make sure _original_ops are as expected. - self.assertEquals(original_mul1_grad._original_op.name, u"mul1") - self.assertEquals(result_mul1_grad._original_op.name, u"res/mul1") - self.assertNotEquals(res.name, g.name) + self.assertEqual(original_mul1_grad._original_op.name, u"mul1") + self.assertEqual(result_mul1_grad._original_op.name, u"res/mul1") + self.assertNotEqual(res.name, g.name) with session.Session() as sess: sess.run(variables.global_variables_initializer()) g_val, res_val = sess.run([g, res]) self.assertNear(g_val, 0.0, ERROR_TOLERANCE) self.assertNear(res_val, 0.0, ERROR_TOLERANCE) + def test_graph_while_loop(self): + graph = ops.Graph() + with graph.as_default(): + max_index = array_ops.placeholder(dtype=dtypes.int32, shape=tuple()) + index_start = constant_op.constant(1) + sum_start = constant_op.constant(0) + _, result = control_flow_ops.while_loop( + cond=lambda i, unused_s: i <= max_index, + body=lambda i, s: (i + 1, s + i), + loop_vars=[index_start, sum_start]) + copied_graph = ops.Graph() + _, copy_info = ge.copy( + graph, dst_graph=copied_graph, dst_scope="imported") + copied_result = copy_info.transformed(result) + copied_max_index = copy_info.transformed(max_index) + with copied_graph.as_default(): + with session.Session() as sess: + n = 10 + sum_val = sess.run(copied_result, feed_dict={copied_max_index: n}) + self.assertEqual(sum_val, 55) + + def test_graph_cond(self): + graph = ops.Graph() + with graph.as_default(): + choice = array_ops.placeholder(shape=(), dtype=dtypes.bool) + result = control_flow_ops.cond( + choice, + lambda: constant_op.constant(1), + lambda: constant_op.constant(2)) + copied_graph = ops.Graph() + _, copy_info = ge.copy( + graph, dst_graph=copied_graph, dst_scope="imported") + copied_result = copy_info.transformed(result) + copied_choice = copy_info.transformed(choice) + with copied_graph.as_default(): + with session.Session() as sess: + res = sess.run(copied_result, feed_dict={copied_choice: True}) + self.assertEqual(res, 1) + res = sess.run(copied_result, feed_dict={copied_choice: False}) + self.assertEqual(res, 2) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 14ac5296657d48c7f9e94d220c9e7e28af4d4353..d8a48387a745e7d88cc6a74c96cb21a2ba1cfa1f 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -129,20 +129,26 @@ def transform_op_if_inside_handler(info, op, keep_if_possible=True): return None -def copy_op_handler(info, op, copy_shape=True): +def copy_op_handler(info, op, new_inputs, copy_shape=True): """Copy a `tf.Operation`. Args: info: Transform._TmpInfo instance. op: the `tf.Operation` to be copied. + new_inputs: The new inputs for this op. copy_shape: also copy the shape of the tensor Returns: A `(op, op_outputs)` tuple containing the transformed op and its outputs. """ + # The `new_inputs` was added to this function. For compatibility reason, + # let's raise an error if `new_inputs` is a boolean. + if isinstance(new_inputs, bool): + raise TypeError("the `new_inputs` argument must be an iterable.") + # pylint: disable=protected-access # Clone the node def: - node_def_ = deepcopy(op._node_def) + node_def_ = deepcopy(op.node_def) # Transform name: name_ = info.new_name(op.name) @@ -155,10 +161,10 @@ def copy_op_handler(info, op, copy_shape=True): # Make a copy of the op_def too. # Its unique to every _type_ of Operation. - op_def_ = deepcopy(op._op_def) + op_def_ = deepcopy(op.op_def) # Initialize a new Operation instance - op_ = tf_ops.Operation(node_def_, info.graph_, [], output_types_, + op_ = tf_ops.Operation(node_def_, info.graph_, new_inputs, output_types_, [], input_types_, None, op_def_) # copy the shape over @@ -170,6 +176,7 @@ def copy_op_handler(info, op, copy_shape=True): # attribute to exist, we will create a dummy original_op first and then # later finalise it with the actual original_op when all the ops have # been copied. + # TODO(fkp): Stop worrying about _original_op and remove this code? if op._original_op: op_._original_op = op._original_op @@ -328,6 +335,14 @@ class _TmpInfo(object): for key in self.graph.get_all_collection_keys()) self.cyclic_ops = [] self.transform_original_op_handler = transform_op_if_inside_handler + # The graph is transformed op by op, in the same order the original ops + # were created. However, this is sometimes not possible due to cycles + # (i.e. while loops). So when the transformer creates a new op whose + # inputs do not exist yet, temporary placeholders are created and stored + # in this `tmp_cyclic_ts` container. During a second pass, + # those temporary tensors are replaced by the proper transformed tensors + # (see the function `_finalize_cycles`). + self.tmp_cyclic_ts = [] def new_name(self, name): """Compute a destination name from a source name. @@ -428,10 +443,10 @@ class Transformer(object): # Create temporary info used during this transform call info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope) - info.transform_original_op_handler = self.transform_original_op_handler self._copy_ops(info) - self._connect_ops(info) + self._finalize_cycles(info) + self._connect_control_inputs(info) # Compute information about the transformation res_info = TransformerInfo(info) @@ -440,10 +455,10 @@ class Transformer(object): def _copy_ops(self, info): """Copy ops without connecting them.""" - for op in info.sgv.ops: - logging.debug("Copying op: %s", op.name) - # TODO(fkp): return a subgraph? - op_, op_outputs_ = self.transform_op_handler(info, op) + sorted_ops = sorted(info.sgv.ops, key=lambda op: op._id) # pylint: disable=protected-access + for op in sorted_ops: + new_inputs = [self._transformed_t(info, t, op) for t in op.inputs] + op_, op_outputs_ = self.transform_op_handler(info, op, new_inputs) if op is op_: raise ValueError("In-place transformation not allowed.") @@ -456,27 +471,36 @@ class Transformer(object): info.transformed_ts[op_output] = op_output_ self.assign_collections_handler(info, op_output, op_output_) - def _connect_ops(self, info): + def _finalize_cycles(self, info): + """Reconnects the cyclic tensors.""" + for t, tmp_t_, consumer_op in info.tmp_cyclic_ts: + if t not in info.transformed_ts: + raise ValueError("The tensor {} should be transformed by now.".format( + t.name)) + if consumer_op not in info.transformed_ops: + raise ValueError("The op {} should be transformed by now.".format( + consumer_op.name)) + t_ = info.transformed_ts[t] + consumer_op_ = info.transformed_ops[consumer_op] + t_index_ = list(consumer_op_.inputs).index(tmp_t_) + consumer_op_._update_input(t_index_, t_, update_dtype=False) # pylint: disable=protected-access + + def _connect_control_inputs(self, info): """Connect the previously copied ops.""" for op in info.sgv.ops: - logging.debug("Finalizing op: %s", op.name) + logging.debug("Connecting control inputs of op: %s", op.name) op_ = info.transformed_ops[op] - # pylint: disable=protected-access - if op_.inputs: - raise ValueError("The newly transformed op should not have " - "any inputs yet: {}".format(op_.name)) - inputs_ = [self._transformed_t(info, t) for t in op.inputs] - for t in inputs_: - op_._add_input(t) - # Finalize original op. + # TODO(fkp): Stop worrying about _original_op and remove this code? + # pylint: disable=protected-access if op._original_op: - original_op = info.transform_original_op_handler(info, op._original_op) + original_op = self.transform_original_op_handler(info, op._original_op) if original_op is None: logging.debug("Could not find original op for: %s", op_.name) else: op_._original_op = original_op + # pylint: enable=protected-access # Finalize control inputs: control_inputs_ = [self.transform_control_input_handler(info, ci) @@ -525,19 +549,38 @@ class Transformer(object): return sgv_.remap(input_map_, output_map_) - def _transformed_t(self, info, t): + def _transformed_t(self, info, t, consumer_op): """Return tre transformed tensor of `t`.""" - if t not in info.transformed_ts: - # If op is not in the subgraph. - if t in info.sgv_inputs_set: - # t is an input of the subgraph. - return self.transform_external_input_handler(info, t) + if t in info.transformed_ts: + # If op is in the subgraph, just return its transformed counterpart. + return info.transformed_ts[t] + + if t in info.sgv_inputs_set: + # `t` is an input of the subgraph. + return self.transform_external_input_handler(info, t) + elif t.op in info.ops: + # `t` is an internal tensor but is not transformed yet because it + # belongs to a graph cycle. + logging.debug("Cyclic tensor: t.name = %s", t.name) + # Try to find an existing tensor we can use for now, + # otherwise create one. We'll rewire this later. + if consumer_op.type == "Merge": + first_input = consumer_op.inputs[0] + tmp_t_ = self._transformed_t(info, first_input, consumer_op) + elif t.op.type == "Enter": + enter_input = t.op.inputs[0] + tmp_t_ = self._transformed_t(info, enter_input, consumer_op) else: - # t is a hidden input of the subgraph. - return self.transform_external_hidden_input_handler(info, t) + with info.graph_.as_default(): + tmp_t_ = util.make_placeholder_from_tensor(t, scope=info.scope_, + prefix="geph_tmp") + logging.debug("Created temporary placeholder: %s.", tmp_t_.name) + # Register as temporary and return. + info.tmp_cyclic_ts.append((t, tmp_t_, consumer_op)) + return tmp_t_ else: - # If op is in the subgraph, just return its transformed. - return info.transformed_ts[t] + # `t` is a hidden input of the subgraph. + return self.transform_external_hidden_input_handler(info, t) def copy(sgv, dst_graph=None, dst_scope="", src_scope="", @@ -624,6 +667,40 @@ def copy_with_input_replacements(sgv, replacement_ts, sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope) +def _add_control_flow_ops(ops, control_ios): + """Complete `ops` so that the tranformed graph is valid. + + Partially copying a graph can lead to a malformed graph. For instance, + copying half of a while construct is likely to result in an invalid graph. + This function attempts to add missing ops so that the transformation result + in a valid graph. + + Args: + ops: list of ops (modifed in-place). + control_ios: object created by a call to `util.ControlOutputs`. + """ + # Find while contexts. + control_flow_contexts = set() + for op in ops: + cfc = op._control_flow_context # pylint: disable=protected-access + if cfc: + control_flow_contexts.add(cfc) + # Find new ops. + new_ops = [] + for cfc in control_flow_contexts: + if cfc.IsWhileContext(): + new_ops += select.get_walks_intersection_ops( + [enter_t.op for enter_t in cfc.loop_enters], + [exit_t.op for exit_t in cfc.loop_exits], + control_ios=control_ios) + # Add new ops. + new_ops_set = set(new_ops) + ops_set = frozenset(ops) + for op in new_ops_set: + if op not in ops_set: + ops.append(op) + + def graph_replace(target_ts, replacement_ts, dst_scope="", src_scope="", reuse_dst_scope=False): """Create a new graph which compute the targets from the replaced Tensors. @@ -657,8 +734,13 @@ def graph_replace(target_ts, replacement_ts, dst_scope="", control_ios=control_ios) if not ops: raise ValueError("Targets and replacements are not connected!") + + # Complete ops to avoid malformed control flow. + # TODO(fkp): Consider moving this function deeper (in the transformer?). + _add_control_flow_ops(ops, control_ios) + # Create a copy of the relevant subgraph - _, info = copy_with_input_replacements( + unused_sgv_, info = copy_with_input_replacements( ops, replacement_ts, None, dst_scope, src_scope, reuse_dst_scope) # Return the transformed targets but keep the original if the transformed # counterpart cannot be found diff --git a/tensorflow/contrib/graph_editor/util.py b/tensorflow/contrib/graph_editor/util.py index 30bc33b9ee42ba78bc7307c67c0fc0af9f3356ef..584f4509ccc0aab30edc2be3bad7a9cb938d6e6a 100644 --- a/tensorflow/contrib/graph_editor/util.py +++ b/tensorflow/contrib/graph_editor/util.py @@ -38,6 +38,11 @@ __all__ = [ ] +# The graph editor sometimes need to create placeholders, they are named +# "geph_*". "geph" stands for Graph-Editor PlaceHolder. +_DEFAULT_PLACEHOLDER_PREFIX = "geph" + + def concatenate_unique(la, lb): """Add all the elements of `lb` to `la` if they are not there already. @@ -405,7 +410,7 @@ def scope_basename(scope): return scope[slash + 1:] -def placeholder_name(t=None, scope=None): +def placeholder_name(t=None, scope=None, prefix=_DEFAULT_PLACEHOLDER_PREFIX): """Create placeholder name for the graph editor. Args: @@ -413,6 +418,7 @@ def placeholder_name(t=None, scope=None): on scope: absolute scope with which to prefix the placeholder's name. None means that the scope of t is preserved. "" means the root scope. + prefix: placeholder name prefix. Returns: A new placeholder name prefixed by "geph". Note that "geph" stands for Graph Editor PlaceHolder. This convention allows to quickly identify the @@ -430,19 +436,20 @@ def placeholder_name(t=None, scope=None): if scope is None: scope = op_dirname - if op_basename.startswith("geph__"): + if op_basename.startswith("{}__".format(prefix)): ph_name = op_basename else: - ph_name = "geph__{}_{}".format(op_basename, t.value_index) + ph_name = "{}__{}_{}".format(prefix, op_basename, t.value_index) return scope + ph_name else: if scope is None: scope = "" - return scope + "geph" + return "{}{}".format(scope, prefix) -def make_placeholder_from_tensor(t, scope=None): +def make_placeholder_from_tensor(t, scope=None, + prefix=_DEFAULT_PLACEHOLDER_PREFIX): """Create a `tf.placeholder` for the Graph Editor. Note that the correct graph scope must be set by the calling function. @@ -452,17 +459,19 @@ def make_placeholder_from_tensor(t, scope=None): (see function placeholder_name). scope: absolute scope within which to create the placeholder. None means that the scope of `t` is preserved. `""` means the root scope. + prefix: placeholder name prefix. Returns: A newly created `tf.placeholder`. Raises: TypeError: if `t` is not `None` or a `tf.Tensor`. """ return tf_array_ops.placeholder( - dtype=t.dtype, shape=t.get_shape(), name=placeholder_name( - t, scope=scope)) + dtype=t.dtype, shape=t.get_shape(), + name=placeholder_name(t, scope=scope, prefix=prefix)) -def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None): +def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None, + prefix=_DEFAULT_PLACEHOLDER_PREFIX): """Create a tf.placeholder for the Graph Editor. Note that the correct graph scope must be set by the calling function. @@ -474,11 +483,13 @@ def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None): shape: the tensor shape (optional). scope: absolute scope within which to create the placeholder. None means that the scope of t is preserved. "" means the root scope. + prefix: placeholder name prefix. Returns: A newly created tf.placeholder. """ return tf_array_ops.placeholder( - dtype=dtype, shape=shape, name=placeholder_name(scope=scope)) + dtype=dtype, shape=shape, + name=placeholder_name(scope=scope, prefix=prefix)) _INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$") diff --git a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py index 252788140f8c1906718c150574b963385b6ecfa1..bcd2a34c4e791a2ab66a439109145d6b78c14e22 100644 --- a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py +++ b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py @@ -110,7 +110,7 @@ class GridRNNCell(rnn.RNNCell): logging.warning('%s: Using a concatenated state is slower and will ' 'soon be deprecated. Use state_is_tuple=True.', self) if not output_is_tuple: - logging.warning('%s: Using a concatenated output is slower and will' + logging.warning('%s: Using a concatenated output is slower and will ' 'soon be deprecated. Use output_is_tuple=True.', self) if num_dims < 1: diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD index 3ff02e085ee63fabf42b3cc4389f4605455f3800..79eb3762edbc17e5c4682ac42dff87ae423bddfe 100755 --- a/tensorflow/contrib/image/BUILD +++ b/tensorflow/contrib/image/BUILD @@ -78,7 +78,10 @@ tf_custom_op_py_library( ], srcs_version = "PY2AND3", deps = [ + ":dense_image_warp_py", ":image_ops", + ":interpolate_spline_py", + ":sparse_image_warp_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", "//tensorflow/python:common_shapes", @@ -194,6 +197,117 @@ cuda_py_test( ], ) +py_library( + name = "dense_image_warp_py", + srcs = [ + "python/ops/dense_image_warp.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "//third_party/py/numpy", + ], +) + +py_library( + name = "interpolate_spline_py", + srcs = [ + "python/ops/interpolate_spline.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], +) + +py_library( + name = "sparse_image_warp_py", + srcs = [ + "python/ops/sparse_image_warp.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":dense_image_warp_py", + ":interpolate_spline_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], +) + +cuda_py_test( + name = "sparse_image_warp_test", + size = "medium", + srcs = ["python/kernel_tests/sparse_image_warp_test.py"], + additional_deps = [ + ":sparse_image_warp_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:clip_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:image_ops", + "//tensorflow/python:variables", + "//tensorflow/core:protos_all_py", + ], + data = [":sparse_image_warp_test_data"], + tags = ["no_pip"], +) + +filegroup( + name = "sparse_image_warp_test_data", + srcs = glob(["python/kernel_tests/test_data/*.png"]), +) + +cuda_py_test( + name = "dense_image_warp_test", + size = "medium", + srcs = ["python/kernel_tests/dense_image_warp_test.py"], + additional_deps = [ + ":dense_image_warp_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:clip_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:image_ops", + "//tensorflow/python:variables", + "//tensorflow/core:protos_all_py", + ], +) + +cuda_py_test( + name = "interpolate_spline_test", + size = "medium", + srcs = ["python/kernel_tests/interpolate_spline_test.py"], + additional_deps = [ + ":interpolate_spline_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:clip_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:image_ops", + "//tensorflow/python:variables", + "//tensorflow/core:protos_all_py", + ], +) + tf_py_test( name = "segmentation_test", size = "medium", diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py index cc8ed117ba2edcc7a53e609381166f17a2fbb45e..e982030bc8959309e72d0f4e02b9755c48535a10 100755 --- a/tensorflow/contrib/image/__init__.py +++ b/tensorflow/contrib/image/__init__.py @@ -30,6 +30,9 @@ projective transforms (including rotation) are supported. @@transform @@translate @@translations_to_projective_transforms +@@dense_image_warp +@@interpolate_spline +@@sparse_image_warp ## Image Segmentation `Ops` @@ -47,6 +50,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.image.python.ops.dense_image_warp import dense_image_warp + from tensorflow.contrib.image.python.ops.distort_image_ops import adjust_hsv_in_yiq from tensorflow.contrib.image.python.ops.distort_image_ops import random_hsv_in_yiq @@ -57,7 +62,9 @@ from tensorflow.contrib.image.python.ops.image_ops import rotate from tensorflow.contrib.image.python.ops.image_ops import transform from tensorflow.contrib.image.python.ops.image_ops import translate from tensorflow.contrib.image.python.ops.image_ops import translations_to_projective_transforms +from tensorflow.contrib.image.python.ops.interpolate_spline import interpolate_spline from tensorflow.contrib.image.python.ops.single_image_random_dot_stereograms import single_image_random_dot_stereograms +from tensorflow.contrib.image.python.ops.sparse_image_warp import sparse_image_warp from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/image/kernels/segmentation_ops.cc b/tensorflow/contrib/image/kernels/segmentation_ops.cc index fe8bf6e21c7b7310527668324571774e8bc50893..93722896233f0278c6cbb44af7203345e58c3172 100644 --- a/tensorflow/contrib/image/kernels/segmentation_ops.cc +++ b/tensorflow/contrib/image/kernels/segmentation_ops.cc @@ -101,8 +101,8 @@ struct ImageConnectedComponentsFunctor { int cost = (union_find.block_height() + union_find.block_width()) * 20; Shard(worker_threads->num_threads, worker_threads->workers, num_images * num_blocks_vertically * num_blocks_horizontally, cost, - [&union_find, num_images, num_blocks_vertically, - num_blocks_horizontally](int64 start_block, int64 limit_block) { + [&union_find, num_blocks_vertically, num_blocks_horizontally]( + int64 start_block, int64 limit_block) { for (int64 i = start_block; i < limit_block; i++) { int64 block_x = i % num_blocks_horizontally; int64 block_y = diff --git a/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py b/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a58b6a247ed6ae252db25a12f1e47c08c9a5c147 --- /dev/null +++ b/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py @@ -0,0 +1,267 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for dense_image_warp.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np + +from tensorflow.contrib.image.python.ops import dense_image_warp + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes + +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + +from tensorflow.python.training import adam + + +class DenseImageWarpTest(test_util.TensorFlowTestCase): + + def setUp(self): + np.random.seed(0) + + def test_interpolate_small_grid_ij(self): + grid = constant_op.constant( + [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], shape=[1, 3, 3, 1]) + query_points = constant_op.constant( + [[0., 0.], [1., 0.], [2., 0.5], [1.5, 1.5]], shape=[1, 4, 2]) + expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1]) + + interp = dense_image_warp._interpolate_bilinear(grid, query_points) + + with self.test_session() as sess: + predicted = sess.run(interp) + self.assertAllClose(expected_results, predicted) + + def test_interpolate_small_grid_xy(self): + grid = constant_op.constant( + [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], shape=[1, 3, 3, 1]) + query_points = constant_op.constant( + [[0., 0.], [0., 1.], [0.5, 2.0], [1.5, 1.5]], shape=[1, 4, 2]) + expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1]) + + interp = dense_image_warp._interpolate_bilinear( + grid, query_points, indexing='xy') + + with self.test_session() as sess: + predicted = sess.run(interp) + self.assertAllClose(expected_results, predicted) + + def test_interpolate_small_grid_batched(self): + grid = constant_op.constant( + [[[0., 1.], [3., 4.]], [[5., 6.], [7., 8.]]], shape=[2, 2, 2, 1]) + query_points = constant_op.constant([[[0., 0.], [1., 0.], [0.5, 0.5]], + [[0.5, 0.], [1., 0.], [1., 1.]]]) + expected_results = np.reshape( + np.array([[0., 3., 2.], [6., 7., 8.]]), [2, 3, 1]) + + interp = dense_image_warp._interpolate_bilinear(grid, query_points) + + with self.test_session() as sess: + predicted = sess.run(interp) + self.assertAllClose(expected_results, predicted) + + def get_image_and_flow_placeholders(self, shape, image_type, flow_type): + batch_size, height, width, numchannels = shape + image_shape = [batch_size, height, width, numchannels] + flow_shape = [batch_size, height, width, 2] + + tf_type = { + 'float16': dtypes.half, + 'float32': dtypes.float32, + 'float64': dtypes.float64 + } + + image = array_ops.placeholder(dtype=tf_type[image_type], shape=image_shape) + + flows = array_ops.placeholder(dtype=tf_type[flow_type], shape=flow_shape) + return image, flows + + def get_random_image_and_flows(self, shape, image_type, flow_type): + batch_size, height, width, numchannels = shape + image_shape = [batch_size, height, width, numchannels] + image = np.random.normal(size=image_shape) + flow_shape = [batch_size, height, width, 2] + flows = np.random.normal(size=flow_shape) * 3 + return image.astype(image_type), flows.astype(flow_type) + + def assert_correct_interpolation_value(self, + image, + flows, + pred_interpolation, + batch_index, + y_index, + x_index, + low_precision=False): + """Assert that the tf interpolation matches hand-computed value.""" + + height = image.shape[1] + width = image.shape[2] + displacement = flows[batch_index, y_index, x_index, :] + float_y = y_index - displacement[0] + float_x = x_index - displacement[1] + floor_y = max(min(height - 2, math.floor(float_y)), 0) + floor_x = max(min(width - 2, math.floor(float_x)), 0) + ceil_y = floor_y + 1 + ceil_x = floor_x + 1 + + alpha_y = min(max(0.0, float_y - floor_y), 1.0) + alpha_x = min(max(0.0, float_x - floor_x), 1.0) + + floor_y = int(floor_y) + floor_x = int(floor_x) + ceil_y = int(ceil_y) + ceil_x = int(ceil_x) + + top_left = image[batch_index, floor_y, floor_x, :] + top_right = image[batch_index, floor_y, ceil_x, :] + bottom_left = image[batch_index, ceil_y, floor_x, :] + bottom_right = image[batch_index, ceil_y, ceil_x, :] + + interp_top = alpha_x * (top_right - top_left) + top_left + interp_bottom = alpha_x * (bottom_right - bottom_left) + bottom_left + interp = alpha_y * (interp_bottom - interp_top) + interp_top + atol = 1e-6 + rtol = 1e-6 + if low_precision: + atol = 1e-2 + rtol = 1e-3 + self.assertAllClose( + interp, + pred_interpolation[batch_index, y_index, x_index, :], + atol=atol, + rtol=rtol) + + def check_zero_flow_correctness(self, shape, image_type, flow_type): + """Assert using zero flows doesn't change the input image.""" + + image, flows = self.get_image_and_flow_placeholders(shape, image_type, + flow_type) + interp = dense_image_warp.dense_image_warp(image, flows) + + with self.test_session() as sess: + rand_image, rand_flows = self.get_random_image_and_flows( + shape, image_type, flow_type) + rand_flows *= 0 + + predicted_interpolation = sess.run( + interp, feed_dict={ + image: rand_image, + flows: rand_flows + }) + self.assertAllClose(rand_image, predicted_interpolation) + + def test_zero_flows(self): + """Apply check_zero_flow_correctness() for a few sizes and types.""" + + shapes_to_try = [[3, 4, 5, 6], [1, 2, 2, 1]] + for shape in shapes_to_try: + self.check_zero_flow_correctness( + shape, image_type='float32', flow_type='float32') + + def check_interpolation_correctness(self, + shape, + image_type, + flow_type, + num_probes=5): + """Interpolate, and then assert correctness for a few query locations.""" + + image, flows = self.get_image_and_flow_placeholders(shape, image_type, + flow_type) + interp = dense_image_warp.dense_image_warp(image, flows) + low_precision = image_type == 'float16' or flow_type == 'float16' + with self.test_session() as sess: + rand_image, rand_flows = self.get_random_image_and_flows( + shape, image_type, flow_type) + + pred_interpolation = sess.run( + interp, feed_dict={ + image: rand_image, + flows: rand_flows + }) + + for _ in range(num_probes): + batch_index = np.random.randint(0, shape[0]) + y_index = np.random.randint(0, shape[1]) + x_index = np.random.randint(0, shape[2]) + + self.assert_correct_interpolation_value( + rand_image, + rand_flows, + pred_interpolation, + batch_index, + y_index, + x_index, + low_precision=low_precision) + + def test_interpolation(self): + """Apply check_interpolation_correctness() for a few sizes and types.""" + + shapes_to_try = [[3, 4, 5, 6], [1, 5, 5, 3], [1, 2, 2, 1]] + for im_type in ['float32', 'float64', 'float16']: + for flow_type in ['float32', 'float64', 'float16']: + for shape in shapes_to_try: + self.check_interpolation_correctness(shape, im_type, flow_type) + + def test_gradients_exist(self): + """Check that backprop can run. + + The correctness of the gradients is assumed, since the forward propagation + is tested to be correct and we only use built-in tf ops. + However, we perform a simple test to make sure that backprop can actually + run. We treat the flows as a tf.Variable and optimize them to minimize + the difference between the interpolated image and the input image. + """ + + batch_size, height, width, numchannels = [4, 5, 6, 7] + image_shape = [batch_size, height, width, numchannels] + image = random_ops.random_normal(image_shape) + flow_shape = [batch_size, height, width, 2] + init_flows = np.float32(np.random.normal(size=flow_shape) * 0.25) + flows = variables.Variable(init_flows) + + interp = dense_image_warp.dense_image_warp(image, flows) + loss = math_ops.reduce_mean(math_ops.square(interp - image)) + + optimizer = adam.AdamOptimizer(1.0) + grad = gradients.gradients(loss, [flows]) + opt_func = optimizer.apply_gradients(zip(grad, [flows])) + init_op = variables.global_variables_initializer() + + with self.test_session() as sess: + sess.run(init_op) + for _ in range(10): + sess.run(opt_func) + + def test_size_exception(self): + """Make sure it throws an exception for images that are too small.""" + + shape = [1, 2, 1, 1] + msg = 'Should have raised an exception for invalid image size' + with self.assertRaises(ValueError, msg=msg): + self.check_interpolation_correctness(shape, 'float32', 'float32') + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1939caaa2d8586413cf9ecba6ce73cf64910d6fc --- /dev/null +++ b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py @@ -0,0 +1,264 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for interpolate_spline.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import interpolate as sc_interpolate + +from tensorflow.contrib.image.python.ops import interpolate_spline + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util + +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + +from tensorflow.python.training import momentum + + +class _InterpolationProblem(object): + """Abstract class for interpolation problem descriptions.""" + + def get_problem(self, optimizable=False, extrapolate=True, dtype='float32'): + """Make data for an interpolation problem where all x vectors are n-d. + + Args: + optimizable: If True, then make train_points a tf.Variable. + extrapolate: If False, then clamp the query_points values to be within + the max and min of train_points. + dtype: The data type to use. + + Returns: + query_points, query_values, train_points, train_values: training and + test tensors for interpolation problem + """ + + # The values generated here depend on a seed of 0. + np.random.seed(0) + + batch_size = 1 + num_training_points = 10 + num_query_points = 4 + + init_points = np.random.uniform( + size=[batch_size, num_training_points, self.DATA_DIM]) + + init_points = init_points.astype(dtype) + train_points = ( + variables.Variable(init_points) + if optimizable else constant_op.constant(init_points)) + train_values = self.tf_function(train_points) + + query_points_np = np.random.uniform( + size=[batch_size, num_query_points, self.DATA_DIM]) + query_points_np = query_points_np.astype(dtype) + if not extrapolate: + query_points_np = np.clip(query_points_np, np.min(init_points), + np.max(init_points)) + + query_points = constant_op.constant(query_points_np) + query_values = self.np_function(query_points_np) + + return query_points, query_values, train_points, train_values + + +class _QuadraticPlusSinProblem1D(_InterpolationProblem): + """1D interpolation problem used for regression testing.""" + DATA_DIM = 1 + HARDCODED_QUERY_VALUES = { + (1.0, 0.0): [6.2647187603, -7.84362604077, -5.63690142322, 1.42928896387], + (1.0, + 0.01): [6.77688289946, -8.02163669853, -5.79491157027, 1.4063285693], + (2.0, + 0.0): [8.67110264937, -8.41281390883, -5.80190044693, 1.50155606059], + (2.0, + 0.01): [6.70797816797, -7.49709587663, -5.28965776238, 1.52284731741], + (3.0, + 0.0): [9.37691802935, -8.50390141515, -5.80786417426, 1.63467762122], + (3.0, + 0.01): [4.47106304758, -5.71266128361, -3.92529303296, 1.86755293857], + (4.0, + 0.0): [9.58172461111, -8.51432104771, -5.80967675388, 1.63361164256], + (4.0, 0.01): [ + -3.87902711352, -0.0253462273846, 1.79857618022, -0.769339675725 + ] + } + + def np_function(self, x): + """Takes np array, evaluates the test function, and returns np array.""" + return np.sum( + np.power((x - 0.5), 3) - 0.25 * x + 10 * np.sin(x * 10), + axis=2, + keepdims=True) + + def tf_function(self, x): + """Takes tf tensor, evaluates the test function, and returns tf tensor.""" + return math_ops.reduce_mean( + math_ops.pow((x - 0.5), 3) - 0.25 * x + 10 * math_ops.sin(x * 10), + 2, + keepdims=True) + + +class _QuadraticPlusSinProblemND(_InterpolationProblem): + """3D interpolation problem used for regression testing.""" + + DATA_DIM = 3 + HARDCODED_QUERY_VALUES = { + (1.0, 0.0): [1.06609663962, 1.28894849357, 1.10882405595, 1.63966936885], + (1.0, 0.01): [1.03123780748, 1.2952930985, 1.10366822954, 1.65265118569], + (2.0, 0.0): [0.627787735064, 1.43802857251, 1.00194632358, 1.91667538215], + (2.0, 0.01): [0.730159985046, 1.41702471595, 1.0065827217, 1.85758519312], + (3.0, 0.0): [0.350460417862, 1.67223539464, 1.00475331246, 2.31580322491], + (3.0, + 0.01): [0.624557250556, 1.63138876667, 0.976588193162, 2.12511237866], + (4.0, + 0.0): [0.898129669986, 1.24434133638, -0.938056116931, 1.59910338833], + (4.0, + 0.01): [0.0930360338179, -3.38791305538, -1.00969032567, 0.745535080382], + } + + def np_function(self, x): + """Takes np array, evaluates the test function, and returns np array.""" + return np.sum( + np.square(x - 0.5) + 0.25 * x + 1 * np.sin(x * 15), + axis=2, + keepdims=True) + + def tf_function(self, x): + """Takes tf tensor, evaluates the test function, and returns tf tensor.""" + return math_ops.reduce_sum( + math_ops.square(x - 0.5) + 0.25 * x + 1 * math_ops.sin(x * 15), + 2, + keepdims=True) + + +class InterpolateSplineTest(test_util.TensorFlowTestCase): + + def test_1d_linear_interpolation(self): + """For 1d linear interpolation, we can compare directly to scipy.""" + + tp = _QuadraticPlusSinProblem1D() + (query_points, _, train_points, train_values) = tp.get_problem( + extrapolate=False, dtype='float64') + interpolation_order = 1 + + with ops.name_scope('interpolator'): + interpolator = interpolate_spline.interpolate_spline( + train_points, train_values, query_points, interpolation_order) + with self.test_session() as sess: + fetches = [query_points, train_points, train_values, interpolator] + query_points_, train_points_, train_values_, interp_ = sess.run(fetches) + + # Just look at the first element of the minibatch. + # Also, trim the final singleton dimension. + interp_ = interp_[0, :, 0] + query_points_ = query_points_[0, :, 0] + train_points_ = train_points_[0, :, 0] + train_values_ = train_values_[0, :, 0] + + # Compute scipy interpolation. + scipy_interp_function = sc_interpolate.interp1d( + train_points_, train_values_, kind='linear') + + scipy_interpolation = scipy_interp_function(query_points_) + scipy_interpolation_on_train = scipy_interp_function(train_points_) + + # Even with float64 precision, the interpolants disagree with scipy a + # bit due to the fact that we add the EPSILON to prevent sqrt(0), etc. + tol = 1e-3 + + self.assertAllClose( + train_values_, scipy_interpolation_on_train, atol=tol, rtol=tol) + self.assertAllClose(interp_, scipy_interpolation, atol=tol, rtol=tol) + + def test_1d_interpolation(self): + """Regression test for interpolation with 1-D points.""" + + tp = _QuadraticPlusSinProblem1D() + (query_points, _, train_points, + train_values) = tp.get_problem(dtype='float64') + + for order in (1, 2, 3): + for reg_weight in (0, 0.01): + interpolator = interpolate_spline.interpolate_spline( + train_points, train_values, query_points, order, reg_weight) + + target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)] + target_interpolation = np.array(target_interpolation) + with self.test_session() as sess: + interp_val = sess.run(interpolator) + self.assertAllClose(interp_val[0, :, 0], target_interpolation) + + def test_nd_linear_interpolation(self): + """Regression test for interpolation with N-D points.""" + + tp = _QuadraticPlusSinProblemND() + (query_points, _, train_points, + train_values) = tp.get_problem(dtype='float64') + + for order in (1, 2, 3): + for reg_weight in (0, 0.01): + interpolator = interpolate_spline.interpolate_spline( + train_points, train_values, query_points, order, reg_weight) + + target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)] + target_interpolation = np.array(target_interpolation) + with self.test_session() as sess: + interp_val = sess.run(interpolator) + self.assertAllClose(interp_val[0, :, 0], target_interpolation) + + def test_interpolation_gradient(self): + """Make sure that backprop can run. Correctness of gradients is assumed. + + Here, we create a use a small 'training' set and a more densely-sampled + set of query points, for which we know the true value in advance. The goal + is to choose x locations for the training data such that interpolating using + this training data yields the best reconstruction for the function + values at the query points. The training data locations are optimized + iteratively using gradient descent. + """ + tp = _QuadraticPlusSinProblemND() + (query_points, query_values, train_points, + train_values) = tp.get_problem(optimizable=True) + + regularization = 0.001 + for interpolation_order in (1, 2, 3, 4): + interpolator = interpolate_spline.interpolate_spline( + train_points, train_values, query_points, interpolation_order, + regularization) + + loss = math_ops.reduce_mean(math_ops.square(query_values - interpolator)) + + optimizer = momentum.MomentumOptimizer(0.001, 0.9) + grad = gradients.gradients(loss, [train_points]) + grad, _ = clip_ops.clip_by_global_norm(grad, 1.0) + opt_func = optimizer.apply_gradients(zip(grad, [train_points])) + init_op = variables.global_variables_initializer() + + with self.test_session() as sess: + sess.run(init_op) + for _ in range(100): + sess.run([loss, opt_func]) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py b/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0135c66e293693345c3da7fdb21e28ca6d160154 --- /dev/null +++ b/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py @@ -0,0 +1,254 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sparse_image_warp.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.image.python.ops import sparse_image_warp + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import image_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test + +from tensorflow.python.training import momentum + + +class SparseImageWarpTest(test_util.TensorFlowTestCase): + + def setUp(self): + np.random.seed(0) + + def testGetBoundaryLocations(self): + image_height = 11 + image_width = 11 + num_points_per_edge = 4 + locs = sparse_image_warp._get_boundary_locations(image_height, image_width, + num_points_per_edge) + num_points = locs.shape[0] + self.assertEqual(num_points, 4 + 4 * num_points_per_edge) + locs = [(locs[i, 0], locs[i, 1]) for i in range(num_points)] + for i in (0, image_height - 1): + for j in (0, image_width - 1): + self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j)) + + for i in (2, 4, 6, 8): + for j in (0, image_width - 1): + self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j)) + + for i in (0, image_height - 1): + for j in (2, 4, 6, 8): + self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j)) + + def testGetGridLocations(self): + image_height = 5 + image_width = 3 + grid = sparse_image_warp._get_grid_locations(image_height, image_width) + for i in range(image_height): + for j in range(image_width): + self.assertEqual(grid[i, j, 0], i) + self.assertEqual(grid[i, j, 1], j) + + def testZeroShift(self): + """Run assertZeroShift for various hyperparameters.""" + for order in (1, 2): + for regularization in (0, 0.01): + for num_boundary_points in (0, 1): + self.assertZeroShift(order, regularization, num_boundary_points) + + def assertZeroShift(self, order, regularization, num_boundary_points): + """Check that warping with zero displacements doesn't change the image.""" + batch_size = 1 + image_height = 4 + image_width = 4 + channels = 3 + + image = np.random.uniform( + size=[batch_size, image_height, image_width, channels]) + + input_image_op = constant_op.constant(np.float32(image)) + + control_point_locations = [[1., 1.], [2., 2.], [2., 1.]] + control_point_locations = constant_op.constant( + np.float32(np.expand_dims(control_point_locations, 0))) + + control_point_displacements = np.zeros( + control_point_locations.shape.as_list()) + control_point_displacements = constant_op.constant( + np.float32(control_point_displacements)) + + (warped_image_op, flow_field) = sparse_image_warp.sparse_image_warp( + input_image_op, + control_point_locations, + control_point_locations + control_point_displacements, + interpolation_order=order, + regularization_weight=regularization, + num_boundary_points=num_boundary_points) + + with self.test_session() as sess: + warped_image, input_image, _ = sess.run( + [warped_image_op, input_image_op, flow_field]) + + self.assertAllClose(warped_image, input_image) + + def testMoveSinglePixel(self): + """Run assertMoveSinglePixel for various hyperparameters and data types.""" + for order in (1, 2): + for num_boundary_points in (1, 2): + for type_to_use in (dtypes.float32, dtypes.float64): + self.assertMoveSinglePixel(order, num_boundary_points, type_to_use) + + def assertMoveSinglePixel(self, order, num_boundary_points, type_to_use): + """Move a single block in a small grid using warping.""" + batch_size = 1 + image_height = 7 + image_width = 7 + channels = 3 + + image = np.zeros([batch_size, image_height, image_width, channels]) + image[:, 3, 3, :] = 1.0 + input_image_op = constant_op.constant(image, dtype=type_to_use) + + # Place a control point at the one white pixel. + control_point_locations = [[3., 3.]] + control_point_locations = constant_op.constant( + np.float32(np.expand_dims(control_point_locations, 0)), + dtype=type_to_use) + # Shift it one pixel to the right. + control_point_displacements = [[0., 1.0]] + control_point_displacements = constant_op.constant( + np.float32(np.expand_dims(control_point_displacements, 0)), + dtype=type_to_use) + + (warped_image_op, flow_field) = sparse_image_warp.sparse_image_warp( + input_image_op, + control_point_locations, + control_point_locations + control_point_displacements, + interpolation_order=order, + num_boundary_points=num_boundary_points) + + with self.test_session() as sess: + warped_image, input_image, flow = sess.run( + [warped_image_op, input_image_op, flow_field]) + # Check that it moved the pixel correctly. + self.assertAllClose( + warped_image[0, 4, 5, :], + input_image[0, 4, 4, :], + atol=1e-5, + rtol=1e-5) + + # Test that there is no flow at the corners. + for i in (0, image_height - 1): + for j in (0, image_width - 1): + self.assertAllClose( + flow[0, i, j, :], np.zeros([2]), atol=1e-5, rtol=1e-5) + + def load_image(self, image_file, sess): + image_op = image_ops.decode_png( + io_ops.read_file(image_file), dtype=dtypes.uint8, channels=4)[:, :, 0:3] + return sess.run(image_op) + + def testSmileyFace(self): + """Check warping accuracy by comparing to hardcoded warped images.""" + + test_data_dir = test.test_src_dir_path('contrib/image/python/' + 'kernel_tests/test_data/') + input_file = test_data_dir + 'Yellow_Smiley_Face.png' + with self.test_session() as sess: + input_image = self.load_image(input_file, sess) + control_points = np.asarray([[64, 59], [180 - 64, 59], [39, 111], + [180 - 39, 111], [90, 143], [58, 134], + [180 - 58, 134]]) # pyformat: disable + control_point_displacements = np.asarray( + [[-10.5, 10.5], [10.5, 10.5], [0, 0], [0, 0], [0, -10], [-20, 10.25], + [10, 10.75]]) + control_points_op = constant_op.constant( + np.expand_dims(np.float32(control_points[:, [1, 0]]), 0)) + control_point_displacements_op = constant_op.constant( + np.expand_dims(np.float32(control_point_displacements[:, [1, 0]]), 0)) + float_image = np.expand_dims(np.float32(input_image) / 255, 0) + input_image_op = constant_op.constant(float_image) + + for interpolation_order in (1, 2, 3): + for num_boundary_points in (0, 1, 4): + warp_op, _ = sparse_image_warp.sparse_image_warp( + input_image_op, + control_points_op, + control_points_op + control_point_displacements_op, + interpolation_order=interpolation_order, + num_boundary_points=num_boundary_points) + with self.test_session() as sess: + warped_image = sess.run(warp_op) + out_image = np.uint8(warped_image[0, :, :, :] * 255) + target_file = ( + test_data_dir + + 'Yellow_Smiley_Face_Warp-interp' + '-{}-clamp-{}.png'.format( + interpolation_order, num_boundary_points)) + + target_image = self.load_image(target_file, sess) + + # Check that the target_image and out_image difference is no + # bigger than 2 (on a scale of 0-255). Due to differences in + # floating point computation on different devices, the float + # output in warped_image may get rounded to a different int + # than that in the saved png file loaded into target_image. + self.assertAllClose(target_image, out_image, atol=2, rtol=1e-3) + + def testThatBackpropRuns(self): + """Run optimization to ensure that gradients can be computed.""" + + batch_size = 1 + image_height = 9 + image_width = 12 + image = variables.Variable( + np.float32( + np.random.uniform(size=[batch_size, image_height, image_width, 3]))) + control_point_locations = [[3., 3.]] + control_point_locations = constant_op.constant( + np.float32(np.expand_dims(control_point_locations, 0))) + control_point_displacements = [[0.25, -0.5]] + control_point_displacements = constant_op.constant( + np.float32(np.expand_dims(control_point_displacements, 0))) + warped_image, _ = sparse_image_warp.sparse_image_warp( + image, + control_point_locations, + control_point_locations + control_point_displacements, + num_boundary_points=3) + + loss = math_ops.reduce_mean(math_ops.abs(warped_image - image)) + optimizer = momentum.MomentumOptimizer(0.001, 0.9) + grad = gradients.gradients(loss, [image]) + grad, _ = clip_ops.clip_by_global_norm(grad, 1.0) + opt_func = optimizer.apply_gradients(zip(grad, [image])) + init_op = variables.global_variables_initializer() + + with self.test_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run([loss, opt_func]) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.png new file mode 100644 index 0000000000000000000000000000000000000000..7e303881e213a82e412d18de9d9d86f368726f06 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png new file mode 100644 index 0000000000000000000000000000000000000000..7fd9e4e6d69f3120428d1d778846d495cea1a989 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png new file mode 100644 index 0000000000000000000000000000000000000000..86d225e5d2158804f88dca881f69ed3ab287d866 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png new file mode 100644 index 0000000000000000000000000000000000000000..37e8ffae114625d0cc6a07ab2b8dbbb7413a3829 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png new file mode 100644 index 0000000000000000000000000000000000000000..e49b5816120d43a669264915f1b6747606e080e0 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png new file mode 100644 index 0000000000000000000000000000000000000000..df3cf2004312ed0ed0ebf1f0340cbfec7fd9ac46 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png new file mode 100644 index 0000000000000000000000000000000000000000..e1799a87c8542d7e515b6185d7e8f6f75fe73f3e Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png new file mode 100644 index 0000000000000000000000000000000000000000..2c346e0ce5487e21d41aa4e6306fd83a7b4ffdb4 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png new file mode 100644 index 0000000000000000000000000000000000000000..6f8b65451cc08a463e4305ddc4be0dbe2879fae9 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png new file mode 100644 index 0000000000000000000000000000000000000000..8e78146d955ae8f02230121e6314f3285e87611e Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png differ diff --git a/tensorflow/contrib/image/python/ops/dense_image_warp.py b/tensorflow/contrib/image/python/ops/dense_image_warp.py new file mode 100644 index 0000000000000000000000000000000000000000..f9b219ada492466919c615d8978e462e6c619d33 --- /dev/null +++ b/tensorflow/contrib/image/python/ops/dense_image_warp.py @@ -0,0 +1,201 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Image warping using per-pixel flow vectors.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def _interpolate_bilinear(grid, + query_points, + name='interpolate_bilinear', + indexing='ij'): + """Similar to Matlab's interp2 function. + + Finds values for query points on a grid using bilinear interpolation. + + Args: + grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`. + query_points: a 3-D float `Tensor` of N points with shape `[batch, N, 2]`. + name: a name for the operation (optional). + indexing: whether the query points are specified as row and column (ij), + or Cartesian coordinates (xy). + + Returns: + values: a 3-D `Tensor` with shape `[batch, N, channels]` + + Raises: + ValueError: if the indexing mode is invalid, or if the shape of the inputs + invalid. + """ + if indexing != 'ij' and indexing != 'xy': + raise ValueError('Indexing mode must be \'ij\' or \'xy\'') + + with ops.name_scope(name): + grid = ops.convert_to_tensor(grid) + query_points = ops.convert_to_tensor(query_points) + shape = grid.get_shape().as_list() + if len(shape) != 4: + msg = 'Grid must be 4 dimensional. Received size: ' + raise ValueError(msg + str(grid.get_shape())) + + batch_size, height, width, channels = shape + query_type = query_points.dtype + grid_type = grid.dtype + + if (len(query_points.get_shape()) != 3 or + query_points.get_shape()[2].value != 2): + msg = ('Query points must be 3 dimensional and size 2 in dim 2. Received ' + 'size: ') + raise ValueError(msg + str(query_points.get_shape())) + + _, num_queries, _ = query_points.get_shape().as_list() + + if height < 2 or width < 2: + msg = 'Grid must be at least batch_size x 2 x 2 in size. Received size: ' + raise ValueError(msg + str(grid.get_shape())) + + alphas = [] + floors = [] + ceils = [] + + index_order = [0, 1] if indexing == 'ij' else [1, 0] + unstacked_query_points = array_ops.unstack(query_points, axis=2) + + for dim in index_order: + with ops.name_scope('dim-' + str(dim)): + queries = unstacked_query_points[dim] + + size_in_indexing_dimension = shape[dim + 1] + + # max_floor is size_in_indexing_dimension - 2 so that max_floor + 1 + # is still a valid index into the grid. + max_floor = math_ops.cast(size_in_indexing_dimension - 2, query_type) + min_floor = constant_op.constant(0.0, dtype=query_type) + floor = math_ops.minimum( + math_ops.maximum(min_floor, math_ops.floor(queries)), max_floor) + int_floor = math_ops.cast(floor, dtypes.int32) + floors.append(int_floor) + ceil = int_floor + 1 + ceils.append(ceil) + + # alpha has the same type as the grid, as we will directly use alpha + # when taking linear combinations of pixel values from the image. + alpha = math_ops.cast(queries - floor, grid_type) + min_alpha = constant_op.constant(0.0, dtype=grid_type) + max_alpha = constant_op.constant(1.0, dtype=grid_type) + alpha = math_ops.minimum(math_ops.maximum(min_alpha, alpha), max_alpha) + + # Expand alpha to [b, n, 1] so we can use broadcasting + # (since the alpha values don't depend on the channel). + alpha = array_ops.expand_dims(alpha, 2) + alphas.append(alpha) + + if batch_size * height * width > np.iinfo(np.int32).max / 8: + error_msg = """The image size or batch size is sufficiently large + that the linearized addresses used by array_ops.gather + may exceed the int32 limit.""" + raise ValueError(error_msg) + + flattened_grid = array_ops.reshape(grid, + [batch_size * height * width, channels]) + batch_offsets = array_ops.reshape( + math_ops.range(batch_size) * height * width, [batch_size, 1]) + + # This wraps array_ops.gather. We reshape the image data such that the + # batch, y, and x coordinates are pulled into the first dimension. + # Then we gather. Finally, we reshape the output back. It's possible this + # code would be made simpler by using array_ops.gather_nd. + def gather(y_coords, x_coords, name): + with ops.name_scope('gather-' + name): + linear_coordinates = batch_offsets + y_coords * width + x_coords + gathered_values = array_ops.gather(flattened_grid, linear_coordinates) + return array_ops.reshape(gathered_values, + [batch_size, num_queries, channels]) + + # grab the pixel values in the 4 corners around each query point + top_left = gather(floors[0], floors[1], 'top_left') + top_right = gather(floors[0], ceils[1], 'top_right') + bottom_left = gather(ceils[0], floors[1], 'bottom_left') + bottom_right = gather(ceils[0], ceils[1], 'bottom_right') + + # now, do the actual interpolation + with ops.name_scope('interpolate'): + interp_top = alphas[1] * (top_right - top_left) + top_left + interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left + interp = alphas[0] * (interp_bottom - interp_top) + interp_top + + return interp + + +def dense_image_warp(image, flow, name='dense_image_warp'): + """Image warping using per-pixel flow vectors. + + Apply a non-linear warp to the image, where the warp is specified by a dense + flow field of offset vectors that define the correspondences of pixel values + in the output image back to locations in the source image. Specifically, the + pixel value at output[b, j, i, c] is + images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c]. + + The locations specified by this formula do not necessarily map to an int + index. Therefore, the pixel value is obtained by bilinear + interpolation of the 4 nearest pixels around + (b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside + of the image, we use the nearest pixel values at the image boundary. + + + Args: + image: 4-D float `Tensor` with shape `[batch, height, width, channels]`. + flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`. + name: A name for the operation (optional). + + Note that image and flow can be of type tf.half, tf.float32, or tf.float64, + and do not necessarily have to be the same type. + + Returns: + A 4-D float `Tensor` with shape`[batch, height, width, channels]` + and same type as input image. + + Raises: + ValueError: if height < 2 or width < 2 or the inputs have the wrong number + of dimensions. + """ + with ops.name_scope(name): + batch_size, height, width, channels = image.get_shape().as_list() + # The flow is defined on the image grid. Turn the flow into a list of query + # points in the grid space. + grid_x, grid_y = array_ops.meshgrid( + math_ops.range(width), math_ops.range(height)) + stacked_grid = math_ops.cast( + array_ops.stack([grid_y, grid_x], axis=2), flow.dtype) + batched_grid = array_ops.expand_dims(stacked_grid, axis=0) + query_points_on_grid = batched_grid - flow + query_points_flattened = array_ops.reshape(query_points_on_grid, + [batch_size, height * width, 2]) + # Compute values at the query points, then reshape the result back to the + # image grid. + interpolated = _interpolate_bilinear(image, query_points_flattened) + interpolated = array_ops.reshape(interpolated, + [batch_size, height, width, channels]) + return interpolated diff --git a/tensorflow/contrib/image/python/ops/interpolate_spline.py b/tensorflow/contrib/image/python/ops/interpolate_spline.py new file mode 100644 index 0000000000000000000000000000000000000000..daf8c56456327f102f1409296a91f9f7b68ec799 --- /dev/null +++ b/tensorflow/contrib/image/python/ops/interpolate_spline.py @@ -0,0 +1,291 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Polyharmonic spline interpolation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops + +EPSILON = 0.0000000001 + + +def _cross_squared_distance_matrix(x, y): + """Pairwise squared distance between two (batch) matrices' rows (2nd dim). + + Computes the pairwise distances between rows of x and rows of y + Args: + x: [batch_size, n, d] float `Tensor` + y: [batch_size, m, d] float `Tensor` + + Returns: + squared_dists: [batch_size, n, m] float `Tensor`, where + squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2 + """ + x_norm_squared = math_ops.reduce_sum(math_ops.square(x), 2) + y_norm_squared = math_ops.reduce_sum(math_ops.square(y), 2) + + # Expand so that we can broadcast. + x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2) + y_norm_squared_tile = array_ops.expand_dims(y_norm_squared, 1) + + x_y_transpose = math_ops.matmul(x, y, adjoint_b=True) + + # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj + squared_dists = x_norm_squared_tile - 2 * x_y_transpose + y_norm_squared_tile + + return squared_dists + + +def _pairwise_squared_distance_matrix(x): + """Pairwise squared distance among a (batch) matrix's rows (2nd dim). + + This saves a bit of computation vs. using _cross_squared_distance_matrix(x,x) + + Args: + x: `[batch_size, n, d]` float `Tensor` + + Returns: + squared_dists: `[batch_size, n, n]` float `Tensor`, where + squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2 + """ + + x_x_transpose = math_ops.matmul(x, x, adjoint_b=True) + x_norm_squared = array_ops.matrix_diag_part(x_x_transpose) + x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2) + + # squared_dists[b,i,j] = ||x_bi - x_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj + squared_dists = x_norm_squared_tile - 2 * x_x_transpose + array_ops.transpose( + x_norm_squared_tile, [0, 2, 1]) + + return squared_dists + + +def _solve_interpolation(train_points, train_values, order, + regularization_weight): + """Solve for interpolation coefficients. + + Computes the coefficients of the polyharmonic interpolant for the 'training' + data defined by (train_points, train_values) using the kernel phi. + + Args: + train_points: `[b, n, d]` interpolation centers + train_values: `[b, n, k]` function values + order: order of the interpolation + regularization_weight: weight to place on smoothness regularization term + + Returns: + w: `[b, n, k]` weights on each interpolation center + v: `[b, d, k]` weights on each input dimension + """ + + b, n, d = train_points.get_shape().as_list() + _, _, k = train_values.get_shape().as_list() + + # First, rename variables so that the notation (c, f, w, v, A, B, etc.) + # follows https://en.wikipedia.org/wiki/Polyharmonic_spline. + # To account for python style guidelines we use + # matrix_a for A and matrix_b for B. + + c = train_points + f = train_values + + # Next, construct the linear system. + with ops.name_scope('construct_linear_system'): + + matrix_a = _phi(_pairwise_squared_distance_matrix(c), order) # [b, n, n] + if regularization_weight > 0: + batch_identity_matrix = np.expand_dims(np.eye(n), 0) + batch_identity_matrix = constant_op.constant( + batch_identity_matrix, dtype=train_points.dtype) + + matrix_a += regularization_weight * batch_identity_matrix + + # Append ones to the feature values for the bias term in the linear model. + ones = array_ops.ones([b, n, 1], train_points.dtype) + matrix_b = array_ops.concat([c, ones], 2) # [b, n, d + 1] + + # [b, n + d + 1, n] + left_block = array_ops.concat( + [matrix_a, array_ops.transpose(matrix_b, [0, 2, 1])], 1) + + num_b_cols = matrix_b.get_shape()[2] # d + 1 + lhs_zeros = array_ops.zeros([b, num_b_cols, num_b_cols], train_points.dtype) + right_block = array_ops.concat([matrix_b, lhs_zeros], + 1) # [b, n + d + 1, d + 1] + lhs = array_ops.concat([left_block, right_block], + 2) # [b, n + d + 1, n + d + 1] + + rhs_zeros = array_ops.zeros([b, d + 1, k], train_points.dtype) + rhs = array_ops.concat([f, rhs_zeros], 1) # [b, n + d + 1, k] + + # Then, solve the linear system and unpack the results. + with ops.name_scope('solve_linear_system'): + w_v = linalg_ops.matrix_solve(lhs, rhs) + w = w_v[:, :n, :] + v = w_v[:, n:, :] + + return w, v + + +def _apply_interpolation(query_points, train_points, w, v, order): + """Apply polyharmonic interpolation model to data. + + Given coefficients w and v for the interpolation model, we evaluate + interpolated function values at query_points. + + Args: + query_points: `[b, m, d]` x values to evaluate the interpolation at + train_points: `[b, n, d]` x values that act as the interpolation centers + ( the c variables in the wikipedia article) + w: `[b, n, k]` weights on each interpolation center + v: `[b, d, k]` weights on each input dimension + order: order of the interpolation + + Returns: + Polyharmonic interpolation evaluated at points defined in query_points. + """ + + batch_size = train_points.get_shape()[0].value + num_query_points = query_points.get_shape()[1].value + + # First, compute the contribution from the rbf term. + pairwise_dists = _cross_squared_distance_matrix(query_points, train_points) + phi_pairwise_dists = _phi(pairwise_dists, order) + + rbf_term = math_ops.matmul(phi_pairwise_dists, w) + + # Then, compute the contribution from the linear term. + # Pad query_points with ones, for the bias term in the linear model. + query_points_pad = array_ops.concat([ + query_points, + array_ops.ones([batch_size, num_query_points, 1], train_points.dtype) + ], 2) + linear_term = math_ops.matmul(query_points_pad, v) + + return rbf_term + linear_term + + +def _phi(r, order): + """Coordinate-wise nonlinearity used to define the order of the interpolation. + + See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition. + + Args: + r: input op + order: interpolation order + + Returns: + phi_k evaluated coordinate-wise on r, for k = r + """ + + # using EPSILON prevents log(0), sqrt0), etc. + # sqrt(0) is well-defined, but its gradient is not + with ops.name_scope('phi'): + if order == 1: + r = math_ops.maximum(r, EPSILON) + r = math_ops.sqrt(r) + return r + elif order == 2: + return 0.5 * r * math_ops.log(math_ops.maximum(r, EPSILON)) + elif order == 4: + return 0.5 * math_ops.square(r) * math_ops.log( + math_ops.maximum(r, EPSILON)) + elif order % 2 == 0: + r = math_ops.maximum(r, EPSILON) + return 0.5 * math_ops.pow(r, 0.5 * order) * math_ops.log(r) + else: + r = math_ops.maximum(r, EPSILON) + return math_ops.pow(r, 0.5 * order) + + +def interpolate_spline(train_points, + train_values, + query_points, + order, + regularization_weight=0.0, + name='interpolate_spline'): + r"""Interpolate signal using polyharmonic interpolation. + + The interpolant has the form + $$f(x) = \sum_{i = 1}^n w_i \phi(||x - c_i||) + v^T x + b.$$ + + This is a sum of two terms: (1) a weighted sum of radial basis function (RBF) + terms, with the centers \\(c_1, ... c_n\\), and (2) a linear term with a bias. + The \\(c_i\\) vectors are 'training' points. In the code, b is absorbed into v + by appending 1 as a final dimension to x. The coefficients w and v are + estimated such that the interpolant exactly fits the value of the function at + the \\(c_i\\) points, the vector w is orthogonal to each \\(c_i\\), and the + vector w sums to 0. With these constraints, the coefficients can be obtained + by solving a linear system. + + \\(\phi\\) is an RBF, parametrized by an interpolation + order. Using order=2 produces the well-known thin-plate spline. + + We also provide the option to perform regularized interpolation. Here, the + interpolant is selected to trade off between the squared loss on the training + data and a certain measure of its curvature + ([details](https://en.wikipedia.org/wiki/Polyharmonic_spline)). + Using a regularization weight greater than zero has the effect that the + interpolant will no longer exactly fit the training data. However, it may be + less vulnerable to overfitting, particularly for high-order interpolation. + + Note the interpolation procedure is differentiable with respect to all inputs + besides the order parameter. + + Args: + train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional + locations. These do not need to be regularly-spaced. + train_values: `[batch_size, n, k]` float `Tensor` of n c-dimensional values + evaluated at train_points. + query_points: `[batch_size, m, d]` `Tensor` of m d-dimensional locations + where we will output the interpolant's values. + order: order of the interpolation. Common values are 1 for + \\(\phi(r) = r\\), 2 for \\(\phi(r) = r^2 * log(r)\\) (thin-plate spline), + or 3 for \\(\phi(r) = r^3\\). + regularization_weight: weight placed on the regularization term. + This will depend substantially on the problem, and it should always be + tuned. For many problems, it is reasonable to use no regularization. + If using a non-zero value, we recommend a small value like 0.001. + name: name prefix for ops created by this function + + Returns: + `[b, m, k]` float `Tensor` of query values. We use train_points and + train_values to perform polyharmonic interpolation. The query values are + the values of the interpolant evaluated at the locations specified in + query_points. + """ + with ops.name_scope(name): + train_points = ops.convert_to_tensor(train_points) + train_values = ops.convert_to_tensor(train_values) + query_points = ops.convert_to_tensor(query_points) + + # First, fit the spline to the observed data. + with ops.name_scope('solve'): + w, v = _solve_interpolation(train_points, train_values, order, + regularization_weight) + + # Then, evaluate the spline at the query locations. + with ops.name_scope('predict'): + query_values = _apply_interpolation(query_points, train_points, w, v, + order) + + return query_values diff --git a/tensorflow/contrib/image/python/ops/sparse_image_warp.py b/tensorflow/contrib/image/python/ops/sparse_image_warp.py new file mode 100644 index 0000000000000000000000000000000000000000..54a215d6db6ded56a1a4a018a7e176f35fe6397e --- /dev/null +++ b/tensorflow/contrib/image/python/ops/sparse_image_warp.py @@ -0,0 +1,201 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Image warping using sparse flow defined at control points.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.image.python.ops import dense_image_warp +from tensorflow.contrib.image.python.ops import interpolate_spline + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops + + +def _get_grid_locations(image_height, image_width): + """Wrapper for np.meshgrid.""" + + y_range = np.linspace(0, image_height - 1, image_height) + x_range = np.linspace(0, image_width - 1, image_width) + y_grid, x_grid = np.meshgrid(y_range, x_range, indexing='ij') + return np.stack((y_grid, x_grid), -1) + + +def _expand_to_minibatch(np_array, batch_size): + """Tile arbitrarily-sized np_array to include new batch dimension.""" + tiles = [batch_size] + [1] * np_array.ndim + return np.tile(np.expand_dims(np_array, 0), tiles) + + +def _get_boundary_locations(image_height, image_width, num_points_per_edge): + """Compute evenly-spaced indices along edge of image.""" + y_range = np.linspace(0, image_height - 1, num_points_per_edge + 2) + x_range = np.linspace(0, image_width - 1, num_points_per_edge + 2) + ys, xs = np.meshgrid(y_range, x_range, indexing='ij') + is_boundary = np.logical_or( + np.logical_or(xs == 0, xs == image_width - 1), + np.logical_or(ys == 0, ys == image_height - 1)) + return np.stack([ys[is_boundary], xs[is_boundary]], axis=-1) + + +def _add_zero_flow_controls_at_boundary(control_point_locations, + control_point_flows, image_height, + image_width, boundary_points_per_edge): + """Add control points for zero-flow boundary conditions. + + Augment the set of control points with extra points on the + boundary of the image that have zero flow. + + Args: + control_point_locations: input control points + control_point_flows: their flows + image_height: image height + image_width: image width + boundary_points_per_edge: number of points to add in the middle of each + edge (not including the corners). + The total number of points added is + 4 + 4*(boundary_points_per_edge). + + Returns: + merged_control_point_locations: augmented set of control point locations + merged_control_point_flows: augmented set of control point flows + """ + + batch_size = control_point_locations.get_shape()[0].value + + boundary_point_locations = _get_boundary_locations(image_height, image_width, + boundary_points_per_edge) + + boundary_point_flows = np.zeros([boundary_point_locations.shape[0], 2]) + + type_to_use = control_point_locations.dtype + boundary_point_locations = constant_op.constant( + _expand_to_minibatch(boundary_point_locations, batch_size), + dtype=type_to_use) + + boundary_point_flows = constant_op.constant( + _expand_to_minibatch(boundary_point_flows, batch_size), dtype=type_to_use) + + merged_control_point_locations = array_ops.concat( + [control_point_locations, boundary_point_locations], 1) + + merged_control_point_flows = array_ops.concat( + [control_point_flows, boundary_point_flows], 1) + + return merged_control_point_locations, merged_control_point_flows + + +def sparse_image_warp(image, + source_control_point_locations, + dest_control_point_locations, + interpolation_order=2, + regularization_weight=0.0, + num_boundary_points=0, + name='sparse_image_warp'): + """Image warping using correspondences between sparse control points. + + Apply a non-linear warp to the image, where the warp is specified by + the source and destination locations of a (potentially small) number of + control points. First, we use a polyharmonic spline + (@{tf.contrib.image.interpolate_spline}) to interpolate the displacements + between the corresponding control points to a dense flow field. + Then, we warp the image using this dense flow field + (@{tf.contrib.image.dense_image_warp}). + + Let t index our control points. For regularization_weight=0, we have: + warped_image[b, dest_control_point_locations[b, t, 0], + dest_control_point_locations[b, t, 1], :] = + image[b, source_control_point_locations[b, t, 0], + source_control_point_locations[b, t, 1], :]. + + For regularization_weight > 0, this condition is met approximately, since + regularized interpolation trades off smoothness of the interpolant vs. + reconstruction of the interpolant at the control points. + See @{tf.contrib.image.interpolate_spline} for further documentation of the + interpolation_order and regularization_weight arguments. + + + Args: + image: `[batch, height, width, channels]` float `Tensor` + source_control_point_locations: `[batch, num_control_points, 2]` float + `Tensor` + dest_control_point_locations: `[batch, num_control_points, 2]` float + `Tensor` + interpolation_order: polynomial order used by the spline interpolation + regularization_weight: weight on smoothness regularizer in interpolation + num_boundary_points: How many zero-flow boundary points to include at + each image edge.Usage: + num_boundary_points=0: don't add zero-flow points + num_boundary_points=1: 4 corners of the image + num_boundary_points=2: 4 corners and one in the middle of each edge + (8 points total) + num_boundary_points=n: 4 corners and n-1 along each edge + name: A name for the operation (optional). + + Note that image and offsets can be of type tf.half, tf.float32, or + tf.float64, and do not necessarily have to be the same type. + + Returns: + warped_image: `[batch, height, width, channels]` float `Tensor` with same + type as input image. + flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense + flow field produced by the interpolation. + """ + + image = ops.convert_to_tensor(image) + source_control_point_locations = ops.convert_to_tensor( + source_control_point_locations) + dest_control_point_locations = ops.convert_to_tensor( + dest_control_point_locations) + + control_point_flows = ( + dest_control_point_locations - source_control_point_locations) + + clamp_boundaries = num_boundary_points > 0 + boundary_points_per_edge = num_boundary_points - 1 + + with ops.name_scope(name): + + batch_size, image_height, image_width, _ = image.get_shape().as_list() + + # This generates the dense locations where the interpolant + # will be evaluated. + grid_locations = _get_grid_locations(image_height, image_width) + + flattened_grid_locations = np.reshape(grid_locations, + [image_height * image_width, 2]) + + flattened_grid_locations = constant_op.constant( + _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype) + + if clamp_boundaries: + (dest_control_point_locations, + control_point_flows) = _add_zero_flow_controls_at_boundary( + dest_control_point_locations, control_point_flows, image_height, + image_width, boundary_points_per_edge) + + flattened_flows = interpolate_spline.interpolate_spline( + dest_control_point_locations, control_point_flows, + flattened_grid_locations, interpolation_order, regularization_weight) + + dense_flows = array_ops.reshape(flattened_flows, + [batch_size, image_height, image_width, 2]) + + warped_image = dense_image_warp.dense_image_warp(image, dense_flows) + + return warped_image, dense_flows diff --git a/tensorflow/contrib/kafka/BUILD b/tensorflow/contrib/kafka/BUILD index efb403462a6e5df5b69ac0735ffc03f40d4a252c..1c3974871c62911c0cb47677eb92d28286837142 100644 --- a/tensorflow/contrib/kafka/BUILD +++ b/tensorflow/contrib/kafka/BUILD @@ -1,66 +1,93 @@ -package( - default_visibility = ["//visibility:private"], -) +package(default_visibility = ["//tensorflow:internal"]) licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") -load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") -load("//tensorflow:tensorflow.bzl", "tf_py_test") +load( + "//tensorflow:tensorflow.bzl", + "tf_gen_op_wrapper_py", + "tf_kernel_library", + "tf_custom_op_library", + "tf_custom_op_py_library", + "tf_gen_op_libs", + "tf_py_test", +) -tf_kernel_library( - name = "kafka_kernels", +py_library( + name = "kafka", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_ops", + ], +) + +tf_custom_op_library( + name = "_dataset_ops.so", + srcs = ["ops/dataset_ops.cc"], + deps = [":dataset_kernels"], +) + +tf_gen_op_libs( + op_lib_names = ["dataset_ops"], +) + +cc_library( + name = "dataset_kernels", srcs = ["kernels/kafka_dataset_ops.cc"], - visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/kernels:bounds_check_lib", - "//tensorflow/core/kernels:dataset", + "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", "@kafka", + "@protobuf_archive//:protobuf_headers", ], + alwayslink = 1, ) -tf_gen_op_libs( - op_lib_names = ["kafka_ops"], +py_library( + name = "dataset_ops", + srcs = [ + "python/ops/kafka_dataset_ops.py", + ], + srcs_version = "PY2AND3", deps = [ - "//tensorflow/core:lib", + ":kafka_op_loader", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", ], ) tf_gen_op_wrapper_py( - name = "gen_kafka_ops", - out = "python/ops/gen_kafka_ops.py", - require_shape_functions = True, - deps = [":kafka_ops_op_lib"], + name = "gen_dataset_ops", + out = "python/ops/gen_dataset_ops.py", + deps = ["//tensorflow/contrib/kafka:dataset_ops_op_lib"], ) -py_library( - name = "kafka", - srcs = [ - "__init__.py", - "python/ops/kafka_dataset_ops.py", +tf_kernel_library( + name = "dataset_ops_kernels", + deps = [ + ":dataset_kernels", + "//tensorflow/core:framework", + ], + alwayslink = 1, +) + +tf_custom_op_py_library( + name = "kafka_op_loader", + srcs = ["python/ops/kafka_op_loader.py"], + dso = ["//tensorflow/contrib/kafka:_dataset_ops.so"], + kernels = [ + ":dataset_ops_kernels", + "//tensorflow/contrib/kafka:dataset_ops_op_lib", ], srcs_version = "PY2AND3", - visibility = ["//visibility:public"], deps = [ - ":gen_kafka_ops", + ":gen_dataset_ops", "//tensorflow/contrib/util:util_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", - "//tensorflow/python:state_ops", - "//tensorflow/python:training", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//tensorflow/python/data/ops:readers", ], ) @@ -88,6 +115,7 @@ tf_py_test( ], tags = [ "manual", + "no_windows", "notap", ], ) @@ -95,7 +123,9 @@ tf_py_test( filegroup( name = "all_files", srcs = glob( - ["**/*"], + include = [ + "**/*", + ], exclude = [ "**/METADATA", "**/OWNERS", diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc index 88ef5f357113372b0a2d0cb13382ac980a61252d..a4cd4a2cc4b99b5906185bd2b942ed15c1ddf5e4 100644 --- a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc +++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc @@ -13,9 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/dataset.h" - -#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/dataset.h" #include "src-cpp/rdkafkacpp.h" diff --git a/tensorflow/contrib/kafka/ops/kafka_ops.cc b/tensorflow/contrib/kafka/ops/dataset_ops.cc similarity index 100% rename from tensorflow/contrib/kafka/ops/kafka_ops.cc rename to tensorflow/contrib/kafka/ops/dataset_ops.cc diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py index 8e51d27a342359881de072c3979a2b5a7fc034ea..a1624614d1ab1be31463c5cdc0b4cfb653165a0c 100644 --- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -17,8 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.kafka.python.ops import gen_kafka_ops -from tensorflow.python.data.ops.readers import Dataset +from tensorflow.contrib.kafka.python.ops import kafka_op_loader # pylint: disable=unused-import +from tensorflow.contrib.kafka.python.ops import gen_dataset_ops +from tensorflow.python.data.ops.dataset_ops import Dataset from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -58,8 +59,8 @@ class KafkaDataset(Dataset): timeout, dtype=dtypes.int64, name="timeout") def _as_variant_tensor(self): - return gen_kafka_ops.kafka_dataset(self._topics, self._servers, self._group, - self._eof, self._timeout) + return gen_dataset_ops.kafka_dataset(self._topics, self._servers, + self._group, self._eof, self._timeout) @property def output_classes(self): diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py b/tensorflow/contrib/kafka/python/ops/kafka_op_loader.py similarity index 75% rename from tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py rename to tensorflow/contrib/kafka/python/ops/kafka_op_loader.py index 690a44ff4368663306733300a1ea70397fb93e1e..ec2fdea962ef946d3f8f32b9e630b92649d612fe 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py +++ b/tensorflow/contrib/kafka/python/ops/kafka_op_loader.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Experimental methods for tf.feature_column sequential input.""" - +"""Python helper for loading kafka ops and kernels.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function + +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader + +_dataset_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("../../_dataset_ops.so")) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index f4ed978174a9ddd8b54a88e60bfb48a67a2e76d2..146ae8b7e2a3b2b479d5b8db7b8bffaca59a358f 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -36,6 +36,7 @@ py_test( srcs = ["fisher_factors_test.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_blocks", "//tensorflow/contrib/kfac/python/ops:fisher_factors", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py index b12f7be76907dc206667eb8ee0c750f3b8db57fc..f22dbcf21566297340f3b4158a810f6d03af12f5 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -23,7 +23,6 @@ import numpy as np from tensorflow.contrib.kfac.python.ops import estimator from tensorflow.contrib.kfac.python.ops import layer_collection as lc from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -40,30 +39,6 @@ from tensorflow.python.training import training_util _ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] -class DeviceContextGeneratorTest(test.TestCase): - - def testNoDevice(self): - device_context_generator = estimator._DeviceContextGenerator(None) - with ops.device("/device:CPU:0"): # This is what will be used - with device_context_generator(): # Does nothing - a = constant_op.constant([2.0], name="a") - self.assertEqual("/device:CPU:0", a.op.device) - - def testTwoDevices(self): - device_context_generator = estimator._DeviceContextGenerator( - ["/device:GPU:0", "/device:GPU:1"]) - with ops.device("/device:CPU:0"): # Will be over-ridden by the inner scopes - with device_context_generator(): - a = constant_op.constant([2.0], name="a") - with device_context_generator(): - b = constant_op.constant([2.0], name="b") - with device_context_generator(): - c = constant_op.constant([2.0], name="c") - self.assertEqual("/device:GPU:0", a.op.device) - self.assertEqual("/device:GPU:1", b.op.device) - self.assertEqual("/device:GPU:0", c.op.device) - - class EstimatorTest(test.TestCase): def setUp(self): @@ -90,59 +65,113 @@ class EstimatorTest(test.TestCase): def testEstimatorInitManualRegistration(self): with self._graph.as_default(): # We should be able to build an estimator for only the registered vars. - estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1, - self.layer_collection) + estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection + ) # Check that we throw an error if we try to build an estimator for vars # that were not manually registered. with self.assertRaises(ValueError): - estimator.FisherEstimator(lambda: 0.2, [self.weights, self.bias], 0.1, - self.layer_collection) + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights, self.bias], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection + ) + est.make_ops_and_vars() # Check that we throw an error if we don't include registered variables, # i.e. self.weights with self.assertRaises(ValueError): - estimator.FisherEstimator(lambda: 0.2, [], 0.1, self.layer_collection) + est = estimator.FisherEstimatorRoundRobin( + variables=[], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection) + est.make_ops_and_vars() @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42) def testVariableWrongNumberOfUses(self, mock_uses): with self.assertRaises(ValueError): - estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1, - self.layer_collection) + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection) + est.make_ops_and_vars() def testInvalidEstimationMode(self): with self.assertRaises(ValueError): - estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1, - self.layer_collection, "not_a_real_mode") + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="not_a_real_mode") + est.make_ops_and_vars() + + def testGradientsModeBuild(self): + with self._graph.as_default(): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="gradients") + est.make_ops_and_vars() - def testModeListCorrect(self): + def testEmpiricalModeBuild(self): with self._graph.as_default(): - est = estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1, - self.layer_collection) - self.assertItemsEqual(_ALL_ESTIMATION_MODES, est._gradient_fns.keys()) + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="empirical") + est.make_ops_and_vars() - def testAllModesBuild(self): - for mode in _ALL_ESTIMATION_MODES: - with self._graph.as_default(): - estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1, - self.layer_collection, mode) + def testCurvaturePropModeBuild(self): + with self._graph.as_default(): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="curvature_prop") + est.make_ops_and_vars() + + def testExactModeBuild(self): + with self._graph.as_default(): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="exact") + est.make_ops_and_vars() def test_cov_update_thunks(self): """Ensures covariance update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: - fisher_estimator = estimator.FisherEstimator( - damping_fn=lambda: 0.2, + fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, + damping=0.2, cov_ema_decay=0.0) # Construct an op that executes one covariance update per step. global_step = training_util.get_or_create_global_step() + (cov_variable_thunks, cov_update_op_thunks, _, + _) = fisher_estimator.create_ops_and_vars_thunks() + for thunk in cov_variable_thunks: + thunk() cov_matrices = [ fisher_factor.get_cov() for fisher_factor in self.layer_collection.get_factors() ] - cov_update_op_thunks = fisher_estimator.cov_update_thunks cov_update_op = control_flow_ops.case( [(math_ops.equal(global_step, i), thunk) for i, thunk in enumerate(cov_update_op_thunks)]) @@ -174,23 +203,61 @@ class EstimatorTest(test.TestCase): sess.run(cov_update_op) sess.run(increment_global_step) + def test_round_robin_placement(self): + """Check if the ops and variables are placed on devices correctly.""" + with self._graph.as_default(): + fisher_estimator = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + layer_collection=self.layer_collection, + damping=0.2, + cov_ema_decay=0.0, + cov_devices=["/cpu:{}".format(i) for i in range(2)], + inv_devices=["/cpu:{}".format(i) for i in range(2)]) + + # Construct an op that executes one covariance update per step. + (cov_update_ops, _, inv_update_ops, _, _, + _) = fisher_estimator.make_ops_and_vars(scope="test") + self.assertEqual(cov_update_ops[0].device, "/device:CPU:0") + self.assertEqual(cov_update_ops[1].device, "/device:CPU:1") + self.assertEqual(inv_update_ops[0].device, "/device:CPU:0") + self.assertEqual(inv_update_ops[1].device, "/device:CPU:1") + cov_matrices = [ + fisher_factor.get_cov() + for fisher_factor in self.layer_collection.get_factors() + ] + inv_matrices = [ + matrix + for fisher_factor in self.layer_collection.get_factors() + for matrix in fisher_factor._matpower_by_exp_and_damping.values() + ] + self.assertEqual(cov_matrices[0].device, "/device:CPU:0") + self.assertEqual(cov_matrices[1].device, "/device:CPU:1") + # Inverse matrices need to be explicitly placed. + self.assertEqual(inv_matrices[0].device, "") + self.assertEqual(inv_matrices[1].device, "") + def test_inv_update_thunks(self): """Ensures inverse update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: - fisher_estimator = estimator.FisherEstimator( - damping_fn=lambda: 0.2, + fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, + damping=0.2, cov_ema_decay=0.0) # Construct op that updates one inverse per global step. global_step = training_util.get_or_create_global_step() + (cov_variable_thunks, _, inv_variable_thunks, + inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks() + for thunk in cov_variable_thunks: + thunk() + for thunk in inv_variable_thunks: + thunk() inv_matrices = [ matrix for fisher_factor in self.layer_collection.get_factors() - for matrix in fisher_factor._inverses_by_damping.values() + for matrix in fisher_factor._matpower_by_exp_and_damping.values() ] - inv_update_op_thunks = fisher_estimator.inv_update_thunks inv_update_op = control_flow_ops.case( [(math_ops.equal(global_step, i), thunk) for i, thunk in enumerate(inv_update_op_thunks)]) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index fb4b3a241c1e9fd82e7bf630fd57295917048fbd..6eda6c31e34370fd2bea1192ebf777924824c8e3 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -63,7 +63,7 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -72,7 +72,7 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -81,7 +81,7 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors(grads, 0.5) @@ -91,9 +91,12 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() + block.register_inverse() + block._factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -109,9 +112,12 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = array_ops.constant([[1.], [2.]]) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = params**2 block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() + block.register_inverse() + block._factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -127,10 +133,13 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (array_ops.constant([2., 3.]), array_ops.constant(4.)) damping = 0.5 block.instantiate_factors((grads,), damping) + block._factor.instantiate_cov_variables() + block.register_inverse() + block._factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(state_ops.assign(block._factor._cov, _make_psd(3))) @@ -154,7 +163,7 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -163,7 +172,7 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -172,7 +181,7 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors(grads, 0.5) @@ -182,9 +191,10 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -200,9 +210,10 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = array_ops.constant([[1.], [2.]]) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = params**2 block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -217,10 +228,11 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) damping = 0.5 block.instantiate_factors((grads,), damping) + block._factor.instantiate_cov_variables() cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1]) sess.run(state_ops.assign(block._factor._cov, cov)) @@ -312,8 +324,8 @@ class FullyConnectedDiagonalFBTest(test.TestCase): self.assertAllClose(expected_result, result) - def testRegisterAdditionalMinibatch(self): - """Ensure 1 big minibatch and 2 small minibatches are equivalent.""" + def testRegisterAdditionalTower(self): + """Ensure 1 big tower and 2 small towers are equivalent.""" multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( self.w, [self.inputs], [self.outputs], [self.output_grads]) multiply_result_small, multiply_inverse_result_small = ( @@ -364,9 +376,10 @@ class FullyConnectedDiagonalFBTest(test.TestCase): block = fb.FullyConnectedDiagonalFB( lc.LayerCollection(), has_bias=isinstance(params, (tuple, list))) for (i, o) in zip(inputs, outputs): - block.register_additional_minibatch(i, o) + block.register_additional_tower(i, o) block.instantiate_factors((output_grads,), damping=0.0) + block._factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) sess.run(block._factor.make_covariance_update_op(0.0)) @@ -389,12 +402,12 @@ class EmbeddingKFACFBTest(test.TestCase): # Add some examples. inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) outputs = array_ops.constant([[0.], [1.], [2.]]) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) # Instantiate factor's variables. Ensure it doesn't fail. grads = outputs**2. damping = array_ops.constant(0.) - block.instantiate_factors(([grads],), damping) + block.instantiate_factors(((grads,),), damping) def testMultiplyInverse(self): with ops.Graph().as_default(), self.test_session() as sess: @@ -407,12 +420,17 @@ class EmbeddingKFACFBTest(test.TestCase): # Add some examples. inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) outputs = array_ops.constant([[0.], [1.], [2.]]) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) # Instantiate factor's variables. Ensure it doesn't fail. grads = outputs**2. damping = array_ops.constant(0.) - block.instantiate_factors(([grads],), damping) + block.instantiate_factors(((grads,),), damping) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() # Create a sparse update. indices = array_ops.constant([1, 3, 4]) @@ -443,7 +461,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([1., 2.]) outputs = array_ops.constant([3., 4.]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection()) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) self.assertAllEqual([outputs], block.tensors_to_compute_grads()) @@ -453,10 +471,10 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([[1., 2.], [3., 4.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) def testInstantiateFactorsNoBias(self): with ops.Graph().as_default(): @@ -464,10 +482,10 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([[1., 2.], [3., 4.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) def testMultiplyInverseTuple(self): with ops.Graph().as_default(), self.test_session() as sess: @@ -475,9 +493,15 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) + + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -501,9 +525,14 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([[1., 2.], [3., 4.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -524,13 +553,20 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): outputs = array_ops.zeros([32, output_dim]) params = array_ops.zeros([input_dim, output_dim]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 damping = 0. # This test is only valid without damping. - block.instantiate_factors(([grads],), damping) + block.instantiate_factors(((grads,),), damping) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3))) sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) + + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + sess.run(block._input_factor.make_inverse_update_ops()) sess.run(block._output_factor.make_inverse_update_ops()) @@ -653,8 +689,8 @@ class ConvDiagonalFBTest(test.TestCase): self.assertAllClose(expected_result, result, atol=1e-3) - def testRegisterAdditionalMinibatch(self): - """Ensure 1 big minibatch and 2 small minibatches are equivalent.""" + def testRegisterAdditionalTower(self): + """Ensure 1 big tower and 2 small towers are equivalent.""" multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( self.w, [self.inputs], [self.outputs], [self.output_grads]) multiply_result_small, multiply_inverse_result_small = ( @@ -715,9 +751,10 @@ class ConvDiagonalFBTest(test.TestCase): block = fb.ConvDiagonalFB( lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME') for (i, o) in zip(inputs, outputs): - block.register_additional_minibatch(i, o) + block.register_additional_tower(i, o) block.instantiate_factors((output_grads,), damping=0.0) + block._factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) sess.run(block._factor.make_covariance_update_op(0.0)) @@ -727,6 +764,54 @@ class ConvDiagonalFBTest(test.TestCase): return multiply_result, multiply_inverse_result +class DepthwiseConvKFCBasicFBTest(test.TestCase): + + def testInstantiateFactors(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = random_ops.random_normal((3, 3, 8, 2)) + inputs = random_ops.random_normal((32, 5, 5, 8)) + outputs = random_ops.random_normal((32, 5, 5, 16)) + layer_collection = lc.LayerCollection() + block = fb.DepthwiseConvKFCBasicFB( + layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME') + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + + def testMultiplyInverse(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = random_ops.random_normal((3, 3, 8, 2)) + inputs = random_ops.random_normal((32, 5, 5, 8)) + outputs = random_ops.random_normal((32, 5, 5, 16)) + layer_collection = lc.LayerCollection() + block = fb.DepthwiseConvKFCBasicFB( + layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME') + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Ensure inverse update op doesn't crash. + sess.run(tf_variables.global_variables_initializer()) + sess.run([ + factor.make_inverse_update_ops() + for factor in layer_collection.get_factors() + ]) + + # Ensure inverse-vector multiply doesn't crash. + output = block.multiply_inverse(params) + sess.run(output) + + # Ensure same shape. + self.assertAllEqual(output.shape, params.shape) + + class ConvKFCBasicFBTest(test.TestCase): def _testConvKFCBasicFBInitParams(self, params): @@ -738,16 +823,17 @@ class ConvKFCBasicFBTest(test.TestCase): params = array_ops.constant(params) inputs = random_ops.random_normal((2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, [1, 1, 1], 'SAME') - block.register_additional_minibatch(inputs, outputs) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) self.assertAllEqual([outputs], block.tensors_to_compute_grads()) def testConvKFCBasicFBInitParamsParamsTuple(self): - self._testConvKFCBasicFBInitParams([np.array([1., 2.]), np.array(3.)]) + self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2]), np.ones([2])]) def testConvKFCBasicFBInitParamsParamsSingle(self): - self._testConvKFCBasicFBInitParams([np.array([1., 2.])]) + self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2])]) def testMultiplyInverseTuple(self): with ops.Graph().as_default(), self.test_session() as sess: @@ -755,11 +841,16 @@ class ConvKFCBasicFBTest(test.TestCase): params = random_ops.random_normal((2, 2, 2, 2)) inputs = random_ops.random_normal((2, 2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), - 'SAME') - block.register_additional_minibatch(inputs, outputs) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -781,12 +872,17 @@ class ConvKFCBasicFBTest(test.TestCase): params = random_ops.random_normal((2, 2, 2, 2)) inputs = random_ops.random_normal((2, 2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), - 'SAME') - block.register_additional_minibatch(inputs, outputs) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) self.assertFalse(block._has_bias) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -804,12 +900,17 @@ class ConvKFCBasicFBTest(test.TestCase): params = [random_ops.random_normal((2, 2, 2, 2))] inputs = random_ops.random_normal((2, 2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), - 'SAME') - block.register_additional_minibatch(inputs, outputs) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) self.assertTrue(block._has_bias) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -827,12 +928,17 @@ class ConvKFCBasicFBTest(test.TestCase): params = array_ops.zeros((2, 2, 2, 2)) inputs = array_ops.zeros((2, 2, 2, 2)) outputs = array_ops.zeros((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), - 'SAME') - block.register_additional_minibatch(inputs, outputs) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) grads = outputs**2 damping = 0. # This test is only valid without damping. - block.instantiate_factors(([grads],), damping) + block.instantiate_factors(((grads,),), damping) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8))) sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) @@ -857,9 +963,9 @@ class FullyConnectedSeriesFBTest(test.TestCase): random_seed.set_random_seed(200) inputs = array_ops.constant([1., 2.]) outputs = array_ops.constant([3., 4.]) - block = fb.FullyConnectedSeriesFB( - lc.LayerCollection(), inputs=[inputs], outputs=[outputs]) - self.assertAllEqual([outputs], block.tensors_to_compute_grads()) + block = fb.FullyConnectedSeriesFB(lc.LayerCollection()) + block.register_additional_tower([inputs], [outputs]) + self.assertAllEqual([[outputs]], block.tensors_to_compute_grads()) def testInstantiateFactorsHasBias(self): with ops.Graph().as_default(): @@ -868,11 +974,10 @@ class FullyConnectedSeriesFBTest(test.TestCase): outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedSeriesFB( lc.LayerCollection(), - inputs=[inputs], - outputs=[outputs], has_bias=True) + block.register_additional_tower([inputs], [outputs]) grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) + block.instantiate_factors((((grads,),),), 0.5) def testInstantiateFactorsNoBias(self): with ops.Graph().as_default(): @@ -881,11 +986,10 @@ class FullyConnectedSeriesFBTest(test.TestCase): outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedSeriesFB( lc.LayerCollection(), - inputs=[inputs], - outputs=[outputs], has_bias=False) + block.register_additional_tower([inputs], [outputs]) grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) + block.instantiate_factors((((grads,),),), 0.5) def as_tensors(tensor_or_tuple): diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py index 66e18974abfadaad5d7a20b40d0b1352bfda67ee..2a3592c53fdda488561e504ba2712aadc3214cc4 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np import numpy.random as npr +from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb from tensorflow.contrib.kfac.python.ops import fisher_factors as ff from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -29,36 +30,13 @@ from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import test -class MaybeColocateTest(test.TestCase): - - def setUp(self): - self._colocate_cov_ops_with_inputs = ff.COLOCATE_COV_OPS_WITH_INPUTS - - def tearDown(self): - ff.set_global_constants( - colocate_cov_ops_with_inputs=self._colocate_cov_ops_with_inputs) - - def testFalse(self): - ff.set_global_constants(colocate_cov_ops_with_inputs=False) - with tf_ops.Graph().as_default(): - a = constant_op.constant([2.0], name='a') - with ff.maybe_colocate_with(a): - b = constant_op.constant(3.0, name='b') - self.assertEqual([b'loc:@a'], a.op.colocation_groups()) - self.assertEqual([b'loc:@b'], b.op.colocation_groups()) - - def testTrue(self): - ff.set_global_constants(colocate_cov_ops_with_inputs=True) - with tf_ops.Graph().as_default(): - a = constant_op.constant([2.0], name='a') - with ff.maybe_colocate_with(a): - b = constant_op.constant(3.0, name='b') - self.assertEqual([b'loc:@a'], a.op.colocation_groups()) - self.assertEqual([b'loc:@a'], b.op.colocation_groups()) +def make_damping_func(damping): + return fb._package_func(lambda: damping, damping) class FisherFactorTestingDummy(ff.FisherFactor): @@ -98,12 +76,21 @@ class FisherFactorTestingDummy(ff.FisherFactor): def right_multiply(self, x, damping): return NotImplementedError - def left_multiply_inverse(self, x, damping): + def left_multiply_matpower(self, x, exp, damping): return NotImplementedError - def right_multiply_inverse(self, x, damping): + def right_multiply_matpower(self, x, exp, damping): return NotImplementedError + def instantiate_inv_variables(self): + return NotImplementedError + + def _num_towers(self): + raise NotImplementedError + + def _get_data_device(self): + raise NotImplementedError + class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): """Dummy class to test the non-abstract methods on ff.InverseProvidingFactor. @@ -135,6 +122,12 @@ class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): def instantiate_covariance(self): pass + def _num_towers(self): + raise NotImplementedError + + def _get_data_device(self): + raise NotImplementedError + class NumericalUtilsTest(test.TestCase): @@ -246,21 +239,24 @@ class InverseProvidingFactorTest(test.TestCase): factor = InverseProvidingFactorTestingDummy(shape) factor_var_scope = 'dummy/a_b_c' - dampings = 0.1, 1e-1, 0.00001, 1e-5 + damping_funcs = [make_damping_func(0.1), + make_damping_func(0.1), + make_damping_func(1e-5), + make_damping_func(1e-5)] + for damping_func in damping_funcs: + factor.register_inverse(damping_func) - for damping in dampings: - factor.register_damped_inverse(damping) + factor.instantiate_inv_variables() - self.assertEqual(set(dampings), set(factor._inverses_by_damping.keys())) - inv = factor._inverses_by_damping[dampings[0]] - self.assertEqual(inv, factor._inverses_by_damping[dampings[1]]) - self.assertNotEqual(inv, factor._inverses_by_damping[dampings[2]]) - self.assertEqual(factor._inverses_by_damping[dampings[2]], - factor._inverses_by_damping[dampings[3]]) + inv = factor.get_inverse(damping_funcs[0]) + self.assertEqual(inv, factor.get_inverse(damping_funcs[1])) + self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2])) + self.assertEqual(factor.get_inverse(damping_funcs[2]), + factor.get_inverse(damping_funcs[3])) factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, factor_var_scope) - self.assertListEqual([inv, factor._inverses_by_damping[dampings[2]]], - factor_vars) + self.assertEqual(set([inv, factor.get_inverse(damping_funcs[2])]), + set(factor_vars)) self.assertEqual(shape, inv.get_shape()) def testRegisterMatpower(self): @@ -270,17 +266,22 @@ class InverseProvidingFactorTest(test.TestCase): factor = InverseProvidingFactorTestingDummy(shape) factor_var_scope = 'dummy/a_b_c' - factor.register_matpower(1, 0.5) - factor.register_matpower(2, 0.5) + # TODO(b/74201126): Change to using the same func for both once + # Topohash is in place. + damping_func_1 = make_damping_func(0.5) + damping_func_2 = make_damping_func(0.5) + + factor.register_matpower(-0.5, damping_func_1) + factor.register_matpower(2, damping_func_2) + + factor.instantiate_inv_variables() - self.assertEqual( - set([(1, 0.5), (2, 0.5)]), - set(factor._matpower_by_exp_and_damping.keys())) factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, factor_var_scope) - matpower1 = factor.get_matpower(1, 0.5) - matpower2 = factor.get_matpower(2, 0.5) - self.assertListEqual([matpower1, matpower2], factor_vars) + matpower1 = factor.get_matpower(-0.5, damping_func_1) + matpower2 = factor.get_matpower(2, damping_func_2) + + self.assertEqual(set([matpower1, matpower2]), set(factor_vars)) self.assertEqual(shape, matpower1.get_shape()) self.assertEqual(shape, matpower2.get_shape()) @@ -299,17 +300,24 @@ class InverseProvidingFactorTest(test.TestCase): factor = InverseProvidingFactorTestingDummy(cov.shape) factor._cov = array_ops.constant(cov, dtype=dtypes.float32) + damping_funcs = [] for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): - factor.register_damped_inverse(1. / i) + damping_funcs.append(make_damping_func(1./i)) + + for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD): + factor.register_inverse(damping_funcs[i]) + + factor.instantiate_inv_variables() ops = factor.make_inverse_update_ops() self.assertEqual(1, len(ops)) sess.run(tf_variables.global_variables_initializer()) new_invs = [] sess.run(ops) - for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): + for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD): # The inverse op will assign the damped inverse of cov to the inv var. - new_invs.append(sess.run(factor._inverses_by_damping[1. / i])) + new_invs.append(sess.run(factor.get_inverse(damping_funcs[i]))) + # We want to see that the new invs are all different from each other. for i in range(len(new_invs)): for j in range(i + 1, len(new_invs)): @@ -324,14 +332,16 @@ class InverseProvidingFactorTest(test.TestCase): factor._cov = array_ops.constant(cov, dtype=dtypes.float32) exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power damping = 0.5 + damping_func = make_damping_func(damping) - factor.register_matpower(exp, damping) + factor.register_matpower(exp, damping_func) + factor.instantiate_inv_variables() ops = factor.make_inverse_update_ops() self.assertEqual(1, len(ops)) sess.run(tf_variables.global_variables_initializer()) sess.run(ops[0]) - matpower = sess.run(factor._matpower_by_exp_and_damping[(exp, damping)]) + matpower = sess.run(factor.get_matpower(exp, damping_func)) matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp) self.assertAllClose(matpower, matpower_np) @@ -342,18 +352,21 @@ class InverseProvidingFactorTest(test.TestCase): factor = InverseProvidingFactorTestingDummy(cov.shape) factor._cov = array_ops.constant(cov, dtype=dtypes.float32) - factor.register_damped_inverse(0) + damping_func = make_damping_func(0) + + factor.register_inverse(damping_func) + factor.instantiate_inv_variables() ops = factor.make_inverse_update_ops() self.assertEqual(1, len(ops)) sess.run(tf_variables.global_variables_initializer()) # The inverse op will assign the damped inverse of cov to the inv var. - old_inv = sess.run(factor._inverses_by_damping[0]) + old_inv = sess.run(factor.get_inverse(damping_func)) self.assertAllClose( sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv) sess.run(ops) - new_inv = sess.run(factor._inverses_by_damping[0]) + new_inv = sess.run(factor.get_inverse(damping_func)) self.assertAllClose(new_inv, np.linalg.inv(cov)) @@ -364,6 +377,7 @@ class FullFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), name='a/b/c') factor = ff.FullFactor((tensor,), 32) + factor.instantiate_cov_variables() self.assertEqual([6, 6], factor.get_cov().get_shape().as_list()) def testFullFactorInitFloat64(self): @@ -372,6 +386,7 @@ class FullFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.FullFactor((tensor,), 32) + factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([6, 6], cov.get_shape().as_list()) @@ -381,6 +396,7 @@ class FullFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.constant([1., 2.], name='a/b/c') factor = ff.FullFactor((tensor,), 2) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) @@ -394,6 +410,7 @@ class NaiveDiagonalFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) + factor.instantiate_cov_variables() self.assertEqual([6, 1], factor.get_cov_var().get_shape().as_list()) def testNaiveDiagonalFactorInitFloat64(self): @@ -402,6 +419,7 @@ class NaiveDiagonalFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) + factor.instantiate_cov_variables() cov = factor.get_cov_var() self.assertEqual(cov.dtype, dtype) self.assertEqual([6, 1], cov.get_shape().as_list()) @@ -411,6 +429,7 @@ class NaiveDiagonalFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.constant([1., 2.], name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 2) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) @@ -424,6 +443,7 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase): input_ids = array_ops.constant([[0], [1], [4]]) vocab_size = 5 factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) + factor.instantiate_cov_variables() cov = factor.get_cov_var() self.assertEqual(cov.shape.as_list(), [vocab_size]) @@ -432,6 +452,7 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase): input_ids = array_ops.constant([[0], [1], [4]]) vocab_size = 5 factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) + factor.instantiate_cov_variables() cov_update_op = factor.make_covariance_update_op(0.0) with self.test_session() as sess: @@ -440,6 +461,118 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase): self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov) +class ConvDiagonalFactorTest(test.TestCase): + + def setUp(self): + self.batch_size = 10 + self.height = self.width = 32 + self.in_channels = 3 + self.out_channels = 1 + self.kernel_height = self.kernel_width = 3 + self.strides = [1, 2, 2, 1] + self.data_format = 'NHWC' + self.padding = 'SAME' + self.kernel_shape = [ + self.kernel_height, self.kernel_width, self.in_channels, + self.out_channels + ] + + def testInit(self): + with tf_ops.Graph().as_default(): + inputs = random_ops.random_uniform( + [self.batch_size, self.height, self.width, self.in_channels]) + outputs_grads = [ + random_ops.random_uniform([ + self.batch_size, self.height // self.strides[1], + self.width // self.strides[2], self.out_channels + ]) for _ in range(3) + ] + + factor = ff.ConvDiagonalFactor( + (inputs,), + (outputs_grads,), + self.kernel_shape, + self.strides, + self.padding, + data_format=self.data_format) + factor.instantiate_cov_variables() + + # Ensure covariance matrix's shape makes sense. + self.assertEqual([ + self.kernel_height * self.kernel_width * self.in_channels, + self.out_channels + ], + factor.get_cov_var().shape.as_list()) + + def testMakeCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(): + # Construct all arguments such that convolution kernel is applied in + # exactly one spatial location. + inputs = np.random.randn( + 1, # batch_size + self.kernel_height, + self.kernel_width, + self.in_channels) # in_channels + outputs_grad = np.random.randn( + 1, # batch_size + 1, # output_height + 1, # output_width + self.out_channels) + + factor = ff.ConvDiagonalFactor( + (constant_op.constant(inputs),), + ((constant_op.constant(outputs_grad),),), + self.kernel_shape, + strides=[1, 1, 1, 1], + padding='VALID') + factor.instantiate_cov_variables() + + # Completely forget initial value on first update. + cov_update_op = factor.make_covariance_update_op(0.0) + + # Ensure new covariance value is same as outer-product of inputs/outputs + # vectorized, squared. + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + cov = sess.run(cov_update_op) + expected_cov = np.outer(inputs.flatten(), outputs_grad.flatten())**2 + self.assertAllClose(expected_cov, cov) + + def testHasBias(self): + with tf_ops.Graph().as_default(): + inputs = random_ops.random_uniform( + [self.batch_size, self.height, self.width, self.in_channels]) + outputs_grads = [ + random_ops.random_uniform([ + self.batch_size, self.height // self.strides[1], + self.width // self.strides[2], self.out_channels + ]) for _ in range(3) + ] + + factor = ff.ConvDiagonalFactor( + (inputs,), + (outputs_grads,), + self.kernel_shape, + self.strides, + self.padding, + data_format=self.data_format, + has_bias=True) + factor.instantiate_cov_variables() + + # Ensure shape accounts for bias. + self.assertEqual([ + self.kernel_height * self.kernel_width * self.in_channels + 1, + self.out_channels + ], + factor.get_cov_var().shape.as_list()) + + # Ensure update op doesn't crash. + cov_update_op = factor.make_covariance_update_op(0.0) + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(cov_update_op) + + class FullyConnectedKroneckerFactorTest(test.TestCase): def _testFullyConnectedKroneckerFactorInit(self, @@ -449,7 +582,8 @@ class FullyConnectedKroneckerFactorTest(test.TestCase): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') - factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=has_bias) + factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=has_bias) + factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual(final_shape, cov.get_shape().as_list()) @@ -466,7 +600,8 @@ class FullyConnectedKroneckerFactorTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=True) + factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=True) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) @@ -476,40 +611,171 @@ class FullyConnectedKroneckerFactorTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - factor = ff.FullyConnectedKroneckerFactor((tensor,)) + factor = ff.FullyConnectedKroneckerFactor(((tensor,),)) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov) -class ConvInputKroneckerFactorTest(test.TestCase): +class ConvFactorTestCase(test.TestCase): + + def assertMatrixRank(self, rank, matrix, atol=1e-5): + assert rank <= matrix.shape[0], 'Rank cannot be larger than matrix size.' + eigvals = np.linalg.eigvals(matrix) + nnz_eigvals = np.sum(eigvals > atol) + self.assertEqual( + rank, + nnz_eigvals, + msg=('Found %d of %d expected non-zero eigenvalues: %s.' % + (nnz_eigvals, rank, eigvals))) + + +class ConvInputKroneckerFactorTest(ConvFactorTestCase): + + def test3DConvolution(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 3**3 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, width, in_channels), seed=0),), + filter_shape=(width, width, width, in_channels, out_channels), + padding='SAME', + strides=(2, 2, 2), + extract_patches_fn='extract_convolution_patches', + has_bias=False) + factor.instantiate_cov_variables() + + # Ensure shape of covariance matches input size of filter. + input_size = in_channels * (width**3) + self.assertEqual([input_size, input_size], + factor.get_cov_var().shape.as_list()) + + # Ensure cov_update_op doesn't crash. + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov_var()) + + # Cov should be rank-8, as the filter will be applied at each corner of + # the 4-D cube. + self.assertMatrixRank(8, cov) + + def testPointwiseConv2d(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 3**2 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0),), + filter_shape=(1, 1, in_channels, out_channels), + padding='SAME', + strides=(1, 1, 1, 1), + extract_patches_fn='extract_pointwise_conv2d_patches', + has_bias=False) + factor.instantiate_cov_variables() + + # Ensure shape of covariance matches input size of filter. + self.assertEqual([in_channels, in_channels], + factor.get_cov_var().shape.as_list()) + + # Ensure cov_update_op doesn't crash. + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov_var()) + + # Cov should be rank-9, as the filter will be applied at each location. + self.assertMatrixRank(9, cov) + + def testStrides(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 3**2 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0),), + filter_shape=(1, 1, in_channels, out_channels), + padding='SAME', + strides=(1, 2, 1, 1), + extract_patches_fn='extract_image_patches', + has_bias=False) + factor.instantiate_cov_variables() + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov_var()) + + # Cov should be the sum of 3 * 2 = 6 outer products. + self.assertMatrixRank(6, cov) + + def testDilationRate(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 2 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0),), + filter_shape=(3, 3, in_channels, out_channels), + padding='SAME', + extract_patches_fn='extract_image_patches', + strides=(1, 1, 1, 1), + dilation_rate=(1, width, width, 1), + has_bias=False) + factor.instantiate_cov_variables() + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov_var()) + + # Cov should be rank = in_channels, as only the center of the filter + # receives non-zero input for each input channel. + self.assertMatrixRank(in_channels, cov) def testConvInputKroneckerFactorInitNoBias(self): with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') + tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') factor = ff.ConvInputKroneckerFactor( - tensor, (1, 2, 3, 4), 3, 2, has_bias=False) + inputs=(tensor,), + filter_shape=(1, 2, 3, 4), + padding='SAME', + has_bias=False) + factor.instantiate_cov_variables() self.assertEqual([1 * 2 * 3, 1 * 2 * 3], factor.get_cov().get_shape().as_list()) def testConvInputKroneckerFactorInit(self): with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') + tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') factor = ff.ConvInputKroneckerFactor( - tensor, (1, 2, 3, 4), 3, 2, has_bias=True) + (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) + factor.instantiate_cov_variables() self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], factor.get_cov().get_shape().as_list()) def testConvInputKroneckerFactorInitFloat64(self): with tf_ops.Graph().as_default(): dtype = dtypes.float64_ref - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64) factor = ff.ConvInputKroneckerFactor( - tensor, (1, 2, 3, 4), 3, 2, has_bias=True) + (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) + factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], @@ -517,37 +783,67 @@ class ConvInputKroneckerFactorTest(test.TestCase): def testMakeCovarianceUpdateOpWithBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: - random_seed.set_random_seed(200) + input_shape = (2, 1, 1, 1) tensor = array_ops.constant( - np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32) + np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( + np.float32)) factor = ff.ConvInputKroneckerFactor( - tensor, (1, 2, 1, 1), [1, 1, 1, 1], 'SAME', has_bias=True) + (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[34.375, 37, 3.125], [37, 41, 3.5], [3.125, 3.5, 1]], - new_cov) + new_cov = sess.run(factor.make_covariance_update_op(0.)) + self.assertAllClose( + [ + [(1. + 4.) / 2., (1. + 2.) / 2.], # + [(1. + 2.) / 2., (1. + 1.) / 2.] + ], # + new_cov) def testMakeCovarianceUpdateOpNoBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: - random_seed.set_random_seed(200) + input_shape = (2, 1, 1, 1) tensor = array_ops.constant( - np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32) - factor = ff.ConvInputKroneckerFactor(tensor, (1, 2, 1, 1), [1, 1, 1, 1], - 'SAME') + np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( + np.float32)) + factor = ff.ConvInputKroneckerFactor( + (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME') + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[34.375, 37], [37, 41]], new_cov) + new_cov = sess.run(factor.make_covariance_update_op(0.)) + self.assertAllClose([[(1. + 4.) / 2.]], new_cov) + + +class ConvOutputKroneckerFactorTest(ConvFactorTestCase): + + def test3DConvolution(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + out_channels = width**3 + + factor = ff.ConvOutputKroneckerFactor(outputs_grads=([ + random_ops.random_uniform( + (batch_size, width, width, width, out_channels), seed=0) + ],)) + factor.instantiate_cov_variables() + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov()) -class ConvOutputKroneckerFactorTest(test.TestCase): + # Cov should be rank 3^3, as each spatial position donates a rank-1 + # update. + self.assertMatrixRank(width**3, cov) def testConvOutputKroneckerFactorInit(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c') - factor = ff.ConvOutputKroneckerFactor((tensor,)) + factor = ff.ConvOutputKroneckerFactor(((tensor,),)) + factor.instantiate_cov_variables() self.assertEqual([5, 5], factor.get_cov().get_shape().as_list()) def testConvOutputKroneckerFactorInitFloat64(self): @@ -555,23 +851,18 @@ class ConvOutputKroneckerFactorTest(test.TestCase): dtype = dtypes.float64_ref random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c') - factor = ff.ConvOutputKroneckerFactor((tensor,)) + factor = ff.ConvOutputKroneckerFactor(((tensor,),)) + factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([5, 5], cov.get_shape().as_list()) - def testConvOutputKroneckerFactorInitNotEnoughDims(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') - with self.assertRaises(IndexError): - ff.ConvOutputKroneckerFactor(tensor) - def testMakeCovarianceUpdateOp(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32) - factor = ff.ConvOutputKroneckerFactor((array_ops.constant(tensor),)) + factor = ff.ConvOutputKroneckerFactor(((array_ops.constant(tensor),),)) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) @@ -584,8 +875,8 @@ class FullyConnectedMultiKFTest(test.TestCase): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), name='a/b/c') - tensor_list = [tensor] - factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False) + factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False) + factor.instantiate_cov_variables() self.assertEqual([3, 3], factor.get_cov().get_shape().as_list()) def testFullyConnectedMultiKFInitFloat64(self): @@ -593,8 +884,8 @@ class FullyConnectedMultiKFTest(test.TestCase): dtype = dtypes.float64_ref random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') - tensor_list = [tensor] - factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False) + factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False) + factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([3, 3], cov.get_shape().as_list()) @@ -603,8 +894,8 @@ class FullyConnectedMultiKFTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - tensor_list = [tensor] - factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=True) + factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=True) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) @@ -614,8 +905,8 @@ class FullyConnectedMultiKFTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - tensor_list = [tensor] - factor = ff.FullyConnectedMultiKF((tensor_list,)) + factor = ff.FullyConnectedMultiKF(((tensor,),)) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py index b8ccbeadd0a9d69edb41fef50e3edb090457adf2..cb80fca3705308f92e308e2a840336fb72d0fa62 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -35,7 +35,7 @@ from tensorflow.python.platform import test class MockFisherBlock(object): """A fake FisherBlock.""" - num_registered_minibatches = 2 + num_registered_towers = 2 def __init__(self, name='MockFisherBlock'): self.name = name @@ -104,22 +104,53 @@ class LayerCollectionTest(test.TestCase): array_ops.constant(3), approx=layer_collection.APPROX_DIAGONAL_NAME) lc.register_conv2d( - array_ops.constant(4), [1, 1, 1, 1], 'SAME', - array_ops.ones((1, 1, 1, 1)), array_ops.constant(3)) + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=array_ops.ones((1, 2, 3, 4)), + outputs=array_ops.ones((1, 1, 1, 5))) lc.register_conv2d( - array_ops.constant(4), [1, 1, 1, 1], - 'SAME', - array_ops.ones((1, 1, 1, 1)), - array_ops.constant(3), + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=array_ops.ones((1, 2, 3, 4)), + outputs=array_ops.ones((1, 1, 1, 5)), approx=layer_collection.APPROX_DIAGONAL_NAME) + lc.register_separable_conv2d( + depthwise_params=array_ops.ones((3, 3, 1, 2)), + pointwise_params=array_ops.ones((1, 1, 2, 4)), + inputs=array_ops.ones((32, 5, 5, 1)), + depthwise_outputs=array_ops.ones((32, 5, 5, 2)), + pointwise_outputs=array_ops.ones((32, 5, 5, 4)), + strides=[1, 1, 1, 1], + padding='SAME') + lc.register_convolution( + params=array_ops.ones((3, 3, 1, 8)), + inputs=array_ops.ones((32, 5, 5, 1)), + outputs=array_ops.ones((32, 5, 5, 8)), + padding='SAME') lc.register_generic( array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME) lc.register_generic( array_ops.constant(6), 16, approx=layer_collection.APPROX_DIAGONAL_NAME) - - self.assertEqual(6, len(lc.get_blocks())) + lc.register_fully_connected_multi( + array_ops.constant(1), + (array_ops.constant(2), array_ops.constant(3)), + (array_ops.constant(4), array_ops.constant(5))) + lc.register_conv2d_multi( + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=(array_ops.ones((1, 2, 3, 4)), array_ops.ones((5, 6, 7, 8))), + outputs=(array_ops.ones((1, 1, 1, 5)), array_ops.ones((2, 2, 2, 10)))) + lc.register_embedding_multi( + array_ops.constant((1,)), + (array_ops.constant(2), array_ops.constant(3)), + (array_ops.constant(4), array_ops.constant(5))) + + self.assertEqual(12, len(lc.get_blocks())) def testRegisterBlocksMultipleRegistrations(self): with ops.Graph().as_default(): @@ -237,16 +268,16 @@ class LayerCollectionTest(test.TestCase): # Create a new loss function by name. lc.register_categorical_predictive_distribution(logits, name='loss1') - self.assertEqual(1, len(lc.losses)) + self.assertEqual(1, len(lc.towers_by_loss)) # Add logits to same loss function. lc.register_categorical_predictive_distribution( logits, name='loss1', reuse=True) - self.assertEqual(1, len(lc.losses)) + self.assertEqual(1, len(lc.towers_by_loss)) # Add another new loss function. lc.register_categorical_predictive_distribution(logits, name='loss2') - self.assertEqual(2, len(lc.losses)) + self.assertEqual(2, len(lc.towers_by_loss)) def testLossFunctionWithoutName(self): """Ensure loss functions get unique names if 'name' not specified.""" @@ -298,13 +329,9 @@ class LayerCollectionTest(test.TestCase): name='loss1', reuse=layer_collection.VARIABLE_SCOPE) - self.assertEqual(len(lc.losses), 1) - loss = lc.losses[0] - + self.assertEqual(len(lc.towers_by_loss), 1) # Three successful registrations. - self.assertEqual(loss.params.shape.as_list(), - [3 * batch_size, output_size]) - self.assertEqual(loss.targets.shape.as_list(), [3 * batch_size]) + self.assertEqual(len(lc.towers_by_loss[0]), 3) def testRegisterCategoricalPredictiveDistributionBatchSize1(self): with ops.Graph().as_default(): @@ -441,13 +468,13 @@ class LayerCollectionTest(test.TestCase): b = variable_scope.get_variable('b', [3]) lc = layer_collection.LayerCollection() lc.register_fully_connected(w, inputs, outputs) - self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 1) + self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1) with self.assertRaises(KeyError): lc.register_fully_connected((w, b), inputs, outputs, reuse=True) self.assertNotIn((w, b), lc.fisher_blocks) - self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 1) + self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1) lc.register_fully_connected(w, inputs, outputs, reuse=True) - self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 2) + self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 2) def testMakeOrGetFactor(self): with ops.Graph().as_default(): @@ -479,17 +506,6 @@ class LayerCollectionTest(test.TestCase): variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertTrue(all([var.name.startswith(scope) for var in variables])) - def testGetUseCountMap(self): - """Ensure get_use_count_map() sums 'num_registered_minibatches'.""" - lc = layer_collection.LayerCollection() - lc.fisher_blocks = { - 'a': MockFisherBlock(), - ('a', 'c'): MockFisherBlock(), - ('b', 'c'): MockFisherBlock() - } - use_count_map = lc.get_use_count_map() - self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map) - def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self): x = variable_scope.get_variable('x', shape=()) y = variable_scope.get_variable('y', shape=()) @@ -550,6 +566,32 @@ class LayerCollectionTest(test.TestCase): self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB) self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB) + def testDefaultLayerCollection(self): + with ops.Graph().as_default(): + # Can't get default if there isn't one set. + with self.assertRaises(ValueError): + layer_collection.get_default_layer_collection() + + # Can't set default twice. + lc = layer_collection.LayerCollection() + layer_collection.set_default_layer_collection(lc) + with self.assertRaises(ValueError): + layer_collection.set_default_layer_collection(lc) + + # Same as one set. + self.assertTrue(lc is layer_collection.get_default_layer_collection()) + + # Can set to None. + layer_collection.set_default_layer_collection(None) + with self.assertRaises(ValueError): + layer_collection.get_default_layer_collection() + + # as_default() is the same as setting/clearing. + with lc.as_default(): + self.assertTrue(lc is layer_collection.get_default_layer_collection()) + with self.assertRaises(ValueError): + layer_collection.get_default_layer_collection() + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py index ae787b6f1ac90218f2ac73d37fb270df0b822de2..c00af5593f085e3b1f3e030a24f4b821115cc869 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py @@ -24,7 +24,6 @@ from tensorflow.contrib.kfac.python.ops import loss_functions from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -97,22 +96,6 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase): # difficult to say if the output is correct or not... neg_log_prob = sess.run(neg_log_prob) - def testMultiMinibatchRegistration(self): - """Ensure this loss function supports registering multiple minibatches.""" - with ops.Graph().as_default(): - tower_logits = [] - loss = None - num_towers = 5 - for _ in range(num_towers): - logits = random_ops.random_uniform(shape=[2, 3]) - tower_logits.append(logits) - if loss is None: - loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) - else: - loss.register_additional_minibatch(logits) - self.assertListEqual(loss.input_minibatches, tower_logits) - self.assertEqual(loss.num_registered_minibatches, num_towers) - def testMultiplyFisherSingleVector(self): with ops.Graph().as_default(), self.test_session() as sess: logits = np.array([1., 2., 3.]) @@ -203,23 +186,5 @@ class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase): # difficult to say if the output is correct or not... neg_log_prob = sess.run(neg_log_prob) - def testMultiMinibatchRegistration(self): - """Ensure this loss function supports registering multiple minibatches.""" - with ops.Graph().as_default(): - tower_logits = [] - loss = None - num_towers = 5 - for _ in range(num_towers): - logits = random_ops.random_uniform(shape=[2, 3]) - tower_logits.append(logits) - if loss is None: - loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( - logits) - else: - loss.register_additional_minibatch(logits) - self.assertListEqual(loss.input_minibatches, tower_logits) - self.assertEqual(loss.num_registered_minibatches, num_towers) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py index 97a97adbf5577cd2694d3055acaa59258ad27964..2cee01212a11595669e9df0fc95a5657926c1038 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py @@ -29,6 +29,8 @@ from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -325,6 +327,84 @@ class UtilsTest(test.TestCase): ], values) + def testExtractConvolutionPatches(self): + with ops.Graph().as_default(), self.test_session() as sess: + batch_size = 10 + image_spatial_shape = [9, 10, 11] + in_channels = out_channels = 32 + kernel_spatial_shape = [5, 3, 3] + spatial_strides = [1, 2, 1] + spatial_dilation = [1, 1, 1] + padding = 'SAME' + + images = random_ops.random_uniform( + [batch_size] + image_spatial_shape + [in_channels], seed=0) + kernel_shape = kernel_spatial_shape + [in_channels, out_channels] + kernel = random_ops.random_uniform(kernel_shape, seed=1) + + # Ensure shape matches expectation. + patches = utils.extract_convolution_patches( + images, + kernel_shape, + padding, + strides=spatial_strides, + dilation_rate=spatial_dilation) + result_spatial_shape = ( + patches.shape.as_list()[1:1 + len(image_spatial_shape)]) + self.assertEqual(patches.shape.as_list(), + [batch_size] + result_spatial_shape + + kernel_spatial_shape + [in_channels]) + + # Ensure extract...patches() + matmul() and convolution() implementation + # give the same answer. + outputs = nn_ops.convolution( + images, + kernel, + padding, + strides=spatial_strides, + dilation_rate=spatial_dilation) + + patches_flat = array_ops.reshape( + patches, [-1, np.prod(kernel_spatial_shape) * in_channels]) + kernel_flat = array_ops.reshape(kernel, [-1, out_channels]) + outputs_flat = math_ops.matmul(patches_flat, kernel_flat) + + outputs_, outputs_flat_ = sess.run([outputs, outputs_flat]) + self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten()) + + def testExtractPointwiseConv2dPatches(self): + with ops.Graph().as_default(), self.test_session() as sess: + batch_size = 10 + image_height = image_width = 8 + in_channels = out_channels = 3 + kernel_height = kernel_width = 1 + strides = [1, 1, 1, 1] + padding = 'VALID' + + images = random_ops.random_uniform( + [batch_size, image_height, image_width, in_channels], seed=0) + kernel_shape = [kernel_height, kernel_width, in_channels, out_channels] + kernel = random_ops.random_uniform(kernel_shape, seed=1) + + # Ensure shape matches expectation. + patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape) + self.assertEqual(patches.shape.as_list(), [ + batch_size, image_height, image_width, kernel_height, kernel_width, + in_channels + ]) + + # Ensure extract...patches() + matmul() and conv2d() implementation + # give the same answer. + outputs = nn_ops.conv2d(images, kernel, strides, padding) + + patches_flat = array_ops.reshape( + patches, [-1, kernel_height * kernel_width * in_channels]) + kernel_flat = array_ops.reshape(kernel, [-1, out_channels]) + outputs_flat = math_ops.matmul(patches_flat, kernel_flat) + + outputs_, outputs_flat_ = sess.run([outputs, outputs_flat]) + self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten()) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index c26230c2a82ae9529ab13b523b9ec287d17debaf..d721ad08afaa416f86ce881d4cdd968cd1809b5a 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -171,6 +171,7 @@ py_library( name = "fisher_estimator", srcs = [ "estimator.py", + "placement.py", ], srcs_version = "PY2AND3", deps = [ @@ -180,6 +181,7 @@ py_library( "//tensorflow/python:gradients", "//tensorflow/python:util", "//third_party/py/numpy", + "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index a7e268c48ae326a4d8fa5fe4a4ed15b8b83a0ed9..ced1110676754b6c8bba813ace743b3f3daddb26 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -18,91 +18,91 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib -import itertools - +import abc import numpy as np +import six +from tensorflow.contrib.kfac.python.ops import placement from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest -class _DeviceContextGenerator(object): - """Class for generating device contexts in a round-robin fashion.""" - - def __init__(self, devices): - """Creates a _DeviceContextGenerator object. +# The linter is confused. +# pylint: disable=abstract-class-instantiated +def make_fisher_estimator(placement_strategy=None, **kwargs): + """Creates Fisher estimator instances based on the placement strategy. - Example usage: + For example if the `placement_strategy` is 'round_robin' then + `FisherEstimatorRoundRobin` instance is returned. - ```python - dcg = _DeviceContextGenerator(['/gpu:0', 'gpu:1']) - with dcg(): - # All operations in this context will be placed on GPU 0 - ... - with dcg(): - # All operations in this context will be placed on GPU 1 - ... - ``` + Args: + placement_strategy: `string`, Strategy to be used for placing covariance + variables, covariance ops and inverse ops. Check + `placement.FisherEstimatorRoundRobin` for a concrete example. + **kwargs: Arguments to be passed into `FisherEstimator` class initializer. - Args: - devices: An iterable of device strings (or None). Successive calls to - __call__ will give contexts which place devices on these devices in - a round-robin fashion. - """ - self._cycle = None if devices is None else itertools.cycle(devices) + Returns: + An instance of class which inherits from `FisherEstimator` and the mixin + which implements specific placement strategy. See, + `FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and + `RoundRobinPlacementMixin`. - @contextlib.contextmanager - def __call__(self): - """Returns a context manager specifying the default device.""" - if self._cycle is None: - yield - else: - with tf_ops.device(next(self._cycle)): - yield + Raises: + ValueError: If the `placement_strategy` is not equal to 'round_robin'. + """ + if placement_strategy in [None, "round_robin"]: + return FisherEstimatorRoundRobin(**kwargs) + else: + raise ValueError("Unimplemented vars and ops placement strategy : %s", + placement_strategy) +# pylint: enable=abstract-class-instantiated +@six.add_metaclass(abc.ABCMeta) class FisherEstimator(object): """Fisher estimator class supporting various approximations of the Fisher. - Attributes: - cov_update_thunks: list of no-arg functions. Executing a function adds - covariance update ops for a single FisherFactor to the graph. - cov_update_ops: List of Ops. Running an op updates covariance matrices for a - single FisherFactor. - cov_update_op: Op. Running updates covariance matrices for all - FisherFactors. - inv_update_thunks: list of no-arg functions. Executing a function adds - inverse update ops for a single FisherFactor to the graph. - inv_update_ops: List of Ops. Running an op updates inverse matrices for a - single FisherFactor. - inv_update_op: Op. Running updates inverse matrices for all FisherFactors. + This is an abstract base class which does not implement a strategy for + placing covariance variables, covariance update ops and inverse update ops. + The placement strategies are implemented in `placement.py`. See + `FisherEstimatorRoundRobin` for example of a concrete subclass with + a round-robin placement strategy. """ def __init__(self, - damping_fn, variables, cov_ema_decay, + damping, layer_collection, + exps=(-1,), estimation_mode="gradients", colocate_gradients_with_ops=True, - cov_devices=None, - inv_devices=None): + name="FisherEstimator"): """Create a FisherEstimator object. Args: - damping_fn: Function, accepts no arguments and returns damping value. variables: A list of the variables for which to estimate the Fisher. This must match the variables registered in layer_collection (if it is not None). cov_ema_decay: The decay factor used when calculating the covariance estimate moving averages. + damping: float. The damping factor used to stabilize training due to + errors in the local approximation with the Fisher information matrix, + and to regularize the update direction by making it closer to the + gradient. (Higher damping means the update looks more like a standard + gradient update - see Tikhonov regularization.) layer_collection: The layer collection object, which holds the fisher blocks, kronecker factors, and losses associated with the graph. + exps: List of floats or ints. These represent the different matrix + powers of the approximate Fisher that the FisherEstimator will be able + to multiply vectors by. If the user asks for a matrix power other + one of these (or 1, which is always supported), there will be a + failure. (Default: (-1,)) estimation_mode: The type of estimator to use for the Fishers. Can be 'gradients', 'empirical', 'curvature_prop', or 'exact'. (Default: 'gradients'). 'gradients' is the basic estimation approach @@ -121,23 +121,17 @@ class FisherEstimator(object): equal to the output dimension, roughly speaking. colocate_gradients_with_ops: Whether we should request gradients be colocated with their respective ops. (Default: True) - cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - + name: A string. A name given to this estimator, which is added to the + variable scope when constructing variables and ops. + (Default: "FisherEstimator") Raises: ValueError: If no losses have been registered with layer_collection. """ - self._damping_fn = damping_fn - self._cov_ema_decay = cov_ema_decay self._variables = variables + self._cov_ema_decay = cov_ema_decay + self._damping = damping self._estimation_mode = estimation_mode self._layers = layer_collection - self._layers.create_subgraph() - self._layers.check_registration(variables) self._gradient_fns = { "gradients": self._get_grads_lists_gradients, "empirical": self._get_grads_lists_empirical, @@ -146,30 +140,10 @@ class FisherEstimator(object): } self._colocate_gradients_with_ops = colocate_gradients_with_ops - # TODO(b/70674513): Factor device placement outside of this class. - self._cov_device_context_generator = _DeviceContextGenerator(cov_devices) - if inv_devices == cov_devices: - self._inv_device_context_generator = self._cov_device_context_generator - else: - self._inv_device_context_generator = _DeviceContextGenerator(inv_devices) + self._made_vars = False + self._exps = exps - self._instantiate_factors() - - self.cov_update_thunks = [ - self._create_cov_update_thunk(factor) - for factor in self._layers.get_factors() - ] - self.cov_update_ops = [thunk() for thunk in self.cov_update_thunks] - self.cov_update_op = control_flow_ops.group( - self.cov_update_ops, name="cov_update_op") - - self.inv_update_thunks = [ - self._create_inv_update_thunk(factor) - for factor in self._layers.get_factors() - ] - self.inv_update_ops = [thunk() for thunk in self.inv_update_thunks] - self.inv_update_op = control_flow_ops.group( - self.inv_update_ops, name="inv_update_op") + self._name = name @property def variables(self): @@ -177,7 +151,92 @@ class FisherEstimator(object): @property def damping(self): - return self._damping_fn() + return self._damping + + @property + def blocks(self): + """All registered FisherBlocks.""" + return self._layers.get_blocks() + + @property + def factors(self): + """All registered FisherFactors.""" + return self._layers.get_factors() + + @property + def name(self): + return self._name + + @abc.abstractmethod + def make_ops_and_vars(self, scope=None): + """Make ops and vars with a specific placement strategy. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. For example in case of + round robin placement a new device is chosen for each factor by cycling + through list of devices in the cov_devices argument. If cov_devices is None + then no explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the inv_devices argument. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all ops will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_ops: List of ops that compute the cov updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + cov_update_op: cov_update_ops grouped into a single op. + inv_update_ops: List of ops that compute the inv updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + inv_update_op: inv_update_ops grouped into a single op. + cov_update_thunks: Thunks that make the ops in cov_update_ops. + inv_update_thunks: Thunks that make the ops in inv_update_ops. + """ + pass + + @abc.abstractmethod + def make_vars_and_create_op_thunks(self, scope=None): + """Make vars and create op thunks with a specific placement strategy. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. A new device is chosen + for each factor by cycling through list of devices in the cov_devices + argument. If cov_devices is None then no explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the inv_devices argument. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all thunks will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + pass def _apply_transformation(self, vecs_and_vars, transform): """Applies an block-wise transformation to the corresponding vectors. @@ -212,9 +271,7 @@ class FisherEstimator(object): A list of (transformed vector, var) pairs in the same order as vecs_and_vars. """ - - return self._apply_transformation(vecs_and_vars, - lambda fb, vec: fb.multiply_inverse(vec)) + return self.multiply_matpower(-1, vecs_and_vars) def multiply(self, vecs_and_vars): """Multiplies the vectors by the corresponding (damped) blocks. @@ -226,9 +283,22 @@ class FisherEstimator(object): A list of (transformed vector, var) pairs in the same order as vecs_and_vars. """ + return self.multiply_matpower(1, vecs_and_vars) + + def multiply_matpower(self, exp, vecs_and_vars): + """Multiplies the vecs by the corresponding matrix powers of the blocks. + + Args: + exp: A float representing the power to raise the blocks by before + multiplying it by the vector. + vecs_and_vars: List of (vector, variable) pairs. - return self._apply_transformation(vecs_and_vars, - lambda fb, vec: fb.multiply(vec)) + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + fcn = lambda fb, vec: fb.multiply_matpower(vec, exp) + return self._apply_transformation(vecs_and_vars, fcn) def _instantiate_factors(self): """Instantiates FisherFactors' variables. @@ -236,9 +306,9 @@ class FisherEstimator(object): Raises: ValueError: If estimation_mode was improperly specified at construction. """ - fisher_blocks_list = self._layers.get_blocks() + blocks = self.blocks tensors_to_compute_grads = [ - fb.tensors_to_compute_grads() for fb in fisher_blocks_list + block.tensors_to_compute_grads() for block in blocks ] try: @@ -248,45 +318,131 @@ class FisherEstimator(object): raise ValueError("Unrecognized value {} for estimation_mode.".format( self._estimation_mode)) - # TODO(b/68033310): This loop round-robins the "concat" operations which - # gather the inputs for the cov_updates. In future, we might do these - # computations locally then communicate the results, which would require a - # modification to this code. - for grads_list, fb in zip(grads_lists, fisher_blocks_list): - with self._cov_device_context_generator(): - fb.instantiate_factors(grads_list, self.damping) + for grads_list, block in zip(grads_lists, blocks): + block.instantiate_factors(grads_list, self.damping) + + def _check_vars_unmade_and_set_made_flag(self): + if self._made_vars: + raise Exception("Already made variables.") + self._made_vars = True + + def made_vars(self): + return self._made_vars + + def _register_matrix_functions(self): + for exp in self._exps: + for block in self.blocks: + block.register_matpower(exp) + + def _finalize_layer_collection(self): + self._layers.create_subgraph() + self._layers.check_registration(self.variables) + self._instantiate_factors() + self._register_matrix_functions() + + def create_ops_and_vars_thunks(self, scope=None): + """Create thunks that make the ops and vars on demand. + + This function returns 4 lists of thunks: cov_variable_thunks, + cov_update_thunks, inv_variable_thunks, and inv_update_thunks. + + The length of each list is the number of factors and the i-th element of + each list corresponds to the i-th factor (given by the "factors" property). + + Note that the execution of these thunks must happen in a certain + partial order. The i-th element of cov_variable_thunks must execute + before the i-th element of cov_update_thunks (and also the i-th element + of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks + must execute before the i-th element of inv_update_thunks. + + TL;DR (oversimplified): Execute the thunks according to the order that + they are returned. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All thunks will execute inside + of a variable scope of the given name. (Default: None) + Returns: + cov_variable_thunks: A list of thunks that make the cov variables. + cov_update_thunks: A list of thunks that make the cov update ops. + inv_variable_thunks: A list of thunks that make the inv variables. + inv_update_thunks: A list of thunks that make the inv update ops. + """ + self._check_vars_unmade_and_set_made_flag() - def _create_cov_update_thunk(self, factor): + self._finalize_layer_collection() + + scope = self.name if scope is None else scope + + cov_variable_thunks = [ + self._create_cov_variable_thunk(factor, scope) + for factor in self.factors + ] + cov_update_thunks = [ + self._create_cov_update_thunk(factor, scope) for factor in self.factors + ] + inv_variable_thunks = [ + self._create_inv_variable_thunk(factor, scope) + for factor in self.factors + ] + inv_update_thunks = [ + self._create_inv_update_thunk(factor, scope) for factor in self.factors + ] + + return (cov_variable_thunks, cov_update_thunks, + inv_variable_thunks, inv_update_thunks) + + def _create_cov_variable_thunk(self, factor, scope): + """Constructs a covariance variable thunk for a single FisherFactor.""" + + def thunk(): + with variable_scope.variable_scope(scope): + return factor.instantiate_cov_variables() + + return thunk + + def _create_cov_update_thunk(self, factor, scope): """Constructs a covariance update thunk for a single FisherFactor.""" def thunk(): - with tf_ops.name_scope( - "create_cov_update_thunk", values=[self._cov_ema_decay]): + with variable_scope.variable_scope(scope): return factor.make_covariance_update_op(self._cov_ema_decay) return thunk - def _create_inv_update_thunk(self, factor): + def _create_inv_variable_thunk(self, factor, scope): + """Constructs a inverse variable thunk for a single FisherFactor.""" + + def thunk(): + with variable_scope.variable_scope(scope): + return factor.instantiate_inv_variables() + + return thunk + + def _create_inv_update_thunk(self, factor, scope): """Constructs an inverse update thunk for a single FisherFactor.""" def thunk(): - with tf_ops.name_scope("create_inv_update_thunk"): - with self._inv_device_context_generator(): - return control_flow_ops.group(factor.make_inverse_update_ops()) + with variable_scope.variable_scope(scope): + return control_flow_ops.group(factor.make_inverse_update_ops()) return thunk def _get_grads_lists_gradients(self, tensors): + # Passing in a list of loss values is better than passing in the sum as + # the latter creates unnessesary ops on the default device grads_flat = gradients_impl.gradients( - self._layers.total_sampled_loss(), + self._layers.eval_losses_on_samples(), nest.flatten(tensors), colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all = nest.pack_sequence_as(tensors, grads_flat) return tuple((grad,) for grad in grads_all) def _get_grads_lists_empirical(self, tensors): + # Passing in a list of loss values is better than passing in the sum as + # the latter creates unnessesary ops on the default device grads_flat = gradients_impl.gradients( - self._layers.total_loss(), + self._layers.eval_losses(), nest.flatten(tensors), colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all = nest.pack_sequence_as(tensors, grads_flat) @@ -295,9 +451,10 @@ class FisherEstimator(object): def _get_transformed_random_signs(self): transformed_random_signs = [] for loss in self._layers.losses: - transformed_random_signs.append( - loss.multiply_fisher_factor( - utils.generate_random_signs(loss.fisher_factor_inner_shape))) + with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]): + transformed_random_signs.append( + loss.multiply_fisher_factor( + utils.generate_random_signs(loss.fisher_factor_inner_shape))) return transformed_random_signs def _get_grads_lists_curvature_prop(self, tensors): @@ -316,13 +473,20 @@ class FisherEstimator(object): # Loop over all coordinates of all losses. grads_all = [] for loss in self._layers.losses: - for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]): - transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( - index) - grads_flat = gradients_impl.gradients( - loss.inputs, - nest.flatten(tensors), - grad_ys=transformed_one_hot, - colocate_gradients_with_ops=self._colocate_gradients_with_ops) - grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) + with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]): + for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]): + transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( + index) + grads_flat = gradients_impl.gradients( + loss.inputs, + nest.flatten(tensors), + grad_ys=transformed_one_hot, + colocate_gradients_with_ops=self._colocate_gradients_with_ops) + grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) return zip(*grads_all) + + +class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin, + FisherEstimator): + """Fisher estimator which provides round robin device placement strategy.""" + pass diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index cf38d28b43836dced8babe2ffa7853b1c4b1b369..b04bf76a886049e876a8dde647dc7b718d03da9d 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -40,12 +40,15 @@ from __future__ import print_function import abc import enum # pylint: disable=g-bad-import-order +import numpy as np import six from tensorflow.contrib.kfac.python.ops import fisher_factors from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import nest # For blocks corresponding to convolutional layers, or any type of block where # the parameters can be thought of as being replicated in time or space, @@ -121,12 +124,44 @@ def compute_pi_adjusted_damping(left_cov, right_cov, damping): return (damping, damping) +class PackagedFunc(object): + """A Python thunk with a stable ID. + + Enables stable names for lambdas. + """ + + def __init__(self, func, func_id): + """Initializes PackagedFunc. + + Args: + func: a zero-arg Python function. + func_id: a hashable, function that produces a hashable, or a list/tuple + thereof. + """ + self._func = func + func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,) + self._func_id = func_id + + def __call__(self): + return self._func() + + @property + def func_id(self): + """A hashable identifier for this function.""" + return tuple(elt() if callable(elt) else elt for elt in self._func_id) + + +def _package_func(func, func_id): + return PackagedFunc(func, func_id) + + @six.add_metaclass(abc.ABCMeta) class FisherBlock(object): """Abstract base class for objects modeling approximate Fisher matrix blocks. - Subclasses must implement multiply_inverse(), instantiate_factors(), and - tensors_to_compute_grads() methods. + Subclasses must implement register_matpower, multiply_matpower, + instantiate_factors, tensors_to_compute_grads, and num_registered_towers + methods. """ def __init__(self, layer_collection): @@ -145,6 +180,32 @@ class FisherBlock(object): pass @abc.abstractmethod + def register_matpower(self, exp): + """Registers a matrix power to be computed by the block. + + Args: + exp: A float representing the power to raise the block by. + """ + pass + + def register_inverse(self): + """Registers a matrix inverse to be computed by the block.""" + self.register_matpower(-1) + + @abc.abstractmethod + def multiply_matpower(self, vector, exp): + """Multiplies the vector by the (damped) matrix-power of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + exp: A float representing the power to raise the block by before + multiplying it by the vector. + + Returns: + The vector left-multiplied by the (damped) matrix-power of the block. + """ + pass + def multiply_inverse(self, vector): """Multiplies the vector by the (damped) inverse of the block. @@ -154,9 +215,8 @@ class FisherBlock(object): Returns: The vector left-multiplied by the (damped) inverse of the block. """ - pass + return self.multiply_matpower(vector, -1) - @abc.abstractmethod def multiply(self, vector): """Multiplies the vector by the (damped) block. @@ -166,7 +226,7 @@ class FisherBlock(object): Returns: The vector left-multiplied by the (damped) block. """ - pass + return self.multiply_matpower(vector, 1) @abc.abstractmethod def tensors_to_compute_grads(self): @@ -175,8 +235,8 @@ class FisherBlock(object): pass @abc.abstractproperty - def num_registered_minibatches(self): - """Number of minibatches registered for this FisherBlock. + def num_registered_towers(self): + """Number of towers registered for this FisherBlock. Typically equal to the number of towers in a multi-tower setup. """ @@ -207,21 +267,18 @@ class FullFB(FisherBlock): super(FullFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - self._damping = damping + self._damping_func = _package_func(lambda: damping, (damping,)) + self._factor = self._layer_collection.make_or_get_factor( fisher_factors.FullFactor, (grads_list, self._batch_size)) - self._factor.register_damped_inverse(damping) - def multiply_inverse(self, vector): - vector_flat = utils.tensors_to_column(vector) - out_flat = self._factor.left_multiply_inverse( - vector_flat, self._damping) - return utils.column_to_tensors(vector, out_flat) + def register_matpower(self, exp): + self._factor.register_matpower(exp, self._damping_func) - def multiply(self, vector): + def multiply_matpower(self, vector, exp): vector_flat = utils.tensors_to_column(vector) - out_flat = self._factor.left_multiply( - vector_flat, self._damping) + out_flat = self._factor.left_multiply_matpower( + vector_flat, exp, self._damping_func) return utils.column_to_tensors(vector, out_flat) def full_fisher_block(self): @@ -231,8 +288,8 @@ class FullFB(FisherBlock): def tensors_to_compute_grads(self): return self._params - def register_additional_minibatch(self, batch_size): - """Register an additional minibatch. + def register_additional_tower(self, batch_size): + """Register an additional tower. Args: batch_size: The batch size, used in the covariance estimator. @@ -240,7 +297,7 @@ class FullFB(FisherBlock): self._batch_sizes.append(batch_size) @property - def num_registered_minibatches(self): + def num_registered_towers(self): return len(self._batch_sizes) @property @@ -271,22 +328,20 @@ class NaiveDiagonalFB(FisherBlock): super(NaiveDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - self._damping = damping + self._damping_func = _package_func(lambda: damping, (damping,)) + self._factor = self._layer_collection.make_or_get_factor( fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size)) - def multiply_inverse(self, vector): - vector_flat = utils.tensors_to_column(vector) - print("vector_flat: %s" % vector_flat) - out_flat = self._factor.left_multiply_inverse( - vector_flat, self._damping) - print("out_flat: %s" % out_flat) - return utils.column_to_tensors(vector, out_flat) + def register_matpower(self, exp): + # Not needed for this. Matrix powers are computed on demand in the + # diagonal case + pass - def multiply(self, vector): + def multiply_matpower(self, vector, exp): vector_flat = utils.tensors_to_column(vector) - out_flat = self._factor.left_multiply( - vector_flat, self._damping) + out_flat = self._factor.left_multiply_matpower( + vector_flat, exp, self._damping_func) return utils.column_to_tensors(vector, out_flat) def full_fisher_block(self): @@ -295,8 +350,8 @@ class NaiveDiagonalFB(FisherBlock): def tensors_to_compute_grads(self): return self._params - def register_additional_minibatch(self, batch_size): - """Register an additional minibatch. + def register_additional_tower(self, batch_size): + """Register an additional tower. Args: batch_size: The batch size, used in the covariance estimator. @@ -304,7 +359,7 @@ class NaiveDiagonalFB(FisherBlock): self._batch_sizes.append(batch_size) @property - def num_registered_minibatches(self): + def num_registered_towers(self): return len(self._batch_sizes) @property @@ -312,7 +367,92 @@ class NaiveDiagonalFB(FisherBlock): return math_ops.reduce_sum(self._batch_sizes) -class FullyConnectedDiagonalFB(FisherBlock): +class InputOutputMultiTower(object): + """Mix-in class for blocks with inputs & outputs and multiple mini-batches.""" + + def __init__(self, *args, **kwargs): + self.__inputs = [] + self.__outputs = [] + super(InputOutputMultiTower, self).__init__(*args, **kwargs) + + def _process_data(self, grads_list): + """Process data into the format used by the factors. + + This function takes inputs and grads_lists data and processes it into + one of the formats expected by the FisherFactor classes (depending on + the value of the global configuration variable TOWER_STRATEGY). + + The initial format of self._inputs is expected to be a list of Tensors + over towers. Similarly grads_lists is expected to be a list over sources + of such lists. + + If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing a single + tensor (represented as a PartitionedTensor object) equal to the + concatenation (across towers) of all of the elements of self._inputs. And + similarly grads_list is formatted into a tuple (over sources) of such + tensors (also represented as PartitionedTensors). + + If TOWER_STRATEGY is "separate", formatting of inputs and grads_list + remains unchanged from the initial format (although possibly converting + from lists into tuples). + + Args: + grads_list: grads_list in its initial format (see above). + + Returns: + inputs: self._inputs transformed into the appropriate format (see + above). + grads_list: grads_list transformed into the appropriate format (see + above). + + Raises: + ValueError: if TOWER_STRATEGY is not one of "separate" or "concat". + """ + inputs = self._inputs + # inputs is a list over towers of Tensors + # grads_list is a list of list with the first index being sources and the + # second being towers. + if fisher_factors.TOWER_STRATEGY == "concat": + # Merge towers together into a PartitionedTensor. We package it in + # a singleton tuple since the factors will expect a list over towers + inputs = (utils.PartitionedTensor(inputs),) + # Do the same for grads_list but preserve leading sources dimension + grads_list = tuple((utils.PartitionedTensor(grads),) + for grads in grads_list) + elif fisher_factors.TOWER_STRATEGY == "separate": + inputs = tuple(inputs) + grads_list = tuple(grads_list) + + else: + raise ValueError("Global config variable TOWER_STRATEGY must be one of " + "'concat' or 'separate'.") + + return inputs, grads_list + + def tensors_to_compute_grads(self): + """Tensors to compute derivative of loss with respect to.""" + return tuple(self._outputs) + + def register_additional_tower(self, inputs, outputs): + self._inputs.append(inputs) + self._outputs.append(outputs) + + @property + def num_registered_towers(self): + result = len(self._inputs) + assert result == len(self._outputs) + return result + + @property + def _inputs(self): + return self.__inputs + + @property + def _outputs(self): + return self.__outputs + + +class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock): """FisherBlock for fully-connected (dense) layers using a diagonal approx. Estimates the Fisher Information matrix's diagonal entries for a fully @@ -344,80 +484,46 @@ class FullyConnectedDiagonalFB(FisherBlock): has_bias: Whether the component Kronecker factors have an additive bias. (Default: False) """ - self._inputs = [] - self._outputs = [] self._has_bias = has_bias super(FullyConnectedDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - inputs = _concat_along_batch_dim(self._inputs) - grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + inputs, grads_list = self._process_data(grads_list) - self._damping = damping self._factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedDiagonalFactor, (inputs, grads_list, self._has_bias)) - def multiply_inverse(self, vector): - """Approximate damped inverse Fisher-vector product. - - Args: - vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape - [input_size, output_size] corresponding to layer's weights. If not, a - 2-tuple of the former and a Tensor of shape [output_size] corresponding - to the layer's bias. + self._damping_func = _package_func(lambda: damping, (damping,)) - Returns: - Tensor of the same shape, corresponding to the inverse Fisher-vector - product. - """ - reshaped_vec = utils.layer_params_to_mat2d(vector) - reshaped_out = self._factor.left_multiply_inverse( - reshaped_vec, self._damping) - return utils.mat2d_to_layer_params(vector, reshaped_out) + def register_matpower(self, exp): + # Not needed for this. Matrix powers are computed on demand in the + # diagonal case + pass - def multiply(self, vector): - """Approximate damped Fisher-vector product. + def multiply_matpower(self, vector, exp): + """Multiplies the vector by the (damped) matrix-power of the block. Args: vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape [input_size, output_size] corresponding to layer's weights. If not, a 2-tuple of the former and a Tensor of shape [output_size] corresponding to the layer's bias. + exp: A scalar representing the power to raise the block before multiplying + it by the vector. Returns: - Tensor of the same shape, corresponding to the Fisher-vector product. + The vector left-multiplied by the (damped) matrix-power of the block. """ reshaped_vec = utils.layer_params_to_mat2d(vector) - reshaped_out = self._factor.left_multiply( - reshaped_vec, self._damping) + reshaped_out = self._factor.left_multiply_matpower( + reshaped_vec, exp, self._damping_func) return utils.mat2d_to_layer_params(vector, reshaped_out) - def tensors_to_compute_grads(self): - """Tensors to compute derivative of loss with respect to.""" - return self._outputs - - def register_additional_minibatch(self, inputs, outputs): - """Registers an additional minibatch to the FisherBlock. - - Args: - inputs: Tensor of shape [batch_size, input_size]. Inputs to the - matrix-multiply. - outputs: Tensor of shape [batch_size, output_size]. Layer preactivations. - """ - self._inputs.append(inputs) - self._outputs.append(outputs) - - @property - def num_registered_minibatches(self): - result = len(self._inputs) - assert result == len(self._outputs) - return result - -class ConvDiagonalFB(FisherBlock): - """FisherBlock for convolutional layers using a diagonal approx. +class ConvDiagonalFB(InputOutputMultiTower, FisherBlock): + """FisherBlock for 2-D convolutional layers using a diagonal approx. Estimates the Fisher Information matrix's diagonal entries for a convolutional layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" @@ -441,7 +547,13 @@ class ConvDiagonalFB(FisherBlock): to the layer's parameters 'w'. """ - def __init__(self, layer_collection, params, strides, padding): + def __init__(self, + layer_collection, + params, + strides, + padding, + data_format=None, + dilations=None): """Creates a ConvDiagonalFB block. Args: @@ -453,92 +565,115 @@ class ConvDiagonalFB(FisherBlock): containing the previous and a Tensor of shape [out_channels]. strides: The stride size in this layer (1-D Tensor of length 4). padding: The padding in this layer (e.g. "SAME"). + data_format: str or None. Format of input data. + dilations: List of 4 ints or None. Rate for dilation along all dimensions. + + Raises: + ValueError: if strides is not length-4. + ValueError: if dilations is not length-4. + ValueError: if channel is not last dimension. """ - self._inputs = [] - self._outputs = [] - self._strides = tuple(strides) if isinstance(strides, list) else strides + if len(strides) != 4: + raise ValueError("strides must contain 4 numbers.") + + if dilations is None: + dilations = [1, 1, 1, 1] + + if len(dilations) != 4: + raise ValueError("dilations must contain 4 numbers.") + + if not utils.is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels-last.") + + self._strides = maybe_tuple(strides) self._padding = padding + self._data_format = data_format + self._dilations = maybe_tuple(dilations) self._has_bias = isinstance(params, (tuple, list)) fltr = params[0] if self._has_bias else params self._filter_shape = tuple(fltr.shape.as_list()) + if len(self._filter_shape) != 4: + raise ValueError( + "Convolution filter must be of shape" + " [filter_height, filter_width, in_channels, out_channels].") + super(ConvDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - # Concatenate inputs, grads_list into single Tensors. - inputs = _concat_along_batch_dim(self._inputs) - grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + inputs, grads_list = self._process_data(grads_list) # Infer number of locations upon which convolution is applied. - inputs_shape = tuple(inputs.shape.as_list()) - self._num_locations = ( - inputs_shape[1] * inputs_shape[2] // - (self._strides[1] * self._strides[2])) - - self._damping = (self._num_locations - * normalize_damping(damping, self._num_locations)) + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), + self._strides) self._factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvDiagonalFactor, (inputs, grads_list, self._filter_shape, self._strides, self._padding, - self._has_bias)) - - def multiply_inverse(self, vector): - reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = self._factor.left_multiply_inverse( - reshaped_vect, self._damping) - return utils.mat2d_to_layer_params(vector, reshaped_out) - - def multiply(self, vector): - reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = self._factor.left_multiply( - reshaped_vect, self._damping) - return utils.mat2d_to_layer_params(vector, reshaped_out) + self._data_format, self._dilations, self._has_bias)) - def tensors_to_compute_grads(self): - return self._outputs + def damping_func(): + return self._num_locations * normalize_damping(damping, + self._num_locations) - def register_additional_minibatch(self, inputs, outputs): - """Registers an additional minibatch to the FisherBlock. + damping_id = (self._num_locations, "mult", "normalize_damping", damping, + self._num_locations) + self._damping_func = _package_func(damping_func, damping_id) - Args: - inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to - the convolution. - outputs: Tensor of shape [batch_size, height, width, output_size]. Layer - preactivations. - """ - self._inputs.append(inputs) - self._outputs.append(outputs) + def register_matpower(self, exp): + # Not needed for this. Matrix powers are computed on demand in the + # diagonal case + pass - @property - def num_registered_minibatches(self): - return len(self._inputs) + def multiply_matpower(self, vector, exp): + reshaped_vect = utils.layer_params_to_mat2d(vector) + reshaped_out = self._factor.left_multiply_matpower( + reshaped_vect, exp, self._damping_func) + return utils.mat2d_to_layer_params(vector, reshaped_out) class KroneckerProductFB(FisherBlock): - """A base class for FisherBlocks with separate input and output factors. + """A base class for blocks with separate input and output Kronecker factors. The Fisher block is approximated as a Kronecker product of the input and output factors. """ - def _register_damped_input_and_output_inverses(self, damping): - """Registers damped inverses for both the input and output factors. - - Sets the instance members _input_damping and _output_damping. Requires the - instance members _input_factor and _output_factor. + def __init__(self, layer_collection): + super(KroneckerProductFB, self).__init__(layer_collection) + + def _setup_damping(self, damping, normalization=None): + """Makes functions that compute the damping values for both factors.""" + def compute_damping(): + if normalization is not None: + maybe_normalized_damping = normalize_damping(damping, normalization) + else: + maybe_normalized_damping = damping + + return compute_pi_adjusted_damping(self._input_factor.get_cov(), + self._output_factor.get_cov(), + maybe_normalized_damping**0.5) + + if normalization is not None: + damping_id = ("compute_pi_adjusted_damping", + "cov", self._input_factor.name, + "cov", self._output_factor.name, + "normalize_damping", damping, normalization, "power", 0.5) + else: + damping_id = ("compute_pi_adjusted_damping", + "cov", self._input_factor.name, + "cov", self._output_factor.name, + damping, "power", 0.5) - Args: - damping: The base damping factor (float or Tensor) for the damped inverse. - """ - self._input_damping, self._output_damping = compute_pi_adjusted_damping( - self._input_factor.get_cov(), - self._output_factor.get_cov(), - damping**0.5) + self._input_damping_func = _package_func(lambda: compute_damping()[0], + damping_id + ("ref", 0)) + self._output_damping_func = _package_func(lambda: compute_damping()[1], + damping_id + ("ref", 1)) - self._input_factor.register_damped_inverse(self._input_damping) - self._output_factor.register_damped_inverse(self._output_damping) + def register_matpower(self, exp): + self._input_factor.register_matpower(exp, self._input_damping_func) + self._output_factor.register_matpower(exp, self._output_damping_func) @property def _renorm_coeff(self): @@ -552,28 +687,15 @@ class KroneckerProductFB(FisherBlock): """ return 1.0 - def multiply_inverse(self, vector): - reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = self._output_factor.right_multiply_inverse( - reshaped_vector, - self._output_damping) - reshaped_out = self._input_factor.left_multiply_inverse( - reshaped_out, self._input_damping) - if self._renorm_coeff != 1.0: - reshaped_out /= math_ops.cast( - self._renorm_coeff, dtype=reshaped_out.dtype) - return utils.mat2d_to_layer_params(vector, reshaped_out) - - def multiply(self, vector): + def multiply_matpower(self, vector, exp): reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = self._output_factor.right_multiply( - reshaped_vector, - self._output_damping) - reshaped_out = self._input_factor.left_multiply( - reshaped_out, self._input_damping) + reshaped_out = self._output_factor.right_multiply_matpower( + reshaped_vector, exp, self._output_damping_func) + reshaped_out = self._input_factor.left_multiply_matpower( + reshaped_out, exp, self._input_damping_func) if self._renorm_coeff != 1.0: - reshaped_out *= math_ops.cast( - self._renorm_coeff, dtype=reshaped_out.dtype) + renorm_coeff = math_ops.cast(self._renorm_coeff, dtype=reshaped_out.dtype) + reshaped_out *= math_ops.cast(renorm_coeff**exp, dtype=reshaped_out.dtype) return utils.mat2d_to_layer_params(vector, reshaped_out) def full_fisher_block(self): @@ -590,10 +712,10 @@ class KroneckerProductFB(FisherBlock): right_factor) -class EmbeddingKFACFB(KroneckerProductFB): +class EmbeddingKFACFB(InputOutputMultiTower, KroneckerProductFB): """K-FAC FisherBlock for embedding layers. - This FisherBlock is similar to EmbeddingKFACFB, except that its + This FisherBlock is similar to FullyConnectedKFACBasicFB, except that its input factor is approximated by a diagonal matrix. In the case that each example references exactly one embedding, this approximation is exact. @@ -608,8 +730,6 @@ class EmbeddingKFACFB(KroneckerProductFB): Fisher information matrix to which this FisherBlock belongs. vocab_size: int. Size of vocabulary for this embedding layer. """ - self._inputs = [] - self._outputs = [] self._vocab_size = vocab_size super(EmbeddingKFACFB, self).__init__(layer_collection) @@ -624,41 +744,17 @@ class EmbeddingKFACFB(KroneckerProductFB): damping: 0-D Tensor or float. 'damping' * identity is approximately added to this FisherBlock's Fisher approximation. """ - # TODO(b/68033310): Validate which of, - # (1) summing on a single device (as below), or - # (2) on each device in isolation and aggregating - # is faster. - inputs = _concat_along_batch_dim(self._inputs) - grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( # - fisher_factors.EmbeddingInputKroneckerFactor, # - ((inputs,), self._vocab_size)) - self._output_factor = self._layer_collection.make_or_get_factor( # - fisher_factors.FullyConnectedKroneckerFactor, # - (grads_list,)) - self._register_damped_input_and_output_inverses(damping) - - def tensors_to_compute_grads(self): - return self._outputs - - def register_additional_minibatch(self, inputs, outputs): - """Registers an additional minibatch to the FisherBlock. - - Args: - inputs: Tensor of shape [batch_size, input_size]. Inputs to the - matrix-multiply. - outputs: Tensor of shape [batch_size, output_size]. Layer preactivations. - """ - self._inputs.append(inputs) - self._outputs.append(outputs) + inputs, grads_list = self._process_data(grads_list) - @property - def num_registered_minibatches(self): - return len(self._inputs) + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.EmbeddingInputKroneckerFactor, + (inputs, self._vocab_size)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, (grads_list,)) + self._setup_damping(damping) -class FullyConnectedKFACBasicFB(KroneckerProductFB): +class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB): """K-FAC FisherBlock for fully-connected (dense) layers. This uses the Kronecker-factorized approximation from the original @@ -674,8 +770,6 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB): has_bias: Whether the component Kronecker factors have an additive bias. (Default: False) """ - self._inputs = [] - self._outputs = [] self._has_bias = has_bias super(FullyConnectedKFACBasicFB, self).__init__(layer_collection) @@ -690,42 +784,19 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB): damping: 0-D Tensor or float. 'damping' * identity is approximately added to this FisherBlock's Fisher approximation. """ - # TODO(b/68033310): Validate which of, - # (1) summing on a single device (as below), or - # (2) on each device in isolation and aggregating - # is faster. - inputs = _concat_along_batch_dim(self._inputs) - grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( # - fisher_factors.FullyConnectedKroneckerFactor, # + inputs, grads_list = self._process_data(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, ((inputs,), self._has_bias)) - self._output_factor = self._layer_collection.make_or_get_factor( # - fisher_factors.FullyConnectedKroneckerFactor, # + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, (grads_list,)) - self._register_damped_input_and_output_inverses(damping) + self._setup_damping(damping) - def tensors_to_compute_grads(self): - return self._outputs - def register_additional_minibatch(self, inputs, outputs): - """Registers an additional minibatch to the FisherBlock. - - Args: - inputs: Tensor of shape [batch_size, input_size]. Inputs to the - matrix-multiply. - outputs: Tensor of shape [batch_size, output_size]. Layer preactivations. - """ - self._inputs.append(inputs) - self._outputs.append(outputs) - - @property - def num_registered_minibatches(self): - return len(self._inputs) - - -class ConvKFCBasicFB(KroneckerProductFB): - """FisherBlock for 2D convolutional layers using the basic KFC approx. +class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): + """FisherBlock for convolutional layers using the basic KFC approx. Estimates the Fisher Information matrix's blog for a convolutional layer. @@ -748,23 +819,40 @@ class ConvKFCBasicFB(KroneckerProductFB): See equation 23 in https://arxiv.org/abs/1602.01407 for details. """ - def __init__(self, layer_collection, params, strides, padding): + def __init__(self, + layer_collection, + params, + padding, + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None): """Creates a ConvKFCBasicFB block. Args: layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs. params: The parameters (Tensor or tuple of Tensors) of this layer. If - kernel alone, a Tensor of shape [kernel_height, kernel_width, + kernel alone, a Tensor of shape [..spatial_filter_shape.., in_channels, out_channels]. If kernel and bias, a tuple of 2 elements containing the previous and a Tensor of shape [out_channels]. - strides: The stride size in this layer (1-D Tensor of length 4). - padding: The padding in this layer (1-D of Tensor length 4). + padding: str. Padding method. + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". """ - self._inputs = [] - self._outputs = [] - self._strides = tuple(strides) if isinstance(strides, list) else strides self._padding = padding + self._strides = maybe_tuple(strides) + self._dilation_rate = maybe_tuple(dilation_rate) + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn self._has_bias = isinstance(params, (tuple, list)) fltr = params[0] if self._has_bias else params @@ -773,145 +861,606 @@ class ConvKFCBasicFB(KroneckerProductFB): super(ConvKFCBasicFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - # TODO(b/68033310): Validate which of, - # (1) summing on a single device (as below), or - # (2) on each device in isolation and aggregating - # is faster. - inputs = _concat_along_batch_dim(self._inputs) - grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) - # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(inputs.shape.as_list(), + self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(), self._strides) + inputs, grads_list = self._process_data(grads_list) + self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvInputKroneckerFactor, - (inputs, self._filter_shape, self._strides, self._padding, + (inputs, self._filter_shape, self._padding, self._strides, + self._dilation_rate, self._data_format, self._extract_patches_fn, self._has_bias)) self._output_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) - damping = normalize_damping(damping, self._num_locations) - self._register_damped_input_and_output_inverses(damping) - self._damping = damping + self._setup_damping(damping, normalization=self._num_locations) @property def _renorm_coeff(self): return self._num_locations - def tensors_to_compute_grads(self): - return self._outputs - def register_additional_minibatch(self, inputs, outputs): - """Registers an additional minibatch to the FisherBlock. +class DepthwiseConvDiagonalFB(ConvDiagonalFB): + """FisherBlock for depthwise_conv2d(). + + Equivalent to ConvDiagonalFB applied to each input channel in isolation. + """ + + def __init__(self, + layer_collection, + params, + strides, + padding, + rate=None, + data_format=None): + """Creates a DepthwiseConvKFCBasicFB block. Args: - inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to - the convolution. - outputs: Tensor of shape [batch_size, height, width, output_size]. Layer - preactivations. + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: Tensor of shape [filter_height, filter_width, in_channels, + channel_multiplier]. + strides: List of 4 ints. Strides along all dimensions. + padding: str. Padding method. + rate: List of 4 ints or None. Rate for dilation along all dimensions. + data_format: str or None. Format of input data. + + Raises: + NotImplementedError: If parameters contains bias. + ValueError: If filter is not 4-D. + ValueError: If strides is not length-4. + ValueError: If rates is not length-2. + ValueError: If channels are not last dimension. """ - self._inputs.append(inputs) - self._outputs.append(outputs) + if isinstance(params, (tuple, list)): + raise NotImplementedError("Bias not yet supported.") + + if params.shape.ndims != 4: + raise ValueError("Filter must be 4-D.") + + if len(strides) != 4: + raise ValueError("strides must account for 4 dimensions.") + + if rate is not None: + if len(rate) != 2: + raise ValueError("rate must only account for spatial dimensions.") + rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. + + if not utils.is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels-last.") + + super(DepthwiseConvDiagonalFB, self).__init__( + layer_collection=layer_collection, + params=params, + strides=strides, + padding=padding, + dilations=rate, + data_format=data_format) + + # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). + filter_height, filter_width, in_channels, channel_multiplier = ( + params.shape.as_list()) + self._filter_shape = (filter_height, filter_width, in_channels, + in_channels * channel_multiplier) + + def multiply_matpower(self, vector, exp): + conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) + conv2d_result = super(DepthwiseConvDiagonalFB, self).multiply_matpower( + conv2d_vector, exp) + return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) + + +class DepthwiseConvKFCBasicFB(ConvKFCBasicFB): + """FisherBlock for depthwise_conv2d(). + + Equivalent to ConvKFCBasicFB applied to each input channel in isolation. + """ + + def __init__(self, + layer_collection, + params, + strides, + padding, + rate=None, + data_format=None): + """Creates a DepthwiseConvKFCBasicFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: Tensor of shape [filter_height, filter_width, in_channels, + channel_multiplier]. + strides: List of 4 ints. Strides along all dimensions. + padding: str. Padding method. + rate: List of 4 ints or None. Rate for dilation along all dimensions. + data_format: str or None. Format of input data. + + Raises: + NotImplementedError: If parameters contains bias. + ValueError: If filter is not 4-D. + ValueError: If strides is not length-4. + ValueError: If rates is not length-2. + ValueError: If channels are not last dimension. + """ + if isinstance(params, (tuple, list)): + raise NotImplementedError("Bias not yet supported.") + + if params.shape.ndims != 4: + raise ValueError("Filter must be 4-D.") + + if len(strides) != 4: + raise ValueError("strides must account for 4 dimensions.") + + if rate is not None: + if len(rate) != 2: + raise ValueError("rate must only account for spatial dimensions.") + rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. + + if not utils.is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels-last.") + + super(DepthwiseConvKFCBasicFB, self).__init__( + layer_collection=layer_collection, + params=params, + padding=padding, + strides=strides, + dilation_rate=rate, + data_format=data_format, + extract_patches_fn="extract_image_patches") + + # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). + filter_height, filter_width, in_channels, channel_multiplier = ( + params.shape.as_list()) + self._filter_shape = (filter_height, filter_width, in_channels, + in_channels * channel_multiplier) + + def multiply_matpower(self, vector, exp): + conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) + conv2d_result = super(DepthwiseConvKFCBasicFB, self).multiply_matpower( + conv2d_vector, exp) + return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) - @property - def num_registered_minibatches(self): - return len(self._inputs) +def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin + """Converts a convolution filter for use with conv2d. -def _concat_along_batch_dim(tensor_list): - """Concatenate tensors along batch (first) dimension. + Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's + compatible with tf.nn.conv2d(). Args: - tensor_list: list of Tensors or list of tuples of Tensors. + filter: Tensor of shape [height, width, in_channels, channel_multiplier]. + name: None or str. Name of Op. Returns: - Tensor or tuple of Tensors. + Tensor of shape [height, width, in_channels, out_channels]. - Raises: - ValueError: If 'tensor_list' is empty. + """ + with ops.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter", + [filter]): + filter = ops.convert_to_tensor(filter) + filter_height, filter_width, in_channels, channel_multiplier = ( + filter.shape.as_list()) + + results = [] + for i in range(in_channels): + # Slice out one in_channel's filter. Insert zeros around it to force it + # to affect that channel and that channel alone. + elements = [] + if i > 0: + elements.append( + array_ops.zeros( + [filter_height, filter_width, i, channel_multiplier])) + elements.append(filter[:, :, i:(i + 1), :]) + if i + 1 < in_channels: + elements.append( + array_ops.zeros([ + filter_height, filter_width, in_channels - (i + 1), + channel_multiplier + ])) + + # Concat along in_channel. + results.append( + array_ops.concat(elements, axis=-2, name="in_channel_%d" % i)) + + # Concat along out_channel. + return array_ops.concat(results, axis=-1, name="out_channel") + + +def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin + """Converts a convolution filter for use with depthwise_conv2d. + + Transforms a filter for use with tf.nn.conv2d() to one that's + compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along + the diagonal. + + Args: + filter: Tensor of shape [height, width, in_channels, out_channels]. + name: None or str. Name of Op. + + Returns: + Tensor of shape, + [height, width, in_channels, channel_multiplier] + Raises: + ValueError: if out_channels is not evenly divisible by in_channels. """ - if not tensor_list: - raise ValueError( - "Cannot concatenate Tensors if there are no Tensors to concatenate.") - - if isinstance(tensor_list[0], (tuple, list)): - # [(tensor1a, tensor1b), - # (tensor2a, tensor2b), ...] --> (tensor_a, tensor_b) - return tuple( - array_ops.concat(tensors, axis=0) for tensors in zip(*tensor_list)) - else: - # [tensor1, tensor2] --> tensor - return array_ops.concat(tensor_list, axis=0) + with ops.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter", + [filter]): + filter = ops.convert_to_tensor(filter) + filter_height, filter_width, in_channels, out_channels = ( + filter.shape.as_list()) + + if out_channels % in_channels != 0: + raise ValueError("out_channels must be evenly divisible by in_channels.") + channel_multiplier = out_channels // in_channels + + results = [] + filter = array_ops.reshape(filter, [ + filter_height, filter_width, in_channels, in_channels, + channel_multiplier + ]) + for i in range(in_channels): + # Slice out output corresponding to the correct filter. + filter_slice = array_ops.reshape( + filter[:, :, i, i, :], + [filter_height, filter_width, 1, channel_multiplier]) + results.append(filter_slice) + + # Concat along out_channel. + return array_ops.concat(results, axis=-2, name="in_channels") + + +def maybe_tuple(obj): + if not isinstance(obj, list): + return obj + return tuple(obj) def num_conv_locations(input_shape, strides): """Returns the number of spatial locations a 2D Conv kernel is applied to. Args: - input_shape: list representing shape of inputs to the Conv layer. - strides: list representing strides for the Conv kernel. + input_shape: List of ints representing shape of inputs to + tf.nn.convolution(). + strides: List of ints representing strides along spatial dimensions as + passed in to tf.nn.convolution(). Returns: A scalar |T| denoting the number of spatial locations for the Conv layer. """ - return input_shape[1] * input_shape[2] // (strides[1] * strides[2]) + spatial_input_locations = np.prod(input_shape[1:-1]) + + if strides is None: + spatial_strides_divisor = 1 + else: + spatial_strides_divisor = np.prod(strides) + return spatial_input_locations // spatial_strides_divisor -class FullyConnectedMultiIndepFB(KroneckerProductFB): + +class InputOutputMultiTowerMultiUse(InputOutputMultiTower): + """Adds methods for multi-use/time-step case to InputOutputMultiTower.""" + + def __init__(self, num_uses=None, *args, **kwargs): + self._num_uses = num_uses + super(InputOutputMultiTowerMultiUse, self).__init__(*args, **kwargs) + + def _process_data(self, grads_list): + """Process temporal/multi-use data into the format used by the factors. + + This function takes inputs and grads_lists data and processes it into + one of the formats expected by the FisherFactor classes (depending on + the value of the global configuration variable TOWER_STRATEGY). + + It accepts the data in one of two initial formats. The first possible + format is where self._inputs is a list of list of Tensors. The first index + is tower, the second is use/time-step. grads_list, meanwhile, is a list + over sources of such lists of lists. + + The second possible data format is where self._inputs is a Tensor with + uses/times-steps folded into the batch dimension. i.e. it is a Tensor + of shape [num_uses * size_batch, ...] which represents a reshape of a + Tensor of shape [num_uses, size_batch, ...]. And similarly grads_list is + a list over sources of such Tensors. + + There are two possible formats which inputs and grads_list are transformed + into. + + If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing + a single tensor (represented as a PartitionedTensor object) with all of + the data from the towers, as well as the uses/time-steps, concatenated + together. In this tensor the leading dimension is the batch and + use/time-step dimensions folded together (with 'use' being the major of + these two, so that the tensors can be thought of as reshapes of ones of + shape [num_uses, batch_size, ...]). grads_list is similarly formatted as a + tuple over sources of such tensors. + + If TOWER_STRATEGY is "separate" the inputs are formatted into lists of + tensors over towers. Each of these tensors has a similar format to + the tensor produced by the "concat" option, except that each contains + only the data from a single tower. grads_list is similarly formatted + into a tuple over sources of such tuples. + + Args: + grads_list: grads_list in its initial format (see above). + + Returns: + inputs: self._inputs transformed into the appropriate format (see + above). + grads_list: grads_list transformed into the appropriate format (see + above). + + Raises: + ValueError: If TOWER_STRATEGY is not one of "separate" or "concat". + ValueError: If the given/initial format of self._inputs and grads_list + isn't recognized, or doesn't agree with self._num_uses. + """ + + inputs = self._inputs + + if isinstance(inputs[0], (list, tuple)): + num_uses = len(inputs[0]) + if self._num_uses is not None and self._num_uses != num_uses: + raise ValueError("num_uses argument doesn't match length of inputs.") + else: + self._num_uses = num_uses + + # Check that all mini-batches/towers have the same number of uses + if not all(len(input_) == num_uses for input_ in inputs): + raise ValueError("Length of inputs argument is inconsistent across " + "towers.") + + if fisher_factors.TOWER_STRATEGY == "concat": + # Reverse the tower and use/time-step indices, so that use is now first, + # and towers is second + inputs = tuple(zip(*inputs)) + + # Flatten the two dimensions + inputs = nest.flatten(inputs) + + # Merge everything together into a PartitionedTensor. We package it in + # a singleton tuple since the factors will expect a list over towers + inputs = (utils.PartitionedTensor(inputs),) + + elif fisher_factors.TOWER_STRATEGY == "separate": + # Merge together the uses/time-step dimension into PartitionedTensors, + # but keep the leading dimension (towers) intact for the factors to + # process individually. + inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs) + + else: + raise ValueError("Global config variable TOWER_STRATEGY must be one of " + "'concat' or 'separate'.") + + # Now we perform the analogous processing for grads_list + if isinstance(grads_list[0][0], (list, tuple)): + num_uses = len(grads_list[0][0]) + if self._num_uses is not None and self._num_uses != num_uses: + raise ValueError("num_uses argument doesn't match length of outputs, " + "or length of outputs is inconsistent with length of " + "inputs.") + else: + self._num_uses = num_uses + + if not all(len(grad) == num_uses for grads in grads_list + for grad in grads): + raise ValueError("Length of outputs argument is inconsistent across " + "towers.") + + if fisher_factors.TOWER_STRATEGY == "concat": + # Reverse the tower and use/time-step indices, so that use is now first, + # and towers is second + grads_list = tuple(tuple(zip(*grads)) for grads in grads_list) + + # Flatten the two dimensions, leaving the leading dimension (source) + # intact + grads_list = tuple(nest.flatten(grads) for grads in grads_list) + + # Merge inner dimensions together into PartitionedTensors. We package + # them in a singleton tuple since the factors will expect a list over + # towers + grads_list = tuple((utils.PartitionedTensor(grads),) + for grads in grads_list) + + elif fisher_factors.TOWER_STRATEGY == "separate": + # Merge together the uses/time-step dimension into PartitionedTensors, + # but keep the leading dimension (towers) intact for the factors to + # process individually. + grads_list = tuple(tuple(utils.PartitionedTensor(grad) + for grad in grads) + for grads in grads_list) + + else: + raise ValueError("Global config variable TOWER_STRATEGY must be one of " + "'concat' or 'separate'.") + + if self._num_uses is None: + raise ValueError("You must supply a value for the num_uses argument if " + "the number of uses cannot be inferred from inputs or " + "outputs arguments (e.g. if they are both given in the " + "single Tensor format, instead of as lists of Tensors.") + + return inputs, grads_list + + +class FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): """FisherBlock for fully-connected layers that share parameters. + + This class implements the "independence across time" approximation from the + following paper: + https://openreview.net/pdf?id=HyMTkQZAb """ - def __init__(self, layer_collection, inputs, outputs, has_bias=False): + def __init__(self, layer_collection, has_bias=False, num_uses=None): """Creates a FullyConnectedMultiIndepFB block. Args: layer_collection: LayerCollection instance. - inputs: list or tuple of Tensors. Each Tensor has shape [batch_size, - inputs_size]. - outputs: list or tuple of Tensors. Each Tensor has shape [batch_size, - outputs_size]. has_bias: bool. If True, estimates Fisher with respect to a bias parameter as well as the layer's parameters. + num_uses: int or None. Number of uses of the layer in the model's graph. + Only required if the data is formatted with uses/time folded into the + batch dimension (instead of uses/time being a list dimension). + (Default: None) """ - - assert len(inputs) == len(outputs) - # We need to make sure inputs and outputs are tuples and not lists so that - # they get hashed by layer_collection.make_or_get_factor properly. - self._inputs = tuple(inputs) - self._outputs = tuple(outputs) self._has_bias = has_bias - self._num_uses = len(inputs) - super(FullyConnectedMultiIndepFB, self).__init__(layer_collection) - - @property - def num_registered_minibatches(self): - # TODO(b/69411207): Add support for registering additional minibatches. - return 1 + super(FullyConnectedMultiIndepFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedMultiKF, - ((self._inputs,), self._has_bias)) + ((inputs,), self._num_uses, self._has_bias)) self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, (grads_list,)) + fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) - damping = normalize_damping(damping, self._num_uses) - self._register_damped_input_and_output_inverses(damping) + self._setup_damping(damping, normalization=self._num_uses) @property def _renorm_coeff(self): - return self._num_uses + return float(self._num_uses) + + +class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): + """FisherBlock for 2D convolutional layers using the basic KFC approx. + + Similar to ConvKFCBasicFB except that this version supports multiple + uses/time-steps via a standard independence approximation. Similar to the + "independence across time" used in FullyConnectedMultiIndepFB but generalized + in the obvious way to conv layers. + """ + + def __init__(self, + layer_collection, + params, + padding, + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None, + num_uses=None): + """Creates a ConvKFCBasicMultiIndepFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters (Tensor or tuple of Tensors) of this layer. If + kernel alone, a Tensor of shape [..spatial_filter_shape.., + in_channels, out_channels]. If kernel and bias, a tuple of 2 elements + containing the previous and a Tensor of shape [out_channels]. + padding: str. Padding method. + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". + num_uses: int or None. Number of uses of the layer in the model's graph. + Only required if the data is formatted with uses/time folded into the + batch dimension (instead of uses/time being a list dimension). + (Default: None) + """ + self._padding = padding + self._strides = maybe_tuple(strides) + self._dilation_rate = maybe_tuple(dilation_rate) + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn + self._has_bias = isinstance(params, (tuple, list)) + + fltr = params[0] if self._has_bias else params + self._filter_shape = tuple(fltr.shape.as_list()) + + super(ConvKFCBasicMultiIndepFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + + # Infer number of locations upon which convolution is applied. + self._num_locations = num_conv_locations(inputs.shape.as_list(), + self._strides) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvInputKroneckerFactor, + (inputs, self._filter_shape, self._padding, self._strides, + self._dilation_rate, self._data_format, self._extract_patches_fn, + self._has_bias)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) + + self._setup_damping(damping, normalization= + (self._num_locations * self._num_uses)) + + @property + def _renorm_coeff(self): + return self._num_locations * self._num_uses - def tensors_to_compute_grads(self): - return self._outputs - def num_inputs(self): - return len(self._inputs) +class EmbeddingKFACMultiIndepFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): + """K-FAC FisherBlock for embedding layers used multiple times in the graph. + + Similar to EmbeddingKFACFB except that this version supports multiple uses + of the parameter within a single model. These uses could correspond to time + steps in an RNN architecture, but they don't have to. + + Does not support bias parameters. + """ + + def __init__(self, layer_collection, vocab_size, num_uses=None): + """Creates a EmbeddingKFACMultiIndepFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + vocab_size: int. Size of vocabulary for this embedding layer. + num_uses: int or None. Number of uses of the layer in the model's graph. + Only required if the data is formatted with time folded into the batch + dimension (instead of time being a list dimension). (Default: None) + """ + self._vocab_size = vocab_size + + super(EmbeddingKFACMultiIndepFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + def instantiate_factors(self, grads_list, damping): + """Instantiate Kronecker Factors for this FisherBlock. + + Args: + grads_list: List of list of list of Tensors. grads_list[i][j][k] is the + gradient of the loss with respect to 'outputs' from source 'i', + tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape + [tower_minibatch_size, output_size]. + damping: 0-D Tensor or float. 'damping' * identity is approximately added + to this FisherBlock's Fisher approximation. + """ + inputs, grads_list = self._process_data(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.EmbeddingInputKroneckerFactor, + (inputs, self._vocab_size)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) + self._setup_damping(damping, normalization=self._num_uses) + + @property + def _renorm_coeff(self): + return float(self._num_uses) class SeriesFBApproximation(enum.IntEnum): @@ -920,34 +1469,35 @@ class SeriesFBApproximation(enum.IntEnum): option2 = 2 -class FullyConnectedSeriesFB(FisherBlock): +class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): """FisherBlock for fully-connected layers that share parameters across time. - See the following preprint for details: + This class implements the "Option 1" and "Option 2" approximation from the + following paper: https://openreview.net/pdf?id=HyMTkQZAb See the end of the appendix of the paper for a pseudo-code of the - algorithm being implemented by multiply_inverse here. Note that we are + algorithm being implemented by multiply_matpower here. Note that we are using pre-computed versions of certain matrix-matrix products to speed things up. This is explicitly explained wherever it is done. """ def __init__(self, layer_collection, - inputs, - outputs, has_bias=False, + num_uses=None, option=SeriesFBApproximation.option2): """Constructs a new `FullyConnectedSeriesFB`. Args: layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs. - inputs: List of tensors of shape [batch_size, input_size]. - Inputs to the layer. - outputs: List of tensors of shape [batch_size, input_size]. - Outputs of the layer (before activations). has_bias: Whether the layer includes a bias parameter. + num_uses: int or None. Number of time-steps over which the layer + is used. Only required if the data is formatted with time folded into + the batch dimension (instead of time being a list dimension). + (Default: None) option: A `SeriesFBApproximation` specifying the simplifying assumption to be used in this block. `option1` approximates the cross-covariance over time as a symmetric matrix, while `option2` makes @@ -955,48 +1505,58 @@ class FullyConnectedSeriesFB(FisherBlock): 3.5 of the paper for more details. """ - assert len(inputs) == len(outputs) - # We need to make sure inputs and outputs are tuples and not lists so that - # they get hashed by layer_collection.make_or_get_factor properly. - self._inputs = tuple(inputs) - self._outputs = tuple(outputs) self._has_bias = has_bias - self._num_timesteps = len(inputs) self._option = option - super(FullyConnectedSeriesFB, self).__init__(layer_collection) + super(FullyConnectedSeriesFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + @property + def _num_timesteps(self): + return self._num_uses @property - def num_registered_minibatches(self): - # TODO(b/69411207): Add support for registering additional minibatches. - return 1 + def _renorm_coeff(self): + # This should no longer be used since the multiply_X functions from the base + # class have been overridden + assert False def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, ((self._inputs,), self._has_bias)) + fisher_factors.FullyConnectedMultiKF, + ((inputs,), self._num_uses, self._has_bias)) + self._input_factor.register_cov_dt1() self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, (grads_list,)) + fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) + self._output_factor.register_cov_dt1() + + self._setup_damping(damping, normalization=self._num_uses) - damping = normalize_damping(damping, self._num_timesteps) - self._damping_input, self._damping_output = compute_pi_adjusted_damping( - self._input_factor.get_cov(), - self._output_factor.get_cov(), - damping**0.5) + def register_matpower(self, exp): + if exp != -1: + raise NotImplementedError("FullyConnectedSeriesFB only supports inverse" + "multiplications.") if self._option == SeriesFBApproximation.option1: - self._input_factor.register_option1quants(self._damping_input) - self._output_factor.register_option1quants(self._damping_output) + self._input_factor.register_option1quants(self._input_damping_func) + self._output_factor.register_option1quants(self._output_damping_func) elif self._option == SeriesFBApproximation.option2: - self._input_factor.register_option2quants(self._damping_input) - self._output_factor.register_option2quants(self._damping_output) + self._input_factor.register_option2quants(self._input_damping_func) + self._output_factor.register_option2quants(self._output_damping_func) else: raise ValueError( "Unrecognized FullyConnectedSeriesFB approximation: {}".format( self._option)) - def multiply_inverse(self, vector): + def multiply_matpower(self, vector, exp): + if exp != -1: + raise NotImplementedError("FullyConnectedSeriesFB only supports inverse" + "multiplications.") + # pylint: disable=invalid-name Z = utils.layer_params_to_mat2d(vector) @@ -1008,8 +1568,10 @@ class FullyConnectedSeriesFB(FisherBlock): if self._option == SeriesFBApproximation.option1: # Note that L_A = A0^(-1/2) * U_A and L_G = G0^(-1/2) * U_G. - L_A, psi_A = self._input_factor.get_option1quants(self._damping_input) - L_G, psi_G = self._output_factor.get_option1quants(self._damping_output) + L_A, psi_A = self._input_factor.get_option1quants( + self._input_damping_func) + L_G, psi_G = self._output_factor.get_option1quants( + self._output_damping_func) def gamma(x): # We are assuming that each case has the same number of time-steps. @@ -1046,9 +1608,10 @@ class FullyConnectedSeriesFB(FisherBlock): # Note that P_A = A_1^T * A_0^(-1) and P_G = G_1^T * G_0^(-1), # and K_A = A_0^(-1/2) * E_A and K_G = G_0^(-1/2) * E_G. - P_A, K_A, mu_A = self._input_factor.get_option2quants(self._damping_input) + P_A, K_A, mu_A = self._input_factor.get_option2quants( + self._input_damping_func) P_G, K_G, mu_G = self._output_factor.get_option2quants( - self._damping_output) + self._output_damping_func) # Our approach differs superficially from the pseudo-code in the paper # in order to reduce the total number of matrix-matrix multiplies. @@ -1101,12 +1664,3 @@ class FullyConnectedSeriesFB(FisherBlock): return utils.mat2d_to_layer_params(vector, Z) # pylint: enable=invalid-name - - def multiply(self, vector): - raise NotImplementedError - - def tensors_to_compute_grads(self): - return self._outputs - - def num_inputs(self): - return len(self._inputs) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index 603d8b8b210279ee6d8f1de0ce10869fde23f4d9..353e1c6abb738cf3ef59d3e188da2727b712b21a 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -36,6 +36,8 @@ from tensorflow.python.ops import special_math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import moving_averages +from tensorflow.python.util import nest + # Whether to initialize covariance estimators at a zero matrix (or the identity # matrix). @@ -53,36 +55,25 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2 # matrix powers. Must be nonnegative. EIGENVALUE_CLIPPING_THRESHOLD = 0.0 -# Colocate the covariance ops and variables with the input tensors for each -# factor. -COLOCATE_COV_OPS_WITH_INPUTS = True - - -@contextlib.contextmanager -def maybe_colocate_with(op): - """Context to colocate with `op` if `COLOCATE_COV_OPS_WITH_INPUTS`.""" - if COLOCATE_COV_OPS_WITH_INPUTS: - if isinstance(op, (list, tuple)): - with tf_ops.colocate_with(op[0]): - yield - else: - with tf_ops.colocate_with(op): - yield - else: - yield +# TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data +# passed to the factors from the blocks will be concatenated across towers +# (lazilly via PartitionedTensor objects). Otherwise a tuple of tensors over +# towers will be passed in, and the factors will iterate over this and do the +# cov computations separately for each one, averaging the results together. +TOWER_STRATEGY = "concat" def set_global_constants(init_covariances_at_zero=None, zero_debias=None, eigenvalue_decomposition_threshold=None, eigenvalue_clipping_threshold=None, - colocate_cov_ops_with_inputs=None): + tower_strategy=None): """Sets various global constants used by the classes in this module.""" global INIT_COVARIANCES_AT_ZERO global ZERO_DEBIAS global EIGENVALUE_DECOMPOSITION_THRESHOLD global EIGENVALUE_CLIPPING_THRESHOLD - global COLOCATE_COV_OPS_WITH_INPUTS + global TOWER_STRATEGY if init_covariances_at_zero is not None: INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero @@ -92,8 +83,8 @@ def set_global_constants(init_covariances_at_zero=None, EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold if eigenvalue_clipping_threshold is not None: EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold - if colocate_cov_ops_with_inputs is not None: - COLOCATE_COV_OPS_WITH_INPUTS = colocate_cov_ops_with_inputs + if tower_strategy is not None: + TOWER_STRATEGY = tower_strategy def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument @@ -112,6 +103,15 @@ def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: di return array_ops.ones(shape, dtype) +@contextlib.contextmanager +def place_on_device(device): + if device is not None and len(device): + with tf_ops.device(device): + yield + else: + yield + + def compute_cov(tensor, tensor_right=None, normalizer=None): """Compute the empirical second moment of the rows of a 2D Tensor. @@ -181,7 +181,9 @@ def scope_string_from_params(params): name_parts = [] for param in params: - if isinstance(param, (tuple, list)): + if param is None: + name_parts.append("None") + elif isinstance(param, (tuple, list)): if all([isinstance(p, int) for p in param]): name_parts.append("-".join([str(p) for p in param])) else: @@ -190,6 +192,8 @@ def scope_string_from_params(params): name_parts.append(str(param)) elif isinstance(param, (tf_ops.Tensor, variables.Variable)): name_parts.append(scope_string_from_name(param)) + elif isinstance(param, utils.PartitionedTensor): + name_parts.append(scope_string_from_name(param.tensors)) else: raise ValueError("Encountered an unsupported param type {}".format( type(param))) @@ -207,6 +211,22 @@ def scalar_or_tensor_to_string(val): return repr(val) if np.isscalar(val) else scope_string_from_name(val) +def list_to_string(lst): + return "_".join(val if isinstance(val, six.string_types) + else scalar_or_tensor_to_string(val) for val in lst) + + +def graph_func_to_id(func): + """Returns a hashable object that represents func's computation.""" + # TODO(b/74201126): replace with Topohash of func's output + return func.func_id + + +def graph_func_to_string(func): + # TODO(b/74201126): replace with Topohash of func's output + return list_to_string(func.func_id) + + @six.add_metaclass(abc.ABCMeta) class FisherFactor(object): """Base class for objects modeling factors of approximate Fisher blocks. @@ -223,13 +243,10 @@ class FisherFactor(object): Note that for blocks that aren't based on approximations, a 'factor' can be the entire block itself, as is the case for the diagonal and full representations. - - Subclasses must implement the _compute_new_cov() method, and the _var_scope - and _cov_shape properties. """ def __init__(self): - self.instantiate_covariance() + self._cov = None @abc.abstractproperty def _var_scope(self): @@ -240,6 +257,10 @@ class FisherFactor(object): """ pass + @property + def name(self): + return self._var_scope + @abc.abstractproperty def _cov_shape(self): """The shape of the variable backing this FisherFactor.""" @@ -257,6 +278,10 @@ class FisherFactor(object): """ pass + @abc.abstractproperty + def _num_towers(self): + pass + @abc.abstractproperty def _dtype(self): """dtype for variable backing this factor.""" @@ -267,8 +292,9 @@ class FisherFactor(object): """Function for initializing covariance variable.""" return covariance_initializer - def instantiate_covariance(self): - """Instantiates the covariance Variable as the instance member _cov.""" + def instantiate_cov_variables(self): + """Makes the internal cov variable(s).""" + assert self._cov is None with variable_scope.variable_scope(self._var_scope): self._cov = variable_scope.get_variable( "cov", @@ -278,12 +304,14 @@ class FisherFactor(object): dtype=self._dtype) @abc.abstractmethod - def _compute_new_cov(self, idx=0): + def _compute_new_cov(self, source, tower): """Computes minibatch-estimated covariance for a single source. Args: - idx: int in [0, self._num_sources). Which source to use when estimating - covariance. + source: int in [0, self._num_sources). Which source to use when computing + the cov update. + tower: int in [0, self._num_towers). Which tower to use when computing + the cov update. Returns: Tensor of same shape as self.get_cov_var(). @@ -298,22 +326,33 @@ class FisherFactor(object): Returns: An Op for updating the covariance Variable referenced by _cov. """ - new_cov_contribs = tuple(self._compute_new_cov(idx) - for idx in range(self._num_sources)) - # This gets the job done but we might want a better solution in the future. - # In particular, we could have a separate way of specifying where the - # the cov variables finally end up, independent of where their various - # contributions are computed. Right now these are the same thing, but in - # the future we might want to perform the cov computations on each tower, - # so that each tower will be considered a "source" (allowing us to reuse - # the existing "source" code for this). - with maybe_colocate_with(new_cov_contribs[0]): - new_cov = math_ops.add_n(new_cov_contribs) - # Synchronize value across all TPU cores. - if utils.on_tpu(): - new_cov = utils.cross_replica_mean(new_cov) - return moving_averages.assign_moving_average( - self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) + new_cov_contribs = [] + for source in range(self._num_sources): + for tower in range(self._num_towers): + device = (self._get_data_device(tower) + if TOWER_STRATEGY == "separate" else None) + with place_on_device(device): + new_cov_contribs.append(self._compute_new_cov(source, tower)) + + new_cov = math_ops.add_n(new_cov_contribs) / float(self._num_towers) + + # I have no idea if the TPU code below is still correct since I don't know + # what it actually does. Also, this code is not present in some of the + # other versions of make_covariance_update_op. Does it matter? + # Synchronize value across all TPU cores. + if utils.on_tpu(): + new_cov = utils.cross_replica_mean(new_cov) + return moving_averages.assign_moving_average( + self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) + + @abc.abstractmethod + def _get_data_device(self, tower): + pass + + @abc.abstractmethod + def instantiate_inv_variables(self): + """Makes the internal "inverse" variable(s).""" + pass @abc.abstractmethod def make_inverse_update_ops(self): @@ -341,70 +380,47 @@ class FisherFactor(object): return self._cov @abc.abstractmethod - def left_multiply(self, x, damping): - """Multiplies 'x' by the damped covariance of this factor. + def left_multiply_matpower(self, x, exp, damping_func): + """Left multiplies 'x' by matrix power of this factor (w/ damping applied). - Let C be the covariance matrix this factor represents, and - D = C + damping * I be its damped variant. This method calculates - matmul(D, vec(x)). + This calculation is essentially: + (C + damping * I)**exp * x + where * is matrix-multiplication, ** is matrix power, I is the identity + matrix, and C is the matrix represented by this factor. - Args: - x: Tensor. Represents a single vector. Shape depends on implementation. - damping: 0-D Tensor. Damping to add to C's diagonal. - - Returns: - Tensor of same shape as 'x'. - """ - pass - - @abc.abstractmethod - def right_multiply(self, x, damping): - """Multiplies 'x' by the damped covariance of this factor. - - Let C be the covariance matrix this factor represents, and - D = C + damping * I be its damped variant. This method calculates - matmul(vec(x), D). + x can represent either a matrix or a vector. For some factors, 'x' might + represent a vector but actually be stored as a 2D matrix for convenience. Args: x: Tensor. Represents a single vector. Shape depends on implementation. - damping: 0-D Tensor. Damping to add to C's diagonal. + exp: float. The matrix exponent to use. + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). Returns: - Tensor of same shape as 'x'. + Tensor of same shape as 'x' representing the result of the multiplication. """ pass @abc.abstractmethod - def left_multiply_inverse(self, x, damping): - """Multiplies 'x' by damped inverse of this factor. - - Let C be the covariance matrix this factor represents and - E = inv(C + damping * I) be its damped inverse. This method calculates - matmul(E, vec(x)). - - Args: - x: Tensor. Represents a single vector. Shape depends on implementation. - damping: 0-D Tensor. Damping to add to C's diagonal. - - Returns: - Tensor of same shape as 'x'. - """ - pass + def right_multiply_matpower(self, x, exp, damping_func): + """Right multiplies 'x' by matrix power of this factor (w/ damping applied). - @abc.abstractmethod - def right_multiply_inverse(self, x, damping): - """Multiplies 'x' by damped inverse of this factor. + This calculation is essentially: + x * (C + damping * I)**exp + where * is matrix-multiplication, ** is matrix power, I is the identity + matrix, and C is the matrix represented by this factor. - Let C be the covariance matrix this factor represents and - E = inv(C + damping * I) be its damped inverse. This method calculates - matmul(vec(x), E). + Unlike left_multiply_matpower, x will always be a matrix. Args: x: Tensor. Represents a single vector. Shape depends on implementation. - damping: 0-D Tensor. Damping to add to C's diagonal. + exp: float. The matrix exponent to use. + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). Returns: - Tensor of same shape as 'x'. + Tensor of same shape as 'x' representing the result of the multiplication. """ pass @@ -428,47 +444,52 @@ class InverseProvidingFactor(FisherFactor): # the latter. def __init__(self): - self._inverses_by_damping = {} - self._matpower_by_exp_and_damping = {} + self._matpower_by_exp_and_damping = {} # { (float, hashable): variable } + self._matpower_registrations = set() # { (float, hashable) } self._eigendecomp = None + self._damping_funcs_by_id = {} # {hashable: lambda} super(InverseProvidingFactor, self).__init__() - def register_damped_inverse(self, damping): - """Registers a damped inverse needed by a FisherBlock. - - This creates a variable and signals make_inverse_update_ops to make the - corresponding update op. The variable can be read via the method - get_inverse. + def _register_damping(self, damping_func): + damping_id = graph_func_to_id(damping_func) + if damping_id not in self._damping_funcs_by_id: + self._damping_funcs_by_id[damping_id] = damping_func + return damping_id - Args: - damping: The damping value (float or Tensor) for this factor. - """ - if damping not in self._inverses_by_damping: - damping_string = scalar_or_tensor_to_string(damping) - with variable_scope.variable_scope(self._var_scope): - inv = variable_scope.get_variable( - "inv_damp{}".format(damping_string), - initializer=inverse_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - self._inverses_by_damping[damping] = inv + def register_inverse(self, damping_func): + # Just for backwards compatibility of some old code and tests + self.register_matpower(-1, damping_func) - def register_matpower(self, exp, damping): - """Registers a matrix power needed by a FisherBlock. + def register_matpower(self, exp, damping_func): + """Registers a matrix power to be maintained and served on demand. This creates a variable and signals make_inverse_update_ops to make the corresponding update op. The variable can be read via the method get_matpower. Args: - exp: The exponent (float or Tensor) to raise the matrix to. - damping: The damping value (float or Tensor). + exp: float. The exponent to use in the matrix power. + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). """ - if (exp, damping) not in self._matpower_by_exp_and_damping: + if exp == 1.0: + # We don't register these. The user shouldn't even be calling this + # function with exp = 1.0. + return + + damping_id = self._register_damping(damping_func) + + if (exp, damping_id) not in self._matpower_registrations: + self._matpower_registrations.add((exp, damping_id)) + + def instantiate_inv_variables(self): + """Makes the internal "inverse" variable(s).""" + + for (exp, damping_id) in self._matpower_registrations: exp_string = scalar_or_tensor_to_string(exp) - damping_string = scalar_or_tensor_to_string(damping) + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) with variable_scope.variable_scope(self._var_scope): matpower = variable_scope.get_variable( "matpower_exp{}_damp{}".format(exp_string, damping_string), @@ -476,34 +497,35 @@ class InverseProvidingFactor(FisherFactor): shape=self._cov_shape, trainable=False, dtype=self._dtype) - self._matpower_by_exp_and_damping[(exp, damping)] = matpower + assert (exp, damping_id) not in self._matpower_by_exp_and_damping + self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" ops = [] - # We do this to ensure that we don't reuse the eigendecomp from old calls - # to make_inverse_update_ops that may be placed on different devices. This - # can happen is the user has both a permanent and lazily constructed - # version of the inverse ops (and only uses one of them). - self.reset_eigendecomp() + num_inverses = sum(1 for (exp, _) in self._matpower_by_exp_and_damping + if exp == -1) + + num_other_matpower = len(self._matpower_by_exp_and_damping) - num_inverses + + other_matrix_power_registered = num_other_matpower >= 1 - num_inverses = len(self._inverses_by_damping) - matrix_power_registered = bool(self._matpower_by_exp_and_damping) use_eig = ( - self._eigendecomp or matrix_power_registered or + self._eigendecomp or other_matrix_power_registered or num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) + # We precompute these so we don't need to evaluate them multiple times (for + # each matrix power that uses them) + damping_value_by_id = {damping_id: self._damping_funcs_by_id[damping_id]() + for damping_id in self._damping_funcs_by_id} + if use_eig: eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence - for damping, inv in self._inverses_by_damping.items(): - ops.append( - inv.assign( - math_ops.matmul(eigenvectors / (eigenvalues + damping), - array_ops.transpose(eigenvectors)))) - - for (exp, damping), matpower in self._matpower_by_exp_and_damping.items(): + for (exp, damping_id), matpower in ( + self._matpower_by_exp_and_damping.items()): + damping = damping_value_by_id[damping_id] ops.append( matpower.assign( math_ops.matmul(eigenvectors * @@ -512,28 +534,31 @@ class InverseProvidingFactor(FisherFactor): # These ops share computation and should be run on a single device. ops = [control_flow_ops.group(*ops)] else: - for damping, inv in self._inverses_by_damping.items(): - ops.append(inv.assign(utils.posdef_inv(self._cov, damping))) + for (exp, damping_id), matpower in ( + self._matpower_by_exp_and_damping.items()): + assert exp == -1 + damping = damping_value_by_id[damping_id] + ops.append(matpower.assign(utils.posdef_inv(self._cov, damping))) + self._eigendecomp = False return ops - def get_damped_inverse(self, damping): - # Note that this function returns a variable which gets updated by the - # inverse ops. It may be stale / inconsistent with the latest value of - # get_cov(). - return self._inverses_by_damping[damping] + def get_inverse(self, damping_func): + # Just for backwards compatibility of some old code and tests + damping_id = graph_func_to_id(damping_func) + return self._matpower_by_exp_and_damping[(-1, damping_id)] - def get_matpower(self, exp, damping): + def get_matpower(self, exp, damping_func): # Note that this function returns a variable which gets updated by the # inverse ops. It may be stale / inconsistent with the latest value of # get_cov(). - return self._matpower_by_exp_and_damping[(exp, damping)] + damping_id = graph_func_to_id(damping_func) + return self._matpower_by_exp_and_damping[(exp, damping_id)] def get_eigendecomp(self): """Creates or retrieves eigendecomposition of self._cov.""" - # Unlike get_inverse and get_matpower this doesn't retrieve a stored - # variable, but instead always computes a fresh version from the current - # value of get_cov(). + # Unlike get_matpower this doesn't retrieve a stored variable, but instead + # always computes a fresh version from the current value of get_cov(). if not self._eigendecomp: eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov) @@ -546,63 +571,42 @@ class InverseProvidingFactor(FisherFactor): return self._eigendecomp - def reset_eigendecomp(self): - self._eigendecomp = None - def get_cov(self): # Variable contains full covariance matrix. return self.get_cov_var() - def left_multiply(self, x, damping): - n = self.get_cov().shape[0] - damped_cov = self.get_cov() + damping * array_ops.eye(n) - + def left_multiply_matpower(self, x, exp, damping_func): if isinstance(x, tf_ops.IndexedSlices): - raise NotImplementedError( - "Left-multiply not yet supported for IndexedSlices.") + raise ValueError("Left-multiply not yet supported for IndexedSlices.") - if len(x.shape) != 2: + if x.shape.ndims != 2: raise ValueError( "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." % (x,)) - return math_ops.matmul(damped_cov, x) + if exp == 1: + return math_ops.matmul(self.get_cov(), x) + damping_func() * x - def right_multiply(self, x, damping): - n = self.get_cov().shape[0] - damped_cov = self.get_cov() + damping * array_ops.eye(n) + return math_ops.matmul(self.get_matpower(exp, damping_func), x) + def right_multiply_matpower(self, x, exp, damping_func): if isinstance(x, tf_ops.IndexedSlices): - return utils.matmul_sparse_dense(x, damped_cov) - - if len(x.shape) != 2: - raise ValueError( - "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." - % (x,)) + if exp == 1: + n = self.get_cov().shape[0] + damped_cov = self.get_cov() + damping_func() * array_ops.eye(n) + return utils.matmul_sparse_dense(x, damped_cov) - return math_ops.matmul(x, damped_cov) - - def left_multiply_inverse(self, x, damping): - if isinstance(x, tf_ops.IndexedSlices): - raise ValueError("Left-multiply not yet supported for IndexedSlices.") + return utils.matmul_sparse_dense(x, self.get_matpower(exp, damping_func)) if x.shape.ndims != 2: raise ValueError( "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." % (x,)) - return math_ops.matmul(self.get_damped_inverse(damping), x) + if exp == 1: + return math_ops.matmul(x, self.get_cov()) + damping_func() * x - def right_multiply_inverse(self, x, damping): - if isinstance(x, tf_ops.IndexedSlices): - return utils.matmul_sparse_dense(x, self.get_damped_inverse(damping)) - - if x.shape.ndims != 2: - raise ValueError( - "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." - % (x,)) - - return math_ops.matmul(x, self.get_damped_inverse(damping)) + return math_ops.matmul(x, self.get_matpower(exp, damping_func)) class FullFactor(InverseProvidingFactor): @@ -622,7 +626,7 @@ class FullFactor(InverseProvidingFactor): @property def _var_scope(self): - return "ff_full/" + scope_string_from_params( + return "ff_full_" + scope_string_from_params( [self._params_grads, self._batch_size]) @property @@ -635,17 +639,25 @@ class FullFactor(InverseProvidingFactor): def _num_sources(self): return len(self._params_grads) + @property + def _num_towers(self): + return 1 + @property def _dtype(self): return self._params_grads[0][0].dtype - def _compute_new_cov(self, idx=0): + def _compute_new_cov(self, source, tower): + assert tower == 0 + # This will be a very basic rank 1 estimate - with maybe_colocate_with(self._params_grads[idx]): - params_grads_flat = utils.tensors_to_column(self._params_grads[idx]) - return ((params_grads_flat * array_ops.transpose( - params_grads_flat)) / math_ops.cast(self._batch_size, - params_grads_flat.dtype)) + params_grads_flat = utils.tensors_to_column(self._params_grads[source]) + return ((params_grads_flat * array_ops.transpose( + params_grads_flat)) / math_ops.cast(self._batch_size, + params_grads_flat.dtype)) + + def _get_data_device(self, tower): + return None class DiagonalFactor(FisherFactor): @@ -656,6 +668,7 @@ class DiagonalFactor(FisherFactor): """ def __init__(self): + self._damping_funcs_by_id = {} # { hashable: lambda } super(DiagonalFactor, self).__init__() @property @@ -665,43 +678,30 @@ class DiagonalFactor(FisherFactor): def make_inverse_update_ops(self): return [] + def instantiate_inv_variables(self): + pass + def get_cov(self): # self.get_cov() could be any shape, but it must have one entry per # parameter. Flatten it into a vector. cov_diag_vec = array_ops.reshape(self.get_cov_var(), [-1]) return array_ops.diag(cov_diag_vec) - def left_multiply(self, x, damping): - damped_cov = self.get_cov_var() + damping - if isinstance(x, tf_ops.IndexedSlices): - return utils.matmul_diag_sparse(array_ops.reshape(damped_cov, [-1]), x) - - if x.shape != damped_cov.shape: - raise ValueError("x (%s) and cov (%s) must have same shape." % - (x, damped_cov)) - - return damped_cov * x - - def right_multiply(self, x, damping): - raise NotImplementedError("Only left-multiply is currently supported.") - - def left_multiply_inverse(self, x, damping): - inverse = 1. / (self.get_cov_var() + damping) + def left_multiply_matpower(self, x, exp, damping_func): + matpower = (self.get_cov_var() + damping_func())**exp if isinstance(x, tf_ops.IndexedSlices): - return utils.matmul_diag_sparse(array_ops.reshape(inverse, [-1]), x) + return utils.matmul_diag_sparse(array_ops.reshape(matpower, [-1]), x) - if x.shape != inverse.shape: + if x.shape != matpower.shape: raise ValueError("x (%s) and cov (%s) must have same shape." % - (x, inverse)) - - return inverse * x + (x, matpower)) + return matpower * x - def right_multiply_inverse(self, x, damping): + def right_multiply_matpower(self, x, exp, damping_func): raise NotImplementedError("Only left-multiply is currently supported.") - def register_damped_inverse(self, damping): - # DiagonalFactors don't keep explicit inverses. + def register_matpower(self, exp, damping_func): pass @@ -730,7 +730,7 @@ class NaiveDiagonalFactor(DiagonalFactor): @property def _var_scope(self): - return "ff_naivediag/" + scope_string_from_params( + return "ff_naivediag_" + scope_string_from_params( [self._params_grads, self._batch_size]) @property @@ -743,15 +743,23 @@ class NaiveDiagonalFactor(DiagonalFactor): def _num_sources(self): return len(self._params_grads) + @property + def _num_towers(self): + return 1 + @property def _dtype(self): return self._params_grads[0][0].dtype - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._params_grads[idx]): - params_grads_flat = utils.tensors_to_column(self._params_grads[idx]) - return (math_ops.square(params_grads_flat) / math_ops.cast( - self._batch_size, params_grads_flat.dtype)) + def _compute_new_cov(self, source, tower): + assert tower == 0 + + params_grads_flat = utils.tensors_to_column(self._params_grads[source]) + return (math_ops.square(params_grads_flat) / math_ops.cast( + self._batch_size, params_grads_flat.dtype)) + + def _get_data_device(self, tower): + return None class EmbeddingInputKroneckerFactor(DiagonalFactor): @@ -772,8 +780,8 @@ class EmbeddingInputKroneckerFactor(DiagonalFactor): """Instantiate EmbeddingInputKroneckerFactor. Args: - input_ids: Tuple of Tensors of shape [batch_size, input_size] and dtype - int32. Indices into embedding matrix. + input_ids: List of Tensors of shape [batch_size, input_size] and dtype + int32. Indices into embedding matrix. List index is tower. vocab_size: int or 0-D Tensor. Maximum value for entries in 'input_ids'. dtype: dtype for covariance statistics. Must be a floating point type. Defaults to float32. @@ -786,7 +794,7 @@ class EmbeddingInputKroneckerFactor(DiagonalFactor): @property def _var_scope(self): - return "ff_diag_embedding/" + scope_string_from_params(self._input_ids) + return "ff_diag_embedding_" + scope_string_from_params(self._input_ids) @property def _cov_shape(self): @@ -794,42 +802,51 @@ class EmbeddingInputKroneckerFactor(DiagonalFactor): @property def _num_sources(self): + return 1 + + @property + def _num_towers(self): return len(self._input_ids) @property def _dtype(self): return self._cov_dtype - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._input_ids): - input_ids = self._input_ids[idx] - if len(input_ids.shape) > 2: - raise ValueError( - "Input to embeddings must have rank <= 2. Found rank %d." % len( - input_ids.shape)) + def _compute_new_cov(self, source, tower): + assert source == 0 + + input_ids = self._input_ids[tower] + + if len(input_ids.shape) > 2: + raise ValueError( + "Input to embeddings must have rank <= 2. Found rank %d." % len( + input_ids.shape)) + + batch_size = array_ops.shape(input_ids)[0] - batch_size = array_ops.shape(input_ids)[0] + # Transform indices into one-hot vectors. + # + # TODO(b/72714822): There must be a faster way to construct the diagonal + # covariance matrix! This operation is O(batch_size * vocab_size), where + # it should be O(batch_size * input_size). + flat_input_ids = array_ops.reshape(input_ids, [-1]) + one_hots = array_ops.one_hot(flat_input_ids, + self._vocab_size) # [?, vocab_size] - # Transform indices into one-hot vectors. - # - # TODO(b/72714822): There must be a faster way to construct the diagonal - # covariance matrix! This operation is O(batch_size * vocab_size), where - # it should be O(batch_size * input_size). - flat_input_ids = array_ops.reshape(input_ids, [-1]) - one_hots = array_ops.one_hot(flat_input_ids, - self._vocab_size) # [?, vocab_size] + # Take average across examples. Note that, because all entries have + # magnitude zero or one, there's no need to square the entries. + # + # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation + # within an example such as average. + # + # TODO(b/72714822): Support for partitioned embeddings. + new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size] + new_cov /= math_ops.cast(batch_size, new_cov.dtype) - # Take average across examples. Note that, because all entries have - # magnitude zero or one, there's no need to square the entries. - # - # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation - # within an example such as average. - # - # TODO(b/72714822): Support for partitioned embeddings. - new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size] - new_cov /= math_ops.cast(batch_size, new_cov.dtype) + return new_cov - return new_cov + def _get_data_device(self, tower): + return self._input_ids[tower].device class FullyConnectedDiagonalFactor(DiagonalFactor): @@ -850,58 +867,75 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): """Instantiate FullyConnectedDiagonalFactor. Args: - inputs: Tensor of shape [batch_size, input_size]. Inputs to fully - connected layer. - outputs_grads: List of Tensors of shape [batch_size, output_size]. - Gradient of loss with respect to layer's preactivations. + inputs: List of Tensors of shape [batch_size, input_size]. Inputs to this + layer. List index is towers. + outputs_grads: List of Tensors, each of shape [batch_size, output_size], + which are the gradients of the loss with respect to the layer's + outputs. First index is source, second is tower. + has_bias: bool. If True, append '1' to each input. """ self._inputs = inputs self._has_bias = has_bias self._outputs_grads = outputs_grads - self._batch_size = array_ops.shape(inputs)[0] self._squared_inputs = None super(FullyConnectedDiagonalFactor, self).__init__() @property def _var_scope(self): - return "ff_diagfc/" + scope_string_from_params( - (self._inputs,) + tuple(self._outputs_grads)) + return "ff_diagfc_" + scope_string_from_params( + tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) @property def _cov_shape(self): - input_size = self._inputs.shape[1] + self._has_bias - output_size = self._outputs_grads[0].shape[1] + input_size = self._inputs[0].shape[1] + self._has_bias + output_size = self._outputs_grads[0][0].shape[1] return [input_size, output_size] @property def _num_sources(self): return len(self._outputs_grads) + @property + def _num_towers(self): + return len(self._inputs) + @property def _dtype(self): - return self._outputs_grads[0].dtype + return self._outputs_grads[0][0].dtype + + def make_covariance_update_op(self, ema_decay): + + self._squared_inputs = [] + for tower in range(self._num_towers): + inputs = self._inputs[tower] + + with place_on_device(self._get_data_device(tower)): + if self._has_bias: + inputs = append_homog(inputs) + self._squared_inputs.append(math_ops.square(inputs)) + + return super(FullyConnectedDiagonalFactor, self).make_covariance_update_op( + ema_decay) + + def _compute_new_cov(self, source, tower): + batch_size = array_ops.shape(self._squared_inputs[tower])[0] + outputs_grad = self._outputs_grads[source][tower] - def _compute_new_cov(self, idx=0): # The well-known special formula that uses the fact that the entry-wise # square of an outer product is the outer-product of the entry-wise squares. # The gradient is the outer product of the input and the output gradients, # so we just square both and then take their outer-product. - with maybe_colocate_with(self._outputs_grads[idx]): - # We only need to compute squared_inputs once - if self._squared_inputs is None: - inputs = self._inputs - if self._has_bias: - inputs = append_homog(self._inputs) - self._squared_inputs = math_ops.square(inputs) + new_cov = math_ops.matmul( + self._squared_inputs[tower], + math_ops.square(outputs_grad), + transpose_a=True) + new_cov /= math_ops.cast(batch_size, new_cov.dtype) + return new_cov - new_cov = math_ops.matmul( - self._squared_inputs, - math_ops.square(self._outputs_grads[idx]), - transpose_a=True) - new_cov /= math_ops.cast(self._batch_size, new_cov.dtype) - return new_cov + def _get_data_device(self, tower): + return self._inputs[tower].device class ConvDiagonalFactor(DiagonalFactor): @@ -913,36 +947,67 @@ class ConvDiagonalFactor(DiagonalFactor): filter_shape, strides, padding, + data_format=None, + dilations=None, has_bias=False): """Creates a ConvDiagonalFactor object. Args: - inputs: Tensor of shape [batch_size, height, width, in_channels]. - Input activations to this layer. - outputs_grads: Tensor of shape [batch_size, height, width, out_channels]. - Per-example gradients to the loss with respect to the layer's output - preactivations. + inputs: List of Tensors of shape [batch_size, height, width, in_channels]. + Input activations to this layer. List index is towers. + outputs_grads: List of Tensors, each of shape [batch_size, + height, width, out_channels], which are the gradients of the loss + with respect to the layer's outputs. First index is source, second + index is tower. filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels, out_channels). Represents shape of kernel used in this layer. strides: The stride size in this layer (1-D Tensor of length 4). padding: The padding in this layer (1-D of Tensor length 4). + data_format: None or str. Format of conv2d inputs. + dilations: None or tuple of 4 ints. has_bias: Python bool. If True, the layer is assumed to have a bias parameter in addition to its filter parameter. + + Raises: + ValueError: If inputs, output_grads, and filter_shape do not agree on + in_channels or out_channels. + ValueError: If strides, dilations are not length-4 lists of ints. + ValueError: If data_format does not put channel last. """ + if not utils.is_data_format_channel_last(data_format): + raise ValueError("Channel must be last.") + if any(input_.shape.ndims != 4 for input_ in inputs): + raise ValueError("inputs must be a list of 4-D Tensors.") + if any(input_.shape.as_list()[-1] != filter_shape[-2] for input_ in inputs): + raise ValueError("inputs and filter_shape must agree on in_channels.") + for i, outputs_grad in enumerate(outputs_grads): + if any(output_grad.shape.ndims != 4 for output_grad in outputs_grad): + raise ValueError("outputs[%d] must be 4-D Tensor." % i) + if any(output_grad.shape.as_list()[-1] != filter_shape[-1] + for output_grad in outputs_grad): + raise ValueError( + "outputs[%d] and filter_shape must agree on out_channels." % i) + if len(strides) != 4: + raise ValueError("strides must be length-4 list of ints.") + if dilations is not None and len(dilations) != 4: + raise ValueError("dilations must be length-4 list of ints.") + self._inputs = inputs + self._outputs_grads = outputs_grads self._filter_shape = filter_shape self._strides = strides self._padding = padding + self._data_format = data_format + self._dilations = dilations self._has_bias = has_bias - self._outputs_grads = outputs_grads self._patches = None super(ConvDiagonalFactor, self).__init__() @property def _var_scope(self): - return "ff_convdiag/" + scope_string_from_name( - (self._inputs,) + tuple(self._outputs_grads)) + return "ff_convdiag_" + scope_string_from_params( + tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) @property def _cov_shape(self): @@ -956,43 +1021,50 @@ class ConvDiagonalFactor(DiagonalFactor): def _num_sources(self): return len(self._outputs_grads) + @property + def _num_towers(self): + return len(self._inputs) + @property def _dtype(self): - return self._outputs_grads[0].dtype + return self._inputs[0].dtype def make_covariance_update_op(self, ema_decay): - with maybe_colocate_with(self._inputs): - filter_height, filter_width, _, _ = self._filter_shape - - # TODO(b/64144716): there is potential here for a big savings in terms - # of memory use. - patches = array_ops.extract_image_patches( - self._inputs, - ksizes=[1, filter_height, filter_width, 1], - strides=self._strides, - rates=[1, 1, 1, 1], - padding=self._padding) - - if self._has_bias: - patches = append_homog(patches) + filter_height, filter_width, _, _ = self._filter_shape - self._patches = patches + # TODO(b/64144716): there is potential here for a big savings in terms + # of memory use. + if self._dilations is None: + rates = (1, 1, 1, 1) + else: + rates = tuple(self._dilations) + + self._patches = [] + for tower in range(self._num_towers): + with place_on_device(self._get_data_device(tower)): + patches = array_ops.extract_image_patches( + self._inputs[tower], + ksizes=[1, filter_height, filter_width, 1], + strides=self._strides, + rates=rates, + padding=self._padding) - op = super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay) + if self._has_bias: + patches = append_homog(patches) - self._patches = None + self._patches.append(patches) - return op + return super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay) - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._outputs_grads[idx]): - outputs_grad = self._outputs_grads[idx] - batch_size = array_ops.shape(self._patches)[0] + def _compute_new_cov(self, source, tower): + patches = self._patches[tower] + batch_size = array_ops.shape(patches)[0] + outputs_grad = self._outputs_grads[source][tower] - new_cov = self._convdiag_sum_of_squares(self._patches, outputs_grad) - new_cov /= math_ops.cast(batch_size, new_cov.dtype) + new_cov = self._convdiag_sum_of_squares(patches, outputs_grad) + new_cov /= math_ops.cast(batch_size, new_cov.dtype) - return new_cov + return new_cov def _convdiag_sum_of_squares(self, patches, outputs_grad): # This computes the sum of the squares of the per-training-case "gradients". @@ -1002,6 +1074,9 @@ class ConvDiagonalFactor(DiagonalFactor): outputs_grad) return math_ops.reduce_sum(math_ops.square(case_wise_gradients), axis=0) + def _get_data_device(self, tower): + return self._inputs[tower].device + class FullyConnectedKroneckerFactor(InverseProvidingFactor): """Kronecker factor for the input or output side of a fully-connected layer. @@ -1013,8 +1088,9 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): """Instantiate FullyConnectedKroneckerFactor. Args: - tensors: List of Tensors of shape [batch_size, n]. Represents either a - layer's inputs or its output's gradients. + tensors: List of list of Tensors, each of shape [batch_size, n]. The + Tensors are typically either a layer's inputs or its output's gradients. + The first list index is source, the second is tower. has_bias: bool. If True, append '1' to each row. """ # The tensor argument is either a tensor of input activations or a tensor of @@ -1025,28 +1101,34 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): @property def _var_scope(self): - return "ff_fckron/" + scope_string_from_params( - [self._tensors, self._has_bias]) + return "ff_fckron_" + scope_string_from_params( + tuple(nest.flatten(self._tensors)) + (self._has_bias,)) @property def _cov_shape(self): - size = self._tensors[0].shape[1] + self._has_bias + size = self._tensors[0][0].shape[1] + self._has_bias return [size, size] @property def _num_sources(self): return len(self._tensors) + @property + def _num_towers(self): + return len(self._tensors[0]) + @property def _dtype(self): - return self._tensors[0].dtype + return self._tensors[0][0].dtype + + def _compute_new_cov(self, source, tower): + tensor = self._tensors[source][tower] + if self._has_bias: + tensor = append_homog(tensor) + return compute_cov(tensor) - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._tensors[idx]): - tensor = self._tensors[idx] - if self._has_bias: - tensor = append_homog(tensor) - return compute_cov(tensor) + def _get_data_device(self, tower): + return self._tensors[0][tower].device class ConvInputKroneckerFactor(InverseProvidingFactor): @@ -1062,84 +1144,133 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): def __init__(self, inputs, filter_shape, - strides, padding, + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None, has_bias=False): """Initializes ConvInputKroneckerFactor. Args: - inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs - to layer. - filter_shape: 1-D Tensor of length 4. Contains [kernel_height, - kernel_width, in_channels, out_channels]. - strides: 1-D Tensor of length 4. Contains [batch_stride, height_stride, - width_stride, in_channel_stride]. + inputs: List of Tensors of shape [batch_size, ..spatial_input_size.., + in_channels]. Inputs to layer. List index is tower. + filter_shape: List of ints. Contains [..spatial_filter_size.., + in_channels, out_channels]. Shape of convolution kernel. padding: str. Padding method for layer. "SAME" or "VALID". + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". has_bias: bool. If True, append 1 to in_channel. """ + self._inputs = inputs self._filter_shape = filter_shape self._strides = strides self._padding = padding + self._dilation_rate = dilation_rate + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn self._has_bias = has_bias - self._inputs = inputs + super(ConvInputKroneckerFactor, self).__init__() @property def _var_scope(self): - return "ff_convinkron/" + scope_string_from_params([ - self._inputs, self._filter_shape, self._strides, self._padding, - self._has_bias - ]) + return "ff_convinkron_" + scope_string_from_params( + tuple(self._inputs) + + tuple((self._filter_shape, self._strides, self._padding, + self._dilation_rate, self._data_format, self._has_bias))) @property def _cov_shape(self): - filter_height, filter_width, in_channels, _ = self._filter_shape - size = filter_height * filter_width * in_channels + self._has_bias + spatial_filter_shape = self._filter_shape[0:-2] + in_channels = self._filter_shape[-2] + size = np.prod(spatial_filter_shape) * in_channels + self._has_bias return [size, size] @property def _num_sources(self): return 1 + @property + def _num_towers(self): + return len(self._inputs) + @property def _dtype(self): - return self._inputs.dtype + return self._inputs[0].dtype - def _compute_new_cov(self, idx=0): - if idx != 0: - raise ValueError("ConvInputKroneckerFactor only supports idx = 0") + def _compute_new_cov(self, source, tower): + assert source == 0 - with maybe_colocate_with(self._inputs): - filter_height, filter_width, in_channels, _ = self._filter_shape + inputs = self._inputs[tower] - # TODO(b/64144716): there is potential here for a big savings in terms of - # memory use. + # TODO(b/64144716): there is potential here for a big savings in terms of + # memory use. + if self._extract_patches_fn in [None, "extract_convolution_patches"]: + patches = utils.extract_convolution_patches( + inputs, + self._filter_shape, + padding=self._padding, + strides=self._strides, + dilation_rate=self._dilation_rate, + data_format=self._data_format) + + elif self._extract_patches_fn == "extract_image_patches": + assert inputs.shape.ndims == 4 + assert len(self._filter_shape) == 4 + assert len(self._strides) == 4, self._strides + if self._dilation_rate is None: + rates = [1, 1, 1, 1] + else: + rates = self._dilation_rate + assert len(rates) == 4 + assert rates[0] == rates[-1] == 1 patches = array_ops.extract_image_patches( - self._inputs, - ksizes=[1, filter_height, filter_width, 1], + inputs, + ksizes=[1] + list(self._filter_shape[0:-2]) + [1], strides=self._strides, - rates=[1, 1, 1, 1], + rates=rates, padding=self._padding) - flatten_size = (filter_height * filter_width * in_channels) - # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde - # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14), - # where M = minibatch size, |T| = number of spatial locations, - # |Delta| = number of spatial offsets, and J = number of input maps - # for convolutional layer l. - patches_flat = array_ops.reshape(patches, [-1, flatten_size]) - # We append a homogenous coordinate to patches_flat if the layer has - # bias parameters. This gives us [[A_l]]_H from the paper. - if self._has_bias: - patches_flat = append_homog(patches_flat) - # We call compute_cov without passing in a normalizer. compute_cov uses - # the first dimension of patches_flat i.e. M|T| as the normalizer by - # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with - # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from - # the paper but has a different scale here for consistency with - # ConvOutputKroneckerFactor. - # (Tilde omitted over A for clarity.) - return compute_cov(patches_flat) + elif self._extract_patches_fn == "extract_pointwise_conv2d_patches": + assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)] + assert self._filter_shape[0] == self._filter_shape[1] == 1 + patches = utils.extract_pointwise_conv2d_patches( + inputs, self._filter_shape, data_format=None) + + else: + raise NotImplementedError(self._extract_patches_fn) + + flatten_size = np.prod(self._filter_shape[0:-1]) + # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde + # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14), + # where M = minibatch size, |T| = number of spatial locations, + # |Delta| = number of spatial offsets, and J = number of input maps + # for convolutional layer l. + patches_flat = array_ops.reshape(patches, [-1, flatten_size]) + # We append a homogenous coordinate to patches_flat if the layer has + # bias parameters. This gives us [[A_l]]_H from the paper. + if self._has_bias: + patches_flat = append_homog(patches_flat) + # We call compute_cov without passing in a normalizer. compute_cov uses + # the first dimension of patches_flat i.e. M|T| as the normalizer by + # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with + # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from + # the paper but has a different scale here for consistency with + # ConvOutputKroneckerFactor. + # (Tilde omitted over A for clarity.) + return compute_cov(patches_flat) + + def _get_data_device(self, tower): + return self._inputs[tower].device class ConvOutputKroneckerFactor(InverseProvidingFactor): @@ -1153,20 +1284,28 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): Section 3.1 Estimating the factors. """ - def __init__(self, outputs_grads): + def __init__(self, outputs_grads, data_format=None): """Initializes ConvOutputKroneckerFactor. Args: - outputs_grads: list of Tensors. Each Tensor is of shape - [batch_size, height, width, out_channels]. + outputs_grads: List of list of Tensors. Each Tensor is of shape + [batch_size, ..spatial_input_size.., out_channels]. First list index + is source, the second is tower. + data_format: None or str. Format of outputs_grads. + + Raises: + ValueError: If channels are not final dimension. """ - self._out_channels = outputs_grads[0].shape.as_list()[3] + if not utils.is_data_format_channel_last(data_format): + raise ValueError("Channel must be last.") + self._out_channels = outputs_grads[0][0].shape.as_list()[-1] self._outputs_grads = outputs_grads super(ConvOutputKroneckerFactor, self).__init__() @property def _var_scope(self): - return "ff_convoutkron/" + scope_string_from_params(self._outputs_grads) + return "ff_convoutkron_" + scope_string_from_params( + nest.flatten(self._outputs_grads)) @property def _cov_shape(self): @@ -1177,134 +1316,146 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return len(self._outputs_grads) + @property + def _num_towers(self): + return len(self._outputs_grads[0]) + @property def _dtype(self): - return self._outputs_grads[0].dtype - - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._outputs_grads[idx]): - # reshaped_tensor below is the matrix DS_l defined in the KFC paper - # (tilde omitted over S for clarity). It has shape M|T| x I, where - # M = minibatch size, |T| = number of spatial locations, and - # I = number of output maps for convolutional layer l. - reshaped_tensor = array_ops.reshape(self._outputs_grads[idx], - [-1, self._out_channels]) - # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov, - # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l - # as defined in the paper, with shape I x I. - # (Tilde omitted over S for clarity.) - return compute_cov(reshaped_tensor) - - -class FullyConnectedMultiKF(InverseProvidingFactor): - """Kronecker factor for a fully connected recurrent layer.""" + return self._outputs_grads[0][0].dtype + + def _compute_new_cov(self, source, tower): + outputs_grad = self._outputs_grads[source][tower] + + # reshaped_tensor below is the matrix DS_l defined in the KFC paper + # (tilde omitted over S for clarity). It has shape M|T| x I, where + # M = minibatch size, |T| = number of spatial locations, and + # I = number of output maps for convolutional layer l. + reshaped_tensor = array_ops.reshape(outputs_grad, [-1, self._out_channels]) + # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov, + # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l + # as defined in the paper, with shape I x I. + # (Tilde omitted over S for clarity.) + return compute_cov(reshaped_tensor) + + def _get_data_device(self, tower): + return self._outputs_grads[0][tower].device + + +class FullyConnectedMultiKF(FullyConnectedKroneckerFactor): + """Kronecker factor for a fully connected layer used multiple times.""" def __init__(self, - tensor_lists, + tensors, + num_uses=None, has_bias=False): """Constructs a new `FullyConnectedMultiKF`. Args: - tensor_lists: List of lists of Tensors of shape [batch_size, n]. + tensors: List of list of Tensors of shape, each of shape + [num_uses * batch_size, n], and is a reshape version of a Tensor of + shape [num_uses, batch_size, n]. Each of these tensors is usually a + layer's inputs or its output's gradients. The first list index is + sources, the second is towers. + num_uses: int. The number of time-steps / uses. has_bias: bool. If True, '1' is appended to each row. """ - self._tensor_lists = tensor_lists - self._has_bias = has_bias - self._batch_size = array_ops.shape(tensor_lists[0][0])[0] - self._num_timesteps = len(tensor_lists[0]) - self._tensors = [None] * len(tensor_lists) + self._num_uses = num_uses self._cov_dt1 = None + self._make_cov_dt1 = False self._option1quants_by_damping = {} self._option2quants_by_damping = {} + self._option1quants_registrations = set() + self._option2quants_registrations = set() - super(FullyConnectedMultiKF, self).__init__() - - @property - def _var_scope(self): - return "ff_fc_multi/" + scope_string_from_params(self._tensor_lists) + super(FullyConnectedMultiKF, self).__init__(tensors=tensors, + has_bias=has_bias) @property - def _num_sources(self): - return len(self._tensor_lists) + def _num_timesteps(self): + return self._num_uses @property - def _dtype(self): - return self._tensor_lists[0][0].dtype + def _var_scope(self): + return "ff_fc_multi_" + scope_string_from_params( + tuple(nest.flatten(self._tensors)) + + (self._num_timesteps, self._has_bias,)) def make_covariance_update_op(self, ema_decay): op = super(FullyConnectedMultiKF, self).make_covariance_update_op(ema_decay) if self._cov_dt1 is not None: - new_cov_dt1_contribs = tuple(self._compute_new_cov_dt1(idx) - for idx in range(self._num_sources)) + new_cov_dt1_contribs = [] + for source in range(self._num_sources): + for tower in range(self._num_towers): + with place_on_device(self._get_data_device(tower)): + new_cov_dt1_contribs.append(self._compute_new_cov_dt1(source, + tower)) + + new_cov_dt1 = (math_ops.add_n(new_cov_dt1_contribs) + / float(self._num_towers)) + + op2 = moving_averages.assign_moving_average( + self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS) + + # TODO(b/69112164): + # It's important that _cov and _cov_dt1 remain consistent with each + # other while the inverse ops are happening. How can we ensure this? + # We will need to add explicit synchronization for this to + # work with asynchronous training. + op = control_flow_ops.group(op, op2) - with maybe_colocate_with(new_cov_dt1_contribs[0]): - new_cov_dt1 = math_ops.add_n(new_cov_dt1_contribs) + return op - op2 = moving_averages.assign_moving_average( - self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS) + def _compute_new_cov_dt1(self, source, tower): # pylint: disable=missing-docstring + tensor = self._tensors[source][tower] + if self._has_bias: + # This appending is technically done twice (the other time is for + # _compute_new_cov()) + tensor = append_homog(tensor) - # TODO(b/69112164): - # It's important that _cov and _cov_dt1 remain consistent with each - # other while the inverse ops are happening. How can we ensure this? - # We will need to add explicit synchronization for this to - # work with asynchronous training. - op = control_flow_ops.group(op, op2) + total_len = array_ops.shape(tensor)[0] + batch_size = total_len // self._num_timesteps - return op + tensor_present = tensor[:-batch_size, :] + tensor_future = tensor[batch_size:, :] - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._tensor_lists[idx]): - tensor = array_ops.concat(self._tensor_lists[idx], 0) - if self._has_bias: - tensor = append_homog(tensor) - # We save these so they can be used by _compute_new_cov_dt1 - self._tensors[idx] = tensor - return compute_cov(tensor) - - def _compute_new_cov_dt1(self, idx=0): - tensor = self._tensors[idx] - with maybe_colocate_with(tensor): - # Is there a more elegant way to do this computation? - tensor_present = tensor[:-self._batch_size, :] - tensor_future = tensor[self._batch_size:, :] - # We specify a normalizer for this computation to ensure a PSD Fisher - # block estimate. This is equivalent to padding with zeros, as was done - # in Section B.2 of the appendix. - normalizer = self._num_timesteps * self._batch_size - return compute_cov( - tensor_future, tensor_right=tensor_present, normalizer=normalizer) + # We specify a normalizer for this computation to ensure a PSD Fisher + # block estimate. This is equivalent to padding with zeros, as was done + # in Section B.2 of the appendix. + return compute_cov( + tensor_future, tensor_right=tensor_present, normalizer=total_len) - @property - def _cov_shape(self): - size = self._tensor_lists[0][0].shape[1] + self._has_bias - return [size, size] + def _get_data_device(self, tower): + return self._tensors[0][tower].device @property def _vec_shape(self): - size = self._tensor_lists[0][0].shape[1] + self._has_bias + size = self._tensors[0][0].shape[1] + self._has_bias return [size] - def get_option1quants(self, damping): - return self._option1quants_by_damping[damping] + def get_option1quants(self, damping_func): + damping_id = graph_func_to_id(damping_func) + return self._option1quants_by_damping[damping_id] - def get_option2quants(self, damping): - return self._option2quants_by_damping[damping] + def get_option2quants(self, damping_func): + damping_id = graph_func_to_id(damping_func) + return self._option2quants_by_damping[damping_id] def get_cov_dt1(self): assert self._cov_dt1 is not None return self._cov_dt1 def register_cov_dt1(self): - """Create a variable representing temporal cross-covariance. + self._make_cov_dt1 = True - (This is technically the second moment, not covariance, since it's - not mean subtracted.) - """ - if self._cov_dt1 is None: + def instantiate_cov_variables(self): + super(FullyConnectedMultiKF, self).instantiate_cov_variables() + assert self._cov_dt1 is None + if self._make_cov_dt1: with variable_scope.variable_scope(self._var_scope): self._cov_dt1 = variable_scope.get_variable( "cov_dt1", @@ -1313,15 +1464,25 @@ class FullyConnectedMultiKF(InverseProvidingFactor): trainable=False, dtype=self._dtype) - def register_option1quants(self, damping): + def register_option1quants(self, damping_func): + damping_id = self._register_damping(damping_func) + if damping_id not in self._option1quants_registrations: + self._option1quants_registrations.add(damping_id) + + def register_option2quants(self, damping_func): + damping_id = self._register_damping(damping_func) + if damping_id not in self._option2quants_registrations: + self._option2quants_registrations.add(damping_id) - self.register_cov_dt1() + def instantiate_inv_variables(self): + super(FullyConnectedMultiKF, self).instantiate_inv_variables() - if damping not in self._option1quants_by_damping: + for damping_id in self._option1quants_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) # It's questionable as to whether we should initialize with stuff like # this at all. Ideally these values should never be used until they are # updated at least once. - damping_string = scalar_or_tensor_to_string(damping) with variable_scope.variable_scope(self._var_scope): Lmat = variable_scope.get_variable( # pylint: disable=invalid-name "Lmat_damp{}".format(damping_string), @@ -1336,17 +1497,15 @@ class FullyConnectedMultiKF(InverseProvidingFactor): trainable=False, dtype=self._dtype) - self._option1quants_by_damping[damping] = (Lmat, psi) + assert damping_id not in self._option1quants_by_damping + self._option1quants_by_damping[damping_id] = (Lmat, psi) - def register_option2quants(self, damping): - - self.register_cov_dt1() - - if damping not in self._option2quants_by_damping: + for damping_id in self._option2quants_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) # It's questionable as to whether we should initialize with stuff like # this at all. Ideally these values should never be used until they are # updated at least once. - damping_string = scalar_or_tensor_to_string(damping) with variable_scope.variable_scope(self._var_scope): Pmat = variable_scope.get_variable( # pylint: disable=invalid-name "Lmat_damp{}".format(damping_string), @@ -1367,14 +1526,15 @@ class FullyConnectedMultiKF(InverseProvidingFactor): trainable=False, dtype=self._dtype) - self._option2quants_by_damping[damping] = (Pmat, Kmat, mu) + assert damping_id not in self._option2quants_by_damping + self._option2quants_by_damping[damping_id] = (Pmat, Kmat, mu) def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" # TODO(b/69918258): Add correctness tests for this method. # pylint: disable=invalid-name - ops = super(FullyConnectedMultiKF, self).make_inverse_update_ops() + ops = [] if (len(self._option1quants_by_damping) + len(self._option2quants_by_damping)): @@ -1395,8 +1555,10 @@ class FullyConnectedMultiKF(InverseProvidingFactor): # consistently, or are somehow read between or during the cov updates. # Can this possibly happen? Is there a way to prevent it? - for damping, (Lmat_var, - psi_var) in self._option1quants_by_damping.items(): + for damping_id, (Lmat_var, + psi_var) in self._option1quants_by_damping.items(): + + damping = self._damping_funcs_by_id[damping_id]() invsqrtC0 = math_ops.matmul( eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) @@ -1421,8 +1583,10 @@ class FullyConnectedMultiKF(InverseProvidingFactor): ops.append(Lmat_var.assign(Lmat)) ops.append(psi_var.assign(psi)) - for damping, (Pmat_var, Kmat_var, - mu_var) in self._option2quants_by_damping.items(): + for damping_id, (Pmat_var, Kmat_var, + mu_var) in self._option2quants_by_damping.items(): + + damping = self._damping_funcs_by_id[damping_id]() # compute C0^(-1/2) invsqrtC0 = math_ops.matmul( @@ -1463,6 +1627,7 @@ class FullyConnectedMultiKF(InverseProvidingFactor): ops.append(Kmat_var.assign(Kmat)) ops.append(mu_var.assign(mu)) + ops += super(FullyConnectedMultiKF, self).make_inverse_update_ops() return [control_flow_ops.group(*ops)] # pylint: enable=invalid-name diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index ce9005b9ce99a4efa5f2821c56e199dd2086482e..586a004f880e7bea2a772c53091285c2907ca31a 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -26,6 +26,7 @@ from __future__ import print_function from collections import defaultdict from collections import OrderedDict +from contextlib import contextmanager from functools import partial import math @@ -59,6 +60,10 @@ _CONV2D_APPROX_TO_BLOCK_TYPES = { APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, } +_EMBEDDING_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.EmbeddingKFACFB +} + APPROX_KRONECKER_INDEP_NAME = "kron_indep" APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1" APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2" @@ -71,10 +76,39 @@ _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = { option=2) } +_CONV2D_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB +} + +_EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.EmbeddingKFACMultiIndepFB +} + # Possible value for 'reuse' keyword argument. Sets 'reuse' to # tf.get_variable_scope().reuse. VARIABLE_SCOPE = "VARIABLE_SCOPE" +_DEFAULT_LAYER_COLLECTION = None + + +def get_default_layer_collection(): + """Get default LayerCollection.""" + if _DEFAULT_LAYER_COLLECTION is None: + raise ValueError( + "Attempted to retrieve default LayerCollection when none is set. Use " + "LayerCollection.as_default().") + + return _DEFAULT_LAYER_COLLECTION + + +def set_default_layer_collection(layer_collection): + global _DEFAULT_LAYER_COLLECTION + + if _DEFAULT_LAYER_COLLECTION is not None and layer_collection is not None: + raise ValueError("Default LayerCollection is already set.") + + _DEFAULT_LAYER_COLLECTION = layer_collection + class LayerParametersDict(OrderedDict): """An OrderedDict where keys are Tensors or tuples of Tensors. @@ -130,6 +164,8 @@ class LayerCollection(object): fisher_factors: an OrderedDict mapping tuples to FisherFactor instances. losses: a list of LossFunction objects. The loss to be optimized is their sum. + loss_colocation_ops: ops to colocate loss function evaluations with. These + will typically be the inputs to the losses. """ def __init__(self, @@ -145,17 +181,27 @@ class LayerCollection(object): self._default_generic_approximation = APPROX_FULL_NAME self._default_embedding_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_approximation = APPROX_KRONECKER_NAME - self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME + self._default_conv2d_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_multi_approximation = ( - APPROX_KRONECKER_SERIES_2_NAME) + APPROX_KRONECKER_INDEP_NAME) + self._default_conv2d_multi_approximation = ( + APPROX_KRONECKER_INDEP_NAME) + self._default_embedding_multi_approximation = APPROX_KRONECKER_INDEP_NAME + self.loss_colocation_ops = {} + self._vars_to_uses = defaultdict(lambda: 0) with variable_scope.variable_scope(None, default_name=name) as scope: self._var_scope = scope.name @property def losses(self): - """LossFunctions registered with this LayerCollection.""" - return list(self._loss_dict.values()) + """Tuple of LossFunction objects registered with this LayerCollection.""" + return nest.flatten(self.towers_by_loss) + + @property + def towers_by_loss(self): + """Tuple across losses of LossFunction objects registered to each tower.""" + return tuple(tuple(lst) for lst in self._loss_dict.values()) @property def registered_variables(self): @@ -214,14 +260,14 @@ class LayerCollection(object): @property def default_conv2d_approximation(self): - return self._default_convolution_2d_approximation + return self._default_conv2d_approximation def set_default_conv2d_approximation(self, value): if value not in _CONV2D_APPROX_TO_BLOCK_TYPES: raise ValueError( "{} is not a valid approximation for 2d convolutional layers.".format( value)) - self._default_convolution_2d_approximation = value + self._default_conv2d_approximation = value @property def default_fully_connected_multi_approximation(self): @@ -233,6 +279,14 @@ class LayerCollection(object): "multi layer.".format(value)) self._default_fully_connected_multi_approximation = value + @property + def default_conv2d_multi_approximation(self): + return self._default_conv2d_multi_approximation + + @property + def default_embedding_multi_approximation(self): + return self._default_embedding_multi_approximation + def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): """Validates and registers the layer_key associated with the fisher_block. @@ -290,23 +344,74 @@ class LayerCollection(object): self.fisher_blocks[layer_key] = fisher_block return fisher_block - def get_use_count_map(self): - """Returns a dict of variables to their number of registrations.""" - # TODO(b/70283403): Reimplement this in the old way, where each - # registration function would be responsible for incrementing the count. - # Also, this version has a bug: it won't do the right thing for generic - # registration for parameters that are shared. i.e. it won't set the use - # count to infinity. - vars_to_uses = defaultdict(int) - for key, block in six.iteritems(self.fisher_blocks): - n = ( - block.num_inputs()*block.num_registered_minibatches if isinstance( - block, (fb.FullyConnectedSeriesFB, fb.FullyConnectedMultiIndepFB)) - else block.num_registered_minibatches) - key = utils.ensure_sequence(key) - for k in key: - vars_to_uses[k] += n - return vars_to_uses + def register_loss_function(self, + loss, + colocation_op, + base_name, + name=None, + reuse=VARIABLE_SCOPE): + """Registers a LossFunction object. + + Args: + loss: The LossFunction object. + colocation_op: The op to colocate the loss function's computations with. + base_name: The name to derive a new unique name from is the name argument + is None. + name: (OPTIONAL) str or None. Unique name for this loss function. If None, + a new name is generated. (Default: None) + reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock. + If False, create a new FisherBlock. If VARIABLE_SCOPE, use + tf.get_variable_scope().reuse. + + Raises: + ValueError: If reuse == True and name == None. + ValueError: If reuse == True and seed != None. + KeyError: If reuse == True and no existing LossFunction with 'name' found. + KeyError: If reuse == False and existing LossFunction with 'name' found. + """ + + name = name or self._graph.unique_name(base_name) + + if reuse == VARIABLE_SCOPE: + reuse = variable_scope.get_variable_scope().reuse + + if reuse: + if name is None: + raise ValueError( + "If reuse is enabled, loss function's name must be set.") + + loss_list = self._loss_dict.get(name, None) + + if loss_list is None: + raise KeyError( + "Unable to find loss function named {}. Register a new loss " + "function with reuse=False.".format(name)) + else: + if name in self._loss_dict: + raise KeyError( + "Loss function named {} already exists. Set reuse=True to append " + "another tower.".format(name)) + + loss_list = [] + self._loss_dict[name] = loss_list + + loss_list.append(loss) + self.loss_colocation_ops[loss] = colocation_op + + def _get_use_count_map(self): + """Returns a dict mapping variables to their number of registrations.""" + return self._vars_to_uses + + def _add_uses(self, params, uses): + """Register additional uses by params in the graph. + + Args: + params: Variable or tuple of Variables. Parameters for a layer. + uses: int or float. Number of additional uses for these parameters. + """ + params = params if isinstance(params, (tuple, list)) else (params,) + for var in params: + self._vars_to_uses[var] += uses def check_registration(self, variables): """Checks that all variable uses have been registered properly. @@ -324,7 +429,7 @@ class LayerCollection(object): # Note that overlapping parameters (i.e. those that share variables) will # be caught by layer_collection.LayerParametersDict during registration. - reg_use_map = self.get_use_count_map() + reg_use_map = self._get_use_count_map() error_messages = [] @@ -414,12 +519,27 @@ class LayerCollection(object): inputs_to_losses = nest.flatten(tuple(loss.inputs for loss in self.losses)) self._subgraph = utils.SubGraph(inputs_to_losses) + def eval_losses(self): + """Return evaluated losses (colocated with inputs to losses).""" + evals = [] + for loss in self.losses: + with ops.colocate_with(self.loss_colocation_ops[loss]): + evals.append(loss.evaluate()) + return evals + + def eval_losses_on_samples(self): + """Return losses evaluated on samples (colocated with inputs to losses).""" + evals = [] + for loss in self.losses: + with ops.colocate_with(self.loss_colocation_ops[loss]): + evals.append(loss.evaluate_on_sample()) + return evals + def total_loss(self): - return math_ops.add_n(tuple(loss.evaluate() for loss in self.losses)) + return math_ops.add_n(self.eval_losses()) def total_sampled_loss(self): - return math_ops.add_n( - tuple(loss.evaluate_on_sample() for loss in self.losses)) + return math_ops.add_n(self.eval_losses_on_samples()) def _get_linked_approx(self, params): """If params were linked, return their specified approximation.""" @@ -429,45 +549,56 @@ class LayerCollection(object): else: return None + def _get_block_type(self, params, approx, default, approx_to_type): + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = default + + if approx not in approx_to_type: + raise ValueError("Bad value {} for approx.".format(approx)) + + return approx_to_type[approx], approx + def register_embedding(self, params, inputs, outputs, approx=None, reuse=VARIABLE_SCOPE): - """Registers a fully connnected layer. + """Registers an embedding layer. Args: params: Embedding matrix of shape [vocab_size, embedding_size]. inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices into embedding matrix. - outputs: Tensor of shape [batch_size, output_size]. Outputs + outputs: Tensor of shape [batch_size, embedding_size]. Outputs produced by layer. - approx: str. Must be "kron". - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + approx: str or None. If not None must be "kron". The Fisher + approximation to use. If None the default value is used. (Default: None) + reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_embedding_approximation - - if approx != APPROX_KRONECKER_NAME: - raise ValueError("Bad value {} for approx.".format(approx)) + block_type, approx = self._get_block_type( + params, approx, self.default_embedding_approximation, + _EMBEDDING_APPROX_TO_BLOCK_TYPES) if isinstance(params, (tuple, list)): raise ValueError("Bias not supported.") - vocab_size = int(params.shape[0]) block = self.register_block( - params, fb.EmbeddingKFACFB(self, vocab_size), reuse=reuse) - block.register_additional_minibatch(inputs, outputs) + params, block_type(self, vocab_size), reuse=reuse) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) def register_fully_connected(self, params, @@ -484,29 +615,31 @@ class LayerCollection(object): inputs: Tensor of shape [batch_size, input_size]. Inputs to layer. outputs: Tensor of shape [batch_size, output_size]. Outputs produced by layer. - approx: str. One of "kron" or "diagonal". - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_fully_connected_approximation - if approx not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) + block_type, approx = self._get_block_type( + params, approx, self.default_fully_connected_approximation, + _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES) - block_type = _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES[approx] has_bias = isinstance(params, (tuple, list)) + block = self.register_block(params, block_type(self, has_bias=has_bias), + reuse=reuse) + block.register_additional_tower(inputs, outputs) - block = self.register_block(params, block_type(self, has_bias), reuse=reuse) - block.register_additional_minibatch(inputs, outputs) + self._add_uses(params, 1) def register_conv2d(self, params, @@ -514,25 +647,33 @@ class LayerCollection(object): padding, inputs, outputs, + data_format=None, + dilations=None, approx=None, reuse=VARIABLE_SCOPE): - """Registers a convolutional layer. + """Registers a call to tf.nn.conv2d(). Args: params: Tensor or 2-tuple of Tensors corresponding to weight and bias of this layer. Weight matrix should have shape [kernel_height, kernel_width, in_channels, out_channels]. Bias should have shape [out_channels]. - strides: 1-D Tensor of length 4. Strides for convolution kernel. + strides: List of 4 ints. Strides for convolution kernel. padding: string. see tf.nn.conv2d for valid values. inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs to layer. outputs: Tensor of shape [batch_size, height, width, out_channels]. Output produced by layer. - approx: str. One of "kron" or "diagonal". - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + data_format: str or None. Format of data. + dilations: List of 4 ints. Dilations along each dimension. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. @@ -540,18 +681,228 @@ class LayerCollection(object): ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_conv2d_approximation + block_type, approx = self._get_block_type( + params, approx, self.default_conv2d_approximation, + _CONV2D_APPROX_TO_BLOCK_TYPES) + + # It feels bad to pass in configuration that has to do with the internal + # implementation. And then we can't use the same constructor for both + # anymore and are thus forced to use this ugly if-statement. + # TODO(b/74793309): Clean this up? + if approx == APPROX_KRONECKER_NAME: + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + data_format=data_format, + dilation_rate=dilations, + extract_patches_fn="extract_image_patches"), + reuse=reuse) + elif approx == APPROX_DIAGONAL_NAME: + assert strides[0] == strides[-1] == 1 + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + dilations=dilations, + data_format=data_format), + reuse=reuse) + else: + raise NotImplementedError(approx) - if approx not in _CONV2D_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_convolution(self, + params, + inputs, + outputs, + padding, + strides=None, + dilation_rate=None, + data_format=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Register a call to tf.nn.convolution(). + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [..filter_spatial_size.., + in_channels, out_channels]. Bias should have shape [out_channels]. + inputs: Tensor of shape [batch_size, ..input_spatial_size.., in_channels]. + Inputs to layer. + outputs: Tensor of shape [batch_size, ..output_spatial_size.., + out_channels]. Output produced by layer. + padding: string. see tf.nn.conv2d for valid values. + strides: List of ints of length len(..input_spatial_size..). Strides for + convolution kernel in spatial dimensions. + dilation_rate: List of ints of length len(..input_spatial_size..). + Dilations along spatial dimension. + data_format: str or None. Format of data. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + # TODO(b/74793309): Have this use _get_block_type like the other + # registration functions? + assert approx is None or approx == APPROX_KRONECKER_NAME - block_type = _CONV2D_APPROX_TO_BLOCK_TYPES[approx] block = self.register_block( - params, block_type(self, params, strides, padding), reuse=reuse) - block.register_additional_minibatch(inputs, outputs) + params, + fb.ConvKFCBasicFB( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + dilation_rate=dilation_rate, + data_format=data_format), + reuse=reuse) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_depthwise_conv2d(self, + params, + inputs, + outputs, + strides, + padding, + rate=None, + data_format=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Register a call to tf.nn.depthwise_conv2d(). + + Args: + params: 4-D Tensor of shape [filter_height, filter_width, + in_channels, channel_multiplier]. Convolutional filter. + inputs: Tensor of shape [batch_size, input_height, input_width, + in_channels]. Inputs to layer. + outputs: Tensor of shape [batch_size, output_height, output_width, + in_channels * channel_multiplier]. Output produced by depthwise conv2d. + strides: List of ints of length 4. Strides along all dimensions. + padding: string. see tf.nn.conv2d for valid values. + rate: None or List of ints of length 2. Dilation rates in spatial + dimensions. + data_format: str or None. Format of data. + approx: str or None. If not None must "diagonal". The Fisher + approximation to use. If None the default value is used. (Default: None) + reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + # TODO(b/74793309): Have this use _get_block_type like the other + # registration functions? + assert approx is None or approx == APPROX_DIAGONAL_NAME + assert data_format in [None, "NHWC"] + + block = self.register_block( + params, + fb.DepthwiseConvDiagonalFB( + layer_collection=self, + params=params, + strides=strides, + padding=padding, + rate=rate, + data_format=data_format), + reuse=reuse) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_separable_conv2d(self, + depthwise_params, + pointwise_params, + inputs, + depthwise_outputs, + pointwise_outputs, + strides, + padding, + rate=None, + data_format=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Register a call to tf.nn.separable_conv2d(). + + Note: This requires access to intermediate outputs between depthwise and + pointwise convolutions. + + Args: + depthwise_params: 4-D Tensor of shape [filter_height, filter_width, + in_channels, channel_multiplier]. Filter for depthwise conv2d. + pointwise_params: 4-D Tensor of shape [1, 1, in_channels * + channel_multiplier, out_channels]. Filter for pointwise conv2d. + inputs: Tensor of shape [batch_size, input_height, input_width, + in_channels]. Inputs to layer. + depthwise_outputs: Tensor of shape [batch_size, output_height, + output_width, in_channels * channel_multiplier]. Output produced by + depthwise conv2d. + pointwise_outputs: Tensor of shape [batch_size, output_height, + output_width, out_channels]. Output produced by pointwise conv2d. + strides: List of ints of length 4. Strides for depthwise conv2d kernel in + all dimensions. + padding: string. see tf.nn.conv2d for valid values. + rate: None or List of ints of length 2. Dilation rate of depthwise conv2d + kernel in spatial dimensions. + data_format: str or None. Format of data. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + self.register_depthwise_conv2d( + params=depthwise_params, + inputs=inputs, + outputs=depthwise_outputs, + strides=strides, + padding=padding, + rate=rate, + data_format=data_format, + approx=APPROX_DIAGONAL_NAME, + reuse=reuse) + + self.register_conv2d( + params=pointwise_params, + inputs=depthwise_outputs, + outputs=pointwise_outputs, + strides=[1, 1, 1, 1], + padding="VALID", + data_format=data_format, + approx=approx, + reuse=reuse) def register_generic(self, params, @@ -562,32 +913,32 @@ class LayerCollection(object): Args: params: Tensor or tuple of Tensors corresponding to the parameters. - batch_size: 0-D Tensor. Size of the minibatch. - approx: str. One of "full" or "diagonal". - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + batch_size: 0-D Tensor. Size of the minibatch (for this tower). + approx: str or None. It not None, must be one of "full" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds 'batch_size' to the total + mini-batch size use when estimating the Fisher block for this layer + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ + block_type, approx = self._get_block_type( + params, approx, self.default_generic_approximation, + _GENERIC_APPROX_TO_BLOCK_TYPES) - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_generic_approximation - - if approx not in _GENERIC_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) - - block_type = _GENERIC_APPROX_TO_BLOCK_TYPES[approx] block = self.register_block(params, block_type(self, params), reuse=reuse) - block.register_additional_minibatch(batch_size) + block.register_additional_tower(batch_size) + + self._add_uses(params, float("inf")) def register_fully_connected_multi(self, params, inputs, outputs, - approx=None): + num_uses=None, approx=None, + reuse=VARIABLE_SCOPE): """Register fully connected layers with shared parameters. This can handle general fully-connected layers with shared parameters, but @@ -598,34 +949,187 @@ class LayerCollection(object): params: Tensor or 2-tuple of Tensors corresponding to weight and bias of this layer. Weight matrix should have shape [input_size, output_size]. Bias should have shape [output_size]. - inputs: A list of tensors, each of shape [batch_size, input_size]. Inputs - to layer. In the case of RNNs, one Tensor per time step. - outputs: A list of tensors, the same length as 'inputs', each of shape - [batch_size, output_size]. Outputs produced by layer. In the case of - RNNs, one Tensor per time step. - approx: str. One of "kron_indep", "kron_series_1", or "kron_series_2". + inputs: A list of Tensors, each of shape [batch_size, input_size]. Inputs + to layer. The list indexes each use in the graph (which might + correspond to a "time-step" in an RNN). OR, can be single Tensor, of + shape [num_uses * batch_size , input_size], which is a reshaped version + of a Tensor of shape [num_uses, batch_size, input_size]. + outputs: A list of Tensors, the same length as 'inputs', each of shape + [batch_size, output_size]. Outputs produced by layer. The list indexes + each use in the graph (which might correspond to a "time-step" in an + RNN). Needs to correspond with the order used in 'inputs'. OR, can be + a single Tensor of shape [num_uses * batch_size, output_size], which is + a reshaped version of a Tensor of shape [num_uses, batch_size, + output_size]. + num_uses: int or None. The number uses/time-steps in the graph where the + layer appears. Only needed if both inputs and outputs are given in the + single Tensor format. (Default: None) + approx: str or None. If not None, must be of "kron_indep", "kron_series_1" + or "kron_series_2". The Fisher approximation to use. If None the default + value is used. (Default: None) + reuse: bool or str. If True, this adds inputs and outputs as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the + word 'use' here has a completely different meaning to "use in the graph" + as it perturns to the 'inputs', 'outputs', and 'num_uses' arguments.) + (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_fully_connected_multi_approximation - has_bias = isinstance(params, (tuple, list)) + block_type, approx = self._get_block_type( + params, approx, self.default_fully_connected_multi_approximation, + _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES) # TODO(b/70283649): something along the lines of find_canonical_output # should be added back in here (and for the other block types, arguably). - if approx not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) - block_type = _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES[approx] + has_bias = isinstance(params, (tuple, list)) + block = self.register_block(params, block_type(self, has_bias=has_bias, + num_uses=num_uses), + reuse=reuse) + block.register_additional_tower(inputs, outputs) + + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + + def register_conv2d_multi(self, + params, + strides, + padding, + inputs, + outputs, + num_uses=None, + data_format=None, + dilations=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers convolutional layers with shared parameters. - # For now we don't support multiple minibatches for this type of layer, so - # we set reuse=False - self.register_block(params, - block_type(self, inputs, outputs, has_bias=has_bias), - reuse=False) + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [kernel_height, + kernel_width, in_channels, out_channels]. Bias should have shape + [out_channels]. + strides: 1-D Tensor of length 4. Strides for convolution kernel. + padding: string. see tf.nn.conv2d for valid values. + inputs: A list of Tensors, each of shape [batch_size, height, width, + in_channels]. Inputs to layer. The list indexes each use in the graph + (which might correspond to a "time-step" in an RNN). OR, can be single + Tensor, of shape [num_uses * batch_size, height, width, in_channels], + which is a reshaped version of a Tensor of shape [num_uses, batch_size, + height, width, in_channels]. + outputs: A list of Tensors, each of shape [batch_size, height, width, + out_channels]. Output produced by layer. The list indexes each use + in the graph (which might correspond to a "time-step" in an RNN). + Needs to correspond with the order used in 'inputs'. OR, can be a + single Tensor, of shape [num_uses * batch_size, height, width, + out_channels], which is a reshaped version of a Tensor of shape + [num_uses, batch_size, height, width, out_channels]. + num_uses: int or None. The number uses/time-steps in the graph where the + layer appears. Only needed if both inputs and outputs are given in the + single Tensor format. (Default: None) + data_format: str or None. Format of data. + dilations: List of 4 ints. Dilations along each dimension. + approx: str or None. If not None must by "kron_indep". The Fisher + approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds inputs and outputs as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the + word 'use' here has a completely different meaning to "use in the graph" + as it perturns to the 'inputs', 'outputs', and 'num_uses' arguments.) + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_conv2d_multi_approximation, + _CONV2D_MULTI_APPROX_TO_BLOCK_TYPES) + + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + data_format=data_format, + dilation_rate=dilations, + extract_patches_fn="extract_image_patches", + num_uses=num_uses), + reuse=reuse) + + block.register_additional_tower(inputs, outputs) + + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + + # TODO(b/74108452): change the loss registration functions names to refer + # to "loss functions" instead of distributions. Following naming convention + # of the loss function classes themselves. + + def register_embedding_multi(self, + params, + inputs, + outputs, + num_uses=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers embedding layers with shared parameters. + + Args: + params: Embedding matrix of shape [vocab_size, embedding_size]. + inputs: A list of Tensors, each of shape [batch_size, input_size] and + dtype int32. Indices into embedding matrix. The list indexes each use + in the graph (which might correspond to a "time-step" in an RNN). + OR, can be single Tensor, of shape [num_uses, batch_size, input_size], + which is a reshaped version of a Tensor of shape [num_uses, batch_size, + input_size]. + outputs: A list of Tensors, each of shape [batch_size, embedding_size]. + Outputs produced by layer. The list indexes each use in the graph + (which might correspond to a "time-step" in an RNN). Needs to + correspond with the order used in 'inputs'. OR, can be a + single Tensor, of shape [num_uses * batch_size, embedding_size], which + is a reshaped version of a Tensor of shape [num_uses, batch_size, + embedding_size]. + num_uses: int or None. The number uses/time-steps in the graph where the + layer appears. Only needed if both inputs and outputs are given in the + single Tensor format. (Default: None) + approx: str or None. If not None must by "kron_indep". The Fisher + approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds inputs and outputs as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the + word 'use' here has a completely different meaning to "use in the graph" + as it perturns to the 'inputs', 'outputs', and 'num_uses' arguments.) + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_embedding_multi_approximation, + _EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES) + + if isinstance(params, (tuple, list)): + raise ValueError("Bias not supported.") + vocab_size = int(params.shape[0]) + + block = self.register_block( + params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, len(inputs)) def register_categorical_predictive_distribution(self, logits, @@ -645,53 +1149,24 @@ class LayerCollection(object): (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) - reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock. - If False, create a new FisherBlock. If VARIABLE_SCOPE, use - tf.get_variable_scope().reuse. - - Raises: - ValueError: If reuse == True and name == None. - ValueError: If reuse == True and seed != None. - KeyError: If reuse == True and no existing LossFunction with 'name' found. - KeyError: If reuse == False and existing LossFunction with 'name' found. + reuse: bool or str. If True, this adds 'logits' as an additional + mini-batch/tower of inputs to the loss-function/predictive distribution + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") """ - name = name or self._graph.unique_name( - "register_categorical_predictive_distribution") - - if reuse == VARIABLE_SCOPE: - reuse = variable_scope.get_variable_scope().reuse - - if reuse: - if name is None: - raise ValueError( - "If reuse is enabled, loss function's name must be set.") - if seed is not None: - raise ValueError( - "Seed can only be specified at LossFunction instantiation.") - - loss = self._loss_dict.get(name, None) - - if loss is None: - raise KeyError( - "Unable to find loss function named {}. Create a new LossFunction " - "with reuse=False.".format(name)) - - loss.register_additional_minibatch(logits, targets=targets) - else: - if name in self._loss_dict: - raise KeyError( - "Loss function named {} already exists. Set reuse=True to append " - "another minibatch.".format(name)) - loss = lf.CategoricalLogitsNegativeLogProbLoss( - logits, targets=targets, seed=seed) - self._loss_dict[name] = loss + loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets, + seed=seed) + self.register_loss_function(loss, logits, + "categorical_predictive_distribution", + name=name, reuse=reuse) def register_normal_predictive_distribution(self, mean, var=0.5, seed=None, targets=None, - name=None): + name=None, + reuse=VARIABLE_SCOPE): """Registers a normal predictive distribution. Args: @@ -708,21 +1183,23 @@ class LayerCollection(object): (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) + reuse: bool or str. If True, this adds 'mean' and 'var' as an additional + mini-batch/tower of inputs to the loss-function/predictive distribution + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") """ - name = name or self._graph.unique_name( - "register_normal_predictive_distribution") - if name in self._loss_dict: - raise NotImplementedError( - "Adding logits to an existing LossFunction not yet supported.") - loss = lf.NormalMeanNegativeLogProbLoss( - mean, var, targets=targets, seed=seed) - self._loss_dict[name] = loss + loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets, + seed=seed) + self.register_loss_function(loss, mean, + "normal_predictive_distribution", + name=name, reuse=reuse) def register_multi_bernoulli_predictive_distribution(self, logits, seed=None, targets=None, - name=None): + name=None, + reuse=VARIABLE_SCOPE): """Registers a multi-Bernoulli predictive distribution. Args: @@ -735,15 +1212,16 @@ class LayerCollection(object): (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) + reuse: bool or str. If True, this adds 'logits' as an additional + mini-batch/tower of inputs to the loss-function/predictive distribution + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") """ - name = name or self._graph.unique_name( - "register_multi_bernoulli_predictive_distribution") - if name in self._loss_dict: - raise NotImplementedError( - "Adding logits to an existing LossFunction not yet supported.") - loss = lf.MultiBernoulliNegativeLogProbLoss( - logits, targets=targets, seed=seed) - self._loss_dict[name] = loss + loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets, + seed=seed) + self.register_loss_function(loss, logits, + "multi_bernoulli_predictive_distribution", + name=name, reuse=reuse) def make_or_get_factor(self, cls, args): """Insert 'cls(args)' into 'self.fisher_factors' if not already present. @@ -772,3 +1250,10 @@ class LayerCollection(object): with variable_scope.variable_scope(self._var_scope): self.fisher_factors[key] = cls(*args) return self.fisher_factors[key] + + @contextmanager + def as_default(self): + """Sets this LayerCollection as the default.""" + set_default_layer_collection(self) + yield + set_default_layer_collection(None) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py index f8aa230d9ca1f542950f56b1e6cf1ab7ccd3d05f..9f4685380705bd409dbcd7e85d0e3bb4189a6adc 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py @@ -30,6 +30,8 @@ from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import _allowed_symbols = [ + "get_default_layer_collection", + "set_default_layer_collection", "LayerParametersDict", "LayerCollection", "APPROX_KRONECKER_NAME", diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py index cb3e698b9ceab920785adf735f88bd8e535a628f..e7d4243fc3d1c2d860693f2f62447b1c9aeeee03 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -57,30 +57,6 @@ class LossFunction(object): """The inputs to the loss function (excluding the targets).""" pass - @property - def input_minibatches(self): - """A `list` of inputs to the loss function, separated by minibatch. - - Typically there will be one minibatch per tower in a multi-tower setup. - Returns a list consisting of `self.inputs` by default; `LossFunction`s - supporting registering multiple minibatches should override this method. - - Returns: - A `list` of `Tensor`s representing - """ - return [self.inputs] - - @property - def num_registered_minibatches(self): - """Number of minibatches registered for this LossFunction. - - Typically equal to the number of towers in a multi-tower setup. - - Returns: - An `int` representing the number of registered minibatches. - """ - return len(self.input_minibatches) - def evaluate(self): """Evaluate the loss function on the targets.""" if self.targets is not None: @@ -474,7 +450,6 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): assert len(variance.shape) == 2, "Expect 2D variance tensor." self._mean = mean self._variance = variance - self._scale = math_ops.sqrt(variance) self._targets = targets super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed) @@ -484,7 +459,7 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): @property def dist(self): - return normal.Normal(loc=self._mean, scale=self._scale) + return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._variance)) @property def params(self): @@ -502,7 +477,7 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): @property def _fisher_mean_factor(self): - return 1. / self._scale + return 1. / math_ops.sqrt(self._variance) @property def _fisher_var(self): @@ -611,36 +586,13 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, index in [0, output_size). seed: int or None. Default random seed when sampling. """ - self._logits_components = [] - self._targets_components = [] - self.register_additional_minibatch(logits, targets=targets) + self._logits = logits + self._targets = targets super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed) - def register_additional_minibatch(self, logits, targets=None): - """Register an additiona minibatch's worth of parameters. - - Args: - logits: Tensor of shape [batch_size, output_size]. Parameters for - underlying distribution. - targets: None or Tensor of shape [batch_size, output_size]. Each row must - be a one-hot vector. - """ - self._logits_components.append(logits) - self._targets_components.append(targets) - - @property - def _logits(self): - return array_ops.concat(self._logits_components, axis=0) - - @property - def input_minibatches(self): - return self._logits_components - @property def targets(self): - if all(target is None for target in self._targets_components): - return None - return array_ops.concat(self._targets_components, axis=0) + return self._targets @property def dist(self): diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index 5d456bcb79ff00cedc1aaa7244cc8722d21f6e98..843aeef7d82df064b757ab4618f2b0ccbbec4cbe 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import warnings # pylint disable=long-line from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp from tensorflow.contrib.kfac.python.ops import estimator as est @@ -50,8 +51,9 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): name="KFAC", estimation_mode="gradients", colocate_gradients_with_ops=True, - cov_devices=None, - inv_devices=None): + batch_size=None, + placement_strategy=None, + **kwargs): """Initializes the KFAC optimizer with the given settings. Args: @@ -91,12 +93,13 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): colocate_gradients_with_ops: Whether we should request gradients we compute in the estimator be colocated with their respective ops. (Default: True) - cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. + batch_size: The size of the mini-batch. Only needed when momentum_type + == 'qmodel' or when automatic adjustment is used. (Default: None) + placement_strategy: string, Device placement strategy used when creating + covariance variables, covariance ops, and inverse ops. + (Default: `None`) + **kwargs: Arguments to be passesd to specific placement + strategy mixin. Check `placement.RoundRobinPlacementMixin` for example. Raises: ValueError: If the momentum type is unsupported. @@ -110,6 +113,13 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): if variables is None: variables = tf_variables.trainable_variables() + # Parameters to be passed to the Fisher estimator: + self._variables = variables + self._cov_ema_decay = cov_ema_decay + self._layers = layer_collection + self._estimation_mode = estimation_mode + self._colocate_gradients_with_ops = colocate_gradients_with_ops + # The below paramaters are required only if damping needs to be adapated. # These parameters can be set by calling # set_damping_adaptation_params() explicitly. @@ -130,17 +140,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): self._q_model_change = None self._update_damping_op = None - self._layers = layer_collection - self._fisher_est = est.FisherEstimator( - lambda: self.damping, - variables, - cov_ema_decay, - layer_collection, - estimation_mode=estimation_mode, - colocate_gradients_with_ops=colocate_gradients_with_ops, - cov_devices=cov_devices, - inv_devices=inv_devices) - momentum_type = momentum_type.lower() legal_momentum_types = ["regular", "adam", "qmodel"] @@ -148,20 +147,30 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): raise ValueError("Unsupported momentum type {}. Must be one of {}." .format(momentum_type, legal_momentum_types)) if momentum_type != "regular" and norm_constraint is not None: - raise ValueError("Update clipping is only supported with momentum" + raise ValueError("Update clipping is only supported with momentum " "type 'regular'.") if momentum_type not in ["regular", "adam"] and momentum != 0: raise ValueError("Momentum must be unspecified if using a momentum_type " "other than 'regular' or 'adam'.") + # Extra parameters of the optimizer self._momentum = momentum self._momentum_type = momentum_type self._norm_constraint = norm_constraint - - # this is a bit of a hack - # TODO(duckworthd): Handle this in a better way (e.g. pass it in?) - self._batch_size = array_ops.shape(layer_collection.losses[0].inputs)[0] - self._losses = layer_collection.losses + self._batch_size = batch_size + self._placement_strategy = placement_strategy + + with variable_scope.variable_scope(name): + self._fisher_est = est.make_fisher_estimator( + placement_strategy=placement_strategy, + variables=self._variables, + cov_ema_decay=self._cov_ema_decay, + damping=self.damping, + layer_collection=self._layers, + exps=(-1,), + estimation_mode=self._estimation_mode, + colocate_gradients_with_ops=self._colocate_gradients_with_ops, + **kwargs) super(KfacOptimizer, self).__init__(learning_rate, name=name) @@ -178,6 +187,10 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): style rule described in Section 6.5 of "Optimizing Neural Networks with Kronecker-factored Approximate Curvature". + Note that this function creates Tensorflow variables which store a few + scalars and are accessed by the ops which update the damping (as part + of the training op returned by the minimize() method). + Args: is_chief: `Boolean`, `True` if the worker is chief. prev_train_batch: Training data used to minimize loss in the previous @@ -199,6 +212,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): """ if self._adapt_damping: raise ValueError("Damping adaptation parameters already set.") + with variable_scope.variable_scope(self.get_name()): self._adapt_damping = True self._is_chief = is_chief @@ -219,64 +233,138 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): self._damping = variable_scope.get_variable( "damping", initializer=self._damping_constant, trainable=False) + @property + def variables(self): + return self._variables + + @property + def damping(self): + if self._damping: + return self._damping + else: + return self._damping_constant + + @property + def damping_adaptation_interval(self): + return self._damping_adaptation_interval + @property def cov_update_thunks(self): - return self._fisher_est.cov_update_thunks + self._maybe_make_and_save_everything() + return self._cov_update_thunks @property def cov_update_ops(self): - return self._fisher_est.cov_update_ops + self._maybe_make_and_save_everything() + return self._cov_update_ops @property def cov_update_op(self): - return self._fisher_est.cov_update_op + self._maybe_make_and_save_everything() + return self._cov_update_op @property def inv_update_thunks(self): - return self._fisher_est.inv_update_thunks + self._maybe_make_and_save_everything() + return self._inv_update_thunks @property def inv_update_ops(self): - return self._fisher_est.inv_update_ops + self._maybe_make_and_save_everything() + return self._inv_update_ops @property def inv_update_op(self): - return self._fisher_est.inv_update_op + self._maybe_make_and_save_everything() + return self._inv_update_op - @property - def variables(self): - return self._fisher_est.variables + def _maybe_make_and_save_everything(self): + if not self._fisher_est.made_vars(): + warnings.warn("These convenience properties will be depcrecated soon. " + "Please use explicit op/thunk creation methods instead " + "(e.g. make_ops_and_vars, etc).", + DeprecationWarning) + (self._cov_update_ops, self._cov_update_op, self._inv_update_ops, + self._inv_update_op, self._cov_update_thunks, + self._inv_update_thunks) = self.make_ops_and_vars() - @property - def damping(self): - if self._damping: - return self._damping - else: - return self._damping_constant + def make_ops_and_vars(self): + """Make ops and vars with device placement `self._placement_strategy`. - @property - def damping_adaptation_interval(self): - return self._damping_adaptation_interval + See `FisherEstimator.make_ops_and_vars` for details. + + Returns: + cov_update_ops: List of ops that compute the cov updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + cov_update_op: cov_update_ops grouped into a single op. + inv_update_ops: List of ops that compute the inv updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + cov_update_op: cov_update_ops grouped into a single op. + inv_update_op: inv_update_ops grouped into a single op. + """ + return self._fisher_est.make_ops_and_vars(scope=self.get_name()) + + def make_vars_and_create_op_thunks(self): + """Make vars and create op thunks. + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + scope = self.get_name() + "/" + self._fisher_est.name + return self._fisher_est.make_vars_and_create_op_thunks(scope=scope) + + def create_ops_and_vars_thunks(self): + """Create thunks that make the ops and vars on demand. + + This function returns 4 lists of thunks: cov_variable_thunks, + cov_update_thunks, inv_variable_thunks, and inv_update_thunks. + + The length of each list is the number of factors and the i-th element of + each list corresponds to the i-th factor (given by the "factors" property). + + Note that the execution of these thunks must happen in a certain + partial order. The i-th element of cov_variable_thunks must execute + before the i-th element of cov_update_thunks (and also the i-th element + of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks + must execute before the i-th element of inv_update_thunks. + + TL;DR (oversimplified): Execute the thunks according to the order that + they are returned. + + Returns: + cov_variable_thunks: A list of thunks that make the cov variables. + cov_update_thunks: A list of thunks that make the cov update ops. + inv_variable_thunks: A list of thunks that make the inv variables. + inv_update_thunks: A list of thunks that make the inv update ops. + """ + scope = self.get_name() + "/" + self._fisher_est.name + return self._fisher_est.create_ops_and_vars_thunks(scope=scope) def minimize(self, *args, **kwargs): - kwargs["var_list"] = kwargs.get("var_list") or self.variables - if set(kwargs["var_list"]) != set(self.variables): - raise ValueError("var_list doesn't match with set of Fisher-estimating " - "variables.") - if self._adapt_damping and self._is_chief: - global_step = kwargs.get("global_step", None) - if not global_step: - raise KeyError("global_step needs to be passed to optimizer.minimize " - "if damping parameter is adapted.") - update_damping_op = self._update_damping(self._prev_train_batch, - global_step) - with ops.control_dependencies([update_damping_op]): - loss = args[0] - loss_assign_op = state_ops.assign(self._prev_loss, loss) - train_op = super(KfacOptimizer, self).minimize(*args, **kwargs) - return control_flow_ops.group(loss_assign_op, train_op) - else: - return super(KfacOptimizer, self).minimize(*args, **kwargs) + # Should this variable scope encompass everything below? Or will the super- + # class make another copy of the same name scope? + with variable_scope.variable_scope(self.get_name()): + kwargs["var_list"] = kwargs.get("var_list") or self.variables + if set(kwargs["var_list"]) != set(self.variables): + raise ValueError("var_list doesn't match with set of Fisher-estimating " + "variables.") + if self._adapt_damping and self._is_chief: + global_step = kwargs.get("global_step", None) + if not global_step: + raise KeyError("global_step needs to be passed to optimizer.minimize " + "if damping parameter is adapted.") + update_damping_op = self._update_damping(self._prev_train_batch, + global_step) + with ops.control_dependencies([update_damping_op]): + loss = args[0] + loss_assign_op = state_ops.assign(self._prev_loss, loss) + train_op = super(KfacOptimizer, self).minimize(*args, **kwargs) + return control_flow_ops.group(loss_assign_op, train_op) + else: + return super(KfacOptimizer, self).minimize(*args, **kwargs) def compute_gradients(self, *args, **kwargs): # args[1] could be our var_list @@ -301,6 +389,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): Returns: An `Operation` that applies the specified gradients. """ + self._maybe_make_and_save_everything() # In Python 3, grads_and_vars can be a zip() object which can only be # iterated over once. By converting it to a list, we ensure that it can be # iterated over more than once. @@ -450,12 +539,12 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): = qmodel(alpha*precon_grad + mu*prev_update) - L(theta). """ - cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._losses, variables) + cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._layers.losses, + variables) # compute the matrix-vector products with the transposed Fisher factor fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads) fft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates) - batch_size = math_ops.cast( self._batch_size, dtype=fft_precon_grads[0].dtype) @@ -639,7 +728,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): # Go through variable and update its associated part of the velocity vector. return [_update_velocity(vec, var) for vec, var in vecs_and_vars] - # TODO(b/73448937): Move all update damping code to a separate class/function. def _update_damping(self, prev_batch, global_step): """Adapts damping parameter. Check KFAC (Section 6.5) for the details. diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py new file mode 100644 index 0000000000000000000000000000000000000000..bf12dbaa9adbaa4af1511034aef0b5ab59d53e26 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/placement.py @@ -0,0 +1,167 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements placement strategies for cov and inv ops, cov variables.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variable_scope + + +def _make_thunk_on_device(func, device): + def thunk(): + with tf_ops.device(device): + return func() + return thunk + + +class RoundRobinPlacementMixin(object): + """Implements round robin placement strategy for ops and variables.""" + + def __init__(self, cov_devices=None, inv_devices=None, *args, **kwargs): + """Initializes the RoundRobinPlacementMixin class. + + Args: + cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + *args: + **kwargs: + + """ + super(RoundRobinPlacementMixin, self).__init__(*args, **kwargs) + self._cov_devices = cov_devices + self._inv_devices = inv_devices + + def make_ops_and_vars(self, scope=None): + """Make ops and vars with a round-robin device placement strategy. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. A new device is chosen + for each factor by cycling through list of devices in the + `self._cov_devices` attribute. If `self._cov_devices` is `None` then no + explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the `self._inv_devices` attribute. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all ops will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_ops: List of ops that compute the cov updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + cov_update_op: cov_update_ops grouped into a single op. + inv_update_ops: List of ops that compute the inv updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + inv_update_op: inv_update_ops grouped into a single op. + cov_update_thunks: Thunks that make the ops in cov_update_ops. + inv_update_thunks: Thunks that make the ops in inv_update_ops. + """ + (cov_update_thunks, + inv_update_thunks) = self.make_vars_and_create_op_thunks(scope=scope) + cov_update_ops = [thunk() for thunk in cov_update_thunks] + inv_update_ops = [thunk() for thunk in inv_update_thunks] + + scope = self.name if scope is None else scope + with variable_scope.variable_scope(scope): + cov_update_op = control_flow_ops.group(cov_update_ops, + name="cov_update_op") + inv_update_op = control_flow_ops.group(inv_update_ops, + name="inv_update_op") + + return (cov_update_ops, cov_update_op, inv_update_ops, inv_update_op, + cov_update_thunks, inv_update_thunks) + + def make_vars_and_create_op_thunks(self, scope=None): + """Make vars and create op thunks w/ a round-robin device placement strat. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. A new device is chosen + for each factor by cycling through list of devices in the + `self._cov_devices` attribute. If `self._cov_devices` is `Non`e then no + explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the `self._inv_devices` attribute. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all thunks will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + # Note: `create_ops_and_vars_thunks` is implemented in `FisherEstimator`. + (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw, + inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope) + + if self._cov_devices: + cov_update_thunks = [] + for cov_variable_thunk, cov_update_thunk, device in zip( + cov_variable_thunks_raw, cov_update_thunks_raw, + itertools.cycle(self._cov_devices)): + with tf_ops.device(device): + cov_variable_thunk() + cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk, + device)) + else: + for cov_variable_thunk in cov_variable_thunks_raw: + cov_variable_thunk() + cov_update_thunks = cov_update_thunks_raw + + for inv_variable_thunk in inv_variable_thunks_raw: + inv_variable_thunk() + + if self._inv_devices: + inv_update_thunks = [] + for inv_update_thunk, device in zip(inv_update_thunks_raw, + itertools.cycle(self._inv_devices)): + inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk, + device)) + else: + inv_update_thunks = inv_update_thunks_raw + + return cov_update_thunks, inv_update_thunks diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index 88e6fb20e8f97528aea2a92752d79344c27bbf24..b6f42815e79fa5eb9c6a2aa9f99ac3ec5a70ad0a 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -24,11 +24,13 @@ from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -430,6 +432,127 @@ def batch_execute(global_step, thunks, batch_size, name=None): return result +def extract_convolution_patches(inputs, + filter_shape, + padding, + strides=None, + dilation_rate=None, + name=None, + data_format=None): + """Extracts inputs to each output coordinate in tf.nn.convolution. + + This is a generalization of tf.extract_image_patches() to tf.nn.convolution(), + where the number of spatial dimensions may be something other than 2. + + Assumes, + - First dimension of inputs is batch_size + - Convolution filter is applied to all input channels. + + Args: + inputs: Tensor of shape [batch_size, ..spatial_image_shape.., + ..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution(). + filter_shape: List of ints. Shape of filter passed to tf.nn.convolution(). + padding: string. Padding method. One of "VALID", "SAME". + strides: None or list of ints. Strides along spatial dimensions. + dilation_rate: None or list of ints. Dilation along spatial dimensions. + name: None or str. Name of Op. + data_format: None or str. Format of data. + + Returns: + Tensor of shape [batch_size, ..spatial_image_shape.., + ..spatial_filter_shape.., in_channels] + + Raises: + ValueError: If data_format does not put channel last. + ValueError: If inputs and filter disagree on in_channels. + """ + if not is_data_format_channel_last(data_format): + raise ValueError("Channel must be last dimension.") + with ops.name_scope(name, "extract_convolution_patches", + [inputs, filter_shape, padding, strides, dilation_rate]): + batch_size = inputs.shape.as_list()[0] + in_channels = inputs.shape.as_list()[-1] + + # filter_shape = spatial_filter_shape + [in_channels, out_channels] + spatial_filter_shape = filter_shape[:-2] + if in_channels != filter_shape[-2]: + raise ValueError("inputs and filter_shape must agree on in_channels.") + + # Map each input feature to a location in the output. + out_channels = np.prod(spatial_filter_shape) * in_channels + filters = linalg_ops.eye(out_channels) + filters = array_ops.reshape( + filters, + list(spatial_filter_shape) + [in_channels, out_channels]) + + result = nn_ops.convolution( + inputs, + filters, + padding=padding, + strides=strides, + dilation_rate=dilation_rate) + spatial_output_shape = result.shape.as_list()[1:-1] + result = array_ops.reshape(result, + [batch_size or -1] + spatial_output_shape + + list(spatial_filter_shape) + [in_channels]) + + return result + + +def extract_pointwise_conv2d_patches(inputs, + filter_shape, + name=None, + data_format=None): + """Extract patches for a 1x1 conv2d. + + Args: + inputs: 4-D Tensor of shape [batch_size, height, width, in_channels]. + filter_shape: List of 4 ints. Shape of filter to apply with conv2d() + name: None or str. Name for Op. + data_format: None or str. Format for data. See 'data_format' in + tf.nn.conv2d() for details. + + Returns: + Tensor of shape [batch_size, ..spatial_input_shape.., + ..spatial_filter_shape.., in_channels] + + Raises: + ValueError: if inputs is not 4-D. + ValueError: if filter_shape is not [1, 1, ?, ?] + ValueError: if data_format is not channels-last. + """ + if inputs.shape.ndims != 4: + raise ValueError("inputs must have 4 dims.") + if len(filter_shape) != 4: + raise ValueError("filter_shape must have 4 dims.") + if filter_shape[0] != 1 or filter_shape[1] != 1: + raise ValueError("filter_shape must have shape 1 along spatial dimensions.") + if not is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels last.") + with ops.name_scope(name, "extract_pointwise_conv2d_patches", + [inputs, filter_shape]): + ksizes = [1, 1, 1, 1] # Spatial shape is 1x1. + strides = [1, 1, 1, 1] # Operate on all pixels. + rates = [1, 1, 1, 1] # Dilation has no meaning with spatial shape = 1. + padding = "VALID" # Doesn't matter. + result = array_ops.extract_image_patches(inputs, ksizes, strides, rates, + padding) + + batch_size, input_height, input_width, in_channels = inputs.shape.as_list() + filter_height, filter_width, in_channels, _ = filter_shape + return array_ops.reshape(result, [ + batch_size, input_height, input_width, filter_height, filter_width, + in_channels + ]) + + +def is_data_format_channel_last(data_format): + """True if data_format puts channel last.""" + if data_format is None: + return True + return data_format.endswith("C") + + def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name """Computes matmul(A, B) where A is sparse, B is dense. @@ -482,5 +605,93 @@ def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1)) return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape) + +class PartitionedTensor(object): + """A Tensor partitioned across its 0-th dimension.""" + + def __init__(self, tensors): + """Initializes PartitionedTensor. + + Args: + tensors: List of Tensors. All Tensors must agree on shape (excepting + batch dimension) and dtype. + + Raises: + ValueError: If 'tensors' has length zero. + ValueError: if contents of 'tensors' don't agree on shape or dtype. + """ + if not tensors: + raise ValueError("tensors must be a list of 1+ Tensors.") + + dtype = tensors[0].dtype + if not all(tensor.dtype == dtype for tensor in tensors): + raise ValueError("all tensors must have dtype = %s." % dtype) + + shape = tensors[0].shape[1:] + if not all(tensor.shape[1:] == shape for tensor in tensors): + raise ValueError("All tensors must have shape = %s (excluding batch " + "dimension)." % shape) + + self.tensors = tensors + self._concats = {} # {device: Tensor} + + @property + def shape(self): + feature_shape = self.tensors[0].shape[1:] + batch_size = sum([tensor.shape[0] for tensor in self.tensors], + tensor_shape.Dimension(0)) + return tensor_shape.TensorShape([batch_size]).concatenate(feature_shape) + + def get_shape(self): + return self.shape + + @property + def dtype(self): + return self.tensors[0].dtype + + def __str__(self): + return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % ( + self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list())) + + def __hash__(self): + return hash(tuple(self.tensors)) + + def __eq__(self, other): + if not isinstance(other, PartitionedTensor): + return False + return self.tensors == other.tensors + + def __ne__(self, other): + return not self == other # pylint: disable=g-comparison-negation + + def __getitem__(self, key): + return self.as_tensor()[key] + + def as_tensor(self, dtype=None, name=None, as_ref=False): + with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors): + assert not as_ref + assert dtype in [None, self.dtype] + result = array_ops.concat(self.tensors, axis=0) + + # Cache 'result' if we haven't already cached a value for this device. + if result.device not in self._concats: + self._concats[result.device] = result + return self._concats[result.device] + + @property + def device(self): + # PartitionedTensors in general do not live on a single device. If the + # device cannot be determined unambiguously this property will return None. + device = self.tensors[0].device + if all(tensor.device == device for tensor in self.tensors): + return device + return None + + +ops.register_tensor_conversion_function( + PartitionedTensor, + lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref)) + + # TODO(b/69623235): Add a function for finding tensors that share gradients # to eliminate redundant fisher factor computations. diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py index 8e424a794691484fdea7d8481677aa641c433d4c..330d222dbf70fcfa02ffd47261c0513d9dd6e0e9 100644 --- a/tensorflow/contrib/kfac/python/ops/utils_lib.py +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -40,6 +40,9 @@ _allowed_symbols = [ "fwd_gradients", "ensure_sequence", "batch_execute", + "extract_convolution_patches", + "extract_pointwise_conv2d_patches", + "is_data_format_channel_last", "matmul_sparse_dense", "matmul_diag_sparse", ] diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py index b62e3050cd7003f1ba72061b133ff9b5d6b616da..ffa208540dae975cb139ad6d76dcf392678ba0ee 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -470,7 +470,7 @@ def embedding_lookup_unique(params, ids, name=None): ids = ops.convert_to_tensor(ids) shape = array_ops.shape(ids) ids_flat = array_ops.reshape( - ids, math_ops.reduce_prod(shape, keep_dims=True)) + ids, math_ops.reduce_prod(shape, keepdims=True)) unique_ids, idx = array_ops.unique(ids_flat) unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids) embeds_flat = array_ops.gather(unique_embeddings, idx) diff --git a/tensorflow/contrib/layers/python/layers/encoders.py b/tensorflow/contrib/layers/python/layers/encoders.py index 89c9d37bd09cb6c43eebb91f3a16600eae9cb490..f42112206d0db9d2e42bd4cff19f6a6533951d46 100644 --- a/tensorflow/contrib/layers/python/layers/encoders.py +++ b/tensorflow/contrib/layers/python/layers/encoders.py @@ -125,7 +125,7 @@ def embed_sequence(ids, `reuse` is `None` or `False`. """ if not (reuse or (vocab_size and embed_dim)): - raise ValueError('Must specify vocab size and embedding dimension when not' + raise ValueError('Must specify vocab size and embedding dimension when not ' 'reusing. Got vocab_size=%s and embed_dim=%s' % ( vocab_size, embed_dim)) with variable_scope.variable_scope( diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 80cbe68870808328b387e2044fe236af5a5e39f8..350bcb3bca11b4cad18ce863ab1496076477aa3c 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -2747,7 +2747,7 @@ def softmax(logits, scope=None): logits_2d = array_ops.reshape(logits, [-1, num_logits]) predictions = nn.softmax(logits_2d) predictions = array_ops.reshape(predictions, array_ops.shape(logits)) - if context.in_graph_mode(): + if not context.executing_eagerly(): predictions.set_shape(logits.get_shape()) return predictions diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 123275e1fde047cd3772528641b2e3b09742fbdc..0b38c0c3fdd84cf432c334554eba3a9b0e44084c 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -29,6 +29,7 @@ from __future__ import print_function import functools import re +import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.framework.python import ops as contrib_framework_ops @@ -37,6 +38,7 @@ from tensorflow.python.framework import ops as framework_ops from tensorflow.python.layers import base from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope @@ -46,6 +48,7 @@ from tensorflow.python.util import nest __all__ = ["rev_block", "RevBlock", "recompute_grad"] LAYER_RE = re.compile(".*revlayer_([0-9]*)/([fg])/.*") +_USE_DEFAULT = "__rev_block_lib_default" def _acc_grads(*lists_of_grads): @@ -219,7 +222,13 @@ class RevBlock(base.Layer): def _efficient_grad_fn(self, inputs, variables, ys, grad_ys): """Custom gradient fn for a block of reversible residual layers.""" + # Inputs have passed through an Identity. Recover the original Tensors to + # be able to match up side inputs. + assert [u"Identity"] == list(set([x.op.type for x in inputs])) + inputs = [x.op.inputs[0] for x in inputs] side_inputs = inputs[2:] + del inputs + f_side_idxs = [None] * len(self.f_side_input) g_side_idxs = [None] * len(self.g_side_input) assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input) @@ -405,12 +414,36 @@ def rev_block(x1, return block.forward(x1, x2) -def recompute_grad(fn): +def enable_with_args(dec): + """A decorator for decorators to enable their usage with or without args.""" + + @functools.wraps(dec) + def new_dec(*args, **kwargs): + if len(args) == 1 and not kwargs and callable(args[0]): + # Used as decorator without args + fn = args[0] + return dec(fn) + else: + return lambda fn: dec(fn, *args, **kwargs) + + return new_dec + + +@enable_with_args +def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False): """Decorator that recomputes the function on the backwards pass. Args: fn: a function that takes Tensors (all as positional arguments) and returns a tuple of Tensors. + use_data_dep: `bool`, if `True` will use a dummy data dependency to force + the recompute to happen. If `False` will use a control dependency. By + default will be `True` if in an XLA context and `False` otherwise. XLA + ignores control dependencies and so this data dependency is necessary. + tupleize_grads: `bool`, if `True` will use control dependencies to ensure + that all gradients are produced before any are consumed by downstream ops. + If `use_data_dep` is also `True`, will use a data dependency instead of + a control dependency. Returns: A wrapped fn that is identical to fn when called, but its activations will @@ -420,13 +453,25 @@ def recompute_grad(fn): @functools.wraps(fn) def wrapped(*args): - return _recompute_grad(fn, args) + return _recompute_grad( + fn, args, use_data_dep=use_data_dep, tupleize_grads=tupleize_grads) return wrapped -def _recompute_grad(fn, args): +def _is_on_tpu(): + ctxt = framework_ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + return control_flow_util.GetContainingXLAContext(ctxt) is not None + + +def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): """See recompute_grad.""" + for arg in args: + if not isinstance(arg, framework_ops.Tensor): + raise ValueError("All inputs to function must be Tensors") + use_data_dep_ = use_data_dep + if use_data_dep_ == _USE_DEFAULT: + use_data_dep_ = _is_on_tpu() cached_vs = [] cached_arg_scope = [] @@ -436,6 +481,8 @@ def _recompute_grad(fn, args): del outputs # Recompute outputs with framework_ops.control_dependencies(output_grads): + if use_data_dep_: + inputs = _force_data_dependency(output_grads, inputs) with contrib_framework_ops.arg_scope(cached_arg_scope[0]): with variable_scope.variable_scope(cached_vs[0], reuse=True): outputs = fn(*inputs) @@ -444,6 +491,13 @@ def _recompute_grad(fn, args): outputs = [outputs] outputs = list(outputs) grads = gradients_impl.gradients(outputs, inputs + variables, output_grads) + + if tupleize_grads: + if use_data_dep_: + grads = _tuple_with_data_dep(grads) + else: + grads = control_flow_ops.tuple(grads) + grad_inputs = grads[:len(inputs)] grad_vars = grads[len(inputs):] return grad_inputs, grad_vars @@ -532,7 +586,7 @@ def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False): get_vars_fn = ( vs.global_variables if use_global_vars else vs.trainable_variables) len_before_vars = len(get_vars_fn()) - inputs = list(inputs) + inputs = [array_ops.identity(x) for x in inputs] outputs = fn(*inputs) train_vars = get_vars_fn()[len_before_vars:] @@ -581,3 +635,46 @@ def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False): flat_inputs = nest.flatten(defun_inputs) id_out = identity(*flat_inputs) return id_out + + +def _force_data_dependency(first_compute, then_compute): + """Force all of `then_compute` to depend on all of `first_compute`. + + Uses a dummy data dependency, which is useful when running on TPUs because + XLA ignores control dependencies. Only supports float arguments. + + Args: + first_compute: `list`. These will be made to run before the + `Tensor`s `then_compute`. + then_compute: `list`. These will run after all the `Tensor`s in + `first_compute`. + + Returns: + `list`, same length as `then_compute`. + + Raises: + ValueError: if ranks are unknown or types are not floating. + """ + + def _first_element(x): + if x.get_shape().ndims is None: + raise ValueError("Rank of Tensor %s must be known" % x) + ndims = x.get_shape().ndims + return array_ops.reshape(array_ops.slice(x, [0] * ndims, [1] * ndims), []) + + first_compute_sum = math_ops.add_n( + [_first_element(x) for x in first_compute if x is not None]) + dtype = first_compute_sum.dtype + if not dtype.is_floating: + raise ValueError("_force_data_dependency only supports floating dtypes.") + epsilon = np.finfo(dtype.as_numpy_dtype).tiny + zero = array_ops.stop_gradient(epsilon * first_compute_sum) + + return [ + array_ops.identity(x) + zero if x is not None else None + for x in then_compute + ] + + +def _tuple_with_data_dep(tensors): + return _force_data_dependency(tensors, tensors) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index cbcbcd75114a522b95631e4e7e95c1641b0a9987..d1ad4e8c98de3e5c5ac212d55cc93707ba9c01cc 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -154,7 +154,7 @@ class RevBlockTest(test.TestCase): y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads]) self.assertAllClose(y_val, yd_val) for g1, g2 in zip(gd_val, g_val): - self.assertAllClose(g1, g2) + self.assertAllClose(g1, g2, rtol=1e-5) def testRevBlock(self): self._testRevBlock() @@ -255,25 +255,54 @@ class RecomputeTest(test.TestCase): def fn_recompute(x): return fn(x) + @rev_block_lib.recompute_grad(use_data_dep=True) + def fn_use_data_dep(x): + return fn(x) + + @rev_block_lib.recompute_grad(tupleize_grads=True) + def fn_tupleize(x): + return fn(x) + + @rev_block_lib.recompute_grad(use_data_dep=True, tupleize_grads=True) + def fn_both(x): + return fn(x) + x = random_ops.random_uniform((3, 1, 3)) - recompute_vars = None - with variable_scope.variable_scope("recompute") as vs: - out1 = math_ops.reduce_sum(fn_recompute(x)) - recompute_vars = vs.trainable_variables() - reg_vars = None - with variable_scope.variable_scope("regular") as vs: - out2 = math_ops.reduce_sum(fn(x)) - reg_vars = vs.trainable_variables() - - grad1 = gradients_impl.gradients(out1, recompute_vars) - grad2 = gradients_impl.gradients(out2, reg_vars) + + names_and_fns = [ + ("recompute", fn_recompute), + ("regular", fn), + ("use_data_dep", fn_use_data_dep), + ("tupleize", fn_tupleize), + ("tuple_and_data_dep", fn_both), + ] + outputs_and_vars = [] + for name, wrapped_fn in names_and_fns: + with variable_scope.variable_scope(name) as vs: + out = math_ops.reduce_sum(wrapped_fn(x)) + outputs_and_vars.append((out, vs.trainable_variables())) + + all_grads = [] + for out, scope_vars in outputs_and_vars: + all_grads.append(gradients_impl.gradients(out, scope_vars)) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) - outs = sess.run([out1, out2, grad1, grad2]) - self.assertAllClose(outs[0], outs[1]) - for g1, g2 in zip(outs[2], outs[3]): - self.assertAllClose(g1, g2) + outputs = list(zip(*outputs_and_vars))[0] + outs, all_grads_val = sess.run([outputs, all_grads]) + + # All outputs are the same + current = outs[0] + for out in outs[1:]: + self.assertAllClose(current, out) + current = out + + # All gradients are the same + for grads in zip(all_grads_val): + current = grads[0] + for g in grads[1:]: + self.assertAllClose(current, g) + current = g class FnWithCustomGradTest(test.TestCase): diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index abf6e393bb0fbbce4e43f6d209e9b30517df36c3..16f80a876fac5e19bb8ce13074759c704c113947 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -5,6 +5,8 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) +load("//tensorflow:tensorflow.bzl", "py_test") + package(default_visibility = [ "//engedu/ml/tf_from_scratch:__pkg__", "//tensorflow:internal", @@ -224,6 +226,7 @@ py_test( size = "small", srcs = ["python/learn/monitors_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip_gpu"], # b/74437598 deps = [ ":learn", "//tensorflow/contrib/framework:framework_py", @@ -426,6 +429,10 @@ py_test( size = "medium", srcs = ["python/learn/estimators/kmeans_test.py"], srcs_version = "PY2AND3", + tags = [ + "noasan", # b/73741358 + "nomac", + ], deps = [ ":learn", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/learn/README.md b/tensorflow/contrib/learn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d516bffc5e0327a3400068b35de5503e5a925a54 --- /dev/null +++ b/tensorflow/contrib/learn/README.md @@ -0,0 +1,143 @@ +EVERYTHING IN THIS DIRECTORY IS DEPRECATED. + +Using functions or classes will result in warnings. + +Instructions for converting to current alternatives are included in the +warnings. A high-level overview is below. + +## Canned Estimators + +Many canned estimators (subclasses of `Estimator`) have equivalents in core: +`DNNClassifier`, `DNNRegressor`, `DNNEstimator`, `LinearClassifier`, +`LinearRegressor`, `DNNLinearCombinedClassifier` and +`DNNLinearCombinedRegressor`. They are exposed under `tf.estimator`. +`DNNEstimator`, `LinearEstimator` and `DNNLinearCombinedEstimator` +are exposed under `tf.contrib.estimator`. + +To migrate to the new api, users need to take the following steps: + +* Replace `tf.contrib.learn` with `tf.estimator`. +* If you subclass any of the estimators, stop doing that. You should be able to + write a factory method that returns a canned estimator instead. If this is not + possible (if you override methods from the canned estimator), consider writing + a custom estimator instead. See `tf.estimator.Estimator`. +* Set `loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE` to preserve loss + reduction as the average over batch. +* Some optimizer-related arguments are no longer passed in the estimator + constructor. Instead, we provide methods that perform the same job by wrapping + an optimizer. Specifically: + * `gradient_clip_norm`: Use `tf.contrib.estimator.clip_gradients_by_norm` + * `embedding_lr_multipliers`: Not supported. + Other arguments: + * `input_layer_min_slice_size`: Replaced by `input_layer_partitioner` + * `enable_centered_bias`: Not supported. Dropping this argument is unlikely to + harm your model. + * `feature_engineering_fn`: Not supported. You can call your + `feature_engineering_fn` inside your input_fn: + ```python + def new_input_fn(): + features, labels = old_input_fn() + return feature_engineering_fn(features, labels) + ``` +* Use `tf.reshape` to reshape labels in your `input_fn`. `tf.estimator` + classifiers and regressors expect labels as a 2D Tensor of shape + `[batch_size, 1]`, or `[batch_size, n_labels]`. In contrast, + `tf.contrib.learn` classifiers and regressors supported labels with shape + `[batch_size]`. +* If you pass custom metrics from the `evaluate()` method call, use + `tf.contrib.estimator.add_metrics`. +* Replace your `serving_input_fn` with a `serving_input_receiver_fn`. + Note this should be entirely distinct from your training `input_fn`, so if you + previously had one `input_fn` with different "modes", you should now factor + that apart. Where the former returned either a simple `(features, labels)` + tuple or `InputFnOps`, you should now return a `ServingInputReceiver`. + If you were generating your `serving_input_fn` using the + `build_parsing_serving_input_fn` helper, you can simply drop in the + replacement `build_parsing_serving_input_receiver_fn`. + +Some remaining estimators/classes: + +* `DynamicRnnEstimator`: Consider a custom `model_fn`. +* `KMeansClustering`: Use `tf.contrib.factorization.KMeansClustering`. +* `LogisticRegressor`: Not supported. Instead, use `binary_classification_head` + with a custom `model_fn`, or with `DNNEstimator`. +* `StateSavingRnnEstimator`: Consider a custom `model_fn`. +* SVM: Consider a custom `model_fn`. +* `LinearComposableModel` and `DNNComposableModel`: Not supported. + Consider `tf.contrib.estimator.DNNEstimator`, or write a custom model_fn. +* `MetricSpec`: Deprecated. For adding custom metrics to canned Estimators, use + `tf.contrib.estimator.add_metrics`. + +## Estimator +`tf.contrib.learn.Estimator` is migrated to `tf.estimator.Estimator`. + +To migrate, users need to take the following steps: + +* Replace `tf.contrib.learn.Estimator` with `tf.estimator.Estimator`. +* If you pass a `config` argument to `Estimator`, this must be + `tf.estimator.RunConfig`. You may need to edit your code accordingly. +* Edit your `model_fn` to return `tf.estimator.EstimatorSpec`. Refer to + `EstimatorSpec` for documentation of specific fields. +* If your `model_fn` uses the `mode` argument, use `tf.estimator.ModeKeys`. + +Some related classes: +* `Evaluable`, `Trainable`: Not supported, merged into `tf.estimator.Estimator`. +* ExportStrategy: Replaced by `tf.estimator.Exporter`. + +## Head/MultiHead +These classes are now supported under `tf.contrib.estimator`, e.g. +`tf.contrib.estimator.multi_class_head` and `tf.contrib.estimator.multi_head`. + +Some differences: + +* `multi_class_head`: If you use `tf.contrib.learn.multi_class_head` with + `n_classes=2`, switch to `tf.contrib.estimator.binary_classification_head`. +* `loss_only_head`: Not supported. +* `poisson_regression_head`: Not supported (yet). +* `binary_svm_head`: Not supported (yet). +* `no_op_train_fn`: Replace it with `tf.no_op`. + +Some arguments are renamed, please refer to documentation. In addition: + +* `loss_fn`: Supported for `multi_label_head`. If you need it for other heads, + please open an issue. +* `metric_class_ids`: Not supported (yet). +* `enable_centered_bias`: Not supported. Dropping this argument is unlikely to + harm your model. +* `label_name`: Not needed in `tf.estimator`. If you don’t use `multi_head`, + drop this argument. If you use `multi_head`, refer to + `tf.contrib.estimator.multi_head` documentation. + +## Experiment Class - Distributed Training Tooling + +Switch to `tf.estimator.train_and_evaluate`. Some differences: + +* Most of the constructor arguments, like `train_input_fn`, `eval_input_fn`, + should be wrapped into `tf.estimator.TrainSpec` and `tf.estimator.EvalSpec`. +* Remove the `experiment_fn`. Instead, create the `Estimator`, + `train_spec` and `eval_spec`, then call `tf.estimator.train_and_evaluate` + directly. +* Inside `tf.estimator.EvalSpec`, the `exporter` field is the replacement + for `export_strategy`. To be precise, `tf.estimator.LatestExporter` is the + replacement for `tf.contrib.learn.make_export_strategy`. If you want to export + only at the end of training use `tf.estimator.FinalExporter`. +* If the `TF_CONFIG` environment variable is constructed manually, please read + the `train_and_evaluate` documentation for the new requirementds (in + particular, the chief node and evaluator node). + +## Others Classes and Functions + +* `tf.contrib.learn.datasets` is deprecated. We are adding ready to use datasets + to tensorflow/models. Many smaller datasets are available from other sources, + such as scikits.learn. Some Python processing may have to be written, but this + is straightforward to implement using the standard modules. +* `tf.contrib.learn.preprocessing`: Deprecated. The python-only preprocessing + functions are not a good fit for TensorFlow. Please use `tf.data`, and + consider tensorflow/transform for more complex use cases. +* `tf.contrib.learn.models`: Not supported, use canned estimators instead. +* `tf.contrib.learn.monitors`: Implement `SessionRunHook` instead. Hook + implementations are in `tf.train`. +* `tf.contrib.learn.learn_io`: Use the methods in `tf.estimator.inputs`, such as + `tf.estimator.inputs.numpy_input_fn`. Some utility functions have no + equivalent, we encourage the use of `tf.data`. + diff --git a/tensorflow/contrib/learn/__init__.py b/tensorflow/contrib/learn/__init__.py index 3698af027e38f1063ad829c26eb179734968f813..79bd73faaf1301a2fc4999b64f88d30542577980 100644 --- a/tensorflow/contrib/learn/__init__.py +++ b/tensorflow/contrib/learn/__init__.py @@ -13,8 +13,11 @@ # limitations under the License. # ============================================================================== -# TODO(ptucker,ipolosukhin): Improve descriptions. -"""High level API for learning. +"""High level API for learning (DEPRECATED). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. See the @{$python/contrib.learn} guide. diff --git a/tensorflow/contrib/learn/python/__init__.py b/tensorflow/contrib/learn/python/__init__.py index bbebd5ab9792cb937219cf937f08c4d4e6e44a92..df23aeb2c433c2b4392f706730f715246ce01cea 100644 --- a/tensorflow/contrib/learn/python/__init__.py +++ b/tensorflow/contrib/learn/python/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""High level API for learning with TensorFlow.""" +"""High level API for learning with TensorFlow (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/__init__.py b/tensorflow/contrib/learn/python/learn/__init__.py index cdc67c77d5fd1df61016835dc75ba44feb458cf9..76e0e8ac8f19026086959f3b197cfd1a81e65a3e 100644 --- a/tensorflow/contrib/learn/python/learn/__init__.py +++ b/tensorflow/contrib/learn/python/learn/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""High level API for learning with TensorFlow.""" +"""High level API for learning with TensorFlow (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py b/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py index 2284ec46e971731af74f17678fc0d1d3888419e2..fed1c44d1970bf07c808ace817aa9972d7776d88 100644 --- a/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py +++ b/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py @@ -12,20 +12,47 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Some common SessionRunHook classes.""" +"""Some common SessionRunHook classes (deprected). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.util.deprecation import deprecated_alias # pylint: disable=invalid-name -LoggingTensorHook = basic_session_run_hooks.LoggingTensorHook -StopAtStepHook = basic_session_run_hooks.StopAtStepHook -CheckpointSaverHook = basic_session_run_hooks.CheckpointSaverHook -StepCounterHook = basic_session_run_hooks.StepCounterHook -NanLossDuringTrainingError = basic_session_run_hooks.NanLossDuringTrainingError -NanTensorHook = basic_session_run_hooks.NanTensorHook -SummarySaverHook = basic_session_run_hooks.SummarySaverHook +LoggingTensorHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.LoggingTensorHook', + 'tf.train.LoggingTensorHook', + basic_session_run_hooks.LoggingTensorHook) +StopAtStepHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.StopAtStepHook', + 'tf.train.StopAtStepHook', + basic_session_run_hooks.StopAtStepHook) +CheckpointSaverHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.CheckpointSaverHook', + 'tf.train.CheckpointSaverHook', + basic_session_run_hooks.CheckpointSaverHook) +StepCounterHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.StepCounterHook', + 'tf.train.StepCounterHook', + basic_session_run_hooks.StepCounterHook) +NanLossDuringTrainingError = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.NanLossDuringTrainingError', + 'tf.train.NanLossDuringTrainingError', + basic_session_run_hooks.NanLossDuringTrainingError) +NanTensorHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.NanTensorHook', + 'tf.train.NanTensorHook', + basic_session_run_hooks.NanTensorHook) +SummarySaverHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.SummarySaverHook', + 'tf.train.SummarySaverHook', + basic_session_run_hooks.SummarySaverHook) # pylint: enable=invalid-name diff --git a/tensorflow/contrib/learn/python/learn/datasets/__init__.py b/tensorflow/contrib/learn/python/learn/datasets/__init__.py index 7240b0de149051afa045a8113f9e9b212840c311..3c34712ac859d32f549468345950a93d2ed2aa56 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/__init__.py +++ b/tensorflow/contrib/learn/python/learn/datasets/__init__.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Dataset utilities and synthetic/reference datasets.""" +"""Dataset utilities and synthetic/reference datasets (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -27,6 +32,7 @@ from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.contrib.learn.python.learn.datasets import mnist from tensorflow.contrib.learn.python.learn.datasets import synthetic from tensorflow.contrib.learn.python.learn.datasets import text_datasets +from tensorflow.python.util.deprecation import deprecated # Export load_iris and load_boston. load_iris = base.load_iris @@ -51,6 +57,7 @@ SYNTHETIC = { } +@deprecated(None, 'Please use tf.data.') def load_dataset(name, size='small', test_with_fake_data=False): """Loads dataset by name. @@ -73,8 +80,9 @@ def load_dataset(name, size='small', test_with_fake_data=False): return DATASETS[name]() +@deprecated(None, 'Please use tf.data.') def make_dataset(name, n_samples=100, noise=None, seed=42, *args, **kwargs): - """Creates binary synthetic datasets + """Creates binary synthetic datasets. Args: name: str, name of the dataset to generate diff --git a/tensorflow/contrib/learn/python/learn/datasets/base.py b/tensorflow/contrib/learn/python/learn/datasets/base.py index ca720ae5ed26e74da12bd6c5a37231b41442f76f..3b5c9b97c08a388e1f35249967b6cab26861f100 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/base.py +++ b/tensorflow/contrib/learn/python/learn/datasets/base.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Base utilities for loading datasets.""" + +"""Base utilities for loading datasets (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -29,11 +35,14 @@ import numpy as np from six.moves import urllib from tensorflow.python.platform import gfile +from tensorflow.python.util.deprecation import deprecated + Dataset = collections.namedtuple('Dataset', ['data', 'target']) Datasets = collections.namedtuple('Datasets', ['train', 'validation', 'test']) +@deprecated(None, 'Use tf.data instead.') def load_csv_with_header(filename, target_dtype, features_dtype, @@ -53,6 +62,7 @@ def load_csv_with_header(filename, return Dataset(data=data, target=target) +@deprecated(None, 'Use tf.data instead.') def load_csv_without_header(filename, target_dtype, features_dtype, @@ -70,6 +80,7 @@ def load_csv_without_header(filename, return Dataset(data=data, target=target) +@deprecated(None, 'Use tf.data instead.') def shrink_csv(filename, ratio): """Create a smaller dataset of only 1/ratio of original data.""" filename_small = filename.replace('.', '_small.') @@ -84,6 +95,7 @@ def shrink_csv(filename, ratio): i += 1 +@deprecated(None, 'Use scikits.learn.datasets.') def load_iris(data_path=None): """Load Iris dataset. @@ -100,6 +112,7 @@ def load_iris(data_path=None): data_path, target_dtype=np.int, features_dtype=np.float) +@deprecated(None, 'Use scikits.learn.datasets.') def load_boston(data_path=None): """Load Boston housing dataset. @@ -116,7 +129,12 @@ def load_boston(data_path=None): data_path, target_dtype=np.float, features_dtype=np.float) -def retry(initial_delay, max_delay, factor=2.0, jitter=0.25, is_retriable=None): +@deprecated(None, 'Use the retry module or similar alternatives.') +def retry(initial_delay, + max_delay, + factor=2.0, + jitter=0.25, + is_retriable=None): """Simple decorator for wrapping retriable functions. Args: @@ -152,7 +170,7 @@ def retry(initial_delay, max_delay, factor=2.0, jitter=0.25, is_retriable=None): for delay in delays(): try: return fn(*args, **kwargs) - except Exception as e: # pylint: disable=broad-except) + except Exception as e: # pylint: disable=broad-except if is_retriable is None: continue @@ -176,11 +194,13 @@ def _is_retriable(e): return isinstance(e, IOError) and e.errno in _RETRIABLE_ERRNOS +@deprecated(None, 'Please use urllib or similar directly.') @retry(initial_delay=1.0, max_delay=16.0, is_retriable=_is_retriable) def urlretrieve_with_retry(url, filename=None): return urllib.request.urlretrieve(url, filename) +@deprecated(None, 'Please write your own downloading logic.') def maybe_download(filename, work_directory, source_url): """Download the data from source url, unless it's already here. diff --git a/tensorflow/contrib/learn/python/learn/datasets/mnist.py b/tensorflow/contrib/learn/python/learn/datasets/mnist.py index 37f9175015a239f763c7721cf36ab8063c0a3e32..abbb44c2f5b701829ce16f64eadd8ebc04c84e2c 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/mnist.py +++ b/tensorflow/contrib/learn/python/learn/datasets/mnist.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functions for downloading and reading MNIST data.""" +"""Functions for downloading and reading MNIST data (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -27,6 +32,7 @@ from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.platform import gfile +from tensorflow.python.util.deprecation import deprecated # CVDF mirror of http://yann.lecun.com/exdb/mnist/ DEFAULT_SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/' @@ -37,6 +43,7 @@ def _read32(bytestream): return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] +@deprecated(None, 'Please use tf.data to implement this functionality.') def extract_images(f): """Extract the images into a 4D uint8 numpy array [index, y, x, depth]. @@ -65,6 +72,7 @@ def extract_images(f): return data +@deprecated(None, 'Please use tf.one_hot on tensors.') def dense_to_one_hot(labels_dense, num_classes): """Convert class labels from scalars to one-hot vectors.""" num_labels = labels_dense.shape[0] @@ -74,6 +82,7 @@ def dense_to_one_hot(labels_dense, num_classes): return labels_one_hot +@deprecated(None, 'Please use tf.data to implement this functionality.') def extract_labels(f, one_hot=False, num_classes=10): """Extract the labels into a 1D uint8 numpy array [index]. @@ -103,7 +112,15 @@ def extract_labels(f, one_hot=False, num_classes=10): class DataSet(object): + """Container class for a dataset (deprecated). + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please use alternatives such as official/mnist/dataset.py' + ' from tensorflow/models.') def __init__(self, images, labels, @@ -210,6 +227,8 @@ class DataSet(object): return self._images[start:end], self._labels[start:end] +@deprecated(None, 'Please use alternatives such as official/mnist/dataset.py' + ' from tensorflow/models.') def read_data_sets(train_dir, fake_data=False, one_hot=False, @@ -275,5 +294,7 @@ def read_data_sets(train_dir, return base.Datasets(train=train, validation=validation, test=test) +@deprecated(None, 'Please use alternatives such as official/mnist/dataset.py' + ' from tensorflow/models.') def load_mnist(train_dir='MNIST-data'): return read_data_sets(train_dir) diff --git a/tensorflow/contrib/learn/python/learn/datasets/produce_small_datasets.py b/tensorflow/contrib/learn/python/learn/datasets/produce_small_datasets.py index 6e0ba38941ce4650ede9f7210e284bde2ed8e6a9..a4848fa64a72f031ef35c0c3256e97a7326acd60 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/produce_small_datasets.py +++ b/tensorflow/contrib/learn/python/learn/datasets/produce_small_datasets.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Produce DBpedia datasets of a smaller size.""" +"""Produce DBpedia datasets of a smaller size (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/datasets/synthetic.py b/tensorflow/contrib/learn/python/learn/datasets/synthetic.py index 9a843168c27d9cae3f55efe4fe4c688d86c745f3..6a0e3350b3d1052249160a2a997a76de7a5040c3 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/synthetic.py +++ b/tensorflow/contrib/learn/python/learn/datasets/synthetic.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Synthetic dataset generators.""" +"""Synthetic dataset generators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -21,8 +26,10 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.learn.python.learn.datasets.base import Dataset +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Consider using synthetic datasets from scikits.learn.') def circles(n_samples=100, noise=None, seed=None, @@ -93,6 +100,7 @@ def circles(n_samples=100, return Dataset(data=X[indices], target=y[indices]) +@deprecated(None, 'Consider using synthetic datasets from scikits.learn.') def spirals(n_samples=100, noise=None, seed=None, diff --git a/tensorflow/contrib/learn/python/learn/datasets/text_datasets.py b/tensorflow/contrib/learn/python/learn/datasets/text_datasets.py index 2596a2ecaf1572506504831e8b08fab9b5dbc119..ce9466301728082f8e9d99c90989ba8fe623bcf0 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/text_datasets.py +++ b/tensorflow/contrib/learn/python/learn/datasets/text_datasets.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Text datasets.""" +"""Text datasets (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -26,10 +31,12 @@ import numpy as np from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.python.platform import gfile +from tensorflow.python.util.deprecation import deprecated DBPEDIA_URL = 'https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz' +@deprecated(None, 'See contrib/learn/README.md') def maybe_download_dbpedia(data_dir): """Download if DBpedia data is not present.""" train_path = os.path.join(data_dir, 'dbpedia_csv/train.csv') @@ -41,6 +48,7 @@ def maybe_download_dbpedia(data_dir): tfile.extractall(data_dir) +@deprecated(None, 'See contrib/learn/README.md') def load_dbpedia(size='small', test_with_fake_data=False): """Get DBpedia datasets from CSV files.""" if not test_with_fake_data: diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py index 4981750c94c7ac31e23b7a3f71ca30e3c9573a20..3e64595f312bcc2a2e8dcba589fb993a249b684b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py +++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""An estimator is a rule for calculating an estimate of a given quantity. +"""An estimator is a rule for calculating an estimate of a given quantity (deprecated). + +These classes are deprecated and replaced with `tf.estimator`. + +See [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. # Estimators diff --git a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py index 15277415a1ce83dc1d4a334e60fe1933ba244df0..1f0e4663d060a3850e2002b27f809fde1db47e48 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -"""sklearn cross-support.""" +"""sklearn cross-support (deprecated).""" from __future__ import absolute_import from __future__ import division @@ -132,6 +132,8 @@ class _TransformerMixin(): class NotFittedError(ValueError, AttributeError): """Exception class to raise if estimator is used before fitting. + USE OF THIS EXCEPTION IS DEPRECATED. + This class inherits from both ValueError and AttributeError to help with exception handling and backward compatibility. diff --git a/tensorflow/contrib/learn/python/learn/estimators/composable_model.py b/tensorflow/contrib/learn/python/learn/estimators/composable_model.py index a02c726c74946d93b8e1726473db746220b00795..1fa58271e2b886cd143683a759145fd750791473 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/composable_model.py +++ b/tensorflow/contrib/learn/python/learn/estimators/composable_model.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TensorFlow composable models used as building blocks for estimators.""" +"""TensorFlow composable models used as building blocks for estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -34,6 +39,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.summary import summary +from tensorflow.python.util.deprecation import deprecated class _ComposableModel(object): @@ -46,6 +52,7 @@ class _ComposableModel(object): _ComposableModel and its subclasses are not part of the public tf.learn API. """ + @deprecated(None, "Please use model_fns in tf.estimator.") def __init__(self, num_label_columns, optimizer, @@ -141,6 +148,10 @@ class _ComposableModel(object): class LinearComposableModel(_ComposableModel): """A _ComposableModel that implements linear regression. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Instances of this class can be used to build estimators through the use of composition. """ @@ -252,6 +263,10 @@ class LinearComposableModel(_ComposableModel): class DNNComposableModel(_ComposableModel): """A _ComposableModel that implements a DNN. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Instances of this class can be used to build estimators through the use of composition. """ diff --git a/tensorflow/contrib/learn/python/learn/estimators/constants.py b/tensorflow/contrib/learn/python/learn/estimators/constants.py index fc69e810244a182b864be856e6720f8584f7aa65..d2548946bc77dea7c452d61c7e2b6e12c3d6239a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/constants.py +++ b/tensorflow/contrib/learn/python/learn/estimators/constants.py @@ -13,9 +13,11 @@ # limitations under the License. # ============================================================================== -"""Constants regarding Estimators. +"""Constants regarding Estimators (deprecated). -This file is obsoleted in the move of Estimator to core. +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. """ from __future__ import absolute_import from __future__ import division @@ -25,6 +27,8 @@ from __future__ import print_function class ProblemType(object): """Enum-like values for the type of problem that the model solves. + THIS CLASS IS DEPRECATED. + These values are used when exporting the model to produce the appropriate signature function for serving. diff --git a/tensorflow/contrib/learn/python/learn/estimators/debug.py b/tensorflow/contrib/learn/python/learn/estimators/debug.py index 9d5f6c2bf969d7c85d251bf1b06a0307a41b2297..24b067b7e38b12df3d1d0c49f626344217218571 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/debug.py +++ b/tensorflow/contrib/learn/python/learn/estimators/debug.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Debug estimators. +"""Debug estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Debug estimators are bias-only estimators that can be used for debugging and as simple baselines. @@ -118,6 +122,10 @@ def debug_model_fn(features, labels, mode, params, config=None): class DebugClassifier(estimator.Estimator): """A classifier for TensorFlow Debug models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python @@ -237,6 +245,10 @@ class DebugClassifier(estimator.Estimator): class DebugRegressor(estimator.Estimator): """A regressor for TensorFlow Debug models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index c17b41c0f767e19d9c3635a8f60347a49b297cfb..eabebb7e881558471c343c0573cc9a8f4a425312 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Deep Neural Network estimators.""" +"""Deep Neural Network estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -212,6 +217,10 @@ def _dnn_model_fn(features, labels, mode, params, config=None): class DNNClassifier(estimator.Estimator): """A classifier for TensorFlow DNN models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python @@ -521,6 +530,10 @@ class DNNClassifier(estimator.Estimator): class DNNRegressor(estimator.Estimator): """A regressor for TensorFlow DNN models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python @@ -796,6 +809,10 @@ class DNNRegressor(estimator.Estimator): class DNNEstimator(estimator.Estimator): """A Estimator for TensorFlow DNN models with user specified _Head. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py index 726612235050def6e7addb503cc6646a25de0e42..3d85533d92d17095bae9a69f229171e1bf61ba10 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorFlow estimators for Linear and DNN joined training models.""" +"""TensorFlow estimators for Linear and DNN joined training models (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -372,6 +377,10 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None): class DNNLinearCombinedEstimator(estimator.Estimator): """An estimator for TensorFlow Linear and DNN joined training models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note: New users must set `fix_global_step_increment_bug=True` when creating an estimator. @@ -490,6 +499,10 @@ class DNNLinearCombinedEstimator(estimator.Estimator): class DNNLinearCombinedClassifier(estimator.Estimator): """A classifier for TensorFlow Linear and DNN joined training models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note: New users must set `fix_global_step_increment_bug=True` when creating an estimator. @@ -832,6 +845,10 @@ class DNNLinearCombinedClassifier(estimator.Estimator): class DNNLinearCombinedRegressor(estimator.Estimator): """A regressor for TensorFlow Linear and DNN joined training models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note: New users must set `fix_global_step_increment_bug=True` when creating an estimator. diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index 69440e823ef1ed2d739f28bc14587891f2de80bb..a703dc66e922d48ceb64edc2a979061b8e45db49 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Estimator for Dynamic RNNs.""" +"""Estimator for Dynamic RNNs (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -540,6 +545,12 @@ def _get_dynamic_rnn_model_fn( class DynamicRnnEstimator(estimator.Estimator): + """Dynamically unrolled RNN (deprecated). + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, problem_type, diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 4b63e08ab3372849309ee5d28d754de82e9632f4..7a026a15e4aeea0dde4ed9f7de053a757a0abb58 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Base Estimator class.""" +"""Base Estimator class (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -138,6 +143,7 @@ def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1): return df.input_builder, df.get_feed_dict_fn() +@deprecated(None, 'Please specify feature columns explicitly.') def infer_real_valued_columns_from_input_fn(input_fn): """Creates `FeatureColumn` objects for inputs defined by `input_fn`. @@ -158,6 +164,7 @@ def infer_real_valued_columns_from_input_fn(input_fn): return layers.infer_real_valued_columns(features) +@deprecated(None, 'Please specify feature columns explicitly.') def infer_real_valued_columns_from_input(x): """Creates `FeatureColumn` objects for inputs defined by input `x`. @@ -389,6 +396,10 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, trainable.Trainable): """Abstract BaseEstimator class to train and evaluate TensorFlow models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Users should not instantiate or subclass this class. Instead, use an `Estimator`. """ @@ -399,6 +410,8 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, # TODO(wicke): Remove this once launcher takes over config functionality _Config = run_config.RunConfig # pylint: disable=invalid-name + @deprecated(None, 'Please replace uses of any Estimator from tf.contrib.learn' + ' with an Estimator from tf.estimator.*') def __init__(self, model_dir=None, config=None): """Initializes a BaseEstimator instance. @@ -457,6 +470,20 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, # TODO(wicke): make RunConfig immutable, and then return it without a copy. return copy.deepcopy(self._config) + @property + def model_fn(self): + """Returns the model_fn which is bound to self.params. + + Returns: + The model_fn with the following signature: + `def model_fn(features, labels, mode, metrics)` + """ + + def public_model_fn(features, labels, mode, config): + return self._call_model_fn(features, labels, mode, config=config) + + return public_model_fn + @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), ('y', None), ('batch_size', None)) def fit(self, @@ -890,8 +917,8 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, if feed_fn: hooks.append(basic_session_run_hooks.FeedFnHook(feed_fn)) if steps == 0: - logging.warning('evaluation steps are 0. If `input_fn` does not raise' - 'OutOfRangeError`, the evaluation will never stop.' + logging.warning('evaluation steps are 0. If `input_fn` does not raise ' + '`OutOfRangeError`, the evaluation will never stop. ' 'Use steps=None if intended.') if steps: hooks.append( @@ -1074,6 +1101,10 @@ def _identity_feature_engineering_fn(features, labels): class Estimator(BaseEstimator): """Estimator class is the basic TensorFlow model trainer/evaluator. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. """ def __init__(self, @@ -1162,7 +1193,7 @@ class Estimator(BaseEstimator): self._feature_engineering_fn = ( feature_engineering_fn or _identity_feature_engineering_fn) - def _call_model_fn(self, features, labels, mode, metrics=None): + def _call_model_fn(self, features, labels, mode, metrics=None, config=None): """Calls model function with support of 2, 3 or 4 arguments. Args: @@ -1170,6 +1201,7 @@ class Estimator(BaseEstimator): labels: labels dict. mode: ModeKeys metrics: Dict of metrics. + config: RunConfig. Returns: A `ModelFnOps` object. If model_fn returns a tuple, wraps them up in a @@ -1186,7 +1218,10 @@ class Estimator(BaseEstimator): if 'params' in model_fn_args: kwargs['params'] = self.params if 'config' in model_fn_args: - kwargs['config'] = self.config + if config: + kwargs['config'] = config + else: + kwargs['config'] = self.config if 'model_dir' in model_fn_args: kwargs['model_dir'] = self.model_dir model_fn_results = self._model_fn(features, labels, **kwargs) @@ -1458,8 +1493,14 @@ class Estimator(BaseEstimator): # For time of deprecation x,y from Estimator allow direct access. # pylint: disable=protected-access class SKCompat(sklearn.BaseEstimator): - """Scikit learn wrapper for TensorFlow Learn Estimator.""" + """Scikit learn wrapper for TensorFlow Learn Estimator. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please switch to the Estimator interface.') def __init__(self, estimator): self._estimator = estimator diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py index fd47710e3015de9ae6a453f98978b0ef8f88968c..e4c31396baf8271c49395a2b87b454dbc77177e2 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utils for Estimator.""" +"""Utils for Estimator (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 9b124b2c19f16bbc9b2afeadb82a32006e1a0ae9..2b4b6eff39f4fc8a20a149edfc07d2f4f27a9bae 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Abstractions for the head(s) of a model. +"""Abstractions for the head(s) of a model (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. """ + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -47,11 +52,16 @@ from tensorflow.python.summary import summary from tensorflow.python.training import training from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect +from tensorflow.python.util.deprecation import deprecated class Head(object): """Interface for the head/top of a model. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Given logits (or output of a hidden layer), a Head knows how to compute predictions, loss, default metric and export signature. It is meant to, @@ -177,6 +187,7 @@ class Head(object): raise NotImplementedError("Calling an abstract method.") +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def regression_head(label_name=None, weight_column_name=None, label_dimension=1, @@ -216,6 +227,7 @@ def regression_head(label_name=None, link_fn=(link_fn if link_fn is not None else array_ops.identity)) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def poisson_regression_head(label_name=None, weight_column_name=None, label_dimension=1, @@ -254,6 +266,7 @@ def poisson_regression_head(label_name=None, # TODO(zakaria): Consider adding a _RegressionHead for logistic_regression +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def multi_class_head(n_classes, label_name=None, weight_column_name=None, @@ -335,6 +348,7 @@ def multi_class_head(n_classes, label_keys=label_keys) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def binary_svm_head( label_name=None, weight_column_name=None, @@ -370,6 +384,7 @@ def binary_svm_head( thresholds=thresholds) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def multi_label_head(n_classes, label_name=None, weight_column_name=None, @@ -430,6 +445,7 @@ def multi_label_head(n_classes, loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def loss_only_head(loss_fn, head_name=None): """Creates a Head that contains only loss terms. @@ -447,6 +463,7 @@ def loss_only_head(loss_fn, head_name=None): return _LossOnlyHead(loss_fn, head_name=head_name) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def multi_head(heads, loss_weights=None): """Creates a MultiHead stemming from same logits/hidden layer. @@ -479,6 +496,7 @@ def multi_head(heads, loss_weights=None): return _MultiHead(heads, loss_merger=_weighted_loss_merger) +@deprecated(None, "Use 'lambda _: tf.no_op()'.") def no_op_train_fn(loss): del loss return control_flow_ops.no_op() diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 6d5da81b4c2087fb9c5307902e452a6220a17cd0..7c2d9bb0767cb979dae9c84b5342d129225677ed 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -362,7 +362,7 @@ class MultiLabelHeadTest(test.TestCase): "auc_precision_recall": 0.166667, "auc_precision_recall/class0": 0, "auc_precision_recall/class1": 0., - "auc_precision_recall/class2": 0.49999, + "auc_precision_recall/class2": 1., "labels/actual_label_mean/class0": self._labels[0][0], "labels/actual_label_mean/class1": self._labels[0][1], "labels/actual_label_mean/class2": self._labels[0][2], @@ -748,7 +748,7 @@ class BinaryClassificationHeadTest(test.TestCase): "accuracy/baseline_label_mean": label_mean, "accuracy/threshold_0.500000_mean": 1. / 2, "auc": 1. / 2, - "auc_precision_recall": 0.25, + "auc_precision_recall": 0.749999, "labels/actual_label_mean": label_mean, "labels/prediction_mean": .731059, # softmax "loss": expected_loss, diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py index 8f9d6fc318a357853bdb8e3264f6691b410006b1..66ebcfd1d81904b9afe5be6bd1a648fe325e1e0b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of k-means clustering on top of `Estimator` API. +"""Implementation of k-means clustering on top of `Estimator` API (deprecated). This module is deprecated. Please use @{tf.contrib.factorization.KMeansClustering} instead of @@ -153,7 +153,12 @@ def _kmeans_clustering_model_fn(features, labels, mode, params, config): # TODO(agarwal,ands): support sharded input. class KMeansClustering(estimator.Estimator): - """An Estimator for K-Means clustering.""" + """An Estimator for K-Means clustering. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE COSINE_DISTANCE = clustering_ops.COSINE_DISTANCE RANDOM_INIT = clustering_ops.RANDOM_INIT diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index 37aa8b339622415d082933cdf66d2472a4119b48..70b70af98c51dcb991c19152607272673953ee2a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Linear Estimators.""" +"""Linear Estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -238,8 +243,8 @@ def sdca_model_fn(features, labels, mode, params): parent_scope = "linear" - with variable_scope.variable_op_scope( - features.values(), parent_scope) as scope: + with variable_scope.variable_scope( + values=features.values(), name_or_scope=parent_scope) as scope: features = features.copy() features.update(layers.transform_features(features, feature_columns)) logits, columns_to_variables, bias = ( @@ -305,6 +310,10 @@ class _SdcaUpdateWeightsHook(session_run_hook.SessionRunHook): class LinearClassifier(estimator.Estimator): """Linear classifier model. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Train a linear model to classify instances into one of multiple possible classes. When number of possible classes is 2, this is binary classification. @@ -625,6 +634,10 @@ class LinearClassifier(estimator.Estimator): class LinearRegressor(estimator.Estimator): """Linear regressor model. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Train a linear regression model to predict label value given observation of feature values. @@ -860,6 +873,10 @@ class LinearRegressor(estimator.Estimator): class LinearEstimator(estimator.Estimator): """Linear model with user specified head. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Train a generalized linear model to predict label value given observation of feature values. diff --git a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py index fb339160d58e09d4ffd50090f2dbbcec08bebe47..3cbcc6e98de1c915c302617e4591c9baa33adeaf 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py +++ b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Logistic regression (aka binary classifier) class. +"""Logistic regression (aka binary classifier) class (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. This defines some useful basic metrics for using logistic regression to classify a binary event (0 vs 1). @@ -75,6 +79,10 @@ def LogisticRegressor( # pylint: disable=invalid-name feature_engineering_fn=None): """Builds a logistic regression Estimator for binary classification. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + This method provides a basic Estimator with some additional metrics for custom binary classification models, including AUC, precision/recall and accuracy. diff --git a/tensorflow/contrib/learn/python/learn/estimators/metric_key.py b/tensorflow/contrib/learn/python/learn/estimators/metric_key.py index 99388f116b345bd038f2985606c6203011597ea2..f264248e44d9aa48f26ee32e36746bd4c3145a8d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/metric_key.py +++ b/tensorflow/contrib/learn/python/learn/estimators/metric_key.py @@ -12,14 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Enum for metric keys.""" +"""Enum for metric keys (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function class MetricKey(object): - """Metric key strings.""" + """Metric key strings (deprecated).""" + LOSS = "loss" AUC = "auc" AUC_PR = "auc_precision_recall" diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py index 44e6c7c52dac524a22e9099e33e2aef82f8fe7ba..dcb161180c99ce71195c820217e8bdaf79d70901 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Classes and methods related to model_fn.""" +"""Classes and methods related to model_fn (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -37,10 +42,13 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import session_run_hook +from tensorflow.python.util.deprecation import deprecated class ModeKeys(object): - """Standard names for model modes. + """Standard names for model modes (deprecated). + + THIS CLASS IS DEPRECATED. The following standard keys are defined: @@ -65,8 +73,16 @@ class ModelFnOps( 'output_alternatives', 'training_chief_hooks', 'training_hooks', 'scaffold', 'mode' ])): - """Ops returned from a model_fn.""" + """Ops returned from a model_fn. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'When switching to tf.estimator.Estimator, use ' + 'tf.estimator.EstimatorSpec. You can use the `estimator_spec`' + ' method to create an equivalent one.') def __new__(cls, mode, predictions=None, diff --git a/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py b/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py index f8d87b8914307a86eb2f46123a28ff11eb925eda..6fd2fc9d592cef4e44a640e2f27cb28b367d44d5 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py +++ b/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Enum for model prediction keys. +"""Enum for model prediction keys (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. This file is obsoleted in the move of Estimator to core. """ @@ -22,6 +26,8 @@ from __future__ import print_function class PredictionKey(object): + """THIS CLASS IS DEPRECATED.""" + CLASSES = "classes" PROBABILITIES = "probabilities" LOGITS = "logits" diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py index 2752bc2d90ee0f51b2c40ccc4d24a4eb21cff38f..215022e5d9e5d3cd5d6a96583b325b19a1719568 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py +++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Common operations for RNN Estimators.""" +"""Common operations for RNN Estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py index fd90fd1cc6277e7d80287aefdbab6134dac7c0d5..1d161093de01ef838d0c75ec9a39574c7529bd57 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Run Config.""" +"""Run Config (deprecated, use tf.estimator.RunConfig instead). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -29,11 +34,12 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.estimator import run_config as core_run_config from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib +from tensorflow.python.util.deprecation import deprecated # A list of the property names in RunConfig user allows to change. They will # not affect the execution framework, so when execution framework checks the -# `uid` of the RunConfig, it should be ingored. +# `uid` of the RunConfig, it should be ignored. _DEFAULT_UID_WHITE_LIST = [ 'tf_random_seed', 'save_summary_steps', @@ -47,6 +53,7 @@ _DEFAULT_UID_WHITE_LIST = [ class Environment(object): + """DEPRECATED CLASS.""" # For running general distributed training. CLOUD = 'cloud' # For running Google-internal distributed training. @@ -56,6 +63,7 @@ class Environment(object): class TaskType(object): + """DEPRECATED CLASS.""" MASTER = 'master' PS = 'ps' WORKER = 'worker' @@ -64,6 +72,8 @@ class TaskType(object): class ClusterConfig(object): """This class specifies the configurations for a distributed run. + THIS CLASS IS DEPRECATED. Use tf.estimator.RunConfig instead. + If you're using an `Estimator`, you should probably use the subclass RunConfig instead. """ @@ -211,10 +221,13 @@ class ClusterConfig(object): class RunConfig(ClusterConfig, core_run_config.RunConfig): """This class specifies the configurations for an `Estimator` run. - This class is the implementation of @{tf.estimator.RunConfig} interface. + This class is a deprecated implementation of @{tf.estimator.RunConfig} + interface. """ _USE_DEFAULT = 0 + @deprecated(None, 'When switching to tf.estimator.Estimator, use' + ' tf.estimator.RunConfig instead.') def __init__(self, master=None, num_cores=0, diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py index 0cea35e219a4457417a161a3ac4ac4292fd24f53..de78c72c3ae3ef14f5f7c46b1d47f82e8266c7c6 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Estimator for State Saving RNNs.""" +"""Estimator for State Saving RNNs (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -528,6 +533,12 @@ def _get_rnn_model_fn(cell_type, class StateSavingRnnEstimator(estimator.Estimator): + """RNN with static unrolling and state saving (deprecated). + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, problem_type, diff --git a/tensorflow/contrib/learn/python/learn/estimators/svm.py b/tensorflow/contrib/learn/python/learn/estimators/svm.py index 72920d73c0c92886e54f533ad7fe170fe27d9870..3459997baba16fc0d4045e50819ecdd0e7121657 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/svm.py +++ b/tensorflow/contrib/learn/python/learn/estimators/svm.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Support Vector Machine (SVM) Estimator.""" +"""Support Vector Machine (SVM) Estimator (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -36,6 +41,10 @@ def _as_iterable(preds, output): class SVM(estimator.Estimator): """Support Vector Machine (SVM) model for binary classification. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Currently, only linear SVMs are supported. For the underlying optimization problem, the `SDCAOptimizer` is used. For performance and convergence tuning, the num_loss_partitions parameter passed to `SDCAOptimizer` (see `__init__()` diff --git a/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py b/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py index a120bc6cc3975a3d4559d018c8aa74ff42a16d2d..71b5658dd174d2b47e33860844359f68e6768026 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py +++ b/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorSignature class and utilities.""" +"""TensorSignature class and utilities (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -33,6 +38,10 @@ class TensorSignature(collections.namedtuple( "TensorSignature", ["dtype", "shape", "is_sparse"])): """Signature of the `Tensor` object. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Useful to check compatibility of tensors. Example: diff --git a/tensorflow/contrib/learn/python/learn/estimators/test_data.py b/tensorflow/contrib/learn/python/learn/estimators/test_data.py index ed201bfc58f273e6587850032386c2686aea4148..e4b057b4f5a9e081c2d891bd9828ffc315e51e91 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/test_data.py +++ b/tensorflow/contrib/learn/python/learn/estimators/test_data.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Test data utilities.""" +"""Test data utilities (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/evaluable.py b/tensorflow/contrib/learn/python/learn/evaluable.py index 8f6cd39864b437f163dd7c1140dc88755ce98529..10881ca885599bc81386e15f814a2687d907f63b 100644 --- a/tensorflow/contrib/learn/python/learn/evaluable.py +++ b/tensorflow/contrib/learn/python/learn/evaluable.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""`Evaluable` interface.""" +"""`Evaluable` interface (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -23,6 +28,10 @@ import abc class Evaluable(object): """Interface for objects that are evaluatable by, e.g., `Experiment`. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. """ __metaclass__ = abc.ABCMeta diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 331bc115499c8d6f4057bf1c0908bcea05f005a3..3744abd860e7f460133873eb534fd75887182f78 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Experiment class collecting information needed for a single training run.""" +"""Experiment class collecting information for a single training run (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -25,7 +30,6 @@ import os import time from tensorflow.contrib.framework import deprecated -from tensorflow.contrib.framework import deprecated_args from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import export_strategy @@ -118,6 +122,10 @@ class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener): class Experiment(object): """Experiment is a class containing all information needed to train a model. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + After an experiment is created (by passing an Estimator and inputs for training and evaluation), an Experiment instance knows how to invoke training and eval loops in a sensible fashion for distributed training. @@ -125,16 +133,8 @@ class Experiment(object): # TODO(ispir): remove delay_workers_by_global_step and make global step based # waiting as only behavior. - @deprecated_args( - "2016-10-23", - "local_eval_frequency is deprecated as local_run will be renamed to " - "train_and_evaluate. Use min_eval_frequency and call train_and_evaluate " - "instead. Note, however, that the default for min_eval_frequency is 1, " - "meaning models will be evaluated every time a new checkpoint is " - "available. In contrast, the default for local_eval_frequency is None, " - "resulting in evaluation occurring only after training has completed. " - "min_eval_frequency is ignored when calling the deprecated local_run.", - "local_eval_frequency") + @deprecated(None, "Please switch to tf.estimator.train_and_evaluate. You will" + " also have to convert to a tf.estimator.Estimator.") def __init__(self, estimator, train_input_fn, @@ -358,7 +358,7 @@ class Experiment(object): self._start_server() elif config.cluster_spec and config.master: raise ValueError( - "For distributed runtime, Experiment class only works with" + "For distributed runtime, Experiment class only works with " "tf.contrib.learn.RunConfig for now, but provided {}".format( type(config))) diff --git a/tensorflow/contrib/learn/python/learn/export_strategy.py b/tensorflow/contrib/learn/python/learn/export_strategy.py index 55a8b824312b89e0ac66513242191f4201ac212a..075cab536ecb5279e7e6f23abb0b70c75043a7ec 100644 --- a/tensorflow/contrib/learn/python/learn/export_strategy.py +++ b/tensorflow/contrib/learn/python/learn/export_strategy.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""ExportStrategy class represents different flavors of model export.""" +"""ExportStrategy class represents different flavors of model export (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -21,6 +26,7 @@ from __future__ import print_function import collections from tensorflow.python.util import tf_inspect +from tensorflow.python.util.deprecation import deprecated __all__ = ['ExportStrategy'] @@ -30,6 +36,10 @@ class ExportStrategy( ['name', 'export_fn', 'strip_default_attrs'])): """A class representing a type of model export. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Typically constructed by a utility function specific to the exporter, such as `saved_model_export_utils.make_export_strategy()`. @@ -56,6 +66,8 @@ class ExportStrategy( forward compatibility of the resulting `SavedModel`. """ + @deprecated(None, 'Please switch to tf.estimator.train_and_evaluate, and use ' + 'tf.estimator.Exporter.') def __new__(cls, name, export_fn, strip_default_attrs=None): return super(ExportStrategy, cls).__new__( cls, name, export_fn, strip_default_attrs) diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py index 98365c05f663e5d2a06703457fc5663d7135f7d9..a997fab723a16dddf150aa9397863605e4e77933 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""High level operations on graphs.""" +"""High level operations on graphs (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -68,6 +73,7 @@ def clear_summary_writers(): return summary_io.SummaryWriterCache.clear() +@deprecated(None, 'Use `SummaryWriterCache.get` directly.') def get_summary_writer(logdir): """Returns single SummaryWriter per logdir in current run. diff --git a/tensorflow/contrib/learn/python/learn/learn_io/__init__.py b/tensorflow/contrib/learn/python/learn/learn_io/__init__.py index 06c3782a471537cf3879450e6bd20899a35d96ac..8b133a4440d8cbc19abca64f972791fc16ade6f8 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/__init__.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/__init__.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tools to allow different io formats.""" +"""Tools to allow different io formats (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/learn_io/dask_io.py b/tensorflow/contrib/learn/python/learn/learn_io/dask_io.py index 7d666391cea3c0a52a2cb7e324c00d5f480710d5..e0a1948d95a727675dac8ff3ce9f55c35d5f8d8d 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/dask_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/dask_io.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Methods to allow dask.DataFrame.""" +"""Methods to allow dask.DataFrame (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -21,6 +26,8 @@ from __future__ import print_function import numpy as np +from tensorflow.python.util.deprecation import deprecated + try: # pylint: disable=g-import-not-at-top import dask.dataframe as dd @@ -60,6 +67,7 @@ def _construct_dask_df_with_divisions(df): return dd.Series(merge(dsk, df.dask), name, df.name, divisions) +@deprecated(None, 'Please feed input to tf.data to support dask.') def extract_dask_data(data): """Extract data from dask.Series or dask.DataFrame for predictors. @@ -81,6 +89,7 @@ def extract_dask_data(data): return data +@deprecated(None, 'Please feed input to tf.data to support dask.') def extract_dask_labels(labels): """Extract data from dask.Series or dask.DataFrame for labels. diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py index 96be8b1bc402479d5611965f27abb197363cb939..c45b1d186471125776d6536112aebb66bb5ad558 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementations of different data feeders to provide data for TF trainer.""" +"""Implementations of different data feeders to provide data for TF trainer (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" # TODO(ipolosukhin): Replace this module with feed-dict queue runners & queues. @@ -31,6 +36,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.deprecation import deprecated # pylint: disable=g-multiple-import,g-bad-import-order from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels @@ -101,6 +107,7 @@ def _is_iterable(x): return hasattr(x, 'next') or hasattr(x, '__next__') +@deprecated(None, 'Please use tensorflow/transform or tf.data.') def setup_train_data_feeder(x, y, n_classes, @@ -188,6 +195,7 @@ def _batch_data(x, batch_size=None): yield np.matrix(chunk) +@deprecated(None, 'Please use tensorflow/transform or tf.data.') def setup_predict_data_feeder(x, batch_size=None): """Returns an iterable for feeding into predict step. @@ -219,6 +227,7 @@ def setup_predict_data_feeder(x, batch_size=None): return [x] +@deprecated(None, 'Please use tensorflow/transform or tf.data.') def setup_processor_data_feeder(x): """Sets up processor iterable. @@ -233,6 +242,7 @@ def setup_processor_data_feeder(x): return x +@deprecated(None, 'Please convert numpy dtypes explicitly.') def check_array(array, dtype): """Checks array on dtype and converts it if different. @@ -275,8 +285,14 @@ def _check_dtype(dtype): class DataFeeder(object): - """Data feeder is an example class to sample data for TF trainer.""" + """Data feeder is an example class to sample data for TF trainer. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please use tensorflow/transform or tf.data.') def __init__(self, x, y, @@ -563,6 +579,10 @@ class DataFeeder(object): class StreamingDataFeeder(DataFeeder): """Data feeder for TF trainer that reads data from iterator. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Streaming data feeder allows to read data as it comes it from disk or somewhere else. It's custom to have this iterators rotate infinetly over the dataset, to allow control of how much to learn on the trainer side. @@ -771,11 +791,16 @@ class StreamingDataFeeder(DataFeeder): class DaskDataFeeder(object): """Data feeder for that reads data from dask.Series and dask.DataFrame. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Numpy arrays can be serialized to disk and it's possible to do random seeks into them. DaskDataFeeder will remove requirement to have full dataset in the memory and still do random seeks for sampling of batches. """ + @deprecated(None, 'Please feed input to tf.data to support dask.') def __init__(self, x, y, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py index 884faf8335e2a3ca1d27d2d93b4c817131648774..f8aaa0c9e3e5b589a6ad47678dba3dc38de7c471 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Methods to allow generator of dict with numpy arrays.""" +"""Methods to allow generator of dict with numpy arrays (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -23,8 +28,10 @@ from types import FunctionType from types import GeneratorType from tensorflow.python.estimator.inputs.queues.feeding_functions import _enqueue_data as enqueue_data +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Please use tf.data.') def generator_input_fn(x, target_key=None, batch_size=128, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index 3a46c239688017f9204d2c6182a6f81cd325a417..9e816f54b6cf8dee84c6d62406ab3db700054d06 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Methods to read data in the graph.""" +"""Methods to read data in the graph (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -34,11 +39,13 @@ from tensorflow.python.platform import gfile from tensorflow.python.summary import summary from tensorflow.python.training import input as input_ops from tensorflow.python.training import queue_runner +from tensorflow.python.util.deprecation import deprecated # Default name for key in the feature dict. KEY_FEATURE_NAME = '__key__' +@deprecated(None, 'Use tf.data.') def read_batch_examples(file_pattern, batch_size, reader, @@ -106,6 +113,7 @@ def read_batch_examples(file_pattern, return examples +@deprecated(None, 'Use tf.data.') def read_keyed_batch_examples(file_pattern, batch_size, reader, @@ -175,6 +183,7 @@ def read_keyed_batch_examples(file_pattern, seed=seed) +@deprecated(None, 'Use tf.data.') def read_keyed_batch_examples_shared_queue(file_pattern, batch_size, reader, @@ -452,6 +461,7 @@ def _read_keyed_batch_examples_helper(file_pattern, return queued_examples_with_keys +@deprecated(None, 'Use tf.data.') def read_keyed_batch_features(file_pattern, batch_size, features, @@ -540,6 +550,7 @@ def read_keyed_batch_features(file_pattern, name=scope) +@deprecated(None, 'Use tf.data.') def read_keyed_batch_features_shared_queue(file_pattern, batch_size, features, @@ -620,6 +631,7 @@ def read_keyed_batch_features_shared_queue(file_pattern, name=scope) +@deprecated(None, 'Use tf.data.') def queue_parsed_features(parsed_features, keys=None, feature_queue_capacity=100, @@ -742,6 +754,7 @@ def queue_parsed_features(parsed_features, return dequeued_keys, dequeued_parsed_features +@deprecated(None, 'Use tf.data.') def read_batch_features(file_pattern, batch_size, features, @@ -821,6 +834,7 @@ def read_batch_features(file_pattern, return features +@deprecated(None, 'Use tf.data.') def read_batch_record_features(file_pattern, batch_size, features, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py b/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py index 692438807fbd7febb156d4db73b5d3deba6c987d..29552d24f1eaa0d85a99c8b09f69d007e7e4fe9f 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py @@ -12,15 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Methods to allow dict of numpy arrays.""" +"""Methods to allow dict of numpy arrays (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.estimator.inputs.numpy_io import numpy_input_fn as core_numpy_input_fn +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Use tf.estimator.inputs.numpy_input_fn.') def numpy_input_fn(x, y=None, batch_size=128, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py index ede7558eafa9237dc63aa95a62e599c5e9755822..b4ef055f5ae484ec704ad42efcf2c00c4a7a4f56 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py @@ -13,13 +13,19 @@ # limitations under the License. # ============================================================================== -"""Methods to allow pandas.DataFrame.""" +"""Methods to allow pandas.DataFrame (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.estimator.inputs.pandas_io import pandas_input_fn as core_pandas_input_fn +from tensorflow.python.util.deprecation import deprecated try: # pylint: disable=g-import-not-at-top @@ -47,6 +53,7 @@ PANDAS_DTYPES = { } +@deprecated(None, 'Please use tf.estimator.inputs.pandas_input_fn') def pandas_input_fn(x, y=None, batch_size=128, @@ -66,6 +73,7 @@ def pandas_input_fn(x, target_column=target_column) +@deprecated(None, 'Please access pandas data directly.') def extract_pandas_data(data): """Extract data from pandas.DataFrame for predictors. @@ -96,6 +104,7 @@ def extract_pandas_data(data): 'float, or bool. Found: ' + ', '.join(error_report)) +@deprecated(None, 'Please access pandas data directly.') def extract_pandas_matrix(data): """Extracts numpy matrix from pandas DataFrame. @@ -111,6 +120,7 @@ def extract_pandas_matrix(data): return data.as_matrix() +@deprecated(None, 'Please access pandas data directly.') def extract_pandas_labels(labels): """Extract data from pandas.DataFrame for labels. diff --git a/tensorflow/contrib/learn/python/learn/learn_runner.py b/tensorflow/contrib/learn/python/learn/learn_runner.py index 2af723a0d64822e81fa0fbeb106ab812de6ab4e8..d719a3e488b9905ef7903e21d90dbaae0449735c 100644 --- a/tensorflow/contrib/learn/python/learn/learn_runner.py +++ b/tensorflow/contrib/learn/python/learn/learn_runner.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Runs an Experiment.""" +"""Runs an Experiment (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -22,6 +27,7 @@ from tensorflow.contrib.learn.python.learn.estimators import run_config as run_c from tensorflow.contrib.learn.python.learn.experiment import Experiment from tensorflow.contrib.training.python.training import hparam as hparam_lib from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.deprecation import deprecated # TODO(xiejw): Refactor the learn_runner to make code reusable. @@ -99,6 +105,7 @@ def _wrapped_experiment_fn_with_uid_check(experiment_fn, require_hparams=False): return wrapped_experiment_fn +@deprecated(None, 'Use tf.estimator.train_and_evaluate.') def run(experiment_fn, output_dir=None, schedule=None, run_config=None, hparams=None): """Make and run an experiment. @@ -218,6 +225,7 @@ def run(experiment_fn, output_dir=None, schedule=None, run_config=None, return _execute_schedule(experiment, schedule) +@deprecated(None, 'Use tf.estimator.train_and_evaluate.') def tune(experiment_fn, tuner): """Tune an experiment with hyper-parameters. diff --git a/tensorflow/contrib/learn/python/learn/learn_runner_lib.py b/tensorflow/contrib/learn/python/learn/learn_runner_lib.py index 7d9b1c7716f0ab1f2274ca53406175240b613027..ba2d067787c1dfd4e4820ecc916f1053e9f3cf60 100644 --- a/tensorflow/contrib/learn/python/learn/learn_runner_lib.py +++ b/tensorflow/contrib/learn/python/learn/learn_runner_lib.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities to run and tune an Experiment. +"""Utilities to run and tune an Experiment (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. @@run @@tune diff --git a/tensorflow/contrib/learn/python/learn/metric_spec.py b/tensorflow/contrib/learn/python/learn/metric_spec.py index 6440bc204b8e339ff51311dcc87b36f556b94092..97220365d5dddb82b602369f06bea021a86d584f 100644 --- a/tensorflow/contrib/learn/python/learn/metric_spec.py +++ b/tensorflow/contrib/learn/python/learn/metric_spec.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The metric spec class to flexibly connect models and metrics.""" +"""The metric spec class to flexibly connect models and metrics (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -22,6 +27,7 @@ import six from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_inspect +from tensorflow.python.util.deprecation import deprecated def _assert_named_args(sentinel): @@ -223,6 +229,10 @@ def _adapt_metric_fn( class MetricSpec(object): """MetricSpec connects a model to metric functions. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + The MetricSpec class contains all information necessary to connect the output of a `model_fn` to the metrics (usually, streaming metrics) that are used in evaluation. @@ -284,6 +294,7 @@ class MetricSpec(object): """ + @deprecated(None, 'Use tf.estimator.EstimatorSpec.eval_metric_ops.') def __init__(self, metric_fn, prediction_key=None, diff --git a/tensorflow/contrib/learn/python/learn/models.py b/tensorflow/contrib/learn/python/learn/models.py index 4283240d018c949bb35aeb12032d2ee8b75884a5..bd4bbf9f8c9ad7e8a0fc06d8c0dc24672536c158 100644 --- a/tensorflow/contrib/learn/python/learn/models.py +++ b/tensorflow/contrib/learn/python/learn/models.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Various high level TF models.""" +"""Various high level TF models (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -28,8 +33,10 @@ from tensorflow.python.ops import array_ops as array_ops_ from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.summary import summary +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Consider using a tf.estimator.LinearRegressor') def linear_regression_zero_init(x, y): """Linear regression subgraph with zero-value initial weights and bias. @@ -43,6 +50,7 @@ def linear_regression_zero_init(x, y): return linear_regression(x, y, init_mean=0.0, init_stddev=0.0) +@deprecated(None, 'Consider using a class from tf.estimator.LinearClassifier') def logistic_regression_zero_init(x, y): """Logistic regression subgraph with zero-value initial weights and bias. @@ -56,6 +64,7 @@ def logistic_regression_zero_init(x, y): return logistic_regression(x, y, init_mean=0.0, init_stddev=0.0) +@deprecated(None, 'Consider using a class from tf.estimator.') def linear_regression(x, y, init_mean=None, init_stddev=1.0): """Creates linear regression TensorFlow subgraph. @@ -107,6 +116,7 @@ def linear_regression(x, y, init_mean=None, init_stddev=1.0): return losses_ops.mean_squared_error_regressor(x, y, weights, bias) +@deprecated(None, 'Consider using a class from tf.estimator.') def logistic_regression(x, y, class_weight=None, @@ -203,6 +213,7 @@ def _reverse_seq(input_seq, lengths): return result +@deprecated(None, 'Please consider `tf.nn.bidirectional_dynamic_rnn`.') def bidirectional_rnn(cell_fw, cell_bw, inputs, @@ -283,6 +294,7 @@ def bidirectional_rnn(cell_fw, # End of TensorFlow 0.7 +@deprecated(None, 'Please consider tensorflow/tensor2tensor.') def get_rnn_model(rnn_size, cell_type, num_layers, input_op_fn, bidirectional, target_predictor_fn, sequence_length, initial_state, attn_length, attn_size, attn_vec_size): diff --git a/tensorflow/contrib/learn/python/learn/monitored_session.py b/tensorflow/contrib/learn/python/learn/monitored_session.py index 22602e9f69d972505d83a66a6f9183b5e4d15c44..ac0433f1775feeed2ec3cf49291da01500bef01b 100644 --- a/tensorflow/contrib/learn/python/learn/monitored_session.py +++ b/tensorflow/contrib/learn/python/learn/monitored_session.py @@ -13,7 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A wrapper of Session API which runs hooks.""" +"""A wrapper of Session API which runs hooks (deprecated). + +These are deprecated aliases for classes and functions in `tf.train`. Please use +those directly. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index 9457a73ecfb41782c888e3bba0b140db83d4d464..77f7c73d5412d40b338eaff4cf04d99fd0892723 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Monitors instrument the training process. +"""Monitors instrument the training process (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. @@get_default_monitors @@BaseMonitor @@ -59,6 +63,10 @@ from tensorflow.python.util import tf_inspect class BaseMonitor(object): """Base class for Monitors. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Defines basic interfaces of Monitors. Monitors can either be run on all workers or, more commonly, restricted to run exclusively on the elected chief worker. @@ -229,6 +237,10 @@ def _extract_output(outputs, request): class EveryN(BaseMonitor): """Base class for monitors that execute callbacks every N steps. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + This class adds three new callbacks: - every_n_step_begin - every_n_step_end @@ -418,6 +430,10 @@ class StopAtStep(BaseMonitor): class PrintTensor(EveryN): """Prints given tensors every N steps. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + This is an `EveryN` monitor and has consistent semantic for `every_n` and `first_n`. @@ -455,9 +471,12 @@ class PrintTensor(EveryN): class LoggingTrainable(EveryN): """Writes trainable variable values into log every N steps. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Write the tensors in trainable variables `every_n` steps, starting with the `first_n`th step. - """ def __init__(self, scope=None, every_n=100, first_n=1): @@ -493,7 +512,12 @@ class LoggingTrainable(EveryN): class SummarySaver(EveryN): - """Saves summaries every N steps.""" + """Saves summaries every N steps. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, summary_op, @@ -554,6 +578,10 @@ class SummarySaver(EveryN): class ValidationMonitor(EveryN): """Runs evaluation of a given estimator, at most every N steps. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note that the evaluation is done based on the saved checkpoint, which will usually be older than the current step. @@ -756,6 +784,10 @@ class ValidationMonitor(EveryN): class CaptureVariable(EveryN): """Captures a variable's values into a collection. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + This monitor is useful for unit testing. You should exercise caution when using this monitor in production, since it never discards values. @@ -794,6 +826,7 @@ class CaptureVariable(EveryN): self._var_values[step] = _extract_output(outputs, self._var_name) +@deprecation.deprecated(None, "Use tf.train.MonitoredTrainingSession.") def get_default_monitors(loss_op=None, summary_op=None, save_summary_steps=100, @@ -828,6 +861,10 @@ def get_default_monitors(loss_op=None, class GraphDump(BaseMonitor): """Dumps almost all tensors in the graph at every step. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note, this is very expensive, prefer `PrintTensor` in production. """ @@ -917,7 +954,12 @@ class GraphDump(BaseMonitor): class ExportMonitor(EveryN): - """Monitor that exports Estimator every N steps.""" + """Monitor that exports Estimator every N steps. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ @deprecation.deprecated("2017-03-25", "ExportMonitor is deprecated. Please pass an " @@ -1040,7 +1082,12 @@ class ExportMonitor(EveryN): class CheckpointSaver(BaseMonitor): - """Saves checkpoints every N steps or N seconds.""" + """Saves checkpoints every N steps or N seconds. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, checkpoint_dir, @@ -1125,7 +1172,12 @@ class CheckpointSaver(BaseMonitor): class StepCounter(EveryN): - """Steps per second monitor.""" + """Steps per second monitor. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, every_n_steps=100, output_dir=None, summary_writer=None): super(StepCounter, self).__init__(every_n_steps=every_n_steps) @@ -1165,6 +1217,10 @@ class NanLossDuringTrainingError(RuntimeError): class NanLoss(EveryN): """NaN Loss monitor. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Monitors loss and stops training if loss is NaN. Can either fail with exception or just stop training. """ diff --git a/tensorflow/contrib/learn/python/learn/ops/__init__.py b/tensorflow/contrib/learn/python/learn/ops/__init__.py index 33962e34cc685ce2c830a7bbfd1b5c626bcd8b31..efb1f47cf5bb2dcd0fb37b7b85cd8f170d56e4d1 100644 --- a/tensorflow/contrib/learn/python/learn/ops/__init__.py +++ b/tensorflow/contrib/learn/python/learn/ops/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Various TensorFlow Ops.""" +"""Various TensorFlow Ops (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py b/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py index fa3b7323e343371e986b763d30a8a44620894549..8f9811cf251ae0af1e0055a56e1358c2771b1367 100644 --- a/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py +++ b/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py @@ -13,7 +13,11 @@ # limitations under the License. # ============================================================================== -"""TensorFlow Ops to work with embeddings. +"""TensorFlow Ops to work with embeddings (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Note: categorical variables are handled via embeddings in many cases. For example, in case of words. @@ -57,7 +61,7 @@ def embedding_lookup(params, ids, name='embedding_lookup'): ids = ops.convert_to_tensor(ids) shape = array_ops_.shape(ids) ids_flat = array_ops_.reshape( - ids, math_ops.reduce_prod(shape, keep_dims=True)) + ids, math_ops.reduce_prod(shape, keepdims=True)) embeds_flat = nn.embedding_lookup(params, ids_flat, name) embed_shape = array_ops_.concat([shape, [-1]], 0) embeds = array_ops_.reshape(embeds_flat, embed_shape) diff --git a/tensorflow/contrib/learn/python/learn/ops/losses_ops.py b/tensorflow/contrib/learn/python/learn/ops/losses_ops.py index b040ab3bb6c516158589a8e30d56fff1f7728951..92976d1539c7ddc226b81f903beee82b798ec8db 100644 --- a/tensorflow/contrib/learn/python/learn/ops/losses_ops.py +++ b/tensorflow/contrib/learn/python/learn/ops/losses_ops.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorFlow Ops for loss computation.""" +"""TensorFlow Ops for loss computation (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py index 45727faab4362abeab18f77861353eb53976023a..aa37cb4a76e2a6157bf077d327248353bd516472 100644 --- a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py +++ b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorFlow Ops for Sequence to Sequence models.""" +"""TensorFlow Ops for Sequence to Sequence models (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -26,8 +31,10 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Please use tf.nn/tf.layers directly.') def sequence_classifier(decoding, labels, sampling_decoding=None, name=None): """Returns predictions and loss for sequence of predictions. @@ -57,6 +64,7 @@ def sequence_classifier(decoding, labels, sampling_decoding=None, name=None): return array_ops.stack(predictions, axis=1), loss +@deprecated(None, 'Please use tf.nn/tf.layers directly.') def seq2seq_inputs(x, y, input_length, output_length, sentinel=None, name=None): """Processes inputs for Sequence to Sequence models. @@ -87,6 +95,7 @@ def seq2seq_inputs(x, y, input_length, output_length, sentinel=None, name=None): return in_x, in_y, out_y +@deprecated(None, 'Please use tf.nn/tf.layers directly.') def rnn_decoder(decoder_inputs, initial_state, cell, scope=None): """RNN Decoder that creates training and sampling sub-graphs. @@ -123,6 +132,7 @@ def rnn_decoder(decoder_inputs, initial_state, cell, scope=None): return outputs, states, sampling_outputs, sampling_states +@deprecated(None, 'Please use tf.nn/tf.layers directly.') def rnn_seq2seq(encoder_inputs, decoder_inputs, encoder_cell, diff --git a/tensorflow/contrib/learn/python/learn/preprocessing/__init__.py b/tensorflow/contrib/learn/python/learn/preprocessing/__init__.py index 7bcc177d4ea0ab57f092d68888a72de2b2fd5edc..e8c6e1acf80f0791421bee59aff30e67bccb44b2 100644 --- a/tensorflow/contrib/learn/python/learn/preprocessing/__init__.py +++ b/tensorflow/contrib/learn/python/learn/preprocessing/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Preprocessing tools useful for building models.""" +"""Preprocessing tools useful for building models (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/preprocessing/categorical.py b/tensorflow/contrib/learn/python/learn/preprocessing/categorical.py index 154739d497ec1029026eaca1e93b37cd225f1050..faba3b2025e8abb51d1989c3fafbd5e711d6559b 100644 --- a/tensorflow/contrib/learn/python/learn/preprocessing/categorical.py +++ b/tensorflow/contrib/learn/python/learn/preprocessing/categorical.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Implements preprocessing transformers for categorical variables.""" +"""Implements preprocessing transformers for categorical variables (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -22,6 +27,8 @@ from __future__ import print_function import math import numpy as np +from tensorflow.python.util.deprecation import deprecated + # pylint: disable=g-bad-import-order from . import categorical_vocabulary from ..learn_io.data_feeder import setup_processor_data_feeder @@ -31,10 +38,16 @@ from ..learn_io.data_feeder import setup_processor_data_feeder class CategoricalProcessor(object): """Maps documents to sequences of word ids. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + As a common convention, Nan values are handled as unknown tokens. Both float('nan') and np.nan are accepted. """ + @deprecated(None, 'Please use tensorflow/transform or tf.data for sequence ' + 'processing.') def __init__(self, min_frequency=0, share=False, vocabularies=None): """Initializes a CategoricalProcessor instance. diff --git a/tensorflow/contrib/learn/python/learn/preprocessing/categorical_vocabulary.py b/tensorflow/contrib/learn/python/learn/preprocessing/categorical_vocabulary.py index 5709955c49fba50ca4a299a443a2902bbd9c6b23..3ac370a6ab4423846e810900514445ad5269b680 100644 --- a/tensorflow/contrib/learn/python/learn/preprocessing/categorical_vocabulary.py +++ b/tensorflow/contrib/learn/python/learn/preprocessing/categorical_vocabulary.py @@ -13,7 +13,11 @@ # limitations under the License. # ============================================================================== -"""Categorical vocabulary classes to map categories to indexes. +"""Categorical vocabulary classes to map categories to indexes (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Can be used for categorical variables, sparse variables and words. """ @@ -25,14 +29,21 @@ from __future__ import print_function import collections import six +from tensorflow.python.util.deprecation import deprecated + class CategoricalVocabulary(object): """Categorical variables vocabulary class. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Accumulates and provides mapping from classes to indexes. Can be easily used for words. """ + @deprecated(None, 'Please use tensorflow/transform or tf.data.') def __init__(self, unknown_token="", support_reverse=True): self._unknown_token = unknown_token self._mapping = {unknown_token: 0} diff --git a/tensorflow/contrib/learn/python/learn/preprocessing/text.py b/tensorflow/contrib/learn/python/learn/preprocessing/text.py index 3af2074c2a46f0258c04111fff0235ba8309625e..f2b6776be7789a9433bfe41eb9354b74347059ec 100644 --- a/tensorflow/contrib/learn/python/learn/preprocessing/text.py +++ b/tensorflow/contrib/learn/python/learn/preprocessing/text.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Implements a number of text preprocessing utilities.""" +"""Implements a number of text preprocessing utilities (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -24,6 +29,7 @@ import numpy as np import six from tensorflow.python.platform import gfile +from tensorflow.python.util.deprecation import deprecated from .categorical_vocabulary import CategoricalVocabulary # pylint: disable=g-bad-import-order @@ -38,6 +44,7 @@ TOKENIZER_RE = re.compile(r"[A-Z]{2,}(?![a-z])|[A-Z][a-z]+(?=[A-Z])|[\'\w\-]+", re.UNICODE) +@deprecated(None, 'Please use tensorflow/transform or tf.data.') def tokenizer(iterator): """Tokenizer generator. @@ -51,9 +58,16 @@ def tokenizer(iterator): yield TOKENIZER_RE.findall(value) +@deprecated(None, 'Please use tensorflow/transform or tf.data.') class ByteProcessor(object): - """Maps documents into sequence of ids for bytes.""" + """Maps documents into sequence of ids for bytes. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please use tensorflow/transform or tf.data.') def __init__(self, max_document_length): self.max_document_length = max_document_length @@ -108,8 +122,14 @@ class ByteProcessor(object): class VocabularyProcessor(object): - """Maps documents to sequences of word ids.""" + """Maps documents to sequences of word ids. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please use tensorflow/transform or tf.data.') def __init__(self, max_document_length, min_frequency=0, diff --git a/tensorflow/contrib/learn/python/learn/session_run_hook.py b/tensorflow/contrib/learn/python/learn/session_run_hook.py index a8ba2be97206f2b974d256ad2c62c21a4e3e55d8..87edc9b720bdb3edcd5f2dcd1662d14da53c51cf 100644 --- a/tensorflow/contrib/learn/python/learn/session_run_hook.py +++ b/tensorflow/contrib/learn/python/learn/session_run_hook.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""This file is deprecated. Use tensorflow.python.training.session_run_hook.""" +"""This file is deprecated. Use `tensorflow.python.training.session_run_hook`. + +See [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/summary_writer_cache.py b/tensorflow/contrib/learn/python/learn/summary_writer_cache.py index 919d415c302b8ec17202aad34ff0bee69bfee2c7..d663cf5fb79c428b0e70d66b0f1305f0559a05c9 100644 --- a/tensorflow/contrib/learn/python/learn/summary_writer_cache.py +++ b/tensorflow/contrib/learn/python/learn/summary_writer_cache.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Wrapper for a Session-like object that handles threads and recovery. +"""Wrapper for a Session-like object that handles threads and recovery (deprecated). + +These are deprecated aliases for classes and functions in `tf.train`. Please use +those directly. Based on an original design of Illia Polosukhin. """ diff --git a/tensorflow/contrib/learn/python/learn/trainable.py b/tensorflow/contrib/learn/python/learn/trainable.py index 429b6040be21d8cbe1f2bba58090366552fdfbe7..a1a3f20dcd8cb5ff7baa559ac41d5e5c40780511 100644 --- a/tensorflow/contrib/learn/python/learn/trainable.py +++ b/tensorflow/contrib/learn/python/learn/trainable.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""`Trainable` interface.""" +"""`Trainable` interface (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -23,6 +28,8 @@ import abc class Trainable(object): """Interface for objects that are trainable by, e.g., `Experiment`. + + THIS CLASS IS DEPRECATED. """ __metaclass__ = abc.ABCMeta diff --git a/tensorflow/contrib/learn/python/learn/utils/__init__.py b/tensorflow/contrib/learn/python/learn/utils/__init__.py index 48978d0ac34cec2b18e6794dcf3b260bc3b683c4..66d8dc6fd43b383919a16515bc96be492a253bf6 100644 --- a/tensorflow/contrib/learn/python/learn/utils/__init__.py +++ b/tensorflow/contrib/learn/python/learn/utils/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorFlow Learn Utils.""" +"""TensorFlow Learn Utils (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index cb34cb1d26b6812c7f3f39e9f965615de5a8ef07..3eacac7a3d3dcff4d39025fdee88e16e385b1b84 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -13,14 +13,18 @@ # limitations under the License. # ============================================================================== -"""Export utilities.""" +"""Export utilities (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.contrib.framework import deprecated -from tensorflow.python.training import training_util from tensorflow.contrib.session_bundle import exporter from tensorflow.contrib.session_bundle import gc from tensorflow.python.client import session as tf_session @@ -32,6 +36,7 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver as tf_saver +from tensorflow.python.training import training_util @deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.') diff --git a/tensorflow/contrib/learn/python/learn/utils/gc.py b/tensorflow/contrib/learn/python/learn/utils/gc.py index 226915987a4934626066b12810f579ae675107b2..916aecbea88b10bbef316ffb89d4c4d89667cb29 100644 --- a/tensorflow/contrib/learn/python/learn/utils/gc.py +++ b/tensorflow/contrib/learn/python/learn/utils/gc.py @@ -13,7 +13,11 @@ # limitations under the License. # ============================================================================== -r"""System for specifying garbage collection (GC) of path based data. +r"""System for specifying garbage collection (GC) of path based data (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. This framework allows for GC of data specified by path names, for example files on disk. gc.Path objects each represent a single item stored at a path and may @@ -73,10 +77,12 @@ import os from tensorflow.python.platform import gfile from tensorflow.python.util import compat +from tensorflow.python.util.deprecation import deprecated Path = collections.namedtuple('Path', 'path export_version') +@deprecated(None, 'Please implement your own file management or use Saver.') def largest_export_versions(n): """Creates a filter that keeps the largest n export versions. @@ -97,6 +103,7 @@ def largest_export_versions(n): return keep +@deprecated(None, 'Please implement your own file management or use Saver.') def one_of_every_n_export_versions(n): """Creates a filter that keeps one of every n export versions. @@ -128,6 +135,7 @@ def one_of_every_n_export_versions(n): return keep +@deprecated(None, 'Please implement your own file management or use Saver.') def mod_export_version(n): """Creates a filter that keeps every export that is a multiple of n. @@ -146,6 +154,7 @@ def mod_export_version(n): return keep +@deprecated(None, 'Please implement your own file management or use Saver.') def union(lf, rf): """Creates a filter that keeps the union of two filters. @@ -163,6 +172,7 @@ def union(lf, rf): return keep +@deprecated(None, 'Please implement your own file management or use Saver.') def negation(f): """Negate a filter. @@ -179,6 +189,7 @@ def negation(f): return keep +@deprecated(None, 'Please implement your own file name management.') def get_paths(base_dir, parser): """Gets a list of Paths in a given directory. diff --git a/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py index b2521933e524e7ec24d73d4b5171f33e507dd88c..b92eb9fea8b7ccea56c781df74dcfa1cc5508e48 100644 --- a/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities for creating input_fns. +"""Utilities for creating input_fns (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Contents of this file are moved to tensorflow/python/estimator/export.py. InputFnOps is renamed to ServingInputReceiver. @@ -32,13 +36,17 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.util.deprecation import deprecated class InputFnOps(collections.namedtuple('InputFnOps', ['features', 'labels', 'default_inputs'])): - """A return type for an input_fn. + """A return type for an input_fn (deprecated). + + THIS CLASS IS DEPRECATED. Please use tf.estimator.export.ServingInputReceiver + instead. This return type is currently only supported for serving input_fn. Training and eval input_fn should return a `(features, labels)` tuple. @@ -56,6 +64,8 @@ class InputFnOps(collections.namedtuple('InputFnOps', """ +@deprecated(None, 'Please use ' + 'tf.estimator.export.build_parsing_serving_input_receiver_fn.') def build_parsing_serving_input_fn(feature_spec, default_batch_size=None): """Build an input_fn appropriate for serving, expecting fed tf.Examples. @@ -84,6 +94,8 @@ def build_parsing_serving_input_fn(feature_spec, default_batch_size=None): return input_fn +@deprecated(None, 'Please use ' + 'tf.estimator.export.build_raw_serving_input_receiver_fn.') def build_default_serving_input_fn(features, default_batch_size=None): """Build an input_fn appropriate for serving, expecting feature Tensors. diff --git a/tensorflow/contrib/learn/python/learn/utils/inspect_checkpoint.py b/tensorflow/contrib/learn/python/learn/utils/inspect_checkpoint.py index 6a63fb545a56e6040b0b0c3bbb6a17cd96925895..6dbaa15f8391b0044be8e30ca191753beb88db93 100644 --- a/tensorflow/contrib/learn/python/learn/utils/inspect_checkpoint.py +++ b/tensorflow/contrib/learn/python/learn/utils/inspect_checkpoint.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A simple script for inspect checkpoint files.""" +"""A simple script for inspect checkpoint files (deprecated).""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index 1593380007b2799fb1d17e92408ab19a7b47fe1e..c7cdb4131215c388412407a008113de13bdd0934 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities supporting export to SavedModel. +"""Utilities supporting export to SavedModel (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Some contents of this file are moved to tensorflow/python/estimator/export.py: @@ -52,8 +56,9 @@ from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.summary import summary_iterator from tensorflow.python.training import saver - from tensorflow.python.util import compat +from tensorflow.python.util.deprecation import deprecated + # A key for use in the input_alternatives dict indicating the default input. # This is the input that will be expected when a serving request does not @@ -77,6 +82,7 @@ FEATURES_INPUT_ALTERNATIVE_KEY = 'features_input_alternative' _FALLBACK_DEFAULT_OUTPUT_ALTERNATIVE_KEY = 'default_output_alternative' +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def build_standardized_signature_def(input_tensors, output_tensors, problem_type): """Build a SignatureDef using problem type and input and output Tensors. @@ -156,6 +162,7 @@ def _is_regression_problem(problem_type, input_tensors, output_tensors): len(input_tensors) == 1 and len(output_tensors) == 1) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_input_alternatives(input_ops): """Obtain all input alternatives using the input_fn output and heuristics.""" input_alternatives = {} @@ -181,6 +188,7 @@ def get_input_alternatives(input_ops): return input_alternatives, features +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_output_alternatives(model_fn_ops, default_output_alternative_key=None): """Obtain all output alternatives using the model_fn output and heuristics. @@ -246,6 +254,7 @@ def get_output_alternatives(model_fn_ops, default_output_alternative_key=None): sorted(output_alternatives.keys()))) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def build_all_signature_defs(input_alternatives, output_alternatives, actual_default_output_alternative_key): """Build `SignatureDef`s from all pairs of input and output alternatives.""" @@ -279,6 +288,7 @@ def build_all_signature_defs(input_alternatives, output_alternatives, MAX_DIRECTORY_CREATION_ATTEMPTS = 10 +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_timestamped_export_dir(export_dir_base): """Builds a path to a new subdirectory within the base directory. @@ -317,6 +327,7 @@ def get_timestamped_export_dir(export_dir_base): '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS)) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_temp_export_dir(timestamped_export_dir): """Builds a directory name based on the argument but starting with 'temp-'. @@ -344,6 +355,7 @@ def _export_version_parser(path): return path._replace(export_version=int(filename)) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_most_recent_export(export_dir_base): """Locate the most recent SavedModel export in a directory of many exports. @@ -363,6 +375,7 @@ def get_most_recent_export(export_dir_base): return next(iter(results or []), None) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def garbage_collect_exports(export_dir_base, exports_to_keep): """Deletes older exports, retaining only a given number of the most recent. @@ -387,6 +400,7 @@ def garbage_collect_exports(export_dir_base, exports_to_keep): logging.warn('Can not delete %s recursively: %s', p.path, e) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def make_export_strategy(serving_input_fn, default_output_alternative_key=None, assets_extra=None, @@ -400,7 +414,7 @@ def make_export_strategy(serving_input_fn, `InputFnOps`. default_output_alternative_key: the name of the head to serve when an incoming serving request does not explicitly request a specific head. - Must be `None` if the estimator inherits from ${tf.estimator.Estimator} + Must be `None` if the estimator inherits from @{tf.estimator.Estimator} or for single-headed models. assets_extra: A dict specifying how to populate the assets.extra directory within the exported SavedModel. Each key should give the destination @@ -438,7 +452,7 @@ def make_export_strategy(serving_input_fn, The string path to the exported directory. Raises: - ValueError: If `estimator` is a ${tf.estimator.Estimator} instance + ValueError: If `estimator` is a @{tf.estimator.Estimator} instance and `default_output_alternative_key` was specified. """ if isinstance(estimator, core_estimator.Estimator): @@ -469,6 +483,8 @@ def make_export_strategy(serving_input_fn, return export_strategy.ExportStrategy('Servo', export_fn, strip_default_attrs) +@deprecated(None, + 'Use tf.estimator.export.build_parsing_serving_input_receiver_fn') def make_parsing_export_strategy(feature_columns, default_output_alternative_key=None, assets_extra=None, @@ -487,7 +503,7 @@ def make_parsing_export_strategy(feature_columns, that must be provided at serving time (excluding labels!). default_output_alternative_key: the name of the head to serve when an incoming serving request does not explicitly request a specific head. - Must be `None` if the estimator inherits from ${tf.estimator.Estimator} + Must be `None` if the estimator inherits from @{tf.estimator.Estimator} or for single-headed models. assets_extra: A dict specifying how to populate the assets.extra directory within the exported SavedModel. Each key should give the destination @@ -555,8 +571,14 @@ def _default_compare_fn(curr_best_eval_result, cand_eval_result): class BestModelSelector(object): - """A helper that keeps track of export selection candidates.""" + """A helper that keeps track of export selection candidates. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def __init__(self, event_file_pattern=None, compare_fn=None): """Constructor of this class. @@ -622,6 +644,7 @@ class BestModelSelector(object): return best_eval_result +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def make_best_model_export_strategy( serving_input_fn, exports_to_keep=1, @@ -707,6 +730,7 @@ def make_best_model_export_strategy( # TODO(b/67013778): Revisit this approach when corresponding changes to # TF Core are finalized. +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def extend_export_strategy(base_export_strategy, post_export_fn, post_export_name=None): @@ -741,7 +765,7 @@ def extend_export_strategy(base_export_strategy, The string path to the SavedModel indicated by post_export_fn. Raises: - ValueError: If `estimator` is a ${tf.estimator.Estimator} instance + ValueError: If `estimator` is a @{tf.estimator.Estimator} instance and `default_output_alternative_key` was specified or if post_export_fn does not return a valid directory. RuntimeError: If unable to create temporary or final export directory. diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD index 208e7bc69be76680868c766bc99429eea5870c80..359255374d2ea2d35fc4b8a8d72fccc280137979 100644 --- a/tensorflow/contrib/linalg/BUILD +++ b/tensorflow/contrib/linalg/BUILD @@ -43,6 +43,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "linear_operator_block_diag_test", + size = "medium", + srcs = ["python/kernel_tests/linear_operator_block_diag_test.py"], + additional_deps = [ + ":linalg_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], + shard_count = 4, + tags = ["noasan"], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py index 4720692c3384ba1bede1f486c1b1e0e69d10a63a..14cc3b2b4971de1a31960ee33c2f304154b1f411 100644 --- a/tensorflow/contrib/linalg/__init__.py +++ b/tensorflow/contrib/linalg/__init__.py @@ -17,6 +17,7 @@ See the @{$python/contrib.linalg} guide. @@LinearOperator +@@LinearOperatorBlockDiag @@LinearOperatorDiag @@LinearOperatorIdentity @@LinearOperatorScaledIdentity @@ -34,6 +35,7 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member from tensorflow.contrib.linalg.python.ops.linear_operator_addition import * +from tensorflow.contrib.linalg.python.ops.linear_operator_block_diag import * from tensorflow.python.ops.linalg.linear_operator import * from tensorflow.python.ops.linalg.linear_operator_composition import * from tensorflow.python.ops.linalg.linear_operator_diag import * @@ -45,4 +47,5 @@ from tensorflow.python.ops.linalg.linear_operator_lower_triangular import * # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member from tensorflow.python.util.all_util import remove_undocumented + remove_undocumented(__name__) diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_block_diag_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_block_diag_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cc1a047d6a2b6029080fad3f240aa00f50504f07 --- /dev/null +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_block_diag_test.py @@ -0,0 +1,253 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.linalg.python.ops import linear_operator_block_diag as block_diag +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.linalg import linalg as linalg_lib +from tensorflow.python.ops.linalg import linear_operator_test_util +from tensorflow.python.ops.linalg import linear_operator_util +from tensorflow.python.platform import test + +linalg = linalg_lib +random_seed.set_random_seed(23) +rng = np.random.RandomState(0) + + +def _block_diag_dense(expected_shape, blocks): + """Convert a list of blocks, into a dense block diagonal matrix.""" + rows = [] + num_cols = 0 + for block in blocks: + # Get the batch shape for the block. + batch_row_shape = array_ops.shape(block)[:-1] + + zeros_to_pad_before_shape = array_ops.concat( + [batch_row_shape, [num_cols]], axis=-1) + zeros_to_pad_before = array_ops.zeros( + shape=zeros_to_pad_before_shape, dtype=block.dtype) + num_cols += array_ops.shape(block)[-1] + zeros_to_pad_after_shape = array_ops.concat( + [batch_row_shape, [expected_shape[-2] - num_cols]], axis=-1) + zeros_to_pad_after = array_ops.zeros( + zeros_to_pad_after_shape, dtype=block.dtype) + + rows.append(array_ops.concat( + [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1)) + + return array_ops.concat(rows, axis=-2) + + +class SquareLinearOperatorBlockDiagTest( + linear_operator_test_util.SquareLinearOperatorDerivedClassTest): + """Most tests done in the base class LinearOperatorDerivedClassTest.""" + + def setUp(self): + # Increase from 1e-6 to 1e-4 + self._atol[dtypes.float32] = 1e-4 + self._atol[dtypes.complex64] = 1e-4 + self._rtol[dtypes.float32] = 1e-4 + self._rtol[dtypes.complex64] = 1e-4 + + @property + def _operator_build_infos(self): + build_info = linear_operator_test_util.OperatorBuildInfo + return [ + build_info((0, 0)), + build_info((1, 1)), + build_info((1, 3, 3)), + build_info((5, 5), blocks=[(2, 2), (3, 3)]), + ] + + def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + shape = list(build_info.shape) + expected_blocks = ( + build_info.__dict__["blocks"] if "blocks" in build_info.__dict__ + else [shape]) + matrices = [ + linear_operator_test_util.random_positive_definite_matrix( + block_shape, dtype, force_well_conditioned=True) + for block_shape in expected_blocks + ] + + if use_placeholder: + matrices_ph = [ + array_ops.placeholder(dtype=dtype) for _ in expected_blocks + ] + # Evaluate here because (i) you cannot feed a tensor, and (ii) + # values are random and we want the same value used for both mat and + # feed_dict. + matrices = self.evaluate(matrices) + operator = block_diag.LinearOperatorBlockDiag( + [linalg.LinearOperatorFullMatrix( + m_ph, is_square=True) for m_ph in matrices_ph], + is_square=True) + feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)} + else: + operator = block_diag.LinearOperatorBlockDiag( + [linalg.LinearOperatorFullMatrix( + m, is_square=True) for m in matrices]) + feed_dict = None + # Should be auto-set. + self.assertTrue(operator.is_square) + + # Broadcast the shapes. + expected_shape = list(build_info.shape) + + matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices) + + block_diag_dense = _block_diag_dense(expected_shape, matrices) + + if not use_placeholder: + block_diag_dense.set_shape( + expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]]) + + return operator, block_diag_dense, feed_dict + + def test_is_x_flags(self): + # Matrix with two positive eigenvalues, 1, and 1. + # The matrix values do not effect auto-setting of the flags. + matrix = [[1., 0.], [1., 1.]] + operator = block_diag.LinearOperatorBlockDiag( + [linalg.LinearOperatorFullMatrix(matrix)], + is_positive_definite=True, + is_non_singular=True, + is_self_adjoint=False) + self.assertTrue(operator.is_positive_definite) + self.assertTrue(operator.is_non_singular) + self.assertFalse(operator.is_self_adjoint) + + def test_is_non_singular_auto_set(self): + # Matrix with two positive eigenvalues, 11 and 8. + # The matrix values do not effect auto-setting of the flags. + matrix = [[11., 0.], [1., 8.]] + operator_1 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True) + operator_2 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True) + + operator = block_diag.LinearOperatorBlockDiag( + [operator_1, operator_2], + is_positive_definite=False, # No reason it HAS to be False... + is_non_singular=None) + self.assertFalse(operator.is_positive_definite) + self.assertTrue(operator.is_non_singular) + + with self.assertRaisesRegexp(ValueError, "always non-singular"): + block_diag.LinearOperatorBlockDiag( + [operator_1, operator_2], is_non_singular=False) + + def test_name(self): + matrix = [[11., 0.], [1., 8.]] + operator_1 = linalg.LinearOperatorFullMatrix(matrix, name="left") + operator_2 = linalg.LinearOperatorFullMatrix(matrix, name="right") + + operator = block_diag.LinearOperatorBlockDiag([operator_1, operator_2]) + + self.assertEqual("left_ds_right", operator.name) + + def test_different_dtypes_raises(self): + operators = [ + linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3)), + linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3).astype(np.float32)) + ] + with self.assertRaisesRegexp(TypeError, "same dtype"): + block_diag.LinearOperatorBlockDiag(operators) + + def test_non_square_operator_raises(self): + operators = [ + linalg.LinearOperatorFullMatrix(rng.rand(3, 4), is_square=False), + linalg.LinearOperatorFullMatrix(rng.rand(3, 3)) + ] + with self.assertRaisesRegexp(ValueError, "square matrices"): + block_diag.LinearOperatorBlockDiag(operators) + + def test_empty_operators_raises(self): + with self.assertRaisesRegexp(ValueError, "non-empty"): + block_diag.LinearOperatorBlockDiag([]) + + +# This test is for blocks with different batch dimensions. +# LinearOperatorFullMatrix doesn't broadcast matmul/solve. +class SquareDiagLinearOperatorBlockDiagTest( + linear_operator_test_util.SquareLinearOperatorDerivedClassTest): + """Most tests done in the base class LinearOperatorDerivedClassTest.""" + + def setUp(self): + # Increase from 1e-6 to 1e-4 + self._atol[dtypes.float32] = 1e-4 + self._atol[dtypes.complex64] = 1e-4 + self._rtol[dtypes.float32] = 1e-4 + self._rtol[dtypes.complex64] = 1e-4 + + @property + def _operator_build_infos(self): + build_info = linear_operator_test_util.OperatorBuildInfo + return [ + build_info((3, 7, 7), blocks=[(1, 2, 2), (3, 2, 2), (1, 3, 3)]), + build_info((2, 1, 6, 6), blocks=[(2, 1, 2, 2), (1, 1, 4, 4)]), + ] + + def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + shape = list(build_info.shape) + expected_blocks = ( + build_info.__dict__["blocks"] if "blocks" in build_info.__dict__ + else [shape]) + diag_matrices = [ + linear_operator_test_util.random_uniform( + shape=block_shape[:-1], minval=1., maxval=20., dtype=dtype) + for block_shape in expected_blocks + ] + + if use_placeholder: + diag_matrices_ph = [ + array_ops.placeholder(dtype=dtype) for _ in expected_blocks + ] + diag_matrices = self.evaluate(diag_matrices) + # Evaluate here because (i) you cannot feed a tensor, and (ii) + # values are random and we want the same value used for both mat and + # feed_dict. + operator = block_diag.LinearOperatorBlockDiag( + [linalg.LinearOperatorDiag(m_ph) for m_ph in diag_matrices_ph]) + feed_dict = {m_ph: m for (m_ph, m) in zip( + diag_matrices_ph, diag_matrices)} + else: + operator = block_diag.LinearOperatorBlockDiag( + [linalg.LinearOperatorDiag(m) for m in diag_matrices]) + feed_dict = None + # Should be auto-set. + self.assertTrue(operator.is_square) + + # Broadcast the shapes. + expected_shape = list(build_info.shape) + + matrices = linear_operator_util.broadcast_matrix_batch_dims( + [array_ops.matrix_diag(diag_block) for diag_block in diag_matrices]) + + block_diag_dense = _block_diag_dense(expected_shape, matrices) + if not use_placeholder: + block_diag_dense.set_shape( + expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]]) + + return operator, block_diag_dense, feed_dict + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_block_diag.py b/tensorflow/contrib/linalg/python/ops/linear_operator_block_diag.py new file mode 100644 index 0000000000000000000000000000000000000000..80649bd52da76452e0427f341ff686c26d70a70f --- /dev/null +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_block_diag.py @@ -0,0 +1,371 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Create a Block Diagonal operator from one or more `LinearOperators`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import common_shapes +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops.linalg import linear_operator +from tensorflow.python.ops.linalg import linear_operator_util + + +class LinearOperatorBlockDiag(linear_operator.LinearOperator): + """Combines one or more `LinearOperators` in to a Block Diagonal matrix. + + This operator combines one or more linear operators `[op1,...,opJ]`, + building a new `LinearOperator`, whose underlying matrix representation is + square and has each operator `opi` on the main diagonal, and zero's elsewhere. + + #### Shape compatibility + + If `opj` acts like a [batch] square matrix `Aj`, then `op_combined` acts like + the [batch] square matrix formed by having each matrix `Aj` on the main + diagonal. + + + Each `opj` is required to represent a square matrix, and hence will have + shape `batch_shape_j + [M_j, M_j]`. + + If `opj` has shape `batch_shape_j + [M_j, M_j]`, then the combined operator + has shape `broadcast_batch_shape + [sum M_j, sum M_j]`, where + `broadcast_batch_shape` is the mutual broadcast of `batch_shape_j`, + `j = 1,...,J`, assuming the intermediate batch shapes broadcast. + Even if the combined shape is well defined, the combined operator's + methods may fail due to lack of broadcasting ability in the defining + operators' methods. + + ```python + # Create a 4 x 4 linear operator combined of two 2 x 2 operators. + operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]]) + operator_2 = LinearOperatorFullMatrix([[1., 0.], [0., 1.]]) + operator = LinearOperatorBlockDiag([operator_1, operator_2]) + + operator.to_dense() + ==> [[1., 2., 0., 0.], + [3., 4., 0., 0.], + [0., 0., 1., 0.], + [0., 0., 0., 1.]] + + operator.shape + ==> [4, 4] + + operator.log_abs_determinant() + ==> scalar Tensor + + x1 = ... # Shape [2, 2] Tensor + x2 = ... # Shape [2, 2] Tensor + x = tf.concat([x1, x2], 0) # Shape [2, 4] Tensor + operator.matmul(x) + ==> tf.concat([operator_1.matmul(x1), operator_2.matmul(x2)]) + + # Create a [2, 3] batch of 4 x 4 linear operators. + matrix_44 = tf.random_normal(shape=[2, 3, 4, 4]) + operator_44 = LinearOperatorFullMatrix(matrix) + + # Create a [1, 3] batch of 5 x 5 linear operators. + matrix_55 = tf.random_normal(shape=[1, 3, 5, 5]) + operator_55 = LinearOperatorFullMatrix(matrix_55) + + # Combine to create a [2, 3] batch of 9 x 9 operators. + operator_99 = LinearOperatorBlockDiag([operator_44, operator_55]) + + # Create a shape [2, 3, 9] vector. + x = tf.random_normal(shape=[2, 3, 9]) + operator_99.matmul(x) + ==> Shape [2, 3, 9] Tensor + ``` + + #### Performance + + The performance of `LinearOperatorBlockDiag` on any operation is equal to + the sum of the individual operators' operations. + + + #### Matrix property hints + + This `LinearOperator` is initialized with boolean flags of the form `is_X`, + for `X = non_singular, self_adjoint, positive_definite, square`. + These have the following meaning: + + * If `is_X == True`, callers should expect the operator to have the + property `X`. This is a promise that should be fulfilled, but is *not* a + runtime assert. For example, finite floating point precision may result + in these promises being violated. + * If `is_X == False`, callers should expect the operator to not have `X`. + * If `is_X == None` (the default), callers should have no expectation either + way. + """ + + def __init__(self, + operators, + is_non_singular=None, + is_self_adjoint=None, + is_positive_definite=None, + is_square=True, + name=None): + r"""Initialize a `LinearOperatorBlockDiag`. + + `LinearOperatorBlockDiag` is initialized with a list of operators + `[op_1,...,op_J]`. + + Args: + operators: Iterable of `LinearOperator` objects, each with + the same `dtype` and composable shape. + is_non_singular: Expect that this operator is non-singular. + is_self_adjoint: Expect that this operator is equal to its hermitian + transpose. + is_positive_definite: Expect that this operator is positive definite, + meaning the quadratic form `x^H A x` has positive real part for all + nonzero `x`. Note that we do not require the operator to be + self-adjoint to be positive-definite. See: + https://en.wikipedia.org/wiki/Positive-definite_matrix\ + #Extension_for_non_symmetric_matrices + is_square: Expect that this operator acts like square [batch] matrices. + This is true by default, and will raise a `ValueError` otherwise. + name: A name for this `LinearOperator`. Default is the individual + operators names joined with `_o_`. + + Raises: + TypeError: If all operators do not have the same `dtype`. + ValueError: If `operators` is empty or are non-square. + """ + # Validate operators. + check_ops.assert_proper_iterable(operators) + operators = list(operators) + if not operators: + raise ValueError( + "Expected a non-empty list of operators. Found: %s" % operators) + self._operators = operators + + # Validate dtype. + dtype = operators[0].dtype + for operator in operators: + if operator.dtype != dtype: + name_type = (str((o.name, o.dtype)) for o in operators) + raise TypeError( + "Expected all operators to have the same dtype. Found %s" + % " ".join(name_type)) + + # Auto-set and check hints. + if all(operator.is_non_singular for operator in operators): + if is_non_singular is False: + raise ValueError( + "The direct sum of non-singular operators is always non-singular.") + is_non_singular = True + + if all(operator.is_self_adjoint for operator in operators): + if is_self_adjoint is False: + raise ValueError( + "The direct sum of self-adjoint operators is always self-adjoint.") + is_self_adjoint = True + + if all(operator.is_positive_definite for operator in operators): + if is_positive_definite is False: + raise ValueError( + "The direct sum of positive definite operators is always " + "positive definite.") + is_positive_definite = True + + if not (is_square and all(operator.is_square for operator in operators)): + raise ValueError( + "Can only represent a block diagonal of square matrices.") + + # Initialization. + graph_parents = [] + for operator in operators: + graph_parents.extend(operator.graph_parents) + + if name is None: + # Using ds to mean direct sum. + name = "_ds_".join(operator.name for operator in operators) + with ops.name_scope(name, values=graph_parents): + super(LinearOperatorBlockDiag, self).__init__( + dtype=dtype, + graph_parents=graph_parents, + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + is_square=True, + name=name) + + @property + def operators(self): + return self._operators + + def _shape(self): + # Get final matrix shape. + domain_dimension = self.operators[0].domain_dimension + range_dimension = self.operators[0].range_dimension + for operator in self.operators[1:]: + domain_dimension += operator.domain_dimension + range_dimension += operator.range_dimension + + matrix_shape = tensor_shape.TensorShape([domain_dimension, range_dimension]) + + # Get broadcast batch shape. + # broadcast_shape checks for compatibility. + batch_shape = self.operators[0].batch_shape + for operator in self.operators[1:]: + batch_shape = common_shapes.broadcast_shape( + batch_shape, operator.batch_shape) + + return batch_shape.concatenate(matrix_shape) + + def _shape_tensor(self): + # Avoid messy broadcasting if possible. + if self.shape.is_fully_defined(): + return ops.convert_to_tensor( + self.shape.as_list(), dtype=dtypes.int32, name="shape") + + domain_dimension = self.operators[0].domain_dimension_tensor() + range_dimension = self.operators[0].range_dimension_tensor() + for operator in self.operators[1:]: + domain_dimension += operator.domain_dimension_tensor() + range_dimension += operator.range_dimension_tensor() + + matrix_shape = array_ops.stack([domain_dimension, range_dimension]) + + # Dummy Tensor of zeros. Will never be materialized. + zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor()) + for operator in self.operators[1:]: + zeros += array_ops.zeros(shape=operator.batch_shape_tensor()) + batch_shape = array_ops.shape(zeros) + + return array_ops.concat((batch_shape, matrix_shape), 0) + + def _matmul(self, x, adjoint=False, adjoint_arg=False): + split_dim = -1 if adjoint_arg else -2 + # Split input by rows normally, and otherwise columns. + split_x = self._split_input_into_blocks(x, axis=split_dim) + + result_list = [] + for index, operator in enumerate(self.operators): + result_list += [operator.matmul( + split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg)] + result_list = linear_operator_util.broadcast_matrix_batch_dims( + result_list) + return array_ops.concat(result_list, axis=-2) + + def _determinant(self): + result = self.operators[0].determinant() + for operator in self.operators[1:]: + result *= operator.determinant() + return result + + def _log_abs_determinant(self): + result = self.operators[0].log_abs_determinant() + for operator in self.operators[1:]: + result += operator.log_abs_determinant() + return result + + def _solve(self, rhs, adjoint=False, adjoint_arg=False): + split_dim = -1 if adjoint_arg else -2 + # Split input by rows normally, and otherwise columns. + split_rhs = self._split_input_into_blocks(rhs, axis=split_dim) + + solution_list = [] + for index, operator in enumerate(self.operators): + solution_list += [operator.solve( + split_rhs[index], adjoint=adjoint, adjoint_arg=adjoint_arg)] + + solution_list = linear_operator_util.broadcast_matrix_batch_dims( + solution_list) + return array_ops.concat(solution_list, axis=-2) + + def _diag_part(self): + diag_list = [] + for operator in self.operators: + # Extend the axis for broadcasting. + diag_list += [operator.diag_part()[..., array_ops.newaxis]] + diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list) + diagonal = array_ops.concat(diag_list, axis=-2) + return array_ops.squeeze(diagonal, axis=-1) + + def _trace(self): + result = self.operators[0].trace() + for operator in self.operators[1:]: + result += operator.trace() + return result + + def _to_dense(self): + num_cols = 0 + rows = [] + broadcasted_blocks = [operator.to_dense() for operator in self.operators] + broadcasted_blocks = linear_operator_util.broadcast_matrix_batch_dims( + broadcasted_blocks) + for block in broadcasted_blocks: + batch_row_shape = array_ops.shape(block)[:-1] + + zeros_to_pad_before_shape = array_ops.concat( + [batch_row_shape, [num_cols]], axis=-1) + zeros_to_pad_before = array_ops.zeros( + shape=zeros_to_pad_before_shape, dtype=block.dtype) + num_cols += array_ops.shape(block)[-1] + zeros_to_pad_after_shape = array_ops.concat( + [batch_row_shape, + [self.domain_dimension_tensor() - num_cols]], axis=-1) + zeros_to_pad_after = array_ops.zeros( + shape=zeros_to_pad_after_shape, dtype=block.dtype) + + rows.append(array_ops.concat( + [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1)) + + mat = array_ops.concat(rows, axis=-2) + mat.set_shape(self.shape) + return mat + + def _assert_non_singular(self): + return control_flow_ops.group([ + operator.assert_non_singular() for operator in self.operators]) + + def _assert_self_adjoint(self): + return control_flow_ops.group([ + operator.assert_self_adjoint() for operator in self.operators]) + + def _assert_positive_definite(self): + return control_flow_ops.group([ + operator.assert_positive_definite() for operator in self.operators]) + + def _split_input_into_blocks(self, x, axis=-1): + """Split `x` into blocks matching `operators`'s `domain_dimension`. + + Specifically, if we have a block diagonal matrix, with block sizes + `[M_j, M_j] j = 1..J`, this method splits `x` on `axis` into `J` + tensors, whose shape at `axis` is `M_j`. + + Args: + x: `Tensor`. `x` is split into `J` tensors. + axis: Python `Integer` representing the axis to split `x` on. + + Returns: + A list of `Tensor`s. + """ + block_sizes = [] + if self.shape.is_fully_defined(): + for operator in self.operators: + block_sizes += [operator.domain_dimension.value] + else: + for operator in self.operators: + block_sizes += [operator.domain_dimension_tensor()] + + return array_ops.split(x, block_sizes, axis=axis) diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index 70f777f08bd5b8157e601f19019075d3e7543811..cfe62fac43b35d863eb559b95057ae62a41bed49 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -270,14 +270,14 @@ class SdcaWithLogisticLossTest(SdcaModelTest): train_op = lr.minimize() - def Minimize(): + def minimize(): with self._single_threaded_test_session(): for _ in range(_MAX_ITERATIONS): - train_op.run() + train_op.run() # pylint: disable=cell-var-from-loop threads = [] for _ in range(num_loss_partitions): - threads.append(threading.Thread(target=Minimize)) + threads.append(threading.Thread(target=minimize)) threads[-1].start() for t in threads: @@ -395,7 +395,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): predicted_labels = get_binary_predictions_for_logistic(predictions) self.assertAllClose([0, 1, 1, 1], predicted_labels.eval()) self.assertAllClose( - 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) + 0.0, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) def testFractionalExampleLabel(self): # Setup test data with 1 positive, and 1 mostly-negative example. @@ -407,7 +407,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): make_example_proto({ 'age': [1], 'gender': [1] - }, 1), + }, 0.9), ] example_weights = [1.0, 1.0] for num_shards in _SHARD_NUMBERS: diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py index 05794a42c5f2d0eece6adab36fb5610078cece31..d4e54c82f988e0adcd16aad29702ee9f8b16aea3 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py @@ -140,8 +140,8 @@ def sdca_model_fn(features, labels, mode, params, config=None): parent_scope = "linear" - with variable_scope.variable_op_scope(features.values(), - parent_scope) as scope: + with variable_scope.variable_scope( + values=features.values(), name_or_scope=parent_scope) as scope: features = features.copy() features.update(layers.transform_features(features, feature_columns)) logits, columns_to_variables, bias = ( diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 44c4a7e2ca8d019ca602c7f2b492cd1e70b17561..18efa64507c95ac7b8d37bd9a8b62c9335b7b5d0 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -132,10 +132,11 @@ cc_library( ":memory_planner", ":schema_fbs_version", ":simple_memory_arena", + ":util", + "//tensorflow/contrib/lite/kernels:eigen_support", "//tensorflow/contrib/lite/kernels:gemm_support", "//tensorflow/contrib/lite/nnapi:nnapi_lib", "//tensorflow/contrib/lite/schema:schema_fbs", - "//tensorflow/core:lib_platform", ], ) @@ -169,6 +170,7 @@ cc_test( deps = [ ":framework", ":string_util", + "//tensorflow/contrib/lite/kernels:kernel_util", "//tensorflow/contrib/lite/kernels/internal:tensor_utils", "//tensorflow/contrib/lite/schema:schema_fbs", "//tensorflow/contrib/lite/testing:util", @@ -232,6 +234,27 @@ cc_test( ], ) +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + deps = [ + ":context", + ], +) + +cc_test( + name = "util_test", + size = "small", + srcs = ["util_test.cc"], + deps = [ + ":context", + ":util", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + # Test the serialization of a model with optional tensors. # Model tests diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile index 7f316292724ea0baaf034d4e914773ad97a957d4..b4504f246a0f806d35d8c3d659717a86d2f2a4f5 100644 --- a/tensorflow/contrib/lite/Makefile +++ b/tensorflow/contrib/lite/Makefile @@ -27,10 +27,10 @@ LIBDIR := $(MAKEFILE_DIR)/gen/lib/ GENDIR := $(MAKEFILE_DIR)/gen/obj/ # Settings for the host compiler. -CXX := $(CC_PREFIX) gcc +CXX := $(CC_PREFIX)gcc CXXFLAGS := --std=c++11 -O3 -DNDEBUG -CC := $(CC_PREFIX) gcc -CFLAGS := +CC := $(CC_PREFIX)gcc +CFLAGS := -O3 -DNDEBUG LDOPTS := LDOPTS += -L/usr/local/lib ARFLAGS := -r @@ -57,10 +57,11 @@ LIBS := \ # If we're on Linux, also link in the dl library. ifeq ($(HOST_OS),LINUX) - LIBS += -ldl -lpthread + LIBS += -ldl endif include $(MAKEFILE_DIR)/ios_makefile.inc +include $(MAKEFILE_DIR)/rpi_makefile.inc # This library is the main target for this makefile. It will contain a minimal # runtime that can be linked in to other programs. diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md index 00e93d2c4f3ab27057b855fba6fccf2ec8d7a1c1..c15ae3f233ed6a697e2df7a539e0ba131d4dd1d9 100644 --- a/tensorflow/contrib/lite/README.md +++ b/tensorflow/contrib/lite/README.md @@ -91,7 +91,7 @@ Currently, we only support building the Android demo app within a Python 2 environment (due to a Bazel bug). ### More about the demo -The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used (229 * 229 for Inception-v3). The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch. 224 * 224 (299 * 299) is the width and height of the image. 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. Both models have 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The model file must be downloaded and bundled within the assets directory of the app. +The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used (299 * 299 for Inception-v3). The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch. 224 * 224 (299 * 299) is the width and height of the image. 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. Both models have 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The model file must be downloaded and bundled within the assets directory of the app. # iOS Demo App @@ -99,7 +99,7 @@ Similar to the Android demo app, there's an iOS camera app that uses exactly the This demo app requires a camera so it doesn't work with simulators. It need to be executed on a real iOS device. Follow the instructions to build and run the demo app: -1. Run `third_party/tensorflow/contrib/lite/examples/ios/download_models.sh` to download the model files used by the demo app. +1. Run `tensorflow/contrib/lite/examples/ios/download_models.sh` to download the model files used by the demo app. 1. Install [CocoaPods](https://cocoapods.org/) if it wasn't installed yet: `sudo gem install cocoapods`. 1. Run `pod install` in `tensorflow/contrib/lite/examples/ios/camera` to generate the workspace file. 1. Open the project by running `open tflite_camera_example.xcworkspace`, and build the app in XCode. @@ -126,6 +126,9 @@ The above pre-trained models have been trained on the ImageNet data set, which c The [TensorFlow for Poets](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/) codelab walks through this process step-by-step. The retraining code supports retraining for both floating point and quantized inference. +# Getting started with RaspberryPi + +Using RaspberryPi can be accomplished by following the [Makefile instructions](g3doc/rpi.md). That will give a you a static library (.a) that you can build your app against. Python bindings will be coming soon as well as a demo app. ### Train a custom model A developer may choose to train a custom model using Tensorflow. TensorFlow documentation has [several tutorials](https://www.tensorflow.org/tutorials/) for building and training models. If the user has written a model using TensorFlow's Slim Framework the first step is to export this to a GraphDef file. This is necessary because Slim does not store the model structure outside the code, so to communicate with other parts of the framework it needs to be exported. Documentation for the export can be found [here](https://github.com/tensorflow/models/tree/master/research/slim#Export). The output of this step will be a .pb file for the custom model. @@ -165,7 +168,7 @@ bazel-bin/tensorflow/python/tools/freeze_graph\ --input_graph=/tmp/mobilenet_v1_224.pb \ --input_checkpoint=/tmp/checkpoints/mobilenet-10202.ckpt \ --input_binary=true --output_graph=/tmp/frozen_mobilenet_v1_224.pb \ - --output_node_names=MobileNet/Predictions/Reshape_1 + --output_node_names=MobilenetV1/Predictions/Reshape_1 ``` The user has to first build the freeze_graph script using bazel and then run the script. The input_binary flag has to be enabled to ensure that the protobuf is read and written in binary format. The user has to input the .pb and the .ckpt files to freeze the graph The output_node_names may not be obvious outside of the code that built the model. The easiest way to find them is to visualize the graph, either with diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h index 58bc164619c2c053b9492e9a0e5de2da30e199af..f84b3dad9550e789237c8e45971002c7d336b9d3 100644 --- a/tensorflow/contrib/lite/arena_planner.h +++ b/tensorflow/contrib/lite/arena_planner.h @@ -33,7 +33,7 @@ class AllocationInfo; // each tensor needs to be allocated and deallocated, and preallocates all the // necessary memory (the PlanAllocations phase). It then assigns portions of // this memory buffer to each tensor (the ExecuteAllocations phase). Tensors may -// share some of the bufer if a tensor B is to be allocated after another tensor +// share some of the buffer if a tensor B is to be allocated after another tensor // A has been deallocated. // // If dynamic tensors are used the planning steps can be repeated during model diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 19829e4991651111e13fc1805f97daef8bc016a7..2813d1c347163e67c70983d3dd49773f4a4b4544 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -104,7 +104,7 @@ def tflite_jni_binary(name, """Builds a jni binary for TFLite.""" linkopts = linkopts + [ "-Wl,--version-script", # Export only jni functions & classes. - linkscript, + "$(location {})".format(linkscript), ] native.cc_binary( name=name, diff --git a/tensorflow/tools/integration_tests/gcs_smoke_test/setup.sh b/tensorflow/contrib/lite/build_rpi_lib.sh similarity index 69% rename from tensorflow/tools/integration_tests/gcs_smoke_test/setup.sh rename to tensorflow/contrib/lite/build_rpi_lib.sh index 6553ba5e3093c26d3c95f40216cd3922a1fb9e4e..3824b16412ed26a6cab79df3242da6017c3322b0 100755 --- a/tensorflow/tools/integration_tests/gcs_smoke_test/setup.sh +++ b/tensorflow/contrib/lite/build_rpi_lib.sh @@ -1,5 +1,5 @@ -#!/bin/bash -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +#!/bin/bash -x +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -GCS_NUMBER=$(cat /dev/urandom | tr -dc 'A-F0-9' | fold -w 8 | head -n 1) -GCS_PATH="$1"/"$GCS_NUMBER".tfrecord -echo "gcs_path=$GCS_PATH" > "$_SETUP_OUTPUT" -touch "$_SETUP_DONE" +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR/../../.." + +CC_PREFIX=arm-linux-gnueabihf- make -j 3 -f tensorflow/contrib/lite/Makefile TARGET=RPI TARGET_ARCH=armv7 diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index 88cdf1d46312f1e610825f23f3d8d357b0762bac..d7993e60cc77839b823e17ce11f8a57d3e0972db 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -24,14 +24,14 @@ extern "C" { #endif // __cplusplus // The enum for builtin operators. -// Note: CUSTOM and DELEGATE are 2 special ops which are not real biultin -// ops. +// Note: CUSTOM and DELEGATE are 2 special ops which are not real built-in ops. typedef enum { kTfLiteBuiltinAdd = 0, kTfLiteBuiltinAveragePool2d = 1, kTfLiteBuiltinConcatenation = 2, kTfLiteBuiltinConv2d = 3, kTfLiteBuiltinDepthwiseConv2d = 4, + kTfLiteBuiltinDequantize = 6, kTfLiteBuiltinEmbeddingLookup = 7, kTfLiteBuiltinFullyConnected = 9, kTfLiteBuiltinHashtableLookup = 10, @@ -77,6 +77,8 @@ typedef enum { kTfLiteBuiltinLogSoftmax = 50, kTfLiteBuiltinDelegate = 51, kTfLiteBuiltinBidirectionalSequenceLstm = 52, + kTfLiteBuiltinCast = 53, + kTfLiteBuiltinPrelu = 54, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/context.c index c09e838c5c2e50e0f4a38eaf66e55246fd9a6f7f..5c6f5e72a47180cd98be46f60cfa8eaf28197806 100644 --- a/tensorflow/contrib/lite/context.c +++ b/tensorflow/contrib/lite/context.c @@ -17,9 +17,14 @@ limitations under the License. #include #include +int TfLiteIntArrayGetSizeInBytes(int size) { + static TfLiteIntArray dummy; + return sizeof(dummy) + sizeof(dummy.data[0]) * size; +} + TfLiteIntArray* TfLiteIntArrayCreate(int size) { TfLiteIntArray* ret = - (TfLiteIntArray*)malloc(sizeof(*ret) + sizeof(ret->data[0]) * size); + (TfLiteIntArray*)malloc(TfLiteIntArrayGetSizeInBytes(size)); ret->size = size; return ret; } @@ -55,12 +60,16 @@ TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src) { void TfLiteIntArrayFree(TfLiteIntArray* a) { free(a); } -void TfLiteTensorFree(TfLiteTensor* t) { +void TfLiteTensorDataFree(TfLiteTensor* t) { if (t->allocation_type == kTfLiteDynamic && t->data.raw) { free(t->data.raw); } - if (t->dims) TfLiteIntArrayFree(t->dims); t->data.raw = NULL; +} + +void TfLiteTensorFree(TfLiteTensor* t) { + TfLiteTensorDataFree(t); + if (t->dims) TfLiteIntArrayFree(t->dims); t->dims = NULL; } diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index ed7f4515fa4437d61a37be93616c28a046295c5a..45184b05ecefb504c75815ae900f3b605359a443 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -29,6 +29,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ #define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ +#include #include #include @@ -40,6 +41,7 @@ typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; // Forward declare so GetNode can use this is in Context. typedef struct _TfLiteRegistration TfLiteRegistration; +typedef struct _TfLiteDelegate TfLiteDelegate; #define kOptionalTensor (-1) @@ -57,6 +59,10 @@ typedef struct { #endif } TfLiteIntArray; +// Given the size (number of elements) in a TfLiteIntArray, calculate its size +// in bytes. +int TfLiteIntArrayGetSizeInBytes(int size); + // Create a array of a given `size` (uninitialized entries). // This returns a pointer, that you must free using TfLiteIntArrayFree(). TfLiteIntArray* TfLiteIntArrayCreate(int size); @@ -162,6 +168,11 @@ typedef enum { kTfLiteDynamic, } TfLiteAllocationType; +// The delegates should use zero or positive integers to represent handles. +// -1 is reserved from unallocated status. +typedef int TfLiteBufferHandle; +const TfLiteBufferHandle kTfLiteNullBufferHandle = -1; + // An tensor in the interpreter system which is a wrapper around a buffer of // data including a dimensionality (or NULL if not currently defined). typedef struct { @@ -194,8 +205,27 @@ typedef struct { // Null-terminated name of this tensor. const char* name; + + // The delegate which knows how to handle `buffer_handle`. + // WARNING: This is an experimental interface that is subject to change. + TfLiteDelegate* delegate; + + // An integer buffer handle that can be handled by `delegate`. + // The value is valid only when delegate is not null. + // WARNING: This is an experimental interface that is subject to change. + TfLiteBufferHandle buffer_handle; + + // If the delegate uses its own buffer (e.g. GPU memory), the delegate is + // responsible to set data_is_stale to true. + // `delegate->CopyFromBufferHandle` can be called to copy the data from + // delegate buffer. + // WARNING: This is an // experimental interface that is subject to change. + bool data_is_stale; } TfLiteTensor; +// Free data memory of tensor `t`; +void TfLiteTensorDataFree(TfLiteTensor* t); + // Free memory of tensor `t`; void TfLiteTensorFree(TfLiteTensor* t); @@ -234,6 +264,11 @@ typedef struct { // WARNING: This is an experimental interface that is subject to change. const void* custom_initial_data; int custom_initial_data_size; + + // The pointer to the delegate. This is non-null only when the node is + // created by calling `interpreter.ModifyGraphWithDelegate`. + // WARNING: This is an experimental interface that is subject to change. + TfLiteDelegate* delegate; } TfLiteNode; typedef struct TfLiteContext { @@ -287,11 +322,16 @@ typedef struct TfLiteContext { // does not take ownership of `nodes_to_replace`. TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)( struct TfLiteContext*, TfLiteRegistration registration, - const TfLiteIntArray* nodes_to_replace); + const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate); + + // Number of threads that are recommended to subsystems like gemmlowp and + // eigen. + int recommended_num_threads; // TODO(ahentz): we should create a more general mechanism for this sort of // library-global objects. void* gemm_context; + void* eigen_context; } TfLiteContext; typedef struct _TfLiteRegistration { @@ -338,19 +378,47 @@ typedef struct _TfLiteRegistration { } TfLiteRegistration; // WARNING: This is an experimental interface that is subject to change. -typedef struct { +typedef struct _TfLiteDelegate { // Data that delegate needs to identify itself. This data is owned by the // delegate. The delegate is owned in the user code, so the delegate is // responsible for doing this when it is destroyed. void* data_; + // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the // delegate a view of the current graph through TfLiteContext*. It typically // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels() // to ask the TensorFlow lite runtime to create macro-nodes to represent // delegated subgraphs of the original graph. - TfLiteStatus (*Prepare)(TfLiteContext* context, void* data); + TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate); + + // Copy the data from delegate buffer handle to raw memory. + // This can be null if the delegate doesn't use its own buffer. + TfLiteStatus (*CopyFromBufferHandle)(TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + void* data, int size); + + // Copy the data from raw memory to delegate buffer handle. + // This can be null if the delegate doesn't use its own buffer. + TfLiteStatus (*CopyToBufferHandle)(TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + void* data, int size); + + // Free the Delegate Buffer Handle. Note: This only frees the handle, but + // this doesn't release the underlying resource (e.g. textures). The + // resources are either owned by application layer or the delegate. + // This can be null if the delegate doesn't use its own buffer. + void (*FreeBufferHandle)(TfLiteDelegate* delegate, + TfLiteBufferHandle* handle); } TfLiteDelegate; +// WARNING: This is an experimental interface that is subject to change. +typedef struct { + TfLiteDelegate* delegate; + TfLiteIntArray* nodes_to_replace; + TfLiteIntArray* input_tensors; + TfLiteIntArray* output_tensors; +} TfLiteDelegateParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h index da193d2586e9123341b9a41be049ee2a4382017a..3c5f805f12f6a1fb7185c140604f692ac282a143 100644 --- a/tensorflow/contrib/lite/error_reporter.h +++ b/tensorflow/contrib/lite/error_reporter.h @@ -30,7 +30,7 @@ namespace tflite { // va_list args; // foo.Report("test %d", args); // where args is va_list // -// Sublclass ErrorReporter to provide another reporting destination. +// Subclass ErrorReporter to provide another reporting destination. // For example, if you have a GUI program, you might redirect to a buffer // that drives a GUI error log box. class ErrorReporter { diff --git a/tensorflow/contrib/lite/g3doc/ios.md b/tensorflow/contrib/lite/g3doc/ios.md index a359b8d4b481dbc15cc86db14eabda5433722b8b..e0358a444d6dffc377bf13ee72ba5477359d6e07 100644 --- a/tensorflow/contrib/lite/g3doc/ios.md +++ b/tensorflow/contrib/lite/g3doc/ios.md @@ -22,6 +22,15 @@ Then install brew install automake brew install libtool ``` +If you get an error where either automake or libtool install but do not link correctly, you'll first need to: +```bash +sudo chown -R $(whoami) /usr/local/* +``` +Then follow the instructions to perform the linking: +```bash +brew link automake +brew link libtool +``` Then you need to run a shell script to download the dependencies you need: diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md index 5b393140d61544e6d6e40d4b6ee1872b22cc84b2..48f43d4fc460a3a5307c5ee1f5e096a409a46af5 100644 --- a/tensorflow/contrib/lite/g3doc/models.md +++ b/tensorflow/contrib/lite/g3doc/models.md @@ -1,4 +1,4 @@ -#List of Hosted Models +# List of Hosted Models * [Inception V3 2015](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_2015_2017_11_10.zip) * [Inception V3 Slim 2016](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip) diff --git a/tensorflow/contrib/lite/g3doc/rpi.md b/tensorflow/contrib/lite/g3doc/rpi.md new file mode 100644 index 0000000000000000000000000000000000000000..7a3a231626d0e1c71e474ff4ff16789ebe2901db --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/rpi.md @@ -0,0 +1,50 @@ +# TensorFlow Lite for Raspberry Pi + +## Cross compiling +### Installing toolchian +This has been tested on Ubuntu 16.04.3 64bit and Tensorflow devel docker image [tensorflow/tensorflow:nightly-devel](https://hub.docker.com/r/tensorflow/tensorflow/tags/). + +To cross compiling TensorFlow Lite. First you should install the toolchain and libs. +```bash +sudo apt-get update +sudo apt-get install crossbuild-essential-armhf +``` +> If you are using docker, you may not use `sudo` + +### Building +Clone this Tensorflow repository, Run this script at the root of the repository to download all the dependencies: +> The Tensorflow repository is in `/tensorflow` if you are using `tensorflow/tensorflow:nightly-devel` docker image, just try it. +```bash +./tensorflow/contrib/lite/download_dependencies.sh +``` +Note than you only need to to this once. + +You should then be able to compile: +```bash +./tensorflow/contrib/lite/build_rpi_lib.sh +``` + +This should compile a static library in: +`tensorflow/contrib/lite/gen/lib/rpi_armv7/libtensorflow-lite.a`. + +## Native compiling +This has been tested on Raspberry Pi 3b, Raspbian GNU/Linux 9.1 (stretch), gcc version 6.3.0 20170516 (Raspbian 6.3.0-18+rpi1). + +Log in to you RPI, install the toolchain. +```bash +sudo apt-get instal build-essential +``` + +First, clone this TensorFlow repository. Run this at the root of the repository: +```bash +./tensorflow/contrib/lite/download_dependencies.sh +``` +Note than you only need to to this once. + +You should then be able to compile: +```bash +./tensorflow/contrib/lite/build_rpi_lib.sh +``` + +This should compile a static library in: +`tensorflow/contrib/lite/gen/lib/rpi_armv7/libtensorflow-lite.a`. diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index b1bbb7c67013acfb575cc1e9f9390ba191cbd08e..61ea5231e352f5e014f9200eccae69548574c034 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -30,13 +30,18 @@ quantized training is necessary before conversion. ## Data Format and Broadcasting At the moment TensorFlow Lite supports only TensorFlow's "NHWC" format, and -broadcasting in operations like tf.add and tf.mul is generally not supported. +broadcasting is only support in a limited number of ops (tf.add, tf.mul, tf.sub, +and tf.div). ## Compatible Operations The following TensorFlow operations are usually mapped to their TensorFlow Lite counterparts: +* [tf.batch_to_space_nd](https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd) - + *as long as the input tensor is 4D (1 batch + 2 spatial + 1 other) and the + crops attribute is not used* +* [tf.exp](https://www.tensorflow.org/api_docs/python/tf/exp) * [tf.matmul](https://www.tensorflow.org/api_docs/python/tf/matmul) - *as long as the second argument is constant and transposition is not used* * [tf.nn.avg_pool](https://www.tensorflow.org/api_docs/python/tf/nn/avg_pool) @@ -47,12 +52,30 @@ counterparts: * [tf.nn.l2_normalize](https://www.tensorflow.org/api_docs/python/tf/nn/l2_normalize) - *as long as normalization is done along the last dimension* * [tf.nn.local_response_normalization](https://www.tensorflow.org/api_docs/python/tf/nn/local_response_normalization) +* [tf.nn.log_softmax](https://www.tensorflow.org/api_docs/python/tf/nn/log_softmax) - + *as long as axis is not provided* * [tf.nn.max_pool](https://www.tensorflow.org/api_docs/python/tf/nn/max_pool) * [tf.nn.softmax](https://www.tensorflow.org/api_docs/python/tf/nn/softmax) - *as long as tensors are 2D and axis is the last dimension* +* [tf.nn.top_k](https://www.tensorflow.org/api_docs/python/tf/nn/top_k) +* [tf.pad](https://www.tensorflow.org/api_docs/python/tf/pad) - *as long as + mode and constant_values are not used* +* [tf.reduce_mean](https://www.tensorflow.org/api_docs/python/tf/reduce_mean) - + *as long as the reduction_indices attribute is not used* * [tf.reshape](https://www.tensorflow.org/api_docs/python/tf/reshape) * [tf.sigmoid](https://www.tensorflow.org/api_docs/python/tf/sigmoid) +* [tf.space_to_batch_nd](https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd) - + *as long as the input tensor is 4D (1 batch + 2 spatial + 1 other)* * [tf.space_to_depth](https://www.tensorflow.org/api_docs/python/tf/space_to_depth) +* [tf.split](https://www.tensorflow.org/api_docs/python/tf/split) - *as long + as num is not provided and num_or_size_split contains number of splits as a + 0D tensor* +* [tf.squeeze](https://www.tensorflow.org/api_docs/python/tf/squeeze) - *as + long as axis is not provided* +* [tf.strided_slice](https://www.tensorflow.org/api_docs/python/tf/strided_slice) - + *as long as ellipsis_mask and new_axis_mask are not used* +* [tf.transpose](https://www.tensorflow.org/versions/master/api_docs/python/tf/transpose) - + *as long as conjugate is not used* ## Straightforward Conversions, Constant-Folding and Fusing @@ -91,7 +114,6 @@ Here is a list of TensorFlow operations that are usually removed from the graph: * [tf.shape](https://www.tensorflow.org/api_docs/python/tf/shape) * [tf.sqrt](https://www.tensorflow.org/api_docs/python/tf/sqrt) * [tf.square](https://www.tensorflow.org/api_docs/python/tf/square) -* [tf.squeeze](https://www.tensorflow.org/api_docs/python/tf/squeeze) * [tf.subtract](https://www.tensorflow.org/api_docs/python/tf/subtract) * [tf.tile](https://www.tensorflow.org/api_docs/python/tf/tile) * [tf.nn.batch_norm_with_global_normalization](https://www.tensorflow.org/api_docs/python/tf/nn/batch_norm_with_global_normalization) @@ -109,17 +131,11 @@ fused. TensorFlow operation not listed above are likely unsupported. Notably, the following common ops are not supported at the moment: -* [tf.batch_to_space_nd](https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd) * [tf.depth_to_space](https://www.tensorflow.org/api_docs/python/tf/depth_to_space) * [tf.floor](https://www.tensorflow.org/api_docs/python/tf/floor) * [tf.gather](https://www.tensorflow.org/api_docs/python/tf/gather) * [tf.image.resize_bilinear](https://www.tensorflow.org/api_docs/python/tf/image/resize_bilinear) -* [tf.pad](https://www.tensorflow.org/api_docs/python/tf/pad) -* [tf.reduce_mean](https://www.tensorflow.org/api_docs/python/tf/reduce_mean) * [tf.slice](https://www.tensorflow.org/api_docs/python/tf/slice) -* [tf.space_to_batch_nd](https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd) -* [tf.split](https://www.tensorflow.org/api_docs/python/tf/split) -* [tf.strided_slice](https://www.tensorflow.org/api_docs/python/tf/strided_slice) * [tf.tanh](https://www.tensorflow.org/api_docs/python/tf/tanh) ## TensorFlow Lite Operations @@ -160,6 +176,20 @@ Options { } ``` +**BATCH_TO_SPACE_ND** + +``` +Inputs { + 0: 4D tensor + 1: 1D tensor + 2: 2D tensor +} +Outputs { + 0: tensor rearranged using block_shape. See tf.batch_to_space_nd for + details. +} +``` + **CONCATENATION** ``` @@ -213,6 +243,17 @@ Options { } ``` +**EXP** + +``` +Inputs { + 0: tensor +} +Outputs { + 0: result of computing element-wise exponential of the input tensor +} +``` + **FULLY_CONNECTED** ``` @@ -289,6 +330,17 @@ Outputs { } ``` +**LOG_SOFTMAX** + +``` +Inputs { + 0: tensor +} +Outputs { + 0: tensor equivalent to logits - log(reduce_sum(exp(logits), -1)) +} +``` + **MAX_POOL_2D** ``` @@ -322,6 +374,34 @@ Options { } ``` +**PAD** + +``` +Inputs { + 0: tensor + 1: tensor +} +Outputs { + 0: tensor where additional values are added before and after the contents of + each dimension +} +``` + +**MEAN (tf.reduce_mean)** + +``` +Inputs { + 0: tensor + 1: tensor +} +Outputs { + 0: tensor containing the mean of the elements +} +Options { + keep_dims: whether to retain reduced dimensions +} +``` + **RELU** ``` @@ -399,6 +479,93 @@ Options { } ``` +**SPACE_TO_BATCH_ND** + +``` +Inputs { + 0: 4D tensor + 1: 1D tensor + 2: 2D tensor +} +Outputs { + 0: a tensor rearranged using block_shape. See tf.space_to_batch_nd for + details. +} +``` + +**SPLIT** + +``` +Inputs { + 0: 0D tensor (axis) + 1: tensor (input) +} +Outputs { + 0-N: subtensors built from the input tensors +} +Options { + num_splits: Specifies number of outputs +} +``` + +**SQUEEZE** + +``` +Inputs { + 0: tensor +} +Outputs { + 0: tensor without any dimensions of size 1 +} +Options { + squeeze_dims +} +``` + +**STRIDED_SLICE** + +``` +Inputs { + 0: tensor + 1: 1D tensor + 2: 1D tensor + 3: 1D tensor +} +Outputs { + 0: slice of the input tensor of the given size +} +Options { + begin_mask: mask for begin indicies + end_mask: mask for end indices + shrink_axis_mask: mask that indicates which dimensions to remove +} +``` + +**TOP_K** + +``` +Inputs { + 0: tensor + 1: OD tensor +} +Outputs { + 0: k largest element along each last dimensional slice + 1: indicies of values within the last dimension of the input ensor +} +``` + +**TRANSPOSE** + +``` +Inputs { + 0: tensor + 1: tensor +} +Outputs { + 0: tensor permuted according to perm +} +``` + And these are TensorFlow Lite operations that are present but not ready for custom models yet: diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 370e4955271679975072d458e0ad9837a69d9556..4575fe884dc07963df5f0a26c5fe6680d92e409c 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -22,19 +22,35 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/graph_info.h" +#include "tensorflow/contrib/lite/kernels/eigen_support.h" #include "tensorflow/contrib/lite/kernels/gemm_support.h" #include "tensorflow/contrib/lite/memory_planner.h" #include "tensorflow/contrib/lite/nnapi_delegate.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/util.h" + +namespace tflite { namespace { -// std::vector preallocation tuning. -constexpr const int kSlotsToReserve = 128; +// Stub method which returns kTfLiteError when the function is forbidden. +// We're registrating this function to several different function to save +// compiled binary size. Please note the restrictions: +// * The type of first parameter have to be `TfLiteContext*`. +// * All paramteters must be trivailly destructible. (E.g. No C++ class) +TfLiteStatus ForbiddenContextFunction(TfLiteContext* context, ...) { + context->ReportError(context, + "The function is forbidden if not calling in delegate."); + return kTfLiteError; +} -} // namespace +// Set the ForbiddenContextFunction to a compatible function pointer. +template +void SetForbiddenContextFunction(FunctionType* func) { + *func = reinterpret_cast(ForbiddenContextFunction); +} -namespace tflite { +} // namespace // A trivial implementation of GraphInfo around the Interpreter. // NOTE: this interpreter info represents the subset of the @@ -77,16 +93,18 @@ Interpreter::Interpreter(ErrorReporter* error_reporter) context_.AddTensors = AddTensors; context_.tensors = nullptr; context_.tensors_size = 0; + context_.eigen_context = nullptr; context_.gemm_context = nullptr; + context_.recommended_num_threads = -1; // Invalid to call these these except from TfLiteDelegate - context_.GetNodeAndRegistration = nullptr; - context_.ReplaceSubgraphsWithDelegateKernels = nullptr; - context_.GetExecutionPlan = nullptr; + SetForbiddenContextFunction(&context_.GetNodeAndRegistration); + SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels); + SetForbiddenContextFunction(&context_.GetExecutionPlan); // Reserve some space for the tensors to avoid excessive resizing. - tensors_.reserve(kSlotsToReserve); - nodes_and_registration_.reserve(kSlotsToReserve); + tensors_.reserve(kTensorsReservedCapacity); + nodes_and_registration_.reserve(kTensorsReservedCapacity); next_execution_plan_index_to_prepare_ = 0; UseNNAPI(false); } @@ -103,19 +121,99 @@ Interpreter::~Interpreter() { } for (int i = 0; i < context_.tensors_size; i++) { - TfLiteTensorFree(&context_.tensors[i]); + TfLiteTensor* tensor = &context_.tensors[i]; + if (tensor->buffer_handle != kTfLiteNullBufferHandle) { + tensor->delegate->FreeBufferHandle(tensor->delegate, + &tensor->buffer_handle); + } + TfLiteTensorFree(tensor); } } TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( TfLiteContext* context, TfLiteRegistration registration, - const TfLiteIntArray* nodes_to_replace) { + const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate) { return static_cast(context->impl_) - ->ReplaceSubgraphsWithDelegateKernels(registration, nodes_to_replace); + ->ReplaceSubgraphsWithDelegateKernels(registration, nodes_to_replace, + delegate); } +namespace { + +// Copy a std::vector to an existing TfLiteIntArray. +// This is a low-level data manipulation function, and it's caller's +// responsibility to ensure TfLiteIntArray has enough size. +void CopyVectorToTfLiteIntArray(const std::vector& vec, + TfLiteIntArray* arr) { + arr->size = vec.size(); + memcpy(arr->data, vec.data(), sizeof(int) * arr->size); +} + +// This function allocates a continuous memory space that contains a +// TfLiteDelegateParams followed by a several TfLiteIntArray. +// When calling `free` at TfLiteDelegateParams*, all the allocated space +// will be freed together. +// +// +-----------------------------------+ +// | TfLiteDelegateParams | +// | TfLiteDelegate* delegate; | +// | TfLiteIntArray* nodes_to_replace; |--\ +// | TfLiteIntArray* input_tensors; |--+--\ +// | TfLiteIntArray* output_tensors; |--+--+--\ +// +-----------------------------------+ | | | +// | TfLiteIntArray (variable size) |<-/ | | +// +-----------------------------------+ | | +// | TfLiteIntArray (variable size) |<----/ | +// +-----------------------------------+ | +// | TfLiteIntArray (variable size) |<-------/ +// +-----------------------------------+ +TfLiteDelegateParams* CreateDelegateParams(TfLiteDelegate* delegate, + const Subgraph& subgraph) { + // Step 1: Calculate the allocation size. + int allocation_size = sizeof(TfLiteDelegateParams); + + int nodes_to_replace_size = + TfLiteIntArrayGetSizeInBytes(subgraph.nodes.size()); + allocation_size += nodes_to_replace_size; + + int input_tensors_size = + TfLiteIntArrayGetSizeInBytes(subgraph.input_tensors.size()); + allocation_size += input_tensors_size; + + int output_tensors_size = + TfLiteIntArrayGetSizeInBytes(subgraph.output_tensors.size()); + allocation_size += output_tensors_size; + + // Step 2: Allocate the memory. + // Use `char*` for conveniently step through the allocated space by bytes. + char* allocation = reinterpret_cast(malloc(allocation_size)); + + // Step 3: Fill all data structures structures. + TfLiteDelegateParams* params = + reinterpret_cast(allocation); + params->delegate = delegate; + allocation += sizeof(TfLiteDelegateParams); + + params->nodes_to_replace = reinterpret_cast(allocation); + CopyVectorToTfLiteIntArray(subgraph.nodes, params->nodes_to_replace); + allocation += nodes_to_replace_size; + + params->input_tensors = reinterpret_cast(allocation); + CopyVectorToTfLiteIntArray(subgraph.input_tensors, params->input_tensors); + allocation += input_tensors_size; + + params->output_tensors = reinterpret_cast(allocation); + CopyVectorToTfLiteIntArray(subgraph.output_tensors, params->output_tensors); + allocation += output_tensors_size; + + return params; +} + +} // namespace + TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( - TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace) { + TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace, + TfLiteDelegate* delegate) { // Annotate the registration as DELEGATE op. registration.builtin_code = BuiltinOperator_DELEGATE; @@ -127,30 +225,37 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( execution_plan_.clear(); for (auto& subgraph : subgraphs) { - // Turn subgraph.nodes into a TfLiteIntArray compatible data structure. - // TODO(aselle): Avoid this copy by constructing subgraph.nodes that way - // in the first place - subgraph.nodes.insert(subgraph.nodes.begin(), - static_cast(subgraph.nodes.size())); // Subgraphs calimed by the delegate should have a "macro" op created, the // other subgraphs (kTfNonPartition) just have their nodes added back to // the execution plan. switch (subgraph.type) { case Subgraph::kTfNonPartition: - for (auto it = subgraph.nodes.begin() + 1; it != subgraph.nodes.end(); + for (auto it = subgraph.nodes.begin(); it != subgraph.nodes.end(); ++it) { execution_plan_.push_back(*it); } break; case Subgraph::kTfPartition: { - void* builtin_data = nullptr; int node_index; - // Create a node that represents computation of this subgraph. - AddNodeWithParameters( - subgraph.input_tensors, subgraph.output_tensors, - reinterpret_cast(subgraph.nodes.data()), - subgraph.nodes.size() * sizeof(subgraph.nodes[0]), builtin_data, - ®istration, &node_index); + + TfLiteDelegateParams* params = CreateDelegateParams(delegate, subgraph); + AddNodeWithParameters(subgraph.input_tensors, subgraph.output_tensors, + nullptr, 0, params, ®istration, &node_index); + + // Initialize the output tensors's delegate-related fields. + for (int tensor_index : subgraph.output_tensors) { + TfLiteTensor* tensor = &tensors_[tensor_index]; + TF_LITE_ENSURE_EQ(&context_, tensor->delegate, nullptr); + TF_LITE_ENSURE_EQ(&context_, tensor->buffer_handle, + kTfLiteNullBufferHandle); + // buffer_handle will be filled in delegate's `Prepare` + // function. + tensor->delegate = delegate; + } + + // Associate the node with the delegate. + TfLiteNode* node = &nodes_and_registration_[node_index].first; + node->delegate = delegate; } break; case Subgraph::kTfUnexplored: return kTfLiteError; @@ -169,8 +274,8 @@ TfLiteStatus Interpreter::GetExecutionPlan(TfLiteIntArray** execution_plan) { *execution_plan = plan_cache_.get(); static_assert(sizeof(plan_cache_->data[0]) == sizeof(execution_plan_[0]), "TfLiteIntArray and execution_plan do not contain same type."); - memcpy(plan_cache_->data, execution_plan_.data(), - sizeof(plan_cache_->data[0]) * execution_plan_.size()); + std::memcpy(plan_cache_->data, execution_plan_.data(), + sizeof(plan_cache_->data[0]) * execution_plan_.size()); return kTfLiteOk; } @@ -240,14 +345,6 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, return kTfLiteOk; } -namespace { -TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector& x) { - TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size()); - for (size_t i = 0; i < x.size(); i++) lite->data[i] = x[i]; - return lite; -} -} // namespace - TfLiteStatus Interpreter::AllocateTensors() { next_execution_plan_index_to_prepare_ = 0; if (memory_planner_) { @@ -260,7 +357,11 @@ TfLiteStatus Interpreter::AllocateTensors() { } TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); - invokable_ = true; + if (state_ == kStateUninvokable) { + state_ = kStateInvokable; + } + TF_LITE_ENSURE(&context_, state_ == kStateInvokable || + state_ == kStateInvokableAndImmutable); return kTfLiteOk; } @@ -268,7 +369,12 @@ TfLiteStatus Interpreter::AddNodeWithParameters( const std::vector& inputs, const std::vector& outputs, const char* init_data, size_t init_data_size, void* builtin_data, const TfLiteRegistration* registration, int* node_index) { - invokable_ = false; + if (state_ == kStateInvokableAndImmutable) { + ReportError(&context_, + "AddNodeWithParameters is disallowed when graph is immutable."); + return kTfLiteError; + } + state_ = kStateUninvokable; std::unique_ptr builtin_data_deleter(builtin_data, free); @@ -282,7 +388,6 @@ TfLiteStatus Interpreter::AddNodeWithParameters( int new_node_index = nodes_and_registration_.size(); if (node_index) *node_index = new_node_index; nodes_and_registration_.resize(nodes_and_registration_.size() + 1); - auto& node_and_reg = nodes_and_registration_.back(); TfLiteNode& node = node_and_reg.first; if (node.inputs) TfLiteIntArrayFree(node.inputs); @@ -292,8 +397,8 @@ TfLiteStatus Interpreter::AddNodeWithParameters( // NOTE, here we are not using move semantics yet, since our internal // representation isn't std::vector, but in the future we would like to avoid // copies, so we want the interface to take r-value references now. - node.inputs = convertVectorToTfLiteIntArray(inputs); - node.outputs = convertVectorToTfLiteIntArray(outputs); + node.inputs = ConvertVectorToTfLiteIntArray(inputs); + node.outputs = ConvertVectorToTfLiteIntArray(outputs); node.temporaries = TfLiteIntArrayCreate(0); if (init_data) { node.user_data = OpInit(*registration, init_data, init_data_size); @@ -306,6 +411,7 @@ TfLiteStatus Interpreter::AddNodeWithParameters( node.builtin_data = builtin_data_deleter.release(); // TODO(ycling): Filling `custom_initial_data` and `custom_initial_data_size` // properly for nodes generated by ReplaceSubgraphsWithDelegateKernels. + if (registration->builtin_code == BuiltinOperator_CUSTOM) { // When it's a CUSTOM op, the `custom_options` field in the Flatbuffer // `Operator` table is passed in. @@ -316,6 +422,7 @@ TfLiteStatus Interpreter::AddNodeWithParameters( node.custom_initial_data_size = 0; } + node.delegate = nullptr; node_and_reg.second = *registration; execution_plan_.push_back(new_node_index); return kTfLiteOk; @@ -323,13 +430,18 @@ TfLiteStatus Interpreter::AddNodeWithParameters( TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index, const std::vector& dims) { + if (state_ == kStateInvokableAndImmutable) { + ReportError(&context_, + "ResizeInputTensor is disallowed when graph is immutable."); + return kTfLiteError; + } + state_ = kStateUninvokable; + // TODO(aselle): All bounds checks can be implemented as one-sided bounds // checks by casting to unsigned for efficiency. Profile before doing this. - TF_LITE_ENSURE(&context_, tensor_index < context_.tensors_size && tensor_index >= 0); - invokable_ = false; - TfLiteIntArray* dims_lite = convertVectorToTfLiteIntArray(dims); + TfLiteIntArray* dims_lite = ConvertVectorToTfLiteIntArray(dims); return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite); } @@ -353,6 +465,7 @@ TfLiteStatus Interpreter::PrepareOpsStartingAt( TfLiteNode& node = nodes_and_registration_[node_index].first; const TfLiteRegistration& registration = nodes_and_registration_[node_index].second; + EnsureTensorsVectorCapacity(); if (OpPrepare(registration, &node) == kTfLiteError) { return kTfLiteError; } @@ -392,7 +505,7 @@ TfLiteStatus Interpreter::Invoke() { ReportError(&context_, "Invoke called on model that is not consistent."); return kTfLiteError; } - if (!invokable_) { + if (state_ == kStateUninvokable) { ReportError(&context_, "Invoke called on model that is not ready."); return kTfLiteError; } @@ -430,10 +543,29 @@ TfLiteStatus Interpreter::Invoke() { TfLiteNode& node = nodes_and_registration_[node_index].first; const TfLiteRegistration& registration = nodes_and_registration_[node_index].second; + + // TODO(ycling): This is an extra loop through inputs to check if the data + // need to be copied from Delegate buffer to raw memory, which is often not + // needed. We may want to cache this in prepare to know if this needs to be + // done for a node or not. + for (int i = 0; i < node.inputs->size; ++i) { + int tensor_index = node.inputs->data[i]; + if (tensor_index == kOptionalTensor) { + continue; + } + TfLiteTensor* tensor = &tensors_[tensor_index]; + if (tensor->delegate && tensor->delegate != node.delegate && + tensor->data_is_stale) { + EnsureTensorDataIsReadable(tensor_index); + } + } + + EnsureTensorsVectorCapacity(); if (OpInvoke(registration, &node) == kTfLiteError) { status = kTfLiteError; } } + return status; } @@ -469,6 +601,7 @@ TfLiteStatus Interpreter::AddTensors(int tensors_to_add, tensors_.resize(tensors_.size() + tensors_to_add); for (int i = base_index; i < tensors_.size(); i++) { memset(&tensors_[i], 0, sizeof(tensors_[i])); + tensors_[i].buffer_handle = kTfLiteNullBufferHandle; } context_.tensors = tensors_.data(); context_.tensors_size = tensors_.size(); @@ -501,9 +634,16 @@ TfLiteStatus Interpreter::GetNodeAndRegistration( } TfLiteStatus Interpreter::SetTensorParametersReadOnly( - int tensor_index, TfLiteType type, const char* name, - const std::vector& dims, TfLiteQuantizationParams quantization, - const char* buffer, size_t bytes, const Allocation* allocation) { + int tensor_index, TfLiteType type, const char* name, const int rank, + const int* dims, TfLiteQuantizationParams quantization, const char* buffer, + size_t bytes, const Allocation* allocation) { + if (state_ == kStateInvokableAndImmutable) { + ReportError( + &context_, + "SetTensorParametersReadOnly is disallowed when graph is immutable."); + return kTfLiteError; + } + TF_LITE_ENSURE(&context_, tensor_index < context_.tensors_size && tensor_index >= 0); // For most tensors we know exactly how much memory is necessary so we can @@ -511,14 +651,27 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly( // because their sizes change with the contents of the individual strings. if (type != kTfLiteString) { size_t required_bytes; - TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(), - &required_bytes)); + TF_LITE_ENSURE_OK(&context_, + BytesRequired(type, dims, rank, &required_bytes)); TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes); } - invokable_ = false; - TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims), - quantization, const_cast(buffer), bytes, - kTfLiteMmapRo, allocation, &context_.tensors[tensor_index]); + + TfLiteTensor& tensor = context_.tensors[tensor_index]; + if (type == tensor.type && + EqualArrayAndTfLiteIntArray(tensor.dims, rank, dims)) { + // Fast path which does not invalidate the invokable property. + TfLiteTensorDataFree(&tensor); + tensor.data.raw = const_cast(buffer); + if (!tensor.dims) tensor.dims = ConvertArrayToTfLiteIntArray(rank, dims); + tensor.params = quantization; + tensor.allocation_type = kTfLiteMmapRo; + tensor.allocation = allocation; + } else { + state_ = kStateUninvokable; + TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims), + quantization, const_cast(buffer), bytes, + kTfLiteMmapRo, allocation, &tensor); + } return kTfLiteOk; } @@ -527,9 +680,14 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly( // bytes. The lifetime of buffer must be ensured to be greater or equal // to Interpreter. TfLiteStatus Interpreter::SetTensorParametersReadWrite( - int tensor_index, TfLiteType type, const char* name, - const std::vector& dims, TfLiteQuantizationParams quantization) { - invokable_ = false; + int tensor_index, TfLiteType type, const char* name, const int rank, + const int* dims, TfLiteQuantizationParams quantization) { + if (state_ == kStateInvokableAndImmutable) { + ReportError( + &context_, + "SetTensorParametersReadWrite is disallowed when graph is immutable."); + return kTfLiteError; + } TF_LITE_ENSURE(&context_, tensor_index < context_.tensors_size && tensor_index >= 0); size_t required_bytes = 0; @@ -538,10 +696,10 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite( // many bytes we will need based on the dimensions. String tensors are // allocated dynamically and we can't know ahead of time how much space // they will require. - TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(), - &required_bytes)); + TF_LITE_ENSURE_OK(&context_, + BytesRequired(type, dims, rank, &required_bytes)); } - TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims), + TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims), quantization, /*buffer=*/nullptr, required_bytes, type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw, @@ -604,26 +762,95 @@ void Interpreter::UseNNAPI(bool enable) { } void Interpreter::SetNumThreads(int num_threads) { - // TODO(ahentz): this forces us to link against gemmlowp even when the ops - // don't use it. We should implement some dynamic mechanism for this sort of - // library-specific initialization. - tflite::gemm_support::SetMaxNumThreads(&context_, num_threads); -} + context_.recommended_num_threads = num_threads; + + // TODO(ahentz): find a way to avoid this. It causes gemmlowp and eigen to + // be required in order to compile the framework. + gemm_support::SetNumThreads(&context_, num_threads); + eigen_support::SetNumThreads(&context_, num_threads); +} + +TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate, + bool allow_dynamic_tensors) { + if (!allow_dynamic_tensors) { + int last_execution_plan_index_prepared; + TF_LITE_ENSURE_OK(&context_, PrepareOpsStartingAt( + 0, &last_execution_plan_index_prepared)); + + bool has_dynamic_tensors = true; + // Dynamic tensors exist if not all nodes can be prepared. + if (last_execution_plan_index_prepared + 1 == execution_plan_.size()) { + // If all the nodes can be prepared, check if the last node has dynamic + // tensors. + int node_index = execution_plan_[last_execution_plan_index_prepared]; + TfLiteNode& node = nodes_and_registration_[node_index].first; + if (!HasDynamicTensor(context_, node.outputs)) { + has_dynamic_tensors = false; + } + } + if (has_dynamic_tensors) { + ReportError(&context_, "Attempting to resize a fixed-size tensor."); + return kTfLiteError; + } + } -TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) { // TODO(aselle): Consider if it is worth storing pointers to delegates. - // Setup additional context interface + // Setup additional context interface. context_.GetNodeAndRegistration = GetNodeAndRegistration; context_.ReplaceSubgraphsWithDelegateKernels = ReplaceSubgraphsWithDelegateKernels; context_.GetExecutionPlan = GetExecutionPlan; - TfLiteStatus status = delegate->Prepare(&context_, delegate->data_); + TfLiteStatus status = delegate->Prepare(&context_, delegate); + // Remove additional context info. - context_.GetNodeAndRegistration = nullptr; - context_.ReplaceSubgraphsWithDelegateKernels = nullptr; - context_.GetExecutionPlan = nullptr; + SetForbiddenContextFunction(&context_.GetNodeAndRegistration); + SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels); + SetForbiddenContextFunction(&context_.GetExecutionPlan); + + TF_LITE_ENSURE_OK(&context_, status); + + if (!allow_dynamic_tensors) { + TF_LITE_ENSURE_OK(&context_, AllocateTensors()); + TF_LITE_ENSURE(&context_, state_ == kStateInvokable || + state_ == kStateInvokableAndImmutable); + // After using a delegate which doesn't support dynamic tensors, make the + // entire graph immutable. + state_ = kStateInvokableAndImmutable; + } + return status; } +TfLiteStatus Interpreter::SetBufferHandle(int tensor_index, + TfLiteBufferHandle buffer_handle, + TfLiteDelegate* delegate) { + TF_LITE_ENSURE(&context_, tensor_index < tensors_size()); + TfLiteTensor* tensor = &tensors_[tensor_index]; + + TF_LITE_ENSURE(&context_, + tensor->delegate == nullptr || tensor->delegate == delegate); + tensor->delegate = delegate; + if (tensor->buffer_handle != kTfLiteNullBufferHandle) { + TF_LITE_ENSURE(&context_, tensor->delegate->FreeBufferHandle != nullptr); + tensor->delegate->FreeBufferHandle(tensor->delegate, + &tensor->buffer_handle); + } + tensor->buffer_handle = buffer_handle; + + return kTfLiteOk; +} + +TfLiteStatus Interpreter::GetBufferHandle(int tensor_index, + TfLiteBufferHandle* buffer_handle, + TfLiteDelegate** delegate) { + TF_LITE_ENSURE(&context_, tensor_index < tensors_size()); + TfLiteTensor* tensor = &tensors_[tensor_index]; + + *delegate = tensor->delegate; + *buffer_handle = tensor->buffer_handle; + + return kTfLiteOk; +} + } // namespace tflite diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index a9df2627e02486f71e9b5b7b0d1bfd89c7ec70c0..77db17878318276c6cf5067274a3af3be262c8e1 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/memory_planner.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" namespace tflite { @@ -134,18 +133,34 @@ class Interpreter { // This variant assumes an external buffer has been allocated of size // bytes. The lifetime of buffer must be ensured to be greater or equal // to Interpreter. - TfLiteStatus SetTensorParametersReadOnly( + inline TfLiteStatus SetTensorParametersReadOnly( int tensor_index, TfLiteType type, const char* name, const std::vector& dims, TfLiteQuantizationParams quantization, + const char* buffer, size_t bytes, + const Allocation* allocation = nullptr) { + return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(), + dims.data(), quantization, buffer, bytes, + allocation); + }; + + TfLiteStatus SetTensorParametersReadOnly( + int tensor_index, TfLiteType type, const char* name, const int rank, + const int* dims, TfLiteQuantizationParams quantization, const char* buffer, size_t bytes, const Allocation* allocation = nullptr); // Set description of inputs/outputs/data/fptrs for node `node_index`. // This variant assumes an external buffer has been allocated of size // bytes. The lifetime of buffer must be ensured to be greater or equal // to Interpreter. - TfLiteStatus SetTensorParametersReadWrite( + inline TfLiteStatus SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, - const std::vector& dims, TfLiteQuantizationParams quantization); + const std::vector& dims, TfLiteQuantizationParams quantization) { + return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(), + dims.data(), quantization); + } + TfLiteStatus SetTensorParametersReadWrite( + int tensor_index, TfLiteType type, const char* name, const int rank, + const int* dims, TfLiteQuantizationParams quantization); // Functions to access tensor data @@ -257,13 +272,57 @@ class Interpreter { // Allow a delegate to look at the graph and modify the graph to handle // parts of the graph themselves. After this is called, the graph may // contain new nodes that replace 1 more nodes. - TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate); + // WARNING: This is an experimental API and subject to change. + TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate, + bool allow_dynamic_tensors = false); + + // Ensure the data in `tensor.data` is readable. In case delegate is used, + // it might require to copy the data from delegate buffer to raw memory. + TfLiteStatus EnsureTensorDataIsReadable(int tensor_index) { + TF_LITE_ENSURE(&context_, tensor_index < tensors_size()); + TfLiteTensor* tensor = &tensors_[tensor_index]; + if (tensor->data_is_stale) { + TF_LITE_ENSURE(&context_, tensor->delegate != nullptr); + TF_LITE_ENSURE(&context_, + tensor->buffer_handle != kTfLiteNullBufferHandle); + // This can be null if the delegate doesn't use its own buffer. + TF_LITE_ENSURE(&context_, + tensor->delegate->CopyFromBufferHandle != nullptr); + tensor->delegate->CopyFromBufferHandle(tensor->delegate, + tensor->buffer_handle, + tensor->data.raw, tensor->bytes); + tensor->data_is_stale = false; + } + return kTfLiteOk; + } - // WARNING: This is a deprecated interface and will be removed as soon as - // possible. Please do not use it. - // TODO(impjdi): Remove this interface after resolving dependencies. - void set_model(const Model* model) { model_ = const_cast(model); } - Model* model() const { return model_; } + // Set the delegate buffer handle to a tensor. It can be called in the + // following cases: + // 1. Set the buffer handle to a tensor that's not being written by a + // delegate. For example, feeding an OpenGL texture as the input of the + // inference graph. + // 2. Set the buffer handle to a tensor that uses the same delegate. + // For example, set an OpenGL texture as the output of inference, while + // the node which produces output is an OpenGL delegate node. + // WARNING: This is an experimental API and subject to change. + TfLiteStatus SetBufferHandle(int tensor_index, + TfLiteBufferHandle buffer_handle, + TfLiteDelegate* delegate); + + // Get the delegate buffer handle, and the delegate which can process the + // buffer handle. + // WARNING: This is an experimental API and subject to change. + TfLiteStatus GetBufferHandle(int tensor_index, + TfLiteBufferHandle* buffer_handle, + TfLiteDelegate** delegate); + + // The default capacity of `tensors_` vector. + static constexpr int kTensorsReservedCapacity = 128; + // The capacity headroom of `tensors_` vector before calling ops' + // `prepare` and `invoke` function. In these functions, it's guaranteed + // allocating up to `kTensorsCapacityHeadroom` more tensors won't invalidate + // pointers to existing tensors. + static constexpr int kTensorsCapacityHeadroom = 16; private: // Give 'op_reg' a chance to initialize itself using the contents of @@ -347,14 +406,15 @@ class Interpreter { // Entry point for C API ReplaceSubgraphsWithDelegateKernels static TfLiteStatus ReplaceSubgraphsWithDelegateKernels( TfLiteContext* context, TfLiteRegistration registration, - const TfLiteIntArray* nodes_to_replace); + const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate); // Update the execution graph to replace some of the nodes with stub // nodes. Specifically any node index that has `nodes[index]==1` will be // slated for replacement with a delegate kernel specified by registration. // WARNING: This is an experimental interface that is subject to change. TfLiteStatus ReplaceSubgraphsWithDelegateKernels( - TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace); + TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace, + TfLiteDelegate* delegate); // WARNING: This is an experimental interface that is subject to change. // Gets the internal pointer to a TensorFlow lite node by node_index. @@ -377,6 +437,32 @@ class Interpreter { static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context, TfLiteIntArray** execution_plan); + // Ensures that `tensors_` has at least `kTensorsCapacityHeadroom` extra + // capacity. Calling this function may invalidate existing pointers to + // tensors. After calling this function, adding `kTensorsCapacityHeadroom` + // more tensors won't invalidate the pointer to existing tensors. + void EnsureTensorsVectorCapacity() { + const int required_capacity = tensors_size() + kTensorsCapacityHeadroom; + if (required_capacity > tensors_.capacity()) { + tensors_.reserve(required_capacity); + context_.tensors = tensors_.data(); + } + } + + // The state of the Interpreter. + enum State { + // The interpreter isn't ready to be invoked. + // `AllocateTensor` need to be called to enter an invokable state. + kStateUninvokable = 0, + // The interpreter is ready to be invoked. + kStateInvokable, + // The interpreter is ready to be invoked, and graph can't be further + // modified. The interpreter will enter this state when calling + // `ModifyGraphWithDelegate` with `allow_dynamic_tensors=false`. + kStateInvokableAndImmutable, + }; + State state_ = kStateUninvokable; + // A pure C data structure used to communicate with the pure C plugin // interface. To avoid copying tensor metadata, this is also the definitive // structure to store tensors. @@ -392,10 +478,6 @@ class Interpreter { // the tensor array. bool consistent_ = true; - // Whether the model is safe to invoke (if any errors occurred this - // will be false). - bool invokable_ = false; - // Array of indices representing the tensors that are inputs to the // interpreter. std::vector inputs_; @@ -411,7 +493,7 @@ class Interpreter { // During Invoke(), Interpreter will allocate input tensors first, which are // known to be fixed size. Then it will allocate outputs from nodes as many // as possible. When there is a node that produces dynamic sized tensor. - // Intepreter will stop allocating tensors, set the value of next allocate + // Interpreter will stop allocating tensors, set the value of next allocate // node id, and execute the node to generate the output tensor before continue // to allocate successors. This process repeats until all nodes are executed. // NOTE: this relies on the order of nodes that is in topological order. @@ -432,11 +514,6 @@ class Interpreter { std::unique_ptr nnapi_delegate_; std::unique_ptr memory_planner_; - - // WARNING: This is a deprecated interface and will be removed as soon as - // possible. Please do not use it. - // TODO(impjdi): Remove this interface after resolving dependencies. - Model* model_ = nullptr; }; } // namespace tflite diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 28c96e5dde6ffa62bb073db9716a00f91c6e0bdf..131e088079857af34478645b7f1559364d03a493 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -17,9 +17,11 @@ limitations under the License. #include #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" #include "tensorflow/contrib/lite/string_util.h" #include "tensorflow/contrib/lite/testing/util.h" + namespace tflite { namespace { @@ -40,7 +42,7 @@ TEST(BasicInterpreter, InvokeInvalidModel) { ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); } -// Test size accesser functions. +// Test size accessor functions. TEST(BasicInterpreter, TestSizeFunctions) { Interpreter interpreter; int base_index; @@ -439,12 +441,12 @@ TEST(BasicInterpreter, ThreeStepAllocate) { // String-in String-out node. TfLiteRegistration reg_copy = {nullptr, nullptr, nullptr, nullptr}; reg_copy.invoke = [](TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; - TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]]; + TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; DynamicBuffer buf; - StringRef str_ref = GetString(a0, 0); + StringRef str_ref = GetString(input, 0); buf.AddString(str_ref); - buf.WriteToTensor(a1); + buf.WriteToTensor(output); return kTfLiteOk; }; @@ -561,6 +563,86 @@ TEST(BasicInterpreter, TestCustomErrorReporter) { ASSERT_EQ(reporter.calls, 1); } +TEST(BasicInterpreter, TestUnsupportedDelegateFunctions) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + TfLiteRegistration registration = { + .init = nullptr, .free = nullptr, .prepare = nullptr, .invoke = nullptr}; + // These functions are only supported inside Delegate's Prepare function. + // The test verifies that these functions returns `kTfLiteError`, but not + // `kTfLiteOk` or just crashes. + registration.prepare = [](TfLiteContext* context, TfLiteNode* node) { + { + TfLiteIntArray* execution_plan; + EXPECT_EQ(context->GetExecutionPlan(context, &execution_plan), + kTfLiteError); + } + { + TfLiteNode* node; + TfLiteRegistration* registration; + EXPECT_EQ( + context->GetNodeAndRegistration(context, 0, &node, ®istration), + kTfLiteError); + } + { + TfLiteRegistration delegate_registration = {nullptr, nullptr, nullptr, + nullptr}; + TfLiteIntArray nodes_to_replace; + nodes_to_replace.size = 0; + EXPECT_EQ(context->ReplaceSubgraphsWithDelegateKernels( + context, delegate_registration, &nodes_to_replace, nullptr), + kTfLiteError); + } + return kTfLiteError; + }; + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, + ®istration), + kTfLiteOk); + EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteError); +} + +TEST(InterpreterTensorsCapacityTest, TestWithinHeadroom) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(Interpreter::kTensorsReservedCapacity), + kTfLiteOk); + TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr}; + registration.prepare = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* first_tensor = context->tensors; + + int new_tensor_index; + context->AddTensors(context, Interpreter::kTensorsCapacityHeadroom, + &new_tensor_index); + EXPECT_EQ(first_tensor, context->tensors); + return kTfLiteOk; + }; + ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, + ®istration), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); +} + +TEST(InterpreterTensorsCapacityTest, TestExceedHeadroom) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(Interpreter::kTensorsReservedCapacity), + kTfLiteOk); + TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr}; + registration.prepare = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* first_tensor = context->tensors; + + int new_tensor_index; + context->AddTensors(context, Interpreter::kTensorsCapacityHeadroom + 1, + &new_tensor_index); + EXPECT_NE(first_tensor, context->tensors); + return kTfLiteOk; + }; + ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, + ®istration), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); +} + // Test fixture that allows playing with execution plans. It creates a two // node graph that can be executed in either [0,1] order or [1,0] order. // The CopyOp records when it is invoked in the class member run_order_ @@ -698,13 +780,17 @@ TfLiteRegistration AddOpRegistration() { reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { // Set output size to input size - TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]]; - TfLiteTensor* tensor1 = &context->tensors[node->inputs->data[1]]; - TfLiteTensor* tensor2 = &context->tensors[node->outputs->data[0]]; - TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims); - TfLiteIntArray* newSizeOther = TfLiteIntArrayCopy(tensor1->dims); - TF_LITE_ENSURE_EQ(context, newSize->size, newSizeOther->size); - TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor2, newSize)); + TfLiteTensor* input1 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* input2 = &context->tensors[node->inputs->data[1]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + + TF_LITE_ENSURE_EQ(context, input1->dims->size, input2->dims->size); + for (int i = 0; i < input1->dims->size; ++i) { + TF_LITE_ENSURE_EQ(context, input1->dims->data[i], input2->dims->data[i]); + } + + TF_LITE_ENSURE_STATUS(context->ResizeTensor( + context, output, TfLiteIntArrayCopy(input1->dims))); return kTfLiteOk; }; @@ -723,26 +809,40 @@ TfLiteRegistration AddOpRegistration() { } class TestDelegate : public ::testing::Test { - public: - TestDelegate() { - interpreter_.AddTensors(5); - interpreter_.SetInputs({0, 1}); - interpreter_.SetOutputs({3, 4}); + protected: + void SetUp() override { + interpreter_.reset(new Interpreter); + interpreter_->AddTensors(5); + interpreter_->SetInputs({0, 1}); + interpreter_->SetOutputs({3, 4}); TfLiteQuantizationParams quant; - interpreter_.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, - quant); - interpreter_.SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, - quant); - interpreter_.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3}, - quant); - interpreter_.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3}, - quant); + interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, + quant); + interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, + quant); + interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3}, + quant); + interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3}, + quant); + interpreter_->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", {3}, + quant); TfLiteRegistration reg = AddOpRegistration(); - interpreter_.AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, ®); - interpreter_.AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, ®); - interpreter_.AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, ®); + interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, ®); + interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, ®); + interpreter_->AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, ®); + } + + void TearDown() override { + // Interpreter relies on delegate_ to free the resources properly. Thus + // the life cycle of delegate must be longer than interpreter. + interpreter_.reset(); + delegate_.reset(); } + TfLiteBufferHandle last_allocated_handle_ = kTfLiteNullBufferHandle; + + TfLiteBufferHandle AllocateBufferHandle() { return ++last_allocated_handle_; } + protected: class SimpleDelegate { public: @@ -751,8 +851,8 @@ class TestDelegate : public ::testing::Test { // value-copyable and compatible with TfLite. explicit SimpleDelegate(const std::vector& nodes) : nodes_(nodes) { delegate_.Prepare = [](TfLiteContext* context, - void* data) -> TfLiteStatus { - auto* simple = reinterpret_cast(data); + TfLiteDelegate* delegate) -> TfLiteStatus { + auto* simple = reinterpret_cast(delegate->data_); TfLiteIntArray* nodes_to_separate = TfLiteIntArrayCreate(simple->nodes_.size()); // Mark nodes that we want in TfLiteIntArray* structure. @@ -783,10 +883,26 @@ class TestDelegate : public ::testing::Test { } context->ReplaceSubgraphsWithDelegateKernels( - context, FakeFusedRegistration(), nodes_to_separate); + context, FakeFusedRegistration(), nodes_to_separate, delegate); TfLiteIntArrayFree(nodes_to_separate); return kTfLiteOk; }; + delegate_.CopyToBufferHandle = [](TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + void* data, int size) -> TfLiteStatus { + // TODO(ycling): Implement tests to test buffer copying logic. + return kTfLiteOk; + }; + delegate_.CopyFromBufferHandle = + [](TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle, + void* data, int size) -> TfLiteStatus { + // TODO(ycling): Implement tests to test buffer copying logic. + return kTfLiteOk; + }; + delegate_.FreeBufferHandle = [](TfLiteDelegate* delegate, + TfLiteBufferHandle* handle) { + *handle = kTfLiteNullBufferHandle; + }; // Store type-punned data SimpleDelegate structure. delegate_.data_ = reinterpret_cast(this); } @@ -803,36 +919,196 @@ class TestDelegate : public ::testing::Test { std::vector nodes_; TfLiteDelegate delegate_; }; - Interpreter interpreter_; + std::unique_ptr interpreter_; + std::unique_ptr delegate_; }; TEST_F(TestDelegate, BasicDelegate) { - interpreter_.Invoke(); - SimpleDelegate simple({0, 1, 2}); - interpreter_.ModifyGraphWithDelegate(simple.get_tf_lite_delegate()); + delegate_ = std::unique_ptr(new SimpleDelegate({0, 1, 2})); + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()); - ASSERT_EQ(interpreter_.execution_plan().size(), 1); - int node = interpreter_.execution_plan()[0]; - const auto* node_and_reg = interpreter_.node_and_registration(node); - ASSERT_EQ(node_and_reg->second.custom_name, + ASSERT_EQ(interpreter_->execution_plan().size(), 1); + int node = interpreter_->execution_plan()[0]; + const auto* node_and_reg = interpreter_->node_and_registration(node); + EXPECT_EQ(node_and_reg->second.custom_name, SimpleDelegate::FakeFusedRegistration().custom_name); + + const TfLiteDelegateParams* params = + reinterpret_cast( + node_and_reg->first.builtin_data); + ASSERT_EQ(params->nodes_to_replace->size, 3); + EXPECT_EQ(params->nodes_to_replace->data[0], 0); + EXPECT_EQ(params->nodes_to_replace->data[1], 1); + EXPECT_EQ(params->nodes_to_replace->data[2], 2); + + ASSERT_EQ(params->input_tensors->size, 2); + EXPECT_EQ(params->input_tensors->data[0], 0); + EXPECT_EQ(params->input_tensors->data[1], 1); + + ASSERT_EQ(params->output_tensors->size, 2); + EXPECT_EQ(params->output_tensors->data[0], 3); + EXPECT_EQ(params->output_tensors->data[1], 4); } TEST_F(TestDelegate, ComplexDeligate) { - interpreter_.Invoke(); - SimpleDelegate simple({1, 2}); - interpreter_.ModifyGraphWithDelegate(simple.get_tf_lite_delegate()); + delegate_ = std::unique_ptr(new SimpleDelegate({1, 2})); + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()); - ASSERT_EQ(interpreter_.execution_plan().size(), 2); + ASSERT_EQ(interpreter_->execution_plan().size(), 2); // 0th should be a non-delegated original op - ASSERT_EQ(interpreter_.execution_plan()[0], 0); + ASSERT_EQ(interpreter_->execution_plan()[0], 0); // 1st should be a new macro op (3) which didn't exist) - ASSERT_EQ(interpreter_.execution_plan()[1], 3); - const auto* node_and_reg = interpreter_.node_and_registration(3); + ASSERT_EQ(interpreter_->execution_plan()[1], 3); + const auto* node_and_reg = interpreter_->node_and_registration(3); ASSERT_EQ(node_and_reg->second.custom_name, SimpleDelegate::FakeFusedRegistration().custom_name); } +TEST_F(TestDelegate, SetBufferHandleToInput) { + delegate_ = std::unique_ptr(new SimpleDelegate({0, 1, 2})); + TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); + interpreter_->ModifyGraphWithDelegate(delegate); + + constexpr int kOutputTensorIndex = 0; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + ASSERT_EQ(tensor->delegate, nullptr); + ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); + + TfLiteBufferHandle handle = AllocateBufferHandle(); + TfLiteStatus status = + interpreter_->SetBufferHandle(kOutputTensorIndex, handle, delegate); + ASSERT_EQ(status, kTfLiteOk); + EXPECT_EQ(tensor->delegate, delegate); + EXPECT_EQ(tensor->buffer_handle, handle); +} + +TEST_F(TestDelegate, SetBufferHandleToOutput) { + delegate_ = std::unique_ptr(new SimpleDelegate({0, 1, 2})); + TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); + interpreter_->ModifyGraphWithDelegate(delegate); + + constexpr int kOutputTensorIndex = 3; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + // Before setting the buffer handle, the tensor's `delegate` is already set + // because it will be written by the delegate. + ASSERT_EQ(tensor->delegate, delegate); + ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); + + TfLiteBufferHandle handle = AllocateBufferHandle(); + TfLiteStatus status = + interpreter_->SetBufferHandle(kOutputTensorIndex, handle, delegate); + ASSERT_EQ(status, kTfLiteOk); + EXPECT_EQ(tensor->delegate, delegate); + EXPECT_EQ(tensor->buffer_handle, handle); +} + +TEST_F(TestDelegate, SetInvalidHandleToTensor) { + interpreter_->Invoke(); + delegate_ = std::unique_ptr(new SimpleDelegate({0, 1, 2})); + TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); + interpreter_->ModifyGraphWithDelegate(delegate, true); + + SimpleDelegate another_simple_delegate({0, 1, 2}); + + constexpr int kOutputTensorIndex = 3; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + // Before setting the buffer handle, the tensor's `delegate` is already set + // because it will be written by the delegate. + ASSERT_EQ(tensor->delegate, delegate); + ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); + + TfLiteBufferHandle handle = AllocateBufferHandle(); + TfLiteStatus status = interpreter_->SetBufferHandle( + kOutputTensorIndex, handle, + another_simple_delegate.get_tf_lite_delegate()); + // Setting a buffer handle to a tensor with another delegate will fail. + ASSERT_EQ(status, kTfLiteError); + EXPECT_EQ(tensor->delegate, delegate); + EXPECT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); +} + +TEST_F(TestDelegate, ResizeInputWithNonDynamicDelegateShouldFail) { + delegate_ = std::unique_ptr(new SimpleDelegate({0, 1, 2})); + ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 2}), kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(1, {1, 2}), kTfLiteOk); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 2}), kTfLiteError); +} + +class TestDelegateWithDynamicTensors : public ::testing::Test { + protected: + void SetUp() override { + interpreter_.reset(new Interpreter); + + interpreter_->AddTensors(2); + interpreter_->SetInputs({0}); + interpreter_->SetOutputs({1}); + TfLiteQuantizationParams quant; + interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, + quant); + interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, + quant); + TfLiteRegistration reg = DynamicCopyOpRegistration(); + interpreter_->AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®); + + delegate_.Prepare = [](TfLiteContext* context, + TfLiteDelegate* delegate) -> TfLiteStatus { + // In this test, the delegate replaces all the nodes if this function is + // called. + TfLiteIntArray* execution_plan; + TF_LITE_ENSURE_STATUS( + context->GetExecutionPlan(context, &execution_plan)); + context->ReplaceSubgraphsWithDelegateKernels( + context, DelegateRegistration(), execution_plan, delegate); + return kTfLiteOk; + }; + } + + static TfLiteRegistration DynamicCopyOpRegistration() { + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + SetTensorToDynamic(output); + return kTfLiteOk; + }; + + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + // Not implemented since this isn't required in testing. + return kTfLiteOk; + }; + return reg; + } + + static TfLiteRegistration DelegateRegistration() { + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + return reg; + } + + std::unique_ptr interpreter_; + TfLiteDelegate delegate_; +}; + +TEST_F(TestDelegateWithDynamicTensors, DisallowDynamicTensors) { + interpreter_->ModifyGraphWithDelegate(&delegate_, false); + + ASSERT_EQ(interpreter_->execution_plan().size(), 1); + // The interpreter should not call delegate's `Prepare` when dynamic tensors + // exist. So the node ID isn't changed. + ASSERT_EQ(interpreter_->execution_plan()[0], 0); +} + +TEST_F(TestDelegateWithDynamicTensors, AllowDynamicTensors) { + interpreter_->ModifyGraphWithDelegate(&delegate_, true); + + ASSERT_EQ(interpreter_->execution_plan().size(), 1); + // The node should be replaced because dynamic tensors are allowed. Therefore + // only node ID in the execution plan is changed from 0 to 1. + ASSERT_EQ(interpreter_->execution_plan()[0], 1); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/ios_makefile.inc b/tensorflow/contrib/lite/ios_makefile.inc index fc6594c3a04ba6aabba99bb631f85737baf389f1..079320586ffd01fc77818a81e0c5962f1d28c1f1 100644 --- a/tensorflow/contrib/lite/ios_makefile.inc +++ b/tensorflow/contrib/lite/ios_makefile.inc @@ -31,9 +31,6 @@ ifeq ($(TARGET), IOS) ${IPHONEOS_SYSROOT} \ -arch $(IOS_ARCH) \ -O3 - ifeq ($(IOS_ARCH), x86_64) - CXXFLAGS += -msse4.1 - endif CCFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ -fembed-bitcode \ -mno-thumb \ diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD index 35aacb70002d1d454f675484e4398bcdffc4acf1..f52d6ba6c5390e631d29e75f833aa4dd5bba1a68 100644 --- a/tensorflow/contrib/lite/java/BUILD +++ b/tensorflow/contrib/lite/java/BUILD @@ -29,7 +29,7 @@ android_library( visibility = ["//visibility:public"], deps = [ ":tflite_runtime", - "@javax_validation", + "@org_checkerframework_qual", ], ) @@ -42,7 +42,7 @@ android_library( ), visibility = ["//visibility:public"], deps = [ - "@javax_validation", + "@org_checkerframework_qual", ], ) @@ -58,7 +58,7 @@ java_library( deps = [ ":libtensorflowlite_jni.so", "//tensorflow/contrib/lite/java/src/main/native", - "@javax_validation", + "@org_checkerframework_qual", ], ) diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD index 654fa9d6d2799fc3cafa3e0e042cb2a5746bf2c5..5eb749aae6e224bec64b66832f116ebc3372c1ef 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD +++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD @@ -6,7 +6,7 @@ android_binary( name = "TfLiteCameraDemo", srcs = glob(["java/**/*.java"]), assets = [ - "@tflite_mobilenet//:labels.txt", + "//tensorflow/contrib/lite/java/demo/app/src/main/assets:labels_mobilenet_quant_v1_224.txt", "@tflite_mobilenet//:mobilenet_quant_v1_224.tflite", ], assets_dir = "", diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java index 9b9fdffab557060f0211a0ce361b002cc7d03956..300786c3ca01b12a46f7f9a6fe8fd720f97a79f4 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java @@ -299,7 +299,7 @@ public class Camera2BasicFragment extends Fragment // create either a new ImageClassifierQuantizedMobileNet or an ImageClassifierFloatInception classifier = new ImageClassifierQuantizedMobileNet(getActivity()); } catch (IOException e) { - Log.e(TAG, "Failed to initialize an image classifier."); + Log.e(TAG, "Failed to initialize an image classifier.", e); } startBackgroundThread(); } @@ -433,7 +433,7 @@ public class Camera2BasicFragment extends Fragment return; } } catch (CameraAccessException e) { - e.printStackTrace(); + Log.e(TAG, "Failed to access Camera", e); } catch (NullPointerException e) { // Currently an NPE is thrown when the Camera2API is used but not supported on the // device this code runs. @@ -478,7 +478,7 @@ public class Camera2BasicFragment extends Fragment } manager.openCamera(cameraId, stateCallback, backgroundHandler); } catch (CameraAccessException e) { - e.printStackTrace(); + Log.e(TAG, "Failed to open Camera", e); } catch (InterruptedException e) { throw new RuntimeException("Interrupted while trying to lock camera opening.", e); } @@ -545,7 +545,7 @@ public class Camera2BasicFragment extends Fragment runClassifier = false; } } catch (InterruptedException e) { - e.printStackTrace(); + Log.e(TAG, "Interrupted when stopping background thread", e); } } @@ -604,7 +604,7 @@ public class Camera2BasicFragment extends Fragment captureSession.setRepeatingRequest( previewRequest, captureCallback, backgroundHandler); } catch (CameraAccessException e) { - e.printStackTrace(); + Log.e(TAG, "Failed to set up config to capture Camera", e); } } @@ -615,7 +615,7 @@ public class Camera2BasicFragment extends Fragment }, null); } catch (CameraAccessException e) { - e.printStackTrace(); + Log.e(TAG, "Failed to preview Camera", e); } } diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java index 156c895146940adfe71f111be6e354e02b75ea48..e164ac75543ebab12e6b1c057c4ed487eb9accdf 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java @@ -16,7 +16,6 @@ limitations under the License. package com.example.android.tflitecamerademo; import android.app.Activity; - import java.io.IOException; /** diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java new file mode 100644 index 0000000000000000000000000000000000000000..d0102883e6b41f5c33a0061c5fd53b5f69b8ab54 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java @@ -0,0 +1,197 @@ +/*Copyright 2018 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.ovic; + +import android.graphics.Bitmap; +import android.os.SystemClock; +import android.util.Log; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.MappedByteBuffer; + +/** + * Class that benchmarks image classifier models. + * + *

===================== General workflow ======================= + * + *

{@code
+ * benchmarker = new OvicBenchmarker();
+ * benchmarker.getReadyToTest(labelInputStream, model);
+ * while (!benchmarker.shouldStop()) {
+ *   Bitmap bitmap = ...
+ *   benchmarker.doTestIteration(bitmap);
+ * }
+ * }
+ */ +public class OvicBenchmarker { + /** Tag for the {@link Log}. */ + private static final String TAG = "OvicBenchmarker"; + + /** Evaluation transformation parameters. */ + private static final float CENTRAL_FRACTION = 0.875f; + + /** Dimensions of inputs. */ + private static final int DIM_BATCH_SIZE = 1; + private static final int DIM_PIXEL_SIZE = 3; + private int imgHeight = 224; + private int imgWidth = 224; + + /* Preallocated buffers for storing image data in. */ + private int[] intValues = null; + + /** A ByteBuffer to hold image data, to be feed into classifier as inputs. */ + private ByteBuffer imgData = null; + + private OvicClassifier classifier; + + /** Total runtime in ms. */ + private double totalRuntime = 0.0; + /** Total allowed runtime in ms. */ + private double wallTime = 20000 * 30.0; + + private Boolean benchmarkStarted = null; + + /** + * Initializes an {@link OvicBenchmarker} + * + * @param wallTime: a double number specifying the total amount of time to benchmark. + */ + public OvicBenchmarker(double wallTime) { + benchmarkStarted = false; + totalRuntime = 0.0; + this.wallTime = wallTime; + } + + /** Check whether the benchmarker should stop. */ + public Boolean shouldStop() { + if (totalRuntime >= wallTime) { + Log.e( + TAG, + "Total runtime " + + Double.toString(totalRuntime) + + " exceeded walltime " + + Double.toString(wallTime)); + return true; + } + return false; + } + + /** Check whether the benchmarker is ready to start classifying images. */ + public Boolean readyToTest() { + return (classifier != null); + } + + /** + * Getting the benchmarker ready for classifying images. + * + * @param labelInputStream: an {@link InputStream} specifying where the list of labels should be + * read from. + * @param model: a {@link MappedByteBuffer} model to benchmark. + */ + public void getReadyToTest(InputStream labelInputStream, MappedByteBuffer model) { + try { + Log.i(TAG, "Creating classifier."); + classifier = new OvicClassifier(labelInputStream, model); + int [] inputDims = classifier.getInputDims(); + imgHeight = inputDims[1]; + imgWidth = inputDims[2]; + // Only accept QUANTIZED_UINT8 input. + imgData = ByteBuffer.allocateDirect(DIM_BATCH_SIZE * imgHeight * imgWidth * DIM_PIXEL_SIZE); + imgData.order(ByteOrder.nativeOrder()); + intValues = new int[imgHeight * imgWidth]; + } catch (Exception e) { + Log.e(TAG, e.getMessage()); + Log.e(TAG, "Failed to initialize ImageNet classifier for the benchmarker."); + } + } + + /** Return how many classes are predicted per image. */ + public int getNumPredictions() { + return classifier.getNumPredictions(); + } + + /** + * Perform test on a single bitmap image. + * + * @param bitmap: a {@link Bitmap} image to classify. + */ + public OvicSingleImageResult doTestIteration(Bitmap bitmap) + throws IOException, InterruptedException { + if (shouldStop() || !readyToTest()) { + return null; + } + OvicSingleImageResult iterResult = null; + try { + Log.i(TAG, "Converting bitmap."); + convertBitmapToInput(bitmap); + Log.i(TAG, "Classifying image."); + iterResult = classifier.classifyByteBuffer(imgData); + } catch (RuntimeException e) { + Log.e(TAG, e.getMessage()); + Log.e(TAG, "Failed to classify image."); + } + if (iterResult == null || iterResult.latency == null) { + throw new RuntimeException("Classification result or timing is invalid."); + } + Log.d(TAG, "Native inference latency: " + iterResult.latency); + Log.i(TAG, iterResult.toString()); + + if (!benchmarkStarted) { // Skip the first image to discount warming-up time. + benchmarkStarted = true; + } else { + totalRuntime += (double) iterResult.latency; + } + return iterResult; + } + + /** + * Writes Image data into a {@link ByteBuffer}. + * + * @param bitmap: a {@link Bitmap} source image. + */ + private void convertBitmapToInput(Bitmap bitmap) throws RuntimeException { + if (imgData == null) { + throw new RuntimeException("Benchmarker is not yet ready to test."); + } + imgData.rewind(); + // Perform transformations corresponding to evaluation mode. + float width = (float) bitmap.getWidth(); + float height = (float) bitmap.getHeight(); + int stWidth = Math.round((width - width * CENTRAL_FRACTION) / 2); + int stHeight = Math.round((height - height * CENTRAL_FRACTION) / 2); + int newWidth = Math.round(width - stWidth * 2); + int newHeight = Math.round(height - stHeight * 2); + bitmap = Bitmap.createBitmap(bitmap, stWidth, stHeight, newWidth, newHeight); + bitmap = Bitmap.createScaledBitmap(bitmap, imgWidth, imgHeight, true); + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); + + // Convert the image to ByteBuffer. + int pixel = 0; + long startTime = SystemClock.uptimeMillis(); + + for (int i = 0; i < imgHeight; ++i) { + for (int j = 0; j < imgWidth; ++j) { + final int val = intValues[pixel++]; + imgData.put((byte) ((val >> 16) & 0xFF)); + imgData.put((byte) ((val >> 8) & 0xFF)); + imgData.put((byte) (val & 0xFF)); + } + } + long endTime = SystemClock.uptimeMillis(); + Log.d(TAG, "Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime)); + } +} diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java new file mode 100644 index 0000000000000000000000000000000000000000..b2dfd8f2e710324f6c11a3098b858ffee8b28b3c --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java @@ -0,0 +1,209 @@ +/*Copyright 2018 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.ovic; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import org.tensorflow.lite.Interpreter; +import org.tensorflow.lite.TestHelper; + +/** Benchmark ImageNet Classifier with Tensorflow Lite. */ +public class OvicClassifier { + + /** Tag for the {@link Log}. */ + private static final String TAG = "OvicClassifier"; + + /** Number of results to show (i.e. the "K" in top-K predictions). */ + private static final int RESULTS_TO_SHOW = 5; + + /** An instance of the driver class to run model inference with Tensorflow Lite. */ + private Interpreter tflite; + + /** Labels corresponding to the output of the vision model. */ + private List labelList; + + /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */ + private byte[][] inferenceOutputArray = null; + /** An array to hold final prediction probabilities. */ + private float[][] labelProbArray = null; + + /** Input resultion. */ + private int[] inputDims = null; + /** Whether the model runs as float or quantized. */ + private Boolean outputIsFloat = null; + + private PriorityQueue> sortedLabels = + new PriorityQueue<>( + RESULTS_TO_SHOW, + new Comparator>() { + @Override + public int compare(Map.Entry o1, Map.Entry o2) { + return (o1.getValue()).compareTo(o2.getValue()); + } + }); + + /** Initializes an {@code OvicClassifier}. */ + OvicClassifier(InputStream labelInputStream, MappedByteBuffer model) + throws IOException, RuntimeException { + if (model == null) { + throw new RuntimeException("Input model is empty."); + } + labelList = loadLabelList(labelInputStream); + // OVIC uses one thread for CPU inference. + tflite = new Interpreter(model, 1); + inputDims = TestHelper.getInputDims(tflite, 0); + if (inputDims.length != 4) { + throw new RuntimeException("The model's input dimensions must be 4 (BWHC)."); + } + if (inputDims[0] != 1) { + throw new RuntimeException("The model must have a batch size of 1, got " + + inputDims[0] + " instead."); + } + if (inputDims[3] != 3) { + throw new RuntimeException("The model must have three color channels, got " + + inputDims[3] + " instead."); + } + int minSide = Math.min(inputDims[1], inputDims[2]); + int maxSide = Math.max(inputDims[1], inputDims[2]); + if (minSide <= 0 || maxSide > 1000) { + throw new RuntimeException("The model's resolution must be between (0, 1000]."); + } + String outputDataType = TestHelper.getOutputDataType(tflite, 0); + if (outputDataType.equals("float")) { + outputIsFloat = true; + } else if (outputDataType.equals("byte")) { + outputIsFloat = false; + } else { + throw new RuntimeException("Cannot process output type: " + outputDataType); + } + inferenceOutputArray = new byte[1][labelList.size()]; + labelProbArray = new float[1][labelList.size()]; + } + + /** Classifies a {@link ByteBuffer} image. */ + // @throws RuntimeException if model is uninitialized. + OvicSingleImageResult classifyByteBuffer(ByteBuffer imgData) throws RuntimeException { + if (tflite == null) { + throw new RuntimeException(TAG + ": ImageNet classifier has not been initialized; Failed."); + } + if (outputIsFloat == null) { + throw new RuntimeException(TAG + ": Classifier output type has not been resolved."); + } + if (outputIsFloat) { + tflite.run(imgData, labelProbArray); + } else { + tflite.run(imgData, inferenceOutputArray); + /** Convert results to float */ + for (int i = 0; i < inferenceOutputArray[0].length; i++) { + labelProbArray[0][i] = (inferenceOutputArray[0][i] & 0xff) / 255.0f; + } + } + OvicSingleImageResult iterResult = computeTopKLabels(); + iterResult.latency = getLastNativeInferenceLatencyMilliseconds(); + return iterResult; + } + + /** Return the probability array of all classes. */ + public float[][] getlabelProbArray() { + return labelProbArray; + } + + /** Return the number of top labels predicted by the classifier. */ + public int getNumPredictions() { + return RESULTS_TO_SHOW; + } + + /** Return the four dimensions of the input image. */ + public int[] getInputDims() { + return inputDims; + } + + /* + * Get native inference latency of last image classification run. + * @throws RuntimeException if model is uninitialized. + */ + public Long getLastNativeInferenceLatencyMilliseconds() { + if (tflite == null) { + throw new RuntimeException(TAG + ": ImageNet classifier has not been initialized; Failed."); + } + Long latency = tflite.getLastNativeInferenceDurationNanoseconds(); + return (latency == null) ? null : (Long) (latency / 1000000); + } + + /** Closes tflite to release resources. */ + public void close() { + tflite.close(); + tflite = null; + } + + /** Reads label list from Assets. */ + private static List loadLabelList(InputStream labelInputStream) throws IOException { + List labelList = new ArrayList(); + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(labelInputStream, StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + labelList.add(line); + } + } + return labelList; + } + + /** Computes top-K labels. */ + private OvicSingleImageResult computeTopKLabels() { + if (labelList == null) { + throw new RuntimeException("Label file has not been loaded."); + } + for (int i = 0; i < labelList.size(); ++i) { + sortedLabels.add(new AbstractMap.SimpleEntry<>(i, labelProbArray[0][i])); + if (sortedLabels.size() > RESULTS_TO_SHOW) { + sortedLabels.poll(); + } + } + OvicSingleImageResult singleImageResult = new OvicSingleImageResult(); + if (sortedLabels.size() != RESULTS_TO_SHOW) { + throw new RuntimeException( + "Number of returned labels does not match requirement: " + + sortedLabels.size() + + " returned, but " + + RESULTS_TO_SHOW + + " required."); + } + for (int i = 0; i < RESULTS_TO_SHOW; ++i) { + Map.Entry label = sortedLabels.poll(); + // ImageNet model prediction indices are 0-based. + singleImageResult.topKIndices.add(label.getKey()); + singleImageResult.topKClasses.add(labelList.get(label.getKey())); + singleImageResult.topKProbs.add(label.getValue()); + } + // Labels with lowest probability are returned first, hence need to reverse them. + Collections.reverse(singleImageResult.topKIndices); + Collections.reverse(singleImageResult.topKClasses); + Collections.reverse(singleImageResult.topKProbs); + return singleImageResult; + } +} diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java new file mode 100644 index 0000000000000000000000000000000000000000..4af9a65c2f45c57b979bf9629e34f52bb0853a44 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java @@ -0,0 +1,54 @@ +/*Copyright 2018 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.ovic; + +import java.util.ArrayList; + +/** Result class for inference run on a single image. */ +public class OvicSingleImageResult { + + /** Top K classes and probabilities. */ + public ArrayList topKClasses; + public ArrayList topKProbs; + public ArrayList topKIndices; + + /** Latency (ms). */ + public Long latency; + + OvicSingleImageResult() { + topKClasses = new ArrayList<>(); + topKProbs = new ArrayList<>(); + topKIndices = new ArrayList<>(); + latency = -1L; + } + + @Override + public String toString() { + String textToShow = latency + "ms"; + for (int k = 0; k < topKProbs.size(); ++k) { + textToShow += + "\nPrediction [" + + k + + "] = Class " + + Integer.toString(topKIndices.get(k)) + + " (" + + topKClasses.get(k) + + ") : " + + Float.toString(topKProbs.get(k)); + } + return textToShow; + } + +} diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java new file mode 100644 index 0000000000000000000000000000000000000000..4fd23a99d25d715530cf36f398d949f7e70598de --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java @@ -0,0 +1,176 @@ +/*Copyright 2018 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.ovic; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import java.awt.image.BufferedImage; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Paths; +import javax.imageio.ImageIO; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.ovic.OvicClassifier}. */ +@RunWith(JUnit4.class) +public final class OvicClassifierTest { + + private OvicClassifier classifier; + private InputStream labelsInputStream = null; + private MappedByteBuffer quantizedModel = null; + private MappedByteBuffer floatModel = null; + private MappedByteBuffer lowResModel = null; + private ByteBuffer testImage = null; + private ByteBuffer lowResTestImage = null; + private OvicSingleImageResult testResult = null; + private static final String LABELS_PATH = "testdata/labels.txt"; + private static final String QUANTIZED_MODEL_PATH = "testdata/quantized_model.lite"; + private static final String LOW_RES_MODEL_PATH = "testdata/low_res_model.lite"; + private static final String FLOAT_MODEL_PATH = "testdata/float_model.lite"; + private static final String TEST_IMAGE_PATH = "testdata/test_image_224.jpg"; + private static final String TEST_LOW_RES_IMAGE_PATH = "testdata/test_image_128.jpg"; + private static final int TEST_IMAGE_GROUNDTRUTH = 653; // "military uniform" + + @Before + public void setUp() { + try { + File labelsfile = new File(getTestDir(LABELS_PATH)); + labelsInputStream = new FileInputStream(labelsfile); + quantizedModel = loadModelFile(getTestDir(QUANTIZED_MODEL_PATH)); + floatModel = loadModelFile(getTestDir(FLOAT_MODEL_PATH)); + lowResModel = loadModelFile(getTestDir(LOW_RES_MODEL_PATH)); + File imageFile = new File(getTestDir(TEST_IMAGE_PATH)); + BufferedImage img = ImageIO.read(imageFile); + testImage = toByteBuffer(img); + // Low res image and models. + imageFile = new File(getTestDir(TEST_LOW_RES_IMAGE_PATH)); + img = ImageIO.read(imageFile); + lowResTestImage = toByteBuffer(img); + } catch (IOException e) { + System.out.print(e.getMessage()); + } + System.out.println("Successful setup"); + } + + private static String getTestDir(String testfile) throws IOException { + return Paths.get("third_party/tensorflow/contrib/lite/java/ovic/src/", testfile).toString(); + } + + @Test + public void ovicClassifier_quantizedModelCreateSuccess() throws Exception { + classifier = new OvicClassifier(labelsInputStream, quantizedModel); + assertThat(classifier != null).isTrue(); + } + + @Test + public void ovicClassifier_floatModelCreateSuccess() throws Exception { + classifier = new OvicClassifier(labelsInputStream, floatModel); + assertThat(classifier != null).isTrue(); + } + + @Test + public void ovicClassifier_quantizedModelClassifySuccess() throws Exception { + classifier = new OvicClassifier(labelsInputStream, quantizedModel); + testResult = classifier.classifyByteBuffer(testImage); + assertCorrectTopK(testResult); + } + + @Test + public void ovicClassifier_floatModelClassifySuccess() throws Exception { + classifier = new OvicClassifier(labelsInputStream, floatModel); + testResult = classifier.classifyByteBuffer(testImage); + assertCorrectTopK(testResult); + } + + @Test + public void ovicClassifier_lowResModelClassifySuccess() throws Exception { + classifier = new OvicClassifier(labelsInputStream, lowResModel); + testResult = classifier.classifyByteBuffer(lowResTestImage); + assertCorrectTopK(testResult); + } + + @Test + public void ovicClassifier_latencyNotNull() throws Exception { + classifier = new OvicClassifier(labelsInputStream, floatModel); + testResult = classifier.classifyByteBuffer(testImage); + assertThat(testResult.latency != null).isTrue(); + } + + @Test + public void ovicClassifier_mismatchedInputResolutionFails() throws Exception { + classifier = new OvicClassifier(labelsInputStream, lowResModel); + int[] inputDims = classifier.getInputDims(); + assertThat((inputDims[1] == 128) && (inputDims[2] == 128)).isTrue(); + try { + testResult = classifier.classifyByteBuffer(testImage); + fail(); + } catch (RuntimeException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Failed to get input dimensions. 0-th input should have 49152 bytes, " + + "but found 150528 bytes."); + } + } + + private static ByteBuffer toByteBuffer(BufferedImage image) { + ByteBuffer imgData = ByteBuffer.allocateDirect( + image.getHeight() * image.getWidth() * 3); + imgData.order(ByteOrder.nativeOrder()); + for (int y = 0; y < image.getHeight(); y++) { + for (int x = 0; x < image.getWidth(); x++) { + int val = image.getRGB(x, y); + imgData.put((byte) ((val >> 16) & 0xFF)); + imgData.put((byte) ((val >> 8) & 0xFF)); + imgData.put((byte) (val & 0xFF)); + } + } + return imgData; + } + + private static void assertCorrectTopK(OvicSingleImageResult testResult) { + assertThat(testResult.topKClasses.size() > 0).isTrue(); + Boolean topKAccurate = false; + // Assert that the correct class is in the top K. + for (int i = 0; i < testResult.topKIndices.size(); i++) { + if (testResult.topKIndices.get(i) == TEST_IMAGE_GROUNDTRUTH) { + topKAccurate = true; + break; + } + } + System.out.println(testResult.toString()); + System.out.flush(); + assertThat(topKAccurate).isTrue(); + } + + private static MappedByteBuffer loadModelFile(String modelFilePath) throws IOException { + File modelfile = new File(modelFilePath); + FileInputStream inputStream = new FileInputStream(modelfile); + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = 0L; + long declaredLength = fileChannel.size(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } +} diff --git a/tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt b/tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt new file mode 100644 index 0000000000000000000000000000000000000000..fe811239d8e2989de19fecabb1ebb0c9dddac514 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt @@ -0,0 +1,1001 @@ +background +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/tensorflow/contrib/lite/java/proguard.flags b/tensorflow/contrib/lite/java/proguard.flags new file mode 100644 index 0000000000000000000000000000000000000000..8ee3d7e7ae728b27789336ac56208acdf13ee424 --- /dev/null +++ b/tensorflow/contrib/lite/java/proguard.flags @@ -0,0 +1,3 @@ +-keepclassmembers class org.tensorflow.lite.NativeInterpreterWrapper { + private long inferenceDurationNanoseconds; +} \ No newline at end of file diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java index d63c299589d2e8ce1051a52d29b533ed126bbcf7..fc16488a6459eb227fde712055d3e8ccfcce0070 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java @@ -71,6 +71,23 @@ enum DataType { throw new IllegalArgumentException("DataType " + this + " is not supported yet"); } + /** Gets string names of the data type. */ + String toStringName() { + switch (this) { + case FLOAT32: + return "float"; + case INT32: + return "int"; + case UINT8: + return "byte"; + case INT64: + return "long"; + case BYTEBUFFER: + return "ByteBuffer"; + } + throw new IllegalArgumentException("DataType " + this + " is not supported yet"); + } + // Cached to avoid copying it private static final DataType[] values = values(); } diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index dd883d69d2065236ee29012b9bde99972aefbcf7..14f461f5f9ba8c0755d2a1968533a79cce10750a 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -19,7 +19,7 @@ import java.io.File; import java.nio.MappedByteBuffer; import java.util.HashMap; import java.util.Map; -import javax.validation.constraints.NotNull; +import org.checkerframework.checker.nullness.qual.NonNull; /** * Driver class to drive model inference with TensorFlow Lite. @@ -60,7 +60,7 @@ public final class Interpreter implements AutoCloseable { * * @param modelFile: a File of a pre-trained TF Lite model. */ - public Interpreter(@NotNull File modelFile) { + public Interpreter(@NonNull File modelFile) { if (modelFile == null) { return; } @@ -73,20 +73,34 @@ public final class Interpreter implements AutoCloseable { *

The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code * Interpreter}. */ - public Interpreter(@NotNull MappedByteBuffer mappedByteBuffer) { + public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer) { wrapper = new NativeInterpreterWrapper(mappedByteBuffer); } + /** + * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file and + * specifies the number of threads used for inference. + * + *

The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code + * Interpreter}. + */ + public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer, int numThreads) { + wrapper = new NativeInterpreterWrapper(mappedByteBuffer, numThreads); + } + /** * Runs model inference if the model takes only one input, and provides only one output. * + *

Warning: The API runs much faster if {@link ByteBuffer} is used as input data type. Please + * consider using {@link ByteBuffer} to feed input data for better performance. + * * @param input an array or multidimensional array, or a {@link ByteBuffer} of primitive types * including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large * input data. When {@link ByteBuffer} is used, its content should remain unchanged until * model inference is done. * @param output a multidimensional array of output data. */ - public void run(@NotNull Object input, @NotNull Object output) { + public void run(@NonNull Object input, @NonNull Object output) { Object[] inputs = {input}; Map outputs = new HashMap<>(); outputs.put(0, output); @@ -96,6 +110,9 @@ public final class Interpreter implements AutoCloseable { /** * Runs model inference if the model takes multiple inputs, or returns multiple outputs. * + *

Warning: The API runs much faster if {@link ByteBuffer} is used as input data type. Please + * consider using {@link ByteBuffer} to feed input data for better performance. + * * @param inputs an array of input data. The inputs should be in the same order as inputs of the * model. Each input can be an array or multidimensional array, or a {@link ByteBuffer} of * primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred @@ -105,7 +122,7 @@ public final class Interpreter implements AutoCloseable { * needs to keep entries for the outputs to be used. */ public void runForMultipleInputsOutputs( - @NotNull Object[] inputs, @NotNull Map outputs) { + @NonNull Object[] inputs, @NonNull Map outputs) { if (wrapper == null) { throw new IllegalStateException("The Interpreter has already been closed."); } @@ -128,7 +145,7 @@ public final class Interpreter implements AutoCloseable { * *

IllegalArgumentException will be thrown if it fails to resize. */ - public void resizeInput(int idx, @NotNull int[] dims) { + public void resizeInput(int idx, @NonNull int[] dims) { if (wrapper == null) { throw new IllegalStateException("The Interpreter has already been closed."); } @@ -161,6 +178,27 @@ public final class Interpreter implements AutoCloseable { return wrapper.getOutputIndex(opName); } + /** + * Returns native inference timing. + *

IllegalArgumentException will be thrown if the model is not initialized by the + * {@link Interpreter}. + */ + public Long getLastNativeInferenceDurationNanoseconds() { + if (wrapper == null) { + throw new IllegalStateException("The interpreter has already been closed."); + } + return wrapper.getLastNativeInferenceDurationNanoseconds(); + } + + /** Turns on/off Android NNAPI for hardware acceleration when it is available. */ + public void setUseNNAPI(boolean useNNAPI) { + if (wrapper != null) { + wrapper.setUseNNAPI(useNNAPI); + } else { + throw new IllegalStateException("NativeInterpreterWrapper has already been closed."); + } + } + /** Release resources associated with the {@code Interpreter}. */ @Override public void close() { diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index 5ee594dec492ad2fee22e603a6de311b3fed4cac..dbf8f8f7cc2815a46130e342d7e45d4e471696de 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -34,7 +34,8 @@ final class NativeInterpreterWrapper implements AutoCloseable { NativeInterpreterWrapper(String modelPath) { errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); modelHandle = createModel(modelPath, errorHandle); - interpreterHandle = createInterpreter(modelHandle, errorHandle); + interpreterHandle = createInterpreter(modelHandle, errorHandle, /* numThreads= */ -1); + isMemoryAllocated = true; } /** @@ -46,7 +47,21 @@ final class NativeInterpreterWrapper implements AutoCloseable { modelByteBuffer = mappedByteBuffer; errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); - interpreterHandle = createInterpreter(modelHandle, errorHandle); + interpreterHandle = createInterpreter(modelHandle, errorHandle, /* numThreads= */ -1); + isMemoryAllocated = true; + } + + /** + * Initializes a {@code NativeInterpreterWrapper} with a {@code MappedByteBuffer} and specifies + * the number of inference threads. The MappedByteBuffer should not be modified after the + * construction of a {@code NativeInterpreterWrapper}. + */ + NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer, int numThreads) { + modelByteBuffer = mappedByteBuffer; + errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); + modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); + interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads); + isMemoryAllocated = true; } /** Releases resources associated with this {@code NativeInterpreterWrapper}. */ @@ -59,6 +74,7 @@ final class NativeInterpreterWrapper implements AutoCloseable { modelByteBuffer = null; inputsIndexes = null; outputsIndexes = null; + isMemoryAllocated = false; } /** Sets inputs, runs model inference and returns outputs. */ @@ -91,11 +107,21 @@ final class NativeInterpreterWrapper implements AutoCloseable { i, inputs.length)); } } + inferenceDurationNanoseconds = -1; long[] outputsHandles = - run(interpreterHandle, errorHandle, sizes, dataTypes, numsOfBytes, inputs); + run( + interpreterHandle, + errorHandle, + sizes, + dataTypes, + numsOfBytes, + inputs, + this, + isMemoryAllocated); if (outputsHandles == null || outputsHandles.length == 0) { throw new IllegalStateException("Interpreter has no outputs."); } + isMemoryAllocated = true; Tensor[] outputs = new Tensor[outputsHandles.length]; for (int i = 0; i < outputsHandles.length; ++i) { outputs[i] = Tensor.fromHandle(outputsHandles[i]); @@ -109,14 +135,18 @@ final class NativeInterpreterWrapper implements AutoCloseable { Object[] sizes, int[] dtypes, int[] numsOfBytes, - Object[] values); + Object[] values, + NativeInterpreterWrapper wrapper, + boolean memoryAllocated); /** Resizes dimensions of a specific input. */ void resizeInput(int idx, int[] dims) { - resizeInput(interpreterHandle, errorHandle, idx, dims); + if (resizeInput(interpreterHandle, errorHandle, idx, dims)) { + isMemoryAllocated = false; + } } - private static native void resizeInput( + private static native boolean resizeInput( long interpreterHandle, long errorHandle, int inputIdx, int[] dims); void setUseNNAPI(boolean useNNAPI) { @@ -236,6 +266,35 @@ final class NativeInterpreterWrapper implements AutoCloseable { } } + /** + * Gets the last inference duration in nanoseconds. It returns null if there is no previous + * inference run or the last inference run failed. + */ + Long getLastNativeInferenceDurationNanoseconds() { + return (inferenceDurationNanoseconds < 0) ? null : inferenceDurationNanoseconds; + } + + /** + * Gets the dimensions of an input. It throws IllegalArgumentException if input index is invalid. + */ + int[] getInputDims(int index) { + return getInputDims(interpreterHandle, index, -1); + } + + /** + * Gets the dimensions of an input. If numBytes >= 0, it will check whether num of bytes match the + * input. + */ + private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes); + + /** Gets the type of an output. It throws IllegalArgumentException if output index is invalid. */ + String getOutputDataType(int index) { + int type = getOutputDataType(interpreterHandle, index); + return DataType.fromNumber(type).toStringName(); + } + + private static native int getOutputDataType(long interpreterHandle, int outputIdx); + private static final int ERROR_BUFFER_SIZE = 512; private long errorHandle; @@ -246,12 +305,16 @@ final class NativeInterpreterWrapper implements AutoCloseable { private int inputSize; + private long inferenceDurationNanoseconds = -1; + private MappedByteBuffer modelByteBuffer; private Map inputsIndexes; private Map outputsIndexes; + private boolean isMemoryAllocated = false; + private static native String[] getInputNames(long interpreterHandle); private static native String[] getOutputNames(long interpreterHandle); @@ -264,12 +327,10 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native long createModelWithBuffer(MappedByteBuffer modelBuffer, long errorHandle); - private static native long createInterpreter(long modelHandle, long errorHandle); + private static native long createInterpreter(long modelHandle, long errorHandle, int numThreads); private static native void delete(long errorHandle, long modelHandle, long interpreterHandle); - private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes); - static { TensorFlowLite.init(); } diff --git a/tensorflow/contrib/lite/java/src/main/native/BUILD b/tensorflow/contrib/lite/java/src/main/native/BUILD index 15806d57c8ed7a45d2db9b80e2aab8e22349ee3e..3571182ca92e959d54935cfdc76679ab69a8cfa9 100644 --- a/tensorflow/contrib/lite/java/src/main/native/BUILD +++ b/tensorflow/contrib/lite/java/src/main/native/BUILD @@ -11,6 +11,7 @@ licenses(["notice"]) # Apache 2.0 cc_library( name = "native_framework_only", srcs = [ + "duration_utils_jni.cc", "exception_jni.cc", "nativeinterpreterwrapper_jni.cc", "tensor_jni.cc", diff --git a/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc b/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..0e08a04370592f6e3c92b5811fa7e163f808e03c --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +namespace tflite { + +// Gets the elapsed wall-clock timespec. +timespec getCurrentTime() { + timespec time; + clock_gettime(CLOCK_MONOTONIC, &time); + return time; +} + +// Computes the time diff from two timespecs. Returns '-1' if 'stop' is earlier +// than 'start'. +jlong timespec_diff_nanoseconds(struct timespec* start, struct timespec* stop) { + jlong result = stop->tv_sec - start->tv_sec; + if (result < 0) return -1; + result = 1000000000 * result + (stop->tv_nsec - start->tv_nsec); + if (result < 0) return -1; + return result; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index c346f9f92e360c0722ebac440d790da6441ceecf..844226203bb02f4017b2f04da34ac81ac2b7a191 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h" - namespace { const int kByteBufferValue = 999; @@ -79,6 +78,21 @@ TfLiteType resolveDataType(jint data_type) { } } +int getDataType(TfLiteType data_type) { + switch (data_type) { + case kTfLiteFloat32: + return 1; + case kTfLiteInt32: + return 2; + case kTfLiteUInt8: + return 3; + case kTfLiteInt64: + return 4; + default: + return -1; + } +} + void printDims(char* buffer, int max_size, int* dims, int num_dims) { if (max_size <= 0) return; buffer[0] = '?'; @@ -149,6 +163,45 @@ TfLiteStatus checkInputs(JNIEnv* env, tflite::Interpreter* interpreter, return kTfLiteOk; } +// Checks whether there is any difference between dimensions of a tensor and a +// given dimensions. Returns true if there is difference, else false. +bool areDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) { + int num_dims = static_cast(env->GetArrayLength(dims)); + jint* ptr = env->GetIntArrayElements(dims, nullptr); + if (ptr == nullptr) { + throwException(env, kIllegalArgumentException, + "Empty dimensions of input array."); + return true; + } + if (tensor->dims->size != num_dims) { + return true; + } + for (int i = 0; i < num_dims; ++i) { + if (ptr[i] != tensor->dims->data[i]) { + return true; + } + } + env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT); + return false; +} + +bool areInputDimensionsTheSame(JNIEnv* env, tflite::Interpreter* interpreter, + int input_size, jobjectArray sizes) { + if (interpreter->inputs().size() != input_size) { + return false; + } + for (int i = 0; i < input_size; ++i) { + int input_idx = interpreter->inputs()[i]; + jintArray dims = + static_cast(env->GetObjectArrayElement(sizes, i)); + TfLiteTensor* target = interpreter->tensor(input_idx); + if (areDimsDifferent(env, target, dims)) return false; + env->DeleteLocalRef(dims); + if (env->ExceptionCheck()) return false; + } + return true; +} + TfLiteStatus resizeInputs(JNIEnv* env, tflite::Interpreter* interpreter, int input_size, jobjectArray sizes) { for (int i = 0; i < input_size; ++i) { @@ -270,6 +323,19 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter( return reinterpret_cast(error_reporter); } +// Verifies whether the model is a flatbuffer file. +class JNIFlatBufferVerifier : public tflite::TfLiteVerifier { + public: + bool Verify(const char* data, int length, + tflite::ErrorReporter* reporter) override { + if (!VerifyModel(data, length)) { + reporter->Report("The model is not a valid Flatbuffer file"); + return false; + } + return true; + } +}; + JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle) { @@ -278,17 +344,11 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( if (error_reporter == nullptr) return 0; const char* path = env->GetStringUTFChars(model_file, nullptr); - { - tflite::FileCopyAllocation allocation(path, nullptr); - if (!VerifyModel(allocation.base(), allocation.bytes())) { - throwException(env, kIllegalArgumentException, - "Contents of %s is not a valid flatbuffer model", path); - env->ReleaseStringUTFChars(model_file, path); - return 0; - } - } + std::unique_ptr verifier; + verifier.reset(new JNIFlatBufferVerifier()); - auto model = tflite::FlatBufferModel::BuildFromFile(path, error_reporter); + auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile( + path, verifier.get(), error_reporter); if (!model) { throwException(env, kIllegalArgumentException, "Contents of %s does not encode a valid TensorFlowLite " @@ -330,7 +390,8 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( - JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle) { + JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle, + jint num_threads) { tflite::FlatBufferModel* model = convertLongToModel(env, model_handle); if (model == nullptr) return 0; BufferErrorReporter* error_reporter = @@ -338,12 +399,21 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( if (error_reporter == nullptr) return 0; auto resolver = ::tflite::CreateOpResolver(); std::unique_ptr interpreter; - TfLiteStatus status = - tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter); + TfLiteStatus status = tflite::InterpreterBuilder(*model, *(resolver.get()))( + &interpreter, static_cast(num_threads)); if (status != kTfLiteOk) { throwException(env, kIllegalArgumentException, "Cannot create interpreter: %s", error_reporter->CachedErrorMessage()); + return 0; + } + // allocates memory + status = interpreter->AllocateTensors(); + if (status != kTfLiteOk) { + throwException(env, kNullPointerException, + "Can not allocate memory for the interpreter", + error_reporter->CachedErrorMessage()); + return 0; } return reinterpret_cast(interpreter.release()); } @@ -353,7 +423,7 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run( JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, - jobjectArray values) { + jobjectArray values, jobject wrapper, jboolean memory_allocated) { tflite::Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); if (interpreter == nullptr) return nullptr; @@ -365,25 +435,29 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_run( TfLiteStatus status = checkInputs(env, interpreter, input_size, data_types, nums_of_bytes, values, sizes); if (status != kTfLiteOk) return nullptr; - // resizes inputs - status = resizeInputs(env, interpreter, input_size, sizes); - if (status != kTfLiteOk) { - throwException(env, kNullPointerException, "Can not resize the input: %s", - error_reporter->CachedErrorMessage()); - return nullptr; - } - // allocates memory - status = interpreter->AllocateTensors(); - if (status != kTfLiteOk) { - throwException(env, kNullPointerException, - "Can not allocate memory for the given inputs: %s", - error_reporter->CachedErrorMessage()); - return nullptr; + if (!memory_allocated || + !areInputDimensionsTheSame(env, interpreter, input_size, sizes)) { + // resizes inputs + status = resizeInputs(env, interpreter, input_size, sizes); + if (status != kTfLiteOk) { + throwException(env, kNullPointerException, "Can not resize the input: %s", + error_reporter->CachedErrorMessage()); + return nullptr; + } + // allocates memory + status = interpreter->AllocateTensors(); + if (status != kTfLiteOk) { + throwException(env, kNullPointerException, + "Can not allocate memory for the given inputs: %s", + error_reporter->CachedErrorMessage()); + return nullptr; + } } // sets inputs status = setInputs(env, interpreter, input_size, data_types, nums_of_bytes, values); if (status != kTfLiteOk) return nullptr; + timespec beforeInference = ::tflite::getCurrentTime(); // runs inference if (interpreter->Invoke() != kTfLiteOk) { throwException(env, kIllegalArgumentException, @@ -391,6 +465,17 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_run( error_reporter->CachedErrorMessage()); return nullptr; } + timespec afterInference = ::tflite::getCurrentTime(); + jclass wrapper_clazz = env->GetObjectClass(wrapper); + jfieldID fid = + env->GetFieldID(wrapper_clazz, "inferenceDurationNanoseconds", "J"); + if (env->ExceptionCheck()) { + env->ExceptionClear(); + } else if (fid != nullptr) { + env->SetLongField( + wrapper, fid, + ::tflite::timespec_diff_nanoseconds(&beforeInference, &afterInference)); + } // returns outputs const std::vector& results = interpreter->outputs(); if (results.empty()) { @@ -414,7 +499,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return nullptr; const int idx = static_cast(input_idx); - if (input_idx >= interpreter->inputs().size()) { + if (input_idx < 0 || input_idx >= interpreter->inputs().size()) { throwException(env, kIllegalArgumentException, "Out of range: Failed to get %d-th input out of %d inputs", input_idx, interpreter->inputs().size()); @@ -422,45 +507,72 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( } TfLiteTensor* target = interpreter->tensor(interpreter->inputs()[idx]); int size = target->dims->size; - int expected_num_bytes = elementByteSize(target->type); - for (int i = 0; i < size; ++i) { - expected_num_bytes *= target->dims->data[i]; - } - if (num_bytes != expected_num_bytes) { - throwException(env, kIllegalArgumentException, - "Failed to get input dimensions. %d-th input should have" - " %d bytes, but found %d bytes.", - idx, expected_num_bytes, num_bytes); - return nullptr; + if (num_bytes >= 0) { // verifies num of bytes matches if num_bytes if valid. + int expected_num_bytes = elementByteSize(target->type); + for (int i = 0; i < size; ++i) { + expected_num_bytes *= target->dims->data[i]; + } + if (num_bytes != expected_num_bytes) { + throwException(env, kIllegalArgumentException, + "Failed to get input dimensions. %d-th input should have" + " %d bytes, but found %d bytes.", + idx, expected_num_bytes, num_bytes); + return nullptr; + } } jintArray outputs = env->NewIntArray(size); env->SetIntArrayRegion(outputs, 0, size, &(target->dims->data[0])); return outputs; } -JNIEXPORT void JNICALL +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType( + JNIEnv* env, jclass clazz, jlong handle, jint output_idx) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return -1; + const int idx = static_cast(output_idx); + if (output_idx < 0 || output_idx >= interpreter->outputs().size()) { + throwException(env, kIllegalArgumentException, + "Out of range: Failed to get %d-th output out of %d outputs", + output_idx, interpreter->outputs().size()); + return -1; + } + TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]); + int type = getDataType(target->type); + return static_cast(type); +} + +JNIEXPORT jboolean JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, jint input_idx, jintArray dims) { BufferErrorReporter* error_reporter = convertLongToErrorReporter(env, error_handle); - if (error_reporter == nullptr) return; + if (error_reporter == nullptr) return JNI_FALSE; tflite::Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); - if (interpreter == nullptr) return; + if (interpreter == nullptr) return JNI_FALSE; const int idx = static_cast(input_idx); if (idx < 0 || idx >= interpreter->inputs().size()) { throwException(env, kIllegalArgumentException, "Can not resize %d-th input for a model having %d inputs.", idx, interpreter->inputs().size()); + return JNI_FALSE; } - TfLiteStatus status = interpreter->ResizeInputTensor( - interpreter->inputs()[idx], convertJIntArrayToVector(env, dims)); - if (status != kTfLiteOk) { - throwException(env, kIllegalArgumentException, - "Failed to resize %d-th input: %s", idx, - error_reporter->CachedErrorMessage()); + // check whether it is resizing with the same dimensions. + TfLiteTensor* target = interpreter->tensor(input_idx); + bool is_changed = areDimsDifferent(env, target, dims); + if (is_changed) { + TfLiteStatus status = interpreter->ResizeInputTensor( + interpreter->inputs()[idx], convertJIntArrayToVector(env, dims)); + if (status != kTfLiteOk) { + throwException(env, kIllegalArgumentException, + "Failed to resize %d-th input: %s", idx, + error_reporter->CachedErrorMessage()); + return JNI_FALSE; + } } + return is_changed ? JNI_TRUE : JNI_FALSE; } JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete( diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h index c52a7e4e439936344be26d5761fb5747db64794a..0e28a77feea41d72be126d6e60fffbe7ce374a76 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/interpreter.h" @@ -28,6 +29,9 @@ limitations under the License. namespace tflite { // This is to be provided at link-time by a library. extern std::unique_ptr CreateOpResolver(); +extern timespec getCurrentTime(); +extern jlong timespec_diff_nanoseconds(struct timespec* start, + struct timespec* stop); } // namespace tflite #ifdef __cplusplus @@ -95,30 +99,33 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: - * Signature: (JJ)J + * Signature: (JJI)J */ JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( - JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle); + JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle, + jint num_threads); /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: - * Signature: (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;)[J + * Signature: + * (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;Ljava/lang/Object;Z)[J */ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run( JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, - jobjectArray values); + jobjectArray values, jobject wrapper, jboolean memory_allocated); /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: * Signature: (JII)[I * - * It gets input dimensions if num_bytes matches number of bytes required by - * the input, else returns null and throws IllegalArgumentException. + * Gets input dimensions. If num_bytes is non-negative, it will check whether + * num_bytes matches num of bytes required by the input, and return null and + * throw IllegalArgumentException if not. */ JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( @@ -127,11 +134,23 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: - * Signature: (JJI[I) + * Signature: (JI)I * - * It resizes dimensions of a input. + * Gets output dimensions. */ -JNIEXPORT void JNICALL +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType( + JNIEnv* env, jclass clazz, jlong handle, jint output_idx); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JJI[I)Z + * + * It returns true if resizing input tensor to different dimensions, else return + * false. + */ +JNIEXPORT jboolean JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, jint input_idx, jintArray dims); diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index 424b3de6c97672e310c54230a7ac1204f46d9ac8..61d6c35ec86beebf78dd81e17e145863516802fa 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -218,4 +218,52 @@ public final class InterpreterTest { int index = interpreter.getOutputIndex("MobilenetV1/Predictions/Softmax"); assertThat(index).isEqualTo(0); } + + @Test + public void testTurnOffNNAPI() throws Exception { + Path path = MODEL_FILE.toPath(); + FileChannel fileChannel = + (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); + MappedByteBuffer mappedByteBuffer = + fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size()); + Interpreter interpreter = new Interpreter(mappedByteBuffer); + interpreter.setUseNNAPI(true); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + interpreter.run(fourD, parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.setUseNNAPI(false); + interpreter.run(fourD, parsedOutputs); + outputOneD = parsedOutputs[0][0][0]; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.close(); + fileChannel.close(); + } + + @Test + public void testTurnOnNNAPI() throws Exception { + Path path = MODEL_FILE.toPath(); + FileChannel fileChannel = + (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); + MappedByteBuffer mappedByteBuffer = + fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size()); + Interpreter interpreter = new Interpreter(mappedByteBuffer); + interpreter.setUseNNAPI(true); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + interpreter.run(fourD, parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.close(); + fileChannel.close(); + } } diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 90323555d88419d837a76bca7de6d9998e388fca..dbe45e5a05b8227b441de7ca6747f61d010ae210 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -47,6 +47,9 @@ public final class NativeInterpreterWrapperTest { private static final String MODEL_WITH_CUSTOM_OP_PATH = "tensorflow/contrib/lite/java/src/testdata/with_custom_op.lite"; + private static final String NONEXISTING_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/nonexisting_model.bin"; + @Test public void testConstructor() { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); @@ -60,7 +63,18 @@ public final class NativeInterpreterWrapperTest { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH); fail(); } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("is not a valid flatbuffer model"); + assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file"); + } + } + + @Test + public void testConstructorWithNonexistingModel() { + try { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(NONEXISTING_MODEL_PATH); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file"); + assertThat(e).hasMessageThat().contains("Could not open"); } } @@ -94,6 +108,30 @@ public final class NativeInterpreterWrapperTest { wrapper.close(); } + @Test + public void testRunWithInputsOfSameDims() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, -6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + float[][][][] parsedOutputs = new float[2][8][8][3]; + outputs[0].copyTo(parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, -19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + parsedOutputs = new float[2][8][8][3]; + outputs[0].copyTo(parsedOutputs); + outputOneD = parsedOutputs[0][0][0]; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + wrapper.close(); + } + @Test public void testRunWithInt() { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INT_MODEL_PATH); @@ -417,4 +455,87 @@ public final class NativeInterpreterWrapperTest { assertThat(shape[1]).isEqualTo(3); assertThat(shape[2]).isEqualTo(1); } + + @Test + public void testGetInferenceLatency() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isGreaterThan(0L); + wrapper.close(); + } + + @Test + public void testGetInferenceLatencyWithNewWrapper() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isNull(); + wrapper.close(); + } + + @Test + public void testGetLatencyAfterFailedInference() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]"); + } + assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isNull(); + wrapper.close(); + } + + @Test + public void testGetInputDims() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + int[] expectedDims = {1, 8, 8, 3}; + assertThat(wrapper.getInputDims(0)).isEqualTo(expectedDims); + wrapper.close(); + } + + @Test + public void testGetInputDimsOutOfRange() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + try { + wrapper.getInputDims(-1); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("Out of range"); + } + try { + wrapper.getInputDims(1); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("Out of range"); + } + wrapper.close(); + } + + @Test + public void testGetOutputDataType() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + assertThat(wrapper.getOutputDataType(0)).contains("float"); + wrapper.close(); + wrapper = new NativeInterpreterWrapper(LONG_MODEL_PATH); + assertThat(wrapper.getOutputDataType(0)).contains("long"); + wrapper.close(); + wrapper = new NativeInterpreterWrapper(INT_MODEL_PATH); + assertThat(wrapper.getOutputDataType(0)).contains("int"); + wrapper.close(); + wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH); + assertThat(wrapper.getOutputDataType(0)).contains("byte"); + wrapper.close(); + } } diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java index 8660cabf709e6531a5667a16e5cf43a93c7135bd..3aef0c3bb6cc4748de0e55d31f0215a77320ae69 100644 --- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java @@ -32,4 +32,55 @@ public class TestHelper { throw new IllegalArgumentException("Interpreter has not initialized; Failed to setUseNNAPI."); } } + + /** + * Gets the last inference duration in nanoseconds. It returns null if there is no previous + * inference run or the last inference run failed. + * + * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code + * IllegalArgumentException} will be thrown. + */ + public static Long getLastNativeInferenceDurationNanoseconds(Interpreter interpreter) { + if (interpreter != null && interpreter.wrapper != null) { + return interpreter.wrapper.getLastNativeInferenceDurationNanoseconds(); + } else { + throw new IllegalArgumentException("Interpreter has not initialized; Failed to get latency."); + } + } + + /** + * Gets the dimensions of an input. + * + * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code + * IllegalArgumentException} will be thrown. + * @param index an integer index of the input. If it is invalid, an {@code + * IllegalArgumentException} will be thrown. + */ + public static int[] getInputDims(Interpreter interpreter, int index) { + if (interpreter != null && interpreter.wrapper != null) { + return interpreter.wrapper.getInputDims(index); + } else { + throw new IllegalArgumentException( + "Interpreter has not initialized;" + " Failed to get input dimensions."); + } + } + + /** + * Gets the string name of the data type of an output. + * + * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code + * IllegalArgumentException} will be thrown. + * @param index an integer index of the output. If it is invalid, an {@code + * IllegalArgumentException} will be thrown. + * @return string name of the data type. Possible values include "float", "int", "byte", and + * "long". + */ + public static String getOutputDataType(Interpreter interpreter, int index) { + if (interpreter != null && interpreter.wrapper != null) { + return interpreter.wrapper.getOutputDataType(index); + } else { + throw new IllegalArgumentException( + "Interpreter has not initialized;" + " Failed to get output data type."); + } + } } diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 956bd35fe67b3a487f5eb545a827908e12127455..48021aea47573b1b24bae78a9532200dc222020e 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -5,15 +5,17 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") -load( - "//tensorflow:tensorflow.bzl", - "tf_cc_test", -) +load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") tf_cc_test( name = "optional_tensor_test", size = "small", srcs = ["optional_tensor_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -33,11 +35,27 @@ cc_library( "//tensorflow/contrib/lite:schema_fbs_version", "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/testing:util", - "//tensorflow/core:lib", + "//tensorflow/core:tflite_portable_logging", "@com_google_googletest//:gtest", ], ) +cc_library( + name = "eigen_support", + srcs = [ + "eigen_support.cc", + ], + hdrs = [ + "eigen_support.h", + ], + copts = tflite_copts(), + deps = [ + ":op_macros", + "//tensorflow/contrib/lite:context", + "//third_party/eigen3", + ], +) + cc_library( name = "gemm_support", srcs = [ @@ -90,6 +108,10 @@ tf_cc_test( name = "kernel_util_test", size = "small", srcs = ["kernel_util_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":kernel_util", "//tensorflow/contrib/lite/testing:util", @@ -97,18 +119,32 @@ tf_cc_test( ], ) +tf_cc_test( + name = "test_util_test", + size = "small", + srcs = ["test_util_test.cc"], + deps = [ + ":test_util", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "builtin_ops", srcs = [ "activations.cc", "add.cc", + "audio_spectrogram.cc", "basic_rnn.cc", "batch_to_space_nd.cc", "bidirectional_sequence_lstm.cc", "bidirectional_sequence_rnn.cc", + "cast.cc", "concatenation.cc", "conv.cc", "depthwise_conv.cc", + "dequantize.cc", "div.cc", "embedding_lookup.cc", "embedding_lookup_sparse.cc", @@ -120,7 +156,9 @@ cc_library( "local_response_norm.cc", "lsh_projection.cc", "lstm.cc", + "maximum.cc", "mean.cc", + "mfcc.cc", "mul.cc", "pad.cc", "pooling.cc", @@ -154,21 +192,49 @@ cc_library( }), deps = [ ":activation_functor", + ":eigen_support", ":kernel_util", ":op_macros", "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/kernels:gemm_support", + "//tensorflow/contrib/lite/kernels/internal:audio_utils", "//tensorflow/contrib/lite/kernels/internal:kernel_utils", "//tensorflow/contrib/lite/kernels/internal:optimized", "//tensorflow/contrib/lite/kernels/internal:optimized_base", "//tensorflow/contrib/lite/kernels/internal:quantization_util", "//tensorflow/contrib/lite/kernels/internal:reference", "//tensorflow/contrib/lite/kernels/internal:reference_base", - "//tensorflow/contrib/lite/kernels/internal:round", "//tensorflow/contrib/lite/kernels/internal:tensor_utils", "@farmhash_archive//:farmhash", + "@flatbuffers", + ], +) + +tf_cc_test( + name = "audio_spectrogram_test", + size = "small", + srcs = ["audio_spectrogram_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) + +tf_cc_test( + name = "mfcc_test", + size = "small", + srcs = ["mfcc_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@flatbuffers", ], ) @@ -176,6 +242,10 @@ tf_cc_test( name = "activations_test", size = "small", srcs = ["activations_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -188,6 +258,42 @@ tf_cc_test( name = "add_test", size = "small", srcs = ["add_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "div_test", + size = "small", + srcs = ["div_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "sub_test", + size = "small", + srcs = ["sub_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -200,6 +306,10 @@ tf_cc_test( name = "transpose_test", size = "small", srcs = ["transpose_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -214,6 +324,10 @@ tf_cc_test( name = "space_to_batch_nd_test", size = "small", srcs = ["space_to_batch_nd_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -226,6 +340,22 @@ tf_cc_test( name = "batch_to_space_nd_test", size = "small", srcs = ["batch_to_space_nd_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "cast_test", + size = "small", + srcs = ["cast_test.cc"], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -238,6 +368,10 @@ tf_cc_test( name = "concatenation_test", size = "small", srcs = ["concatenation_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -250,6 +384,10 @@ tf_cc_test( name = "conv_test", size = "small", srcs = ["conv_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -263,10 +401,27 @@ tf_cc_test( name = "depthwise_conv_test", size = "small", srcs = ["depthwise_conv_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "dequantize_test", + size = "small", + srcs = ["dequantize_test.cc"], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_absl//absl/memory", "@com_google_googletest//:gtest", ], ) @@ -275,6 +430,10 @@ tf_cc_test( name = "basic_rnn_test", size = "small", srcs = ["basic_rnn_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -287,6 +446,10 @@ tf_cc_test( name = "bidirectional_sequence_lstm_test", size = "small", srcs = ["bidirectional_sequence_lstm_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -299,6 +462,10 @@ tf_cc_test( name = "unidirectional_sequence_lstm_test", size = "small", srcs = ["unidirectional_sequence_lstm_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -311,6 +478,9 @@ tf_cc_test( name = "bidirectional_sequence_rnn_test", size = "small", srcs = ["bidirectional_sequence_rnn_test.cc"], + tags = [ + "tflite_not_portable", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -323,6 +493,10 @@ tf_cc_test( name = "unidirectional_sequence_rnn_test", size = "small", srcs = ["unidirectional_sequence_rnn_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -335,6 +509,10 @@ tf_cc_test( name = "l2norm_test", size = "small", srcs = ["l2norm_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -347,6 +525,22 @@ tf_cc_test( name = "exp_test", size = "small", srcs = ["exp_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "maximum_test", + size = "small", + srcs = ["maximum_test.cc"], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -359,6 +553,10 @@ tf_cc_test( name = "mean_test", size = "small", srcs = ["mean_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -371,6 +569,10 @@ tf_cc_test( name = "mul_test", size = "small", srcs = ["mul_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -383,6 +585,10 @@ tf_cc_test( name = "pad_test", size = "small", srcs = ["pad_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -395,6 +601,10 @@ tf_cc_test( name = "reshape_test", size = "small", srcs = ["reshape_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -407,6 +617,10 @@ tf_cc_test( name = "gather_test", size = "small", srcs = ["gather_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:builtin_op_data", @@ -420,6 +634,10 @@ tf_cc_test( name = "topk_v2_test", size = "small", srcs = ["topk_v2_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:builtin_op_data", @@ -433,6 +651,10 @@ tf_cc_test( name = "resize_bilinear_test", size = "small", srcs = ["resize_bilinear_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -445,6 +667,10 @@ tf_cc_test( name = "svdf_test", size = "small", srcs = ["svdf_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -457,6 +683,10 @@ tf_cc_test( name = "embedding_lookup_test", size = "small", srcs = ["embedding_lookup_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -469,6 +699,10 @@ tf_cc_test( name = "embedding_lookup_sparse_test", size = "small", srcs = ["embedding_lookup_sparse_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -481,6 +715,10 @@ tf_cc_test( name = "fully_connected_test", size = "small", srcs = ["fully_connected_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -493,6 +731,10 @@ tf_cc_test( name = "local_response_norm_test", size = "small", srcs = ["local_response_norm_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -505,6 +747,10 @@ tf_cc_test( name = "pooling_test", size = "small", srcs = ["pooling_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -517,6 +763,10 @@ tf_cc_test( name = "softmax_test", size = "small", srcs = ["softmax_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -530,6 +780,10 @@ tf_cc_test( name = "log_softmax_test", size = "small", srcs = ["log_softmax_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -543,6 +797,10 @@ tf_cc_test( name = "lsh_projection_test", size = "small", srcs = ["lsh_projection_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -555,6 +813,10 @@ tf_cc_test( name = "hashtable_lookup_test", size = "small", srcs = ["hashtable_lookup_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -568,6 +830,10 @@ tf_cc_test( name = "lstm_test", size = "small", srcs = ["lstm_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -580,6 +846,10 @@ tf_cc_test( name = "skip_gram_test", size = "small", srcs = ["skip_gram_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -593,6 +863,10 @@ tf_cc_test( name = "space_to_depth_test", size = "small", srcs = ["space_to_depth_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -605,6 +879,10 @@ tf_cc_test( name = "split_test", size = "small", srcs = ["split_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -617,6 +895,10 @@ tf_cc_test( name = "squeeze_test", size = "small", srcs = ["squeeze_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -629,6 +911,10 @@ tf_cc_test( name = "strided_slice_test", size = "small", srcs = ["strided_slice_test.cc"], + tags = [ + "tflite_not_portable_ios_arm64", + "tflite_not_portable_ios_x86_64", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -648,3 +934,5 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) + +tflite_portable_test_suite() diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index 6acded3091cb820ba641bac2498799d295d7dc7f..39a54c93962b33f3a787b3387d9a133119d0e80a 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -63,6 +63,33 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArrayCopy(input->dims)); } +TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + if (input->type == kTfLiteUInt8) { + static constexpr int kInputIntegerBits = 4; + + const double input_real_multiplier = + input->params.scale * + static_cast(1 << (31 - kInputIntegerBits)); + + QuantizeMultiplierGreaterThanOne(input_real_multiplier, + &data->input_multiplier, + &data->input_left_shift); + data->input_range_radius = + CalculateInputRadius(kInputIntegerBits, data->input_left_shift); + } + + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) { OpData* data = reinterpret_cast(node->user_data); @@ -123,6 +150,34 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArrayCopy(input->dims)); } +TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* alpha = GetInput(context, node, 1); + + output->type = input->type; + + // Currently only Float32 is supported + // TODO(ycling): Support other data types. + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, alpha->type, kTfLiteFloat32); + + // Currently, only support 4D `input` and 3D `alpha` with shape + // (1, 1, channels). + // TODO(impjdi): Support other cases where `alpha` is broadcastable + // to `input`. + TF_LITE_ENSURE_EQ(context, input->dims->size, 4); + TF_LITE_ENSURE_EQ(context, alpha->dims->size, 3); + TF_LITE_ENSURE_EQ(context, alpha->dims->data[0], 1); + TF_LITE_ENSURE_EQ(context, alpha->dims->data[1], 1); + TF_LITE_ENSURE_EQ(context, alpha->dims->data[2], input->dims->data[3]); + + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); @@ -180,6 +235,7 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { } TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); switch (input->type) { @@ -191,6 +247,14 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { for (; in < in_end; in++, out++) *out = std::tanh(*in); return kTfLiteOk; } break; + case kTfLiteUInt8: { + optimized_ops::Tanh(GetTensorData(input), GetTensorDims(input), + input->params.zero_point, data->input_range_radius, + data->input_multiplier, data->input_left_shift, + GetTensorData(output), + GetTensorDims(output)); + return kTfLiteOk; + } break; default: context->ReportError(context, "Only float32 supported currently."); return kTfLiteError; @@ -352,6 +416,35 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } } +TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* alpha = GetInput(context, node, 1); + TfLiteTensor* output = GetOutput(context, node, 0); + + if (input->type != kTfLiteFloat32) { + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } + TF_LITE_ENSURE_EQ(context, input->dims->size, 4); + const int batches = input->dims->data[0]; + const int height = input->dims->data[1]; + const int width = input->dims->data[2]; + const int channels = input->dims->data[3]; + + TF_LITE_ENSURE_EQ(context, alpha->dims->size, 3); + TF_LITE_ENSURE_EQ(context, alpha->dims->data[0], 1); + TF_LITE_ENSURE_EQ(context, alpha->dims->data[1], 1); + TF_LITE_ENSURE_EQ(context, alpha->dims->data[2], channels); + + const int n = batches * height * width * channels; + for (int i = 0; i < n; ++i) { + const float x = input->data.f[i]; + output->data.f[i] = x >= 0.0f ? x : alpha->data.f[i % channels] * x; + } + + return kTfLiteOk; +} + } // namespace activations TfLiteRegistration* Register_RELU() { @@ -376,8 +469,8 @@ TfLiteRegistration* Register_RELU6() { } TfLiteRegistration* Register_TANH() { - static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, - activations::GenericPrepare, + static TfLiteRegistration r = {activations::Init, activations::Free, + activations::TanhPrepare, activations::TanhEval}; return &r; } @@ -403,6 +496,13 @@ TfLiteRegistration* Register_LOG_SOFTMAX() { return &r; } +TfLiteRegistration* Register_PRELU() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::PreluPrepare, + activations::PreluEval}; + return &r; +} + } // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc index 302e52b96db0206f77eb4c8fcffd565b1db0cd3e..50a84edd475c8051a563cf8ed9fc03099829b786 100644 --- a/tensorflow/contrib/lite/kernels/activations_test.cc +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -52,6 +52,14 @@ class BaseActivationsOpModel : public SingleOpModel { BuildInterpreter({GetShape(input_)}); } + BaseActivationsOpModel(BuiltinOperator type, const TensorData &input, + const TensorData &output) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(type, BuiltinOptions_NONE, 0); + BuildInterpreter({GetShape(input_)}); + } + protected: int input_; int output_; @@ -143,6 +151,27 @@ TEST(FloatActivationsOpTest, Tanh) { }))); } +TEST(QuantizedActivationsOpTest, Tanh) { + QuantizedActivationsOpModel m( + BuiltinOperator_TANH, + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -8, 8}, + /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, -1, 1}); + m.SetInput({ + 0, -6, 2, 4, // + -4, -2, 8, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.0, -0.999987, 0.964027, 0.999329, // + -0.996078, -0.96402, 0.99999, 0.76159, // + }, + 4 * (1. / 256)))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({128, 0, 251, 255, 0, 5, 255, 226})); +} + TEST(FloatActivationsOpTest, Sigmoid) { FloatActivationsOpModel m(BuiltinOperator_LOGISTIC, /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); @@ -354,6 +383,49 @@ TEST(FloatActivationsOpTest, LogSoftmax) { }))); } +class PReluOpModel : public SingleOpModel { + public: + PReluOpModel(const TensorData& input, const TensorData& alpha) { + input_ = AddInput(input); + alpha_ = AddInput(alpha); + output_ = AddOutput(input); + SetBuiltinOp(BuiltinOperator_PRELU, BuiltinOptions_NONE, 0); + BuildInterpreter({GetShape(input_), GetShape(alpha_)}); + } + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetAlpha(std::initializer_list data) { + PopulateTensor(alpha_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int alpha_; + int output_; +}; + +TEST(FloatActivationsOpTest, PRelu) { + PReluOpModel m({TensorType_FLOAT32, {1, 2, 2, 3}}, + {TensorType_FLOAT32, {1, 1, 3}}); + + m.SetInput({ + 0.0f, 0.0f, 0.0f, // Row 1, Column 1 + 1.0f, 1.0f, 1.0f, // Row 1, Column 2 + -1.0f, -1.0f, -1.0f, // Row 2, Column 1 + -2.0f, -2.0f, -2.0f, // Row 1, Column 2 + }); + m.SetAlpha({0.0f, 1.0f, 2.0f}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0.0f, 0.0f, 0.0f, // Row 1, Column 1 + 1.0f, 1.0f, 1.0f, // Row 1, Column 2 + 0.0f, -1.0f, -2.0f, // Row 2, Column 1 + 0.0f, -2.0f, -4.0f, // Row 1, Column 2 + })); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc new file mode 100644 index 0000000000000000000000000000000000000000..602f3888c10b3790dc0328c817bdd83276544b25 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc @@ -0,0 +1,165 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/spectrogram.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +#include "flatbuffers/flexbuffers.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace audio_spectrogram { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +enum KernelType { + kReference, +}; + +typedef struct { + int window_size; + int stride; + bool magnitude_squared; + int output_height; + internal::Spectrogram* spectrogram; +} TfLiteAudioSpectrogramParams; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new TfLiteAudioSpectrogramParams; + + const uint8_t* buffer_t = reinterpret_cast(buffer); + + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); + data->window_size = m["window_size"].AsInt64(); + data->stride = m["stride"].AsInt64(); + data->magnitude_squared = m["magnitude_squared"].AsBool(); + + data->spectrogram = new internal::Spectrogram; + + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + auto* params = reinterpret_cast(buffer); + delete params->spectrogram; + delete params; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2); + + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + TF_LITE_ENSURE(context, params->spectrogram->Initialize(params->window_size, + params->stride)); + const int64_t sample_count = input->dims->data[0]; + const int64_t length_minus_window = (sample_count - params->window_size); + if (length_minus_window < 0) { + params->output_height = 0; + } else { + params->output_height = 1 + (length_minus_window / params->stride); + } + TfLiteIntArray* output_size = TfLiteIntArrayCreate(3); + output_size->data[0] = input->dims->data[1]; + output_size->data[1] = params->output_height; + output_size->data[2] = params->spectrogram->output_frequency_channels(); + + return context->ResizeTensor(context, output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->user_data); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE(context, params->spectrogram->Initialize(params->window_size, + params->stride)); + + const float* input_data = GetTensorData(input); + + const int64_t sample_count = input->dims->data[0]; + const int64_t channel_count = input->dims->data[1]; + + const int64_t output_width = params->spectrogram->output_frequency_channels(); + + float* output_flat = GetTensorData(output); + + std::vector input_for_channel(sample_count); + for (int64_t channel = 0; channel < channel_count; ++channel) { + float* output_slice = + output_flat + (channel * params->output_height * output_width); + for (int i = 0; i < sample_count; ++i) { + input_for_channel[i] = input_data[i * channel_count + channel]; + } + std::vector> spectrogram_output; + TF_LITE_ENSURE(context, + params->spectrogram->ComputeSquaredMagnitudeSpectrogram( + input_for_channel, &spectrogram_output)); + TF_LITE_ENSURE_EQ(context, spectrogram_output.size(), + params->output_height); + TF_LITE_ENSURE(context, spectrogram_output.empty() || + (spectrogram_output[0].size() == output_width)); + for (int row_index = 0; row_index < params->output_height; ++row_index) { + const std::vector& spectrogram_row = spectrogram_output[row_index]; + TF_LITE_ENSURE_EQ(context, spectrogram_row.size(), output_width); + float* output_row = output_slice + (row_index * output_width); + if (params->magnitude_squared) { + for (int i = 0; i < output_width; ++i) { + output_row[i] = spectrogram_row[i]; + } + } else { + for (int i = 0; i < output_width; ++i) { + output_row[i] = sqrtf(spectrogram_row[i]); + } + } + } + } + return kTfLiteOk; +} + +} // namespace audio_spectrogram + +TfLiteRegistration* Register_AUDIO_SPECTROGRAM() { + static TfLiteRegistration r = { + audio_spectrogram::Init, audio_spectrogram::Free, + audio_spectrogram::Prepare, + audio_spectrogram::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8d460fdfc610ef9a867acd492ca0558fb6eab8c3 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc @@ -0,0 +1,122 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_AUDIO_SPECTROGRAM(); + +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +class BaseAudioSpectrogramOpModel : public SingleOpModel { + public: + BaseAudioSpectrogramOpModel(const TensorData& input1, + const TensorData& output, int window_size, + int stride, bool magnitude_squared) { + input1_ = AddInput(input1); + output_ = AddOutput(output); + + flexbuffers::Builder fbb; + fbb.Map([&]() { + fbb.Int("window_size", window_size); + fbb.Int("stride", stride); + fbb.Bool("magnitude_squared", magnitude_squared); + }); + fbb.Finish(); + SetCustomOp("AudioSpectrogram", fbb.GetBuffer(), + Register_AUDIO_SPECTROGRAM); + BuildInterpreter({GetShape(input1_)}); + } + + int input1() { return input1_; } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input1_; + int output_; +}; + +TEST(BaseAudioSpectrogramOpModel, NonSquaredTest) { + BaseAudioSpectrogramOpModel m({TensorType_FLOAT32, {8, 1}}, + {TensorType_FLOAT32, {}}, 8, 1, false); + m.PopulateTensor(m.input1(), + {-1.0f, 0.0f, 1.0f, 0.0f, -1.0f, 0.0f, 1.0f, 0.0f}); + + m.Invoke(); + + std::vector output_shape = m.GetOutputShape(); + EXPECT_EQ(3, output_shape.size()); + EXPECT_THAT(output_shape, ElementsAre(1, 1, 5)); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + {0.0f, 1.0f, 2.0f, 1.0f, 0.0f}, 1e-3))); +} + +TEST(SpectrogramOpTest, SquaredTest) { + BaseAudioSpectrogramOpModel m({TensorType_FLOAT32, {8, 1}}, + {TensorType_FLOAT32, {}}, 8, 1, true); + m.PopulateTensor(m.input1(), + {-1.0f, 0.0f, 1.0f, 0.0f, -1.0f, 0.0f, 1.0f, 0.0f}); + + m.Invoke(); + + std::vector output_shape = m.GetOutputShape(); + EXPECT_EQ(3, output_shape.size()); + EXPECT_THAT(output_shape, ElementsAre(1, 1, 5)); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + {0.f, 1.f, 4.f, 1.f, 0.f}, 1e-3))); +} + +TEST(SpectrogramOpTest, StrideTest) { + BaseAudioSpectrogramOpModel m({TensorType_FLOAT32, {10, 1}}, + {TensorType_FLOAT32, {}}, 8, 2, true); + m.PopulateTensor(m.input1(), {-1.0f, 0.0f, 1.0f, 0.0f, -1.0f, 0.0f, + 1.0f, 0.0f, 1.0f, 0.0f}); + + m.Invoke(); + + std::vector output_shape = m.GetOutputShape(); + EXPECT_THAT(output_shape, ElementsAre(1, 2, 5)); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + {0, 1, 4, 1, 0, 1, 2, 1, 2, 1}, 1e-3))); +} + +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc index 8d70df5e21fab110be238a6f72abe9aac8a75622..a64ac42bc43336db928d2682e290f5263f3db0f4 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" @@ -443,166 +444,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -// Performs an LSTM batch inference step for input specified by input_ptr_batch. -// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and -// biases (*_bias_ptr), and buffers (*_scratch), along with additional -// parameters: -// - params: various LSTM params including activation, clipping, etc., -// - use_cifg: use coupled input forget gates, -// - use_peephole: whether to use peephole connection or not, -// - n_batch: size of batch, -// - n_cell: number of cells (or units), -// - n_input: the input size, -// - n_output: the output size. -// -// The pointers to the hidden state and the output are updated as a result. -// -// The pointers with the suffix "_batch" point to data aligned in batch_major -// order, and each step processes batch_size many inputs from input_ptr_batch, -// and updates batch_size many outputs and hidden states. -void LstmBatchStep( - const float* input_ptr_batch, const float* input_to_input_weights_ptr, - const float* input_to_forget_weights_ptr, - const float* input_to_cell_weights_ptr, - const float* input_to_output_weights_ptr, - const float* recurrent_to_input_weights_ptr, - const float* recurrent_to_forget_weights_ptr, - const float* recurrent_to_cell_weights_ptr, - const float* recurrent_to_output_weights_ptr, - const float* cell_to_input_weights_ptr, - const float* cell_to_forget_weights_ptr, - const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const float* projection_weights_ptr, - const float* projection_bias_ptr, const TfLiteLSTMParams* params, - bool use_cifg, bool use_peephole, int n_batch, int n_cell, int n_input, - int n_output, float* output_state_ptr, float* cell_state_ptr, - float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch, - float* output_gate_scratch, float* output_ptr_time) { - // Initialize scratch buffers with bias. - if (!use_cifg) { - tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch, - input_gate_scratch); - } - tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, - forget_gate_scratch); - tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, - cell_scratch); - tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, - output_gate_scratch); - - // For each batch and cell: compute input_weight * input. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - input_gate_scratch, /*result_stride=*/1); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - forget_gate_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - cell_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - output_gate_scratch, /*result_stride=*/1); - - // For each batch and cell: compute recurrent_weight * output_state. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, input_gate_scratch, - /*result_stride=*/1); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, forget_gate_scratch, - /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, cell_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, output_gate_scratch, - /*result_stride=*/1); - - // For each batch and cell: update input gate. - if (!use_cifg) { - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch, - input_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, - input_gate_scratch); - } - - // For each batch and cell: update forget gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch, - forget_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, - forget_gate_scratch); - - // For each batch and cell: update the cell. - tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, - n_batch * n_cell, cell_state_ptr); - tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, - params->activation, cell_scratch); - if (use_cifg) { - tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, - forget_gate_scratch); - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); - } else { - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); - } - if (params->cell_clip > 0.0) { - tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, - params->cell_clip, cell_state_ptr); - } - - // For each batch and cell: update the output gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch, - output_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, - output_gate_scratch); - tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, - params->activation, cell_scratch); - tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, - n_batch * n_cell, output_gate_scratch); - - // For each batch: update the projection and output_state. - const bool use_projection_weight = (projection_weights_ptr != nullptr); - const bool use_projection_bias = (projection_bias_ptr != nullptr); - if (use_projection_weight) { - if (use_projection_bias) { - tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, - n_batch, output_ptr_time); - } else { - tensor_utils::ZeroVector(output_ptr_time, n_batch * n_output); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch, - output_ptr_time, /*result_stride=*/1); - if (params->proj_clip > 0.0) { - tensor_utils::ClipVector(output_ptr_time, n_batch * n_output, - params->proj_clip, output_ptr_time); - } - } else { - tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, - output_ptr_time); - } - tensor_utils::CopyVector(output_ptr_time, n_batch * n_output, - output_state_ptr); -} - // The LSTM Op engine. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); @@ -756,7 +597,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const float* input_ptr_batch = input->data.f + t * n_batch * n_input; float* output_ptr_time = fw_output->data.f + t * n_batch * n_fw_output; - LstmBatchStep( + kernel_utils::LstmStep( input_ptr_batch, fw_input_to_input_weights_ptr, fw_input_to_forget_weights->data.f, fw_input_to_cell_weights->data.f, fw_input_to_output_weights->data.f, fw_recurrent_to_input_weights_ptr, @@ -766,11 +607,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { fw_cell_to_forget_weights_ptr, fw_cell_to_output_weights_ptr, fw_input_gate_bias_ptr, fw_forget_gate_bias->data.f, fw_cell_bias->data.f, fw_output_gate_bias->data.f, - fw_projection_weights_ptr, fw_projection_bias_ptr, params, fw_use_cifg, - fw_use_peephole, n_batch, n_fw_cell, n_input, n_fw_output, - fw_output_state->data.f, fw_cell_state->data.f, fw_input_gate_scratch, - fw_forget_gate_scratch, fw_cell_scratch, fw_output_gate_scratch, - output_ptr_time); + fw_projection_weights_ptr, fw_projection_bias_ptr, params, n_batch, + n_fw_cell, n_input, n_fw_output, fw_output_state->data.f, + fw_cell_state->data.f, fw_input_gate_scratch, fw_forget_gate_scratch, + fw_cell_scratch, fw_output_gate_scratch, output_ptr_time); } // n_cell and n_output will be the same size when there is no projection. @@ -828,7 +668,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const float* input_ptr_batch = input->data.f + t * n_batch * n_input; float* output_ptr_time = bw_output->data.f + t * n_batch * n_bw_output; - LstmBatchStep( + kernel_utils::LstmStep( input_ptr_batch, bw_input_to_input_weights_ptr, bw_input_to_forget_weights->data.f, bw_input_to_cell_weights->data.f, bw_input_to_output_weights->data.f, bw_recurrent_to_input_weights_ptr, @@ -838,11 +678,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { bw_cell_to_forget_weights_ptr, bw_cell_to_output_weights_ptr, bw_input_gate_bias_ptr, bw_forget_gate_bias->data.f, bw_cell_bias->data.f, bw_output_gate_bias->data.f, - bw_projection_weights_ptr, bw_projection_bias_ptr, params, bw_use_cifg, - bw_use_peephole, n_batch, n_bw_cell, n_input, n_bw_output, - bw_output_state->data.f, bw_cell_state->data.f, bw_input_gate_scratch, - bw_forget_gate_scratch, bw_cell_scratch, bw_output_gate_scratch, - output_ptr_time); + bw_projection_weights_ptr, bw_projection_bias_ptr, params, n_batch, + n_bw_cell, n_input, n_bw_output, bw_output_state->data.f, + bw_cell_state->data.f, bw_input_gate_scratch, bw_forget_gate_scratch, + bw_cell_scratch, bw_output_gate_scratch, output_ptr_time); } // Backward step. diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc new file mode 100644 index 0000000000000000000000000000000000000000..19942de7bc0c083f192a4b337b224b778d991140 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/cast.cc @@ -0,0 +1,99 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace cast { +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + +template +void copyCast(const FromT* in, ToT* out, int num_elements) { + std::transform(in, in + num_elements, out, + [](FromT a) { return static_cast(a); }); +} + +template +TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out, + int num_elements) { + switch (out->type) { + case kTfLiteInt64: + copyCast(in, out->data.i64, num_elements); + break; + case kTfLiteInt32: + copyCast(in, out->data.i32, num_elements); + break; + case kTfLiteUInt8: + copyCast(in, out->data.uint8, num_elements); + break; + case kTfLiteFloat32: + copyCast(in, out->data.f, num_elements); + break; + default: + // Unsupported type. + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const int num_elements = NumElements(input); + TF_LITE_ENSURE_EQ(context, num_elements, NumElements(output)); + switch (input->type) { + case kTfLiteInt64: + return copyToTensor(input->data.i64, output, num_elements); + case kTfLiteInt32: + return copyToTensor(input->data.i32, output, num_elements); + case kTfLiteUInt8: + return copyToTensor(input->data.uint8, output, num_elements); + case kTfLiteFloat32: + return copyToTensor(input->data.f, output, num_elements); + default: + // Unsupported type. + return kTfLiteError; + } + return kTfLiteOk; +} +} // namespace cast + +TfLiteRegistration* Register_CAST() { + static TfLiteRegistration r = {nullptr, nullptr, cast::Prepare, cast::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/cast_test.cc b/tensorflow/contrib/lite/kernels/cast_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e56482a371550b6275a6380e2beebe3cef958ff --- /dev/null +++ b/tensorflow/contrib/lite/kernels/cast_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class CastOpModel : public SingleOpModel { + public: + CastOpModel(const TensorData& input, const TensorData& output) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_CAST, BuiltinOptions_CastOptions, + CreateCastOptions(builder_).Union()); + BuildInterpreter({GetShape(input_)}); + } + + int input() const { return input_; } + int output() const { return output_; } + + protected: + int input_; + int output_; +}; + +TEST(CastOpModel, CastIntToFloat) { + CastOpModel m({TensorType_INT64, {2, 3}}, {TensorType_FLOAT32, {2, 3}}); + m.PopulateTensor(m.input(), {100, 200, 300, 400, 500, 600}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f})); +} + +TEST(CastOpModel, CastFloatToInt) { + CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT32, {3, 2}}); + m.PopulateTensor(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({100, 20, 3, 0, 0, 1})); +} + +} // namespace +} // namespace tflite +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index b93a416351cae34b2df8791e382a8a2cd38dcffb..18ff33bf9f55ac1d25bb3392e714686c5305c2b8 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/eigen_support.h" #include "tensorflow/contrib/lite/kernels/gemm_support.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h" @@ -43,6 +44,8 @@ namespace conv { enum KernelType { kReference, kGenericOptimized, // Neon-free + // kMultithreadOptimized is a mixture of an Eigen-based kernel when threads + // are available and kGenericOptimized when we must use only one thread. kMultithreadOptimized, // The kernel uses use CBLAS interface for matrix multiplication. // It's fast when an optimized CBLAS implementation is available (e.g. Apple @@ -61,7 +64,7 @@ struct OpData { TfLitePaddingValues padding; // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multipler plus a left shift. + // be represented as a fixed point multiplier plus a left shift. int32_t output_multiplier; int output_shift; // The range of the fused activation layer. For example for kNone and @@ -75,6 +78,8 @@ struct OpData { bool need_hwcn_weights; bool have_weights_been_transposed; bool need_im2col; + + bool run_multithreaded_kernel; }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -83,10 +88,12 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { // to carry information from Prepare() to Eval(). auto* data = new OpData; gemm_support::IncrementUsageCounter(context); + eigen_support::IncrementUsageCounter(context); return data; } void Free(TfLiteContext* context, void* buffer) { + eigen_support::DecrementUsageCounter(context); gemm_support::DecrementUsageCounter(context); delete reinterpret_cast(buffer); } @@ -137,7 +144,8 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context, // buffer to store the results. // This path is only used for float processing, so only create the buffer if // we're running with that data type. - data->need_hwcn_weights = (input->type == kTfLiteFloat32); + data->need_hwcn_weights = + (input->type == kTfLiteFloat32 && data->run_multithreaded_kernel); int temporaries_count = 0; if (data->need_im2col) { @@ -165,6 +173,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); + data->run_multithreaded_kernel = context->recommended_num_threads != 1; + TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(context, node)); bool hasBias = node->inputs->size == 3; @@ -449,8 +459,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // separate ops to avoid dispatch overhead here. switch (input->type) { // Already know in/outtypes are same. case kTfLiteFloat32: - EvalFloat(context, node, params, data, input, filter, bias, - im2col, hwcn_weights, output); + if (data->run_multithreaded_kernel) { + EvalFloat(context, node, params, data, input, filter, bias, + im2col, hwcn_weights, output); + } else { + EvalFloat(context, node, params, data, input, filter, + bias, im2col, hwcn_weights, output); + } break; case kTfLiteUInt8: EvalQuantized(context, node, params, data, input, filter, diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc index 15dbfe08c82befcf001b9ed9a053528b5606053e..cad9ce114c8387047af2b63bee704035fd329330 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc +++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc @@ -52,7 +52,7 @@ enum KernelType { struct OpData { TfLitePaddingValues padding; // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multipler plus a left shift. + // be represented as a fixed point multiplier plus a left shift. int32_t output_multiplier; int output_shift; // The range of the fused activation layer. For example for kNone and diff --git a/tensorflow/contrib/lite/kernels/dequantize.cc b/tensorflow/contrib/lite/kernels/dequantize.cc new file mode 100644 index 0000000000000000000000000000000000000000..e685f2465f627cf30e02564e6f16e1ec69e208e2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/dequantize.cc @@ -0,0 +1,77 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace dequantize { + +struct OpContext { + OpContext(TfLiteContext* context, TfLiteNode* node) { + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + TfLiteTensor* input; + TfLiteTensor* output; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + OpContext op_context(context, node); + + TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8); + + op_context.output->type = kTfLiteFloat32; + return context->ResizeTensor(context, op_context.output, + TfLiteIntArrayCopy(op_context.input->dims)); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + + auto zero_point = op_context.input->params.zero_point; + auto scale = op_context.input->params.scale; + + optimized_ops::Dequantize(GetTensorData(op_context.input), + GetTensorDims(op_context.input), zero_point, scale, + GetTensorData(op_context.output), + GetTensorDims(op_context.output)); + return kTfLiteOk; +} + +} // namespace dequantize + +TfLiteRegistration* Register_DEQUANTIZE_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, dequantize::Prepare, + dequantize::Eval}; + return &r; +} + +TfLiteRegistration* Register_DEQUANTIZE() { return Register_DEQUANTIZE_OPT(); } + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/dequantize_test.cc b/tensorflow/contrib/lite/kernels/dequantize_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fcd74206177a0a97db168338e3619d4b95c052a9 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/dequantize_test.cc @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class DequantizeOpModel : public SingleOpModel { + public: + DequantizeOpModel(std::initializer_list shape, float min, float max) { + input_ = AddInput({TensorType_UINT8, shape, min, max}); + output_ = AddOutput({TensorType_FLOAT32, shape}); + SetBuiltinOp(BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions, + CreateDequantizeOptions(builder_).Union()); + + BuildInterpreter({GetShape(input_)}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(SplitOpTest, FourDimensional) { + DequantizeOpModel m({2, 5}, -63.5, 64); + + m.SetInput({0, 1, 2, 3, 4, 251, 252, 253, 254, 255}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64}))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc index c77a0de9b7d0a61dd7c350649f119db3e32c21b9..ec380c8e4956e5bcd0d7559bfd8f89a52d9d233c 100644 --- a/tensorflow/contrib/lite/kernels/div.cc +++ b/tensorflow/contrib/lite/kernels/div.cc @@ -106,43 +106,7 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, #undef TF_LITE_DIV } -template -void EvalQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteDivParams* params, const OpData* data, - TfLiteTensor* input1, TfLiteTensor* input2, - TfLiteTensor* output) { - auto input1_offset = -input1->params.zero_point; - auto input2_offset = -input2->params.zero_point; - auto output_offset = output->params.zero_point; - - int32_t output_multiplier; - int output_shift; - - double real_multiplier = - input1->params.scale * input2->params.scale / output->params.scale; - QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, - &output_shift); - - int32 output_activation_min, output_activation_max; - CalculateActivationRangeUint8(params->activation, output, - &output_activation_min, &output_activation_max); - -#define TF_LITE_DIV(type, opname) \ - type::opname(GetTensorData(input1), GetTensorDims(input1), \ - input1_offset, GetTensorData(input2), \ - GetTensorDims(input2), input2_offset, output_offset, \ - output_multiplier, output_shift, output_activation_min, \ - output_activation_max, GetTensorData(output), \ - GetTensorDims(output)); - // The quantized version of Div doesn't support activations, so we - // always use BroadcastDiv. - if (kernel_type == kReference) { - TF_LITE_DIV(reference_ops, BroadcastDiv); - } else { - TF_LITE_DIV(optimized_ops, BroadcastDiv); - } -#undef TF_LITE_DIV -} + template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { @@ -155,9 +119,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (output->type == kTfLiteFloat32) { EvalFloat(context, node, params, data, input1, input2, output); - } else if (output->type == kTfLiteUInt8) { - EvalQuantized(context, node, params, data, input1, input2, - output); } else { context->ReportError(context, "Div only supports FLOAT32 and quantized UINT8 now."); diff --git a/tensorflow/contrib/lite/kernels/div_test.cc b/tensorflow/contrib/lite/kernels/div_test.cc index e67e0ec034a585d0f2f06c6681d83d11e618ef47..276b8289fbc1b4dcbf4624b76b854300d0fd4912 100644 --- a/tensorflow/contrib/lite/kernels/div_test.cc +++ b/tensorflow/contrib/lite/kernels/div_test.cc @@ -52,23 +52,6 @@ class FloatDivOpModel : public BaseDivOpModel { std::vector GetOutput() { return ExtractVector(output_); } }; -// For quantized Div, the error shouldn't exceed (2*step + step^2). -// The param min=-1.0 & max=1.0 is used in the following tests. -// The tolerance value is ~0.0157. -const float kQuantizedStep = 2.0 / 255.0; -const float kQuantizedTolerance = - 2.0 * kQuantizedStep + kQuantizedStep * kQuantizedStep; - -class QuantizedDivOpModel : public BaseDivOpModel { - public: - using BaseDivOpModel::BaseDivOpModel; - - std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); - } -}; - TEST(FloatDivOpTest, NoActivation) { FloatDivOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, @@ -125,46 +108,6 @@ TEST(FloatDivOpTest, WithBroadcast) { } } -TEST(QuantizedDivOpTest, NoActivation) { - QuantizedDivOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, - {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, - {TensorType_UINT8, {}, -1.0, 1.0}, - ActivationFunctionType_NONE); - m.QuantizeAndPopulate(m.input1(), {-0.6, 0.2, 0.9, -0.7}); - m.QuantizeAndPopulate(m.input2(), {0.8, 0.4, 0.9, -0.8}); - m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), - ElementsAreArray(ArrayFloatNear({-0.75, 0.5, 1.0, 0.875}, - kQuantizedTolerance))); -} - -// for quantized Div, the error shouldn't exceed 2*step -float GetTolerance(int min, int max) { - float kQuantizedStep = (max - min) / 255.0; - float kQuantizedTolerance = 2.0 * kQuantizedStep; - return kQuantizedTolerance; -} - -TEST(QuantizedDivOpTest, WithBroadcast) { - float kQuantizedTolerance = GetTolerance(-3.0, 3.0); - std::vector> test_shapes = { - {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; - for (int i = 0; i < test_shapes.size(); ++i) { - QuantizedDivOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, - {TensorType_UINT8, {}, -3.0, 3.0}, // always a scalar - {TensorType_UINT8, {}, -3.0, 3.0}, - ActivationFunctionType_NONE); - m.QuantizeAndPopulate(m.input1(), {-0.2, 0.2, 0.07, - 0.08, 0.11, -0.123}); - m.QuantizeAndPopulate(m.input2(), {0.1}); - m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), - ElementsAreArray(ArrayFloatNear( - {-2.0, 2.0, 0.7, 0.8, 1.1, -1.23}, kQuantizedTolerance))) - << "With shape number " << i; - } -} - } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc new file mode 100644 index 0000000000000000000000000000000000000000..f1fdb42624073717fb70423ff70dfad08e578ca6 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/eigen_support.cc @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/eigen_support.h" + +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace eigen_support { + +struct RefCountedEigenContext { + int num_references = 0; +}; + +void IncrementUsageCounter(TfLiteContext* context) { + auto* ptr = reinterpret_cast(context->eigen_context); + if (ptr == nullptr) { + if (context->recommended_num_threads != -1) { + Eigen::setNbThreads(context->recommended_num_threads); + } + ptr = new RefCountedEigenContext; + ptr->num_references = 0; + context->eigen_context = ptr; + } + ptr->num_references++; +} + +void DecrementUsageCounter(TfLiteContext* context) { + auto* ptr = reinterpret_cast(context->eigen_context); + if (ptr == nullptr) { + TF_LITE_FATAL( + "Call to DecrementUsageCounter() not preceded by " + "IncrementUsageCounter()"); + } + if (--ptr->num_references == 0) { + delete ptr; + context->eigen_context = nullptr; + } +} + +void SetNumThreads(TfLiteContext* context, int num_threads) { + IncrementUsageCounter(context); + Eigen::setNbThreads(num_threads); + DecrementUsageCounter(context); +} + +} // namespace eigen_support +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h new file mode 100644 index 0000000000000000000000000000000000000000..aa8c351fd8e8dae45f7d4807ce24d80bb393c41c --- /dev/null +++ b/tensorflow/contrib/lite/kernels/eigen_support.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_ + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { +namespace eigen_support { + +// Let the framework know that the op will be using Eigen. If necessary a set of +// temporary Eigen objects might be created and placed in 'context'. +void IncrementUsageCounter(TfLiteContext* context); + +// Let the framework know that the op stopped using Eigen. If there are no more +// usages all temporary Eigen objects will be deleted. +void DecrementUsageCounter(TfLiteContext* context); + +// Set the number of threads that can be used by Eigen. +void SetNumThreads(TfLiteContext* context, int num_threads); + +} // namespace eigen_support +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_ diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc index a77fe94e499078bc2f0660e8e49fd557ed0f625d..888e67966c0a408257e763a405bf6e928310f4d9 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -48,7 +48,7 @@ enum KernelType { struct OpData { // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multipler plus a left shift. + // be represented as a fixed point multiplier plus a left shift. int32_t output_multiplier; int output_shift; // The range of the fused activation layer. For example for kNone and diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/contrib/lite/kernels/gemm_support.cc index eb2b0aacf7ecc3ed5dbde5ccce7a46dcda0a93b3..95f45ea768be7f9bae9570563f161792afbff436 100644 --- a/tensorflow/contrib/lite/kernels/gemm_support.cc +++ b/tensorflow/contrib/lite/kernels/gemm_support.cc @@ -29,6 +29,9 @@ void IncrementUsageCounter(TfLiteContext* context) { if (ptr == nullptr) { ptr = new RefCountedGemmContext; ptr->gemm_context_ = new gemmlowp::GemmContext(); + if (context->recommended_num_threads != -1) { + ptr->gemm_context_->set_max_num_threads(context->recommended_num_threads); + } ptr->num_references_ = 0; context->gemm_context = ptr; } @@ -58,7 +61,7 @@ gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) { return ptr->gemm_context_; } -void SetMaxNumThreads(TfLiteContext* context, int num_threads) { +void SetNumThreads(TfLiteContext* context, int num_threads) { IncrementUsageCounter(context); GetFromContext(context)->set_max_num_threads(num_threads); DecrementUsageCounter(context); diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h index 466781cbcecc7fb851d9078c450cc6c12364d2bb..f033501cb6e341aa014fa4d956b531bd79aa555b 100644 --- a/tensorflow/contrib/lite/kernels/gemm_support.h +++ b/tensorflow/contrib/lite/kernels/gemm_support.h @@ -45,8 +45,8 @@ void IncrementUsageCounter(TfLiteContext* context); // 'context'. If there are no more usages the GemmContext will be deleted. void DecrementUsageCounter(TfLiteContext* context); -// Set the maximum number threads available for gemmlowp operations. -void SetMaxNumThreads(TfLiteContext* context, int num_threads); +// Set the number of threads that can be used by gemmlowp. +void SetNumThreads(TfLiteContext* context, int num_threads); } // namespace gemm_support } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index f47fb04cbaa688b75e763ff9d3cb7df44ac3f166..aa3957bee133c8b51a82e9c62884ce365e086d2e 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -10,21 +10,25 @@ tflite_deps_intel = [ "@arm_neon_2_x86_sse", ] +HARD_FP_FLAGS_IF_APPLICABLE = select({ + "//tensorflow:android_arm": ["-mfloat-abi=softfp"], + "//tensorflow:android_arm64": ["-mfloat-abi=softfp"], + "//tensorflow:android_armeabi": ["-mfloat-abi=softfp"], + "//conditions:default": [], +}) + NEON_FLAGS_IF_APPLICABLE = select({ ":arm": [ "-O3", "-mfpu=neon", - "-mfloat-abi=softfp", ], ":armeabi-v7a": [ "-O3", "-mfpu=neon", - "-mfloat-abi=softfp", ], ":armv7a": [ "-O3", "-mfpu=neon", - "-mfloat-abi=softfp", ], "//conditions:default": [ "-O3", @@ -145,6 +149,7 @@ cc_library( "common.h", "optimized/depthwiseconv_float.h", "optimized/depthwiseconv_uint8.h", + "optimized/depthwiseconv_uint8_3x3_filter.h", "optimized/optimized_ops.h", ], copts = tflite_copts(), @@ -208,7 +213,10 @@ cc_library( "compatibility.h", "quantization_util.h", ], - deps = [":round"], + deps = [ + ":round", + ":types", + ], ) cc_test( @@ -283,7 +291,7 @@ cc_library( "optimized/neon_tensor_utils.h", "optimized/tensor_utils_impl.h", ], - copts = NEON_FLAGS_IF_APPLICABLE, + copts = NEON_FLAGS_IF_APPLICABLE + HARD_FP_FLAGS_IF_APPLICABLE, deps = [ ":cpu_check", ":portable_tensor_utils", @@ -305,6 +313,27 @@ cc_library( ], ) +# Audio support classes imported directly from TensorFlow. +cc_library( + name = "audio_utils", + srcs = [ + "mfcc.cc", + "mfcc_dct.cc", + "mfcc_mel_filterbank.cc", + "spectrogram.cc", + ], + hdrs = [ + "mfcc.h", + "mfcc_dct.h", + "mfcc_mel_filterbank.h", + "spectrogram.h", + ], + deps = [ + "//third_party/fft2d:fft2d_headers", + "@fft2d", + ], +) + cc_library( name = "tensor_utils", srcs = [ diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc index 510395126ce3785b1d44fec1e0eb994c29ff0db7..f142374269606bdd3d4184af013749102666ab89 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -40,5 +40,152 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr, hidden_state_ptr_batch); } +void LstmStep( + const float* input_ptr_batch, const float* input_to_input_weights_ptr, + const float* input_to_forget_weights_ptr, + const float* input_to_cell_weights_ptr, + const float* input_to_output_weights_ptr, + const float* recurrent_to_input_weights_ptr, + const float* recurrent_to_forget_weights_ptr, + const float* recurrent_to_cell_weights_ptr, + const float* recurrent_to_output_weights_ptr, + const float* cell_to_input_weights_ptr, + const float* cell_to_forget_weights_ptr, + const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const float* projection_weights_ptr, + const float* projection_bias_ptr, const TfLiteLSTMParams* params, + int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr, + float* cell_state_ptr, float* input_gate_scratch, + float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, + float* output_ptr_batch) { + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights_ptr == nullptr); + const bool use_peephole = (cell_to_output_weights_ptr != nullptr); + // Initialize scratch buffers with bias. + if (!use_cifg) { + tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch, + input_gate_scratch); + } + tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, + forget_gate_scratch); + tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, + output_gate_scratch); + + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + forget_gate_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + output_gate_scratch, /*result_stride=*/1); + + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, input_gate_scratch, + /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, forget_gate_scratch, + /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, output_gate_scratch, + /*result_stride=*/1); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch, + input_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch, + forget_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, + n_batch * n_cell, cell_state_ptr); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + params->activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); + } + if (params->cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, + params->cell_clip, cell_state_ptr); + } + + // For each batch and cell: update the output gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch, + output_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, + params->activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights_ptr != nullptr); + const bool use_projection_bias = (projection_bias_ptr != nullptr); + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, + n_batch, output_ptr_batch); + } else { + tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch, + output_ptr_batch, /*result_stride=*/1); + if (params->proj_clip > 0.0) { + tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, + params->proj_clip, output_ptr_batch); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output_ptr_batch); + } + tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, + output_state_ptr); +} + } // namespace kernel_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h index 9872d4500b862388ed4b96c97e3755f548e35d35..3ec60ee57a87833959a34ba95d32df15bea188a4 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h @@ -35,6 +35,42 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr, TfLiteFusedActivation activation, float* hidden_state_ptr_batch, float* output_ptr_batch); +// Performs an LSTM batch inference step for input specified by input_ptr_batch. +// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and +// biases (*_bias_ptr), and buffers (*_scratch), along with additional +// parameters: +// - params: various LSTM params including activation, clipping, etc., +// - n_batch: size of batch, +// - n_cell: number of cells (or units), +// - n_input: the input size, +// - n_output: the output size. +// +// The pointers to the cell and output state and the output are updated. Unless +// projection is specified output and output state contain the same data. +// +// The pointers with the suffix "_batch" point to data aligned in batch_major +// order, and each step processes batch_size many inputs from input_ptr_batch, +// and updates batch_size many cell and output states. +void LstmStep( + const float* input_ptr_batch, const float* input_to_input_weights_ptr, + const float* input_to_forget_weights_ptr, + const float* input_to_cell_weights_ptr, + const float* input_to_output_weights_ptr, + const float* recurrent_to_input_weights_ptr, + const float* recurrent_to_forget_weights_ptr, + const float* recurrent_to_cell_weights_ptr, + const float* recurrent_to_output_weights_ptr, + const float* cell_to_input_weights_ptr, + const float* cell_to_forget_weights_ptr, + const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const float* projection_weights_ptr, + const float* projection_bias_ptr, const TfLiteLSTMParams* params, + int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr, + float* cell_state_ptr, float* input_gate_scratch, + float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, + float* output_ptr_batch); + } // namespace kernel_utils } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/mfcc.cc b/tensorflow/contrib/lite/kernels/internal/mfcc.cc new file mode 100644 index 0000000000000000000000000000000000000000..eafe0c7afee6fabd5a4a258aa5176e23f5e8d62a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/mfcc.cc @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/contrib/lite/kernels/internal/mfcc.h" + +namespace tflite { +namespace internal { + +const double kDefaultUpperFrequencyLimit = 4000; +const double kDefaultLowerFrequencyLimit = 20; +const double kFilterbankFloor = 1e-12; +const int kDefaultFilterbankChannelCount = 40; +const int kDefaultDCTCoefficientCount = 13; + +Mfcc::Mfcc() + : initialized_(false), + lower_frequency_limit_(kDefaultLowerFrequencyLimit), + upper_frequency_limit_(kDefaultUpperFrequencyLimit), + filterbank_channel_count_(kDefaultFilterbankChannelCount), + dct_coefficient_count_(kDefaultDCTCoefficientCount) {} + +bool Mfcc::Initialize(int input_length, double input_sample_rate) { + bool initialized = mel_filterbank_.Initialize( + input_length, input_sample_rate, filterbank_channel_count_, + lower_frequency_limit_, upper_frequency_limit_); + initialized &= + dct_.Initialize(filterbank_channel_count_, dct_coefficient_count_); + initialized_ = initialized; + return initialized; +} + +void Mfcc::Compute(const std::vector& spectrogram_frame, + std::vector* output) const { + if (!initialized_) { + // LOG(ERROR) << "Mfcc not initialized."; + return; + } + std::vector working; + mel_filterbank_.Compute(spectrogram_frame, &working); + for (int i = 0; i < working.size(); ++i) { + double val = working[i]; + if (val < kFilterbankFloor) { + val = kFilterbankFloor; + } + working[i] = log(val); + } + dct_.Compute(working, output); +} + +} // namespace internal +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/mfcc.h b/tensorflow/contrib/lite/kernels/internal/mfcc.h new file mode 100644 index 0000000000000000000000000000000000000000..d8500ecdcf38e5dcfe9eb89915501678455b3dd9 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/mfcc.h @@ -0,0 +1,78 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Basic class for computing MFCCs from spectrogram slices. + +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_H_ + +#include + +#include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h" +#include "tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h" + +namespace tflite { +namespace internal { + +class Mfcc { + public: + Mfcc(); + bool Initialize(int input_length, double input_sample_rate); + + // Input is a single squared-magnitude spectrogram frame. The input spectrum + // is converted to linear magnitude and weighted into bands using a + // triangular mel filterbank, and a discrete cosine transform (DCT) of the + // values is taken. Output is populated with the lowest dct_coefficient_count + // of these values. + void Compute(const std::vector& spectrogram_frame, + std::vector* output) const; + + void set_upper_frequency_limit(double upper_frequency_limit) { + // CHECK(!initialized_) << "Set frequency limits before calling + // Initialize."; + upper_frequency_limit_ = upper_frequency_limit; + } + + void set_lower_frequency_limit(double lower_frequency_limit) { + // CHECK(!initialized_) << "Set frequency limits before calling + // Initialize."; + lower_frequency_limit_ = lower_frequency_limit; + } + + void set_filterbank_channel_count(int filterbank_channel_count) { + /// CHECK(!initialized_) << "Set channel count before calling Initialize."; + filterbank_channel_count_ = filterbank_channel_count; + } + + void set_dct_coefficient_count(int dct_coefficient_count) { + // CHECK(!initialized_) << "Set coefficient count before calling + // Initialize."; + dct_coefficient_count_ = dct_coefficient_count; + } + + private: + MfccMelFilterbank mel_filterbank_; + MfccDct dct_; + bool initialized_; + double lower_frequency_limit_; + double upper_frequency_limit_; + int filterbank_channel_count_; + int dct_coefficient_count_; +}; + +} // namespace internal +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/mfcc_dct.cc b/tensorflow/contrib/lite/kernels/internal/mfcc_dct.cc new file mode 100644 index 0000000000000000000000000000000000000000..b0b7d181bdcf01688a387f33a3e64fc904324b50 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/mfcc_dct.cc @@ -0,0 +1,78 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h" + +#include + +namespace tflite { +namespace internal { + +MfccDct::MfccDct() : initialized_(false) {} + +bool MfccDct::Initialize(int input_length, int coefficient_count) { + coefficient_count_ = coefficient_count; + input_length_ = input_length; + + if (coefficient_count_ < 1) { + return false; + } + + if (input_length < 1) { + return false; + } + + if (coefficient_count_ > input_length_) { + return false; + } + + cosines_.resize(coefficient_count_); + double fnorm = sqrt(2.0 / input_length_); + // Some platforms don't have M_PI, so define a local constant here. + const double pi = atan(1) * 4; + double arg = pi / input_length_; + for (int i = 0; i < coefficient_count_; ++i) { + cosines_[i].resize(input_length_); + for (int j = 0; j < input_length_; ++j) { + cosines_[i][j] = fnorm * cos(i * arg * (j + 0.5)); + } + } + initialized_ = true; + return true; +} + +void MfccDct::Compute(const std::vector &input, + std::vector *output) const { + if (!initialized_) { + return; + } + + output->resize(coefficient_count_); + int length = input.size(); + if (length > input_length_) { + length = input_length_; + } + + for (int i = 0; i < coefficient_count_; ++i) { + double sum = 0.0; + for (int j = 0; j < length; ++j) { + sum += cosines_[i][j] * input[j]; + } + (*output)[i] = sum; + } +} + +} // namespace internal +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/mfcc_dct.h b/tensorflow/contrib/lite/kernels/internal/mfcc_dct.h new file mode 100644 index 0000000000000000000000000000000000000000..a53f5cbd9bb70c7c9dd49672681140bb9cbd2f4f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/mfcc_dct.h @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Basic minimal DCT class for MFCC speech processing. + +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_DCT_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_DCT_H_ + +#include + +namespace tflite { +namespace internal { + +class MfccDct { + public: + MfccDct(); + bool Initialize(int input_length, int coefficient_count); + void Compute(const std::vector& input, + std::vector* output) const; + + private: + bool initialized_; + int coefficient_count_; + int input_length_; + std::vector > cosines_; +}; + +} // namespace internal +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_DCT_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.cc b/tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.cc new file mode 100644 index 0000000000000000000000000000000000000000..c3deb33d91a47bfe54b7c84d2a615df2422f90cc --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.cc @@ -0,0 +1,204 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This code resamples the FFT bins, and smooths then with triangle-shaped +// weights to create a mel-frequency filter bank. For filter i centered at f_i, +// there is a triangular weighting of the FFT bins that extends from +// filter f_i-1 (with a value of zero at the left edge of the triangle) to f_i +// (where the filter value is 1) to f_i+1 (where the filter values returns to +// zero). + +// Note: this code fails if you ask for too many channels. The algorithm used +// here assumes that each FFT bin contributes to at most two channels: the +// right side of a triangle for channel i, and the left side of the triangle +// for channel i+1. If you ask for so many channels that some of the +// resulting mel triangle filters are smaller than a single FFT bin, these +// channels may end up with no contributing FFT bins. The resulting mel +// spectrum output will have some channels that are always zero. + +#include "tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h" + +#include + +namespace tflite { +namespace internal { + +MfccMelFilterbank::MfccMelFilterbank() : initialized_(false) {} + +bool MfccMelFilterbank::Initialize(int input_length, double input_sample_rate, + int output_channel_count, + double lower_frequency_limit, + double upper_frequency_limit) { + num_channels_ = output_channel_count; + sample_rate_ = input_sample_rate; + input_length_ = input_length; + + if (num_channels_ < 1) { + // LOG(ERROR) << "Number of filterbank channels must be positive."; + return false; + } + + if (sample_rate_ <= 0) { + // LOG(ERROR) << "Sample rate must be positive."; + return false; + } + + if (input_length < 2) { + // LOG(ERROR) << "Input length must greater than 1."; + return false; + } + + if (lower_frequency_limit < 0) { + // LOG(ERROR) << "Lower frequency limit must be nonnegative."; + return false; + } + + if (upper_frequency_limit <= lower_frequency_limit) { + /// LOG(ERROR) << "Upper frequency limit must be greater than " + // << "lower frequency limit."; + return false; + } + + // An extra center frequency is computed at the top to get the upper + // limit on the high side of the final triangular filter. + center_frequencies_.resize(num_channels_ + 1); + const double mel_low = FreqToMel(lower_frequency_limit); + const double mel_hi = FreqToMel(upper_frequency_limit); + const double mel_span = mel_hi - mel_low; + const double mel_spacing = mel_span / static_cast(num_channels_ + 1); + for (int i = 0; i < num_channels_ + 1; ++i) { + center_frequencies_[i] = mel_low + (mel_spacing * (i + 1)); + } + + // Always exclude DC; emulate HTK. + const double hz_per_sbin = + 0.5 * sample_rate_ / static_cast(input_length_ - 1); + start_index_ = static_cast(1.5 + (lower_frequency_limit / hz_per_sbin)); + end_index_ = static_cast(upper_frequency_limit / hz_per_sbin); + + // Maps the input spectrum bin indices to filter bank channels/indices. For + // each FFT bin, band_mapper tells us which channel this bin contributes to + // on the right side of the triangle. Thus this bin also contributes to the + // left side of the next channel's triangle response. + band_mapper_.resize(input_length_); + int channel = 0; + for (int i = 0; i < input_length_; ++i) { + double melf = FreqToMel(i * hz_per_sbin); + if ((i < start_index_) || (i > end_index_)) { + band_mapper_[i] = -2; // Indicate an unused Fourier coefficient. + } else { + while ((center_frequencies_[channel] < melf) && + (channel < num_channels_)) { + ++channel; + } + band_mapper_[i] = channel - 1; // Can be == -1 + } + } + + // Create the weighting functions to taper the band edges. The contribution + // of any one FFT bin is based on its distance along the continuum between two + // mel-channel center frequencies. This bin contributes weights_[i] to the + // current channel and 1-weights_[i] to the next channel. + weights_.resize(input_length_); + for (int i = 0; i < input_length_; ++i) { + channel = band_mapper_[i]; + if ((i < start_index_) || (i > end_index_)) { + weights_[i] = 0.0; + } else { + if (channel >= 0) { + weights_[i] = + (center_frequencies_[channel + 1] - FreqToMel(i * hz_per_sbin)) / + (center_frequencies_[channel + 1] - center_frequencies_[channel]); + } else { + weights_[i] = (center_frequencies_[0] - FreqToMel(i * hz_per_sbin)) / + (center_frequencies_[0] - mel_low); + } + } + } + // Check the sum of FFT bin weights for every mel band to identify + // situations where the mel bands are so narrow that they don't get + // significant weight on enough (or any) FFT bins -- i.e., too many + // mel bands have been requested for the given FFT size. + std::vector bad_channels; + for (int c = 0; c < num_channels_; ++c) { + float band_weights_sum = 0.0; + for (int i = 0; i < input_length_; ++i) { + if (band_mapper_[i] == c - 1) { + band_weights_sum += (1.0 - weights_[i]); + } else if (band_mapper_[i] == c) { + band_weights_sum += weights_[i]; + } + } + // The lowest mel channels have the fewest FFT bins and the lowest + // weights sum. But given that the target gain at the center frequency + // is 1.0, if the total sum of weights is 0.5, we're in bad shape. + if (band_weights_sum < 0.5) { + bad_channels.push_back(c); + } + } + if (!bad_channels.empty()) { + /* + LOG(ERROR) << "Missing " << bad_channels.size() << " bands " + << " starting at " << bad_channels[0] + << " in mel-frequency design. " + << "Perhaps too many channels or " + << "not enough frequency resolution in spectrum. (" + << "input_length: " << input_length + << " input_sample_rate: " << input_sample_rate + << " output_channel_count: " << output_channel_count + << " lower_frequency_limit: " << lower_frequency_limit + << " upper_frequency_limit: " << upper_frequency_limit; + */ + } + initialized_ = true; + return true; +} + +// Compute the mel spectrum from the squared-magnitude FFT input by taking the +// square root, then summing FFT magnitudes under triangular integration windows +// whose widths increase with frequency. +void MfccMelFilterbank::Compute(const std::vector &input, + std::vector *output) const { + if (!initialized_) { + // LOG(ERROR) << "Mel Filterbank not initialized."; + return; + } + + if (input.size() <= end_index_) { + // LOG(ERROR) << "Input too short to compute filterbank"; + return; + } + + // Ensure output is right length and reset all values. + output->assign(num_channels_, 0.0); + + for (int i = start_index_; i <= end_index_; i++) { // For each FFT bin + double spec_val = sqrt(input[i]); + double weighted = spec_val * weights_[i]; + int channel = band_mapper_[i]; + if (channel >= 0) + (*output)[channel] += weighted; // Right side of triangle, downward slope + channel++; + if (channel < num_channels_) + (*output)[channel] += spec_val - weighted; // Left side of triangle + } +} + +double MfccMelFilterbank::FreqToMel(double freq) const { + return 1127.0 * log(1.0 + (freq / 700.0)); +} + +} // namespace internal +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h b/tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h new file mode 100644 index 0000000000000000000000000000000000000000..c1db28243eea39a694b7613ac7144dce9b294897 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h @@ -0,0 +1,63 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Basic class for applying a mel-scale mapping to a power spectrum. + +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_MEL_FILTERBANK_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_MEL_FILTERBANK_H_ + +#include + +namespace tflite { +namespace internal { + +class MfccMelFilterbank { + public: + MfccMelFilterbank(); + bool Initialize(int input_length, // Number of unique FFT bins fftsize/2+1. + double input_sample_rate, int output_channel_count, + double lower_frequency_limit, double upper_frequency_limit); + + // Takes a squared-magnitude spectrogram slice as input, computes a + // triangular-mel-weighted linear-magnitude filterbank, and places the result + // in output. + void Compute(const std::vector& input, + std::vector* output) const; + + private: + double FreqToMel(double freq) const; + bool initialized_; + int num_channels_; + double sample_rate_; + int input_length_; + std::vector center_frequencies_; // In mel, for each mel channel. + + // Each FFT bin b contributes to two triangular mel channels, with + // proportion weights_[b] going into mel channel band_mapper_[b], and + // proportion (1 - weights_[b]) going into channel band_mapper_[b] + 1. + // Thus, weights_ contains the weighting applied to each FFT bin for the + // upper-half of the triangular band. + std::vector weights_; // Right-side weight for this fft bin. + + // FFT bin i contributes to the upper side of mel channel band_mapper_[i] + std::vector band_mapper_; + int start_index_; // Lowest FFT bin used to calculate mel spectrum. + int end_index_; // Highest FFT bin used to calculate mel spectrum. +}; + +} // namespace internal +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_MEL_FILTERBANK_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h index dbc4f0d6fdca8279072d6ea225334722d6a89eb2..c71b070680ead77769dd8b04d0d7a133ad694abc 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -18,6 +18,7 @@ limitations under the License. #include "fixedpoint/fixedpoint.h" #include "public/gemmlowp.h" #include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h" #include "tensorflow/contrib/lite/kernels/internal/types.h" namespace tflite { @@ -1692,6 +1693,23 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, const int output_width = ArraySize(output_dims, 1); TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); +#ifdef __aarch64__ + // Call kernel optimized for depthwise convolutions using 3x3 filters, + // stride = 1, no padding, depth_multiplier = 1 and depth a multiple of 16. + if (filter_width == 3 && filter_height == 3 && depth_multiplier == 1 && + (stride_width == 1 || stride_width == 2) && + (stride_height == 1 || stride_height == 2) && pad_width == 0 && + pad_height == 0 && (input_depth % 16) == 0) { + DepthwiseConv3by3FilterDepth16( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, stride_height, + pad_width, pad_height, depth_multiplier, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims); + return; + } +#endif + static const int kAccBufferMaxSize = 2048; int32 acc_buffer[kAccBufferMaxSize]; TFLITE_DCHECK_GE(kAccBufferMaxSize, output_depth); diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h new file mode 100644 index 0000000000000000000000000000000000000000..9dc76e7608f170fcf21bb188226bf30995df8cda --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -0,0 +1,706 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_3X3_FILTER_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_3X3_FILTER_H_ + +#include "fixedpoint/fixedpoint.h" +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +#ifdef __aarch64__ + +inline void preload_l1_keep(const uint8* ptr) { +#ifdef GEMMLOWP_ARM_64 + asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :); +#else + gemmlowp::Prefetch(ptr); +#endif +} + +// Implementation of quantized DepthwiseConv for 3x3 filters. + +// Below are helper structs to remove the use of arrays. +// There is an llvm bug that causes significant slowdown when using arrays for +// NEON intrinsics vector data types. +// See: https://bugs.llvm.org/show_bug.cgi?id=34945 + +struct Int32x16 { + int32x4_t v0, v1, v2, v3; +}; + +struct Int16x16 { + int16x8_t low, high; +}; + +struct Int16x16x3 { + Int16x16 v0, v1, v2; +}; + +struct Filter3x3x16 { + Int16x16x3 r0, r1, r2; +}; + +// Loads 3x3 filter of depth 16 and adds filter offsets. +inline Filter3x3x16 LoadFilterDepth16(const uint8* filter_ptr, + int32 filter_offset, int output_depth) { + Filter3x3x16 filter; + + uint8x8_t temp_u8_0, temp_u8_1, temp_u8_2, temp_u8_3, temp_u8_4, temp_u8_5, + temp_u8_6, temp_u8_7, temp_u8_8, temp_u8_9, temp_u8_10, temp_u8_11, + temp_u8_12, temp_u8_13, temp_u8_14, temp_u8_15, temp_u8_16, temp_u8_17; + int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset); + + temp_u8_0 = vld1_u8(filter_ptr + 0 * output_depth); + temp_u8_1 = vld1_u8(filter_ptr + 0 * output_depth + 8); + temp_u8_2 = vld1_u8(filter_ptr + 1 * output_depth); + temp_u8_3 = vld1_u8(filter_ptr + 1 * output_depth + 8); + temp_u8_4 = vld1_u8(filter_ptr + 2 * output_depth); + temp_u8_5 = vld1_u8(filter_ptr + 2 * output_depth + 8); + + temp_u8_6 = vld1_u8(filter_ptr + 3 * output_depth); + temp_u8_7 = vld1_u8(filter_ptr + 3 * output_depth + 8); + temp_u8_8 = vld1_u8(filter_ptr + 4 * output_depth); + temp_u8_9 = vld1_u8(filter_ptr + 4 * output_depth + 8); + temp_u8_10 = vld1_u8(filter_ptr + 5 * output_depth); + temp_u8_11 = vld1_u8(filter_ptr + 5 * output_depth + 8); + + temp_u8_12 = vld1_u8(filter_ptr + 6 * output_depth); + temp_u8_13 = vld1_u8(filter_ptr + 6 * output_depth + 8); + temp_u8_14 = vld1_u8(filter_ptr + 7 * output_depth); + temp_u8_15 = vld1_u8(filter_ptr + 7 * output_depth + 8); + temp_u8_16 = vld1_u8(filter_ptr + 8 * output_depth); + temp_u8_17 = vld1_u8(filter_ptr + 8 * output_depth + 8); + + filter.r0.v0.low = vreinterpretq_s16_u16(vmovl_u8(temp_u8_0)); + filter.r0.v0.high = vreinterpretq_s16_u16(vmovl_u8(temp_u8_1)); + filter.r0.v1.low = vreinterpretq_s16_u16(vmovl_u8(temp_u8_2)); + filter.r0.v1.high = vreinterpretq_s16_u16(vmovl_u8(temp_u8_3)); + filter.r0.v2.low = vreinterpretq_s16_u16(vmovl_u8(temp_u8_4)); + filter.r0.v2.high = vreinterpretq_s16_u16(vmovl_u8(temp_u8_5)); + + filter.r1.v0.low = vreinterpretq_s16_u16(vmovl_u8(temp_u8_6)); + filter.r1.v0.high = vreinterpretq_s16_u16(vmovl_u8(temp_u8_7)); + filter.r1.v1.low = vreinterpretq_s16_u16(vmovl_u8(temp_u8_8)); + filter.r1.v1.high = vreinterpretq_s16_u16(vmovl_u8(temp_u8_9)); + filter.r1.v2.low = vreinterpretq_s16_u16(vmovl_u8(temp_u8_10)); + filter.r1.v2.high = vreinterpretq_s16_u16(vmovl_u8(temp_u8_11)); + + filter.r2.v0.low = vreinterpretq_s16_u16(vmovl_u8(temp_u8_12)); + filter.r2.v0.high = vreinterpretq_s16_u16(vmovl_u8(temp_u8_13)); + filter.r2.v1.low = vreinterpretq_s16_u16(vmovl_u8(temp_u8_14)); + filter.r2.v1.high = vreinterpretq_s16_u16(vmovl_u8(temp_u8_15)); + filter.r2.v2.low = vreinterpretq_s16_u16(vmovl_u8(temp_u8_16)); + filter.r2.v2.high = vreinterpretq_s16_u16(vmovl_u8(temp_u8_17)); + + filter.r0.v0.low = vaddq_s16(filter.r0.v0.low, filter_offset_vec); + filter.r0.v0.high = vaddq_s16(filter.r0.v0.high, filter_offset_vec); + filter.r0.v1.low = vaddq_s16(filter.r0.v1.low, filter_offset_vec); + filter.r0.v1.high = vaddq_s16(filter.r0.v1.high, filter_offset_vec); + filter.r0.v2.low = vaddq_s16(filter.r0.v2.low, filter_offset_vec); + filter.r0.v2.high = vaddq_s16(filter.r0.v2.high, filter_offset_vec); + + filter.r1.v0.low = vaddq_s16(filter.r1.v0.low, filter_offset_vec); + filter.r1.v0.high = vaddq_s16(filter.r1.v0.high, filter_offset_vec); + filter.r1.v1.low = vaddq_s16(filter.r1.v1.low, filter_offset_vec); + filter.r1.v1.high = vaddq_s16(filter.r1.v1.high, filter_offset_vec); + filter.r1.v2.low = vaddq_s16(filter.r1.v2.low, filter_offset_vec); + filter.r1.v2.high = vaddq_s16(filter.r1.v2.high, filter_offset_vec); + + filter.r2.v0.low = vaddq_s16(filter.r2.v0.low, filter_offset_vec); + filter.r2.v0.high = vaddq_s16(filter.r2.v0.high, filter_offset_vec); + filter.r2.v1.low = vaddq_s16(filter.r2.v1.low, filter_offset_vec); + filter.r2.v1.high = vaddq_s16(filter.r2.v1.high, filter_offset_vec); + filter.r2.v2.low = vaddq_s16(filter.r2.v2.low, filter_offset_vec); + filter.r2.v2.high = vaddq_s16(filter.r2.v2.high, filter_offset_vec); + + return filter; +} + +// Loads 3 input cells of depth 16 and adds input offsets. +inline Int16x16x3 LoadInputRowDepth16(const uint8* ptr, int input_depth, + int32 input_offset, + Int16x16x3 input_row) { + uint8x8_t temp_0, temp_1; + int16x8_t offset_vec = vdupq_n_s16(input_offset); + + temp_0 = vld1_u8(ptr + 0 * input_depth); + temp_1 = vld1_u8(ptr + 0 * input_depth + 8); + input_row.v0.low = vreinterpretq_s16_u16(vmovl_u8(temp_0)); + input_row.v0.high = vreinterpretq_s16_u16(vmovl_u8(temp_1)); + input_row.v0.low = vaddq_s16(input_row.v0.low, offset_vec); + input_row.v0.high = vaddq_s16(input_row.v0.high, offset_vec); + + temp_0 = vld1_u8(ptr + 1 * input_depth); + temp_1 = vld1_u8(ptr + 1 * input_depth + 8); + input_row.v1.low = vreinterpretq_s16_u16(vmovl_u8(temp_0)); + input_row.v1.high = vreinterpretq_s16_u16(vmovl_u8(temp_1)); + input_row.v1.low = vaddq_s16(input_row.v1.low, offset_vec); + input_row.v1.high = vaddq_s16(input_row.v1.high, offset_vec); + + temp_0 = vld1_u8(ptr + 2 * input_depth); + temp_1 = vld1_u8(ptr + 2 * input_depth + 8); + input_row.v2.low = vreinterpretq_s16_u16(vmovl_u8(temp_0)); + input_row.v2.high = vreinterpretq_s16_u16(vmovl_u8(temp_1)); + input_row.v2.low = vaddq_s16(input_row.v2.low, offset_vec); + input_row.v2.high = vaddq_s16(input_row.v2.high, offset_vec); + + return input_row; +} + +// Performs multiply accumulate on 3 inputs of depth 16. +inline Int32x16 MultiplyAccumulateRowDepth16(Int32x16 output, + const Int16x16x3& filter_row, + const Int16x16x3& input_row) { + output.v0 = vmlal_s16(output.v0, vget_low_s16(filter_row.v0.low), + vget_low_s16(input_row.v0.low)); + output.v1 = vmlal_s16(output.v1, vget_high_s16(filter_row.v0.low), + vget_high_s16(input_row.v0.low)); + output.v2 = vmlal_s16(output.v2, vget_low_s16(filter_row.v0.high), + vget_low_s16(input_row.v0.high)); + output.v3 = vmlal_s16(output.v3, vget_high_s16(filter_row.v0.high), + vget_high_s16(input_row.v0.high)); + + output.v0 = vmlal_s16(output.v0, vget_low_s16(filter_row.v1.low), + vget_low_s16(input_row.v1.low)); + output.v1 = vmlal_s16(output.v1, vget_high_s16(filter_row.v1.low), + vget_high_s16(input_row.v1.low)); + output.v2 = vmlal_s16(output.v2, vget_low_s16(filter_row.v1.high), + vget_low_s16(input_row.v1.high)); + output.v3 = vmlal_s16(output.v3, vget_high_s16(filter_row.v1.high), + vget_high_s16(input_row.v1.high)); + + output.v0 = vmlal_s16(output.v0, vget_low_s16(filter_row.v2.low), + vget_low_s16(input_row.v2.low)); + output.v1 = vmlal_s16(output.v1, vget_high_s16(filter_row.v2.low), + vget_high_s16(input_row.v2.low)); + output.v2 = vmlal_s16(output.v2, vget_low_s16(filter_row.v2.high), + vget_low_s16(input_row.v2.high)); + output.v3 = vmlal_s16(output.v3, vget_high_s16(filter_row.v2.high), + vget_high_s16(input_row.v2.high)); + + return output; +} + +// Applies activation, offset and downquantize on a set of accumulator +// registers of depth 16. Stores results to output. +inline void DownquantizeAndStoreDepth16(Int32x16 acc, int32 output_multiplier, + int output_shift, + int32x4_t output_offset_vec, + int32x4_t output_activation_min_vec, + int32x4_t output_activation_max_vec, + uint8* output_ptr) { + // Fixed-point multiplication. + acc.v0 = vqrdmulhq_n_s32(acc.v0, output_multiplier); + acc.v1 = vqrdmulhq_n_s32(acc.v1, output_multiplier); + acc.v2 = vqrdmulhq_n_s32(acc.v2, output_multiplier); + acc.v3 = vqrdmulhq_n_s32(acc.v3, output_multiplier); + + using gemmlowp::RoundingDivideByPOT; + acc.v0 = RoundingDivideByPOT(acc.v0, output_shift); + acc.v1 = RoundingDivideByPOT(acc.v1, output_shift); + acc.v2 = RoundingDivideByPOT(acc.v2, output_shift); + acc.v3 = RoundingDivideByPOT(acc.v3, output_shift); + + // Add the output offset. + acc.v0 = vaddq_s32(acc.v0, output_offset_vec); + acc.v1 = vaddq_s32(acc.v1, output_offset_vec); + acc.v2 = vaddq_s32(acc.v2, output_offset_vec); + acc.v3 = vaddq_s32(acc.v3, output_offset_vec); + + // Apply the activation function. + acc.v0 = vmaxq_s32(acc.v0, output_activation_min_vec); + acc.v1 = vmaxq_s32(acc.v1, output_activation_min_vec); + acc.v2 = vmaxq_s32(acc.v2, output_activation_min_vec); + acc.v3 = vmaxq_s32(acc.v3, output_activation_min_vec); + + acc.v0 = vminq_s32(acc.v0, output_activation_max_vec); + acc.v1 = vminq_s32(acc.v1, output_activation_max_vec); + acc.v2 = vminq_s32(acc.v2, output_activation_max_vec); + acc.v3 = vminq_s32(acc.v3, output_activation_max_vec); + + // Saturating cast to uint8 and store to destination. + int16x4_t acc_tlla_s16 = vqmovn_s32(acc.v0); + int16x4_t acc_tllb_s16 = vqmovn_s32(acc.v1); + int16x4_t acc_tlha_s16 = vqmovn_s32(acc.v2); + int16x4_t acc_tlhb_s16 = vqmovn_s32(acc.v3); + + int16x8_t res_s16_0 = vcombine_s16(acc_tlla_s16, acc_tllb_s16); + int16x8_t res_s16_1 = vcombine_s16(acc_tlha_s16, acc_tlhb_s16); + uint8x8_t res_u8_0 = vqmovun_s16(res_s16_0); + uint8x8_t res_u8_1 = vqmovun_s16(res_s16_1); + vst1q_u8(output_ptr, vcombine_u8(res_u8_0, res_u8_1)); +} + +// A kernel that is optimized on the number of output cells in the x and y +// direction, and the stride. Assumes 3x3 filters of 16 depth. +template +struct ConvKernel3x3FilterDepth16 {}; + +template <> +struct ConvKernel3x3FilterDepth16<1, 2, 1> { + static void Run(const Filter3x3x16& filter, const uint8* input_ptr, + int input_depth, int32 input_offset, int input_row_width, + const int32* bias_ptr, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_ptr, int output_depth, int output_width) { + // 16 depth accumulators for the 2 outputs. + Int32x16 acc0, acc1; + + // Accumulators for top filter. + acc0.v0 = vld1q_s32(bias_ptr); + acc0.v1 = vld1q_s32(bias_ptr + 4); + acc0.v2 = vld1q_s32(bias_ptr + 8); + acc0.v3 = vld1q_s32(bias_ptr + 12); + // Accumulators for bottom filter. + acc1.v0 = vld1q_s32(bias_ptr); + acc1.v1 = vld1q_s32(bias_ptr + 4); + acc1.v2 = vld1q_s32(bias_ptr + 8); + acc1.v3 = vld1q_s32(bias_ptr + 12); + + // Main multiply accumulate work. + { + // Load inputs for one filter row at a time. + Int16x16x3 input; + + // Do first row of top filter. + input = LoadInputRowDepth16(input_ptr, input_depth, input_offset, input); + acc0 = MultiplyAccumulateRowDepth16(acc0, filter.r0, input); + + // Do second row of top filter. + input = LoadInputRowDepth16(input_ptr + input_row_width, input_depth, + input_offset, input); + acc0 = MultiplyAccumulateRowDepth16(acc0, filter.r1, input); + + // The inputs to second row of the top filter are also the inputs to the + // first row of the bottom filter. + acc1 = MultiplyAccumulateRowDepth16(acc1, filter.r0, input); + + // Do third row of top filter. + input = LoadInputRowDepth16(input_ptr + 2 * input_row_width, input_depth, + input_offset, input); + acc0 = MultiplyAccumulateRowDepth16(acc0, filter.r2, input); + + // The inputs to third row of the top filter are also the inputs to the + // second row of the bottom filter. + acc1 = MultiplyAccumulateRowDepth16(acc1, filter.r1, input); + + // Do third row of bottom filter. + input = LoadInputRowDepth16(input_ptr + 3 * input_row_width, input_depth, + input_offset, input); + acc1 = MultiplyAccumulateRowDepth16(acc1, filter.r2, input); + } + + // Apply activation, downquantize and store. + int32x4_t output_offset_vec = vdupq_n_s32(output_offset); + int32x4_t output_activation_min_vec = vdupq_n_s32(output_activation_min); + int32x4_t output_activation_max_vec = vdupq_n_s32(output_activation_max); + + DownquantizeAndStoreDepth16(acc0, output_multiplier, output_shift, + output_offset_vec, output_activation_min_vec, + output_activation_max_vec, output_ptr); + + DownquantizeAndStoreDepth16(acc1, output_multiplier, output_shift, + output_offset_vec, output_activation_min_vec, + output_activation_max_vec, + output_ptr + output_depth * output_width); + } +}; + +template <> +struct ConvKernel3x3FilterDepth16<1, 2, 2> { + static void Run(const Filter3x3x16& filter, const uint8* input_ptr, + int input_depth, int32 input_offset, int input_row_width, + const int32* bias_ptr, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_ptr, int output_depth, int output_width) { + // 16 depth accumulators for the 2 outputs. + Int32x16 acc0, acc1; + + // Accumulators for top filter. + acc0.v0 = vld1q_s32(bias_ptr); + acc0.v1 = vld1q_s32(bias_ptr + 4); + acc0.v2 = vld1q_s32(bias_ptr + 8); + acc0.v3 = vld1q_s32(bias_ptr + 12); + // Accumulators for bottom filter. + acc1.v0 = vld1q_s32(bias_ptr); + acc1.v1 = vld1q_s32(bias_ptr + 4); + acc1.v2 = vld1q_s32(bias_ptr + 8); + acc1.v3 = vld1q_s32(bias_ptr + 12); + + // Main multiply accumulate work. + { + // Load inputs for one filter row at a time. + Int16x16x3 input; + + // Do first row of top filter. + input = LoadInputRowDepth16(input_ptr, input_depth, input_offset, input); + acc0 = MultiplyAccumulateRowDepth16(acc0, filter.r0, input); + + // Do second row of top filter. + input = LoadInputRowDepth16(input_ptr + input_row_width, input_depth, + input_offset, input); + acc0 = MultiplyAccumulateRowDepth16(acc0, filter.r1, input); + + // Do third row of top filter. + input = LoadInputRowDepth16(input_ptr + 2 * input_row_width, input_depth, + input_offset, input); + acc0 = MultiplyAccumulateRowDepth16(acc0, filter.r2, input); + + // The inputs to third row of the top filter are also the inputs + // to first row of the bottom filter. + acc1 = MultiplyAccumulateRowDepth16(acc1, filter.r0, input); + + // Do second row of bottom filter. + input = LoadInputRowDepth16(input_ptr + 3 * input_row_width, input_depth, + input_offset, input); + acc1 = MultiplyAccumulateRowDepth16(acc1, filter.r1, input); + + // Do third row of bottom filter. + input = LoadInputRowDepth16(input_ptr + 4 * input_row_width, input_depth, + input_offset, input); + acc1 = MultiplyAccumulateRowDepth16(acc1, filter.r2, input); + } + + // Apply activation, downquantize and store. + int32x4_t output_offset_vec = vdupq_n_s32(output_offset); + int32x4_t output_activation_min_vec = vdupq_n_s32(output_activation_min); + int32x4_t output_activation_max_vec = vdupq_n_s32(output_activation_max); + + DownquantizeAndStoreDepth16(acc0, output_multiplier, output_shift, + output_offset_vec, output_activation_min_vec, + output_activation_max_vec, output_ptr); + + DownquantizeAndStoreDepth16(acc1, output_multiplier, output_shift, + output_offset_vec, output_activation_min_vec, + output_activation_max_vec, + output_ptr + output_depth * output_width); + } +}; + +template <> +struct ConvKernel3x3FilterDepth16<1, 1> { + static void Run(const Filter3x3x16& filter, const uint8* input_ptr, + int input_depth, int32 input_offset, int input_row_width, + const int32* bias_ptr, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_ptr, int output_depth, int output_width) { + Int32x16 acc; + acc.v0 = vld1q_s32(bias_ptr); + acc.v1 = vld1q_s32(bias_ptr + 4); + acc.v2 = vld1q_s32(bias_ptr + 8); + acc.v3 = vld1q_s32(bias_ptr + 12); + + // Main multiply accumulate work. + { + // Load inputs for one filter row at a time. + Int16x16x3 input; + + // Do first row. + input = LoadInputRowDepth16(input_ptr, input_depth, input_offset, input); + acc = MultiplyAccumulateRowDepth16(acc, filter.r0, input); + + // Do second row. + input = LoadInputRowDepth16(input_ptr + input_row_width, input_depth, + input_offset, input); + acc = MultiplyAccumulateRowDepth16(acc, filter.r1, input); + + // Do third row. + input = LoadInputRowDepth16(input_ptr + 2 * input_row_width, input_depth, + input_offset, input); + acc = MultiplyAccumulateRowDepth16(acc, filter.r2, input); + } + + // Apply activation, downquantize and store. + int32x4_t output_offset_vec = vdupq_n_s32(output_offset); + int32x4_t output_activation_min_vec = vdupq_n_s32(output_activation_min); + int32x4_t output_activation_max_vec = vdupq_n_s32(output_activation_max); + + DownquantizeAndStoreDepth16(acc, output_multiplier, output_shift, + output_offset_vec, output_activation_min_vec, + output_activation_max_vec, output_ptr); + } +}; + +inline void DepthwiseConv3by3FilterDepth16( + const uint8* input_data, const Dims<4>& input_dims, int32 input_offset, + const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int stride_width, + int stride_height, int pad_width, int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = ArraySize(input_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + + // Algorithm assumes below constraints. It is optimized for depth multiplier + // of 1, 3x3 filter, no padding, strides 1 and 2. + TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + TFLITE_DCHECK(depth_multiplier == 1); + TFLITE_DCHECK(filter_height == 3); + TFLITE_DCHECK(filter_width == 3); + TFLITE_DCHECK(pad_height == 0); + TFLITE_DCHECK(pad_width == 0); + TFLITE_DCHECK(stride_width == 1 || stride_width == 2); + TFLITE_DCHECK(stride_height == 1 || stride_height == 2); + + // The number of outputs to process in the main loop. + const int num_x_outputs = 1; + const int num_y_outputs = 2; + + const int input_row_width = output_depth * (input_width + 2 * pad_width); + const int input_batch_size = + input_row_width * (input_height + 2 * pad_height); + const int output_batch_size = output_depth * output_width * output_height; + const int input_ptr_x_increment = input_depth * stride_width; + + // Calculate extents of non-boundary loop. + int out_x_start = 0; + for (; out_x_start < input_width; out_x_start++) { + int in_x = (out_x_start * stride_width) - pad_width; + if (in_x >= 0) { + break; + } + } + int out_x_end = output_width - 1; + for (; out_x_end >= 0; out_x_end--) { + int in_x = (out_x_end * stride_width) - pad_width; + int in_x_end = in_x + filter_width + (num_x_outputs - 1) * stride_width; + if (in_x_end <= input_width) { + out_x_end++; + break; + } + } + int out_y_start = 0; + for (; out_y_start < input_height; out_y_start++) { + int in_y = (out_y_start * stride_height) - pad_height; + if (in_y >= 0) { + break; + } + } + int out_y_end = output_height - 1; + for (; out_y_end >= 0; out_y_end--) { + int in_y = (out_y_end * stride_height) - pad_height; + int in_y_end = in_y + filter_height + (num_y_outputs - 1) * stride_height; + if (in_y_end <= input_height) { + out_y_end++; + break; + } + } + + using dot_product_func_t = + decltype(&ConvKernel3x3FilterDepth16<1, 2, 1>::Run); + dot_product_func_t dot_product_func = nullptr; + + if (stride_width == 1 && stride_height == 1) { + dot_product_func = ConvKernel3x3FilterDepth16<1, 2, 1>::Run; + } else { + dot_product_func = ConvKernel3x3FilterDepth16<1, 2, 2>::Run; + } + + // Offsets for preloading inputs. + const int i0 = 0; + const int i1 = input_depth; + const int i2 = 2 * input_depth; + const int i3 = input_row_width; + const int i4 = input_row_width + input_depth; + const int i5 = input_row_width + 2 * input_depth; + const int i6 = 2 * input_row_width; + const int i7 = 2 * input_row_width + input_depth; + const int i8 = 2 * input_row_width + 2 * input_depth; + const int i9 = 3 * input_row_width; + const int i10 = 3 * input_row_width + input_depth; + const int i11 = 3 * input_row_width + 2 * input_depth; + const int i12 = 4 * input_row_width; + const int i13 = 4 * input_row_width + input_depth; + const int i14 = 4 * input_row_width + 2 * input_depth; + + for (int b = 0; b < batches; ++b) { + const int32* bias_ptr = bias_data; + const uint8* filter_ptr = filter_data; + + const int in_batch_offset = b * input_batch_size; + const int out_batch_offset = b * output_batch_size; + + int depth = 0; + for (; depth <= output_depth - 16; depth += 16) { + Filter3x3x16 filter = + LoadFilterDepth16(filter_ptr, filter_offset, output_depth); + + // Handle 1x2 outputs. + int out_y = out_y_start; + for (; out_y < out_y_end; out_y += num_y_outputs) { + int out_x = out_x_start; + + int in_y_offset = + stride_height * input_row_width * (out_y + pad_height); + int in_x_offset = stride_width * input_depth * (out_x + pad_width); + + const uint8* input_ptr = + input_data + depth + in_x_offset + in_y_offset + in_batch_offset; + + // Preload inputs. If input depth is large, preload every value of the + // input for this depth range. Otherwise, preload only the first values + // of each row. + if (input_depth >= 32) { + preload_l1_keep(input_ptr + i0); + preload_l1_keep(input_ptr + i1); + preload_l1_keep(input_ptr + i2); + preload_l1_keep(input_ptr + i3); + preload_l1_keep(input_ptr + i4); + preload_l1_keep(input_ptr + i5); + preload_l1_keep(input_ptr + i6); + preload_l1_keep(input_ptr + i7); + preload_l1_keep(input_ptr + i8); + preload_l1_keep(input_ptr + i9); + preload_l1_keep(input_ptr + i10); + preload_l1_keep(input_ptr + i11); + + if (stride_height == 2) { + preload_l1_keep(input_ptr + i12); + preload_l1_keep(input_ptr + i13); + preload_l1_keep(input_ptr + i14); + } + } else { + preload_l1_keep(input_ptr + i0); + preload_l1_keep(input_ptr + i3); + preload_l1_keep(input_ptr + i6); + preload_l1_keep(input_ptr + i9); + + if (stride_height == 2) { + preload_l1_keep(input_ptr + i12); + } + } + + uint8* output_ptr = output_data + depth + (out_x * output_depth) + + (output_depth * output_width * out_y) + + out_batch_offset; + + for (; out_x < out_x_end; out_x += num_x_outputs) { + dot_product_func(filter, input_ptr, input_depth, input_offset, + input_row_width, bias_ptr, output_offset, + output_multiplier, output_shift, + output_activation_min, output_activation_max, + output_ptr, output_depth, output_width); + + input_ptr += input_ptr_x_increment * num_x_outputs; + output_ptr += output_depth * num_x_outputs; + + // Preload the next inputs depending on stride. + if (stride_width == 1) { + preload_l1_keep(input_ptr + i2); + preload_l1_keep(input_ptr + i5); + preload_l1_keep(input_ptr + i8); + preload_l1_keep(input_ptr + i11); + } else if (stride_width == 2) { + preload_l1_keep(input_ptr + i1); + preload_l1_keep(input_ptr + i2); + preload_l1_keep(input_ptr + i4); + preload_l1_keep(input_ptr + i5); + preload_l1_keep(input_ptr + i7); + preload_l1_keep(input_ptr + i8); + preload_l1_keep(input_ptr + i10); + preload_l1_keep(input_ptr + i11); + preload_l1_keep(input_ptr + i13); + preload_l1_keep(input_ptr + i14); + } + } + + // Handle the rest of the right side. + for (; out_x < output_width; out_x++) { + // This code path can only be reached if we're handling >1 x outputs + // at a time or support padding. + } + } + + // Handle the rest of the bottom side. + for (; out_y < output_height; out_y++) { + int out_x = out_x_start; + + int in_y_offset = + stride_height * input_row_width * (out_y + pad_height); + int in_x_offset = stride_width * input_depth * (out_x + pad_width); + + const uint8* input_ptr = + input_data + depth + in_x_offset + in_y_offset + in_batch_offset; + + if (input_depth >= 32) { + preload_l1_keep(input_ptr + i0); + preload_l1_keep(input_ptr + i1); + preload_l1_keep(input_ptr + i2); + preload_l1_keep(input_ptr + i3); + preload_l1_keep(input_ptr + i4); + preload_l1_keep(input_ptr + i5); + preload_l1_keep(input_ptr + i6); + preload_l1_keep(input_ptr + i7); + } else { + preload_l1_keep(input_ptr + i0); + preload_l1_keep(input_ptr + i3); + preload_l1_keep(input_ptr + i6); + } + + uint8* output_ptr = output_data + depth + (out_x * output_depth) + + (output_depth * output_width * out_y) + + out_batch_offset; + + for (; out_x < output_width; out_x++) { + ConvKernel3x3FilterDepth16<1, 1>::Run( + filter, input_ptr, input_depth, input_offset, input_row_width, + bias_ptr, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_ptr, + output_depth, output_width); + + input_ptr += input_ptr_x_increment; + output_ptr += output_depth; + + if (stride_width == 1) { + preload_l1_keep(input_ptr + i2); + preload_l1_keep(input_ptr + i5); + preload_l1_keep(input_ptr + i8); + } else if (stride_width == 2) { + preload_l1_keep(input_ptr + i1); + preload_l1_keep(input_ptr + i2); + preload_l1_keep(input_ptr + i4); + preload_l1_keep(input_ptr + i5); + preload_l1_keep(input_ptr + i7); + preload_l1_keep(input_ptr + i8); + } + } + } + filter_ptr += 16; + bias_ptr += 16; + } + } +} + +#endif // __aarch64__ + +} // namespace optimized_ops +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_3X3_FILTER_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index bbeac0fbfb7532e742a0a5aa421630fcbd677575..aff47d6e4877542cec9a4ba6eb0071421a9fcc07 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -324,6 +324,198 @@ void Gemm(const Eigen::MatrixBase& lhs, const Eigen::MatrixBase& rhs, } } +#ifdef GEMMLOWP_NEON +// In the common case of batch size 1, a fully-connected node degenerates +// to a matrix*vector product. LSTM cells contain a fully-connected node; +// when quantized, this becomes a special type of GEMV operation where +// the output is 16bit-quantized, thus needs its own special path. +inline void GEMVForLstmCell(const uint8* input_data, const Dims<4>& input_dims, + const uint8* weights_data, + const Dims<4>& weights_dims, + uint8 weights_zero_point, const int32* bias_data, + const Dims<4>& bias_dims, int32 accum_multiplier, + int accum_shift, int16* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("GEMVForLstmCell"); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3), + 1); + const int input_size = input_dims.strides[3]; + const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0); + // This special fast path for quantized LSTM cells does not try to support + // odd sizes that we haven't encountered in any LSTM cell, that would + // require special code (that would go untested until any LSTM cell + // exercises it). We just guard our assumptions about size evenness with + // the following assertions. + TFLITE_DCHECK(!(output_size % 4)); + TFLITE_DCHECK(!(input_size % 8)); + const int32* bias_ptr = bias_data; + int16* output_ptr = output_data; + for (int out = 0; out < output_size; out += 4) { + int32x4_t acc_0 = vdupq_n_s32(0); + int32x4_t acc_1 = vdupq_n_s32(0); + int32x4_t acc_2 = vdupq_n_s32(0); + int32x4_t acc_3 = vdupq_n_s32(0); + const int16x8_t input_offset_vec = vdupq_n_s16(-128); + const int16x8_t weights_offset_vec = vdupq_n_s16(-weights_zero_point); + int in = 0; + // Handle 16 levels of depth at a time. + for (; in <= input_size - 16; in += 16) { + const uint8x16_t input_val_u8 = vld1q_u8(input_data + in); + const uint8* weights_ptr = weights_data + in + out * input_size; + uint8x16_t weights_val_u8_0 = vld1q_u8(weights_ptr + 0 * input_size); + uint8x16_t weights_val_u8_1 = vld1q_u8(weights_ptr + 1 * input_size); + uint8x16_t weights_val_u8_2 = vld1q_u8(weights_ptr + 2 * input_size); + uint8x16_t weights_val_u8_3 = vld1q_u8(weights_ptr + 3 * input_size); + int16x8_t input_val_0, input_val_1; + const uint8x8_t low = vget_low_u8(input_val_u8); + const uint8x8_t high = vget_high_u8(input_val_u8); + input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low)); + input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high)); + input_val_0 = vaddq_s16(input_val_0, input_offset_vec); + input_val_1 = vaddq_s16(input_val_1, input_offset_vec); + int16x8_t weights_val_0_0, weights_val_1_0, weights_val_2_0, + weights_val_3_0; + int16x8_t weights_val_0_1, weights_val_1_1, weights_val_2_1, + weights_val_3_1; + weights_val_0_0 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_0))), + weights_offset_vec); + weights_val_0_1 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_0))), + weights_offset_vec); + weights_val_1_0 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_1))), + weights_offset_vec); + weights_val_1_1 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_1))), + weights_offset_vec); + weights_val_2_0 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_2))), + weights_offset_vec); + weights_val_2_1 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_2))), + weights_offset_vec); + weights_val_3_0 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_3))), + weights_offset_vec); + weights_val_3_1 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_3))), + weights_offset_vec); + acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_0), + vget_low_s16(input_val_0)); + acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_0), + vget_low_s16(input_val_0)); + acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_0), + vget_low_s16(input_val_0)); + acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_0), + vget_low_s16(input_val_0)); + acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_0), + vget_high_s16(input_val_0)); + acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_0), + vget_high_s16(input_val_0)); + acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_0), + vget_high_s16(input_val_0)); + acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_0), + vget_high_s16(input_val_0)); + acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_1), + vget_low_s16(input_val_1)); + acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_1), + vget_low_s16(input_val_1)); + acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_1), + vget_low_s16(input_val_1)); + acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_1), + vget_low_s16(input_val_1)); + acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_1), + vget_high_s16(input_val_1)); + acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_1), + vget_high_s16(input_val_1)); + acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_1), + vget_high_s16(input_val_1)); + acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_1), + vget_high_s16(input_val_1)); + } + // Handle 8 levels of depth at a time. + for (; in < input_size; in += 8) { + const uint8x8_t input_val_u8 = vld1_u8(input_data + in); + const uint8* weights_ptr = weights_data + in + out * input_size; + uint8x8_t weights_val_u8_0 = vld1_u8(weights_ptr + 0 * input_size); + uint8x8_t weights_val_u8_1 = vld1_u8(weights_ptr + 1 * input_size); + uint8x8_t weights_val_u8_2 = vld1_u8(weights_ptr + 2 * input_size); + uint8x8_t weights_val_u8_3 = vld1_u8(weights_ptr + 3 * input_size); + int16x8_t input_val; + input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8)); + input_val = vaddq_s16(input_val, input_offset_vec); + int16x8_t weights_val_0, weights_val_1, weights_val_2, weights_val_3; + weights_val_0 = + vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_0)), + weights_offset_vec); + weights_val_1 = + vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_1)), + weights_offset_vec); + weights_val_2 = + vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_2)), + weights_offset_vec); + weights_val_3 = + vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_3)), + weights_offset_vec); + acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0), + vget_low_s16(input_val)); + acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1), + vget_low_s16(input_val)); + acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2), + vget_low_s16(input_val)); + acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3), + vget_low_s16(input_val)); + acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0), + vget_high_s16(input_val)); + acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1), + vget_high_s16(input_val)); + acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2), + vget_high_s16(input_val)); + acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3), + vget_high_s16(input_val)); + } + // Horizontally reduce accumulators + int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1, + pairwise_reduced_acc_2, pairwise_reduced_acc_3; + pairwise_reduced_acc_0 = + vpadd_s32(vget_low_s32(acc_0), vget_high_s32(acc_0)); + pairwise_reduced_acc_1 = + vpadd_s32(vget_low_s32(acc_1), vget_high_s32(acc_1)); + pairwise_reduced_acc_2 = + vpadd_s32(vget_low_s32(acc_2), vget_high_s32(acc_2)); + pairwise_reduced_acc_3 = + vpadd_s32(vget_low_s32(acc_3), vget_high_s32(acc_3)); + const int32x2_t reduced_lo = + vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1); + const int32x2_t reduced_hi = + vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3); + int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi); + // Add bias values. + int32x4_t bias_vec = vld1q_s32(bias_ptr); + bias_ptr += 4; + reduced = vaddq_s32(reduced, bias_vec); + int left_shift = accum_shift > 0 ? accum_shift : 0; + int right_shift = accum_shift > 0 ? 0 : -accum_shift; + reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift)); + // Multiply by the fixed-point multiplier. + reduced = vqrdmulhq_n_s32(reduced, accum_multiplier); + // Rounding-shift-right. + using gemmlowp::RoundingDivideByPOT; + reduced = RoundingDivideByPOT(reduced, right_shift); + // Narrow values down to 16 bit signed. + const int16x4_t res16 = vqmovn_s32(reduced); + vst1_s16(output_ptr, res16); + output_ptr += 4; + } +} +#endif + inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, const float* weights_data, const Dims<4>& weights_dims, const float* bias_data, @@ -610,6 +802,76 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, input_offset, output_pipeline); } +inline void FullyConnected( + const uint8* input_data, const Dims<4>& input_dims, int32 input_offset, + const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset, + int32 output_multiplier, int output_shift, int32 output_activation_min, + int32 output_activation_max, int16* output_data, const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + gemmlowp::ScopedProfilingLabel label("FullyConnected/Uint8Int16"); + // This is a copy of the reference implementation. We do not currently have a + // properly optimized version. + (void)gemm_context; // only used in properly optimized code. + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + TFLITE_DCHECK_EQ(output_offset, 0); + + // TODO(benoitjacob): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0); + const int accum_depth = ArraySize(filter_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + + // Implementation of the fully connected node suited to the inside of an LSTM + // cell. The operands are 8-bit integers, the accumulators are internally + // 32bit integers, and the output is 16-bit fixed-point with 3 integer bits so + // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that + // is explained in the function comment above. +#ifdef GEMMLOWP_NEON + if (batches == 1 && !(output_depth % 4) && !(accum_depth % 8) && + input_offset == -128 && output_activation_min == -32768 && + output_activation_max == 32767) { + GEMVForLstmCell(input_data, input_dims, filter_data, filter_dims, + filter_offset, bias_data_int32, bias_dims, + output_multiplier, -output_shift, output_data, output_dims); + return; + } +#endif + gemmlowp::MatrixMap weights_matrix( + filter_data, output_depth, accum_depth); + gemmlowp::MatrixMap input_matrix( + input_data, accum_depth, batches); + gemmlowp::MatrixMap output_matrix( + output_data, output_depth, batches); + typedef gemmlowp::VectorMap + ColVectorMap; + ColVectorMap bias_vector(bias_data_int32, output_depth); + gemmlowp::OutputStageBiasAddition bias_addition_stage; + bias_addition_stage.bias_vector = bias_vector; + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage; + scale_stage.result_offset_after_shift = 0; + scale_stage.result_fixedpoint_multiplier = output_multiplier; + // Note that this shift is negated wrt ordinary FC. + scale_stage.result_exponent = -output_shift; + gemmlowp::OutputStageClamp clamp_stage; + clamp_stage.min = output_activation_min; + clamp_stage.max = output_activation_max; + gemmlowp::OutputStageSaturatingCastToInt16 saturating_cast_int16_stage; + auto output_pipeline = + std::make_tuple(bias_addition_stage, scale_stage, clamp_stage, + saturating_cast_int16_stage); + gemmlowp::GemmWithOutputPipeline( + gemm_context, weights_matrix, input_matrix, &output_matrix, filter_offset, + input_offset, output_pipeline); +} + // legacy, for compatibility with old checked-in code template void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, @@ -768,6 +1030,7 @@ inline void DilatedConv(const float* input_data, const Dims<4>& input_dims, float output_activation_max, float* output_data, const Dims<4>& output_dims, float* im2col_data, const Dims<4>& im2col_dims) { + gemmlowp::ScopedProfilingLabel label("DilatedConv"); // This is a copy of the reference Conv implementation. We do not currently // have an optimized path for dilation. (void)im2col_data; // only used in optimized code. @@ -1530,6 +1793,8 @@ inline void Add(int left_shift, const uint8* input1_data, TFLITE_DCHECK_LT(input1_offset, 256); TFLITE_DCHECK_LT(input2_offset, 256); #ifdef USE_NEON + const auto output_activation_min_vector = vdup_n_u8(output_activation_min); + const auto output_activation_max_vector = vdup_n_u8(output_activation_max); for (; i <= size - 8; i += 8) { const auto input1_val_original = vld1_u8(input1_data + i); const auto input2_val_original = vld1_u8(input2_data + i); @@ -1575,7 +1840,10 @@ inline void Add(int left_shift, const uint8* input1_data, const auto s2_narrowed = vmovn_s32(s2); const auto s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed), vdupq_n_s16(output_offset)); - vst1_u8(output_data + i, vqmovun_s16(s)); + const auto clamped = + vmax_u8(output_activation_min_vector, + vmin_u8(output_activation_max_vector, vqmovun_s16(s))); + vst1_u8(output_data + i, clamped); } #endif // NEON @@ -1598,6 +1866,52 @@ inline void Add(int left_shift, const uint8* input1_data, } } +template +inline void Add(const int16* input1_data, const Dims<4>& input1_dims, + int input1_shift, const int16* input2_data, + const Dims<4>& input2_dims, int input2_shift, + int16 output_activation_min, int16 output_activation_max, + int16* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Add/Int16"); + // This is a copy of the reference implementation. We do not currently have a + // properly optimized version. + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, -32768); + TFLITE_DCHECK_EQ(output_activation_max, 32767); + } + + const int flat_size = RequiredBufferSizeForDims(output_dims); + TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input1_dims), flat_size); + TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input2_dims), flat_size); + + TFLITE_DCHECK(input1_shift == 0 || input2_shift == 0); + TFLITE_DCHECK_GE(input1_shift, 0); + TFLITE_DCHECK_GE(input2_shift, 0); + const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data; + const int16* shift_input = input1_shift == 0 ? input2_data : input1_data; + const int input_shift = input1_shift == 0 ? input2_shift : input1_shift; + + for (int i = 0; i < flat_size; i++) { + // F0 uses 0 integer bits, range [-1, 1]. + using F0 = gemmlowp::FixedPoint; + + F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]); + F0 scaled_input = + F0::FromRaw(gemmlowp::RoundingDivideByPOT(shift_input[i], input_shift)); + F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled); + const int16 raw_output = result.raw(); + const int16 clamped_output = std::min( + output_activation_max, std::max(output_activation_min, raw_output)); + output_data[i] = clamped_output; + } +} + template void Add(const int32* input1_data, const Dims<4>& input1_dims, const int32* input2_data, const Dims<4>& input2_dims, @@ -1872,6 +2186,57 @@ void Mul(const int32* input1_data, const Dims<4>& input1_dims, } } +inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, + const int16* input2_data, const Dims<4>& input2_dims, + int16* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Mul/Int16"); + // This is a copy of the reference implementation. We do not currently have a + // properly optimized version. + + const int flat_size = RequiredBufferSizeForDims(output_dims); + TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input1_dims), flat_size); + TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input2_dims), flat_size); + + for (int i = 0; i < flat_size; i++) { + // F0 uses 0 integer bits, range [-1, 1]. + using F0 = gemmlowp::FixedPoint; + + F0 unclamped_result = + F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]); + output_data[i] = unclamped_result.raw(); + } +} + +inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, + const int16* input2_data, const Dims<4>& input2_dims, + int32 output_offset, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Mul/Int16Uint8"); + // This is a copy of the reference implementation. We do not currently have a + // properly optimized version. + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + + const int flat_size = RequiredBufferSizeForDims(output_dims); + TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input1_dims), flat_size); + TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input2_dims), flat_size); + + for (int i = 0; i < flat_size; i++) { + // F0 uses 0 integer bits, range [-1, 1]. + using F0 = gemmlowp::FixedPoint; + + F0 unclamped_result = + F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]); + int16 rescaled_result = + gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8); + int16 clamped_result = + std::min(output_activation_max - output_offset, rescaled_result); + clamped_result = + std::max(output_activation_min - output_offset, clamped_result); + output_data[i] = output_offset + clamped_result; + } +} + // TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then @@ -2181,10 +2546,10 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne( shifted_input2_val, input2_multiplier, input2_shift); - const int32 raw_sum = scaled_input1_val - scaled_input2_val; + const int32 raw_sub = scaled_input1_val - scaled_input2_val; const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + + raw_sub, output_multiplier, output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -2197,34 +2562,6 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, } } -template -inline void BroadcastSub(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - BroadcastSub(left_shift, input1_data, input1_dims, input1_offset, - input1_multiplier, input1_shift, input2_data, input2_dims, - input2_offset, input2_multiplier, input2_shift, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_data, output_dims); -} - template void Concatenation(int concat_dim, const Scalar* const* input_data, const Dims<4>* const* input_dims, int inputs_count, @@ -2351,198 +2688,6 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, output_state_map.tanh(); } -#ifdef GEMMLOWP_NEON -// In the common case of batch size 1, a fully-connected node degenerates -// to a matrix*vector product. LSTM cells contain a fully-connected node; -// when quantized, this becomes a special type of GEMV operation where -// the output is 16bit-quantized, thus needs its own special path. -inline void GEMVForLstmCell(const uint8* input_data, const Dims<4>& input_dims, - const uint8* weights_data, - const Dims<4>& weights_dims, - uint8 weights_zero_point, const int32* bias_data, - const Dims<4>& bias_dims, int32 accum_multiplier, - int accum_shift, int16* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("GEMVForLstmCell"); - TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); - TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * - ArraySize(output_dims, 3), - 1); - const int input_size = input_dims.strides[3]; - const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0); - // This special fast path for quantized LSTM cells does not try to support - // odd sizes that we haven't encountered in any LSTM cell, that would - // require special code (that would go untested until any LSTM cell - // exercises it). We just guard our assumptions about size evenness with - // the following assertions. - TFLITE_DCHECK(!(output_size % 4)); - TFLITE_DCHECK(!(input_size % 8)); - const int32* bias_ptr = bias_data; - int16* output_ptr = output_data; - for (int out = 0; out < output_size; out += 4) { - int32x4_t acc_0 = vdupq_n_s32(0); - int32x4_t acc_1 = vdupq_n_s32(0); - int32x4_t acc_2 = vdupq_n_s32(0); - int32x4_t acc_3 = vdupq_n_s32(0); - const int16x8_t input_offset_vec = vdupq_n_s16(-128); - const int16x8_t weights_offset_vec = vdupq_n_s16(-weights_zero_point); - int in = 0; - // Handle 16 levels of depth at a time. - for (; in <= input_size - 16; in += 16) { - const uint8x16_t input_val_u8 = vld1q_u8(input_data + in); - const uint8* weights_ptr = weights_data + in + out * input_size; - uint8x16_t weights_val_u8_0 = vld1q_u8(weights_ptr + 0 * input_size); - uint8x16_t weights_val_u8_1 = vld1q_u8(weights_ptr + 1 * input_size); - uint8x16_t weights_val_u8_2 = vld1q_u8(weights_ptr + 2 * input_size); - uint8x16_t weights_val_u8_3 = vld1q_u8(weights_ptr + 3 * input_size); - int16x8_t input_val_0, input_val_1; - const uint8x8_t low = vget_low_u8(input_val_u8); - const uint8x8_t high = vget_high_u8(input_val_u8); - input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low)); - input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high)); - input_val_0 = vaddq_s16(input_val_0, input_offset_vec); - input_val_1 = vaddq_s16(input_val_1, input_offset_vec); - int16x8_t weights_val_0_0, weights_val_1_0, weights_val_2_0, - weights_val_3_0; - int16x8_t weights_val_0_1, weights_val_1_1, weights_val_2_1, - weights_val_3_1; - weights_val_0_0 = vaddq_s16( - vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_0))), - weights_offset_vec); - weights_val_0_1 = vaddq_s16( - vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_0))), - weights_offset_vec); - weights_val_1_0 = vaddq_s16( - vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_1))), - weights_offset_vec); - weights_val_1_1 = vaddq_s16( - vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_1))), - weights_offset_vec); - weights_val_2_0 = vaddq_s16( - vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_2))), - weights_offset_vec); - weights_val_2_1 = vaddq_s16( - vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_2))), - weights_offset_vec); - weights_val_3_0 = vaddq_s16( - vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_3))), - weights_offset_vec); - weights_val_3_1 = vaddq_s16( - vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_3))), - weights_offset_vec); - acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_0), - vget_low_s16(input_val_0)); - acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_0), - vget_low_s16(input_val_0)); - acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_0), - vget_low_s16(input_val_0)); - acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_0), - vget_low_s16(input_val_0)); - acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_0), - vget_high_s16(input_val_0)); - acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_0), - vget_high_s16(input_val_0)); - acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_0), - vget_high_s16(input_val_0)); - acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_0), - vget_high_s16(input_val_0)); - acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_1), - vget_low_s16(input_val_1)); - acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_1), - vget_low_s16(input_val_1)); - acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_1), - vget_low_s16(input_val_1)); - acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_1), - vget_low_s16(input_val_1)); - acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_1), - vget_high_s16(input_val_1)); - acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_1), - vget_high_s16(input_val_1)); - acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_1), - vget_high_s16(input_val_1)); - acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_1), - vget_high_s16(input_val_1)); - } - // Handle 8 levels of depth at a time. - for (; in < input_size; in += 8) { - const uint8x8_t input_val_u8 = vld1_u8(input_data + in); - const uint8* weights_ptr = weights_data + in + out * input_size; - uint8x8_t weights_val_u8_0 = vld1_u8(weights_ptr + 0 * input_size); - uint8x8_t weights_val_u8_1 = vld1_u8(weights_ptr + 1 * input_size); - uint8x8_t weights_val_u8_2 = vld1_u8(weights_ptr + 2 * input_size); - uint8x8_t weights_val_u8_3 = vld1_u8(weights_ptr + 3 * input_size); - int16x8_t input_val; - input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8)); - input_val = vaddq_s16(input_val, input_offset_vec); - int16x8_t weights_val_0, weights_val_1, weights_val_2, weights_val_3; - weights_val_0 = - vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_0)), - weights_offset_vec); - weights_val_1 = - vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_1)), - weights_offset_vec); - weights_val_2 = - vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_2)), - weights_offset_vec); - weights_val_3 = - vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_3)), - weights_offset_vec); - acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0), - vget_low_s16(input_val)); - acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1), - vget_low_s16(input_val)); - acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2), - vget_low_s16(input_val)); - acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3), - vget_low_s16(input_val)); - acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0), - vget_high_s16(input_val)); - acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1), - vget_high_s16(input_val)); - acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2), - vget_high_s16(input_val)); - acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3), - vget_high_s16(input_val)); - } - // Horizontally reduce accumulators - int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1, - pairwise_reduced_acc_2, pairwise_reduced_acc_3; - pairwise_reduced_acc_0 = - vpadd_s32(vget_low_s32(acc_0), vget_high_s32(acc_0)); - pairwise_reduced_acc_1 = - vpadd_s32(vget_low_s32(acc_1), vget_high_s32(acc_1)); - pairwise_reduced_acc_2 = - vpadd_s32(vget_low_s32(acc_2), vget_high_s32(acc_2)); - pairwise_reduced_acc_3 = - vpadd_s32(vget_low_s32(acc_3), vget_high_s32(acc_3)); - const int32x2_t reduced_lo = - vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1); - const int32x2_t reduced_hi = - vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3); - int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi); - // Add bias values. - int32x4_t bias_vec = vld1q_s32(bias_ptr); - bias_ptr += 4; - reduced = vaddq_s32(reduced, bias_vec); - int left_shift = accum_shift > 0 ? accum_shift : 0; - int right_shift = accum_shift > 0 ? 0 : -accum_shift; - reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift)); - // Multiply by the fixed-point multiplier. - reduced = vqrdmulhq_n_s32(reduced, accum_multiplier); - // Rounding-shift-right. - using gemmlowp::RoundingDivideByPOT; - reduced = RoundingDivideByPOT(reduced, right_shift); - // Narrow values down to 16 bit signed. - const int16x4_t res16 = vqmovn_s32(reduced); - vst1_s16(output_ptr, res16); - output_ptr += 4; - } -} -#endif - // Quantized LSTM cell. Currently just a copy of the reference impl in // reference_ops.h. See the big function comment there, not replicating it // here. @@ -3809,6 +3954,28 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, } } +inline void Logistic(const int16* input_data, const Dims<4>& input_dims, + int16* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Logistic/Int16"); + // This is a copy of the reference implementation. We do not currently have a + // properly optimized version. + const int flat_size = RequiredBufferSizeForDims(output_dims); + TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input_dims), flat_size); + + for (int i = 0; i < flat_size; i++) { + // F0 uses 0 integer bits, range [-1, 1]. + // This is the return type of math functions such as tanh, logistic, + // whose range is in [-1, 1]. + using F0 = gemmlowp::FixedPoint; + // F3 uses 3 integer bits, range [-8, 8], the input range expected here. + using F3 = gemmlowp::FixedPoint; + + const F3 input = F3::FromRaw(input_data[i]); + F0 output = gemmlowp::logistic(input); + output_data[i] = output.raw(); + } +} + inline void Tanh(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("Tanh"); @@ -3967,6 +4134,45 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, output_data[c] = output_val; } } + +inline void Tanh(const int16* input_data, const Dims<4>& input_dims, + int input_left_shift, int16* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Tanh/Int16"); + // This is a copy of the reference implementation. We do not currently have a + // properly optimized version. + + // Support for shifts is limited until we have a parameterized version of + // SaturatingRoundingMultiplyByPOT(). + TFLITE_DCHECK_GE(input_left_shift, 0); + TFLITE_DCHECK_LE(input_left_shift, 1); + + const int flat_size = RequiredBufferSizeForDims(output_dims); + TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input_dims), flat_size); + + // F0 uses 0 integer bits, range [-1, 1]. + // This is the return type of math functions such as tanh, logistic, + // whose range is in [-1, 1]. + using F0 = gemmlowp::FixedPoint; + // F3 uses 3 integer bits, range [-8, 8], the input range expected here. + using F3 = gemmlowp::FixedPoint; + + if (input_left_shift == 0) { + for (int i = 0; i < flat_size; i++) { + F3 input = F3::FromRaw(input_data[i]); + F0 output = gemmlowp::tanh(input); + output_data[i] = output.raw(); + } + } else { + for (int i = 0; i < flat_size; i++) { + F3 input = F3::FromRaw( + gemmlowp::SaturatingRoundingMultiplyByPOT<1>(input_data[i])); + F0 output = gemmlowp::tanh(input); + output_data[i] = output.raw(); + } + } +} + inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, int32 zero_point, double scale, float* output_data, const Dims<4>& output_dims) { @@ -4903,6 +5109,78 @@ void Transpose(const T* input, const Dims<4>& input_dims, T* output, } } +inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("TransposeConv"); + // THIS FUNCTION IS A COPY FROM reference_ops.h. + // To optimize, start by using the conv code with transposed weights for the + // case of stride_height = stride_width = 1. + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + + // Although transpose convolution simplifies to convolution with transposed + // weights for strides of 1, non-unitary striding complicates matters. To + // keep this reference implementation as clear as possible, we use a "scatter" + // access pattern, where we loop through all the input elements, computing + // their influence on the output, rather than looping through the output + // elements in the typical "gather" access pattern of a conv. We therefore + // must initialize the output array to zero. + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int out_channel = 0; out_channel < output_depth; ++out_channel) { + output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] = + 0.0f; + } + } + } + } + + // Loop through input elements one at a time. + for (int batch = 0; batch < batches; ++batch) { + for (int in_y = 0; in_y < input_height; ++in_y) { + for (int in_x = 0; in_x < input_width; ++in_x) { + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + // Loop through the output elements it will influence + const int out_x_origin = (in_x * stride_width) - pad_width; + const int out_y_origin = (in_y * stride_height) - pad_height; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + for (int out_channel = 0; out_channel < input_depth; + ++out_channel) { + // Compute output element location + const int out_x = out_x_origin + filter_x; + const int out_y = out_y_origin + filter_y; + // We cannot accumulate out of bounds + if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) && + (out_y < output_height)) { + float input_value = input_data[Offset(input_dims, in_channel, + in_x, in_y, batch)]; + float filter_value = + filter_data[Offset(filter_dims, out_channel, filter_x, + filter_y, in_channel)]; + output_data[Offset(output_dims, out_channel, out_x, out_y, + batch)] += input_value * filter_value; + } + } + } + } + } + } + } + } +} + } // namespace optimized_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h index ba06bc0975b6847b24592daa60efe99983d03707..9a04b76e56b2527b06f5b0ec1e75e991fd1cbdea 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -12,13 +12,156 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_ -#define PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ +#include #include +#include + +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/internal/round.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" namespace tflite { +// Given the min and max values of a float array, return +// reasonable quantization parameters to use for this array. +template +QuantizationParams ChooseQuantizationParams(double rmin, double rmax) { + const T qmin = std::numeric_limits::min(); + const T qmax = std::numeric_limits::max(); + const double qmin_double = qmin; + const double qmax_double = qmax; + // 0 should always be a representable value. Let's assume that the initial + // min,max range contains 0. + TFLITE_CHECK_LE(rmin, 0.); + TFLITE_CHECK_GE(rmax, 0.); + if (rmin == rmax) { + // Special case where the min,max range is a point. Should be {0}. + TFLITE_CHECK_EQ(rmin, 0.); + TFLITE_CHECK_EQ(rmax, 0.); + QuantizationParams quantization_params; + quantization_params.zero_point = 0; + quantization_params.scale = 0.; + return quantization_params; + } + + // General case. + // + // First determine the scale. + const double scale = (rmax - rmin) / (qmax_double - qmin_double); + + // Zero-point computation. + // First the initial floating-point computation. The zero-point can be + // determined from solving an affine equation for any known pair + // (real value, corresponding quantized value). + // We know two such pairs: (rmin, qmin) and (rmax, qmax). + // The arithmetic error on the zero point computed from either pair + // will be roughly machine_epsilon * (sum of absolute values of terms) + // so we want to use the variant that adds the smaller terms. + const double zero_point_from_min = qmin_double - rmin / scale; + const double zero_point_from_max = qmax_double - rmax / scale; + const double zero_point_from_min_error = + std::abs(qmin_double) + std::abs(rmin / scale); + const double zero_point_from_max_error = + std::abs(qmax_double) + std::abs(rmax / scale); + + const double zero_point_double = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Now we need to nudge the zero point to be an integer + // (our zero points are integer, and this is motivated by the requirement + // to be able to represent the real value "0" exactly as a quantized value, + // which is required in multiple places, for example in Im2col with SAME + // padding). + T nudged_zero_point = 0; + if (zero_point_double < qmin_double) { + nudged_zero_point = qmin; + } else if (zero_point_double > qmax_double) { + nudged_zero_point = qmax; + } else { + nudged_zero_point = static_cast(round(zero_point_double)); + } + // The zero point should always be in the range of quantized value, + // [qmin, qmax]. + TFLITE_CHECK_GE(nudged_zero_point, qmin); + TFLITE_CHECK_LE(nudged_zero_point, qmax); + + // Finally, store the result nudged quantization params. + QuantizationParams quantization_params; + quantization_params.zero_point = nudged_zero_point; + quantization_params.scale = scale; + return quantization_params; +} + +// Converts a floating-point number to an integer. For all inputs x where +// static_cast(x) is legal according to the C++ standard, the result +// is identical to that cast (i.e. the result is x with its fractional part +// truncated whenever that is representable as IntOut). +// +// static_cast would cause undefined behavior for the following cases, which +// have well-defined behavior for this function: +// +// 1. If x is NaN, the result is zero. +// +// 2. If the truncated form of x is above the representable range of IntOut, +// the result is std::numeric_limits::max(). +// +// 3. If the truncated form of x is below the representable range of IntOut, +// the result is std::numeric_limits::min(). +// +// Note that cases #2 and #3 cover infinities as well as finite numbers. +// +// The range of FloatIn must include the range of IntOut, otherwise +// the results are undefined. +// TODO(sfeuz): Replace by absl::SafeCast once available. +template +IntOut SafeCast(FloatIn x) { + static_assert(!std::numeric_limits::is_integer, + "FloatIn is integer"); + static_assert(std::numeric_limits::is_integer, + "IntOut is not integer"); + static_assert(std::numeric_limits::radix == 2, "IntOut is base 2"); + + // Special case NaN, for which the logic below doesn't work. + if (std::isnan(x)) { + return 0; + } + + // Negative values all clip to zero for unsigned results. + if (!std::numeric_limits::is_signed && x < 0) { + return 0; + } + + // Handle infinities. + if (std::isinf(x)) { + return x < 0 ? std::numeric_limits::min() + : std::numeric_limits::max(); + } + + // Set exp such that x == f * 2^exp for some f with |f| in [0.5, 1.0), + // unless x is zero in which case exp == 0. Note that this implies that the + // magnitude of x is strictly less than 2^exp. + int exp = 0; + std::frexp(x, &exp); + + // Let N be the number of non-sign bits in the representation of IntOut. If + // the magnitude of x is strictly less than 2^N, the truncated version of x + // is representable as IntOut. The only representable integer for which this + // is not the case is kMin for signed types (i.e. -2^N), but that is covered + // by the fall-through below. + if (exp <= std::numeric_limits::digits) { + return x; + } + + // Handle numbers with magnitude >= 2^N. + return x < 0 ? std::numeric_limits::min() + : std::numeric_limits::max(); +} + // Decompose a double multiplier into a Q0.31 int32 representation of its // significand, and shift representation of NEGATIVE its exponent --- // this is intended as a RIGHT-shift. @@ -57,10 +200,10 @@ void PreprocessSoftmaxScaling(double beta, double input_scale, // Calculate the largest input that will result in a within-bounds intermediate // result within MultiplyByQuantizedMultiplierGreaterThanOne. In other words, // it must not overflow before we reduce the value by multiplication by the -// input multiplier. The negative radius is used as the minimum difference -// in Softmax. +// input multiplier. The negative radius is used as the minimum difference in +// Softmax. int CalculateInputRadius(int input_integer_bits, int input_left_shift); } // namespace tflite -#endif // PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc index 19b1b408ec74b0939065b0ad10b91ecfc2cd4765..3e9a3c29ee26e96612bb05eb9cd1e1badad10c7a 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc @@ -22,6 +22,177 @@ namespace { using ::testing::Pair; +template +void RunSafeCastTests() { + const IntOut imax = std::numeric_limits::max(); + EXPECT_GT(imax, 0); + const IntOut imin = std::numeric_limits::min(); + const bool s = std::numeric_limits::is_signed; + if (s) { + EXPECT_LT(imin, 0); + } else { + EXPECT_EQ(0, imin); + } + + // Some basic tests. + EXPECT_EQ(SafeCast(static_cast(0.0)), 0); + EXPECT_EQ(SafeCast(static_cast(-0.0)), 0); + EXPECT_EQ(SafeCast(static_cast(0.99)), 0); + EXPECT_EQ(SafeCast(static_cast(1.0)), 1); + EXPECT_EQ(SafeCast(static_cast(1.01)), 1); + EXPECT_EQ(SafeCast(static_cast(1.99)), 1); + EXPECT_EQ(SafeCast(static_cast(2.0)), 2); + EXPECT_EQ(SafeCast(static_cast(2.01)), 2); + EXPECT_EQ(SafeCast(static_cast(-0.99)), 0); + EXPECT_EQ(SafeCast(static_cast(-1.0)), s ? -1 : 0); + EXPECT_EQ(SafeCast(static_cast(-1.01)), s ? -1 : 0); + EXPECT_EQ(SafeCast(static_cast(-1.99)), s ? -1 : 0); + EXPECT_EQ(SafeCast(static_cast(-2.0)), s ? -2 : 0); + EXPECT_EQ(SafeCast(static_cast(-2.01)), s ? -2 : 0); + EXPECT_EQ(SafeCast(static_cast(117.9)), 117); + EXPECT_EQ(SafeCast(static_cast(118.0)), 118); + EXPECT_EQ(SafeCast(static_cast(118.1)), 118); + EXPECT_EQ(SafeCast(static_cast(-117.9)), s ? -117 : 0); + EXPECT_EQ(SafeCast(static_cast(-118.0)), s ? -118 : 0); + EXPECT_EQ(SafeCast(static_cast(-118.1)), s ? -118 : 0); + + // Some edge cases. + EXPECT_EQ(SafeCast(std::numeric_limits::max()), imax); + EXPECT_EQ(SafeCast(std::numeric_limits::lowest()), imin); + EXPECT_EQ(SafeCast(std::numeric_limits::infinity()), imax); + EXPECT_EQ(SafeCast(-std::numeric_limits::infinity()), imin); + EXPECT_EQ(SafeCast(std::numeric_limits::quiet_NaN()), 0); + + // Some larger numbers. + if (sizeof(IntOut) >= 4 && sizeof(FloatIn) > 4) { + EXPECT_EQ(SafeCast(static_cast(0x76543210)), 0x76543210); + } + + if (sizeof(FloatIn) > sizeof(IntOut)) { + // Check values near imax. + EXPECT_EQ(SafeCast( + static_cast(static_cast(imax) + 0.1)), + imax); + EXPECT_EQ(SafeCast( + static_cast(static_cast(imax) + 0.99)), + imax); + EXPECT_EQ(SafeCast( + static_cast(static_cast(imax) + 1.0)), + imax); + EXPECT_EQ(SafeCast( + static_cast(static_cast(imax) + 1.99)), + imax); + EXPECT_EQ(SafeCast( + static_cast(static_cast(imax) + 2.0)), + imax); + EXPECT_EQ(SafeCast( + static_cast(static_cast(imax) - 0.1)), + imax - 1); + EXPECT_EQ(SafeCast( + static_cast(static_cast(imax) - 0.99)), + imax - 1); + EXPECT_EQ(SafeCast( + static_cast(static_cast(imax) - 1.0)), + imax - 1); + EXPECT_EQ(SafeCast( + static_cast(static_cast(imax) - 1.01)), + imax - 2); + EXPECT_EQ(SafeCast( + static_cast(static_cast(imax) - 1.99)), + imax - 2); + EXPECT_EQ(SafeCast( + static_cast(static_cast(imax) - 2.0)), + imax - 2); + EXPECT_EQ(SafeCast( + static_cast(static_cast(imax) - 2.01)), + imax - 3); + } + + // Check values considerably larger in magnitude than imin and imax + EXPECT_EQ( + SafeCast(static_cast(static_cast(imax) * 2)), + imax); + EXPECT_EQ( + SafeCast(static_cast(static_cast(imax) * 20)), + imax); + EXPECT_EQ( + SafeCast(static_cast(static_cast(imax) * 100)), + imax); + EXPECT_EQ( + SafeCast(static_cast(static_cast(imin) * 2)), + imin); + EXPECT_EQ( + SafeCast(static_cast(static_cast(imin) * 20)), + imin); + EXPECT_EQ( + SafeCast(static_cast(static_cast(imin) * 100)), + imin); +} + +TEST(QuantizationUtilTest, SafeCast) { + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); +} + +// Example taken from http://www.tensorflow.org/performance/quantization +// +// Quantized | Float +// --------- | ----- +// 0 | -10.0 +// 255 | 30.0 +// 128 | 10.0 +TEST(QuantizationUtilTest, ChooseQuantizationParams) { + QuantizationParams qp = ChooseQuantizationParams(-10.0, 30.0); + EXPECT_NEAR(qp.scale, 0.156863, 1e-5); + EXPECT_EQ(qp.zero_point, 64); +} + +TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMinBoundary) { + QuantizationParams qp = ChooseQuantizationParams(0.0, 30.0); + EXPECT_NEAR(qp.scale, 0.117647, 1e-5); + EXPECT_EQ(qp.zero_point, 0); +} + +TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroNotInRange) { + // Assumption is that zero is within the range. + EXPECT_DEATH(ChooseQuantizationParams(10.0, 30.0), ""); +} + +TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangePositive) { + // Assumption is that zero is within the range. + EXPECT_DEATH(ChooseQuantizationParams(30.0, 30.0), ""); +} + +TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangeZero) { + QuantizationParams qp = ChooseQuantizationParams(0.0, 0.0); + EXPECT_NEAR(qp.scale, 0.0, 1e-5); + EXPECT_EQ(qp.zero_point, 0); +} + +TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMaxBoundary) { + QuantizationParams qp = ChooseQuantizationParams(-10.0, 0.0); + EXPECT_NEAR(qp.scale, 0.039216, 1e-5); + EXPECT_EQ(qp.zero_point, 255); +} + +TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) { + EXPECT_DEATH(ChooseQuantizationParams(10.0, -30.0), ""); +} + TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOne) { auto quantize = [](double d) { int32_t q; diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 192e169d0ac163255038f159c5abe70bc788673d..4509db06fd757c73a2eb4edc81e205e25738cb06 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -404,6 +404,7 @@ inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, const int in_d = out_d + ((out_h % block_size) * block_size + out_w % block_size) * output_depth; + const int in_w = out_w / block_size; const int in_h = out_h / block_size; const int in_b = out_b; @@ -551,6 +552,55 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, } } +inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, int16* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + (void)gemm_context; // only used in optimized code. + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + TFLITE_DCHECK_EQ(output_offset, 0); + // TODO(benoitjacob): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0); + const int accum_depth = ArraySize(filter_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + // Internal accumulation. + // Initialize accumulator with the bias-value. + int32 accum = bias_data[out_c]; + // Accumulation loop. + for (int d = 0; d < accum_depth; ++d) { + int16 input_val = input_data[b * accum_depth + d] + input_offset; + int16 filter_val = filter_data[out_c * accum_depth + d] + filter_offset; + accum += filter_val * input_val; + } + // Down-scale the final int32 accumulator to the scale used by our + // (16-bit, typically 3 integer bits) fixed-point format. The quantized + // multiplier and shift here have been pre-computed offline + // (e.g. by toco). + accum = MultiplyByQuantizedMultiplier(accum, output_multiplier, + -output_shift); + // Saturate, cast to int16, and store to output array. + accum = std::max(accum, output_activation_min - output_offset); + accum = std::min(accum, output_activation_max - output_offset); + accum += output_offset; + output_data[out_c + output_depth * b] = accum; + } + } +} + // legacy, for compatibility with old checked-in code template void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, @@ -903,6 +953,49 @@ inline void Add(int left_shift, const uint8* input1_data, } } +template +inline void Add(const int16* input1_data, const Dims<4>& input1_dims, + int input1_shift, const int16* input2_data, + const Dims<4>& input2_dims, int input2_shift, + int16 output_activation_min, int16 output_activation_max, + int16* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, -32768); + TFLITE_DCHECK_EQ(output_activation_max, 32767); + } + + const int flat_size = RequiredBufferSizeForDims(output_dims); + TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input1_dims), flat_size); + TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input2_dims), flat_size); + + TFLITE_DCHECK(input1_shift == 0 || input2_shift == 0); + TFLITE_DCHECK_GE(input1_shift, 0); + TFLITE_DCHECK_GE(input2_shift, 0); + const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data; + const int16* shift_input = input1_shift == 0 ? input2_data : input1_data; + const int input_shift = input1_shift == 0 ? input2_shift : input1_shift; + + for (int i = 0; i < flat_size; i++) { + // F0 uses 0 integer bits, range [-1, 1]. + using F0 = gemmlowp::FixedPoint; + + F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]); + F0 scaled_input = + F0::FromRaw(gemmlowp::RoundingDivideByPOT(shift_input[i], input_shift)); + F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled); + const int16 raw_output = result.raw(); + const int16 clamped_output = std::min( + output_activation_max, std::max(output_activation_min, raw_output)); + output_data[i] = clamped_output; + } +} + // TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then @@ -1184,6 +1277,53 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, } } +inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, + const int16* input2_data, const Dims<4>& input2_dims, + int16* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Mul/Int16"); + + const int flat_size = RequiredBufferSizeForDims(output_dims); + TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input1_dims), flat_size); + TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input2_dims), flat_size); + + for (int i = 0; i < flat_size; i++) { + // F0 uses 0 integer bits, range [-1, 1]. + using F0 = gemmlowp::FixedPoint; + + F0 unclamped_result = + F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]); + output_data[i] = unclamped_result.raw(); + } +} + +inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, + const int16* input2_data, const Dims<4>& input2_dims, + int32 output_offset, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Mul/Int16Uint8"); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + + const int flat_size = RequiredBufferSizeForDims(output_dims); + TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input1_dims), flat_size); + TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input2_dims), flat_size); + + for (int i = 0; i < flat_size; i++) { + // F0 uses 0 integer bits, range [-1, 1]. + using F0 = gemmlowp::FixedPoint; + + F0 unclamped_result = + F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]); + int16 rescaled_result = + gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8); + int16 clamped_result = + std::min(output_activation_max - output_offset, rescaled_result); + clamped_result = + std::max(output_activation_min - output_offset, clamped_result); + output_data[i] = output_offset + clamped_result; + } +} + // legacy, for compatibility with old checked-in code template inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, @@ -1199,6 +1339,47 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, output_data, output_dims); } +// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary +// dimensionality if the runtime code does a single loop over one dimension +// that handles broadcasting as the base case. The code generator would then +// generate max(D1, D2) nested for loops. +template +void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, + T* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastDiv"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest + // stride, typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for + // the best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] / + input2_data[SubscriptToIndex(desc2, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + inline void Div(const float* input1_data, const Dims<4>& input1_dims, const float* input2_data, const Dims<4>& input2_dims, float output_activation_min, float output_activation_max, @@ -1425,10 +1606,10 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne( shifted_input2_val, input2_multiplier, input2_shift); - const int32 raw_sum = scaled_input1_val - scaled_input2_val; + const int32 raw_sub = scaled_input1_val - scaled_input2_val; const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + + raw_sub, output_multiplier, output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -1441,34 +1622,6 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, } } -template -inline void BroadcastSub(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - BroadcastSub(left_shift, input1_data, input1_dims, input1_offset, - input1_multiplier, input1_shift, input2_data, input2_dims, - input2_offset, input2_multiplier, input2_shift, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_data, output_dims); -} - template void Concatenation(int concat_dim, const Scalar* const* input_data, const Dims<4>* const* input_dims, int inputs_count, @@ -2533,11 +2686,13 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled); const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4); + // Convert from Q0.31 to Q23.8. using gemmlowp::RoundingDivideByPOT; int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23); if (output_val_s32 == 256) { output_val_s32 = 255; } + // Reinterpret as U0.8. TFLITE_DCHECK_GE(output_val_s32, 0); TFLITE_DCHECK_LE(output_val_s32, 255); output_val = static_cast(output_val_s32); @@ -2549,6 +2704,25 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, } } +inline void Logistic(const int16* input_data, const Dims<4>& input_dims, + int16* output_data, const Dims<4>& output_dims) { + const int flat_size = RequiredBufferSizeForDims(output_dims); + TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input_dims), flat_size); + + for (int i = 0; i < flat_size; i++) { + // F0 uses 0 integer bits, range [-1, 1]. + // This is the return type of math functions such as tanh, logistic, + // whose range is in [-1, 1]. + using F0 = gemmlowp::FixedPoint; + // F3 uses 3 integer bits, range [-8, 8], the input range expected here. + using F3 = gemmlowp::FixedPoint; + + const F3 input = F3::FromRaw(input_data[i]); + F0 output = gemmlowp::logistic(input); + output_data[i] = output.raw(); + } +} + inline void Tanh(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); @@ -2598,13 +2772,14 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled); const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4); - + // Convert from Q0.31 to Q24.7. using gemmlowp::RoundingDivideByPOT; int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24); output_val_s32 += output_zero_point; if (output_val_s32 == 256) { output_val_s32 = 255; } + // Reinterpret as Q0.7, encoded in uint8. TFLITE_DCHECK_GE(output_val_s32, 0); TFLITE_DCHECK_LE(output_val_s32, 255); output_val = static_cast(output_val_s32); @@ -2616,6 +2791,40 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, } } +inline void Tanh(const int16* input_data, const Dims<4>& input_dims, + int input_left_shift, int16* output_data, + const Dims<4>& output_dims) { + // Support for shifts is limited until we have a parameterized version of + // SaturatingRoundingMultiplyByPOT(). + TFLITE_DCHECK_GE(input_left_shift, 0); + TFLITE_DCHECK_LE(input_left_shift, 1); + + const int flat_size = RequiredBufferSizeForDims(output_dims); + TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input_dims), flat_size); + + // F0 uses 0 integer bits, range [-1, 1]. + // This is the return type of math functions such as tanh, logistic, + // whose range is in [-1, 1]. + using F0 = gemmlowp::FixedPoint; + // F3 uses 3 integer bits, range [-8, 8], the input range expected here. + using F3 = gemmlowp::FixedPoint; + + if (input_left_shift == 0) { + for (int i = 0; i < flat_size; i++) { + F3 input = F3::FromRaw(input_data[i]); + F0 output = gemmlowp::tanh(input); + output_data[i] = output.raw(); + } + } else { + for (int i = 0; i < flat_size; i++) { + F3 input = F3::FromRaw( + gemmlowp::SaturatingRoundingMultiplyByPOT<1>(input_data[i])); + F0 output = gemmlowp::tanh(input); + output_data[i] = output.raw(); + } + } +} + inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, int32 zero_point, double scale, float* output_data, const Dims<4>& output_dims) { @@ -3115,9 +3324,11 @@ inline void Mean(T* input_data, const int* input_dims, const int input_num_dims, for (int idx = 0; idx < num_resolved_axis; ++idx) { num_elements_in_axis *= static_cast(input_dims[resolved_axis[idx]]); } - for (size_t idx = 0; idx < num_outputs; ++idx) { - output_data[idx] = static_cast(static_cast(output_data[idx]) / - num_elements_in_axis); + if (num_elements_in_axis > 0) { + for (size_t idx = 0; idx < num_outputs; ++idx) { + output_data[idx] = static_cast(static_cast(output_data[idx]) / + num_elements_in_axis); + } } } @@ -3235,6 +3446,30 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, } } +template +void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + auto out_idx = Offset(output_dims, c, x, y, b); + auto in1_idx = SubscriptToIndex(desc1, c, x, y, b); + auto in2_idx = SubscriptToIndex(desc2, c, x, y, b); + auto in1_val = input1_data[in1_idx]; + auto in2_val = input2_data[in2_idx]; + output_data[out_idx] = in1_val > in2_val ? in1_val : in2_val; + } + } + } + } +} + template void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, T2* output_data, const Dims<4>& output_dims) { @@ -3298,6 +3533,67 @@ void Transpose(const T* input, const Dims<4>& input_dims, T* output, } } +inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + + // Although transpose convolution simplifies to convolution with transposed + // weights for strides of 1, non-unitary striding complicates matters. To + // keep this reference implementation as clear as possible, we use a "scatter" + // access pattern, where we loop through all the input elements, computing + // their influence on the output, rather than looping through the output + // elements in the typical "gather" access pattern of a conv. We therefore + // must initialize the output array to zero. + for (int i = 0; i < RequiredBufferSizeForDims(output_dims); i++) { + output_data[i] = 0.0f; + } + + // Loop through input elements one at a time. + for (int batch = 0; batch < batches; ++batch) { + for (int in_y = 0; in_y < input_height; ++in_y) { + for (int in_x = 0; in_x < input_width; ++in_x) { + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + // Loop through the output elements it will influence + const int out_x_origin = (in_x * stride_width) - pad_width; + const int out_y_origin = (in_y * stride_height) - pad_height; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + for (int out_channel = 0; out_channel < output_depth; + ++out_channel) { + // Compute output element location + const int out_x = out_x_origin + filter_x; + const int out_y = out_y_origin + filter_y; + // We cannot accumulate out of bounds + if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) && + (out_y < output_height)) { + float input_value = input_data[Offset(input_dims, in_channel, + in_x, in_y, batch)]; + float filter_value = + filter_data[Offset(filter_dims, out_channel, filter_x, + filter_y, in_channel)]; + output_data[Offset(output_dims, out_channel, out_x, out_y, + batch)] += input_value * filter_value; + } + } + } + } + } + } + } + } +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/spectrogram.cc b/tensorflow/contrib/lite/kernels/internal/spectrogram.cc new file mode 100644 index 0000000000000000000000000000000000000000..4eddf7bf0a2cbca695dae20ba8ba56a9cd72e4ba --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/spectrogram.cc @@ -0,0 +1,244 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/kernels/internal/spectrogram.h" + +#include +#include + +#include "third_party/fft2d/fft.h" + +namespace tflite { +namespace internal { + +using std::complex; + +namespace { +// Returns the default Hann window function for the spectrogram. +void GetPeriodicHann(int window_length, std::vector* window) { + // Some platforms don't have M_PI, so define a local constant here. + const double pi = std::atan(1) * 4; + window->resize(window_length); + for (int i = 0; i < window_length; ++i) { + (*window)[i] = 0.5 - 0.5 * cos((2 * pi * i) / window_length); + } +} +} // namespace + +bool Spectrogram::Initialize(int window_length, int step_length) { + std::vector window; + GetPeriodicHann(window_length, &window); + return Initialize(window, step_length); +} + +inline int Log2Floor(uint n) { + if (n == 0) return -1; + int log = 0; + uint value = n; + for (int i = 4; i >= 0; --i) { + int shift = (1 << i); + uint x = value >> shift; + if (x != 0) { + value = x; + log += shift; + } + } + return log; +} + +inline int Log2Ceiling(uint n) { + int floor = Log2Floor(n); + if (n == (n & ~(n - 1))) // zero or a power of two + return floor; + else + return floor + 1; +} + +inline uint NextPowerOfTwo(uint value) { + int exponent = Log2Ceiling(value); + // DCHECK_LT(exponent, std::numeric_limits::digits); + return 1 << exponent; +} + +bool Spectrogram::Initialize(const std::vector& window, + int step_length) { + window_length_ = window.size(); + window_ = window; // Copy window. + if (window_length_ < 2) { + // LOG(ERROR) << "Window length too short."; + initialized_ = false; + return false; + } + + step_length_ = step_length; + if (step_length_ < 1) { + // LOG(ERROR) << "Step length must be positive."; + initialized_ = false; + return false; + } + + fft_length_ = NextPowerOfTwo(window_length_); + // CHECK(fft_length_ >= window_length_); + output_frequency_channels_ = 1 + fft_length_ / 2; + + // Allocate 2 more than what rdft needs, so we can rationalize the layout. + fft_input_output_.assign(fft_length_ + 2, 0.0); + + int half_fft_length = fft_length_ / 2; + fft_double_working_area_.assign(half_fft_length, 0.0); + fft_integer_working_area_.assign(2 + static_cast(sqrt(half_fft_length)), + 0); + // Set flag element to ensure that the working areas are initialized + // on the first call to cdft. It's redundant given the assign above, + // but keep it as a reminder. + fft_integer_working_area_[0] = 0; + input_queue_.clear(); + samples_to_next_step_ = window_length_; + initialized_ = true; + return true; +} + +template +bool Spectrogram::ComputeComplexSpectrogram( + const std::vector& input, + std::vector>>* output) { + if (!initialized_) { + // LOG(ERROR) << "ComputeComplexSpectrogram() called before successful call + // " + // << "to Initialize()."; + return false; + } + // CHECK(output); + output->clear(); + int input_start = 0; + while (GetNextWindowOfSamples(input, &input_start)) { + // DCHECK_EQ(input_queue_.size(), window_length_); + ProcessCoreFFT(); // Processes input_queue_ to fft_input_output_. + // Add a new slice vector onto the output, to save new result to. + output->resize(output->size() + 1); + // Get a reference to the newly added slice to fill in. + auto& spectrogram_slice = output->back(); + spectrogram_slice.resize(output_frequency_channels_); + for (int i = 0; i < output_frequency_channels_; ++i) { + // This will convert double to float if it needs to. + spectrogram_slice[i] = complex( + fft_input_output_[2 * i], fft_input_output_[2 * i + 1]); + } + } + return true; +} +// Instantiate it four ways: +template bool Spectrogram::ComputeComplexSpectrogram( + const std::vector& input, std::vector>>*); +template bool Spectrogram::ComputeComplexSpectrogram( + const std::vector& input, + std::vector>>*); +template bool Spectrogram::ComputeComplexSpectrogram( + const std::vector& input, + std::vector>>*); +template bool Spectrogram::ComputeComplexSpectrogram( + const std::vector& input, + std::vector>>*); + +template +bool Spectrogram::ComputeSquaredMagnitudeSpectrogram( + const std::vector& input, + std::vector>* output) { + if (!initialized_) { + // LOG(ERROR) << "ComputeSquaredMagnitudeSpectrogram() called before " + // << "successful call to Initialize()."; + return false; + } + // CHECK(output); + output->clear(); + int input_start = 0; + while (GetNextWindowOfSamples(input, &input_start)) { + // DCHECK_EQ(input_queue_.size(), window_length_); + ProcessCoreFFT(); // Processes input_queue_ to fft_input_output_. + // Add a new slice vector onto the output, to save new result to. + output->resize(output->size() + 1); + // Get a reference to the newly added slice to fill in. + auto& spectrogram_slice = output->back(); + spectrogram_slice.resize(output_frequency_channels_); + for (int i = 0; i < output_frequency_channels_; ++i) { + // Similar to the Complex case, except storing the norm. + // But the norm function is known to be a performance killer, + // so do it this way with explicit real and imagninary temps. + const double re = fft_input_output_[2 * i]; + const double im = fft_input_output_[2 * i + 1]; + // Which finally converts double to float if it needs to. + spectrogram_slice[i] = re * re + im * im; + } + } + return true; +} +// Instantiate it four ways: +template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram( + const std::vector& input, std::vector>*); +template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram( + const std::vector& input, std::vector>*); +template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram( + const std::vector& input, std::vector>*); +template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram( + const std::vector& input, std::vector>*); + +// Return true if a full window of samples is prepared; manage the queue. +template +bool Spectrogram::GetNextWindowOfSamples(const std::vector& input, + int* input_start) { + auto input_it = input.begin() + *input_start; + int input_remaining = input.end() - input_it; + if (samples_to_next_step_ > input_remaining) { + // Copy in as many samples are left and return false, no full window. + input_queue_.insert(input_queue_.end(), input_it, input.end()); + *input_start += input_remaining; // Increases it to input.size(). + samples_to_next_step_ -= input_remaining; + return false; // Not enough for a full window. + } else { + // Copy just enough into queue to make a new window, then trim the + // front off the queue to make it window-sized. + input_queue_.insert(input_queue_.end(), input_it, + input_it + samples_to_next_step_); + *input_start += samples_to_next_step_; + input_queue_.erase( + input_queue_.begin(), + input_queue_.begin() + input_queue_.size() - window_length_); + // DCHECK_EQ(window_length_, input_queue_.size()); + samples_to_next_step_ = step_length_; // Be ready for next time. + return true; // Yes, input_queue_ now contains exactly a window-full. + } +} + +void Spectrogram::ProcessCoreFFT() { + for (int j = 0; j < window_length_; ++j) { + fft_input_output_[j] = input_queue_[j] * window_[j]; + } + // Zero-pad the rest of the input buffer. + for (int j = window_length_; j < fft_length_; ++j) { + fft_input_output_[j] = 0.0; + } + const int kForwardFFT = 1; // 1 means forward; -1 reverse. + // This real FFT is a fair amount faster than using cdft here. + rdft(fft_length_, kForwardFFT, &fft_input_output_[0], + &fft_integer_working_area_[0], &fft_double_working_area_[0]); + // Make rdft result look like cdft result; + // unpack the last real value from the first position's imag slot. + fft_input_output_[fft_length_] = fft_input_output_[1]; + fft_input_output_[fft_length_ + 1] = 0; + fft_input_output_[1] = 0; +} + +} // namespace internal +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/spectrogram.h b/tensorflow/contrib/lite/kernels/internal/spectrogram.h new file mode 100644 index 0000000000000000000000000000000000000000..b77a68f7dfe6edb07ec4e5db540c673b2d6f6d6e --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/spectrogram.h @@ -0,0 +1,110 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Class for generating spectrogram slices from a waveform. +// Initialize() should be called before calls to other functions. Once +// Initialize() has been called and returned true, The Compute*() functions can +// be called repeatedly with sequential input data (ie. the first element of the +// next input vector directly follows the last element of the previous input +// vector). Whenever enough audio samples are buffered to produce a +// new frame, it will be placed in output. Output is cleared on each +// call to Compute*(). This class is thread-unsafe, and should only be +// called from one thread at a time. +// With the default parameters, the output of this class should be very +// close to the results of the following MATLAB code: +// overlap_samples = window_length_samples - step_samples; +// window = hann(window_length_samples, 'periodic'); +// S = abs(spectrogram(audio, window, overlap_samples)).^2; + +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_SPECTROGRAM_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_SPECTROGRAM_H_ + +#include +#include +#include + +#include "third_party/fft2d/fft.h" + +namespace tflite { +namespace internal { + +class Spectrogram { + public: + Spectrogram() : initialized_(false) {} + ~Spectrogram() {} + + // Initializes the class with a given window length and step length + // (both in samples). Internally a Hann window is used as the window + // function. Returns true on success, after which calls to Process() + // are possible. window_length must be greater than 1 and step + // length must be greater than 0. + bool Initialize(int window_length, int step_length); + + // Initialize with an explicit window instead of a length. + bool Initialize(const std::vector& window, int step_length); + + // Processes an arbitrary amount of audio data (contained in input) + // to yield complex spectrogram frames. After a successful call to + // Initialize(), Process() may be called repeatedly with new input data + // each time. The audio input is buffered internally, and the output + // vector is populated with as many temporally-ordered spectral slices + // as it is possible to generate from the input. The output is cleared + // on each call before the new frames (if any) are added. + // + // The template parameters can be float or double. + template + bool ComputeComplexSpectrogram( + const std::vector& input, + std::vector>>* output); + + // This function works as the one above, but returns the power + // (the L2 norm, or the squared magnitude) of each complex value. + template + bool ComputeSquaredMagnitudeSpectrogram( + const std::vector& input, + std::vector>* output); + + // Return reference to the window function used internally. + const std::vector& GetWindow() const { return window_; } + + // Return the number of frequency channels in the spectrogram. + int output_frequency_channels() const { return output_frequency_channels_; } + + private: + template + bool GetNextWindowOfSamples(const std::vector& input, + int* input_start); + void ProcessCoreFFT(); + + int fft_length_; + int output_frequency_channels_; + int window_length_; + int step_length_; + bool initialized_; + int samples_to_next_step_; + + std::vector window_; + std::vector fft_input_output_; + std::deque input_queue_; + + // Working data areas for the FFT routines. + std::vector fft_integer_working_area_; + std::vector fft_double_working_area_; +}; + +} // namespace internal +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_SPECTROGRAM_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index afe131b06ec41201395e80aa5415fd7db990f8d4..293538fcbb6406d6065d8efd25adb3b163638c92 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -21,6 +21,22 @@ namespace tflite { enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu }; +// Quantization parameters, determining the mapping of quantized values +// to real values (i.e. determining how quantized values are mathematically +// interpreted). +// +// The correspondence is as follows: +// +// real_value = scale * (quantized_value - zero_point); +// +// In other words, zero_point designates which quantized value corresponds to +// the real 0 value, and scale designates the difference between the real values +// corresponding to consecutive quantized values differing by 1. +struct QuantizationParams { + int32 zero_point = 0; + double scale = 0.0; +}; + template struct Dims { int sizes[N]; diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h index 28f53b9fbbc5620f2fab5c73e40bed8af4af5f1e..2f407b5da31594335dba31b3057737e67a974057 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.h +++ b/tensorflow/contrib/lite/kernels/kernel_util.h @@ -53,13 +53,13 @@ inline TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context, } // Determines whether tensor is constant. -inline bool IsConstantTensor(TfLiteTensor* tensor) { +inline bool IsConstantTensor(const TfLiteTensor* tensor) { return tensor->allocation_type == kTfLiteMmapRo; } // Determines whether tensor is dynamic. Note that a tensor can be non-const and -// not dynamic. This function specificially checks for a dynamic tensor. -inline bool IsDynamicTensor(TfLiteTensor* tensor) { +// not dynamic. This function specifically checks for a dynamic tensor. +inline bool IsDynamicTensor(const TfLiteTensor* tensor) { return tensor->allocation_type == kTfLiteDynamic; } diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc index 5f73b56ed9790b216adc788490faebaabd2bc756..0ee35775d50b8750455572f789d7b92481655a95 100644 --- a/tensorflow/contrib/lite/kernels/lsh_projection.cc +++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// LSH Projection projects an input to a bit vector via locality senstive +// LSH Projection projects an input to a bit vector via locality sensitive // hashing. // // Options: diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index 6c06264d845c24e71647b6fd2374734be32383ef..8cf1165135bdb0d4669bb97fd2d98e3dc044b4d9 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" @@ -212,9 +213,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, // present. // 2) If projection weight is present, then projection bias is optional. // TODO(ghodrat): make sure this is correct. - const bool projecton_tensors_consistent = + const bool projection_tensors_consistent = ((projection_weights != nullptr) || (projection_bias == nullptr)); - TF_LITE_ENSURE(context, projecton_tensors_consistent == true); + TF_LITE_ENSURE(context, projection_tensors_consistent == true); return kTfLiteOk; } @@ -356,7 +357,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const int n_output = recurrent_to_output_weights->dims->data[1]; // Since we have already checked that weights are all there or none, we can - // check the existense of only one to the get the condition. + // check the existence of only one to get the condition. const bool use_cifg = (input_to_input_weights == nullptr); const bool use_peephole = (cell_to_output_weights != nullptr); @@ -377,127 +378,54 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; } - // Initialize scratch buffers with bias. - if (!use_cifg) { - tensor_utils::VectorBatchVectorAssign(input_gate_bias->data.f, n_cell, - n_batch, input_gate_scratch); - } - tensor_utils::VectorBatchVectorAssign(forget_gate_bias->data.f, n_cell, - n_batch, forget_gate_scratch); - tensor_utils::VectorBatchVectorAssign(cell_bias->data.f, n_cell, n_batch, - cell_scratch); - tensor_utils::VectorBatchVectorAssign(output_gate_bias->data.f, n_cell, - n_batch, output_gate_scratch); - - // For each batch and cell: compute input_weight * input. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_input_weights->data.f, n_cell, n_input, input->data.f, n_batch, - input_gate_scratch, /*result_stride=*/1); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_forget_weights->data.f, n_cell, n_input, input->data.f, n_batch, - forget_gate_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_cell_weights->data.f, n_cell, n_input, input->data.f, n_batch, - cell_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_output_weights->data.f, n_cell, n_input, input->data.f, n_batch, - output_gate_scratch, /*result_stride=*/1); - - // For each batch and cell: compute recurrent_weight * output_state. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_input_weights->data.f, n_cell, n_output, - output_state->data.f, n_batch, input_gate_scratch, /*result_stride=*/1); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_forget_weights->data.f, n_cell, n_output, - output_state->data.f, n_batch, forget_gate_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_cell_weights->data.f, n_cell, n_output, output_state->data.f, - n_batch, cell_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_output_weights->data.f, n_cell, n_output, - output_state->data.f, n_batch, output_gate_scratch, /*result_stride=*/1); - - // For each batch and cell: update input gate. - if (!use_cifg) { - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_input_weights->data.f, n_cell, cell_state->data.f, n_batch, - input_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, - input_gate_scratch); - } - - // For each batch and cell: update forget gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_forget_weights->data.f, n_cell, cell_state->data.f, n_batch, - forget_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, - forget_gate_scratch); - - // For each batch and cell: update the cell. - tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, - cell_state->data.f, n_batch * n_cell, - cell_state->data.f); - tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, - params->activation, cell_scratch); - if (use_cifg) { - tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, - forget_gate_scratch); - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, forget_gate_scratch, n_batch * n_cell, - cell_state->data.f); - } else { - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state->data.f); - } - if (params->cell_clip > 0.0) { - tensor_utils::ClipVector(cell_state->data.f, n_batch * n_cell, - params->cell_clip, cell_state->data.f); - } - - // For each batch and cell: update the output gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_output_weights->data.f, n_cell, cell_state->data.f, n_batch, - output_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, - output_gate_scratch); - tensor_utils::ApplyActivationToVector(cell_state->data.f, n_batch * n_cell, - params->activation, cell_scratch); - tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, - n_batch * n_cell, output_gate_scratch); - - // For each batch: update the projection and output_state. - const bool use_projection_weight = (projection_weights != nullptr); - const bool use_projection_bias = (projection_bias != nullptr); - if (use_projection_weight) { - if (use_projection_bias) { - tensor_utils::VectorBatchVectorAssign(projection_bias->data.f, n_output, - n_batch, output->data.f); - } else { - tensor_utils::ZeroVector(output->data.f, n_batch * n_output); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - projection_weights->data.f, n_output, n_cell, output_gate_scratch, - n_batch, output->data.f, /*result_stride=*/1); - if (params->proj_clip > 0.0) { - tensor_utils::ClipVector(output->data.f, n_batch * n_output, - params->proj_clip, output->data.f); - } - } else { - tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, - output->data.f); - } - tensor_utils::CopyVector(output->data.f, n_batch * n_output, - output_state->data.f); + // Check optional tensors, the respective pointers can be null. + const float* input_to_input_weights_ptr = + (use_cifg) ? nullptr : input_to_input_weights->data.f; + const float* recurrent_to_input_weights_ptr = + (use_cifg) ? nullptr : recurrent_to_input_weights->data.f; + const float* input_gate_bias_ptr = + (use_cifg) ? nullptr : input_gate_bias->data.f; + const float* cell_to_input_weights_ptr = + (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr; + const float* cell_to_forget_weights_ptr = + (use_peephole) ? cell_to_forget_weights->data.f : nullptr; + const float* cell_to_output_weights_ptr = + (use_peephole) ? cell_to_output_weights->data.f : nullptr; + const float* projection_weights_ptr = + (projection_weights == nullptr) ? nullptr : projection_weights->data.f; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const float* input_ptr_batch = input->data.f; + const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f; + const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f; + const float* input_to_output_weights_ptr = input_to_output_weights->data.f; + const float* recurrent_to_forget_weights_ptr = + recurrent_to_forget_weights->data.f; + const float* recurrent_to_cell_weights_ptr = + recurrent_to_cell_weights->data.f; + const float* recurrent_to_output_weights_ptr = + recurrent_to_output_weights->data.f; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* output_state_ptr = output_state->data.f; + float* cell_state_ptr = cell_state->data.f; + float* output_ptr_batch = output->data.f; + + kernel_utils::LstmStep( + input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr, + input_to_cell_weights_ptr, input_to_output_weights_ptr, + recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr, + recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr, + cell_to_input_weights_ptr, cell_to_forget_weights_ptr, + cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, + cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, + projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, + output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, output_ptr_batch); return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/kernels/maximum.cc b/tensorflow/contrib/lite/kernels/maximum.cc new file mode 100644 index 0000000000000000000000000000000000000000..9fdf2b47eaf421bda11e7474ad819692106a90ac --- /dev/null +++ b/tensorflow/contrib/lite/kernels/maximum.cc @@ -0,0 +1,106 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace maximum { + +// This file has a reference implemenation of TFMaximum. +enum KernelType { + kReference, +}; + +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +struct MaximumContext { + MaximumContext(TfLiteContext* context, TfLiteNode* node) { + input1 = GetInput(context, node, kInputTensor1); + input2 = GetInput(context, node, kInputTensor2); + output = GetOutput(context, node, kOutputTensor); + } + TfLiteTensor* input1; + TfLiteTensor* input2; + TfLiteTensor* output; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + MaximumContext op_context(context, node); + TF_LITE_ENSURE_EQ(context, op_context.input1->type, op_context.input2->type); + TfLiteIntArray* output_dims = TfLiteIntArrayCopy(op_context.input2->dims); + op_context.output->type = op_context.input2->type; + return context->ResizeTensor(context, op_context.output, output_dims); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + MaximumContext op_context(context, node); + +#define TF_LITE_MAXIMUM(kernel_type, data_type) \ + kernel_type::TensorFlowMaximum( \ + GetTensorData(op_context.input1), \ + GetTensorDims(op_context.input1), \ + GetTensorData(op_context.input2), \ + GetTensorDims(op_context.input2), \ + GetTensorData(op_context.output), \ + GetTensorDims(op_context.output)) + + if (kernel_type == kReference) { + switch (op_context.output->type) { + case kTfLiteFloat32: + TF_LITE_MAXIMUM(reference_ops, float); + break; + default: + context->ReportError(context, + "Type %d is currently not supported by Maximum.", + op_context.output->type); + return kTfLiteError; + } + } else { + context->ReportError(context, + "Type %d is currently not supported by Maximum.", + op_context.output->type); + return kTfLiteError; + } +#undef TF_LITE_MAXIMUM + return kTfLiteOk; +} + +} // namespace maximum + +TfLiteRegistration* Register_MAXIMUM_REF() { + static TfLiteRegistration r = {nullptr, nullptr, maximum::Prepare, + maximum::Eval}; + return &r; +} + +TfLiteRegistration* Register_MAXIMUM() { return Register_MAXIMUM_REF(); } + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/maximum_test.cc b/tensorflow/contrib/lite/kernels/maximum_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b3fd7d4e6f40e53db51edf2e7594662629302add --- /dev/null +++ b/tensorflow/contrib/lite/kernels/maximum_test.cc @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class MaximumOpModel : public SingleOpModel { + public: + MaximumOpModel(const TensorData& input1, const TensorData& input2, + const TensorType& output) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MAXIMUM, BuiltinOptions_MaximumOptions, + CreateMaximumOptions(builder_).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + template + void SetInput1(std::initializer_list data) { + PopulateTensor(input1_, data); + } + + template + void SetInput2(std::initializer_list data) { + PopulateTensor(input2_, data); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +TEST(MaximumOpTest, FloatTest) { + std::initializer_list data1 = {1.0, 0.0, -1.0, 11.0, -2.0, -1.44}; + std::initializer_list data2 = {-1.0, 0.0, 1.0, 12.0, -3.0, -1.43}; + MaximumOpModel m({TensorType_FLOAT32, {3, 1, 2}}, + {TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32); + m.SetInput1(data1); + m.SetInput2(data2); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({1.0, 0.0, 1.0, 12.0, -2.0, -1.43}))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/mean_test.cc b/tensorflow/contrib/lite/kernels/mean_test.cc index c4c53c2ded351849e7c458fc754c36395a25ebd0..2d6d4bc2da4b75289ee27c3f2a12787216716d44 100644 --- a/tensorflow/contrib/lite/kernels/mean_test.cc +++ b/tensorflow/contrib/lite/kernels/mean_test.cc @@ -74,7 +74,7 @@ class MeanOpDynamicModel : public BaseMeanOpModel { } }; -TEST(ConstMeanOpTest, NotKeepDims) { +TEST(ConstFloatMeanOpTest, NotKeepDims) { std::initializer_list data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; @@ -86,7 +86,7 @@ TEST(ConstMeanOpTest, NotKeepDims) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); } -TEST(ConstMeanOpTest, KeepDims) { +TEST(ConstFloatMeanOpTest, KeepDims) { std::initializer_list data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; @@ -99,7 +99,7 @@ TEST(ConstMeanOpTest, KeepDims) { ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); } -TEST(DynamicMeanOpTest, NotKeepDims) { +TEST(DynamicFloatMeanOpTest, NotKeepDims) { std::initializer_list data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; @@ -114,7 +114,7 @@ TEST(DynamicMeanOpTest, NotKeepDims) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); } -TEST(DynamicMeanOpTest, KeepDims) { +TEST(DynamicFloatMeanOpTest, KeepDims) { std::initializer_list data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; @@ -130,6 +130,70 @@ TEST(DynamicMeanOpTest, KeepDims) { ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); } +TEST(DynamicFloatMeanOpTest, Scale) { + std::initializer_list data = {9.527}; + MeanOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}}, + {TensorType_INT32, {1}}, true); + std::initializer_list axis = {0}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({9.527}))); +} + +TEST(ConstUint8MeanOpTest, NotKeepDims) { + std::initializer_list data = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24}; + MeanOpConstModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {2}}, + {4}, {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({12, 13})); +} + +TEST(ConstUint8MeanOpTest, KeepDims) { + std::initializer_list data = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24}; + MeanOpConstModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {3}}, + {2}, {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({10, 12, 14})); +} + +TEST(DynamicUint8MeanOpTest, NotKeepDims) { + std::initializer_list data = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24}; + MeanOpDynamicModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {2}}, + {TensorType_INT32, {4}}, false); + std::initializer_list axis = {1, 0, -3, -3}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({12, 13})); +} + +TEST(DynamicUint8MeanOpTest, KeepDims) { + std::initializer_list data = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24}; + MeanOpDynamicModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {3}}, + {TensorType_INT32, {2}}, true); + std::initializer_list axis = {0, 2}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({10, 12, 14})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/mfcc.cc b/tensorflow/contrib/lite/kernels/mfcc.cc new file mode 100644 index 0000000000000000000000000000000000000000..018db0dc54c5d281bf3fb3ff8a1f111b427fe76b --- /dev/null +++ b/tensorflow/contrib/lite/kernels/mfcc.cc @@ -0,0 +1,154 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/mfcc.h" +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h" +#include "tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace mfcc { + +enum KernelType { + kReference, +}; + +typedef struct { + float upper_frequency_limit; + float lower_frequency_limit; + int filterbank_channel_count; + int dct_coefficient_count; +} TfLiteMfccParams; + +constexpr int kInputTensorWav = 0; +constexpr int kInputTensorRate = 1; +constexpr int kOutputTensor = 0; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new TfLiteMfccParams; + + const uint8_t* buffer_t = reinterpret_cast(buffer); + + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); + data->upper_frequency_limit = m["upper_frequency_limit"].AsInt64(); + data->lower_frequency_limit = m["lower_frequency_limit"].AsInt64(); + data->filterbank_channel_count = m["filterbank_channel_count"].AsInt64(); + data->dct_coefficient_count = m["dct_coefficient_count"].AsInt64(); + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* inputWav = GetInput(context, node, kInputTensorWav); + TfLiteTensor* inputRate = GetInput(context, node, kInputTensorRate); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(inputWav), 3); + TF_LITE_ENSURE_EQ(context, NumDimensions(inputRate), 1); + + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, inputWav->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(3); + output_size->data[0] = inputWav->dims->data[0]; + output_size->data[1] = inputWav->dims->data[1]; + output_size->data[2] = params->dct_coefficient_count; + + return context->ResizeTensor(context, output, output_size); +} + +// Input is a single squared-magnitude spectrogram frame. The input spectrum +// is converted to linear magnitude and weighted into bands using a +// triangular mel filterbank, and a discrete cosine transform (DCT) of the +// values is taken. Output is populated with the lowest dct_coefficient_count +// of these values. +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->user_data); + + TfLiteTensor* inputWav = GetInput(context, node, kInputTensorWav); + TfLiteTensor* inputRate = GetInput(context, node, kInputTensorRate); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + const int32 sample_rate = *GetTensorData(inputRate); + + const int spectrogram_channels = inputWav->dims->data[2]; + const int spectrogram_samples = inputWav->dims->data[1]; + const int audio_channels = inputWav->dims->data[0]; + + internal::Mfcc mfcc; + mfcc.set_upper_frequency_limit(params->upper_frequency_limit); + mfcc.set_lower_frequency_limit(params->lower_frequency_limit); + mfcc.set_filterbank_channel_count(params->filterbank_channel_count); + mfcc.set_dct_coefficient_count(params->dct_coefficient_count); + + mfcc.Initialize(spectrogram_channels, sample_rate); + + const float* spectrogram_flat = GetTensorData(inputWav); + float* output_flat = GetTensorData(output); + + for (int audio_channel = 0; audio_channel < audio_channels; ++audio_channel) { + for (int spectrogram_sample = 0; spectrogram_sample < spectrogram_samples; + ++spectrogram_sample) { + const float* sample_data = + spectrogram_flat + + (audio_channel * spectrogram_samples * spectrogram_channels) + + (spectrogram_sample * spectrogram_channels); + std::vector mfcc_input(sample_data, + sample_data + spectrogram_channels); + std::vector mfcc_output; + mfcc.Compute(mfcc_input, &mfcc_output); + TF_LITE_ENSURE_EQ(context, params->dct_coefficient_count, + mfcc_output.size()); + float* output_data = output_flat + + (audio_channel * spectrogram_samples * + params->dct_coefficient_count) + + (spectrogram_sample * params->dct_coefficient_count); + for (int i = 0; i < params->dct_coefficient_count; ++i) { + output_data[i] = mfcc_output[i]; + } + } + } + + return kTfLiteOk; +} + +} // namespace mfcc + +TfLiteRegistration* Register_MFCC() { + static TfLiteRegistration r = {mfcc::Init, mfcc::Free, mfcc::Prepare, + mfcc::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/mfcc_test.cc b/tensorflow/contrib/lite/kernels/mfcc_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0291ca8c1c58ea6ab3bb7c22bc436ed3404cba74 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/mfcc_test.cc @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_MFCC(); + +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +class BaseMfccOpModel : public SingleOpModel { + public: + BaseMfccOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + + flexbuffers::Builder fbb; + fbb.Map([&]() { + fbb.Int("upper_frequency_limit", 4000); + fbb.Int("lower_frequency_limit", 20); + fbb.Int("filterbank_channel_count", 40); + fbb.Int("dct_coefficient_count", 13); + }); + fbb.Finish(); + SetCustomOp("Mfcc", fbb.GetBuffer(), Register_MFCC); + + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +TEST(MfccOpTest, SimpleTest) { + BaseMfccOpModel m({TensorType_FLOAT32, {1, 1, 513}}, {TensorType_INT32, {1}}, + {TensorType_FLOAT32, {}}); + + std::vector data(513); + for (int i = 0; i < data.size(); ++i) { + data[i] = i + 1; + } + m.PopulateTensor(m.input1(), 0, data.data(), + data.data() + data.size()); + m.PopulateTensor(m.input2(), {22050}); + + m.Invoke(); + + std::vector output_shape = m.GetOutputShape(); + EXPECT_THAT(output_shape, ElementsAre(1, 1, 13)); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {29.13970072, -6.41568601, -0.61903012, -0.96778652, -0.26819878, + -0.40907028, -0.15614748, -0.23203119, -0.10481487, -0.1543029, + -0.0769791, -0.10806114, -0.06047613}, + 1e-3))); +} + +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index aea6f8d9d34420363cc1045425f3d27b12af449e..0f98154b904b1f776016e6bbee3263027f815244 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -17,6 +17,14 @@ limitations under the License. namespace tflite { namespace ops { + +namespace custom { + +TfLiteRegistration* Register_AUDIO_SPECTROGRAM(); +TfLiteRegistration* Register_MFCC(); + +} // namespace custom + namespace builtin { TfLiteRegistration* Register_RELU(); @@ -65,6 +73,10 @@ TfLiteRegistration* Register_STRIDED_SLICE(); TfLiteRegistration* Register_EXP(); TfLiteRegistration* Register_TOPK_V2(); TfLiteRegistration* Register_LOG_SOFTMAX(); +TfLiteRegistration* Register_CAST(); +TfLiteRegistration* Register_DEQUANTIZE(); +TfLiteRegistration* Register_PRELU(); +TfLiteRegistration* Register_MAXIMUM(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -119,6 +131,16 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_EXP, Register_EXP()); AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2()); AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX()); + AddBuiltin(BuiltinOperator_CAST, Register_CAST()); + AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE()); + AddBuiltin(BuiltinOperator_PRELU, Register_PRELU()); + AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM()); + + // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that + // custom ops aren't always included by default. + AddCustom("Mfcc", tflite::ops::custom::Register_MFCC()); + AddCustom("AudioSpectrogram", + tflite::ops::custom::Register_AUDIO_SPECTROGRAM()); } TfLiteRegistration* BuiltinOpResolver::FindOp( diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc index f3e6ddc9f480e3863cac52157ae28b7329ee2088..438f70d3115130efe477a3ceeccd2e77108c979a 100644 --- a/tensorflow/contrib/lite/kernels/reshape.cc +++ b/tensorflow/contrib/lite/kernels/reshape.cc @@ -49,20 +49,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArray* output_size = TfLiteIntArrayCreate(params->num_dimensions); int num_output_elements = 1; - int strech_dim = -1; + int stretch_dim = -1; for (int i = 0; i < params->num_dimensions; ++i) { int value = params->shape[i]; if (value == -1) { - TF_LITE_ENSURE_EQ(context, strech_dim, -1); - strech_dim = i; + TF_LITE_ENSURE_EQ(context, stretch_dim, -1); + stretch_dim = i; } else { num_output_elements *= value; output_size->data[i] = value; } } - if (strech_dim != -1) { - output_size->data[strech_dim] = num_input_elements / num_output_elements; - num_output_elements *= output_size->data[strech_dim]; + if (stretch_dim != -1) { + output_size->data[stretch_dim] = num_input_elements / num_output_elements; + num_output_elements *= output_size->data[stretch_dim]; } TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements); diff --git a/tensorflow/contrib/lite/kernels/reshape_test.cc b/tensorflow/contrib/lite/kernels/reshape_test.cc index 0fbcf6e6aa311d2cac491336ee54ccf58bbda8fd..aecbd0399f7454045e8189072f45b695b0525204 100644 --- a/tensorflow/contrib/lite/kernels/reshape_test.cc +++ b/tensorflow/contrib/lite/kernels/reshape_test.cc @@ -60,7 +60,7 @@ TEST(ReshapeOpTest, TooManyDimensions) { TEST(ReshapeOpTest, TooManySpecialDimensions) { EXPECT_DEATH(ReshapeOpModel({1, 2, 4, 1}, {-1, -1, 2, 4}), - "strech_dim != -1"); + "stretch_dim != -1"); } TEST(ReshapeOpTest, SimpleTest) { diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc index fb1e11e0ca00abb36d7f29d562711a7bbcbeca1c..eb374d903182f46b40f5c80bfd769a19a5594742 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice.cc @@ -48,7 +48,7 @@ struct StridedSliceContext { output = GetOutput(context, node, kOutputTensor); dims = NumDimensions(input); } - TfLiteStridedSliceParams* params; + const TfLiteStridedSliceParams* params; TfLiteTensor* input; TfLiteTensor* begin; TfLiteTensor* end; @@ -199,19 +199,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { strides.emplace_back(1); } - op_context.params->begin_mask = + int begin_mask = ReverseMaskBits(op_context.params->begin_mask, op_context.dims); - op_context.params->end_mask = - ReverseMaskBits(op_context.params->end_mask, op_context.dims); - op_context.params->shrink_axis_mask = + int end_mask = ReverseMaskBits(op_context.params->end_mask, op_context.dims); + int shrink_axis_mask = ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims); -#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ - kernel_type::StridedSlice( \ - GetTensorData(op_context.input), \ - GetTensorDims(op_context.input), op_context.params->begin_mask, \ - op_context.params->end_mask, op_context.params->shrink_axis_mask, \ - starts, stops, strides, GetTensorData(op_context.output), \ +#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ + kernel_type::StridedSlice( \ + GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), begin_mask, end_mask, shrink_axis_mask, \ + starts, stops, strides, GetTensorData(op_context.output), \ GetTensorDims(op_context.output)) switch (op_context.input->type) { diff --git a/tensorflow/contrib/lite/kernels/strided_slice_test.cc b/tensorflow/contrib/lite/kernels/strided_slice_test.cc index 5cac04b38364958c5b0794c21742e8b592372ae9..5c98c5f43181fe75f35716dae5682113bde883ec 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice_test.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice_test.cc @@ -522,6 +522,28 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) { EXPECT_TRUE(m.GetOutputShape().empty()); EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); } + +// This tests catches a very subtle bug that was fixed by cl/188403234. +TEST(StridedSliceOpTest, RunTwice) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0); + + auto setup_inputs = [&m]() { + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, 0}); + m.SetEnd({2, 2}); + m.SetStrides({1, 1}); + }; + + setup_inputs(); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 4, 5})); + + setup_inputs(); + m.Invoke(); + // Prior to cl/188403234 this was {4, 5}. + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 4, 5})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc index 410585a2939c8e817c969015c96bbc266fc5641b..5acb3561817f2989d2db7fd0b0bf2dac5a100389 100644 --- a/tensorflow/contrib/lite/kernels/sub.cc +++ b/tensorflow/contrib/lite/kernels/sub.cc @@ -78,10 +78,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } template -void EvalSubFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteSubParams* params, const OpData* data, - TfLiteTensor* input1, TfLiteTensor* input2, - TfLiteTensor* output) { +void EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteSubParams* params, const OpData* data, + TfLiteTensor* input1, TfLiteTensor* input2, + TfLiteTensor* output) { float output_activation_min, output_activation_max; CalculateActivationRangeFloat(params->activation, &output_activation_min, &output_activation_max); @@ -107,10 +107,10 @@ void EvalSubFloat(TfLiteContext* context, TfLiteNode* node, } template -void EvalSubQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteSubParams* params, const OpData* data, - TfLiteTensor* input1, TfLiteTensor* input2, - TfLiteTensor* output) { +void EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteSubParams* params, const OpData* data, + TfLiteTensor* input1, TfLiteTensor* input2, + TfLiteTensor* output) { auto input1_offset = -input1->params.zero_point; auto input2_offset = -input2->params.zero_point; auto output_offset = output->params.zero_point; @@ -169,11 +169,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); if (output->type == kTfLiteFloat32) { - EvalSubFloat(context, node, params, data, input1, input2, - output); + EvalFloat(context, node, params, data, input1, input2, output); } else if (output->type == kTfLiteUInt8) { - EvalSubQuantized(context, node, params, data, input1, input2, - output); + EvalQuantized(context, node, params, data, input1, input2, + output); } else { context->ReportError(context, "Inputs and outputs not all float|unit8 types."); diff --git a/tensorflow/contrib/lite/kernels/sub_test.cc b/tensorflow/contrib/lite/kernels/sub_test.cc index 1fd0ee2a0e8fe2c4e482bb097b1a51cf61d5b786..ff07aeec49dbfcc0e1f65df3d674d5ec30f1b54c 100644 --- a/tensorflow/contrib/lite/kernels/sub_test.cc +++ b/tensorflow/contrib/lite/kernels/sub_test.cc @@ -31,7 +31,7 @@ class BaseSubOpModel : public SingleOpModel { input1_ = AddInput(input1); input2_ = AddInput(input2); output_ = AddOutput(output); - SetBuiltinOp(BuiltinOperator_Sub, BuiltinOptions_SubOptions, + SetBuiltinOp(BuiltinOperator_SUB, BuiltinOptions_SubOptions, CreateSubOptions(builder_, activation_type).Union()); BuildInterpreter({GetShape(input1_), GetShape(input2_)}); } @@ -76,7 +76,8 @@ TEST(FloatSubOpModel, NoActivation) { m.PopulateTensor(m.input1(), {-2.0, 0.2, 1.7, 0.5}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.8}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({-2.1, 0.0, 1.4, -0.3})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-2.1, 0.0, 1.4, -0.3}))); } TEST(FloatSubOpModel, ActivationRELU_N1_TO_1) { @@ -86,7 +87,8 @@ TEST(FloatSubOpModel, ActivationRELU_N1_TO_1) { m.PopulateTensor(m.input1(), {-2.0, 0.2, 1.7, 0.5}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.8}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.0, 0.0, 1.0, -0.3})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-1.0, 0.0, 1.0, -0.3}))); } TEST(FloatSubOpModel, VariousInputShapes) { @@ -99,8 +101,9 @@ TEST(FloatSubOpModel, VariousInputShapes) { m.PopulateTensor(m.input1(), {-2.0, 0.2, 1.7, 0.5, -1.1, 2.0}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.8, -1.1, 0.1}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), - ElementsAreArray({-2.1, 0.0, 1.4, -0.3, 0.0, 1.9})) + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-2.1, 0.0, 1.4, -0.3, 0.0, 1.9}))) << "With shape number " << i; } } @@ -125,17 +128,13 @@ TEST(FloatSubOpModel, WithBroadcast) { TEST(QuantizedSubOpModel, QuantizedTestsNoActivation) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = { - {0.1, 0.2, 0.3, 0.4}, - {-0.2, 0.2, 0.4, 0.7}, - {-0.01, 0.2, 0.7, 0.3}}; + {0.1, 0.2, 0.3, 0.4}, {-0.2, 0.2, 0.4, 0.7}, {-0.01, 0.2, 0.7, 0.3}}; std::vector> inputs2 = { - {0.6, 0.4, 0.3, 0.1}, - {0.6, 0.4, 0.5, -0.2}, - {0.6, 0.4, -0.18, 0.5}}; + {0.6, 0.4, 0.3, 0.1}, {0.6, 0.4, 0.5, -0.2}, {0.6, 0.4, -0.18, 0.5}}; std::vector> results = { - {-0.5, -0.2, 0.0, 0.3}, - {-0.8, -0.2, -0.1, 0.9}, - {-0.61, -0.2, 0.88, -0.2}}; + {-0.5, -0.2, 0.0, 0.3}, + {-0.8, -0.2, -0.1, 0.9}, + {-0.61, -0.2, 0.88, -0.2}}; for (int i = 0; i < inputs1.size(); ++i) { QuantizedSubOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, @@ -185,9 +184,8 @@ TEST(QuantizedSubOpModel, QuantizedVariousInputShapes) { m.QuantizeAndPopulate(m.input2(), {0.1, 0.3, 0.3, 0.5, 1.1, 0.1}); m.Invoke(); EXPECT_THAT(m.GetDequantizedOutput(), - ElementsAreArray(ArrayFloatNear({-2.1, -0.1, 0.4, - 0.3, 0.0, 1.9}, - kQuantizedTolerance))) + ElementsAreArray(ArrayFloatNear( + {-2.1, -0.1, 0.4, 0.3, 0.0, 1.9}, kQuantizedTolerance))) << "With shape number " << i; } } @@ -205,9 +203,8 @@ TEST(QuantizedSubOpModel, QuantizedWithBroadcast) { m.QuantizeAndPopulate(m.input2(), {0.7}); m.Invoke(); EXPECT_THAT(m.GetDequantizedOutput(), - ElementsAreArray(ArrayFloatNear({-2.7, -0.5, 0.0, - 0.1, 0.4, 1.3}, - kQuantizedTolerance))) + ElementsAreArray(ArrayFloatNear( + {-2.7, -0.5, 0.0, 0.1, 0.4, 1.3}, kQuantizedTolerance))) << "With shape number " << i; } } diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index 373310bd87370a670a847cf5328633956028a850..0bb28b50b2a5e5a9fd803ecf1b0928026f63881e 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -141,8 +141,8 @@ void SingleOpModel::SetBuiltinOp(BuiltinOperator type, void SingleOpModel::SetCustomOp( const string& name, const std::vector& custom_option, - const std::function& registeration) { - custom_registrations_[name] = registeration; + const std::function& registration) { + custom_registrations_[name] = registration; opcodes_.push_back( CreateOperatorCodeDirect(builder_, BuiltinOperator_CUSTOM, name.data())); operators_.push_back(CreateOperator( diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index 7d476ba1eaffbb24fb77390c0e71c32d60b6411e..a9064d54e7704d52eefa34f6bf446ec1cfe68fe1 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -39,10 +39,10 @@ inline std::vector Quantize(const std::vector& data, float scale, int32_t zero_point) { std::vector q; for (float f : data) { - q.push_back(std::max( + q.push_back(static_cast(std::max( std::numeric_limits::min(), - std::min(std::numeric_limits::max(), - static_cast(std::round(zero_point + (f / scale)))))); + std::min(std::numeric_limits::max(), + std::round(zero_point + (f / scale)))))); } return q; } diff --git a/tensorflow/contrib/lite/kernels/test_util_test.cc b/tensorflow/contrib/lite/kernels/test_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1e10e89061213b6fcabd404310893dd97a51d83f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/test_util_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +TEST(TestUtilTest, QuantizeVector) { + std::vector data = {-1.0, -0.5, 0.0, 0.5, 1.0, 1000.0}; + auto q_data = Quantize(data, /*scale=*/1.0, /*zero_point=*/0); + std::vector expected = {0, 0, 0, 1, 1, 255}; + EXPECT_THAT(q_data, ElementsAreArray(expected)); +} + +TEST(TestUtilTest, QuantizeVectorScalingDown) { + std::vector data = {-1.0, -0.5, 0.0, 0.5, 1.0, 1000.0}; + auto q_data = Quantize(data, /*scale=*/10.0, /*zero_point=*/0); + std::vector expected = {0, 0, 0, 0, 0, 100}; + EXPECT_THAT(q_data, ElementsAreArray(expected)); +} + +TEST(TestUtilTest, QuantizeVectorScalingUp) { + std::vector data = {-1.0, -0.5, 0.0, 0.5, 1.0, 1000.0}; + auto q_data = Quantize(data, /*scale=*/0.1, /*zero_point=*/0); + std::vector expected = {0, 0, 0, 5, 10, 255}; + EXPECT_THAT(q_data, ElementsAreArray(expected)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc index 9cdb58714edb5fee771fc45f3c53a570f8fb28d1..42941a97db70adb37c20500c8f9438adfea25389 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" @@ -359,7 +360,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const int n_output = recurrent_to_output_weights->dims->data[1]; // Since we have already checked that weights are all there or none, we can - // check the existense of only one to the get the condition. + // check the existence of only one to get the condition. const bool use_cifg = (input_to_input_weights == nullptr); const bool use_peephole = (cell_to_output_weights != nullptr); @@ -380,135 +381,57 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; } + // Check optional tensors, the respective pointers can be null. + const float* input_to_input_weights_ptr = + (use_cifg) ? nullptr : input_to_input_weights->data.f; + const float* recurrent_to_input_weights_ptr = + (use_cifg) ? nullptr : recurrent_to_input_weights->data.f; + const float* input_gate_bias_ptr = + (use_cifg) ? nullptr : input_gate_bias->data.f; + const float* cell_to_input_weights_ptr = + (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr; + const float* cell_to_forget_weights_ptr = + (use_peephole) ? cell_to_forget_weights->data.f : nullptr; + const float* cell_to_output_weights_ptr = + (use_peephole) ? cell_to_output_weights->data.f : nullptr; + const float* projection_weights_ptr = + (projection_weights == nullptr) ? nullptr : projection_weights->data.f; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f; + const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f; + const float* input_to_output_weights_ptr = input_to_output_weights->data.f; + const float* recurrent_to_forget_weights_ptr = + recurrent_to_forget_weights->data.f; + const float* recurrent_to_cell_weights_ptr = + recurrent_to_cell_weights->data.f; + const float* recurrent_to_output_weights_ptr = + recurrent_to_output_weights->data.f; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* output_state_ptr = output_state->data.f; + float* cell_state_ptr = cell_state->data.f; + for (int t = 0; t < max_time; t++) { - const float* input_ptr_time = input->data.f + t * n_batch * n_input; - // Initialize scratch buffers with bias. - if (!use_cifg) { - tensor_utils::VectorBatchVectorAssign(input_gate_bias->data.f, n_cell, - n_batch, input_gate_scratch); - } - tensor_utils::VectorBatchVectorAssign(forget_gate_bias->data.f, n_cell, - n_batch, forget_gate_scratch); - tensor_utils::VectorBatchVectorAssign(cell_bias->data.f, n_cell, n_batch, - cell_scratch); - tensor_utils::VectorBatchVectorAssign(output_gate_bias->data.f, n_cell, - n_batch, output_gate_scratch); - - // For each batch and cell: compute input_weight * input. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_input_weights->data.f, n_cell, n_input, input_ptr_time, - n_batch, input_gate_scratch, /*result_stride=*/1); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_forget_weights->data.f, n_cell, n_input, input_ptr_time, - n_batch, forget_gate_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_cell_weights->data.f, n_cell, n_input, input_ptr_time, n_batch, - cell_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_output_weights->data.f, n_cell, n_input, input_ptr_time, - n_batch, output_gate_scratch, /*result_stride=*/1); - - // For each batch and cell: compute recurrent_weight * output_state. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_input_weights->data.f, n_cell, n_output, - output_state->data.f, n_batch, input_gate_scratch, - /*result_stride=*/1); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_forget_weights->data.f, n_cell, n_output, - output_state->data.f, n_batch, forget_gate_scratch, - /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_cell_weights->data.f, n_cell, n_output, - output_state->data.f, n_batch, cell_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_output_weights->data.f, n_cell, n_output, - output_state->data.f, n_batch, output_gate_scratch, - /*result_stride=*/1); - - // For each batch and cell: update input gate. - if (!use_cifg) { - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_input_weights->data.f, n_cell, cell_state->data.f, n_batch, - input_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, - input_gate_scratch); - } - - // For each batch and cell: update forget gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_forget_weights->data.f, n_cell, cell_state->data.f, n_batch, - forget_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, - forget_gate_scratch); - - // For each batch and cell: update the cell. - tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, - cell_state->data.f, n_batch * n_cell, - cell_state->data.f); - tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, - params->activation, cell_scratch); - if (use_cifg) { - tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, - forget_gate_scratch); - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, forget_gate_scratch, n_batch * n_cell, - cell_state->data.f); - } else { - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, input_gate_scratch, n_batch * n_cell, - cell_state->data.f); - } - if (params->cell_clip > 0.0) { - tensor_utils::ClipVector(cell_state->data.f, n_batch * n_cell, - params->cell_clip, cell_state->data.f); - } - - // For each batch and cell: update the output gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_output_weights->data.f, n_cell, cell_state->data.f, n_batch, - output_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, - output_gate_scratch); - tensor_utils::ApplyActivationToVector(cell_state->data.f, n_batch * n_cell, - params->activation, cell_scratch); - tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, - n_batch * n_cell, - output_gate_scratch); - - // For each batch: update the projection and output_state. - const bool use_projection_weight = (projection_weights != nullptr); - const bool use_projection_bias = (projection_bias != nullptr); - float* output_ptr_time = output->data.f + t * n_batch * n_output; - if (use_projection_weight) { - if (use_projection_bias) { - tensor_utils::VectorBatchVectorAssign(projection_bias->data.f, n_output, - n_batch, output_ptr_time); - } else { - tensor_utils::ZeroVector(output_ptr_time, n_batch * n_output); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - projection_weights->data.f, n_output, n_cell, output_gate_scratch, - n_batch, output_ptr_time, /*result_stride=*/1); - if (params->proj_clip > 0.0) { - tensor_utils::ClipVector(output_ptr_time, n_batch * n_output, - params->proj_clip, output_ptr_time); - } - } else { - tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, - output_ptr_time); - } - tensor_utils::CopyVector(output_ptr_time, n_batch * n_output, - output_state->data.f); + const float* input_ptr_batch = input->data.f + t * n_batch * n_input; + float* output_ptr_batch = output->data.f + t * n_batch * n_output; + + kernel_utils::LstmStep( + input_ptr_batch, input_to_input_weights_ptr, + input_to_forget_weights_ptr, input_to_cell_weights_ptr, + input_to_output_weights_ptr, recurrent_to_input_weights_ptr, + recurrent_to_forget_weights_ptr, recurrent_to_cell_weights_ptr, + recurrent_to_output_weights_ptr, cell_to_input_weights_ptr, + cell_to_forget_weights_ptr, cell_to_output_weights_ptr, + input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, + output_gate_bias_ptr, projection_weights_ptr, projection_bias_ptr, + params, n_batch, n_cell, n_input, n_output, output_state_ptr, + cell_state_ptr, input_gate_scratch, forget_gate_scratch, cell_scratch, + output_gate_scratch, output_ptr_batch); } return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/memory_planner.h b/tensorflow/contrib/lite/memory_planner.h index 5cd6c208500f3ea84ab8146f7f136e8b7851ff03..0294ec815c4820d41361b8cd4a814b74c3c1d770 100644 --- a/tensorflow/contrib/lite/memory_planner.h +++ b/tensorflow/contrib/lite/memory_planner.h @@ -34,8 +34,8 @@ class MemoryPlanner { // [first_node, last_node]. virtual TfLiteStatus ExecuteAllocations(int first_node, int last_node) = 0; - // Invalidates allocations made earliers. This is called when tensors sizes - // have change. All planned allocations remain, but can't be used until + // Invalidates allocations made earlier. This is called when tensors sizes + // have changed. All planned allocations remain, but can't be used until // ExecuteAllocations() is called. virtual TfLiteStatus ResetAllocations() = 0; }; diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 725f2838c574fcc2ba389401f92575279ebc144c..791d1378f393594ceb6f1fcec7cc5aadaa81dab3 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -32,11 +32,46 @@ namespace tflite { const char* kEmptyTensorName = ""; +// Loads a model from `filename`. If `mmap_file` is true then use mmap, +// otherwise make a copy of the model in a buffer. +std::unique_ptr GetAllocationFromFile(const char* filename, + bool mmap_file, + ErrorReporter* error_reporter, + bool use_nnapi) { + std::unique_ptr allocation; + if (mmap_file) { + if (use_nnapi && NNAPIExists()) + allocation.reset(new NNAPIAllocation(filename, error_reporter)); + else + allocation.reset(new MMAPAllocation(filename, error_reporter)); + } else { + allocation.reset(new FileCopyAllocation(filename, error_reporter)); + } + return allocation; +} + std::unique_ptr FlatBufferModel::BuildFromFile( const char* filename, ErrorReporter* error_reporter) { std::unique_ptr model; - model.reset(new FlatBufferModel(filename, /*mmap_file=*/true, error_reporter, - /*use_nnapi=*/true)); + auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true, + error_reporter, /*use_nnapi=*/true); + model.reset(new FlatBufferModel(allocation.release(), error_reporter)); + if (!model->initialized()) model.reset(); + return model; +} + +std::unique_ptr FlatBufferModel::VerifyAndBuildFromFile( + const char* filename, TfLiteVerifier* verifier, + ErrorReporter* error_reporter) { + std::unique_ptr model; + auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true, + error_reporter, /*use_nnapi=*/true); + if (verifier && + !verifier->Verify(static_cast(allocation->base()), + allocation->bytes(), error_reporter)) { + return model; + } + model.reset(new FlatBufferModel(allocation.release(), error_reporter)); if (!model->initialized()) model.reset(); return model; } @@ -44,7 +79,9 @@ std::unique_ptr FlatBufferModel::BuildFromFile( std::unique_ptr FlatBufferModel::BuildFromBuffer( const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) { std::unique_ptr model; - model.reset(new FlatBufferModel(buffer, buffer_size, error_reporter)); + Allocation* allocation = + new MemoryAllocation(buffer, buffer_size, error_reporter); + model.reset(new FlatBufferModel(allocation, error_reporter)); if (!model->initialized()) model.reset(); return model; } @@ -57,23 +94,6 @@ std::unique_ptr FlatBufferModel::BuildFromModel( return model; } -FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file, - ErrorReporter* error_reporter, bool use_nnapi) - : error_reporter_(error_reporter ? error_reporter - : DefaultErrorReporter()) { - if (mmap_file) { - if (use_nnapi && NNAPIExists()) - allocation_ = new NNAPIAllocation(filename, error_reporter); - else - allocation_ = new MMAPAllocation(filename, error_reporter); - } else { - allocation_ = new FileCopyAllocation(filename, error_reporter); - } - if (!allocation_->valid() || !CheckModelIdentifier()) return; - - model_ = ::tflite::GetModel(allocation_->base()); -} - bool FlatBufferModel::CheckModelIdentifier() const { if (!tflite::ModelBufferHasIdentifier(allocation_->base())) { const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base()); @@ -85,21 +105,21 @@ bool FlatBufferModel::CheckModelIdentifier() const { return true; } -FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes, +FlatBufferModel::FlatBufferModel(const Model* model, ErrorReporter* error_reporter) : error_reporter_(error_reporter ? error_reporter : DefaultErrorReporter()) { - allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter); - if (!allocation_->valid()) return; - - model_ = ::tflite::GetModel(allocation_->base()); + model_ = model; } -FlatBufferModel::FlatBufferModel(const Model* model, +FlatBufferModel::FlatBufferModel(Allocation* allocation, ErrorReporter* error_reporter) : error_reporter_(error_reporter ? error_reporter : DefaultErrorReporter()) { - model_ = model; + allocation_ = allocation; + if (!allocation_->valid() || !CheckModelIdentifier()) return; + + model_ = ::tflite::GetModel(allocation_->base()); } FlatBufferModel::~FlatBufferModel() { delete allocation_; } @@ -287,6 +307,9 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_EXP: case BuiltinOperator_TOPK_V2: case BuiltinOperator_LOG_SOFTMAX: + case BuiltinOperator_CAST: + case BuiltinOperator_DEQUANTIZE: + case BuiltinOperator_PRELU: break; case BuiltinOperator_LSH_PROJECTION: { TfLiteLSHProjectionParams* params = @@ -574,6 +597,9 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_MAXIMUM: { + break; + } case BuiltinOperator_DELEGATE: { // TODO(ycling): Revisit when supporting saving delegated models. error_reporter->Report("DELEGATE op shouldn't exist in model."); @@ -657,9 +683,27 @@ TfLiteStatus InterpreterBuilder::ParseTensors( // but we really only support one value for the whole tensor. // TODO(aselle): This breaks as well if these are nullptr's. // TODO(aselle): This assumes non per-channel quantization. - if (q_params->scale()) quantization.scale = q_params->scale()->Get(0); - if (q_params->zero_point()) + + if (q_params->scale()) { + if (q_params->scale()->size() != 1) { + error_reporter_->Report( + "QuantizationParam has %d scale values (only 1 is supported).", + q_params->scale()->size()); + return kTfLiteError; + } + quantization.scale = q_params->scale()->Get(0); + } + + if (q_params->zero_point()) { + if (q_params->zero_point()->size() != 1) { + error_reporter_->Report( + "QuantizationParam has %d zero_point values" + " (only 1 is supported).", + q_params->zero_point()->size()); + return kTfLiteError; + } quantization.zero_point = q_params->zero_point()->Get(0); + } } TfLiteType type; @@ -737,6 +781,11 @@ TfLiteStatus InterpreterBuilder::ParseTensors( TfLiteStatus InterpreterBuilder::operator()( std::unique_ptr* interpreter) { + return operator()(interpreter, /*num_threads=*/-1); +} + +TfLiteStatus InterpreterBuilder::operator()( + std::unique_ptr* interpreter, int num_threads) { if (!interpreter) { error_reporter_->Report( "Null output pointer passed to InterpreterBuilder."); @@ -791,9 +840,8 @@ TfLiteStatus InterpreterBuilder::operator()( if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) { return cleanup_and_error(); } - - (**interpreter).set_model(model_); - + // Set num threads + (**interpreter).SetNumThreads(num_threads); // Parse inputs/outputs (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs())); (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs())); diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h index a467df5bb4eee3f6ce814512cb8b74bf09a6a4e7..036dc46e03f565c40791aee55d4158cef5c832e0 100644 --- a/tensorflow/contrib/lite/model.h +++ b/tensorflow/contrib/lite/model.h @@ -41,6 +41,17 @@ limitations under the License. namespace tflite { +// Abstract interface that verifies whether a given model is legit. +// It facilitates the use-case to verify and build a model without loading it +// twice. +class TfLiteVerifier { + public: + // Returns true if the model is legit. + virtual bool Verify(const char* data, int length, + ErrorReporter* reporter) = 0; + virtual ~TfLiteVerifier() {} +}; + // An RAII object that represents a read-only tflite model, copied from disk, // or mmapped. This uses flatbuffers as the serialization format. class FlatBufferModel { @@ -50,6 +61,12 @@ class FlatBufferModel { const char* filename, ErrorReporter* error_reporter = DefaultErrorReporter()); + // Verifies whether the content of the file is legit, then builds a model + // based on the file. Returns a nullptr in case of failure. + static std::unique_ptr VerifyAndBuildFromFile( + const char* filename, TfLiteVerifier* verifier = nullptr, + ErrorReporter* error_reporter = DefaultErrorReporter()); + // Builds a model based on a pre-loaded flatbuffer. The caller retains // ownership of the buffer and should keep it alive until the returned object // is destroyed. Returns a nullptr in case of failure. @@ -64,7 +81,7 @@ class FlatBufferModel { const tflite::Model* model_spec, ErrorReporter* error_reporter = DefaultErrorReporter()); - // Releases memory or unmaps mmaped meory. + // Releases memory or unmaps mmaped memory. ~FlatBufferModel(); // Copying or assignment is disallowed to simplify ownership semantics. @@ -82,23 +99,9 @@ class FlatBufferModel { bool CheckModelIdentifier() const; private: - // Loads a model from `filename`. If `mmap_file` is true then use mmap, - // otherwise make a copy of the model in a buffer. - // - // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be - // used. - explicit FlatBufferModel( - const char* filename, bool mmap_file = true, - ErrorReporter* error_reporter = DefaultErrorReporter(), - bool use_nnapi = false); - - // Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has - // to remain alive and unchanged until the end of this flatbuffermodel's - // lifetime. - // - // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be - // used. - FlatBufferModel(const char* ptr, size_t num_bytes, + // Loads a model from a given allocation. FlatBufferModel will take over the + // ownership of `allocation`, and delete it in desctructor. + FlatBufferModel(Allocation* allocation, ErrorReporter* error_reporter = DefaultErrorReporter()); // Loads a model from Model flatbuffer. The `model` has to remain alive and @@ -151,6 +154,8 @@ class InterpreterBuilder { InterpreterBuilder(const InterpreterBuilder&) = delete; InterpreterBuilder& operator=(const InterpreterBuilder&) = delete; TfLiteStatus operator()(std::unique_ptr* interpreter); + TfLiteStatus operator()(std::unique_ptr* interpreter, + int num_threads); private: TfLiteStatus BuildLocalIndexToRegistrationMapping(); diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc index 66f22fd66a9ae0d35553a1f780ef73a5c5994c99..ae6c1ece18963f11f48a6f07bea4065ce39687e0 100644 --- a/tensorflow/contrib/lite/model_test.cc +++ b/tensorflow/contrib/lite/model_test.cc @@ -209,6 +209,38 @@ TEST(BasicFlatBufferModel, TestNullModel) { ASSERT_EQ(interpreter.get(), nullptr); } +// Mocks the verifier by setting the result in ctor. +class FakeVerifier : public tflite::TfLiteVerifier { + public: + explicit FakeVerifier(bool result) : result_(result) {} + bool Verify(const char* data, int length, + tflite::ErrorReporter* reporter) override { + return result_; + } + + private: + bool result_; +}; + +TEST(BasicFlatBufferModel, TestWithTrueVerifier) { + FakeVerifier verifier(true); + ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile( + "tensorflow/contrib/lite/testdata/test_model.bin", + &verifier)); +} + +TEST(BasicFlatBufferModel, TestWithFalseVerifier) { + FakeVerifier verifier(false); + ASSERT_FALSE(FlatBufferModel::VerifyAndBuildFromFile( + "tensorflow/contrib/lite/testdata/test_model.bin", + &verifier)); +} + +TEST(BasicFlatBufferModel, TestWithNullVerifier) { + ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile( + "tensorflow/contrib/lite/testdata/test_model.bin", nullptr)); +} + struct TestErrorReporter : public ErrorReporter { int Report(const char* format, va_list args) override { calls++; diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h index 76032771af2c8e099aed498b2071816646f3b606..bd49d327c995ef53dc6cf9f8301ab749c925b2c7 100644 --- a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h +++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h @@ -569,7 +569,7 @@ enum { ANEURALNETWORKS_LOGISTIC = 14, /** - * Projects an input to a bit vector via locality senstive hashing. + * Projects an input to a bit vector via locality sensitive hashing. * * Inputs: * * 0: Hash functions. Dim.size == 2, DataType: Float. diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index e631ffd845d3b31232070b935c12aa8a2e8ce05e..decaf9f160ad35b66f0ed56d0840634c610e4246 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -346,7 +346,11 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_STRIDED_SLICE: case tflite::BuiltinOperator_EXP: case tflite::BuiltinOperator_LOG_SOFTMAX: + case tflite::BuiltinOperator_DEQUANTIZE: case tflite::BuiltinOperator_DELEGATE: + case tflite::BuiltinOperator_CAST: + case tflite::BuiltinOperator_PRELU: + case tflite::BuiltinOperator_MAXIMUM: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 82feae0f0041997949212613c654a5695f468d56..411d5c0d272c07b710fe987d25a79f2614bbab4e 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -4,6 +4,38 @@ package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "py_test") +filegroup( + name = "interpreter_test_data", + srcs = glob(["**/testdata/*"]), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "interpreter", + srcs = [ + "interpreter.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper", + ], +) + +py_test( + name = "interpreter_test", + srcs = ["interpreter_test.py"], + data = [":interpreter_test_data"], + srcs_version = "PY2AND3", + tags = ["no_oss"], + deps = [ + ":interpreter", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", + ], +) + py_library( name = "lite", srcs = ["lite.py"], @@ -37,7 +69,10 @@ py_test( name = "lite_test", srcs = ["lite_test.py"], srcs_version = "PY2AND3", - tags = ["no_oss"], + tags = [ + "no-internal-py3", + "no_oss", + ], deps = [ ":lite", ":op_hint", @@ -49,6 +84,41 @@ py_test( ], ) +py_binary( + name = "convert_saved_model", + srcs = ["convert_saved_model.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":lite", + "//tensorflow/contrib/saved_model:saved_model_py", + "//tensorflow/python:graph_util", + "//tensorflow/python/tools:freeze_graph_lib", + ], +) + +py_test( + name = "convert_saved_model_test", + srcs = ["convert_saved_model_test.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":convert_saved_model", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + "//tensorflow/python/saved_model", + ], +) + +# Transitive dependencies of this target will be included in the pip package. +py_library( + name = "tf_lite_py_pip", + deps = [ + ":convert_saved_model", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a2b5ef488ec1feb455b2c8d5d1c4005c3b2f60d6 --- /dev/null +++ b/tensorflow/contrib/lite/python/convert_saved_model.py @@ -0,0 +1,262 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""TensorFlow Lite flatbuffer generation from saved_models. + +Example: + +bazel run third_party/tensorflow/contrib/lite/python:convert_saved_model -- \ + --saved_model_dir=/tmp/test_saved_model/1519865537 \ + --output_tflite=/tmp/test.lite + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.saved_model.python.saved_model import reader +from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils +from tensorflow.core.framework import types_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import graph_util as tf_graph_util +from tensorflow.python.framework import ops +from tensorflow.python.platform import app +from tensorflow.python.platform import flags +from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import tag_constants + +flags.DEFINE_string("saved_model_dir", "", "Saved model directory to convert.") +flags.DEFINE_string("output_tflite", None, "File path to write flatbuffer.") +flags.DEFINE_string("output_arrays", None, + "List of output tensor names, the default value is None, " + "which means the conversion will keep all outputs.") +flags.DEFINE_integer("batch_size", 1, + "If input tensor shape has None at first dimension, " + "e.g. (None,224,224,3), replace None with batch_size.") +flags.DEFINE_string("tag_set", tag_constants.SERVING, + "Group of tag(s) of the MetaGraphDef in the saved_model, " + "in string format, separated by ','. For tag-set contains " + "multiple tags, all tags must be passed in.") +flags.DEFINE_string("signature_key", + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + "This is signature key to extract inputs, outputs.") + + +def log_tensor_details(tensor_info): + """Log tensor details: name, shape, and type.""" + for key in tensor_info: + val = tensor_info[key] + dtype = types_pb2.DataType.Name(val.dtype) + if val.tensor_shape.unknown_rank: + shape = "unknown_rank" + else: + dims = [str(dim.size) for dim in val.tensor_shape.dim] + shape = "({})".format(", ".join(dims)) + + logging.info("Tensor's key in saved_model's tensor_map: %s", key) + logging.info(" tensor name: %s, shape: %s, type: %s", val.name, shape, + dtype) + + +def get_meta_graph_def(saved_model_dir, tag_set): + """Validate saved_model and extract MetaGraphDef. + + Args: + saved_model_dir: saved_model path to convert. + tag_set: Set of tag(s) of the MetaGraphDef to load. + + Returns: + The meta_graph_def used for tflite conversion. + + Raises: + ValueError: No valid MetaGraphDef for given tag_set. + """ + saved_model = reader.read_saved_model(saved_model_dir) + tag_sets = [] + result_meta_graph_def = None + for meta_graph_def in saved_model.meta_graphs: + meta_graph_tag_set = set(meta_graph_def.meta_info_def.tags) + tag_sets.append(meta_graph_tag_set) + if meta_graph_tag_set == tag_set: + result_meta_graph_def = meta_graph_def + logging.info("The given saved_model contains the following tags: %s", + tag_sets) + if result_meta_graph_def is not None: + return result_meta_graph_def + else: + raise ValueError("No valid MetaGraphDef for this tag_set '{}'. Possible " + "values are '{}'. ".format(tag_set, tag_sets)) + + +def get_signature_def(meta_graph, signature_key): + """Get the signature def from meta_graph with given signature_key. + + Args: + meta_graph: meta_graph_def. + signature_key: signature_def in the meta_graph_def. + + Returns: + The signature_def used for tflite conversion. + + Raises: + ValueError: Given signature_key is not valid for this meta_graph. + """ + signature_def_map = meta_graph.signature_def + signature_def_keys = set(signature_def_map.keys()) + logging.info( + "The given saved_model MetaGraphDef contains SignatureDefs with the " + "following keys: %s", signature_def_keys) + if signature_key not in signature_def_keys: + raise ValueError("No '{}' in the saved_model\'s SignatureDefs. Possible " + "values are '{}'. ".format(signature_key, + signature_def_keys)) + signature_def = signature_def_utils.get_signature_def_by_key( + meta_graph, signature_key) + return signature_def + + +def get_inputs_outputs(signature_def): + """Get inputs and outputs from signature def. + + Args: + signature_def: signatuer def in the meta_graph_def for conversion. + + Returns: + The inputs and outputs in the graph for conversion. + """ + inputs_tensor_info = signature_def.inputs + outputs_tensor_info = signature_def.outputs + logging.info("input tensors info: ") + log_tensor_details(inputs_tensor_info) + logging.info("output tensors info: ") + log_tensor_details(outputs_tensor_info) + + def gather_names(tensor_info): + return [tensor_info[key].name for key in tensor_info] + + inputs = gather_names(inputs_tensor_info) + outputs = gather_names(outputs_tensor_info) + return inputs, outputs + + +def convert(saved_model_dir, + output_tflite=None, + output_arrays=None, + tag_set=None, + signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + batch_size=1): + """Convert a saved_model to tflite flatbuffer. + + Args: + saved_model_dir: Saved model directory to convert. + output_tflite: File path to write result flatbuffer. + output_arrays: List of output tensor names, the default value is None, which + means conversion keeps all output tensors. This is also used to filter + tensors that are from Op currently not supported in tflite, e.g., Argmax). + tag_set: This is the set of tags to get meta_graph_def in saved_model. + signature_key: This is the signature key to extract inputs, outputs. + batch_size: If input tensor shape has None at first dimension, + e.g. (None,224,224,3), replace None with batch_size. + + Returns: + The converted data. For example if tflite was the destination, then + this will be a tflite flatbuffer in a bytes array. + + Raises: + ValueError: If tag_set does not indicate any meta_graph_def in saved_model, + or signature_key is not in relevant meta_graph_def, + or input shape has None beyond 1st dimension, e.g., (1,None, None, 3), + or given output_arrays are not valid causing empty outputs. + """ + if tag_set is None: + tag_set = set([tag_constants.SERVING]) + + meta_graph = get_meta_graph_def(saved_model_dir, tag_set) + signature_def = get_signature_def(meta_graph, signature_key) + inputs, outputs = get_inputs_outputs(signature_def) + + graph = ops.Graph() + with session.Session(graph=graph) as sess: + + loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir) + + in_tensors = [graph.get_tensor_by_name(input_) for input_ in inputs] + + # Users can use output_arrays to filter output tensors for conversion. + # If output_arrays is None, we keep all output tensors. In future, we may + # use tflite supported Op list and check whether op is custom Op to + # automatically filter output arrays. + # TODO(zhixianyan): Use tflite supported Op list to filter outputs. + if output_arrays is not None: + output_arrays = output_arrays.split(",") + out_tensors = [ + graph.get_tensor_by_name(output) + for output in outputs + if output.split(":")[0] in output_arrays + ] + else: + out_tensors = [graph.get_tensor_by_name(output) for output in outputs] + + output_names = [node.split(":")[0] for node in outputs] + + if not out_tensors: + raise ValueError( + "No valid output tensors for '{}', possible values are '{}'".format( + output_arrays, output_names)) + + frozen_graph_def = tf_graph_util.convert_variables_to_constants( + sess, graph.as_graph_def(), output_names) + + # Toco requires fully defined tensor shape, for input tensor with None in + # their shape, e.g., (None, 224, 224, 3), we need to replace first None with + # a given batch size. For shape with more None, e.g. (None, None, None, 3), + # still be able to replace and convert, but require further investigation. + # TODO(zhixianyan): Add supports for input tensor with more None in shape. + for i in range(len(in_tensors)): + shape = in_tensors[i].get_shape().as_list() + if shape[0] is None: + shape[0] = batch_size + if None in shape[1:]: + raise ValueError( + "Only support None shape at 1st dim as batch_size. But tensor " + "'{}' 's shape '{}' has None at other dimension. ".format( + inputs[i], shape)) + in_tensors[i].set_shape(shape) + + result = lite.toco_convert(frozen_graph_def, in_tensors, out_tensors) + + if output_tflite is not None: + with gfile.Open(output_tflite, "wb") as f: + f.write(result) + logging.info("Successfully converted to: %s", output_tflite) + + return result + + +def main(_): + convert( + saved_model_dir=flags.FLAGS.saved_model_dir, + output_tflite=flags.FLAGS.output_tflite, + output_arrays=flags.FLAGS.output_arrays, + batch_size=flags.FLAGS.batch_size, + tag_set=set(flags.FLAGS.tag_set.split(",")), + signature_key=flags.FLAGS.signature_key) + + +if __name__ == "__main__": + app.run(main) diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d87fbeb91cc3d2779c0ae01aff488f88bd340c1c --- /dev/null +++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py @@ -0,0 +1,276 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TF Lite SavedModel Conversion test cases. + + - test on generated saved_models from simple graphs (sanity check) + - test mnist savedmodel generated on-the-fly + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from tensorflow.contrib.lite.python import convert_saved_model +from tensorflow.python import estimator +from tensorflow.python import keras +from tensorflow.python import layers +from tensorflow.python import losses +from tensorflow.python import nn +from tensorflow.python import saved_model +from tensorflow.python import train +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + + +class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase): + + def _createSimpleSavedModel(self, shape): + """Create a simple savedmodel on the fly.""" + saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel") + with session.Session() as sess: + in_tensor = array_ops.placeholder(shape=shape, dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + inputs = {"x": in_tensor} + outputs = {"y": out_tensor} + saved_model.simple_save(sess, saved_model_dir, inputs, outputs) + return saved_model_dir + + def testSimpleSavedModel(self): + """Test a simple savedmodel created on the fly.""" + # Create a simple savedmodel + saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) + # Convert to tflite + result = convert_saved_model.convert(saved_model_dir=saved_model_dir) + self.assertTrue(result) + + def testSimpleSavedModelWithNoneBatchSizeInShape(self): + """Test a simple savedmodel, with None in input tensor's shape.""" + saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3]) + result = convert_saved_model.convert(saved_model_dir=saved_model_dir) + self.assertTrue(result) + + def testSimpleSavedModelWithMoreNoneInShape(self): + """Test a simple savedmodel, fail as more None in input shape.""" + saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, None, 3]) + # Convert to tflite: this should raise ValueError, as 3rd dim is None. + with self.assertRaises(ValueError): + convert_saved_model.convert(saved_model_dir=saved_model_dir) + + def testSimpleSavedModelWithWrongSignatureKey(self): + """Test a simple savedmodel, fail as given signature is invalid.""" + saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) + # Convert to tflite: this should raise ValueError, as + # signature_key does not exit in the saved_model. + with self.assertRaises(ValueError): + convert_saved_model.convert( + saved_model_dir=saved_model_dir, signature_key="wrong-key") + + def testSimpleSavedModelWithWrongOutputArray(self): + """Test a simple savedmodel, fail as given output_arrays is invalid.""" + # Create a simple savedmodel + saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) + # Convert to tflite: this should raise ValueError, as + # output_arrays is not valid for the saved_model. + with self.assertRaises(ValueError): + convert_saved_model.convert( + saved_model_dir=saved_model_dir, output_arrays="wrong-output") + + def testMultipleMetaGraphDef(self): + """Test saved model with multiple MetaGraphDef.""" + saved_model_dir = os.path.join(self.get_temp_dir(), "savedmodel_two_mgd") + builder = saved_model.builder.SavedModelBuilder(saved_model_dir) + with session.Session(graph=ops.Graph()) as sess: + # MetaGraphDef 1 + in_tensor = array_ops.placeholder(shape=[1, 28, 28], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sig_input_tensor = saved_model.utils.build_tensor_info(in_tensor) + sig_input_tensor_signature = {"x": sig_input_tensor} + sig_output_tensor = saved_model.utils.build_tensor_info(out_tensor) + sig_output_tensor_signature = {"y": sig_output_tensor} + predict_signature_def = ( + saved_model.signature_def_utils.build_signature_def( + sig_input_tensor_signature, sig_output_tensor_signature, + saved_model.signature_constants.PREDICT_METHOD_NAME)) + signature_def_map = { + saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + predict_signature_def + } + builder.add_meta_graph_and_variables( + sess, + tags=[saved_model.tag_constants.SERVING, "additional_test_tag"], + signature_def_map=signature_def_map) + # MetaGraphDef 2 + builder.add_meta_graph(tags=["tflite"]) + builder.save(True) + + # Convert to tflite + convert_saved_model.convert( + saved_model_dir=saved_model_dir, + tag_set=set([saved_model.tag_constants.SERVING, "additional_test_tag"])) + + +class Model(keras.Model): + """Model to recognize digits in the MNIST dataset. + + Train and export savedmodel, used for testOnflyTrainMnistSavedModel + + Network structure is equivalent to: + https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py + and + https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py + + But written as a ops.keras.Model using the layers API. + """ + + def __init__(self, data_format): + """Creates a model for classifying a hand-written digit. + + Args: + data_format: Either "channels_first" or "channels_last". + "channels_first" is typically faster on GPUs while "channels_last" is + typically faster on CPUs. See + https://www.tensorflow.org/performance/performance_guide#data_formats + """ + super(Model, self).__init__() + self._input_shape = [-1, 28, 28, 1] + + self.conv1 = layers.Conv2D( + 32, 5, padding="same", data_format=data_format, activation=nn.relu) + self.conv2 = layers.Conv2D( + 64, 5, padding="same", data_format=data_format, activation=nn.relu) + self.fc1 = layers.Dense(1024, activation=nn.relu) + self.fc2 = layers.Dense(10) + self.dropout = layers.Dropout(0.4) + self.max_pool2d = layers.MaxPooling2D( + (2, 2), (2, 2), padding="same", data_format=data_format) + + def __call__(self, inputs, training): + """Add operations to classify a batch of input images. + + Args: + inputs: A Tensor representing a batch of input images. + training: A boolean. Set to True to add operations required only when + training the classifier. + + Returns: + A logits Tensor with shape [, 10]. + """ + y = array_ops.reshape(inputs, self._input_shape) + y = self.conv1(y) + y = self.max_pool2d(y) + y = self.conv2(y) + y = self.max_pool2d(y) + y = layers.flatten(y) + y = self.fc1(y) + y = self.dropout(y, training=training) + return self.fc2(y) + + +def model_fn(features, labels, mode, params): + """The model_fn argument for creating an Estimator.""" + model = Model(params["data_format"]) + image = features + if isinstance(image, dict): + image = features["image"] + + if mode == estimator.ModeKeys.PREDICT: + logits = model(image, training=False) + predictions = { + "classes": math_ops.argmax(logits, axis=1), + "probabilities": nn.softmax(logits), + } + return estimator.EstimatorSpec( + mode=estimator.ModeKeys.PREDICT, + predictions=predictions, + export_outputs={ + "classify": estimator.export.PredictOutput(predictions) + }) + + elif mode == estimator.ModeKeys.TRAIN: + optimizer = train.AdamOptimizer(learning_rate=1e-4) + + logits = model(image, training=True) + loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) + return estimator.EstimatorSpec( + mode=estimator.ModeKeys.TRAIN, + loss=loss, + train_op=optimizer.minimize(loss, train.get_or_create_global_step())) + + elif mode == estimator.ModeKeys.EVAL: + logits = model(image, training=False) + loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) + return estimator.EstimatorSpec( + mode=estimator.ModeKeys.EVAL, + loss=loss, + eval_metric_ops={ + "accuracy": + ops.metrics.accuracy( + labels=labels, predictions=math_ops.argmax(logits, axis=1)), + }) + + +def dummy_input_fn(): + image = random_ops.random_uniform([100, 784]) + labels = random_ops.random_uniform([100, 1], maxval=9, dtype=dtypes.int32) + return image, labels + + +class ConvertSavedModelTestTrainGraph(test_util.TensorFlowTestCase): + + def testTrainedMnistSavedModel(self): + """Test mnist savedmodel, trained with dummy data and small steps.""" + # Build classifier + classifier = estimator.Estimator( + model_fn=model_fn, + params={ + "data_format": "channels_last" # tflite format + }) + + # Train and pred for serving + classifier.train(input_fn=dummy_input_fn, steps=2) + image = array_ops.placeholder(dtypes.float32, [None, 28, 28]) + pred_input_fn = estimator.export.build_raw_serving_input_receiver_fn({ + "image": image, + }) + + # Export savedmodel + saved_model_dir = os.path.join(self.get_temp_dir(), "mnist_savedmodel") + classifier.export_savedmodel(saved_model_dir, pred_input_fn) + + # Convert to tflite and test output + saved_model_name = os.listdir(saved_model_dir)[0] + saved_model_final_dir = os.path.join(saved_model_dir, saved_model_name) + output_tflite = os.path.join(saved_model_dir, + saved_model_final_dir + ".lite") + # TODO(zhixianyan): no need to limit output_arrays to `Softmax' + # once b/74205001 fixed and argmax implemented in tflite. + result = convert_saved_model.convert( + saved_model_dir=saved_model_final_dir, + output_arrays="Softmax", + output_tflite=output_tflite) + + self.assertTrue(result) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py new file mode 100644 index 0000000000000000000000000000000000000000..b8638007f7e49737726d9939a00e8cb1d6a41281 --- /dev/null +++ b/tensorflow/contrib/lite/python/interpreter.py @@ -0,0 +1,151 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python TF-Lite interpreter.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.lite.python.interpreter_wrapper import tensorflow_wrap_interpreter_wrapper as interpreter_wrapper + + +class Interpreter(object): + """Interpreter inferace for TF-Lite Models.""" + + def __init__(self, model_path=None, model_content=None): + """Constructor. + + Args: + model_path: Path to TF-Lite Flatbuffer file. + model_content: Content of model. + + Raises: + ValueError: If the interpreter was unable to create. + """ + if model_path and not model_content: + self._interpreter = ( + interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromFile( + model_path)) + if not self._interpreter: + raise ValueError('Failed to open {}'.format(model_path)) + elif model_content and not model_path: + self._interpreter = ( + interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer( + model_content, len(model_content))) + if not self._interpreter: + raise ValueError( + 'Failed to create model from {} bytes'.format(len(model_content))) + elif not model_path and not model_path: + raise ValueError('`model_path` or `model_content` must be specified.') + else: + raise ValueError('Can\'t both provide `model_path` and `model_content`') + + def allocate_tensors(self): + if not self._interpreter.AllocateTensors(): + raise ValueError('Failed to allocate tensors') + + def _get_tensor_details(self, tensor_index): + """Gets tensor details. + + Args: + tensor_index: Tensor index of tensor to query. + + Returns: + a dictionary containing the name, index, shape and type of the tensor. + + Raises: + ValueError: If tensor_index is invalid. + """ + tensor_index = int(tensor_index) + tensor_name = self._interpreter.TensorName(tensor_index) + tensor_size = self._interpreter.TensorSize(tensor_index) + tensor_type = self._interpreter.TensorType(tensor_index) + tensor_quantization = self._interpreter.TensorQuantization(tensor_index) + + if not tensor_name or not tensor_type: + raise ValueError('Could not get tensor details') + + details = { + 'name': tensor_name, + 'index': tensor_index, + 'shape': tensor_size, + 'dtype': tensor_type, + 'quantization': tensor_quantization, + } + + return details + + def get_input_details(self): + """Gets model input details. + + Returns: + A list of input details. + """ + return [ + self._get_tensor_details(i) for i in self._interpreter.InputIndices() + ] + + def set_tensor(self, tensor_index, value): + """Sets the value of the input. + + Args: + tensor_index: Tensor index of tensor to set. This value can be gotten from + the 'index' field in get_input_details. + value: Value of tensor to set. + + Raises: + ValueError: If the interpreter could not set the tensor. + """ + if not self._interpreter.SetTensor(tensor_index, value): + raise ValueError('Failed to set tensor') + + def resize_tensor_input(self, input_index, tensor_size): + """Resizes an input tensor. + + Args: + input_index: Tensor index of input to set. This value can be gotten from + the 'index' field in get_input_details. + tensor_size: The tensor_shape to resize the input to. + + Raises: + ValueError: If the interpreter could not resize the input tensor. + """ + if not self.ResizeInputTensor.SetTensor(input_index, tensor_size): + raise ValueError('Failed to set input') + + def get_output_details(self): + """Gets model output details. + + Returns: + A list of output details. + """ + return [ + self._get_tensor_details(i) for i in self._interpreter.OutputIndices() + ] + + def get_tensor(self, tensor_index): + """Sets the value of the input. + + Args: + tensor_index: Tensor index of tensor to get. This value can be gotten from + the 'index' field in get_output_details. + + Returns: + a numpy array. + """ + return self._interpreter.GetTensor(tensor_index) + + def invoke(self): + if not self._interpreter.Invoke(): + raise ValueError('Failed to invoke TFLite model') diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cd2386f5263f24e1e034015ec6880e71f0608c7c --- /dev/null +++ b/tensorflow/contrib/lite/python/interpreter_test.py @@ -0,0 +1,92 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow Lite Python Interface: Sanity check.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import io +import numpy as np + +from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper +from tensorflow.python.framework import test_util +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test + + +class InterpreterTest(test_util.TensorFlowTestCase): + + def testFloat(self): + interpreter = interpreter_wrapper.Interpreter( + model_path=resource_loader.get_path_to_datafile( + 'testdata/permute_float.tflite')) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('input', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 4] == input_details[0]['shape']).all()) + self.assertEqual((0.0, 0), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('output', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 4] == output_details[0]['shape']).all()) + self.assertEqual((0.0, 0), output_details[0]['quantization']) + + test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32) + expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], test_input) + interpreter.invoke() + + output_data = interpreter.get_tensor(output_details[0]['index']) + self.assertTrue((expected_output == output_data).all()) + + def testUint8(self): + model_path = resource_loader.get_path_to_datafile( + 'testdata/permute_uint8.tflite') + with io.open(model_path, 'rb') as model_file: + data = model_file.read() + + interpreter = interpreter_wrapper.Interpreter(model_content=data) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('input', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 4] == input_details[0]['shape']).all()) + self.assertEqual((1.0, 0), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('output', output_details[0]['name']) + self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertTrue(([1, 4] == output_details[0]['shape']).all()) + self.assertEqual((1.0, 0), output_details[0]['quantization']) + + test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8) + expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8) + interpreter.set_tensor(input_details[0]['index'], test_input) + interpreter.invoke() + + output_data = interpreter.get_tensor(output_details[0]['index']) + self.assertTrue((expected_output == output_data).all()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..453eda6e7345762666917fd501b69c7181c349e8 --- /dev/null +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD @@ -0,0 +1,32 @@ +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") + +cc_library( + name = "interpreter_wrapper_lib", + srcs = ["interpreter_wrapper.cc"], + hdrs = ["interpreter_wrapper.h"], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/core:lib", + "//tensorflow/python:numpy_lib", + "//util/python:python_headers", + "@com_google_absl//absl/memory", + ], +) + +tf_py_wrap_cc( + name = "tensorflow_wrap_interpreter_wrapper", + srcs = [ + "interpreter_wrapper.i", + ], + deps = [ + ":interpreter_wrapper_lib", + "//util/python:python_headers", + ], +) diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc new file mode 100644 index 0000000000000000000000000000000000000000..35ad226b78c906f0819afd5b029a1a0d438d69af --- /dev/null +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -0,0 +1,337 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h" + +#include + +#include "absl/memory/memory.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/python/lib/core/numpy.h" + +#if PY_MAJOR_VERSION >= 3 +#define PY_TO_CPPSTRING PyBytes_AsStringAndSize +#define CPP_TO_PYSTRING PyBytes_FromStringAndSize +#else +#define PY_TO_CPPSTRING PyString_AsStringAndSize +#define CPP_TO_PYSTRING PyString_FromStringAndSize +#endif + +namespace tflite { +namespace interpreter_wrapper { + +namespace { +std::unique_ptr CreateInterpreter( + const tflite::FlatBufferModel* model, + const tflite::ops::builtin::BuiltinOpResolver& resolver) { + if (!model) { + return nullptr; + } + + std::unique_ptr interpreter; + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (interpreter) { + for (const int input_index : interpreter->inputs()) { + const TfLiteTensor* tensor = interpreter->tensor(input_index); + CHECK(tensor); + const TfLiteIntArray* dims = tensor->dims; + if (!dims) { + continue; + } + + std::vector input_dims(dims->data, dims->data + dims->size); + interpreter->ResizeInputTensor(input_index, input_dims); + } + } + return interpreter; +} + +int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { + switch (tf_lite_type) { + case kTfLiteFloat32: + return NPY_FLOAT32; + case kTfLiteInt32: + return NPY_INT32; + case kTfLiteUInt8: + return NPY_UINT8; + case kTfLiteInt64: + return NPY_INT64; + case kTfLiteString: + return NPY_OBJECT; + case kTfLiteNoType: + return -1; + } + LOG(ERROR) << "Unknown TfLiteType " << tf_lite_type; + return -1; +} + +TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { + int pyarray_type = PyArray_TYPE(array); + switch (pyarray_type) { + case NPY_FLOAT32: + return kTfLiteFloat32; + case NPY_INT32: + return kTfLiteInt32; + case NPY_UINT8: + return kTfLiteUInt8; + case NPY_INT64: + return kTfLiteInt64; + case NPY_OBJECT: + case NPY_STRING: + case NPY_UNICODE: + return kTfLiteString; + } + LOG(ERROR) << "Unknown PyArray dtype " << pyarray_type; + return kTfLiteNoType; +} + +struct PyDecrefDeleter { + void operator()(PyObject* p) const { Py_DECREF(p); } +}; + +PyObject* PyArrayFromIntVector(const int* data, npy_intp size) { + void* pydata = malloc(size * sizeof(int)); + memcpy(pydata, data, size * sizeof(int)); + return PyArray_SimpleNewFromData(1, &size, NPY_INT32, pydata); +} + +PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) { + PyObject* result = PyTuple_New(2); + PyTuple_SET_ITEM(result, 0, PyFloat_FromDouble(param.scale)); + PyTuple_SET_ITEM(result, 1, PyInt_FromLong(param.zero_point)); + return result; +} + +} // namespace + +InterpreterWrapper::InterpreterWrapper( + std::unique_ptr model) + : model_(std::move(model)), + resolver_(absl::make_unique()), + interpreter_(CreateInterpreter(model_.get(), *resolver_)) {} + +InterpreterWrapper::~InterpreterWrapper() {} + +bool InterpreterWrapper::AllocateTensors() { + if (!interpreter_) { + LOG(ERROR) << "Cannot allocate tensors: invalid interpreter."; + return false; + } + + if (interpreter_->AllocateTensors() != kTfLiteOk) { + LOG(ERROR) << "Unable to allocate tensors."; + return false; + } + + return true; +} + +bool InterpreterWrapper::Invoke() { + return interpreter_ ? (interpreter_->Invoke() == kTfLiteOk) : false; +} + +PyObject* InterpreterWrapper::InputIndices() const { + PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(), + interpreter_->inputs().size()); + + return PyArray_Return(reinterpret_cast(np_array)); +} + +PyObject* InterpreterWrapper::OutputIndices() const { + PyObject* np_array = PyArrayFromIntVector(interpreter_->outputs().data(), + interpreter_->outputs().size()); + + return PyArray_Return(reinterpret_cast(np_array)); +} + +bool InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) { + if (!interpreter_) { + LOG(ERROR) << "Invalid interpreter."; + return false; + } + + std::unique_ptr array_safe( + PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr)); + if (!array_safe) { + LOG(ERROR) << "Failed to convert value into readable tensor."; + return false; + } + + PyArrayObject* array = reinterpret_cast(array_safe.get()); + + if (PyArray_NDIM(array) != 1) { + LOG(ERROR) << "Expected 1-D defining input shape."; + return false; + } + + if (PyArray_TYPE(array) != NPY_INT32) { + LOG(ERROR) << "Shape must be an int32 array"; + return false; + } + + std::vector dims(PyArray_SHAPE(array)[0]); + memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int)); + + return interpreter_->ResizeInputTensor(i, dims); +} + +std::string InterpreterWrapper::TensorName(int i) const { + if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) { + return ""; + } + + const TfLiteTensor* tensor = interpreter_->tensor(i); + return tensor->name; +} + +PyObject* InterpreterWrapper::TensorType(int i) const { + if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) { + return nullptr; + } + + const TfLiteTensor* tensor = interpreter_->tensor(i); + int typenum = TfLiteTypeToPyArrayType(tensor->type); + return PyArray_TypeObjectFromType(typenum); +} + +PyObject* InterpreterWrapper::TensorSize(int i) const { + if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) { + Py_INCREF(Py_None); + return Py_None; + } + + const TfLiteTensor* tensor = interpreter_->tensor(i); + PyObject* np_array = + PyArrayFromIntVector(tensor->dims->data, tensor->dims->size); + + return PyArray_Return(reinterpret_cast(np_array)); +} + +PyObject* InterpreterWrapper::TensorQuantization(int i) const { + if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) { + Py_INCREF(Py_None); + return Py_None; + } + + const TfLiteTensor* tensor = interpreter_->tensor(i); + return PyTupleFromQuantizationParam(tensor->params); +} + +bool InterpreterWrapper::SetTensor(int i, PyObject* value) { + if (!interpreter_) { + LOG(ERROR) << "Invalid interpreter."; + return false; + } + + if (i >= interpreter_->tensors_size()) { + LOG(ERROR) << "Invalid tensor index: " << i << " exceeds max tensor index " + << interpreter_->tensors_size(); + return false; + } + + std::unique_ptr array_safe( + PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr)); + if (!array_safe) { + LOG(ERROR) << "Failed to convert value into readable tensor."; + return false; + } + + PyArrayObject* array = reinterpret_cast(array_safe.get()); + const TfLiteTensor* tensor = interpreter_->tensor(i); + + if (TfLiteTypeFromPyArray(array) != tensor->type) { + LOG(ERROR) << "Cannot set tensor:" + << " Got tensor of type " << TfLiteTypeFromPyArray(array) + << " but expected type " << tensor->type << " for input " << i; + return false; + } + + if (PyArray_NDIM(array) != tensor->dims->size) { + LOG(ERROR) << "Cannot set tensor: Dimension mismatch"; + return false; + } + + for (int j = 0; j < PyArray_NDIM(array); j++) { + if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) { + LOG(ERROR) << "Cannot set tensor: Dimension mismatch"; + return false; + } + } + + size_t size = PyArray_NBYTES(array); + DCHECK_EQ(size, tensor->bytes); + memcpy(tensor->data.raw, PyArray_DATA(array), size); + return true; +} + +PyObject* InterpreterWrapper::GetTensor(int i) const { + if (!interpreter_) { + LOG(ERROR) << "Invalid interpreter."; + Py_INCREF(Py_None); + return Py_None; + } + + if (i >= interpreter_->tensors_size()) { + LOG(ERROR) << "Invalid tensor index: " << i << " exceeds max tensor index " + << interpreter_->inputs().size(); + Py_INCREF(Py_None); + return Py_None; + } + + const TfLiteTensor* output_tensor = interpreter_->tensor(i); + const int tensor_size = output_tensor->bytes; + if (tensor_size <= 0) { + LOG(ERROR) << "Invalid tensor size"; + Py_INCREF(Py_None); + return Py_None; + } + + int type_num = TfLiteTypeToPyArrayType(output_tensor->type); + if (type_num == -1) { + LOG(ERROR) << "Unknown tensor type " << output_tensor->type; + Py_INCREF(Py_None); + return Py_None; + } + + void* data = malloc(tensor_size); + memcpy(data, output_tensor->data.raw, tensor_size); + + const TfLiteIntArray* output_dims = output_tensor->dims; + std::vector dims(output_dims->data, + output_dims->data + output_dims->size); + PyObject* np_array = + PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data); + + return PyArray_Return(reinterpret_cast(np_array)); +} + +InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( + const char* model_path) { + std::unique_ptr model = + tflite::FlatBufferModel::BuildFromFile(model_path); + return model ? new InterpreterWrapper(std::move(model)) : nullptr; +} + +InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( + const char* data, size_t len) { + std::unique_ptr model = + tflite::FlatBufferModel::BuildFromBuffer(data, len); + return model ? new InterpreterWrapper(std::move(model)) : nullptr; +} + +} // namespace interpreter_wrapper +} // namespace tflite diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..0972c572595f5044a305a81afaccbea5f131247c --- /dev/null +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -0,0 +1,77 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ +#define TENSORFLOW_CONTRIB_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ + +#include +#include +#include + +#include + +// We forward declare TFLite classes here to avoid exposing them to SWIG. +namespace tflite { +namespace ops { +namespace builtin { +class BuiltinOpResolver; +} // namespace builtin +} // namespace ops + +class FlatBufferModel; +class Interpreter; + +namespace interpreter_wrapper { + +class InterpreterWrapper { + public: + // SWIG caller takes ownership of pointer. + static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path); + + // SWIG caller takes ownership of pointer. + static InterpreterWrapper* CreateWrapperCPPFromBuffer(const char* data, + size_t len); + + ~InterpreterWrapper(); + bool AllocateTensors(); + bool Invoke(); + + PyObject* InputIndices() const; + PyObject* OutputIndices() const; + bool ResizeInputTensor(int i, PyObject* value); + + std::string TensorName(int i) const; + PyObject* TensorType(int i) const; + PyObject* TensorSize(int i) const; + PyObject* TensorQuantization(int i) const; + bool SetTensor(int i, PyObject* value); + PyObject* GetTensor(int i) const; + + private: + InterpreterWrapper(std::unique_ptr model); + + // InterpreterWrapper is not copyable or assignable. We avoid the use of + // InterpreterWrapper() = delete here for SWIG compatibility. + InterpreterWrapper(); + InterpreterWrapper(const InterpreterWrapper& rhs); + + const std::unique_ptr model_; + const std::unique_ptr resolver_; + const std::unique_ptr interpreter_; +}; + +} // namespace interpreter_wrapper +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i new file mode 100644 index 0000000000000000000000000000000000000000..7f51f9f00d1b2fe057052f7b7bd52bcb65231164 --- /dev/null +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i @@ -0,0 +1,25 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +%include "std_string.i" + + +%{ +#define SWIG_FILE_WITH_INIT +#include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h" +%} + + +%include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h" diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 5d2f21653762a405a57288a7ba38323e5e42b3e1..ed6dd036f9fd9f39b74e902498d815793943924b 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -25,9 +25,9 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice. from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import subprocess -import tempfile +import os as _os +import subprocess as _subprocess +import tempfile as _tempfile # pylint: disable=unused-import from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs @@ -74,7 +74,7 @@ else: _toco_from_proto_bin = _resource_loader.get_path_to_datafile( "../toco/python/toco_from_protos") -if _toco_from_proto_bin and not os.path.exists(_toco_from_proto_bin): +if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin): _toco_from_proto_bin = "toco_from_protos" @@ -102,10 +102,10 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str): return _toco_python.TocoConvert( model_flags_str, toco_flags_str, input_data_str) - with tempfile.NamedTemporaryFile() as fp_toco, \ - tempfile.NamedTemporaryFile() as fp_model, \ - tempfile.NamedTemporaryFile() as fp_input, \ - tempfile.NamedTemporaryFile() as fp_output: + with _tempfile.NamedTemporaryFile() as fp_toco, \ + _tempfile.NamedTemporaryFile() as fp_model, \ + _tempfile.NamedTemporaryFile() as fp_input, \ + _tempfile.NamedTemporaryFile() as fp_output: fp_model.write(model_flags_str) fp_toco.write(toco_flags_str) fp_input.write(input_data_str) @@ -118,11 +118,11 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str): fp_output.name ] cmdline = " ".join(cmd) - proc = subprocess.Popen( + proc = _subprocess.Popen( cmdline, shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, + stdout=_subprocess.PIPE, + stderr=_subprocess.STDOUT, close_fds=True) stdout, stderr = proc.communicate() exitcode = proc.returncode @@ -202,11 +202,12 @@ def toco_convert(input_data, input_array.name = _tensor_name(input_tensor) input_array.shape.dims.extend(map(int, input_tensor.get_shape())) - toco.inference_input_type = tflite_input_type for output_tensor in output_tensors: model.output_arrays.append(_tensor_name(output_tensor)) + # TODO(aselle): Consider handling the case of allowing quantized + # inputs to be converted to float (via the toco.inference_input_type field). data = toco_convert_protos(model.SerializeToString(), toco.SerializeToString(), input_data.SerializeToString()) diff --git a/tensorflow/contrib/lite/python/op_hint.py b/tensorflow/contrib/lite/python/op_hint.py index 9a3971228a683211e84b4c55d3a3e8d574b5ed94..7908689ce4a719ab15bd49a368a87f9cad7c6d61 100644 --- a/tensorflow/contrib/lite/python/op_hint.py +++ b/tensorflow/contrib/lite/python/op_hint.py @@ -119,8 +119,10 @@ class OpHint(object): def _setattr(self, dest_op, name, value): tensor_value = _ops.convert_to_tensor(value) - dest_op.op.node_def.attr[name].tensor.CopyFrom( - tensor_value.op.node_def.attr["value"].tensor) + # pylint: disable=protected-access + dest_op.op._set_attr(name, _attr_value_pb2.AttrValue( + tensor=tensor_value.op.node_def.attr["value"].tensor)) + # pylint: enable=protected-access def add_inputs(self, *args): """Add a sequence of inputs to the function invocation. diff --git a/tensorflow/contrib/lite/rpi_makefile.inc b/tensorflow/contrib/lite/rpi_makefile.inc new file mode 100644 index 0000000000000000000000000000000000000000..832ef5824bea86a368184bd7e3d17915739e9d46 --- /dev/null +++ b/tensorflow/contrib/lite/rpi_makefile.inc @@ -0,0 +1,33 @@ +# Settings for Raspberry Pi. +ifeq ($(TARGET), RPI) + ifeq ($(TARGET_ARCH), armv7) + CXXFLAGS += \ + -march=armv7-a \ + -mfpu=neon-vfpv4 \ + -funsafe-math-optimizations \ + -ftree-vectorize + + CCFLAGS += \ + -march=armv7-a \ + -mfpu=neon-vfpv4 \ + -funsafe-math-optimizations \ + -ftree-vectorize + + LDFLAGS := \ + -Wl,--no-export-dynamic \ + -Wl,--exclude-libs,ALL \ + -Wl,--gc-sections \ + -Wl,--as-needed + endif + + LIBS := \ + -lstdc++ \ + -lpthread \ + -lm \ + -ldl + + OBJDIR := $(OBJDIR)rpi_$(TARGET_ARCH)/ + LIBDIR := $(LIBDIR)rpi_$(TARGET_ARCH)/ + BINDIR := $(BINDIR)rpi_$(TARGET_ARCH)/ + DEPDIR := $(DEPDIR)rpi_$(TARGET_ARCH)/ +endif diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD index 54167ddd9a5a003d0ff21e6627a1dbe94afa3e87..da65ec659c7ab39348d2b7911aceaa9dbdd2654b 100644 --- a/tensorflow/contrib/lite/schema/BUILD +++ b/tensorflow/contrib/lite/schema/BUILD @@ -5,6 +5,7 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") py_binary( name = "upgrade_schema", @@ -80,3 +81,5 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) + +tflite_portable_test_suite() diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc index 08bcfe451685f488be2c3bc180f2dfc43dfe4f05..ac408d2f94b98d505afe4c951d7cc2ff960606fb 100644 --- a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc @@ -46,8 +46,7 @@ extern "C" { #endif // __cplusplus // The enum for builtin operators. -// Note: CUSTOM and DELEGATE are 2 special ops which are not real biultin -// ops. +// Note: CUSTOM and DELEGATE are 2 special ops which are not real built-in ops. typedef enum { )"; diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 98ac0469d1b885aa8047d35c8d814da4b61eff0c..7d2e00fe329a5da77af7bf091eaa99badbd1022a 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -75,7 +75,7 @@ enum BuiltinOperator : byte { CONV_2D = 3, DEPTHWISE_CONV_2D = 4, // DEPTH_TO_SPACE = 5, - // DEQUANTIZE = 6, + DEQUANTIZE = 6, EMBEDDING_LOOKUP = 7, // FLOOR = 8, FULLY_CONNECTED = 9, @@ -129,6 +129,9 @@ enum BuiltinOperator : byte { // WARNING: Experimental interface, subject to change DELEGATE = 51, BIDIRECTIONAL_SEQUENCE_LSTM = 52, + CAST = 53, + PRELU = 54, + MAXIMUM = 55, } // Options for the builtin operators. @@ -169,6 +172,9 @@ union BuiltinOptions { TopKV2Options, SplitOptions, LogSoftmaxOptions, + CastOptions, + DequantizeOptions, + MaximumOptions, } enum Padding : byte { SAME, VALID } @@ -374,6 +380,15 @@ table StridedSliceOptions { table LogSoftmaxOptions { } +table CastOptions { +} + +table DequantizeOptions { +} + +table MaximumOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 99e1accaa71ffc92514595a745fcb60115ef61a0..66a97a1460d12b48102f53f975cb1e25e7735111 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -139,6 +139,15 @@ struct StridedSliceOptionsT; struct LogSoftmaxOptions; struct LogSoftmaxOptionsT; +struct CastOptions; +struct CastOptionsT; + +struct DequantizeOptions; +struct DequantizeOptionsT; + +struct MaximumOptions; +struct MaximumOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -201,6 +210,7 @@ enum BuiltinOperator { BuiltinOperator_CONCATENATION = 2, BuiltinOperator_CONV_2D = 3, BuiltinOperator_DEPTHWISE_CONV_2D = 4, + BuiltinOperator_DEQUANTIZE = 6, BuiltinOperator_EMBEDDING_LOOKUP = 7, BuiltinOperator_FULLY_CONNECTED = 9, BuiltinOperator_HASHTABLE_LOOKUP = 10, @@ -246,17 +256,21 @@ enum BuiltinOperator { BuiltinOperator_LOG_SOFTMAX = 50, BuiltinOperator_DELEGATE = 51, BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM = 52, + BuiltinOperator_CAST = 53, + BuiltinOperator_PRELU = 54, + BuiltinOperator_MAXIMUM = 55, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM + BuiltinOperator_MAX = BuiltinOperator_MAXIMUM }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[50] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[54] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, BuiltinOperator_CONCATENATION, BuiltinOperator_CONV_2D, BuiltinOperator_DEPTHWISE_CONV_2D, + BuiltinOperator_DEQUANTIZE, BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOperator_FULLY_CONNECTED, BuiltinOperator_HASHTABLE_LOOKUP, @@ -301,7 +315,10 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[50] { BuiltinOperator_SPLIT, BuiltinOperator_LOG_SOFTMAX, BuiltinOperator_DELEGATE, - BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM + BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, + BuiltinOperator_CAST, + BuiltinOperator_PRELU, + BuiltinOperator_MAXIMUM }; return values; } @@ -314,7 +331,7 @@ inline const char **EnumNamesBuiltinOperator() { "CONV_2D", "DEPTHWISE_CONV_2D", "", - "", + "DEQUANTIZE", "EMBEDDING_LOOKUP", "", "FULLY_CONNECTED", @@ -361,6 +378,9 @@ inline const char **EnumNamesBuiltinOperator() { "LOG_SOFTMAX", "DELEGATE", "BIDIRECTIONAL_SEQUENCE_LSTM", + "CAST", + "PRELU", + "MAXIMUM", nullptr }; return names; @@ -409,11 +429,14 @@ enum BuiltinOptions { BuiltinOptions_TopKV2Options = 34, BuiltinOptions_SplitOptions = 35, BuiltinOptions_LogSoftmaxOptions = 36, + BuiltinOptions_CastOptions = 37, + BuiltinOptions_DequantizeOptions = 38, + BuiltinOptions_MaximumOptions = 39, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_LogSoftmaxOptions + BuiltinOptions_MAX = BuiltinOptions_MaximumOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[37] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[40] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -451,7 +474,10 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[37] { BuiltinOptions_ExpOptions, BuiltinOptions_TopKV2Options, BuiltinOptions_SplitOptions, - BuiltinOptions_LogSoftmaxOptions + BuiltinOptions_LogSoftmaxOptions, + BuiltinOptions_CastOptions, + BuiltinOptions_DequantizeOptions, + BuiltinOptions_MaximumOptions }; return values; } @@ -495,6 +521,9 @@ inline const char **EnumNamesBuiltinOptions() { "TopKV2Options", "SplitOptions", "LogSoftmaxOptions", + "CastOptions", + "DequantizeOptions", + "MaximumOptions", nullptr }; return names; @@ -653,6 +682,18 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_LogSoftmaxOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_CastOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_DequantizeOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_MaximumOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -972,6 +1013,30 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_LogSoftmaxOptions ? reinterpret_cast(value) : nullptr; } + CastOptionsT *AsCastOptions() { + return type == BuiltinOptions_CastOptions ? + reinterpret_cast(value) : nullptr; + } + const CastOptionsT *AsCastOptions() const { + return type == BuiltinOptions_CastOptions ? + reinterpret_cast(value) : nullptr; + } + DequantizeOptionsT *AsDequantizeOptions() { + return type == BuiltinOptions_DequantizeOptions ? + reinterpret_cast(value) : nullptr; + } + const DequantizeOptionsT *AsDequantizeOptions() const { + return type == BuiltinOptions_DequantizeOptions ? + reinterpret_cast(value) : nullptr; + } + MaximumOptionsT *AsMaximumOptions() { + return type == BuiltinOptions_MaximumOptions ? + reinterpret_cast(value) : nullptr; + } + const MaximumOptionsT *AsMaximumOptions() const { + return type == BuiltinOptions_MaximumOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -3635,6 +3700,126 @@ inline flatbuffers::Offset CreateLogSoftmaxOptions( flatbuffers::Offset CreateLogSoftmaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogSoftmaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct CastOptionsT : public flatbuffers::NativeTable { + typedef CastOptions TableType; + CastOptionsT() { + } +}; + +struct CastOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef CastOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + CastOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(CastOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const CastOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct CastOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit CastOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + CastOptionsBuilder &operator=(const CastOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateCastOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + CastOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateCastOptions(flatbuffers::FlatBufferBuilder &_fbb, const CastOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct DequantizeOptionsT : public flatbuffers::NativeTable { + typedef DequantizeOptions TableType; + DequantizeOptionsT() { + } +}; + +struct DequantizeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DequantizeOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + DequantizeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DequantizeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const DequantizeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct DequantizeOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit DequantizeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + DequantizeOptionsBuilder &operator=(const DequantizeOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateDequantizeOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + DequantizeOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateDequantizeOptions(flatbuffers::FlatBufferBuilder &_fbb, const DequantizeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct MaximumOptionsT : public flatbuffers::NativeTable { + typedef MaximumOptions TableType; + MaximumOptionsT() { + } +}; + +struct MaximumOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef MaximumOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + MaximumOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(MaximumOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const MaximumOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct MaximumOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit MaximumOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + MaximumOptionsBuilder &operator=(const MaximumOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateMaximumOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + MaximumOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateMaximumOptions(flatbuffers::FlatBufferBuilder &_fbb, const MaximumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -3860,6 +4045,15 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const LogSoftmaxOptions *builtin_options_as_LogSoftmaxOptions() const { return builtin_options_type() == BuiltinOptions_LogSoftmaxOptions ? static_cast(builtin_options()) : nullptr; } + const CastOptions *builtin_options_as_CastOptions() const { + return builtin_options_type() == BuiltinOptions_CastOptions ? static_cast(builtin_options()) : nullptr; + } + const DequantizeOptions *builtin_options_as_DequantizeOptions() const { + return builtin_options_type() == BuiltinOptions_DequantizeOptions ? static_cast(builtin_options()) : nullptr; + } + const MaximumOptions *builtin_options_as_MaximumOptions() const { + return builtin_options_type() == BuiltinOptions_MaximumOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -4030,6 +4224,18 @@ template<> inline const LogSoftmaxOptions *Operator::builtin_options_as inline const CastOptions *Operator::builtin_options_as() const { + return builtin_options_as_CastOptions(); +} + +template<> inline const DequantizeOptions *Operator::builtin_options_as() const { + return builtin_options_as_DequantizeOptions(); +} + +template<> inline const MaximumOptions *Operator::builtin_options_as() const { + return builtin_options_as_MaximumOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -5512,6 +5718,75 @@ inline flatbuffers::Offset CreateLogSoftmaxOptions(flatbuffer _fbb); } +inline CastOptionsT *CastOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new CastOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void CastOptions::UnPackTo(CastOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset CastOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CastOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateCastOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateCastOptions(flatbuffers::FlatBufferBuilder &_fbb, const CastOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CastOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateCastOptions( + _fbb); +} + +inline DequantizeOptionsT *DequantizeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new DequantizeOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void DequantizeOptions::UnPackTo(DequantizeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset DequantizeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DequantizeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateDequantizeOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateDequantizeOptions(flatbuffers::FlatBufferBuilder &_fbb, const DequantizeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DequantizeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateDequantizeOptions( + _fbb); +} + +inline MaximumOptionsT *MaximumOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new MaximumOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void MaximumOptions::UnPackTo(MaximumOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset MaximumOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MaximumOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateMaximumOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateMaximumOptions(flatbuffers::FlatBufferBuilder &_fbb, const MaximumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MaximumOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateMaximumOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -5836,6 +6111,18 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_CastOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_DequantizeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_MaximumOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -5998,6 +6285,18 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_CastOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_DequantizeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_MaximumOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -6148,6 +6447,18 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateLogSoftmaxOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_CastOptions: { + auto ptr = reinterpret_cast(value); + return CreateCastOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_DequantizeOptions: { + auto ptr = reinterpret_cast(value); + return CreateDequantizeOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_MaximumOptions: { + auto ptr = reinterpret_cast(value); + return CreateMaximumOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -6298,6 +6609,18 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new LogSoftmaxOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_CastOptions: { + value = new CastOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_DequantizeOptions: { + value = new DequantizeOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_MaximumOptions: { + value = new MaximumOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -6485,6 +6808,21 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_CastOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_DequantizeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_MaximumOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/contrib/lite/schema/upgrade_schema.py b/tensorflow/contrib/lite/schema/upgrade_schema.py index 94f5730be5d991ae13fb019e4d035e23f76fe441..e0b36d3d3ee94b00cccd3968d14c63fe19c3c27c 100644 --- a/tensorflow/contrib/lite/schema/upgrade_schema.py +++ b/tensorflow/contrib/lite/schema/upgrade_schema.py @@ -39,8 +39,8 @@ import tensorflow as tf from tensorflow.python.platform import resource_loader parser = argparse.ArgumentParser( - description="Script to move TFLite models from pre-release schema to" - " new schema.") + description="Script to move TFLite models from pre-release schema to " + "new schema.") parser.add_argument( "input", type=str, @@ -48,7 +48,7 @@ parser.add_argument( parser.add_argument( "output", type=str, - help="Output json or bin TensorFlow lite model compliant with" + help="Output json or bin TensorFlow lite model compliant with " "the new schema. Extension must be `.json`, `.bin` or `.tflite`.") @@ -258,7 +258,7 @@ class Converter(object): # Check if builtin_code is the appropriate string type # use type("") instead of str or unicode. for py2and3 if not isinstance(operator_code["builtin_code"], type(u"")): - raise ValueError("builtin_code %r is non-string. this usually means" + raise ValueError("builtin_code %r is non-string. this usually means " "your model has consistency problems." % (operator_code["builtin_code"])) operator_code["builtin_code"] = (RemapOperator( diff --git a/tensorflow/contrib/lite/simple_memory_arena.cc b/tensorflow/contrib/lite/simple_memory_arena.cc index 4aab244989ca5300fbe74162e03deaac89af60ad..2f2004f56bcad5b56f9dd6d4bc824ec14d79e795 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.cc +++ b/tensorflow/contrib/lite/simple_memory_arena.cc @@ -113,21 +113,21 @@ TfLiteStatus SimpleMemoryArena::Commit(TfLiteContext* context) { underlying_buffer_size_ = required_size; underlying_buffer_aligned_ptr_ = new_underlying_buffer_aligned_ptr; } - commited_ = true; + committed_ = true; return underlying_buffer_ != nullptr ? kTfLiteOk : kTfLiteError; } TfLiteStatus SimpleMemoryArena::ResolveAlloc(TfLiteContext* context, const ArenaAlloc& alloc, char** output_ptr) { - TF_LITE_ENSURE(context, commited_); + TF_LITE_ENSURE(context, committed_); TF_LITE_ENSURE(context, output_ptr != nullptr); *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset; return kTfLiteOk; } TfLiteStatus SimpleMemoryArena::Clear() { - commited_ = false; + committed_ = false; high_water_mark_ = 0; allocs_.clear(); return kTfLiteOk; diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h index 0535522374c63459d029c252ebe94628cf3122d5..5faf78b59e3755d22e4e866d433e622baa6c66c1 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.h +++ b/tensorflow/contrib/lite/simple_memory_arena.h @@ -22,7 +22,7 @@ limitations under the License. namespace tflite { // This little structure holds the offset and the size for a dynamic memory -// allocation in the memory arena. When the arena is commited and the +// allocation in the memory arena. When the arena is committed and the // underlying buffer is set, the alloc can be resolved into an actual memory // pointer. struct ArenaAlloc { @@ -43,7 +43,7 @@ struct ArenaAlloc { class SimpleMemoryArena { public: explicit SimpleMemoryArena(size_t arena_alignment) - : commited_(false), + : committed_(false), arena_alignment_(arena_alignment), high_water_mark_(0), underlying_buffer_size_(0), @@ -73,7 +73,7 @@ class SimpleMemoryArena { } private: - bool commited_; + bool committed_; size_t arena_alignment_; size_t high_water_mark_; std::unique_ptr underlying_buffer_; diff --git a/tensorflow/contrib/lite/special_rules.bzl b/tensorflow/contrib/lite/special_rules.bzl new file mode 100644 index 0000000000000000000000000000000000000000..54083c49182c707620cbd231b957405cfe24be92 --- /dev/null +++ b/tensorflow/contrib/lite/special_rules.bzl @@ -0,0 +1,6 @@ +"""External versions of build rules that differ outside of Google.""" + +def tflite_portable_test_suite(**kwargs): + """This is a no-op outside of Google.""" + _ignore = [kwargs] + pass diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index b5960d6f8d97a440323f01bda1c4976fe5584ac5..12b7b3c35088a0560213e2e1431f23427d4fe640 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -8,6 +8,7 @@ load( "//tensorflow/contrib/lite:build_def.bzl", "gen_zipped_test_files", ) +load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") load( "//tensorflow:tensorflow.bzl", "tf_cc_test", @@ -34,11 +35,12 @@ gen_zipped_test_files( "l2norm.zip", "local_response_norm.zip", "log_softmax.zip", - "lstm.zip", "max_pool.zip", + "maximum.zip", "mean.zip", "mul.zip", "pad.zip", + "prelu.zip", "relu.zip", "relu1.zip", "relu6.zip", @@ -236,6 +238,9 @@ cc_test( size = "small", srcs = ["tf_driver_test.cc"], data = ["//tensorflow/contrib/lite:testdata/multi_add.pb"], + tags = [ + "tflite_not_portable", + ], deps = [ ":tf_driver", "@com_google_googletest//:gtest_main", @@ -259,6 +264,9 @@ cc_test( name = "generate_testspec_test", size = "small", srcs = ["generate_testspec_test.cc"], + tags = [ + "tflite_not_portable", + ], deps = [ ":generate_testspec", "@com_google_googletest//:gtest_main", @@ -317,7 +325,11 @@ tf_cc_test( "//tensorflow/contrib/lite:testdata/multi_add.bin", "//tensorflow/contrib/lite:testdata/multi_add.pb", ], - tags = ["no_oss"], + tags = [ + "no_cuda_on_cpu_tap", + "no_oss", + "tflite_not_portable", + ], deps = [ ":tflite_diff_flags", ":tflite_diff_util", @@ -336,7 +348,10 @@ tf_cc_test( ], data = [":optest"], shard_count = 20, - tags = ["no_oss"], + tags = [ + "no_oss", + "tflite_not_portable", + ], deps = [ ":parse_testdata_lib", ":tflite_driver", @@ -370,3 +385,5 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) + +tflite_portable_test_suite() diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 5488b71fcf644070710acc4b2b2886e9a96facb6..68bce19aa372280219fb2be9ebe3bef2ad03ec05 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -617,6 +617,54 @@ def make_relu6_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_prelu_tests(zip_path): + """Make a set of tests to do PReLU.""" + + test_parameters = [{ + # The canonical case for image processing is having a 4D `input` (NHWC) + # and `shared_axes`=[1, 2], so the alpha parameter is per channel. + "input_shape": [[1, 10, 10, 3], [3, 3, 3, 3]], + "shared_axes": [[1, 2], [1]], + }] + + def build_graph(parameters): + """Build the graph for the test case.""" + + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + prelu = tf.keras.layers.PReLU(shared_axes=parameters["shared_axes"]) + out = prelu(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + """Build the inputs for the test case.""" + + input_shape = parameters["input_shape"] + input_values = create_tensor_data( + np.float32, input_shape, min_value=-10, max_value=10) + shared_axes = parameters["shared_axes"] + + alpha_shape = [] + for dim in range(1, len(input_shape)): + alpha_shape.append(1 if dim in shared_axes else input_shape[dim]) + + alpha_values = create_tensor_data(np.float32, alpha_shape) + + with tf.variable_scope("", reuse=True): + alpha = tf.get_variable("p_re_lu/alpha") + sess.run(alpha.assign(alpha_values)) + + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + use_frozen_graph=True) + + # This function tests various TensorFLow functions that generates Const op, # including `tf.ones`, `tf.zeros` and random functions. def make_constant_tests(zip_path): @@ -706,7 +754,7 @@ def make_mean_tests(zip_path): [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3] ], "const_axis": [True, False], - "keep_dims": [True, False], + "keepdims": [True, False], }, { "input_dtype": [tf.float32, tf.int32, tf.int64], "input_shape": [[1, 224, 224, 3]], @@ -717,7 +765,7 @@ def make_mean_tests(zip_path): [2, 2, 3], [-3, -3, -4], [-3, 2, 1] ], "const_axis": [True, False], - "keep_dims": [True, False], + "keepdims": [True, False], }] def build_graph(parameters): @@ -740,7 +788,7 @@ def make_mean_tests(zip_path): input_tensors = [input_tensor, axis] out = tf.reduce_mean( - input_tensor, axis=axis, keep_dims=parameters["keep_dims"]) + input_tensor, axis=axis, keepdims=parameters["keepdims"]) return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): @@ -814,6 +862,41 @@ def make_log_softmax_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_maximum_tests(zip_path): + """Make a set of tests to do maximum.""" + + test_parameters = [{ + "input_dtype": [tf.float32], + "input_shape_1": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + "input_shape_2": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + }] + + def build_graph(parameters): + """Build the maximum op testing graph.""" + input_tensor_1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input_1", + shape=parameters["input_shape_1"]) + input_tensor_2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input_2", + shape=parameters["input_shape_2"]) + + out = tf.maximum(input_tensor_1, input_tensor_2) + return [input_tensor_1, input_tensor_2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + values = [ + create_tensor_data(parameters["input_dtype"], + parameters["input_shape_1"]), + create_tensor_data(parameters["input_dtype"], + parameters["input_shape_2"]) + ] + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_binary_op_tests_func(binary_operator): """Return a function that does a test on a binary operator.""" return lambda zip_path: make_binary_op_tests(zip_path, binary_operator) @@ -1571,7 +1654,7 @@ def make_transpose_tests(zip_path): }, { "dtype": [tf.float32], "input_shape": [[1, 2, 3, 4, 5]], - "perm": [[0, 1, 2, 3, 4]], + "perm": [[4, 3, 2, 1, 0]], "constant_perm": [True, False], }] @@ -1911,6 +1994,7 @@ def main(unused_args): "relu.zip": make_relu_tests, "relu1.zip": make_relu1_tests, "relu6.zip": make_relu6_tests, + "prelu.zip": make_prelu_tests, "l2_pool.zip": make_pool_tests(make_l2_pool), "avg_pool.zip": make_pool_tests(tf.nn.avg_pool), "max_pool.zip": make_pool_tests(tf.nn.max_pool), @@ -1929,6 +2013,7 @@ def main(unused_args): "exp.zip": make_exp_tests, "log_softmax.zip": make_log_softmax_tests, "lstm.zip": make_lstm_tests, + "maximum.zip": make_maximum_tests, } out = FLAGS.zip_to_output bin_path = FLAGS.toco diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index c160b250c8a8c6a8a21bd438d0ee862288c2c08e..e9d505a76d15c8eaf1d3b6ba55bffe512532585e 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -47,7 +47,6 @@ tensorflow::Env* env = tensorflow::Env::Default(); // Key is a substring of the test name and value is a bug number. // TODO(ahentz): make sure we clean this list up frequently. std::map kBrokenTests = { - // Add only supports float32. (and "constant" tests use Add) {R"(^\/adda.*int32)", "68808744"}, {R"(^\/constant.*int32)", "68808744"}, @@ -89,6 +88,9 @@ std::map kBrokenTests = { // Transpose only supports 1D-4D input tensors. {R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"}, + + // PRelu only supports 4D input with (1, 1, channels) 3D alpha now. + {R"(^\/prelu.*shared_axes=\[1\])", "75975192"}, }; // Allows test data to be unzipped into a temporary directory and makes @@ -248,13 +250,14 @@ INSTANTIATE_TESTS(l2_pool) INSTANTIATE_TESTS(l2norm) INSTANTIATE_TESTS(local_response_norm) INSTANTIATE_TESTS(log_softmax) -INSTANTIATE_TESTS(lstm) +INSTANTIATE_TESTS(maximum) INSTANTIATE_TESTS(max_pool) INSTANTIATE_TESTS(mean) INSTANTIATE_TESTS(mul) INSTANTIATE_TESTS(pad) INSTANTIATE_TESTS(relu) INSTANTIATE_TESTS(relu1) +INSTANTIATE_TESTS(prelu) INSTANTIATE_TESTS(relu6) INSTANTIATE_TESTS(reshape) INSTANTIATE_TESTS(resize_bilinear) diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 17407f3db27ead984d1cfffc3f0085ac86f5318f..102740ee4725904918ce551d1a3e233ee6f8cc57 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -124,6 +124,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -167,12 +168,48 @@ cc_library( ], ) +cc_library( + name = "toco_saved_model", + srcs = [ + "toco_saved_model.cc", + ], + hdrs = [ + "toco_saved_model.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":model_cmdline_flags", + ":model_flags_proto_cc", + ":toco_flags_proto_cc", + ":types_proto_cc", + "//tensorflow/cc/tools:freeze_saved_model", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "toco_saved_model_test", + srcs = ["toco_saved_model_test.cc"], + deps = [ + ":model_cmdline_flags", + ":toco_cmdline_flags", + ":toco_saved_model", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "graph_transformations", srcs = [ "graph_transformations/convert_expanddims_to_reshape.cc", "graph_transformations/convert_pure_conv_to_depthwise.cc", "graph_transformations/convert_reorder_axes.cc", + "graph_transformations/convert_squeeze_to_reshape.cc", "graph_transformations/convert_trivial_addn_to_add.cc", "graph_transformations/convert_trivial_stack_to_reshape.cc", "graph_transformations/convert_trivial_transpose_to_reshape.cc", @@ -192,9 +229,11 @@ cc_library( "graph_transformations/identify_lstm.cc", "graph_transformations/identify_lstm_merge_inputs.cc", "graph_transformations/identify_lstm_split_inputs.cc", + "graph_transformations/identify_prelu.cc", "graph_transformations/identify_relu1.cc", "graph_transformations/lstm_utils.cc", "graph_transformations/make_initial_dequantize_operator.cc", + "graph_transformations/propagate_activation_function_into_constants.cc", "graph_transformations/propagate_array_data_types.cc", "graph_transformations/propagate_fixed_sizes.cc", "graph_transformations/quantize.cc", @@ -218,6 +257,7 @@ cc_library( "graph_transformations/resolve_constant_concatenation.cc", "graph_transformations/resolve_constant_fake_quant.cc", "graph_transformations/resolve_constant_fill.cc", + "graph_transformations/resolve_constant_gather.cc", "graph_transformations/resolve_constant_range.cc", "graph_transformations/resolve_constant_shape_or_rank.cc", "graph_transformations/resolve_constant_stack.cc", @@ -240,6 +280,7 @@ cc_library( "graph_transformations/resolve_tensorflow_tile.cc", "graph_transformations/resolve_transpose_attributes.cc", "graph_transformations/unfuse_activation_functions.cc", + "graph_transformations/unpartition_embedding_lookup.cc", "graph_transformations/unroll_batch_matmul.cc", ], hdrs = [ @@ -328,6 +369,7 @@ cc_library( ":toco_graphviz_dump_options", ":toco_port", ":types_proto_cc", + "//tensorflow/contrib/lite/kernels/internal:quantization_util", "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@protobuf_archive//:protobuf_headers", @@ -357,6 +399,7 @@ tf_cc_binary( ":toco_cmdline_flags", ":toco_flags_proto_cc", ":toco_port", + ":toco_saved_model", ":toco_tooling", ":types_proto_cc", "//tensorflow/core:lib", diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc index 49cc1fc2aa365925cde86ceb658ff2b354d06911..621fbcb98db049f819ebbbda8816ad4e30538530 100644 --- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc +++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc @@ -248,29 +248,49 @@ void AllocateTransientArrays(Model* model, op_index++) { const auto& op = model->operators[op_index]; // Allocate those arrays whose lifespan starts exactly here. + std::vector arrays_to_allocate; for (const auto& input : op->inputs) { if (StartsAt(array_lifespans[input], op_index)) { - AllocateTransientArray(*model, input, &allocator, - transient_data_alignment); + if (std::find(arrays_to_allocate.begin(), arrays_to_allocate.end(), + input) == arrays_to_allocate.end()) { + arrays_to_allocate.push_back(input); + } } } for (const auto& output : op->outputs) { if (StartsAt(array_lifespans[output], op_index)) { - AllocateTransientArray(*model, output, &allocator, - transient_data_alignment); + if (std::find(arrays_to_allocate.begin(), arrays_to_allocate.end(), + output) == arrays_to_allocate.end()) { + arrays_to_allocate.push_back(output); + } } } + for (const string& array : arrays_to_allocate) { + AllocateTransientArray(*model, array, &allocator, + transient_data_alignment); + } + // Deallocate those arrays whose lifespan ends exactly here. + std::vector arrays_to_deallocate; for (const auto& input : op->inputs) { if (EndsAt(array_lifespans[input], op_index)) { - DeallocateTransientArray(*model, input, &allocator); + if (std::find(arrays_to_deallocate.begin(), arrays_to_deallocate.end(), + input) == arrays_to_deallocate.end()) { + arrays_to_deallocate.push_back(input); + } } } for (const auto& output : op->outputs) { if (EndsAt(array_lifespans[output], op_index)) { - DeallocateTransientArray(*model, output, &allocator); + if (std::find(arrays_to_deallocate.begin(), arrays_to_deallocate.end(), + output) == arrays_to_deallocate.end()) { + arrays_to_deallocate.push_back(output); + } } } + for (const string& array : arrays_to_deallocate) { + DeallocateTransientArray(*model, array, &allocator); + } } // Just out of curiosity (not used in the actual allocation process) diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index 59a6115920614d38900c0370708324c122384420..7b71792ff79604a61e0693415815bc86c8d6d1bc 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -190,6 +190,7 @@ struct ParsedModelFlags { Arg output_array; Arg output_arrays; Arg input_shapes; + Arg batch_size = Arg(1); Arg mean_value = Arg(0.f); Arg mean_values; Arg std_value = Arg(1.f); @@ -215,9 +216,11 @@ struct ParsedModelFlags { // you want). See toco_cmdline_flags.cc for details. struct ParsedTocoFlags { Arg input_file; + Arg savedmodel_directory; Arg output_file; - Arg input_format; - Arg output_format; + Arg input_format = Arg("TENSORFLOW_GRAPHDEF"); + Arg output_format = Arg("TFLITE"); + Arg savedmodel_tagset; // TODO(aselle): command_line_flags doesn't support doubles Arg default_ranges_min = Arg(0.); Arg default_ranges_max = Arg(0.); diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 6900468ec6484d5c1896752286a2fa72f4d38c07..22a23357b36c16ea937e726f1e49aa95d7f964e3 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -548,6 +548,38 @@ void ConvertDepthwiseConvOperator(const Model& model, } } +void ConvertTransposeConvOperator(const Model& model, + const TransposeConvOperator& src_op, + GraphDef* tensorflow_graph) { + auto* conv2d_op = tensorflow_graph->add_node(); + conv2d_op->set_op("Conv2DBackpropInput"); + conv2d_op->set_name(src_op.outputs[0]); + *conv2d_op->add_input() = src_op.inputs[0]; + *conv2d_op->add_input() = src_op.inputs[1]; + *conv2d_op->add_input() = src_op.inputs[2]; + (*conv2d_op->mutable_attr())["T"].set_type(DT_FLOAT); + const string& weights_array_name = WalkUpToConstantArray( + model, src_op.inputs[TransposeConvOperator::WEIGHTS]); + const auto& weights_array = model.GetArray(weights_array_name); + CHECK(weights_array.buffer->type == ArrayDataType::kFloat); + ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI, + AxesOrder::kHWIO, tensorflow_graph); + auto& strides = (*conv2d_op->mutable_attr())["strides"]; + strides.mutable_list()->add_i(1); + strides.mutable_list()->add_i(src_op.stride_height); + strides.mutable_list()->add_i(src_op.stride_width); + strides.mutable_list()->add_i(1); + string padding; + if (src_op.padding.type == PaddingType::kSame) { + padding = "SAME"; + } else if (src_op.padding.type == PaddingType::kValid) { + padding = "VALID"; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + (*conv2d_op->mutable_attr())["padding"].set_s(padding); +} + void ConvertDepthToSpaceOperator(const Model& model, const DepthToSpaceOperator& src_op, GraphDef* tensorflow_graph) { @@ -1622,9 +1654,11 @@ void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op, const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); - auto& squeeze_dims = (*new_op->mutable_attr())["squeeze_dims"]; - for (int i : src_op.squeeze_dims) { - squeeze_dims.mutable_list()->add_i(i); + if (!src_op.squeeze_dims.empty()) { + auto& squeeze_dims = (*new_op->mutable_attr())["squeeze_dims"]; + for (int i : src_op.squeeze_dims) { + squeeze_dims.mutable_list()->add_i(i); + } } } @@ -1859,6 +1893,10 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertExpandDimsOperator(model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kTransposeConv) { + ConvertTransposeConvOperator( + model, static_cast(src_op), + tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md index 440f9c367c25726e20aa8828e3050cd1dc1b230d..36e2d9c37238bb6184ec99c567810b1bcb9a68ce 100644 --- a/tensorflow/contrib/lite/toco/g3doc/python_api.md +++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md @@ -28,7 +28,7 @@ val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) out = tf.identity(val, name="out") with tf.Session() as sess: tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out]) - open("test.tflite", "wb").write(tflite_modeL) + open("test.tflite", "wb").write(tflite_model) ``` **NOTE** Currently, the TOCO command will cause a fatal error to the Python diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc new file mode 100644 index 0000000000000000000000000000000000000000..81cedb5dad751aacbbb32326db73de386aba282d --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc @@ -0,0 +1,85 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +// Replaces a tf.squeeze operator with a reshape. +// Squeeze removes dimensions == 1 (if in the list of squeeze_dims). This +// means that the data layout will never change with this op, just the shape. +// By converting these to reshapes once we have run shape propagation we allow +// standard reshape optimization transforms to do their magic. +bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) { + auto squeeze_it = model->operators.begin() + op_index; + if (squeeze_it->get()->type != OperatorType::kSqueeze) { + return false; + } + auto squeeze_op = static_cast(squeeze_it->get()); + CHECK_EQ(squeeze_op->inputs.size(), 1); + CHECK_EQ(squeeze_op->outputs.size(), 1); + + const auto& input_array = model->GetArray(squeeze_op->inputs[0]); + if (!input_array.has_shape()) { + // Yield until input dims have been resolved. + return false; + } + if (input_array.shape().dimensions_count() == 0) { + // Input array cannot be 0-D. + return false; + } + if (!model->HasArray(squeeze_op->outputs[0]) || + !model->GetArray(squeeze_op->outputs[0]).has_shape()) { + // Yield until shape propagation has set the output shape for us. + return false; + } + + // We use the output shape that has been calculated by shape propagation. + const auto& output_shape = model->GetArray(squeeze_op->outputs[0]).shape(); + + // Empty shapes will not work as empty data arrays. + if (output_shape.dimensions_count() == 0) { + return false; + } + + auto* reshape_op = new TensorFlowReshapeOperator; + reshape_op->inputs = { + squeeze_op->inputs[0], + CreateInt32Array(model, squeeze_op->outputs[0] + "_shape", + output_shape.dims()), + }; + reshape_op->outputs = squeeze_op->outputs; + + AddMessageF("Replacing %s with %s", LogName(*squeeze_op), + LogName(*reshape_op)); + + // Replace the operator in the graph. + const auto reshape_it = model->operators.emplace(squeeze_it, reshape_op); + squeeze_it = reshape_it + 1; + CHECK_EQ(squeeze_it->get(), squeeze_op); + model->operators.erase(squeeze_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc index c2b166033c33b777bad88cb712adf8517be1762a..5a36a90b3841504d6f018832777e50bac95218d7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc @@ -21,6 +21,33 @@ limitations under the License. namespace toco { +namespace { + +bool TransposeAffectsMemoryOrder(std::vector perm, + std::vector in_shape) { + CHECK_EQ(perm.size(), in_shape.size()); + // See what the ordering of the non-unary columns are before and after + // transpose permutation. If the major indices stay in the same order (not + // just the shape) then the flat buffer representation shouldn't change. + std::vector old_major_index_ordering; + std::vector new_major_index_ordering; + for (int i = 0; i < in_shape.size(); i++) { + if (in_shape[i] != 1) { + old_major_index_ordering.push_back(i); + } + + if (in_shape[perm[i]] != 1) { + new_major_index_ordering.push_back(perm[i]); + } + } + + CHECK_EQ(new_major_index_ordering.size(), old_major_index_ordering.size()); + + return old_major_index_ordering != new_major_index_ordering; +} + +} // namespace + bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) { auto transpose_it = model->operators.begin() + op_index; if (transpose_it->get()->type != OperatorType::kTranspose) { @@ -29,23 +56,26 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) { TransposeOperator* transpose_op = static_cast(transpose_it->get()); + const auto& input_array = model->GetArray(transpose_op->inputs[0]); const auto& output_array = model->GetArray(transpose_op->outputs[0]); - if (!output_array.has_shape()) { + if (!input_array.has_shape() || !output_array.has_shape()) { // Yield until PropagateFixedSizes has been run on this op. return false; } // Note: We can assume we have error checked inputs in PropagateFixedSizes. - // This transpose is trivial if we only have one non-unitary dimension. - std::vector const& dims = output_array.shape().dims(); - unsigned non_unitary_axis_count = 0; - for (int i = 0; i < dims.size(); i++) { - if (dims[i] != 1) { - non_unitary_axis_count++; - } + // Check that the permutation has propogated. + std::vector const& perm = transpose_op->perm; + if (perm.empty()) { + return false; } - if (non_unitary_axis_count > 1) { - // Transpose is not trivial + + // This transpose is trivial if non-unitary dimensions remain in the same + // order. + std::vector const& input_dims = input_array.shape().dims(); + std::vector const& output_dims = output_array.shape().dims(); + + if (TransposeAffectsMemoryOrder(perm, input_dims)) { return false; } @@ -61,11 +91,11 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) { string shape_array_name = toco::AvailableArrayName(*model, perm_array_name); Array& shape_array = model->GetOrCreateArray(shape_array_name); *(shape_array.mutable_shape()->mutable_dims()) = { - 1, static_cast(dims.size())}; + 1, static_cast(output_dims.size())}; reshape_op->inputs.push_back(shape_array_name); shape_array.data_type = ArrayDataType::kInt32; auto& shape_buffer = shape_array.GetMutableBuffer(); - shape_buffer.data = dims; + shape_buffer.data = output_dims; // Delete perm array if unused if (IsDiscardableArray(*model, perm_array_name) && diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc index ab943f72d1dd87ae9ff4bd53a807cd4923a88c38..c5ce3fcd95eb0aaf63dcc7f43b96d8a13ed93929 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc @@ -42,9 +42,9 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { if (CountTrueOutputs(*model, *op) > 1) { AddMessageF( - "Not fusing activation function into %s because it has more than one " - " consumed output", - LogName(*op)); + "Not fusing activation function %s into %s because it has more than " + "one consumed output", + LogName(*ac_op), LogName(*op)); return false; } @@ -56,22 +56,31 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { AddMessageF( "Not fusing activation function into %s because it is consumed by more " "than 1 other operator", - LogName(*op)); + LogName(*ac_op), LogName(*op)); + return false; + } + + if (!IsDiscardableArray(*model, op->outputs[0])) { + AddMessageF( + "Not fusing activation function %s into %s because output %s it is not " + "discardable", + LogName(*ac_op), LogName(*op), op->outputs[0]); return false; } if (op->fused_activation_function != FusedActivationFunctionType::kNone) { AddMessageF( - "Not fusing activation function into %s because it already has a fused " - "activation function", - LogName(*op)); + "Not fusing activation function %s into %s because it already has a " + "fused activation function", + LogName(*ac_op), LogName(*op)); return false; } if (!OperatorSupportsFusedActivation(op->type)) { AddMessageF( - "Not fusing activation function because the %s op doesn't support it", - LogName(*op)); + "Not fusing activation function %s because the %s op doesn't support " + "it", + LogName(*ac_op), LogName(*op)); return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc index 5b57178b18d2d60e1f301a1a8b257d8057618550..76c6be00d407ca30b898d088c9fa34cd7f76f656 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc @@ -50,7 +50,17 @@ void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, // TODO(b/62904716): Bias array should become 1-D when padding removed. const int depth = bias_shape.dims(bias_shape.dimensions_count() - 1); - CHECK_EQ(depth, operand_shape.dims(operand_shape.dimensions_count() - 1)); + int operand_channel_increment = 0; + if (operand_shape.dimensions_count() >= 1 && + operand_shape.dims(operand_shape.dimensions_count() - 1) == + bias_shape.dims(bias_shape.dimensions_count() - 1)) { + operand_channel_increment = 1; + } else if (operand_shape.dimensions_count() == 0 || + operand_shape.dims(operand_shape.dimensions_count() - 1) == 1) { + operand_channel_increment = 0; + } else { + LOG(FATAL) << "Operand shape mismatch."; + } enum class OpType { BiasPlusOperand, BiasMinusOperand, OperandMinusBias }; @@ -60,9 +70,10 @@ void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, ? OpType::BiasMinusOperand : OpType::OperandMinusBias; + int operand_channel = 0; for (int i = 0; i < depth; i++) { float& bias_val = bias_data[i]; - const float operand_val = operand_data[i]; + const float operand_val = operand_data[operand_channel]; if (optype == OpType::BiasPlusOperand) { bias_val += operand_val; } else if (optype == OpType::BiasMinusOperand) { @@ -72,6 +83,7 @@ void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, } else { LOG(FATAL) << "Should not get here."; } + operand_channel += operand_channel_increment; } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index f2c81ebc81c2928ae60d66bfcd7f643c5412f196..640afc7c74d7284fb9e212ab23d74a8215314add 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -114,6 +114,7 @@ void RunGraphTransformations(Model* model, const string& message, // List of all graph transformations DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise) +DECLARE_GRAPH_TRANSFORMATION(ConvertSqueezeToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialAddNToAdd) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialStackToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTransposeToReshape) @@ -128,8 +129,10 @@ DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell) DECLARE_GRAPH_TRANSFORMATION(SplitLstmCellInputs) DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs) DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1) +DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu) DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv) DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator) +DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants) DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes) DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes) DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax) @@ -175,8 +178,10 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStack) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantGather) DECLARE_GRAPH_TRANSFORMATION(ResolveMultiplyByZero) DECLARE_GRAPH_TRANSFORMATION(Dequantize) +DECLARE_GRAPH_TRANSFORMATION(UnpartitionEmbeddingLookup) class ResolveReshapeAttributes : public GraphTransformation { public: diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index 938d76386d6f315abfe6fe55b133cb4d19014f01..5cc82da5d544846cc095046ceccf0664525aae41 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -326,9 +326,12 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForAverageOrMaxPool(model, op); break; + case OperatorType::kStridedSlice: case OperatorType::kSqueeze: case OperatorType::kTensorFlowReshape: case OperatorType::kPad: + case OperatorType::kGather: + case OperatorType::kTranspose: changed = HardcodeMinMaxFromFirstInput(model, op); break; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc new file mode 100644 index 0000000000000000000000000000000000000000..30be4ac0aa5e9f639bbf0630e142c2806faa3260 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc @@ -0,0 +1,119 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +// This transformation rule tries to identify the PRelu structure generated by +// Keras, and convert it to a single op. +// +// The formula of PReLU is: +// f(x) = alpha * x for x < 0, f(x) = x for x >= 0. +// +// `x` is the input, and `alpha` is a trainable tensor which can be broadcasted +// to the shape of `x`. +// +// There's no native PRelu op in TensorFlow, so Keras generates the following +// structure which does the equivalent calculation: +// f(x) = Relu(x) + (-alpha * Relu(-x)) +// +// Practically, alpha is always a constant in the inference graph, and Toco have +// other graph transformations which fold the activation functions to other ops. +// Therefore, we're looking for the structure: +// +// f(x) = Relu(x) + (negative_alpha * Neg(x, activation=Relu)) + +namespace toco { + +bool IdentifyPRelu::Run(Model* model, std::size_t op_index) { + const auto add_op_it = model->operators.begin() + op_index; + const auto* add_op = add_op_it->get(); + if (add_op == nullptr || add_op->type != OperatorType::kAdd || + add_op->inputs.size() != 2 || + add_op->fused_activation_function != FusedActivationFunctionType::kNone) { + return false; + } + + const auto* relu_input_op = GetOpWithOutput(*model, add_op->inputs[0]); + if (relu_input_op == nullptr || relu_input_op->type != OperatorType::kRelu || + relu_input_op->inputs.size() != 1 || + relu_input_op->fused_activation_function != + FusedActivationFunctionType::kNone) { + return false; + } + + // TODO(ycling): Both Add and Mul are commutative. Support the case where + // the position of operands are exchanged. + const auto* mul_op = GetOpWithOutput(*model, add_op->inputs[1]); + if (mul_op == nullptr || mul_op->type != OperatorType::kMul || + mul_op->inputs.size() != 2 || + mul_op->fused_activation_function != FusedActivationFunctionType::kNone) { + return false; + } + + const auto neg_alpha_tensor_name = mul_op->inputs[0]; + + const auto* relu_neg_input_op = GetOpWithOutput(*model, mul_op->inputs[1]); + + if (relu_neg_input_op == nullptr || + relu_neg_input_op->type != OperatorType::kNeg || + relu_neg_input_op->fused_activation_function != + FusedActivationFunctionType::kRelu || + relu_neg_input_op->inputs.size() != 1) { + return false; + } + + if (relu_input_op->inputs[0] != relu_neg_input_op->inputs[0]) { + return false; + } + + const auto input_tensor_name = relu_input_op->inputs[0]; + const auto output_tensor_name = add_op->outputs[0]; + + // Construct a tensor for positive alpha (double negative). + const auto alpha_tensor_name = + AvailableArrayName(*model, neg_alpha_tensor_name + "_neg"); + model->GetOrCreateArray(alpha_tensor_name); + + auto* neg_neg_alpha_op = new NegOperator; + neg_neg_alpha_op->inputs = {neg_alpha_tensor_name}; + neg_neg_alpha_op->outputs = {alpha_tensor_name}; + model->operators.emplace(add_op_it, neg_neg_alpha_op); + + auto* prelu_op = new PReluOperator; + prelu_op->inputs = {input_tensor_name, alpha_tensor_name}; + prelu_op->outputs = {output_tensor_name}; + model->operators.emplace(add_op_it, prelu_op); + AddMessageF("Creating %s replacing equivalent subgraph", LogName(*prelu_op)); + + DeleteArrayIfUsedOnce(neg_alpha_tensor_name, model); + DeleteArrayIfUsedOnce(add_op->inputs[0], model); + DeleteArrayIfUsedOnce(add_op->inputs[1], model); + DeleteArrayIfUsedOnce(mul_op->inputs[1], model); + // Remove the existing Add op that outputs the final result. If the other + // intermediate tensors aren't used by other ops, those will be removed by + // other graph transformation rules. + model->operators.erase(FindOp(*model, add_op)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h index 881c2d4dc892625d4640cac867a2f49c24b638f5..4a9974ed4e0ebec4381b86798156f4f51bb154a0 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h @@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_LSTM_UTILS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_LSTM_UTILS_H_ + #include #include #include @@ -100,3 +103,5 @@ bool GetMatchingRnnArray(Model* model, const string& back_edge_source_array, string* rnn_array); } // namespace toco + +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_LSTM_UTILS_H_ diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc index d83603e9a2c59ae74a5e5fda5b11178740336bfb..935da9f966ca63095faa17476be3a559d1a0193a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc @@ -85,8 +85,8 @@ bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op, auto& dequantized_input_minmax = dequantized_input_array.GetOrCreateMinMax(); dequantized_input_minmax = input_minmax; auto& input_qparams = input_array.GetOrCreateQuantizationParams(); - GetQuantizationParamsFromMinMax( - model->flags, input_minmax, &input_qparams); + GetQuantizationParamsFromMinMax(input_minmax, + &input_qparams); transformation->AddMessageF( "Created %s" diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc new file mode 100644 index 0000000000000000000000000000000000000000..cf17c49b1098d02468935aa72d1d1e73b4addbe1 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc @@ -0,0 +1,121 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool PropagateActivationFunctionIntoConstants::Run(Model* model, + std::size_t op_index) { + const auto ac_it = model->operators.begin() + op_index; + const auto* ac_op = ac_it->get(); + if (ac_op->type != OperatorType::kRelu6 && + ac_op->type != OperatorType::kRelu1 && + ac_op->type != OperatorType::kRelu) { + return false; + } + + // Find the op producing the array passed to this activation function. + auto* src_op = GetOpWithOutput(*model, ac_op->inputs[0]); + if (!src_op) { + return false; + } + + // Ensure the src_op is not used without the activation function applied. + if (CountTrueOutputs(*model, *src_op) > 1) { + AddMessageF( + "Not propagating activation function %s into %s because it has more " + "than one consumed output", + LogName(*ac_op), LogName(*src_op)); + } + + // Filter to the list of supported ops. + string src_op_input; + switch (src_op->type) { + case OperatorType::kGather: + src_op_input = src_op->inputs[0]; + break; + default: + return false; + } + CHECK_EQ(src_op->outputs[0], ac_op->inputs[0]); + + // Ensure the input is constant as otherwise this needs to happen at runtime. + // If we bail here, it's still possible that FuseActivationFunctions will fuse + // the activation if it's supported by the op. + if (!IsConstantParameterArray(*model, src_op_input)) { + AddMessageF( + "Not propagating activation function %s into %s:%s because it is not " + "constant", + LogName(*ac_op), LogName(*src_op), src_op_input); + return false; + } + + // Get the array we'll be working with and ensure it's a compatible type. + auto& const_array = model->GetArray(src_op_input); + if (const_array.data_type != ArrayDataType::kFloat) { + AddMessageF( + "Not propagating activation function %s into %s:%s because it is " + "non-float data", + LogName(*ac_op), LogName(*src_op), src_op_input); + return false; + } + auto& const_array_data = + const_array.GetMutableBuffer().data; + + // Perform the activation function directly into the constant data array. + for (size_t i = 0; i < const_array_data.size(); ++i) { + const float value = const_array_data[i]; + float new_value = value; + switch (ac_op->type) { + case OperatorType::kRelu: { + static constexpr float kLower = 0; + new_value = value < kLower ? kLower : value; + break; + } + case OperatorType::kRelu1: { + static constexpr float kUpper = 1; + static constexpr float kLower = -1; + new_value = value > kUpper ? kUpper : value < kLower ? kLower : value; + break; + } + case OperatorType::kRelu6: { + static constexpr float kUpper = 6; + static constexpr float kLower = 0; + new_value = value > kUpper ? kUpper : value < kLower ? kLower : value; + break; + } + default: + LOG(FATAL) << "Unsupported activation function " << LogName(*ac_op); + return false; + } + const_array_data[i] = new_value; + } + + AddMessageF("Propagated activation function %s into %s:%s", LogName(*ac_op), + LogName(*src_op), src_op_input); + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index f0d107232b4517115aa3f64b39b825dbaffb83ce..778da39bf13563cbbdbe54f1140595b057253ae3 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -71,6 +71,11 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { CHECK_GE(op->inputs.size(), 2); const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type; SetDataTypeForAllOutputs(model, op, data_type); + } else if (op->type == OperatorType::kTransposeConv) { + // These operators produce an output with the same type as their 3rd input + CHECK_GE(op->inputs.size(), 3); + const ArrayDataType data_type = model->GetArray(op->inputs[2]).data_type; + SetDataTypeForAllOutputs(model, op, data_type); } else if (op->type == OperatorType::kCast) { // Data type of the Cast op is specified. CHECK_EQ(op->outputs.size(), 1); @@ -97,10 +102,13 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { SetDataTypeForAllOutputs(model, op, data_type); } else if (op->type == OperatorType::kTensorFlowUnsupported) { auto* unsupported_op = static_cast(op); - if (unsupported_op->output_data_types.size() != op->outputs.size()) { + // Some output tensors from the op could be eliminated by optimization. + // This can make unsupported_op->output_data_types have more elements than + // op->outputs. + if (unsupported_op->output_data_types.size() < op->outputs.size()) { return false; } - for (int i = 0; i < unsupported_op->output_data_types.size(); ++i) { + for (int i = 0; i < op->outputs.size(); ++i) { auto output = op->outputs[i]; auto data_type = unsupported_op->output_data_types[i]; model->GetArray(output).data_type = data_type; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 0e2e5ecf30053103492337685d85a2aacf832caf..676736cfc523c03c9f4d99c404eb2b5209209945 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -190,6 +190,116 @@ void ProcessConvOperator(Model* model, ConvOperator* op) { } } +void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { + // TransposeConv is unique in that it is specifically given the output shape + // as a 1D array on it's 1st input. Theoretically then, resolving the output + // shape is as easy as waiting for this input to be resolved. However, we also + // have to calculate the padding which requires the weights shape. So, we + // might as well calculate the output shape and ensure it matches the + // specified one + + // Check if we have already run. + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + return; + } + + // SPECIFIED OUTPUT SHAPE + // The below is the specified, or prescribed output shape, _given_ to the + // operator as an input. + auto& specified_output_shape_array = + model->GetArray(op->inputs[TransposeConvOperator::OUTPUT_SHAPE]); + if (!specified_output_shape_array.has_shape() || + !specified_output_shape_array.buffer) { + // Yield until the specified output shape is resolved as a constant + return; + } + + CHECK(specified_output_shape_array.data_type == ArrayDataType::kInt32) + << "TransposeConv input_dims must be int32"; + + CHECK(specified_output_shape_array.shape().dimensions_count() == 1 && + specified_output_shape_array.shape().dims(0) == 4) + << "TransposeConv requires a 1D, 4 element array on it's 0th input " + "specifying the output shape. \"" + << op->inputs[TransposeConvOperator::OUTPUT_SHAPE] << "\" had shape " + << toco::ShapeToString(specified_output_shape_array.shape()); + + // COMPUTE PADDING + // We require the weights shape to calculate padding. + const auto& weights_array = + model->GetArray(op->inputs[TransposeConvOperator::WEIGHTS]); + if (!weights_array.has_shape()) { + // Yield until weights dims have been resolved. + return; + } + const auto& weights_shape = weights_array.shape(); + CHECK_EQ(weights_shape.dimensions_count(), 4) + << "TransposeConv weights must have 4 input dimensions. Input weights \"" + << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape " + << toco::ShapeToString(weights_shape) << "."; + + CHECK(weights_shape.dims(0) == 1 && weights_shape.dims(3) == 1) + << "TransposeConv weights dimensions must begin and end with 1. Input " + "weights \"" + << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape " + << toco::ShapeToString(weights_shape) << "."; + + // Compute padding + const int kheight = weights_shape.dims(1); + const int kwidth = weights_shape.dims(2); + op->padding.GetOrCreateFixedPadding(); + if (op->padding.type == PaddingType::kValid) { + op->padding.fixed->height = 0; + op->padding.fixed->width = 0; + } else if (op->padding.type == PaddingType::kSame) { + op->padding.fixed->height = (kheight - 1) / 2; + op->padding.fixed->width = (kwidth - 1) / 2; + } else { + LOG(FATAL) << "TransposeConv only supports SAME or VALID padding"; + } + + // VALIDATE OUTPUT SHAPE + // Compute the output shape from the input and weights shapes to verify it + // agrees with the specified output shape. + const auto& input_array = + model->GetArray(op->inputs[TransposeConvOperator::DATA_INPUT]); + if (!input_array.has_shape()) { + // Yield until input dims have been resolved. + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4) + << "TransposeConv input shape must have 4 dimensions. Input \"" + << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape " + << toco::ShapeToString(weights_shape) << "."; + + // Compute output shape + const int input_width = input_shape.dims(2); + const int input_height = input_shape.dims(1); + int output_height = op->stride_height * (input_height - 1); + int output_width = op->stride_width * (input_width - 1); + if (op->padding.type == PaddingType::kValid) { + output_height += kheight; + output_width += kwidth; + } else if (op->padding.type == PaddingType::kSame) { + output_height += 1; + output_width += 1; + } + + CHECK(specified_output_shape_array.GetBuffer().data == + std::vector({input_shape.dims(0), output_height, output_width, + weights_shape.dims(3)})) + << "Specified output shape: " << ShapeToString(output_array.shape()) + << ", does not agree with shape computed from input data and weights: [" + << input_shape.dims(0) << ", " << output_height << ", " << output_width + << ", " << weights_shape.dims(3) << "]."; + + // SUCCESS: Set the op's output shape according to the specified output shape. + *(output_array.mutable_shape()->mutable_dims()) = + specified_output_shape_array.GetBuffer().data; +} + void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { if (!EnsureBiasVectorShape(model, op)) { return; @@ -1300,7 +1410,7 @@ void ProcessTransposeOperator(Model* model, TransposeOperator* op) { std::vector const& perm = perm_array.GetBuffer().data; CHECK_EQ(perm.size(), input_shape.dimensions_count()) - << "Transpose permutation input " << op->inputs[0] + << "Transpose permutation input " << op->inputs[1] << " must be same length as input dimensions"; std::vector* output_dims = output_array.mutable_shape()->mutable_dims(); for (int i = 0; i < perm.size(); i++) { @@ -1357,6 +1467,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kRelu: case OperatorType::kRelu1: case OperatorType::kRelu6: + case OperatorType::kPRelu: case OperatorType::kSoftmax: case OperatorType::kLogSoftmax: case OperatorType::kLogistic: @@ -1402,8 +1513,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessConvOperator(model, static_cast(op)); break; case OperatorType::kTransposeConv: - // Unimplemented, hopefully another graph transformation will drop it or - // rewrite it. + ProcessTransposeConvOperator(model, + static_cast(op)); break; case OperatorType::kDepthwiseConv: ProcessDepthwiseConvOperator(model, @@ -1542,6 +1653,12 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kTranspose: ProcessTransposeOperator(model, static_cast(op)); break; + case OperatorType::kDynamicPartition: + case OperatorType::kDynamicStitch: + // DynamicPartition/DynamicStitch are currently only supported for + // transforms that remove them, so we avoid propagating shapes through + // them and let things settle once they've been removed. + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index 77316751bc2642a0c974d16f694aeebe1cd53a9f..9679ea0a776f9049699b087fd34f6a9088257c06 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -49,7 +49,10 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kTensorFlowReshape || type == OperatorType::kTanh || type == OperatorType::kMul || type == OperatorType::kSpaceToDepth || - type == OperatorType::kDepthToSpace || type == OperatorType::kLstmCell; + type == OperatorType::kStridedSlice || + type == OperatorType::kDepthToSpace || + type == OperatorType::kLstmCell || type == OperatorType::kGather || + type == OperatorType::kTranspose; } template @@ -62,8 +65,6 @@ std::unique_ptr QuantizeBuffer( static_cast&>(buffer); auto* quantized_buffer = new Buffer; quantized_buffer->data.resize(float_buffer.data.size()); - const auto qmin = static_cast(std::numeric_limits>::min()); - const auto qmax = static_cast(std::numeric_limits>::max()); for (std::size_t i = 0; i < float_buffer.data.size(); i++) { const float src_val = float_buffer.data[i]; double scaled_val; // Astonishingly, using 'float' degrades accuracy just @@ -75,9 +76,8 @@ std::unique_ptr QuantizeBuffer( } else { scaled_val = quantization_params.zero_point + inverse_scale * src_val; } - const auto rounded_val = static_cast(std::round(scaled_val)); - const auto clamped_val = std::min(qmax, std::max(qmin, rounded_val)); - quantized_buffer->data[i] = static_cast>(clamped_val); + quantized_buffer->data[i] = + tflite::SafeCast>(std::round(scaled_val)); } return std::unique_ptr(quantized_buffer); } @@ -222,7 +222,49 @@ ArrayDataType GetQuantizedDataType(const Array& array, default: LOG(FATAL) << "Unhandled final quantization type " << static_cast(array.final_data_type); - return default_type; + } +} + +void GetQuantizationParams(ArrayDataType data_type, const MinMax& minmax, + QuantizationParams* quantization_params) { + switch (data_type) { + case ArrayDataType::kInt8: + GetQuantizationParamsFromMinMax( + minmax, quantization_params); + break; + case ArrayDataType::kUint8: + GetQuantizationParamsFromMinMax( + minmax, quantization_params); + break; + case ArrayDataType::kInt16: + GetQuantizationParamsFromMinMax( + minmax, quantization_params); + break; + case ArrayDataType::kUint16: + GetQuantizationParamsFromMinMax( + minmax, quantization_params); + break; + case ArrayDataType::kInt32: + GetQuantizationParamsFromMinMax( + minmax, quantization_params); + break; + case ArrayDataType::kUint32: + GetQuantizationParamsFromMinMax( + minmax, quantization_params); + break; + case ArrayDataType::kInt64: + GetQuantizationParamsFromMinMax( + minmax, quantization_params); + break; + case ArrayDataType::kUint64: + GetQuantizationParamsFromMinMax( + minmax, quantization_params); + break; + case ArrayDataType::kFloat: + case ArrayDataType::kNone: + default: + LOG(FATAL) << "Unhandled final quantization type " + << static_cast(data_type); } } @@ -284,16 +326,14 @@ bool ChooseQuantizationForOperatorInput( if (op.type == OperatorType::kLstmCell) { if (input_index == LstmCellOperator::PREV_STATE_INPUT) { - GetQuantizationParamsFromMinMax( - model->flags, minmax, quantization_params); *quantized_data_type = ArrayDataType::kInt16; + GetQuantizationParams(*quantized_data_type, minmax, quantization_params); return true; } } - GetQuantizationParamsFromMinMax(model->flags, minmax, - quantization_params); *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8); + GetQuantizationParams(*quantized_data_type, minmax, quantization_params); transformation->AddMessageF( "For input array %s with min=%g" ", max=%g" @@ -416,15 +456,13 @@ bool ChooseQuantizationForOperatorOutput( if (op.type == OperatorType::kLstmCell) { if (output_index == LstmCellOperator::STATE_OUTPUT || output_index == LstmCellOperator::ACTIV_TEMP) { - GetQuantizationParamsFromMinMax( - model->flags, minmax, quantization_params); *quantized_data_type = ArrayDataType::kInt16; + GetQuantizationParams(*quantized_data_type, minmax, quantization_params); return true; } } - GetQuantizationParamsFromMinMax(model->flags, minmax, - quantization_params); *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8); + GetQuantizationParams(*quantized_data_type, minmax, quantization_params); transformation->AddMessageF( "For output array %s with min=%g, max=%g" ", chose to quantize as %s with zero_point=%d" @@ -472,9 +510,11 @@ bool Quantize::Run(Model* model, std::size_t op_index) { // // Let us just guard this assumption by the following assertion: for (const auto& input : op.inputs) { - if (IsInputArray(*model, input)) { - const auto& input_array = model->GetArray(input); - CHECK(input_array.quantization_params); + const auto& input_array = model->GetArray(input); + if (IsInputArray(*model, input) && + input_array.data_type == ArrayDataType::kFloat) { + CHECK(input_array.quantization_params) + << "Input array " << input << " is missing quantization_params"; } } if (!SupportsQuantization(op)) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc index 587f171bbf823408a45083c36d52f1d38c300123..aa93ace03af300f9cbd3f9c6620a6a58b9329aa4 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc @@ -60,7 +60,9 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, for (int i = 0; i < passthru_op->inputs.size(); i++) { if (!model->GetArray(passthru_op->inputs[i]).buffer) { count_nonconstant_input_arrays++; - main_input_array_index = i; + if (count_nonconstant_input_arrays == 1) { + main_input_array_index = i; + } } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc index 28f76c9d36d6f68c8997fa0cf620c8aec4273619..9b65feaa6443cd32ac1bef961600ff225d52d4b2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -30,6 +31,7 @@ bool RemoveTrivialQuantizedActivationFunc::Run(Model* model, const auto it = model->operators.begin() + op_index; auto* op = it->get(); if (op->fused_activation_function != FusedActivationFunctionType::kRelu && + op->fused_activation_function != FusedActivationFunctionType::kRelu1 && op->fused_activation_function != FusedActivationFunctionType::kRelu6) { return false; } @@ -42,33 +44,49 @@ bool RemoveTrivialQuantizedActivationFunc::Run(Model* model, } const auto& quantization_params = output_array.GetQuantizationParams(); + double clamp_min; + double clamp_max; + switch (op->fused_activation_function) { + case FusedActivationFunctionType::kRelu: + clamp_min = 0.0; + clamp_max = std::numeric_limits::infinity(); + break; + case FusedActivationFunctionType::kRelu1: + clamp_min = -1.0; + clamp_max = 1.0; + break; + case FusedActivationFunctionType::kRelu6: + clamp_min = 0.0; + clamp_max = 6.0; + break; + default: + LOG(FATAL) << "Unsupported fused activation type: " + << static_cast(op->fused_activation_function); + return false; + } + bool has_nontrivial_min_bound = false; bool has_nontrivial_max_bound = false; - if (op->fused_activation_function == FusedActivationFunctionType::kRelu || - op->fused_activation_function == FusedActivationFunctionType::kRelu6) { - double lowest_representable_output = - (0. - quantization_params.zero_point) * quantization_params.scale; - if (lowest_representable_output < 0.) { - has_nontrivial_min_bound = true; - AddMessageF( - "Quantized activation function is not trivial: " - "the lowest representable output value %g" - " less than the clamp min bound.", - lowest_representable_output); - } + double lowest_representable_output = + (0. - quantization_params.zero_point) * quantization_params.scale; + if (lowest_representable_output < clamp_min) { + has_nontrivial_min_bound = true; + AddMessageF( + "Quantized activation function is not trivial: " + "the lowest representable output value %g" + " less than the clamp min bound %g.", + lowest_representable_output, clamp_min); } - if (op->fused_activation_function == FusedActivationFunctionType::kRelu6) { - double highest_representable_output = - (255. - quantization_params.zero_point) * quantization_params.scale; - if (highest_representable_output > 6.) { - has_nontrivial_max_bound = true; - AddMessageF( - "Quantized activation function is not trivial: " - "the highest representable output value %g" - " is greater than the clamp max bound.", - highest_representable_output); - } + double highest_representable_output = + (255. - quantization_params.zero_point) * quantization_params.scale; + if (highest_representable_output > clamp_max) { + has_nontrivial_max_bound = true; + AddMessageF( + "Quantized activation function is not trivial: " + "the highest representable output value %g" + " is greater than the clamp max bound %g.", + highest_representable_output, clamp_max); } if (has_nontrivial_min_bound || has_nontrivial_max_bound) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc index 90f9381ec154f145cda826ff9730ff332cd96701..61477d59aea2f11c6347b84d8863763a86c43558 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc @@ -61,8 +61,8 @@ bool IsReshapeTrivial(const Model& model, const Operator& op, if (next_op->type == OperatorType::kTensorFlowReshape) { transformation->AddMessageF( "%s is trivial because its output is only consumed by another " - "Reshape op", - LogName(op)); + "Reshape op %s", + LogName(op), LogName(*next_op)); return true; } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc index 30a005c789bb12e880e8e4534088d99ebacba84a..9852c86c21b9a0714bc728e60b5d9dfe61ff52d1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc @@ -42,14 +42,22 @@ bool ReorderActivationFunctions::Run(Model* model, std::size_t op_index) { std::unique_ptr& exchange_op = *exchange_it; DCHECK(exchange_op); - if (exchange_op->type != OperatorType::kTensorFlowReshape) { - return false; + // Allow activation functions to move up over any operator that does not + // change the values. + switch (exchange_op->type) { + case OperatorType::kExpandDims: + case OperatorType::kSqueeze: + case OperatorType::kTensorFlowReshape: + case OperatorType::kTranspose: + break; + default: + return false; } DCHECK_EQ(exchange_op->outputs[0], ac_op->inputs[0]); - const auto& exchange_op_input = exchange_op->inputs[0]; - const auto& intermediate_array = exchange_op->outputs[0]; - const auto& ac_op_output = ac_op->outputs[0]; + const auto exchange_op_input = exchange_op->inputs[0]; + const auto intermediate_array = exchange_op->outputs[0]; + const auto ac_op_output = ac_op->outputs[0]; int count_ops_consuming_output = CountOpsWithInput(*model, intermediate_array); @@ -62,32 +70,58 @@ bool ReorderActivationFunctions::Run(Model* model, std::size_t op_index) { return false; } - // If the ac_op was originally producing an output_array we can't reorder as - // otherwise the output array would change. It'd be nice to still be able to - // reorder but if code is relying on the fetch names instead of array indices - // this won't work. - for (int i = 0; i < model->flags.output_arrays_size(); ++i) { - if (model->flags.output_arrays(i) == ac_op->outputs[0]) { - AddMessageF( - "Not exchanging activation function with %s to preserve output array " - "name %s", - LogName(*exchange_op), ac_op->outputs[0]); - return false; - } - } - - // Rewire by changing inputs, including all consumers. - Operator* consumer = GetFirstOpWithInput(*model, ac_op_output); - while (consumer) { - for (int i = 0; i < consumer->inputs.size(); ++i) { - if (consumer->inputs[i] == ac_op_output) { - consumer->inputs[i] = intermediate_array; + // If the ac_op was originally producing an output_array we can't trivially + // reorder as otherwise the output array name would change and break + // downstream assumptions. To work around that we perform some renaming below + // in that case at the cost of a bit more confusing array names in this rare + // case. + bool is_ac_op_output = + std::find(model->flags.output_arrays().begin(), + model->flags.output_arrays().end(), + ac_op_output) != model->flags.output_arrays().end(); + if (is_ac_op_output) { + // To preserve the output array name of the activation function we need to + // create a temporary to use to pass between ac->ex. + // + // Original: + // (a) -> EX -> (b) -> AC -> (c) + // Now: + // (a) -> AC -> (c') -> EX -> (c) + AddMessageF( + "Exchanging activation function %s with %s but renaming to preserve " + "output array %s", + LogName(*ac_op), LogName(*exchange_op), ac_op->outputs[0]); + + auto renamed_ac_op_output = + AvailableArrayName(*model, ac_op_output + "_exchange"); + ac_op->inputs[0] = exchange_op_input; + ac_op->outputs[0] = renamed_ac_op_output; + model->EraseArray(exchange_op->outputs[0]); + exchange_op->inputs[0] = renamed_ac_op_output; + exchange_op->outputs[0] = ac_op_output; + } else { + // Simply swap the order and update consumers to use the exchange_op output + // array (b). + // + // Original: + // (a) -> EX -> (b) -> AC -> (c) + // Now: + // (a) -> AC -> (c) -> EX -> (b) + AddMessageF("Exchanging activation function %s with %s", LogName(*ac_op), + LogName(*exchange_op)); + + Operator* consumer = GetFirstOpWithInput(*model, ac_op_output); + while (consumer) { + for (int i = 0; i < consumer->inputs.size(); ++i) { + if (consumer->inputs[i] == ac_op_output) { + consumer->inputs[i] = intermediate_array; + } } + consumer = GetFirstOpWithInput(*model, ac_op_output); } - consumer = GetFirstOpWithInput(*model, ac_op_output); + ac_op->inputs[0] = exchange_op_input; + exchange_op->inputs[0] = ac_op_output; } - ac_op->inputs[0] = exchange_op_input; - exchange_op->inputs[0] = ac_op_output; // Clear shapes; this will allow shape propagation to fix the sizes for us. model->GetOrCreateArray(ac_op->outputs[0]).clear_shape(); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc index 944901ece77430708013ea4ca340a30511ba0174..625d90205a801ad7c3fc1026c9cedc9b509f920d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc @@ -55,8 +55,8 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { const int size = input_buffer.data.size(); output_buffer.data.resize(size); QuantizationParams qparams; - GetQuantizationParamsFromMinMax( - model->flags, *fakequant_op->minmax, &qparams); + GetQuantizationParamsFromMinMax(*fakequant_op->minmax, + &qparams); for (int i = 0; i < size; i++) { const double src_val = input_buffer.data[i]; const double unclamped_quantized_val = diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc new file mode 100644 index 0000000000000000000000000000000000000000..d999c2df9483e096f333c6af83e1d9fee873d4d6 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc @@ -0,0 +1,134 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +// Gathers data from axis 0. +template +inline void Gather(const Array& input_array, int input_rank, + const Array& coords_array, Array* output_array) { + const Shape& input_shape = input_array.shape(); + const std::vector>& input_data = + input_array.GetBuffer().data; + const Shape& coords_shape = coords_array.shape(); + const std::vector& coords_data = + coords_array.GetBuffer().data; + + const Shape& output_shape = output_array->shape(); + std::vector>& output_data = + output_array->GetMutableBuffer().data; + output_data.resize(RequiredBufferSizeForShape(output_shape)); + + int rev_input_rank = input_shape.dimensions_count() - 1 - (input_rank - 1); + CHECK_EQ(coords_shape.dims(0), output_array->shape().dims(rev_input_rank)); + + int stride = 1; + for (int i = input_shape.dimensions_count() - 1; i >= input_rank - 1; --i) { + stride *= input_shape.dims(i); + } + + for (int i = 0; i < coords_shape.dims(0); ++i) { + DCHECK_GE(coords_data[i], 0); + DCHECK_LT(coords_data[i], input_shape.dims(rev_input_rank)); + DataType* out = output_data.data() + i * stride; + const DataType* in = input_data.data() + coords_data[i] * stride; + memcpy(out, in, sizeof(DataType) * stride); + } +} + +} // namespace + +// Resolves a constant Gather operation. +// This simply performs the gather and produces the output array with the +// appropriate values. +bool ResolveConstantGather::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + const auto* base_op = it->get(); + if (base_op->type != OperatorType::kGather) { + return false; + } + const auto* op = static_cast(base_op); + + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes. + return false; + } + if (!output_array.has_shape()) { + // Yield until the output shape has been set by PropagateFixedShapes. + return false; + } + + // Only handling axis=0 for now. + if (op->axis != 0) { + AddMessageF("%s has axis %d; only axis=0 is supported", LogName(*op), + op->axis); + return false; + } + + // We require constant inputs. + if (!IsConstantParameterArray(*model, op->inputs[0]) || + !IsConstantParameterArray(*model, op->inputs[1])) { + return false; + } + const Array& input_array = model->GetArray(op->inputs[0]); + const Array& coords_array = model->GetArray(op->inputs[1]); + CHECK(coords_array.data_type == ArrayDataType::kInt32) + << "Only int32 indices are supported"; + + CHECK(!output_array.buffer); + switch (output_array.data_type) { + case ArrayDataType::kFloat: + Gather(input_array, op->input_rank, coords_array, + &output_array); + break; + case ArrayDataType::kUint8: + Gather(input_array, op->input_rank, coords_array, + &output_array); + break; + case ArrayDataType::kInt32: + Gather(input_array, op->input_rank, coords_array, + &output_array); + break; + case ArrayDataType::kInt64: + Gather(input_array, op->input_rank, coords_array, + &output_array); + break; + default: + LOG(FATAL) << "Unsupported data type given to Gather op with output \"" + << op->outputs[0] << "\""; + break; + } + + // Erase input arrays if no longer used after we remove the op. + DeleteArrayIfUsedOnce(op->inputs[0], model); + DeleteArrayIfUsedOnce(op->inputs[1], model); + + // Erase the operator. + model->operators.erase(it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc index f227554bc505efe6a758fdd9894fee43f2500641..d4db6f1c009cd19515655fb31974a2e97cfa42e8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -28,21 +28,45 @@ limitations under the License. namespace toco { +bool CopyMinMaxFromFirstInput(const Operator& op, Model* model) { + auto& output_array = model->GetArray(op.outputs[0]); + if (output_array.minmax) { + return false; + } + const auto& input_array = model->GetArray(op.inputs[0]); + if (!input_array.minmax) { + return false; + } + const auto& input_minmax = input_array.GetMinMax(); + CHECK(!output_array.minmax); + auto& output_minmax = output_array.GetOrCreateMinMax(); + output_minmax.min = input_minmax.min; + output_minmax.max = input_minmax.max; + return true; +} + bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { const auto unary_it = model->operators.begin() + op_index; const auto* unary_op = unary_it->get(); - // Test for unary ops of types that we know how to resolve - if (unary_op->type != OperatorType::kCast && - unary_op->type != OperatorType::kNeg && - unary_op->type != OperatorType::kTensorFlowRsqrt && - unary_op->type != OperatorType::kTensorFlowSqrt && - unary_op->type != OperatorType::kTensorFlowSquare && - unary_op->type != OperatorType::kTensorFlowSum && - unary_op->type != OperatorType::kTensorFlowMin && - unary_op->type != OperatorType::kTensorFlowMax && - unary_op->type != OperatorType::kTensorFlowReshape) { - return false; + // Test for unary ops of types that we know how to resolve. + switch (unary_op->type) { + case OperatorType::kCast: + case OperatorType::kNeg: + case OperatorType::kTensorFlowRsqrt: + case OperatorType::kTensorFlowSqrt: + case OperatorType::kTensorFlowSquare: + case OperatorType::kTensorFlowSum: + case OperatorType::kTensorFlowMin: + case OperatorType::kTensorFlowMax: + case OperatorType::kTensorFlowReshape: + case OperatorType::kRelu6: + case OperatorType::kRelu1: + case OperatorType::kRelu: + break; + default: + return false; } + // Check if the input is a constant parameter. if (!IsConstantParameterArray(*model, unary_op->inputs[0])) { return false; @@ -76,6 +100,12 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { return false; } + // The min-max is only copied for ops that copy data without arithmetic. + // In future trivial transpose, etc, can be handled here. + if (unary_op->type == OperatorType::kTensorFlowReshape) { + CopyMinMaxFromFirstInput(*unary_op, model); + } + const auto& input_array = model->GetArray(unary_op->inputs[0]); // We have already tested above for existence of buffers (synonymous to being // a constant param). @@ -135,15 +165,34 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { } } else if (unary_op->type == OperatorType::kTensorFlowReshape) { CHECK(input_buffer_size == output_buffer_size); - memcpy(output_float_data.data(), (*input_float_data).data(), - output_buffer_size * sizeof(output_float_data[0])); + output_float_data = *input_float_data; } else if (unary_op->type == OperatorType::kTensorFlowSum) { - // At the moment only full reduction across all dimensions is supported. - float sum = 0.f; - for (int i = 0; i < input_buffer_size; i++) { - sum += (*input_float_data)[i]; + CHECK_EQ(unary_op->inputs.size(), 2) << "Sum needs 2 inputs"; + if (!IsConstantParameterArray(*model, unary_op->inputs[1])) { + AddMessageF("Axis input is non-constant"); + return false; } - for (int i = 0; i < output_buffer_size; ++i) { + auto& axis_array = model->GetArray(unary_op->inputs[1]); + CHECK(axis_array.data_type == ArrayDataType::kInt32); + int axis = axis_array.GetBuffer().data[0]; + CHECK_LT(axis, input_shape.dimensions_count()) << "Axis out of bounds"; + + // We currently only handle reduction on axis 0. + CHECK_EQ(axis, 0) << "Only reduction along axis 0 is supported"; + // We currently only handle 1-D and 2-D input tensors. + CHECK_LE(input_shape.dimensions_count(), 2) << "Rank >2 not yet supported"; + // We only support keep_dims=true; shape prop will need to change otherwise. + auto sum_op = static_cast(unary_op); + CHECK(sum_op->keep_dims) << "Only keep_dims=true is supported"; + + std::vector indices(input_shape.dimensions_count()); + for (int i = 0; i < input_shape.dims(1); ++i) { + indices[1] = i; + float sum = 0.f; + for (int j = 0; j < input_shape.dims(0); ++j) { + indices[0] = j; + sum += (*input_float_data)[Offset(input_shape, indices)]; + } output_float_data[i] = sum; } } else if (unary_op->type == OperatorType::kTensorFlowMin) { @@ -193,6 +242,37 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { } output_float_data[i] = outval; } + } else if (unary_op->type == OperatorType::kRelu6 && + unary_op->type == OperatorType::kRelu1 && + unary_op->type == OperatorType::kRelu) { + for (size_t i = 0; i < output_buffer_size; ++i) { + const float value = (*input_float_data)[i]; + float new_value = 0.0f; + switch (unary_op->type) { + case OperatorType::kRelu: { + static constexpr float kLower = 0; + new_value = value < kLower ? kLower : value; + break; + } + case OperatorType::kRelu1: { + static constexpr float kUpper = 1; + static constexpr float kLower = -1; + new_value = value > kUpper ? kUpper : value < kLower ? kLower : value; + break; + } + case OperatorType::kRelu6: { + static constexpr float kUpper = 6; + static constexpr float kLower = 0; + new_value = value > kUpper ? kUpper : value < kLower ? kLower : value; + break; + } + default: + LOG(FATAL) << "Unsupported activation function " + << LogName(*unary_op); + return false; + } + output_float_data[i] = new_value; + } } else { LOG(FATAL) << "should not get here."; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc b/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc new file mode 100644 index 0000000000000000000000000000000000000000..48c326651f3201b4f7a31ac2440b171841e8ed7b --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc @@ -0,0 +1,240 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { + // Collapses a partitioned tf.nn.embedding_lookup back into a single Gather. + // https://www.tensorflow.org/api_docs/python/tf/nn/embedding_lookup + // This transform attempts to identify the len(params) > 1 case and collapse + // it to the len(params) = 1 case by concatenating the original params and + // reversing the partitioning. + // + // If len(params) to the tf.nn.embedding_lookup == 1, the whole op becomes + // simply a gather: + // https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/python/ops/embedding_ops.py#L150 + // + // Notes on this implementation: + // - only supports partition_strategy='mod' + // + // A rough graph of a partitioned embedding_lookup looks like: + // (ids)--+-->FloorDiv--+-->DynamicPartition-->[[Gather]]--\ + // \-->FloorMod--/ | + // V | + // Range-->DynamicPartition-------->DynamicStitch<---------/ + // (const) V + // (embeddings) + + // First look for the final DynamicStitch. + auto op_it = model->operators.begin() + op_index; + if (op_it->get()->type != OperatorType::kDynamicStitch) { + return false; + } + auto* stitch_op = static_cast(op_it->get()); + + // Split up the DynamicStitch inputs into the indices and data. + std::vector stitch_indices_inputs; + std::vector stitch_data_inputs; + for (size_t i = 0; i < stitch_op->num_partitions; ++i) { + stitch_indices_inputs.push_back(stitch_op->inputs[i]); + } + for (size_t i = stitch_op->num_partitions; i < stitch_op->num_partitions * 2; + ++i) { + stitch_data_inputs.push_back(stitch_op->inputs[i]); + } + + // Validate all indices come from the same DynamicPartition. + DynamicPartitionOperator* indices_partition_op = nullptr; + for (const string& indices_partition_output_name : stitch_indices_inputs) { + auto* op = GetOpWithOutput(*model, indices_partition_output_name); + CHECK(op) << "Source of " << indices_partition_output_name << " not found"; + if (op->type != OperatorType::kDynamicPartition) { + AddMessageF( + "Skipping because indices input %s into " + "%s is unexpected", + LogName(*op), LogName(*stitch_op)); + return false; + } + if (!indices_partition_op) { + indices_partition_op = static_cast(op); + } else { + // Ensure this is the same op as previous ones. + if (op != indices_partition_op) { + AddMessageF( + "Skipping because indices input %s into " + "%s is from a different source op than others", + LogName(*op), LogName(*stitch_op)); + return false; + } + } + } + CHECK(indices_partition_op) << "No indices inputs"; + + // The data for the indices must be a constant range of the array shape. + if (!IsConstantParameterArray(*model, indices_partition_op->inputs[0])) { + AddMessageF("Skipping because indices partition data is non-constant"); + return false; + } + auto& indices_data_array = model->GetArray(indices_partition_op->inputs[0]); + if (indices_data_array.data_type == ArrayDataType::kNone) { + // Yield until data types are propagated. + return false; + } + CHECK(indices_data_array.data_type == ArrayDataType::kInt32) + << "Indices partition inputs must be int32"; + const auto& indices_data_buffer = + indices_data_array.GetBuffer().data; + for (size_t i = 0; i < indices_data_buffer.size(); ++i) { + CHECK_EQ(indices_data_buffer[i], i) << "Indices range must be identity"; + } + + // Find all of the gathers used for the data inputs. + std::vector gather_ops; + for (const string& gather_output_name : stitch_data_inputs) { + auto* op = GetOpWithOutput(*model, gather_output_name); + CHECK(op) << "Source of " << gather_output_name << " not found"; + if (op->type != OperatorType::kGather) { + AddMessageF( + "Skipping because data input %s into %s " + "is unexpected", + LogName(*op), LogName(*stitch_op)); + return false; + } + gather_ops.push_back(static_cast(op)); + } + + // Validate all gathers come from the same DynamicPartition. + DynamicPartitionOperator* data_partition_op = nullptr; + for (auto* gather_op : gather_ops) { + auto* op = GetOpWithOutput(*model, gather_op->inputs[1]); + CHECK(op) << "Source of " << gather_op->inputs[1] << " not found"; + if (op->type != OperatorType::kDynamicPartition) { + AddMessageF( + "Skipping because data input %s into " + "%s is unexpected", + LogName(*op), LogName(*gather_op)); + return false; + } + if (!data_partition_op) { + data_partition_op = static_cast(op); + } else { + // Ensure this is the same op as previous ones. + if (op != data_partition_op) { + AddMessageF( + "Skipping because data input %s into " + "%s is from a different source op than others", + LogName(*op), LogName(*gather_op)); + return false; + } + } + } + CHECK(data_partition_op) << "No data inputs"; + + // Validate the partition ops have the same sizes. + CHECK_EQ(indices_partition_op->num_partitions, + data_partition_op->num_partitions) + << "Indices and data partition ops have differing dimensions"; + int num_partitions = indices_partition_op->num_partitions; + + // Partition strategy of 'mod' gives us a FloorMod and FloorDiv. + // The gather partition uses the FloorDiv as the data and FloorMod as the + // partitions and the indices use the FloorMod as their partitions. + Operator* div_op = GetOpWithOutput(*model, data_partition_op->inputs[0]); + Operator* mod_op = GetOpWithOutput(*model, data_partition_op->inputs[1]); + CHECK(div_op && div_op->type == OperatorType::kFloorDiv) + << "Unsupported partition strategy"; + CHECK(mod_op && mod_op->type == OperatorType::kFloorMod) + << "Unsupported partition strategy"; + CHECK_EQ(mod_op, GetOpWithOutput(*model, indices_partition_op->inputs[1])) + << "Indices and data parition ops require the same partition strategy " + "and inputs"; + + // Glob together all of the gather data. This is not yet in the correct order. + auto* gather_params_concat_op = new ConcatenationOperator; + for (const auto& gather_op : gather_ops) { + gather_params_concat_op->inputs.push_back(gather_op->inputs[0]); + } + gather_params_concat_op->outputs.push_back( + AvailableArrayName(*model, gather_ops[0]->inputs[0] + "_unpartitioned")); + op_it = model->operators.emplace(op_it, gather_params_concat_op) + 1; + model->GetOrCreateArray(gather_params_concat_op->outputs[0]); + + // Permute the gather params to undo the partitioning that was originally + // done. + auto* gather_params_permute_op = new GatherOperator; + gather_params_permute_op->inputs.push_back( + gather_params_concat_op->outputs[0]); + gather_params_permute_op->inputs.push_back( + AvailableArrayName(*model, gather_ops[0]->inputs[0] + "_permuted/perm")); + gather_params_permute_op->outputs.push_back( + AvailableArrayName(*model, gather_ops[0]->inputs[0] + "_permuted")); + op_it = model->operators.emplace(op_it, gather_params_permute_op) + 1; + model->GetOrCreateArray(gather_params_permute_op->outputs[0]); + const auto& partition_array = model->GetArray(gather_ops[0]->inputs[0]); + const auto& partition_array_dims = partition_array.shape().dims(); + gather_params_permute_op->input_rank = + partition_array.shape().dimensions_count(); + auto& perm_array = + model->GetOrCreateArray(gather_params_permute_op->inputs[1]); + perm_array.data_type = ArrayDataType::kInt32; + perm_array.mutable_shape()->ReplaceDims( + {num_partitions * partition_array_dims[0]}); + auto& perm_data = perm_array.GetMutableBuffer().data; + perm_data.resize(RequiredBufferSizeForShape(perm_array.shape())); + // NOTE: this is what relies on the partition_strategy. + for (int i = 0; i < num_partitions * partition_array_dims[0]; ++i) { + int p = i % num_partitions; + perm_data[i] = p * partition_array_dims[0] + i / num_partitions; + } + + // Insert the new unpartitioned gather op. + auto* merged_gather_op = new GatherOperator; + merged_gather_op->inputs = {gather_params_permute_op->outputs[0], + mod_op->inputs[0]}; + merged_gather_op->outputs = {stitch_op->outputs[0]}; + merged_gather_op->input_rank = partition_array.shape().dimensions_count(); + model->operators.emplace(op_it, merged_gather_op); + + AddMessageF( + "Replacing suspected partitioned tf.nn.embedding_lookup (starting at %s " + "+ %s and ending at %s) with a single unpartitioned gather %s", + LogName(*div_op), LogName(*mod_op), LogName(*stitch_op), + LogName(*merged_gather_op)); + + // Ensure the stitch output array is dead, as we don't want whatever was in it + // previously now that we've redefined it. It'll be recreated when needed. + model->EraseArray(stitch_op->outputs[0]); + model->GetOrCreateArray(merged_gather_op->outputs[0]); + + // Erase all the original ops. + DeleteOpAndArraysIfUnused(model, div_op); + DeleteOpAndArraysIfUnused(model, mod_op); + for (auto* gather_op : gather_ops) { + DeleteOpAndArraysIfUnused(model, gather_op); + } + DeleteOpAndArraysIfUnused(model, indices_partition_op); + DeleteOpAndArraysIfUnused(model, data_partition_op); + DeleteOpAndArraysIfUnused(model, stitch_op); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 27d2f33a8d278156262753e6572c10ff967bda4c..b844e0b9484f55ffaad63e55956ff789036f05e3 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -272,6 +272,39 @@ void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { } } +void ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { + CHECK_EQ(input_tensor.dtype(), DT_BOOL); + const auto& input_shape = input_tensor.tensor_shape(); + CHECK_LE(input_shape.dim_size(), 4); + ImportShape(input_shape.dim(), output_array->mutable_shape()); + int input_flat_size = 1; + for (int k = 0; k < input_shape.dim_size(); k++) { + input_flat_size *= input_shape.dim(k).size(); + } + auto& output_bool_data = + output_array->GetMutableBuffer().data; + output_bool_data.resize(RequiredBufferSizeForShape(output_array->shape()), + false); + if (input_tensor.bool_val_size()) { + for (int i = 0; i < input_tensor.bool_val_size(); i++) { + output_bool_data[i] = input_tensor.bool_val(i); + } + } else if (input_tensor.tensor_content().size() == input_flat_size) { + std::vector buf(input_tensor.tensor_content().size()); + toco::port::CopyToBuffer(input_tensor.tensor_content(), buf.data()); + for (int i = 0; i < input_tensor.tensor_content().size(); i++) { + output_bool_data[i] = static_cast(buf[i]); + } + } else { + // Some graphs have bool const nodes without actual value... + // assuming that 'false' is implied. + // So far only encountered that in an array with 1 entry, let's + // require that until we encounter a graph where that's not the case. + CHECK_EQ(output_bool_data.size(), 1); + output_bool_data[0] = false; + } +} + void ImportStringArray(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_STRING); const auto& input_shape = input_tensor.tensor_shape(); @@ -318,6 +351,18 @@ void CheckInputsCount(const NodeDef& node, << " input(s) other than control dependencies: " << node.DebugString(); } +template +string CreateConstArray(Model* model, string const& name, + std::vector > const& data) { + // Utility function to create a const 1D array, useful for input parameters. + string array_name = toco::AvailableArrayName(*model, name); + auto& array = model->GetOrCreateArray(array_name); + array.data_type = T; + array.mutable_shape()->mutable_dims()->emplace_back(data.size()); + array.GetMutableBuffer().data = data; + return array_name; +} + void ConvertConstOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -347,6 +392,10 @@ void ConvertConstOperator(const NodeDef& node, array.data_type = ArrayDataType::kString; ImportStringArray(tensor, &array); break; + case DT_BOOL: + array.data_type = ArrayDataType::kBool; + ImportBoolArray(tensor, &array); + break; default: array.data_type = ArrayDataType::kNone; // do nothing, silently ignore the Const data. @@ -678,9 +727,12 @@ void ConvertSqueezeOperator(const NodeDef& node, op->inputs.push_back(node.input(0)); op->outputs.push_back(node.name()); - const auto& squeeze_dims = GetListAttr(node, "squeeze_dims"); - for (int i = 0; i < squeeze_dims.i_size(); ++i) { - op->squeeze_dims.push_back(squeeze_dims.i(i)); + // When omitted we are to squeeze all dimensions == 1. + if (HasAttr(node, "squeeze_dims")) { + const auto& squeeze_dims = GetListAttr(node, "squeeze_dims"); + for (int i = 0; i < squeeze_dims.i_size(); ++i) { + op->squeeze_dims.push_back(squeeze_dims.i(i)); + } } model->operators.emplace_back(op); @@ -1399,12 +1451,8 @@ void ConvertFusedBatchNormOperator(const NodeDef& node, const string& moving_variance_input = node.input(4); // Create an array holding the epsilon value (typically, 0.001). - const string epsilon_array_name = node.name() + "_epsilon_array"; - auto& epsilon_array = model->GetOrCreateArray(epsilon_array_name); - epsilon_array.data_type = ArrayDataType::kFloat; - *epsilon_array.mutable_shape()->mutable_dims() = {1}; - epsilon_array.GetMutableBuffer().data.push_back( - GetFloatAttr(node, "epsilon")); + const string epsilon_array_name = CreateConstArray( + model, node.name() + "_epsilon_array", {GetFloatAttr(node, "epsilon")}); // Add epsilon to the moving variance. const string epsilon_add_op_name = node.name() + "_epsilon"; @@ -1493,7 +1541,9 @@ void ConvertMeanOperator(const NodeDef& node, op->inputs.push_back(node.input(1)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); - if (HasAttr(node, "keep_dims")) { + if (HasAttr(node, "keepdims")) { + op->keep_dims = GetBoolAttr(node, "keepdims"); + } else if (HasAttr(node, "keep_dims")) { op->keep_dims = GetBoolAttr(node, "keep_dims"); } } @@ -1532,16 +1582,56 @@ void ConvertTransposeConvOperator(const NodeDef& node, CHECK_EQ(node.op(), "Conv2DBackpropInput"); CheckInputsCount(node, tf_import_flags, 3); auto* op = new TransposeConvOperator; - op->inputs.push_back(node.input(2)); - op->inputs.push_back(node.input(1)); op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->inputs.push_back(node.input(2)); op->outputs.push_back(node.name()); const auto& strides = GetListAttr(node, "strides"); - CHECK_EQ(strides.i_size(), 4); - CHECK_EQ(strides.i(0), 1); op->stride_height = strides.i(1); op->stride_width = strides.i(2); - CHECK_EQ(strides.i(3), 1); + CHECK_EQ(strides.i_size(), 4) + << "Can only import TransposeConv ops with 4D strides. TensorFlow op \"" + << node.name() << "\" has " << strides.i_size() << "D strides."; + CHECK((strides.i(0) == 1) && (strides.i(3) == 1)) + << "Can only import TransposeConv ops with striding along the height " + "(1st) or width (2nd) axis. TensorFlow op \"" + << node.name() << "\" had strides:[ " << strides.i(0) << ", " + << strides.i(1) << ", " << strides.i(2) << ", " << strides.i(3) << "]."; + op->stride_height = strides.i(1); + op->stride_width = strides.i(2); + if (HasAttr(node, "dilations")) { + const auto& dilations = GetListAttr(node, "dilations"); + CHECK_EQ(dilations.i_size(), 4) + << "Dilation unsupported in TransposeConv. TensorFlow op \"" + << node.name() << "\" had dilations"; + CHECK((dilations.i(0) == 1) && (dilations.i(1) == 1) && + (dilations.i(1) == 1) && (dilations.i(3) == 1)) + << "Dilation unsupported in TransposeConv. TensorFlow op \"" + << node.name() << "\" had dilations:[ " << dilations.i(0) << ", " + << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3) + << "]."; + } + + const string& weights_name = node.input(TransposeConvOperator::WEIGHTS); + const string& transposed_weights_name = weights_name + "_transposed"; + // Check if a TransposeOperator was already created for these weights + // (can happen when multiple layers share the same weights). + const Operator* existing_transpose = + GetOpWithOutput(*model, transposed_weights_name); + if (existing_transpose) { + CHECK(existing_transpose->type == OperatorType::kTranspose); + } else { + // Transpose weights from HWIO order to OHWI order, which is more efficient + // for computation + TransposeOperator* transpose = new TransposeOperator; + string perm_array = CreateConstArray( + model, node.name() + "_transpose_perm", {3, 0, 1, 2}); + transpose->inputs = {weights_name, perm_array}; + transpose->outputs = {transposed_weights_name}; + model->operators.emplace_back(transpose); + } + op->inputs[1] = transposed_weights_name; + auto const& padding = GetStringAttr(node, "padding"); if (padding == "SAME") { op->padding.type = PaddingType::kSame; @@ -1837,19 +1927,9 @@ void ConvertTopKV2Operator(const NodeDef& node, op->inputs.push_back(node.input(0)); // K can be encoded as attr (TopK) convert it to a const. if (HasAttr(node, "k")) { - // Convert attribute into const tensor. - const string array_name = node.name() + "k"; - auto& array = model->GetOrCreateArray(array_name); - array.data_type = ArrayDataType::kInt32; - // Size of array is always 1. - array.mutable_shape()->mutable_dims()->emplace_back(1); - - auto& output_int_data = - array.GetMutableBuffer().data; - output_int_data.resize(1); - output_int_data[0] = GetIntAttr(node, "k"); - op->inputs.push_back(array_name); - + string k_array = CreateConstArray( + model, node.name() + "k", {GetIntAttr(node, "k")}); + op->inputs.push_back(k_array); } else { CheckInputsCount(node, tf_import_flags, 2); op->inputs.push_back(node.input(1)); @@ -1859,6 +1939,42 @@ void ConvertTopKV2Operator(const NodeDef& node, op->outputs.push_back(node.name() + ":1"); model->operators.emplace_back(op.release()); } + +void ConvertDynamicPartitionOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + auto op = absl::make_unique(); + CHECK(HasAttr(node, "num_partitions")); + op->num_partitions = GetIntAttr(node, "num_partitions"); + CheckInputsCount(node, tf_import_flags, 2); + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + CHECK_GT(op->num_partitions, 1); + op->outputs.push_back(node.name()); // Implicit :0. + for (int i = 1; i < op->num_partitions; ++i) { + op->outputs.push_back(node.name() + ":" + std::to_string(i)); + } + model->operators.emplace_back(op.release()); +} + +void ConvertDynamicStitchOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + // The parallel and non-parallel variants are the same besides whether they + // have a parallel loop; there are no behavioral differences. + CHECK(node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch"); + auto op = absl::make_unique(); + CHECK(HasAttr(node, "N")); + op->num_partitions = GetIntAttr(node, "N"); + // Expect all ID partitions + all value partitions. + CheckInputsCount(node, tf_import_flags, op->num_partitions * 2); + for (int i = 0; i < op->num_partitions * 2; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op.release()); +} + } // namespace std::unique_ptr ImportTensorFlowGraphDef( @@ -2044,6 +2160,11 @@ std::unique_ptr ImportTensorFlowGraphDef( ConvertExpOperator(node, tf_import_flags, model); } else if (node.op() == "TopK" || node.op() == "TopKV2") { ConvertTopKV2Operator(node, tf_import_flags, model); + } else if (node.op() == "DynamicPartition") { + ConvertDynamicPartitionOperator(node, tf_import_flags, model); + } else if (node.op() == "DynamicStitch" || + node.op() == "ParallelDynamicStitch") { + ConvertDynamicStitchOperator(node, tf_import_flags, model); } else { ConvertUnsupportedOperator(node, tf_import_flags, model); } diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 346859ab392d257355b21411a1b3691c8dda5421..5199e292e19c2ac59dcfc2efd9947cc788b0299d 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -29,6 +29,8 @@ limitations under the License. namespace toco { +using tflite::QuantizationParams; + enum class OperatorType { kNone, // General-purpose neural network operators. @@ -63,6 +65,7 @@ enum class OperatorType { kRelu, kRelu1, kRelu6, + kPRelu, kSoftmax, kLogSoftmax, kSub, @@ -115,6 +118,8 @@ enum class OperatorType { kTensorFlowTile, kTranspose, kTopK_V2, + kDynamicPartition, + kDynamicStitch, // An unsupported TF operation. It's only needed to be able to represent TF // graph internally and is expected to be dropped by graph transformations. kTensorFlowUnsupported, @@ -244,6 +249,8 @@ struct GenericBuffer { // in containers and have the containers call the right subclass destructor. virtual ~GenericBuffer() {} + virtual int Length() const = 0; + const ArrayDataType type; protected: @@ -256,6 +263,8 @@ template struct Buffer : GenericBuffer { Buffer() : GenericBuffer(A) {} + int Length() const override { return data.size(); } + std::vector> data; }; @@ -558,6 +567,18 @@ struct Relu6Operator : Operator { Relu6Operator() : Operator(OperatorType::kRelu6) {} }; +// PRelu +// f(x) = alpha * x for x < 0, f(x) = x for x >= 0. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: the alpha array +// +// Equivalent to keras.layers.PReLU. +struct PReluOperator : Operator { + PReluOperator() : Operator(OperatorType::kPRelu) {} +}; + // Element-wise Logistic operator: // x -> Logistic(x) = 1 / (1 + exp(-x)) // @@ -840,19 +861,29 @@ struct SqueezeOperator : Operator { }; // Inputs: -// inputs[0]: required: the input activations array -// inputs[1]: required: the Conv weights -// channel. +// inputs[0]: required: the output shape +// inputs[1]: required: the weights +// inputs[2]: required: the input activations array +// NOTE: The input activations is NOT the first input. +// // // Outputs: // outputs[0]: required: the output activations array // // TensorFlow equivalent: Conv2DBackpropInput struct TransposeConvOperator : Operator { + enum Inputs { + OUTPUT_SHAPE = 0, + WEIGHTS = 1, + DATA_INPUT = 2, + }; + TransposeConvOperator() : Operator(OperatorType::kTransposeConv) {} Padding padding; int stride_width = 0; int stride_height = 0; + // Dilation is possible with transpose convolution, but Tensorflow does not + // currently support it, so we omit it. }; // Given a tensor input, this operation calculates element-wise exponential @@ -1410,6 +1441,30 @@ struct TopKV2Operator : Operator { TopKV2Operator() : Operator(OperatorType::kTopK_V2) {} }; +// DynamicPartition operator: +// +// Inputs: +// inputs[0]: required: data. +// inputs[1]: required: partitions. +// +// TensorFlow equivalent: DynamicPartition +struct DynamicPartitionOperator : Operator { + DynamicPartitionOperator() : Operator(OperatorType::kDynamicPartition) {} + int num_partitions; +}; + +// DynamicStitch operator: +// +// Inputs: +// inputs[0,N): required: indices. +// inputs[N,2N): required: data. +// +// TensorFlow equivalent: DynamicStitch/ParallelDynamicStitch +struct DynamicStitchOperator : Operator { + DynamicStitchOperator() : Operator(OperatorType::kDynamicStitch) {} + int num_partitions; +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are @@ -1423,22 +1478,6 @@ inline bool operator<(const Alloc& a, const Alloc& b) { return a.start < b.start; } -// Quantization parameters, determining the mapping of quantized values -// to real values (i.e. determining how quantized values are mathematically -// interpreted). -// -// The correspondence is as follows: -// -// real_value = scale * (quantized_value - zero_point); -// -// In other words, zero_point designates which quantized value corresponds to -// the real 0 value, and scale designates the difference between the real values -// corresponding to consecutive quantized values differing by 1. -struct QuantizationParams { - int32 zero_point = 0; - double scale = 0.; -}; - class Shape { public: // For Shape, we stick to half-way encapsulation for now: diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc index 4e2dec15a534607ef9207149a2e6061069eabcb1..4264f21c76e6f4a26d1be710874c0edb96a6ca6d 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -72,6 +72,12 @@ bool ParseModelFlagsFromCommandLineFlags( "Shapes corresponding to --input_arrays, colon-separated. For " "many models each shape takes the form batch size, input array " "height, input array width, input array depth."), + Flag("batch_size", parsed_flags.batch_size.bind(), + parsed_flags.batch_size.default_value(), + "Batch size for the model. Replaces the first dimension of an " + "input size array if undefined. Use only with SavedModels when " + "--input_shapes flag is not specified. Always use --input_shapes " + "flag with frozen graphs."), Flag("input_data_type", parsed_flags.input_data_type.bind(), parsed_flags.input_data_type.default_value(), "Deprecated: use --input_data_types instead. Input array type, if " diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto index 867b86f31d16b502a7aeb92cb3d8c96117630cd2..42e0f54826dd809a801a8ac1bfd0a5a7660382a8 100644 --- a/tensorflow/contrib/lite/toco/model_flags.proto +++ b/tensorflow/contrib/lite/toco/model_flags.proto @@ -96,11 +96,13 @@ message RnnState { // model that does not already contain such MinMax information. message ArraysExtraInfo { message Entry { - // Next ID to use: 5. + // Next ID to use: 7. optional string name = 1; optional float min = 2; optional float max = 3; optional IODataType data_type = 4; + optional InputArrayShape shape = 5; + optional float constant_float_value = 6; } repeated Entry entries = 1; } diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc index fddf6cc83686632033f31496ec42b33e2ea15f20..5e421ba944cccd9746c66bc33e986b4406dd3bf5 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc @@ -144,7 +144,9 @@ std::unique_ptr MaybeReplaceCompositeSubgraph( MaybeResolveClusters(tf_graph, cluster_factories); // Copy function definitions - *(pruned_graph->mutable_library()) = tf_graph.library(); + if (pruned_graph) { + *(pruned_graph->mutable_library()) = tf_graph.library(); + } return pruned_graph; } diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD index a2b8145a67278c3ac0065f9551da6ffd1de60772..9d3e1daf1258c6bc076dac566129174430bb761d 100644 --- a/tensorflow/contrib/lite/toco/tflite/BUILD +++ b/tensorflow/contrib/lite/toco/tflite/BUILD @@ -115,9 +115,11 @@ cc_library( deps = [ ":operator", ":types", + "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/schema:schema_fbs", "//tensorflow/contrib/lite/toco:model", "//tensorflow/contrib/lite/toco:tooling_util", + "//tensorflow/contrib/lite/tools:verifier", "@flatbuffers", ], ) diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc index 5b1ab514b23248cd98e66847185d0e8b9fe2d6aa..c0e7ab2ef57ed8edf1b7cda08c64f6ae66172af3 100644 --- a/tensorflow/contrib/lite/toco/tflite/import.cc +++ b/tensorflow/contrib/lite/toco/tflite/import.cc @@ -15,10 +15,12 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/tflite/import.h" #include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" #include "tensorflow/contrib/lite/toco/tflite/operator.h" #include "tensorflow/contrib/lite/toco/tflite/types.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/contrib/lite/tools/verifier.h" namespace toco { @@ -64,6 +66,9 @@ void ImportTensors(const ::tflite::Model& input_model, Model* model) { auto shape = input_tensor->shape(); if (shape) { + // If the shape is 0-dimensional, make sure to record it as such, + // as oppose to leaving the array without a shape. + array.mutable_shape()->mutable_dims()->clear(); for (int i = 0; i < shape->Length(); ++i) { auto d = shape->Get(i); array.mutable_shape()->mutable_dims()->push_back(d); @@ -159,16 +164,28 @@ void ImportIOTensors(const ::tflite::Model& input_model, } } +namespace { +bool Verify(const void* buf, size_t len) { + ::flatbuffers::Verifier verifier(static_cast(buf), len); + return ::tflite::VerifyModelBuffer(verifier); +} +} // namespace + std::unique_ptr Import(const ModelFlags& model_flags, const string& input_file_contents) { + ::tflite::AlwaysTrueResolver r; + if (!::tflite::Verify(input_file_contents.data(), input_file_contents.size(), + r, ::tflite::DefaultErrorReporter())) { + LOG(FATAL) << "Invalid flatbuffer."; + } const ::tflite::Model* input_model = ::tflite::GetModel(input_file_contents.data()); // Full list of all known operators. const auto ops_by_name = BuildOperatorByNameMap(); - if (input_model->subgraphs()->size() != 1) { - LOG(FATAL) << "# of subgraphs in tflite should be exactly 1 for now."; + if (!input_model->subgraphs() || input_model->subgraphs()->size() != 1) { + LOG(FATAL) << "Number of subgraphs in tflite should be exactly 1."; } std::unique_ptr model; model.reset(new Model); diff --git a/tensorflow/contrib/lite/toco/tflite/import_test.cc b/tensorflow/contrib/lite/toco/tflite/import_test.cc index aad6e780d5eb5c3dbc880906df5053ad231ffd54..edd22f783f03b1fbd34039cd7b00f08d34ca9fc6 100644 --- a/tensorflow/contrib/lite/toco/tflite/import_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/import_test.cc @@ -27,60 +27,110 @@ namespace { using ::testing::ElementsAre; +using flatbuffers::Offset; +using flatbuffers::Vector; class ImportTest : public ::testing::Test { protected: template - flatbuffers::Offset> CreateDataVector( - const std::vector& data) { + Offset> CreateDataVector(const std::vector& data) { return builder_.CreateVector(reinterpret_cast(data.data()), sizeof(T) * data.size()); } - // This is a very simplistic model. We are not interested in testing all the - // details here, since tf.mini's testing framework will be exercising all the - // conversions multiple times, and the conversion of operators is tested by - // separate unittests. - void BuildTestModel() { - // The tensors + + Offset>> BuildBuffers() { + auto buf0 = ::tflite::CreateBuffer(builder_, CreateDataVector({})); + auto buf1 = ::tflite::CreateBuffer( + builder_, CreateDataVector({1.0f, 2.0f, 3.0f, 4.0f})); + auto buf2 = + ::tflite::CreateBuffer(builder_, CreateDataVector({3.0f, 4.0f})); + return builder_.CreateVector( + std::vector>({buf0, buf1, buf2})); + } + + Offset>> BuildTensors() { auto q = ::tflite::CreateQuantizationParameters( builder_, /*min=*/builder_.CreateVector({0.1f}), /*max=*/builder_.CreateVector({0.2f}), /*scale=*/builder_.CreateVector({0.3f}), /*zero_point=*/builder_.CreateVector({100ll})); - auto buf0 = ::tflite::CreateBuffer(builder_, CreateDataVector({})); - auto buf1 = - ::tflite::CreateBuffer(builder_, CreateDataVector({1.0f, 2.0f})); - auto buf2 = - ::tflite::CreateBuffer(builder_, CreateDataVector({3.0f})); - auto buffers = builder_.CreateVector( - std::vector>({buf0, buf1, buf2})); - auto t1 = ::tflite::CreateTensor(builder_, - builder_.CreateVector({1, 2, 3, 4}), - ::tflite::TensorType_FLOAT32, 1, - builder_.CreateString("tensor_one"), q); + auto t1 = + ::tflite::CreateTensor(builder_, builder_.CreateVector({1, 2, 2}), + ::tflite::TensorType_FLOAT32, 1, + builder_.CreateString("tensor_one"), q); auto t2 = ::tflite::CreateTensor(builder_, builder_.CreateVector({2, 1}), ::tflite::TensorType_FLOAT32, 2, builder_.CreateString("tensor_two"), q); - auto tensors = builder_.CreateVector( - std::vector>({t1, t2})); - - // The operator codes. - auto c1 = - ::tflite::CreateOperatorCode(builder_, ::tflite::BuiltinOperator_CUSTOM, - builder_.CreateString("custom_op_one")); - auto c2 = ::tflite::CreateOperatorCode( - builder_, ::tflite::BuiltinOperator_CONV_2D, 0); - auto opcodes = builder_.CreateVector( - std::vector>({c1, c2})); - - auto subgraph = ::tflite::CreateSubGraph(builder_, tensors, 0, 0, 0); - std::vector> subgraph_vector( - {subgraph}); - auto subgraphs = builder_.CreateVector(subgraph_vector); + return builder_.CreateVector( + std::vector>({t1, t2})); + } + + Offset>> BuildOpCodes( + std::initializer_list<::tflite::BuiltinOperator> op_codes) { + std::vector> op_codes_vector; + for (auto op : op_codes) { + op_codes_vector.push_back(::tflite::CreateOperatorCode(builder_, op, 0)); + } + return builder_.CreateVector(op_codes_vector); + } + + Offset>> BuildOpCodes() { + return BuildOpCodes({::tflite::BuiltinOperator_MAX_POOL_2D, + ::tflite::BuiltinOperator_CONV_2D}); + } + + Offset>> BuildOperators( + std::initializer_list inputs, std::initializer_list outputs) { + auto is = builder_.CreateVector(inputs); + if (inputs.size() == 0) is = 0; + auto os = builder_.CreateVector(outputs); + if (outputs.size() == 0) os = 0; + auto op = ::tflite::CreateOperator( + builder_, 0, is, os, ::tflite::BuiltinOptions_Conv2DOptions, + ::tflite::CreateConv2DOptions(builder_, ::tflite::Padding_VALID, 1, 1, + ::tflite::ActivationFunctionType_NONE) + .Union(), + /*custom_options=*/0, ::tflite::CustomOptionsFormat_FLEXBUFFERS); + + return builder_.CreateVector(std::vector>({op})); + } + + Offset>> BuildOperators() { + return BuildOperators({0}, {1}); + } + + Offset>> BuildSubGraphs( + Offset>> tensors, + Offset>> operators, + int num_sub_graphs = 1) { + std::vector inputs = {0}; + std::vector outputs = {1}; + std::vector> v; + for (int i = 0; i < num_sub_graphs; ++i) { + v.push_back(::tflite::CreateSubGraph( + builder_, tensors, builder_.CreateVector(inputs), + builder_.CreateVector(outputs), operators, + builder_.CreateString("subgraph"))); + } + return builder_.CreateVector(v); + } + + // This is a very simplistic model. We are not interested in testing all the + // details here, since tf.mini's testing framework will be exercising all the + // conversions multiple times, and the conversion of operators is tested by + // separate unittests. + void BuildTestModel() { + auto buffers = BuildBuffers(); + auto tensors = BuildTensors(); + auto opcodes = BuildOpCodes(); + auto operators = BuildOperators(); + auto subgraphs = BuildSubGraphs(tensors, operators); auto s = builder_.CreateString(""); - builder_.Finish(::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, - opcodes, subgraphs, s, buffers)); + + ::tflite::FinishModelBuffer( + builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, + opcodes, subgraphs, s, buffers)); input_model_ = ::tflite::GetModel(builder_.GetBufferPointer()); } @@ -89,7 +139,6 @@ class ImportTest : public ::testing::Test { builder_.GetSize()); } flatbuffers::FlatBufferBuilder builder_; - // const uint8_t* buffer_ = nullptr; const ::tflite::Model* input_model_ = nullptr; }; @@ -106,7 +155,7 @@ TEST_F(ImportTest, LoadOperatorsTable) { details::OperatorsTable operators; details::LoadOperatorsTable(*input_model_, &operators); - EXPECT_THAT(operators, ElementsAre("custom_op_one", "CONV_2D")); + EXPECT_THAT(operators, ElementsAre("MAX_POOL_2D", "CONV_2D")); } TEST_F(ImportTest, Tensors) { @@ -118,9 +167,9 @@ TEST_F(ImportTest, Tensors) { Array& a1 = model->GetArray("tensor_one"); EXPECT_EQ(ArrayDataType::kFloat, a1.data_type); EXPECT_THAT(a1.GetBuffer().data, - ElementsAre(1.0f, 2.0f)); + ElementsAre(1.0f, 2.0f, 3.0f, 4.0f)); ASSERT_TRUE(a1.has_shape()); - EXPECT_THAT(a1.shape().dims(), ElementsAre(1, 2, 3, 4)); + EXPECT_THAT(a1.shape().dims(), ElementsAre(1, 2, 2)); const auto& mm = a1.minmax; ASSERT_TRUE(mm.get()); @@ -133,6 +182,80 @@ TEST_F(ImportTest, Tensors) { EXPECT_EQ(100, q->zero_point); } +TEST_F(ImportTest, NoBuffers) { + auto buffers = 0; + auto tensors = BuildTensors(); + auto opcodes = BuildOpCodes(); + auto operators = BuildOperators(); + auto subgraphs = BuildSubGraphs(tensors, operators); + auto comment = builder_.CreateString(""); + ::tflite::FinishModelBuffer( + builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes, + subgraphs, comment, buffers)); + EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()), + "Missing 'buffers' section."); +} + +TEST_F(ImportTest, NoInputs) { + auto buffers = BuildBuffers(); + auto tensors = BuildTensors(); + auto opcodes = BuildOpCodes(); + auto operators = BuildOperators({}, {1}); + auto subgraphs = BuildSubGraphs(tensors, operators); + auto comment = builder_.CreateString(""); + ::tflite::FinishModelBuffer( + builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes, + subgraphs, comment, buffers)); + EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()), + "Missing 'inputs' for operator."); +} + +TEST_F(ImportTest, NoOutputs) { + auto buffers = BuildBuffers(); + auto tensors = BuildTensors(); + auto opcodes = BuildOpCodes(); + auto operators = BuildOperators({0}, {}); + auto subgraphs = BuildSubGraphs(tensors, operators); + auto comment = builder_.CreateString(""); + ::tflite::FinishModelBuffer( + builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes, + subgraphs, comment, buffers)); + EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()), + "Missing 'outputs' for operator."); +} + +TEST_F(ImportTest, InvalidOpCode) { + auto buffers = BuildBuffers(); + auto tensors = BuildTensors(); + auto opcodes = BuildOpCodes({static_cast<::tflite::BuiltinOperator>(-1), + ::tflite::BuiltinOperator_CONV_2D}); + auto operators = BuildOperators(); + auto subgraphs = BuildSubGraphs(tensors, operators); + auto comment = builder_.CreateString(""); + ::tflite::FinishModelBuffer( + builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes, + subgraphs, comment, buffers)); + EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()), + "Operator id '-1' is out of range."); +} + +TEST_F(ImportTest, MultipleSubGraphs) { + auto buffers = BuildBuffers(); + auto tensors = BuildTensors(); + auto opcodes = BuildOpCodes(); + auto operators = BuildOperators(); + auto subgraphs = BuildSubGraphs(tensors, operators, 2); + auto comment = builder_.CreateString(""); + ::tflite::FinishModelBuffer( + builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes, + subgraphs, comment, buffers)); + + input_model_ = ::tflite::GetModel(builder_.GetBufferPointer()); + + EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()), + "Number of subgraphs in tflite should be exactly 1."); +} + // TODO(ahentz): still need tests for Operators and IOTensors. } // namespace diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index f2cc4ef71f71902e363ac4cddd3695446af30c7d..0989bfe5a3de9a7c0f62b272b0be84df1f4ddcb0 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -854,6 +854,8 @@ std::vector> BuildOperatorList() { new SimpleOperator("RELU_N1_TO_1", OperatorType::kRelu1)); ops.emplace_back( new SimpleOperator("RELU6", OperatorType::kRelu6)); + ops.emplace_back( + new SimpleOperator("PRELU", OperatorType::kPRelu)); ops.emplace_back(new SimpleOperator( "LOGISTIC", OperatorType::kLogistic)); ops.emplace_back( @@ -861,6 +863,8 @@ std::vector> BuildOperatorList() { ops.emplace_back(new SimpleOperator("EXP", OperatorType::kExp)); ops.emplace_back(new SimpleOperator( "LOG_SOFTMAX", OperatorType::kLogSoftmax)); + ops.emplace_back(new SimpleOperator( + "MAXIMUM", OperatorType::kTensorFlowMaximum)); return ops; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 9c19f8d4649acf40fdd85b78874f7b18798533f2..f7a213ecfc539e009f78e7c0e424d36a38b3486c 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -109,6 +109,8 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator("EXP", OperatorType::kExp); CheckSimpleOperator("LOG_SOFTMAX", OperatorType::kLogSoftmax); + CheckSimpleOperator( + "MAXIMUM", OperatorType::kTensorFlowMaximum); } TEST_F(OperatorTest, BuiltinAdd) { diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc index b4c2851502a40a1ca36965d4ddd2c8a15b8fe60f..0afd2f3df57caf3214dd198bfa2ee75fa7a8fd7b 100644 --- a/tensorflow/contrib/lite/toco/tflite/types.cc +++ b/tensorflow/contrib/lite/toco/tflite/types.cc @@ -90,6 +90,8 @@ flatbuffers::Offset> DataBuffer::Serialize( return CopyBuffer(array, builder); case ArrayDataType::kInt32: return CopyBuffer(array, builder); + case ArrayDataType::kInt64: + return CopyBuffer(array, builder); case ArrayDataType::kString: return CopyBuffer(array, builder); case ArrayDataType::kUint8: diff --git a/tensorflow/contrib/lite/toco/toco.cc b/tensorflow/contrib/lite/toco/toco.cc index f01ec0ec6102494f36cca0265b79e90355661271..8041aa9e7fbfdaf44134395fee4b2bb01633893a 100644 --- a/tensorflow/contrib/lite/toco/toco.cc +++ b/tensorflow/contrib/lite/toco/toco.cc @@ -23,40 +23,70 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" #include "tensorflow/contrib/lite/toco/toco_flags.pb.h" #include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/toco_saved_model.h" #include "tensorflow/contrib/lite/toco/toco_tooling.h" #include "tensorflow/contrib/lite/toco/toco_types.h" #include "tensorflow/core/platform/logging.h" -#ifndef CHECK_OK -#define CHECK_OK(val) CHECK_EQ((val).ok(), true) -#define QCHECK_OK(val) QCHECK_EQ((val).ok(), true) -#endif - namespace toco { namespace { -#define QCHECK_REQUIRE_TOCO_FLAG(arg) \ - QCHECK(parsed_toco_flags.arg.specified()) << "Missing required flag: " #arg; - -void CheckFilePermissions(const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags, - const TocoFlags& toco_flags) { - port::CheckInitGoogleIsDone("InitGoogle is not done yet"); - - QCHECK_REQUIRE_TOCO_FLAG(input_file) - QCHECK_OK(port::file::Exists(parsed_toco_flags.input_file.value(), - port::file::Defaults())) - << "Specified input_file does not exist: " - << parsed_toco_flags.input_file.value(); - QCHECK_OK(port::file::Readable(parsed_toco_flags.input_file.value(), - port::file::Defaults())) +// Checks the permissions of the output file to ensure it is writeable. +void CheckOutputFilePermissions(const Arg& output_file) { + QCHECK(output_file.specified()) << "Missing required flag --output_file.\n"; + QCHECK(port::file::Writable(output_file.value()).ok()) + << "Specified output_file is not writable: " << output_file.value() + << ".\n"; +} + +// Checks the permissions of the frozen model file. +void CheckFrozenModelPermissions(const Arg& input_file) { + QCHECK(input_file.specified()) << "Missing required flag --input_file.\n"; + QCHECK(port::file::Exists(input_file.value(), port::file::Defaults()).ok()) + << "Specified input_file does not exist: " << input_file.value() << ".\n"; + QCHECK(port::file::Readable(input_file.value(), port::file::Defaults()).ok()) << "Specified input_file exists, but is not readable: " - << parsed_toco_flags.input_file.value(); + << input_file.value() << ".\n"; +} - QCHECK_REQUIRE_TOCO_FLAG(output_file); - QCHECK_OK(port::file::Writable(parsed_toco_flags.output_file.value())) - << "parsed_toco_flags.input_file.value() output_file is not writable: " - << parsed_toco_flags.output_file.value(); +// Checks the permissions of the SavedModel directory. +void CheckSavedModelPermissions(const Arg& savedmodel_directory) { + QCHECK(savedmodel_directory.specified()) + << "Missing required flag --savedmodel_directory.\n"; + QCHECK( + port::file::Exists(savedmodel_directory.value(), port::file::Defaults()) + .ok()) + << "Specified savedmodel_directory does not exist: " + << savedmodel_directory.value() << ".\n"; +} + +// Reads the contents of the GraphDef from either the frozen graph file or the +// SavedModel directory. If it reads the SavedModel directory, it updates the +// ModelFlags and TocoFlags accordingly. +void ReadInputData(const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags, + TocoFlags* toco_flags, ModelFlags* model_flags, + string* graph_def_contents) { + port::CheckInitGoogleIsDone("InitGoogle is not done yet.\n"); + + bool has_input_file = parsed_toco_flags.input_file.specified(); + bool has_savedmodel_dir = parsed_toco_flags.savedmodel_directory.specified(); + + // Ensure either input_file or savedmodel_directory flag has been set. + QCHECK_NE(has_input_file, has_savedmodel_dir) + << "Specify either input_file or savedmodel_directory flag.\n"; + + // Checks the input file permissions and reads the contents. + if (has_input_file) { + CheckFrozenModelPermissions(parsed_toco_flags.input_file); + CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(), + graph_def_contents, port::file::Defaults()) + .ok()); + } else { + CheckSavedModelPermissions(parsed_toco_flags.savedmodel_directory); + GetSavedModelContents(parsed_toco_flags, parsed_model_flags, toco_flags, + model_flags, graph_def_contents); + } } void ToolMain(const ParsedTocoFlags& parsed_toco_flags, @@ -67,21 +97,20 @@ void ToolMain(const ParsedTocoFlags& parsed_toco_flags, TocoFlags toco_flags; ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags, &toco_flags); - CheckFilePermissions(parsed_toco_flags, parsed_model_flags, toco_flags); + string graph_def_contents; + ReadInputData(parsed_toco_flags, parsed_model_flags, &toco_flags, + &model_flags, &graph_def_contents); + CheckOutputFilePermissions(parsed_toco_flags.output_file); - string input_file_contents; - CHECK_OK(port::file::GetContents(parsed_toco_flags.input_file.value(), - &input_file_contents, - port::file::Defaults())); std::unique_ptr model = - Import(toco_flags, model_flags, input_file_contents); + Import(toco_flags, model_flags, graph_def_contents); Transform(toco_flags, model.get()); string output_file_contents; Export(toco_flags, *model, toco_flags.allow_custom_ops(), &output_file_contents); - CHECK_OK(port::file::SetContents(parsed_toco_flags.output_file.value(), - output_file_contents, - port::file::Defaults())); + CHECK(port::file::SetContents(parsed_toco_flags.output_file.value(), + output_file_contents, port::file::Defaults()) + .ok()); } } // namespace diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index 0f67c2de728532b5b8101b3514811a78a3b3bc38..cc7803dd866f0282f67d1d6f227cce0fdd8c7fd6 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/strip.h" +#include "absl/types/optional.h" #include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" #include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/core/platform/logging.h" @@ -38,6 +39,9 @@ bool ParseTocoFlagsFromCommandLineFlags( "Input file (model of any supported format). For Protobuf " "formats, both text and binary are supported regardless of file " "extension."), + Flag("savedmodel_directory", parsed_flags.savedmodel_directory.bind(), + parsed_flags.savedmodel_directory.default_value(), + "Full path to the directory containing the SavedModel."), Flag("output_file", parsed_flags.output_file.bind(), parsed_flags.output_file.default_value(), "Output file. " @@ -49,6 +53,11 @@ bool ParseTocoFlagsFromCommandLineFlags( parsed_flags.output_format.default_value(), "Output file format. " "One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT."), + Flag("savedmodel_tagset", parsed_flags.savedmodel_tagset.bind(), + parsed_flags.savedmodel_tagset.default_value(), + "Comma-separated set of tags identifying the MetaGraphDef within " + "the SavedModel to analyze. All tags in the tag set must be " + "specified."), Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(), parsed_flags.default_ranges_min.default_value(), "If defined, will be used as the default value for the min bound " @@ -128,47 +137,72 @@ bool ParseTocoFlagsFromCommandLineFlags( } } +namespace { + +// Defines the requirements for a given flag. kUseDefault means the default +// should be used in cases where the value isn't specified by the user. +enum class FlagRequirement { + kNone, + kMustBeSpecified, + kMustNotBeSpecified, + kUseDefault, +}; + +// Enforces the FlagRequirements are met for a given flag. +template +void EnforceFlagRequirement(const T& flag, const string& flag_name, + FlagRequirement requirement) { + if (requirement == FlagRequirement::kMustBeSpecified) { + QCHECK(flag.specified()) << "Missing required flag " << flag_name; + } + if (requirement == FlagRequirement::kMustNotBeSpecified) { + QCHECK(!flag.specified()) + << "Given other flags, this flag should not have been specified: " + << flag_name; + } +} + +// Gets the value from the flag if specified. Returns default if the +// FlagRequirement is kUseDefault. +template +absl::optional GetFlagValue(const Arg& flag, + FlagRequirement requirement) { + if (flag.specified()) return flag.value(); + if (requirement == FlagRequirement::kUseDefault) return flag.default_value(); + return absl::optional(); +} + +} // namespace + void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, TocoFlags* toco_flags) { namespace port = toco::port; port::CheckInitGoogleIsDone("InitGoogle is not done yet"); - enum class FlagRequirement { kNone, kMustBeSpecified, kMustNotBeSpecified }; - -#define ENFORCE_FLAG_REQUIREMENT(name, requirement) \ - do { \ - if (requirement == FlagRequirement::kMustBeSpecified) { \ - QCHECK(parsed_toco_flags.name.specified()) \ - << "Missing required flag: " << #name; \ - } \ - if (requirement == FlagRequirement::kMustNotBeSpecified) { \ - QCHECK(!parsed_toco_flags.name.specified()) \ - << "Given other flags, this flag should not have been specified: " \ - << #name; \ - } \ - } while (false) -#define READ_TOCO_FLAG(name, requirement) \ - ENFORCE_FLAG_REQUIREMENT(name, requirement); \ - do { \ - if (parsed_toco_flags.name.specified()) { \ - toco_flags->set_##name(parsed_toco_flags.name.value()); \ - } \ +#define READ_TOCO_FLAG(name, requirement) \ + do { \ + EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement); \ + auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \ + if (flag_value.has_value()) { \ + toco_flags->set_##name(flag_value.value()); \ + } \ } while (false) -#define PARSE_TOCO_FLAG(Type, name, requirement) \ - ENFORCE_FLAG_REQUIREMENT(name, requirement); \ - do { \ - if (parsed_toco_flags.name.specified()) { \ - Type x; \ - QCHECK(Type##_Parse(parsed_toco_flags.name.value(), &x)) \ - << "Unrecognized " << #Type << " value " \ - << parsed_toco_flags.name.value(); \ - toco_flags->set_##name(x); \ - } \ +#define PARSE_TOCO_FLAG(Type, name, requirement) \ + do { \ + EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement); \ + auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \ + if (flag_value.has_value()) { \ + Type x; \ + QCHECK(Type##_Parse(flag_value.value(), &x)) \ + << "Unrecognized " << #Type << " value " \ + << parsed_toco_flags.name.value(); \ + toco_flags->set_##name(x); \ + } \ } while (false) - PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kMustBeSpecified); - PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kMustBeSpecified); + PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kUseDefault); + PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kUseDefault); PARSE_TOCO_FLAG(IODataType, inference_type, FlagRequirement::kNone); PARSE_TOCO_FLAG(IODataType, inference_input_type, FlagRequirement::kNone); READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone); diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.cc b/tensorflow/contrib/lite/toco/toco_saved_model.cc new file mode 100644 index 0000000000000000000000000000000000000000..91a742b9e0d3c7ba5b5b955a3da27d7bf3d48871 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_saved_model.cc @@ -0,0 +1,186 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/strings/numbers.h" +#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" +#include "tensorflow/contrib/lite/toco/toco_saved_model.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" + +namespace toco { +namespace { + +// Loads a SavedModel from the directory specified in parsed_toco_flags. +// Returns a SavedModelBundle with the requested MetaGraphDef. +const tensorflow::SavedModelBundle* LoadSavedModel( + const ParsedTocoFlags& parsed_toco_flags) { + const string model_path = parsed_toco_flags.savedmodel_directory.value(); + QCHECK(tensorflow::MaybeSavedModelDirectory(model_path)) + << "Model is not saved in the supported SavedModel format.\n"; + + // Gets the tags identifying the MetaGraphDef from the command line arguments. + QCHECK(parsed_toco_flags.savedmodel_tagset.specified()) + << "Missing required flag --savedmodel_tagset.\n"; + const string tags_str = parsed_toco_flags.savedmodel_tagset.value(); + auto tags = absl::StrSplit(tags_str, ','); + + // Loads MetaGraphDef. + auto* bundle = new tensorflow::SavedModelBundle; + TF_CHECK_OK(tensorflow::LoadSavedModel(tensorflow::SessionOptions(), + tensorflow::RunOptions(), model_path, + tags, bundle)) + << "Failed to load exported model from " << model_path + << ". Ensure the model contains the required tags '" << tags_str + << "'.\n"; + return bundle; +} + +// Returns the array name without the postfix. +// +// e.g. reduces "input:0" to "input". +string GetArrayName(const string& name) { + const std::vector& names = absl::StrSplit(name, ':'); + return names[0]; +} + +// Returns the list of array names without the postfix sorted alphabetically. +std::set GetSortedNames(const std::unordered_set& names) { + std::vector final_names; + final_names.reserve(names.size()); + for (const auto& name : names) { + final_names.push_back(GetArrayName(name)); + } + return std::set(final_names.begin(), final_names.end()); +} + +// Gets the final shape after replacing the first dimension with batch size, if +// it is undefined (containing the value -1). Returns whether the shape is +// valid. +bool ReplaceShapeBatchSize(const tensorflow::TensorShapeProto& shape, + int batch_size, + tensorflow::TensorShapeProto* final_shape) { + for (int idx = 0; idx < shape.dim().size(); ++idx) { + int64 final_dim = shape.dim()[idx].size(); + if (final_dim == -1) { + if (idx > 0) return false; + final_dim = batch_size; + } + final_shape->add_dim()->set_size(final_dim); + } + return true; +} + +// Updates the input arrays in ModelFlags to contain the shape of the array. +void ProcessInputShapes(const tensorflow::GraphDef& graph_def, int batch_size, + ModelFlags* model_flags) { + // Build map of input array names to input arrays. + std::unordered_map input_data_map; + for (auto& input : *model_flags->mutable_input_arrays()) { + input_data_map[input.name()] = &input; + } + + // Adds shapes to the input arrays if the shape is valid. + for (const tensorflow::NodeDef& node_def : graph_def.node()) { + if (input_data_map.find(node_def.name()) != input_data_map.end()) { + const auto shape_it = node_def.attr().find("shape"); + if (shape_it != node_def.attr().end()) { + tensorflow::TensorShapeProto final_shape; + bool is_valid = ReplaceShapeBatchSize(shape_it->second.shape(), + batch_size, &final_shape); + + if (is_valid) { + auto* shape = input_data_map.at(node_def.name())->mutable_shape(); + QCHECK_EQ(shape->dims_size(), 0) + << "The shape for the input '" << node_def.name() + << "' was previously defined. For clarity please define inputs " + << "via --input_arrays and input_shapes flags.\n"; + for (const auto& dim : final_shape.dim()) { + shape->add_dims(dim.size()); + } + } + } + } + } + + // Checks all input arrays have a shape. + for (auto const& input : model_flags->input_arrays()) { + QCHECK(input.shape().dims_size() > 0) + << "A valid input shape was not found for input '" << input.name() + << "'. Please define via --input_arrays and --input_shapes flags.\n"; + } +} + +} // namespace + +void ParseMetaData(const tensorflow::GraphDef& graph_def, + const std::unordered_set& inputs, + const std::unordered_set& outputs, + const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags, + TocoFlags* toco_flags, ModelFlags* model_flags) { + if (!parsed_model_flags.input_arrays.specified()) { + const std::set sorted_inputs = GetSortedNames(inputs); + for (const auto& input_name : sorted_inputs) { + model_flags->add_input_arrays()->set_name(input_name); + } + } + + if (!parsed_model_flags.output_arrays.specified()) { + const std::set sorted_outputs = GetSortedNames(outputs); + for (const auto& output_name : sorted_outputs) { + model_flags->add_output_arrays(GetArrayName(output_name)); + } + } + + if (!parsed_model_flags.input_shapes.specified()) { + int batch_size = parsed_model_flags.batch_size.value(); + ProcessInputShapes(graph_def, batch_size, model_flags); + } + + if (!parsed_toco_flags.inference_type.specified()) { + toco_flags->set_inference_type(IODataType::FLOAT); + } +} + +// TODO(nupurgarg): Add top level tests. +void GetSavedModelContents(const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags, + TocoFlags* toco_flags, ModelFlags* model_flags, + string* graph_def_contents) { + // Loads the MetaGraphDef within a SavedModelBundle. + auto bundle = LoadSavedModel(parsed_toco_flags); + + // Converts the MetaGraphDef to frozen GraphDef. + tensorflow::GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_CHECK_OK(tensorflow::FreezeSavedModel(*bundle, &frozen_graph_def, &inputs, + &outputs)); + + // Reads the frozen GraphDef into a string. + QCHECK(frozen_graph_def.SerializeToString(graph_def_contents)) + << "Unable to generate serialized GraphDef.\n"; + + // Process inputs and outputs and metadata within GraphDef. + const tensorflow::GraphDef graph_def = bundle->meta_graph_def.graph_def(); + ParseMetaData(graph_def, inputs, outputs, parsed_toco_flags, + parsed_model_flags, toco_flags, model_flags); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.h b/tensorflow/contrib/lite/toco/toco_saved_model.h new file mode 100644 index 0000000000000000000000000000000000000000..7a0fabd82d90131a3b2d28c757c08dcb0f9e3988 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_saved_model.h @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_ + +#include +#include + +#include "tensorflow/cc/tools/freeze_saved_model.h" +#include "tensorflow/contrib/lite/toco/args.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" +#include "tensorflow/contrib/lite/toco/types.pb.h" + +namespace toco { + +// Parses metadata into `toco_flags` and `model_flags`. +// +// Stores `inputs` as input_arrays and `outputs` as output_arrays in +// `model_flags`. Infers input_shapes from the GraphDef and stores it in +// `model_flags` as part of the input_arrays. Assumes inference_type is FLOAT +// and stores it in `toco_flags`. +void ParseMetaData(const tensorflow::GraphDef& graph_def, + const std::unordered_set& inputs, + const std::unordered_set& outputs, + const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags, + TocoFlags* toco_flags, ModelFlags* model_flags); + +// Generates a frozen graph from the SavedModel in the directory specified in +// `toco_flags`. Reads frozen graph contents into `graph_def_contents`. Parses +// metadata relating to the GraphDef into `toco_flags` and `model_flags`. +void GetSavedModelContents(const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags, + TocoFlags* toco_flags, ModelFlags* model_flags, + string* graph_def_contents); + +} // namespace toco + +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_ diff --git a/tensorflow/contrib/lite/toco/toco_saved_model_test.cc b/tensorflow/contrib/lite/toco/toco_saved_model_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e122afe65dc29abc85f142f4019aae5058ace51 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_saved_model_test.cc @@ -0,0 +1,274 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/toco/toco_saved_model.h" +#include "absl/strings/str_join.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" +#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +#include +#include + +namespace toco { +namespace { + +using tensorflow::ops::Add; +using tensorflow::ops::Const; +using tensorflow::ops::FakeQuantWithMinMaxArgs; +using tensorflow::ops::Placeholder; + +class TocoSavedModelTest : public ::testing::Test { + protected: + // Calls functions to process cmdline arguments and calls ParseMetaData. + // ParseMetaData parses input_arrays, output_arrays, and gets metadata from + // SavedModel it is not defined in the cmdline arguments. + void ProcessGraphDefMetadata(const std::unordered_set& inputs, + const std::unordered_set& outputs, + const tensorflow::GraphDef& graph_def) { + ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags_, &toco_flags_); + ReadModelFlagsFromCommandLineFlags(parsed_model_flags_, &model_flags_); + ParseMetaData(graph_def, inputs, outputs, parsed_toco_flags_, + parsed_model_flags_, &toco_flags_, &model_flags_); + } + + // Gets the GraphDef from the SavedModelBundle and processes metadata. + void ProcessSavedModelMetadata(const std::unordered_set& inputs, + const std::unordered_set& outputs) { + const tensorflow::GraphDef graph_def = bundle_.meta_graph_def.graph_def(); + ProcessGraphDefMetadata(inputs, outputs, graph_def); + } + + // Returns a GraphDef representing a simple float model with a single input. + tensorflow::GraphDef GetFloatGraphDef(const std::vector& shape) { + tensorflow::GraphDef graph_def; + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + tensorflow::Output input = + Placeholder(scope.WithOpName("input"), tensorflow::DT_FLOAT, + Placeholder::Shape(tensorflow::PartialTensorShape(shape))); + tensorflow::Output zero = Const(scope.WithOpName("zero"), 0.0f, {}); + tensorflow::Output add = Add(scope.WithOpName("add"), input, zero); + + TF_EXPECT_OK(scope.ToGraphDef(&graph_def)); + return graph_def; + } + + // Returns a GraphDef representing a simple float model with two inputs. + tensorflow::GraphDef GetComplexFloatGraphDef() { + tensorflow::GraphDef graph_def; + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + tensorflow::Output inputA = + Placeholder(scope.WithOpName("inputA"), tensorflow::DT_FLOAT, + Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1}))); + tensorflow::Output inputB = + Placeholder(scope.WithOpName("inputB"), tensorflow::DT_FLOAT, + Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1}))); + tensorflow::Output add = Add(scope.WithOpName("add"), inputB, inputA); + + TF_EXPECT_OK(scope.ToGraphDef(&graph_def)); + return graph_def; + } + + // Returns a GraphDef representing a simple quantized model. + tensorflow::GraphDef GetQuantizedGraphDef() { + tensorflow::GraphDef graph_def; + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + tensorflow::Output input = + Placeholder(scope.WithOpName("input"), tensorflow::DT_FLOAT, + Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1}))); + tensorflow::Output zero = Const(scope.WithOpName("zero"), 0.0f, {}); + tensorflow::Output fake_quant = + FakeQuantWithMinMaxArgs(scope.WithOpName("quant"), zero); + tensorflow::Output add = Add(scope.WithOpName("add"), input, fake_quant); + + TF_EXPECT_OK(scope.ToGraphDef(&graph_def)); + return graph_def; + } + + // Gets the values in the input_arrays flag. + std::vector GetInputArrays() { + std::vector actual; + for (const auto& input : model_flags_.input_arrays()) { + actual.push_back(input.name()); + } + return actual; + } + + // Gets the values in the output_arrays flag. + std::vector GetOutputArrays() { + std::vector actual(model_flags_.output_arrays().begin(), + model_flags_.output_arrays().end()); + return actual; + } + + // Gets the shape of the given input array. + string GetInputShape(const string& input_array) { + for (const auto& input : model_flags_.input_arrays()) { + if (input.name() == input_array) { + std::vector dims; + for (int idx = 0; idx < input.shape().dims_size(); ++idx) { + dims.push_back(std::to_string(input.shape().dims(idx))); + } + return absl::StrJoin(dims, ","); + } + } + return ""; + } + + tensorflow::SavedModelBundle bundle_; + ParsedTocoFlags parsed_toco_flags_; + ParsedModelFlags parsed_model_flags_; + TocoFlags toco_flags_; + ModelFlags model_flags_; +}; + +// Tests if input_arrays, output_arrays, inference_type, and output_arrays are +// added to ModelFlags if they are not specified in cmdline arguments. +// Tests if the default batch size replaces a -1 in the first dimension. +TEST_F(TocoSavedModelTest, NoCmdLine) { + tensorflow::GraphDef graph_def = GetFloatGraphDef({-1, 3, 3, 1}); + + ProcessGraphDefMetadata({"input"}, {"add"}, graph_def); + EXPECT_EQ(GetInputArrays(), std::vector({"input"})); + EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); + EXPECT_EQ(GetInputShape("input"), "1,3,3,1"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Tests if the order of input_arrays and output_arrays is deterministic when +// they are taken from the SavedModel. +TEST_F(TocoSavedModelTest, NoCmdLineMultipleArrays) { + tensorflow::GraphDef graph_def = GetComplexFloatGraphDef(); + + // Note: The model does not have two outputs. However, the function does not + // need an accurate output_array list. This is only meant to test order. + ProcessGraphDefMetadata({"inputB", "inputA"}, {"add", "invalid"}, graph_def); + EXPECT_EQ(GetInputArrays(), std::vector({"inputA", "inputB"})); + EXPECT_EQ(GetOutputArrays(), std::vector({"add", "invalid"})); + EXPECT_EQ(GetInputShape("inputA"), "1,3,3,1"); + EXPECT_EQ(GetInputShape("inputB"), "1,3,3,1"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Tests if input_shapes is inferred when input_arrays is passed in via cmdline +// arguments. +TEST_F(TocoSavedModelTest, InputNameWithoutInputShape) { + parsed_model_flags_.input_arrays.bind()("input"); + tensorflow::GraphDef graph_def = GetFloatGraphDef({2, 3, 3, 1}); + + ProcessGraphDefMetadata({"not_used_input"}, {"add"}, graph_def); + EXPECT_EQ(GetInputArrays(), std::vector({"input"})); + EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); + EXPECT_EQ(GetInputShape("input"), "2,3,3,1"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Ensures a failure occurs when input_shapes is defined without input_arrays. +TEST_F(TocoSavedModelTest, InputShapeWithoutInputName) { + parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12"); + tensorflow::GraphDef graph_def = GetFloatGraphDef({1, 3, 3, 1}); + + EXPECT_DEATH(ProcessGraphDefMetadata({"input"}, {"add"}, graph_def), + "failed: input_shapes.size\\(\\) == " + "model_flags->input_arrays_size\\(\\)"); +} + +// Tests if the cmdline values of input_arrays, input_shapes are used when +// specified with an empty GraphDef. +TEST_F(TocoSavedModelTest, InputArraysCmdLine) { + parsed_model_flags_.input_arrays.bind()("inputA,inputB"); + parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12"); + + ProcessSavedModelMetadata({"input0", "input1"}, {"output0", "output1"}); + EXPECT_EQ(GetInputArrays(), std::vector({"inputA", "inputB"})); + EXPECT_EQ(GetOutputArrays(), std::vector({"output0", "output1"})); + EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1"); + EXPECT_EQ(GetInputShape("inputB"), "9,12"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Tests if the cmdline values of input_arrays, input_shapes are used when +// specified even if values exist within the GraphDef. +TEST_F(TocoSavedModelTest, InputArraysCmdLineWithGraphDef) { + parsed_model_flags_.input_arrays.bind()("inputA"); + parsed_model_flags_.input_shapes.bind()("1,224,224,1"); + tensorflow::GraphDef graph_def = GetFloatGraphDef({1, 3, 3, 1}); + + ProcessGraphDefMetadata({"inputA"}, {"add"}, graph_def); + EXPECT_EQ(GetInputArrays(), std::vector({"inputA"})); + EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); + EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Tests if the cmdline values of input_arrays, input_shapes, inference_type, +// and output_arrays are used when specified with an empty GraphDef. +TEST_F(TocoSavedModelTest, AllParamsCmdLine) { + parsed_model_flags_.input_arrays.bind()("inputA,inputB"); + parsed_model_flags_.output_arrays.bind()("outputA,outputB"); + parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12"); + parsed_toco_flags_.inference_type.bind()("FLOAT"); + + ProcessSavedModelMetadata({"input0", "input1"}, {"output0", "output1"}); + EXPECT_EQ(GetInputArrays(), std::vector({"inputA", "inputB"})); + EXPECT_EQ(GetOutputArrays(), std::vector({"outputA", "outputB"})); + EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1"); + EXPECT_EQ(GetInputShape("inputB"), "9,12"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Tests if a quantized graph gives the correct values assuming type is passed +// in via command line. +TEST_F(TocoSavedModelTest, QuantizedNoCmdLine) { + parsed_toco_flags_.inference_type.bind()("QUANTIZED_UINT8"); + tensorflow::GraphDef graph_def = GetQuantizedGraphDef(); + + ProcessGraphDefMetadata({"input"}, {"add"}, graph_def); + EXPECT_EQ(GetInputArrays(), std::vector({"input"})); + EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); + EXPECT_EQ(GetInputShape("input"), "1,3,3,1"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::QUANTIZED_UINT8); +} + +// Tests if the provided batch size replaces a -1 in the first dimension of +// input shape. +TEST_F(TocoSavedModelTest, MissingShapeParameterValid) { + parsed_model_flags_.batch_size.bind()(3); + tensorflow::GraphDef graph_def = GetFloatGraphDef({-1, 3, 3, 1}); + + ProcessGraphDefMetadata({"input"}, {"add"}, graph_def); + EXPECT_EQ(GetInputArrays(), std::vector({"input"})); + EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); + EXPECT_EQ(GetInputShape("input"), "3,3,3,1"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Ensures a failure occurs if there is a -1 in a dimension aside from the first +// position of input shape. +TEST_F(TocoSavedModelTest, MissingShapeParameterInvalid) { + parsed_model_flags_.batch_size.bind()(3); + tensorflow::GraphDef graph_def = GetFloatGraphDef({1, -1, 3, 1}); + + EXPECT_DEATH(ProcessGraphDefMetadata({"input"}, {"add"}, graph_def), + "A valid input shape was not found for input 'input'."); +} + +} // namespace +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index a09a3c4ef56edc6ba7fd19eb1ff45a2e41cf3dd2..30dd6fab9ebbad9c2add7f830f9b58a73f41714b 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -52,12 +52,14 @@ void MakeGeneralGraphTransformationsSet( GraphTransformationsSet* transformations) { CHECK(transformations->empty()); transformations->Add(new ConvertExpandDimsToReshape); + transformations->Add(new ConvertSqueezeToReshape); transformations->Add(new ConvertTrivialAddNToAdd); transformations->Add(new ConvertTrivialStackToReshape); transformations->Add(new ConvertTrivialTransposeToReshape); transformations->Add(new ConvertReorderAxes); transformations->Add(new ResolveReshapeAttributes); transformations->Add(new ResolveTransposeAttributes); + transformations->Add(new PropagateActivationFunctionIntoConstants); transformations->Add(new PropagateArrayDataTypes); transformations->Add(new PropagateFixedSizes); transformations->Add(new RemoveTensorFlowAssert); @@ -76,6 +78,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveBatchNormalization); transformations->Add(new ResolveConstantBinaryOperator); transformations->Add(new ResolveConstantFill); + transformations->Add(new ResolveConstantGather); transformations->Add(new ResolveConstantRange); transformations->Add(new ResolveConstantStack); transformations->Add(new ResolveConstantStridedSlice); @@ -91,6 +94,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new IdentifyL2Normalization); transformations->Add(new IdentifyL2Pool); transformations->Add(new IdentifyRelu1); + transformations->Add(new IdentifyPRelu); transformations->Add(new RemoveTrivialBinaryOperator); transformations->Add(new ReadFakeQuantMinMax); transformations->Add(new ResolveSpaceToBatchNDAttributes); @@ -102,6 +106,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveConstantShapeOrRank); transformations->Add(new MakeInitialDequantizeOperator); transformations->Add(new ResolveConstantFakeQuant); + transformations->Add(new UnpartitionEmbeddingLookup); } bool SupportsQuantization(FileFormat format) { @@ -285,6 +290,10 @@ void Transform(const TocoFlags& toco_flags, Model* model) { EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(model); } + // Fix any issues with IO edges. This must happen after any transform that + // may modify the structure of the edges. + FixEdgeArrays(model); + LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model); if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) { diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 9e725822383b06985bbb5cffdc19a759bc6d5cf3..f3f50487ff74904bf3708fa4c86f522997b55ca0 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -84,6 +84,8 @@ string ArrayDataTypeName(ArrayDataType data_type) { return "Uint64"; case ArrayDataType::kString: return "String"; + case ArrayDataType::kBool: + return "Bool"; case ArrayDataType::kNone: return "None"; default: @@ -157,6 +159,15 @@ bool DeleteArrayIfUsedOnce(const string& array_name, Model* model) { return false; } +void DeleteOpAndArraysIfUnused(Model* model, Operator* op) { + for (const string& array_name : op->inputs) { + DeleteArrayIfUsedOnce(array_name, model); + } + auto op_it = FindOp(*model, op); + CHECK(op_it != model->operators.end()); + model->operators.erase(op_it); +} + std::vector>::const_iterator FindOpWithOutput( const Model& model, const string& array_name) { for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { @@ -289,6 +300,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Relu) HANDLE_OPERATORTYPENAME_CASE(Relu1) HANDLE_OPERATORTYPENAME_CASE(Relu6) + HANDLE_OPERATORTYPENAME_CASE(PRelu) HANDLE_OPERATORTYPENAME_CASE(ReorderAxes) HANDLE_OPERATORTYPENAME_CASE(Softmax) HANDLE_OPERATORTYPENAME_CASE(LogSoftmax) @@ -345,6 +357,8 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(TopK_V2) HANDLE_OPERATORTYPENAME_CASE(TensorFlowUnsupported) HANDLE_OPERATORTYPENAME_CASE(Exp) + HANDLE_OPERATORTYPENAME_CASE(DynamicPartition) + HANDLE_OPERATORTYPENAME_CASE(DynamicStitch) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE @@ -809,9 +823,15 @@ void CheckEachArray(const Model& model) { // It's OK to have a buffer or an alloc, but not both. // (Since allocs are for transient arrays without a buffer). CHECK(!array->buffer || !array->alloc); - // If there is a buffer, its type should be consistent with data_type. if (array->buffer) { + // If there is a buffer, its type should be consistent with data_type. CHECK(array->buffer->type == array->data_type); + // The presence of a fixed buffer should imply the presence of a fixed + // shape. + CHECK(array->has_shape()); + // The shape flat-size should agree with the buffer length. + CHECK_EQ(array->buffer->Length(), + RequiredBufferSizeForShape(array->shape())); } // Check name. Either "name_with_suffix_8", "name_with_port:3", but not @@ -1028,6 +1048,117 @@ void CheckModelCounts(const Model& model) { } } +void FixEdgeArrays(Model* model) { + for (const string& output_array_name : model->flags.output_arrays()) { + if (!GetOpWithOutput(*model, output_array_name)) { + // Output has no operator producing it. Change that by inserting a copy. + LOG(WARNING) << "Fixing constant output array " << output_array_name + << " by inserting a copy. This is not optimal."; + string intermediate_array_name = + AvailableArrayName(*model, output_array_name + "_copy"); + CloneArray(model, output_array_name, intermediate_array_name); + InsertCopyOperator(model, intermediate_array_name, output_array_name); + } + } +} + +void InsertCopyOperator(Model* model, const string& source_array_name, + const string& target_array_name) { + // Drop constant data from the target array as the copy will be done at + // runtime. + Array& target_array = model->GetOrCreateArray(target_array_name); + target_array.buffer.reset(); + + // Reshape to the same size. This should be a no-op. + const Array& source_array = model->GetArray(source_array_name); + std::vector shape = source_array.shape().dims(); + + // Insert copy operator. + auto* copy_op = new TensorFlowReshapeOperator; + copy_op->inputs = { + source_array_name, + CreateInt32Array(model, target_array_name + "_copy_shape", shape)}; + copy_op->outputs = {target_array_name}; + model->operators.emplace_back(copy_op); +} + +namespace { +template +void CopyArrayBuffer(const Array& source_array, Array* target_array) { + if (source_array.buffer) { + const auto& source_buffer = source_array.GetBuffer(); + auto& target_buffer = target_array->GetMutableBuffer(); + target_buffer.data = source_buffer.data; + } +} +} // namespace + +void CloneArray(Model* model, const string& source_array_name, + const string& target_array_name) { + CHECK(!model->HasArray(target_array_name)); + const Array& source_array = model->GetArray(source_array_name); + Array& target_array = model->GetOrCreateArray(target_array_name); + + switch (source_array.data_type) { + case ArrayDataType::kBool: + CopyArrayBuffer(source_array, &target_array); + break; + case ArrayDataType::kFloat: + CopyArrayBuffer(source_array, &target_array); + break; + case ArrayDataType::kInt8: + CopyArrayBuffer(source_array, &target_array); + break; + case ArrayDataType::kUint8: + CopyArrayBuffer(source_array, &target_array); + break; + case ArrayDataType::kInt16: + CopyArrayBuffer(source_array, &target_array); + break; + case ArrayDataType::kUint16: + CopyArrayBuffer(source_array, &target_array); + break; + case ArrayDataType::kInt32: + CopyArrayBuffer(source_array, &target_array); + break; + case ArrayDataType::kUint32: + CopyArrayBuffer(source_array, &target_array); + break; + case ArrayDataType::kInt64: + CopyArrayBuffer(source_array, &target_array); + break; + case ArrayDataType::kUint64: + CopyArrayBuffer(source_array, &target_array); + break; + case ArrayDataType::kString: + CopyArrayBuffer(source_array, &target_array); + break; + default: + LOG(FATAL) << "Unsupported data type: " + << ArrayDataTypeName(source_array.data_type); + return; + } + + if (source_array.minmax) { + const auto& smm = source_array.GetMinMax(); + auto& tmm = target_array.GetOrCreateMinMax(); + tmm.min = smm.min; + tmm.max = smm.max; + } + + if (source_array.quantization_params) { + const auto& sqp = source_array.GetQuantizationParams(); + auto& tqp = target_array.GetOrCreateQuantizationParams(); + tqp.zero_point = sqp.zero_point; + tqp.scale = sqp.scale; + } + + target_array.data_type = source_array.data_type; + target_array.final_data_type = source_array.final_data_type; + + target_array.copy_shape(source_array.shape()); +} + void MakeArrayDims(int num_dims, int batch, int height, int width, int depth, std::vector* out_dims) { CHECK(out_dims->empty()); @@ -1191,7 +1322,7 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { << "This model does not define output arrays, so a " "--output_arrays flag must be given on the command-line."; - for (const auto& input_array_proto : model->flags.input_arrays()) { + for (auto& input_array_proto : *model->flags.mutable_input_arrays()) { auto& input_array = model->GetOrCreateArray(input_array_proto.name()); if (input_array_proto.has_data_type()) { const ArrayDataType specified_type = @@ -1235,6 +1366,11 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { for (int i = 0; i < input_array_dims.size(); i++) { CHECK_EQ(input_array_dims[i], input_array_proto.shape().dims(i)); } + } else { + for (int i = 0; i < input_array.shape().dimensions_count(); i++) { + input_array_proto.mutable_shape()->add_dims( + input_array.shape().dims(i)); + } } } @@ -1330,6 +1466,8 @@ void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min, int ElementSize(ArrayDataType data_type) { switch (data_type) { + case ArrayDataType::kBool: + return sizeof(bool); case ArrayDataType::kFloat: return 4; case ArrayDataType::kInt8: @@ -1355,7 +1493,7 @@ int ElementSize(ArrayDataType data_type) { LOG(FATAL) << "Transient arrays with strings are not supported yet"; return 0; default: - LOG(FATAL) << "Should not get here."; + LOG(FATAL) << "Unknown data_type = " << static_cast(data_type); return 0; } } @@ -1785,7 +1923,10 @@ bool IsDiscardableArray(const Model& model, const string& array_name) { void CheckFinalDataTypesSatisfied(const Model& model) { for (const auto& array_entry : model.GetArrayMap()) { const auto& array = *array_entry.second; - if (array.final_data_type != ArrayDataType::kNone) { + // If the final data type is int16, the data type may be float, for example + // after dequantization. + if (array.final_data_type != ArrayDataType::kNone && + array.final_data_type != ArrayDataType::kInt16) { CHECK(array.final_data_type == array.data_type) << "Array \"" << array_entry.first << "\" has mis-matching actual and final data types (" @@ -1831,9 +1972,9 @@ void FinishBuildingRNNStates(Model* model) { void UseArraysExtraInfo(Model* model) { for (const auto& entry : model->flags.arrays_extra_info().entries()) { - QCHECK(model->HasArray(entry.name())) - << "ArraysExtraInfo refers to non-existent array name: " - << entry.name(); + if (!model->HasArray(entry.name())) { + continue; + } auto& array = model->GetArray(entry.name()); auto& minmax = array.GetOrCreateMinMax(); if (entry.has_min() || entry.has_max()) { @@ -1845,6 +1986,24 @@ void UseArraysExtraInfo(Model* model) { array.final_data_type = ConvertIODataTypeToArrayDataType(entry.data_type()); } + if (entry.has_shape()) { + array.clear_shape(); + // Make sure to create the shape even if there are no dims, to + // correctly record 0-D shapes. + array.mutable_shape(); + for (int dim : entry.shape().dims()) { + array.mutable_shape()->mutable_dims()->push_back(dim); + } + } + if (entry.has_constant_float_value()) { + CHECK(array.has_shape()); + CHECK(array.data_type == ArrayDataType::kFloat); + auto& data = array.GetMutableBuffer().data; + data.resize(RequiredBufferSizeForShape(array.shape())); + for (float& f : data) { + f = entry.constant_float_value(); + } + } } } diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index 11208ed667212d56f9ef45e4f394e0bbf5000cbc..d3b7224fe3a773e389ad8fc9a40f0a0fad4debe5 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -28,6 +28,7 @@ limitations under the License. #if TOCO_SUPPORT_PORTABLE_PROTOS #include "third_party/protobuf/src/google/protobuf/text_format.h" #endif // TOCO_SUPPORT_PORTABLE_PROTOS +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/runtime/types.h" @@ -64,6 +65,10 @@ int CountOpsWithInput(const Model& model, const string& array_name); bool DeleteArrayIfUnused(const string& array_name, Model* model); bool DeleteArrayIfUsedOnce(const string& array_name, Model* model); +// Deletes the op and any of its input and output arrays if they are unused +// after the op has been deleted. +void DeleteOpAndArraysIfUnused(Model* model, Operator* op); + std::vector>::const_iterator FindOpWithOutput( const Model& model, const string& array_name); Operator* GetOpWithOutput(const Model& model, const string& array_name); @@ -71,8 +76,6 @@ Operator* GetOpWithOutput(const Model& model, const string& array_name); std::vector>::iterator FindOpWithOutput( Model& model, const string& array_name); -Operator* GetOpWithOutput(const Model& model, const string& array_name); - std::vector>::const_iterator FindOpWithInput( const Model& model, const string& array_name); @@ -141,78 +144,29 @@ void FixOperatorOrdering(Model* model); void FixNoMissingArray(Model* model); void FixNoOrphanedArray(Model* model); +// Fixes input/output arrays that may have issues during export or inference. +void FixEdgeArrays(Model* model); + +// Inserts a no-op reshape operator between the source array and the target +// array. This effectively just copies the data. +void InsertCopyOperator(Model* model, const string& source_array_name, + const string& target_array_name); + +// Clones an array with all data and parameters. +void CloneArray(Model* model, const string& source_array_name, + const string& target_array_name); + void ResolveModelFlags(const ModelFlags& model_flags, Model* model); template -void GetQuantizationParamsFromMinMax(const ModelFlags& model_flags, - const MinMax& minmax, +void GetQuantizationParamsFromMinMax(const MinMax& minmax, QuantizationParams* quantization_params) { using Integer = DataType; - const Integer qmin = std::numeric_limits::min(); - const Integer qmax = std::numeric_limits::max(); - const double qmin_double = qmin; - const double qmax_double = qmax; const double rmin = minmax.min; const double rmax = minmax.max; - // 0 should always be a representable value. Let's assume that the initial - // min,max range contains 0. - CHECK_LE(rmin, 0.); - CHECK_GE(rmax, 0.); - if (rmin == rmax) { - // Special case where the min,max range is a point. Should be {0}. - CHECK_EQ(rmin, 0.); - CHECK_EQ(rmax, 0.); - quantization_params->zero_point = 0; - quantization_params->scale = 0.; - return; - } - // General case. - // - // First determine the scale. - const double scale = (rmax - rmin) / (qmax_double - qmin_double); - - // Zero-point computation. - // First the initial floating-point computation. The zero-point can be - // determined from solving an affine equation for any known pair - // (real value, corresponding quantized value). - // We know two such pairs: (rmin, qmin) and (rmax, qmax). - // The arithmetic error on the zero point computed from either pair - // will be roughly machine_epsilon * (sum of absolute values of terms) - // so we want to use the variant that adds the smaller terms. - const double zero_point_from_min = qmin_double - rmin / scale; - const double zero_point_from_max = qmax_double - rmax / scale; - const double zero_point_from_min_error = - std::abs(qmin_double) + std::abs(rmin / scale); - const double zero_point_from_max_error = - std::abs(qmax_double) + std::abs(rmax / scale); - - const double zero_point_double = - zero_point_from_min_error < zero_point_from_max_error - ? zero_point_from_min - : zero_point_from_max; - - // Now we need to nudge the zero point to be an integer - // (our zero points are integer, and this is motivated by the requirement - // to be able to represent the real value "0" exactly as a quantized value, - // which is required in multiple places, for example in Im2col with SAME - // padding). - Integer nudged_zero_point = 0; - if (zero_point_double < qmin_double) { - nudged_zero_point = qmin; - } else if (zero_point_double > qmax_double) { - nudged_zero_point = qmax; - } else { - nudged_zero_point = static_cast(std::round(zero_point_double)); - } - // The zero point should always be in the range of quantized value, - // [qmin, qmax]. - CHECK_GE(nudged_zero_point, qmin); - CHECK_LE(nudged_zero_point, qmax); - - // Finally, store the result nudged quantization params. - quantization_params->zero_point = nudged_zero_point; - quantization_params->scale = scale; + *quantization_params = + ::tflite::ChooseQuantizationParams(rmin, rmax); } void CheckIsReadyForQuantization(const Model& model); diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD index 999ccf2ebc009b6b7c50a9a2d1667d69a3f690e7..b5abbc0712599814e078d19bc015bc7bf1812f95 100644 --- a/tensorflow/contrib/lite/tools/BUILD +++ b/tensorflow/contrib/lite/tools/BUILD @@ -4,6 +4,7 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 +load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") py_binary( @@ -45,7 +46,15 @@ tf_cc_binary( "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/kernels:builtin_ops", - ], + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], + }), ) cc_library( @@ -111,6 +120,9 @@ cc_test( name = "verifier_test", size = "small", srcs = ["verifier_test.cc"], + tags = [ + "tflite_not_portable", + ], deps = [ ":mutable_op_resolver", ":verifier", @@ -124,3 +136,5 @@ cc_test( "@flatbuffers", ], ) + +tflite_portable_test_suite() diff --git a/tensorflow/contrib/lite/tools/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark_model.cc index 6ae3ab57294a92162b15f326630ac202a9ba2a82..93c80e0f5e021f76bff6858b0ea3370724393d6d 100644 --- a/tensorflow/contrib/lite/tools/benchmark_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark_model.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,36 +25,89 @@ limitations under the License. #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/string_util.h" #include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" #ifdef TFLITE_CUSTOM_OPS_HEADER void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); #endif -#define LOG(x) std::cerr +namespace tflite { -#define CHECK(x) \ - if (!(x)) { \ - LOG(ERROR) << #x << "failed"; \ - exit(1); \ +using ::tensorflow::Env; +using ::tensorflow::str_util::Split; +using ::tensorflow::str_util::SplitAndParseAsFloats; +using ::tensorflow::str_util::SplitAndParseAsInts; + +struct InputLayerInfo { + string name; + TfLiteType data_type; + std::vector shape; + // Note that initialization_values is currently unused. + std::vector initialization_values; +}; + +template +void FillRandomValue(T* ptr, const std::vector& sizes, + const std::function& random_func) { + int num_elements = 1; + for (int dim : sizes) { + num_elements *= dim; + } + for (int i = 0; i < num_elements; ++i) { + *ptr++ = random_func(); } +} -namespace tensorflow { -namespace benchmark_tflite_model { +void FillRandomString(tflite::DynamicBuffer* buffer, + const std::vector& sizes, + const std::function& random_func) { + int num_elements = 1; + for (int dim : sizes) { + num_elements *= dim; + } + for (int i = 0; i < num_elements; ++i) { + auto str = random_func(); + buffer->AddString(str.data(), str.length()); + } +} -std::unique_ptr model; -std::unique_ptr interpreter; +TfLiteType TfLiteTypeFromString(const string& input_layer_type) { + if (input_layer_type == "string") + return kTfLiteString; + else if (input_layer_type == "float") + return kTfLiteFloat32; + else if (input_layer_type == "uint8") + return kTfLiteUInt8; + else if (input_layer_type == "int32") + return kTfLiteInt32; + else if (input_layer_type == "int64") + return kTfLiteInt64; + else + return kTfLiteNoType; +} -void InitImpl(const std::string& graph, const std::vector& sizes, - const std::string& input_layer_type, int num_threads) { - CHECK(graph.c_str()); +std::vector ShapeFromTfLiteTensor(TfLiteTensor* t) { + std::vector result; + result.reserve(t->dims->size); + for (int i = 0; i < t->dims->size; ++i) { + result.push_back(t->dims->data[i]); + } + CHECK(!result.empty()) << "Found no shapes in model"; + return result; +} - model = tflite::FlatBufferModel::BuildFromFile(graph.c_str()); +bool CreateInterpreter(const string& graph, + std::unique_ptr* model, + std::unique_ptr* interpreter) { + *model = tflite::FlatBufferModel::BuildFromFile(graph.c_str()); if (!model) { - LOG(FATAL) << "Failed to mmap model " << graph; + std::cerr << "Failed to load model " << graph << std::endl; + return false; } - LOG(INFO) << "Loaded model " << graph; - model->error_reporter(); - LOG(INFO) << "resolved reporter"; #ifdef TFLITE_CUSTOM_OPS_HEADER tflite::MutableOpResolver resolver; @@ -63,34 +116,360 @@ void InitImpl(const std::string& graph, const std::vector& sizes, tflite::ops::builtin::BuiltinOpResolver resolver; #endif - tflite::InterpreterBuilder(*model, resolver)(&interpreter); - if (!interpreter) { - LOG(FATAL) << "Failed to construct interpreter"; + tflite::InterpreterBuilder(*(model->get()), resolver)(interpreter); + if (!(*interpreter)) { + std::cerr << "Failed to construct interpreter" << std::endl; + return false; } + return true; +} + +bool PrepareInterpreter(const std::vector inputs, + int num_threads, bool use_nnapi, + Interpreter* interpreter) { if (num_threads != -1) { interpreter->SetNumThreads(num_threads); } - int input = interpreter->inputs()[0]; + interpreter->UseNNAPI(use_nnapi); - if (input_layer_type != "string") { - interpreter->ResizeInputTensor(input, sizes); + // Check that all names and types match + for (const InputLayerInfo& input : inputs) { + for (int i : interpreter->inputs()) { + TfLiteTensor* t = interpreter->tensor(i); + CHECK_EQ(t->name, input.name) + << "Tensor # " << i << " is named " << t->name + << " but flags call it " << input.name; + CHECK_EQ(t->type, input.data_type) + << "Could not match the type of input tensor " << t->name; + } + } + + // Resize all non-string tensors. + for (const InputLayerInfo& input : inputs) { + for (int i : interpreter->inputs()) { + TfLiteTensor* t = interpreter->tensor(i); + if (t->type != kTfLiteString) { + interpreter->ResizeInputTensor(i, input.shape); + } + } } if (interpreter->AllocateTensors() != kTfLiteOk) { - LOG(FATAL) << "Failed to allocate tensors!"; + std::cerr << "Failed to allocate tensors!" << std::endl; + return false; + } + + // Set the values of the input tensors. + for (int i : interpreter->inputs()) { + TfLiteTensor* t = interpreter->tensor(i); + std::vector sizes = ShapeFromTfLiteTensor(t); + + // TODO(ahentz): below we ignore the O-th dimension (number of batches). + if (t->type == kTfLiteFloat32) { + FillRandomValue( + interpreter->typed_tensor(i), + std::vector(sizes.begin() + 1, sizes.end()), + []() { return static_cast(rand()) / RAND_MAX - 0.5f; }); + } else if (t->type == kTfLiteUInt8) { + FillRandomValue( + interpreter->typed_tensor(i), + std::vector(sizes.begin() + 1, sizes.end()), + []() { return static_cast(rand()) % 255; }); + } else if (t->type == kTfLiteString) { + tflite::DynamicBuffer buffer; + FillRandomString(&buffer, sizes, []() { + return "we're have some friends over saturday to hang out in the yard"; + }); + buffer.WriteToTensor(interpreter->tensor(i)); + } else { + std::cerr << "Don't know how to populate tensor " << t->name + << " of type " << t->type << std::endl; + return false; + } + } + return true; +} + +bool PopulateInputLayerInfo(const string& names_string, + const string& shapes_string, + const string& types_string, + const string& values_string, + std::vector* info) { + std::vector names = Split(names_string, ','); + std::vector shapes = Split(shapes_string, ':'); + std::vector types = Split(types_string, ','); + std::vector values = Split(values_string, ':'); + + if (names.size() != shapes.size()) { + LOG(ERROR) << "The number of items in" + << " --input_layer_shape (" << shapes_string << ", with " + << shapes.size() << " items)" + << " must match the number of items in" + << " --input_layer (" << names_string << ", with " + << names.size() << " items)." + << " For example --input_layer=input1,input2" + << " --input_layer_shape=1,224,224,4:1,20"; + return false; + } + if (names.size() != types.size()) { + LOG(ERROR) << "The number of items in" + << " --input_layer_type (" << types_string << ", with " + << types.size() << " items)" + << " must match the number of items in" + << " --input_layer (" << names_string << ", with " + << names.size() << " items)." + << " For example --input_layer=input1,input2" + << " --input_layer_type=float,int"; + return false; + } + + for (int i = 0; i < names.size(); ++i) { + info->push_back(InputLayerInfo()); + InputLayerInfo& input = info->back(); + + input.name = names[i]; + + input.data_type = TfLiteTypeFromString(types[i]); + CHECK(input.data_type != kTfLiteNoType) + << types[i] << " was an invalid type"; + + CHECK(SplitAndParseAsInts(shapes[i], ',', &input.shape)) + << "Incorrect size string specified: " << shapes[i]; + for (int dim : input.shape) { + if (dim == -1) { + LOG(ERROR) << "Any unknown sizes in the shapes (-1's) must be replaced" + << " with the size you want to benchmark with."; + return false; + } + } + + if (i < values.size()) { + CHECK(SplitAndParseAsFloats(values[i], ',', &input.initialization_values)) + << "Incorrect initialization values string specified: " << values[i]; + } + } + + return true; +} + +bool RunBenchmark(Interpreter* interpreter, int64_t* inference_time_us) { + const int64_t start_time = Env::Default()->NowMicros(); + + if (interpreter->Invoke() != kTfLiteOk) { + std::cerr << "Failed to invoke!"; + return false; } + + const int64_t end_time = Env::Default()->NowMicros(); + *inference_time_us = end_time - start_time; + return true; +} + +class Latencies { + public: + void AddMeasurement(int64_t time_us) { + max_ = std::max(time_us, max_); + min_ = std::min(time_us, min_); + ++count_; + sum_ += time_us; + squared_sum_ += static_cast(time_us) * time_us; + } + + double avg() const { + if (count_ == 0) return std::numeric_limits::quiet_NaN(); + return static_cast(sum_) / count_; + } + + int64_t std_deviation() const { + if (count_ == 0 || min_ == max_) return 0; + return sqrt(squared_sum_ / count_ - avg() * avg()); + } + + void OutputToStream(std::ostream* stream) const { + *stream << "count=" << count_; + if (count_ == 0) return; + *stream << " min=" << min_ << " max=" << max_; + *stream << " avg=" << avg() << " std=" << std_deviation(); + } + + private: + int64_t count_ = 0; + int64_t min_ = std::numeric_limits::max(); + int64_t max_ = std::numeric_limits::min(); + int64_t sum_ = 0; + double squared_sum_ = 0; +}; + +bool TimeMultipleRuns(Interpreter* interpreter, double sleep_seconds, + int num_runs, int64* total_time_us) { + // Convert the run_delay string into a timespec. + timespec req; + req.tv_sec = static_cast(sleep_seconds); + req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000; + + *total_time_us = 0; + + std::cout << "Running benchmark for " << num_runs + << " iterations: " << std::endl; + + Latencies latencies; + for (int i = 0; i < num_runs; ++i) { + int64_t time_us; + bool run_status = RunBenchmark(interpreter, &time_us); + latencies.AddMeasurement(time_us); + *total_time_us += time_us; + if (!run_status) { + std::cout << "Failed on run " << i << std::endl; + return false; + } + + // If requested, sleep between runs for an arbitrary amount of time. + // This can be helpful to determine the effect of mobile processor + // scaling and thermal throttling. + if (sleep_seconds > 0.0) { +#ifdef PLATFORM_WINDOWS + Sleep(sleep_seconds * 1000); +#else + nanosleep(&req, nullptr); +#endif + } + } + latencies.OutputToStream(&std::cout); + std::cout << std::endl; + + return true; } int Main(int argc, char** argv) { - InitImpl("", {}, "", 1); + using tensorflow::Flag; + using tensorflow::Flags; + + string graph; // e.g.: /data/local/tmp/tfl_inception-v1_model.fb + string input_layer_string; // e.g.: input + string input_layer_shape_string; // e.g.: 1,224,224,3 + string input_layer_type_string; // e.g.: float + string input_layer_values_string; + string output_layer_string; // e.g.: output + int num_runs = 50; + string run_delay = "-1.0"; + int num_threads = -1; + string benchmark_name = ""; + string output_prefix = ""; + int warmup_runs = 1; + bool use_nnapi = false; + + std::vector flag_list = { + Flag("graph", &graph, "graph file name"), + // All the following flags are optional, but can be used in order + // to benchmark different input shapes. + Flag("input_layer", &input_layer_string, "input layer names"), + Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"), + Flag("input_layer_type", &input_layer_type_string, "input layer type"), + Flag("input_layer_values", &input_layer_values_string, + "values to initialize the inputs with"), + Flag("output_layer", &output_layer_string, "output layer name"), + Flag("num_runs", &num_runs, "number of runs"), + Flag("run_delay", &run_delay, "delay between runs in seconds"), + Flag("num_threads", &num_threads, "number of threads"), + Flag("benchmark_name", &benchmark_name, "benchmark name"), + Flag("output_prefix", &output_prefix, "benchmark output prefix"), + Flag("warmup_runs", &warmup_runs, "how many runs to initialize model"), + Flag("use_nnapi", &use_nnapi, "use nnapi api"), + }; + string usage = Flags::Usage(argv[0], flag_list); + const bool parse_result = Flags::Parse(&argc, argv, flag_list); + tensorflow::port::InitMain(argv[0], &argc, &argv); + + if (!parse_result) { + std::cerr << usage << std::endl; + return -1; + } + + std::cout << "Graph: [" << graph << "]" << std::endl; + if (!input_layer_string.empty()) { + std::cout << "Input layers: [" << input_layer_string << "]" << std::endl; + std::cout << "Input shapes: [" << input_layer_shape_string << "]" + << std::endl; + std::cout << "Input types: [" << input_layer_type_string << "]" + << std::endl; + } + if (!output_layer_string.empty()) { + std::cout << "Output layers: [" << output_layer_string << "]" << std::endl; + } + std::cout << "Num runs: [" << num_runs << "]" << std::endl; + std::cout << "Inter-run delay (seconds): [" << run_delay << "]" << std::endl; + std::cout << "Num threads: [" << num_threads << "]" << std::endl; + if (!benchmark_name.empty()) { + std::cout << "Benchmark name: [" << benchmark_name << "]" << std::endl; + std::cout << "Output prefix: [" << output_prefix << "]" << std::endl; + } + std::cout << "Warmup runs: [" << warmup_runs << "]" << std::endl; + std::cout << "Use nnapi : [" << use_nnapi << "]" << std::endl; + + if (graph.empty()) { + std::cout + << "Please specify the name of your TF Lite input file with --graph" + << std::endl; + return -1; + } + + std::vector inputs; + if (!PopulateInputLayerInfo(input_layer_string, input_layer_shape_string, + input_layer_type_string, + input_layer_values_string, &inputs)) { + return -1; + } + + int64 initialization_start_us = Env::Default()->NowMicros(); + + std::unique_ptr model; + std::unique_ptr interpreter; + if (!CreateInterpreter(graph, &model, &interpreter)) { + return -1; + } + if (!PrepareInterpreter(inputs, num_threads, use_nnapi, interpreter.get())) { + return -1; + } + + int64 initialization_end_us = Env::Default()->NowMicros(); + + const double initialization_time_s = + (initialization_end_us - initialization_start_us) / 1000000.0f; + std::cout << "Initialized session in " << initialization_time_s << "s" + << std::endl; + + const double sleep_seconds = std::strtod(run_delay.c_str(), nullptr); + + // If requested, run through the graph first to preinitialize everything + // before the benchmarking runs. + int64 warmup_time_us = 0; + if (warmup_runs > 0) { + if (!TimeMultipleRuns(interpreter.get(), sleep_seconds, warmup_runs, + &warmup_time_us)) { + std::cerr << "Warmup failed" << std::endl; + return -1; + } + } + + // Capture overall inference time without stat logging overhead. This is the + // timing data that can be compared to other libaries. + int64 no_stat_time_us = 0; + if (!TimeMultipleRuns(interpreter.get(), sleep_seconds, num_runs, + &no_stat_time_us)) { + std::cerr << "Timing failed." << std::endl; + return -1; + } + + std::cout << "Average inference timings in us: " << no_stat_time_us / num_runs + << " , Warmup: " + << (warmup_runs > 0 ? warmup_time_us / warmup_runs : 0) << ", " + << std::endl; + return 0; } -} // namespace benchmark_tflite_model -} // namespace tensorflow +} // namespace tflite -int main(int argc, char** argv) { - return tensorflow::benchmark_tflite_model::Main(argc, argv); -} +int main(int argc, char** argv) { return ::tflite::Main(argc, argv); } diff --git a/tensorflow/contrib/lite/tools/verifier.cc b/tensorflow/contrib/lite/tools/verifier.cc index 59c74205f0a311ec12ff87f46622041605fb493b..8818a7dc85d9ffdc1da450fb389d5ed11139bc31 100644 --- a/tensorflow/contrib/lite/tools/verifier.cc +++ b/tensorflow/contrib/lite/tools/verifier.cc @@ -148,11 +148,52 @@ bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer, // TODO(yichengfan): verify quantized tensors. } +using flatbuffers::Offset; +using flatbuffers::Vector; + +bool VerifyOperators(const Vector>& operators, + ErrorReporter* error_reporter) { + for (const auto& op : operators) { + if (!op->inputs()) { + ReportError(error_reporter, "Missing 'inputs' for operator."); + return false; + } + if (!op->outputs()) { + ReportError(error_reporter, "Missing 'outputs' for operator."); + return false; + } + } + return true; +} + +bool VerifySubGraphs(const Model& model, ErrorReporter* error_reporter) { + if (!model.subgraphs()) { + ReportError(error_reporter, "Missing 'subgraphs' section."); + return false; + } + for (const auto& subgraph : *model.subgraphs()) { + if (!subgraph->operators()) { + ReportError(error_reporter, "Missing 'operators' section in subgraph."); + return false; + } + + if (!VerifyOperators(*subgraph->operators(), error_reporter)) { + return false; + } + } + return true; +} + // Verifies tensors have valid properties and legit buffer if set. bool VerifyTensors(const Model& model, ErrorReporter* error_reporter) { if (!model.subgraphs()) { return true; } + if (!model.buffers()) { + ReportError(error_reporter, "Missing 'buffers' section."); + return false; + } + for (const auto& subgraph : *model.subgraphs()) { if (!subgraph->tensors()) { continue; @@ -167,19 +208,23 @@ bool VerifyTensors(const Model& model, ErrorReporter* error_reporter) { return false; } auto* buffer = model.buffers()->Get(tensor->buffer()); - if (!buffer || !buffer->data()) { + if (!buffer) { ReportError(error_reporter, "Tensor buffer %d not set", tensor->buffer()); return false; } - if (tensor->type() == TensorType_STRING) { - if (!VerifyStringTensorBuffer(*buffer, error_reporter)) { - return false; - } - } else { - if (!VerifyNumericTensorBuffer(*tensor, *buffer, error_reporter)) { - return false; + // Many transient tensors don't have data in the flatbuffer. Their + // buffers will be allocated by the interpreter at run-time. + if (buffer->data()) { + if (tensor->type() == TensorType_STRING) { + if (!VerifyStringTensorBuffer(*buffer, error_reporter)) { + return false; + } + } else { + if (!VerifyNumericTensorBuffer(*tensor, *buffer, error_reporter)) { + return false; + } } } } @@ -193,6 +238,13 @@ bool VerifyOps(const Model& model, const OpResolver& resolver, return true; } for (const auto& opcode : *model.operator_codes()) { + if (opcode->builtin_code() < BuiltinOperator_MIN || + opcode->builtin_code() > BuiltinOperator_MAX) { + ReportError(error_reporter, "Operator id '%d' is out of range.", + opcode->builtin_code()); + return false; + } + if (opcode->builtin_code() == BuiltinOperator_CUSTOM) { if (!resolver.FindOp(opcode->custom_code()->c_str())) { ReportError(error_reporter, "Unsupported custom op: %s", @@ -223,6 +275,9 @@ bool Verify(const void* buf, size_t len, const OpResolver& resolver, ReportError(error_reporter, "Invalid model version %d", model->version()); return false; } + if (!VerifySubGraphs(*model, error_reporter)) { + return false; + } if (!VerifyTensors(*model, error_reporter)) { return false; } diff --git a/tensorflow/contrib/lite/tools/verifier.h b/tensorflow/contrib/lite/tools/verifier.h index c2ee11215c861ed7b27696a8d786bb6e2a48e930..b7ce4e830576af14002d6bd9080af1da5764b1c9 100644 --- a/tensorflow/contrib/lite/tools/verifier.h +++ b/tensorflow/contrib/lite/tools/verifier.h @@ -23,6 +23,21 @@ limitations under the License. namespace tflite { +class AlwaysTrueResolver : public OpResolver { + public: + AlwaysTrueResolver() {} + TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override { + static TfLiteRegistration null_registration = {nullptr, nullptr, nullptr, + nullptr}; + return &null_registration; + } + TfLiteRegistration* FindOp(const char* op) const override { + static TfLiteRegistration null_registration = {nullptr, nullptr, nullptr, + nullptr}; + return &null_registration; + } +}; + // Verifies the integrity of a Tensorflow Lite flatbuffer model file. // Currently, it verifies: // * The file is following a legit flatbuffer schema. diff --git a/tensorflow/contrib/lite/tools/verifier_test.cc b/tensorflow/contrib/lite/tools/verifier_test.cc index b3e611f999b2837efbf8876bd989db44c408b8c7..03b93afe3ed04b4bff13bc01d7c7c8e9fae9bdf3 100644 --- a/tensorflow/contrib/lite/tools/verifier_test.cc +++ b/tensorflow/contrib/lite/tools/verifier_test.cc @@ -113,8 +113,8 @@ TEST(VerifyModel, TestEmptyModel) { /*description=*/0, /*buffers=*/0); ::tflite::FinishModelBuffer(builder, model); - ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize(), - MutableOpResolver{}, DefaultErrorReporter())); + ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(), + MutableOpResolver{}, DefaultErrorReporter())); } TEST(VerifyModel, TestSimpleModel) { diff --git a/tensorflow/contrib/lite/util.cc b/tensorflow/contrib/lite/util.cc new file mode 100644 index 0000000000000000000000000000000000000000..fb4af07d060cac3a6a4e01c7d625b6db5241f10d --- /dev/null +++ b/tensorflow/contrib/lite/util.cc @@ -0,0 +1,41 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/util.h" + +namespace tflite { + +TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector& input) { + return ConvertArrayToTfLiteIntArray(input.size(), input.data()); +} + +TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims) { + TfLiteIntArray* output = TfLiteIntArrayCreate(rank); + for (size_t i = 0; i < rank; i++) { + output->data[i] = dims[i]; + } + return output; +} + +bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size, + const int* b) { + if (!a) return false; + if (a->size != b_size) return false; + for (int i = 0; i < a->size; ++i) { + if (a->data[i] != b[i]) return false; + } + return true; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h new file mode 100644 index 0000000000000000000000000000000000000000..a34db35823104414cce028b9119397da085d05b1 --- /dev/null +++ b/tensorflow/contrib/lite/util.h @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file provides general C++ utility functions in TFLite. +// For example: Converting between `TfLiteIntArray`, `std::vector` and +// Flatbuffer vectors. These functions can't live in `context.h` since it's pure +// C. + +#ifndef TENSORFLOW_CONTRIB_LITE_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_UTIL_H_ + +#include +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// Converts a `std::vector` to a `TfLiteIntArray`. +TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector& input); + +TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims); + +// Checks whether a `TfLiteIntArray` and an int array have matching elements. +bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size, + const int* b); + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_UTIL_H_ diff --git a/tensorflow/contrib/lite/util_test.cc b/tensorflow/contrib/lite/util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..04579c53aa4835c47d812c89a1554a0d2f2f30b8 --- /dev/null +++ b/tensorflow/contrib/lite/util_test.cc @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/util.h" + +namespace tflite { +namespace { + +TEST(ConvertVectorToTfLiteIntArray, TestWithVector) { + std::vector input = {1, 2}; + TfLiteIntArray* output = ConvertVectorToTfLiteIntArray(input); + ASSERT_NE(output, nullptr); + EXPECT_EQ(output->size, 2); + EXPECT_EQ(output->data[0], 1); + EXPECT_EQ(output->data[1], 2); + TfLiteIntArrayFree(output); +} + +TEST(ConvertVectorToTfLiteIntArray, TestWithEmptyVector) { + std::vector input; + TfLiteIntArray* output = ConvertVectorToTfLiteIntArray(input); + ASSERT_NE(output, nullptr); + EXPECT_EQ(output->size, 0); + TfLiteIntArrayFree(output); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index a430dac4ec43ce31f0b5aaae5e7b0b51d25c9632..a03e731be32c5964cb4aece8e8a67525883a4e7c 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -105,7 +105,7 @@ def index_table_from_tensor(mapping, ... tf.tables_initializer().run() - ids.eval() ==> [0, 1, 4, 2] + ids.eval() ==> [0, 1, 3, 2] ``` Args: @@ -341,23 +341,21 @@ class MutableHashTable(LookupInterface): # training to work correctly. Use the node name if no shared_name has been # explicitly specified. use_node_name_sharing = checkpoint and shared_name is None - # pylint: disable=protected-access if self._default_value.get_shape().ndims == 0: - self._table_ref = gen_lookup_ops._mutable_hash_table_v2( + self._table_ref = gen_lookup_ops.mutable_hash_table_v2( shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, key_dtype=key_dtype, value_dtype=value_dtype, name=name) else: - self._table_ref = gen_lookup_ops._mutable_hash_table_of_tensors_v2( + self._table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2( shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, key_dtype=key_dtype, value_dtype=value_dtype, value_shape=self._default_value.get_shape(), name=name) - # pylint: enable=protected-access super(MutableHashTable, self).__init__(key_dtype, value_dtype, self._table_ref.op.name.split( "/")[-1]) @@ -378,9 +376,7 @@ class MutableHashTable(LookupInterface): with ops.name_scope(name, "%s_Size" % self._name, [self._table_ref]) as name: with ops.colocate_with(self._table_ref): - - # pylint: disable=protected-access - return gen_lookup_ops._lookup_table_size_v2(self._table_ref, name=name) + return gen_lookup_ops.lookup_table_size_v2(self._table_ref, name=name) def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. @@ -406,8 +402,7 @@ class MutableHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_find" % self._name, (self._table_ref, keys, self._default_value)) as name: with ops.colocate_with(self._table_ref): - # pylint: disable=protected-access - values = gen_lookup_ops._lookup_table_find_v2( + values = gen_lookup_ops.lookup_table_find_v2( self._table_ref, keys, self._default_value, name=name) values.set_shape(keys.get_shape().concatenate(self._value_shape)) @@ -437,7 +432,7 @@ class MutableHashTable(LookupInterface): [self._table_ref, keys, values]) as name: with ops.colocate_with(self._table_ref): # pylint: disable=protected-access - op = gen_lookup_ops._lookup_table_insert_v2( + op = gen_lookup_ops.lookup_table_insert_v2( self._table_ref, keys, values, name=name) return op @@ -454,8 +449,7 @@ class MutableHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_export_values" % self._name, [self._table_ref]) as name: with ops.colocate_with(self._table_ref): - # pylint: disable=protected-access - exported_keys, exported_values = gen_lookup_ops._lookup_table_export_v2( + exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( self._table_ref, self._key_dtype, self._value_dtype, name=name) exported_values.set_shape(exported_keys.get_shape().concatenate( @@ -477,7 +471,7 @@ class MutableHashTable(LookupInterface): def restore(self, restored_tensors, unused_restored_shapes): # pylint: disable=protected-access with ops.colocate_with(self.op._table_ref): - return gen_lookup_ops._lookup_table_import_v2( + return gen_lookup_ops.lookup_table_import_v2( self.op._table_ref, restored_tensors[0], restored_tensors[1]) @@ -500,7 +494,7 @@ class MutableDenseHashTable(LookupInterface): value_dtype=tf.int64, default_value=-1, empty_key=0) - table.insert(keys, values) + sess.run(table.insert(keys, values)) out = table.lookup(query_keys) print(out.eval()) ``` @@ -551,8 +545,7 @@ class MutableDenseHashTable(LookupInterface): # explicitly specified. use_node_name_sharing = checkpoint and shared_name is None empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype) - # pylint: disable=protected-access - self._table_ref = gen_lookup_ops._mutable_dense_hash_table_v2( + self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2( empty_key=empty_key, shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, @@ -560,7 +553,6 @@ class MutableDenseHashTable(LookupInterface): value_shape=self._value_shape, initial_num_buckets=initial_num_buckets, name=name) - # pylint: enable=protected-access super(MutableDenseHashTable, self).__init__( key_dtype, value_dtype, self._table_ref.op.name.split("/")[-1]) @@ -580,8 +572,7 @@ class MutableDenseHashTable(LookupInterface): with ops.name_scope(name, "%s_Size" % self._name, [self._table_ref]) as name: with ops.colocate_with(self._table_ref): - # pylint: disable=protected-access - return gen_lookup_ops._lookup_table_size_v2(self._table_ref, name=name) + return gen_lookup_ops.lookup_table_size_v2(self._table_ref, name=name) def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. @@ -607,8 +598,7 @@ class MutableDenseHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_find" % self._name, [self._table_ref, keys]) as name: with ops.colocate_with(self._table_ref): - # pylint: disable=protected-access - values = gen_lookup_ops._lookup_table_find_v2( + values = gen_lookup_ops.lookup_table_find_v2( self._table_ref, keys, self._default_value, name=name) if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0: @@ -640,8 +630,7 @@ class MutableDenseHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_insert" % self._name, [self._table_ref, keys, values]) as name: with ops.colocate_with(self._table_ref): - # pylint: disable=protected-access - op = gen_lookup_ops._lookup_table_insert_v2( + op = gen_lookup_ops.lookup_table_insert_v2( self._table_ref, keys, values, name=name) return op @@ -658,8 +647,7 @@ class MutableDenseHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_export_values" % self._name, [self._table_ref]) as name: with ops.colocate_with(self._table_ref): - # pylint: disable=protected-access - exported_keys, exported_values = gen_lookup_ops._lookup_table_export_v2( + exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( self._table_ref, self._key_dtype, self._value_dtype, name=name) exported_values.set_shape(exported_keys.get_shape().concatenate( @@ -681,5 +669,5 @@ class MutableDenseHashTable(LookupInterface): def restore(self, restored_tensors, unused_restored_shapes): # pylint: disable=protected-access with ops.colocate_with(self.op._table_ref): - return gen_lookup_ops._lookup_table_import_v2( + return gen_lookup_ops.lookup_table_import_v2( self.op._table_ref, restored_tensors[0], restored_tensors[1]) diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index 81327407d44b4317b7aecb964a689a35aa35c163..05e8d9064bea748c935859f5f9b4c7e646f504cf 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -677,6 +677,7 @@ endif # TEGRA TF_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) # Add in any extra files that don't fit the patterns easily TF_CC_SRCS += tensorflow/contrib/makefile/downloads/fft2d/fftsg.c +TF_CC_SRCS += tensorflow/core/common_runtime/gpu/gpu_id_manager.cc # Also include the op and kernel definitions. TF_CC_SRCS += $(shell cat $(MAKEFILE_DIR)/tf_op_files.txt) PBT_CC_SRCS := $(shell cat $(MAKEFILE_DIR)/tf_pb_text_files.txt) diff --git a/tensorflow/contrib/makefile/README.md b/tensorflow/contrib/makefile/README.md index 995230dfa848532dc2a50b85f58d19ba264f293e..6c3b02e12b3082be8bfcc316c4c6122931eb5f76 100644 --- a/tensorflow/contrib/makefile/README.md +++ b/tensorflow/contrib/makefile/README.md @@ -194,6 +194,8 @@ with: srcs = glob(["libs/arm64-v8a/*.so"]), ``` +If you are building for Android TV (Shield TV devices), replace "portrait" with "landscape" for android:screenOrientation in all four activities in tensorflow/examples/android/AndroidManifest.xml + Then run: ```bash # Create dir for native libs diff --git a/tensorflow/contrib/makefile/build_all_ios.sh b/tensorflow/contrib/makefile/build_all_ios.sh index 2d9979183975e6a17527b40ef5ee1795ced44a7b..0a458a27b3ac9b1a24b0f42de2f0166d515e8cd9 100755 --- a/tensorflow/contrib/makefile/build_all_ios.sh +++ b/tensorflow/contrib/makefile/build_all_ios.sh @@ -80,10 +80,9 @@ if [[ ! -z "${OPTIMIZE_FOR_GRAPH}" ]]; then fi else echo "${PRNT_SLCTV_BIN} found. Using it" - ${PRNT_SLCTV_BIN} --graphs=${OPTIMIZE_FOR_GRAPH} > ${TOP_SRCDIR}/tensorflow/core/framework/ops_to_register.h - fi + ${PRNT_SLCTV_BIN} --graphs=${OPTIMIZE_FOR_GRAPH} > ${TOP_SRCDIR}/tensorflow/core/framework/ops_to_register.h fi if [[ "${ONLY_MAKE_TENSORFLOW}" != "true" ]]; then @@ -111,7 +110,7 @@ if [[ -z "${BUILD_ARCH}" ]]; then TARGET_NSYNC_LIB=`tensorflow/contrib/makefile/compile_nsync.sh -t ios` else # arch specified so build just that - TARGET_NSYNC_LIB=`tensorflow/contrib/makefile/compile_nsync.sh -t ios -a ${BUILD_ARCH}` + TARGET_NSYNC_LIB=`tensorflow/contrib/makefile/compile_nsync.sh -t ios -a "${BUILD_ARCH}"` fi export HOST_NSYNC_LIB TARGET_NSYNC_LIB diff --git a/tensorflow/contrib/makefile/compile_nsync.sh b/tensorflow/contrib/makefile/compile_nsync.sh index 7927997678f077a716d81749561068f259d9744f..e8c6edd7ba9aa6a45d956d1d5655b2809d8d2309 100755 --- a/tensorflow/contrib/makefile/compile_nsync.sh +++ b/tensorflow/contrib/makefile/compile_nsync.sh @@ -109,17 +109,18 @@ for arch in $archs; do linux) makefile=' CC=${CC_PREFIX} g++ PLATFORM_CPPFLAGS=-DNSYNC_USE_CPP11_TIMEPOINT -DNSYNC_ATOMIC_CPP11 \ + -I../../platform/c++11.futex \ -I../../platform/c++11 -I../../platform/gcc \ -I../../platform/posix -pthread PLATFORM_CFLAGS=-std=c++11 -Werror -Wall -Wextra -pedantic PLATFORM_LDFLAGS=-pthread MKDEP=${CC} -M -std=c++11 - PLATFORM_C=../../platform/c++11/src/nsync_semaphore_mutex.cc \ + PLATFORM_C=../../platform/linux/src/nsync_semaphore_futex.c \ ../../platform/c++11/src/per_thread_waiter.cc \ ../../platform/c++11/src/yield.cc \ ../../platform/c++11/src/time_rep_timespec.cc \ ../../platform/c++11/src/nsync_panic.cc - PLATFORM_OBJS=nsync_semaphore_mutex.o per_thread_waiter.o yield.o \ + PLATFORM_OBJS=nsync_semaphore_futex.o per_thread_waiter.o yield.o \ time_rep_timespec.o nsync_panic.o TEST_PLATFORM_C=../../platform/c++11/src/start_thread.cc TEST_PLATFORM_OBJS=start_thread.o diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index 4ae18b2cef28335a90bbc967529c0cf76b0a5da2..8b415e6527f85a5a7844b9d4156fd39ecb1b637a 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -34,7 +34,7 @@ PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/. RE2_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" -CUB_URL="$(grep -o 'https.*cub/archive.*zip' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" +CUB_URL="$(grep -o 'https.*cub/archive.*zip' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, # so work around it by patching the source. diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt index d56e388477db6239cfb577f7e2754321ff33bd82..77c936d8c5b99033ff5c5e149a6ce6613b603132 100644 --- a/tensorflow/contrib/makefile/proto_text_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt @@ -17,6 +17,7 @@ tensorflow/core/platform/env_time.cc tensorflow/core/platform/setround.cc tensorflow/core/platform/denormal.cc tensorflow/core/platform/default/tracing.cc +tensorflow/core/platform/default/mutex.cc tensorflow/core/platform/default/logging.cc tensorflow/core/platform/cpu_info.cc tensorflow/core/lib/wav/wav_io.cc diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 5a812af4e95fe7a05b9c2634b0cc1d860fb7f619..7a7683c95369aa929d93591e6bf78fd945ce36bc 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -258,6 +258,7 @@ tensorflow/core/kernels/requantize.cc tensorflow/core/kernels/remote_fused_graph_execute_op.cc tensorflow/core/kernels/remote_fused_graph_execute_utils.cc tensorflow/core/kernels/batch_matmul_op_real.cc +tensorflow/core/kernels/random_op.cc tensorflow/core/ops/training_ops.cc tensorflow/core/ops/string_ops.cc tensorflow/core/ops/state_ops.cc diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 31e274c5fd7c670458b1b40a4f58c668a23776c7..81f05e7ce587ed1da67a17efbbeb809dbe7fc0b3 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1263,7 +1263,7 @@ def _compute_placement_auc(labels, predictions, weights, alpha, weights_for_true = ordered_weights * float_labels_for_true weights_for_false = ordered_weights * float_labels_for_false - # For each set of weights with the same segmented indices, we add up the + # For each set of weights with the same segmented indices, we add up the # weight values. Note that for each label, we deliberately rely on weights # for the opposite label. weight_totals_for_true = math_ops.segment_sum(weights_for_false, @@ -3646,8 +3646,8 @@ def cohen_kappa(labels, `updates_collections` are not a list or tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): - raise RuntimeError('tf.contrib.metrics.cohen_kappa is not supported' + if context.executing_eagerly(): + raise RuntimeError('tf.contrib.metrics.cohen_kappa is not supported ' 'when eager execution is enabled.') if num_classes < 2: raise ValueError('`num_classes` must be >= 2.' diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index b387f26c0195432fb972dac450d2919bdaa702a1..33eb655fb660f0ecdfe1c5ab870d7f17690ae3ff 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -1802,9 +1802,9 @@ class StreamingAUCTest(test.TestCase): auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.54166603, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(0.79166, sess.run(update_op), delta=1e-3) - self.assertAlmostEqual(0.54166603, auc.eval(), delta=1e-3) + self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3) def testAnotherAUCPRSpecialCase(self): with self.test_session() as sess: @@ -1816,9 +1816,9 @@ class StreamingAUCTest(test.TestCase): auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.44365042, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(0.610317, sess.run(update_op), delta=1e-3) - self.assertAlmostEqual(0.44365042, auc.eval(), delta=1e-3) + self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3) def testThirdAUCPRSpecialCase(self): with self.test_session() as sess: @@ -1830,9 +1830,9 @@ class StreamingAUCTest(test.TestCase): auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.73611039, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(0.90277, sess.run(update_op), delta=1e-3) - self.assertAlmostEqual(0.73611039, auc.eval(), delta=1e-3) + self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-3) def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) @@ -1865,9 +1865,9 @@ class StreamingAUCTest(test.TestCase): auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.49999976, sess.run(update_op), 6) + self.assertAlmostEqual(1, sess.run(update_op), 6) - self.assertAlmostEqual(0.49999976, auc.eval(), 6) + self.assertAlmostEqual(1, auc.eval(), 6) def testWithMultipleUpdates(self): num_samples = 1000 @@ -6888,8 +6888,7 @@ class CohenKappaTest(test.TestCase): # [[0, 25, 0], # [0, 0, 25], # [25, 0, 0]] - # Calculated by v0.19: sklearn.metrics.cohen_kappa_score( - # labels, predictions) + # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(labels, predictions) expect = -0.333333333333 with self.test_session() as sess: @@ -6948,8 +6947,7 @@ class CohenKappaTest(test.TestCase): weights_t: weights[batch_start:batch_end] }) # Calculated by v0.19: sklearn.metrics.cohen_kappa_score( - # labels_np, predictions_np, - # sample_weight=weights_np) + # labels_np, predictions_np, sample_weight=weights_np) expect = 0.289965397924 self.assertAlmostEqual(expect, kappa.eval(), 5) diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index d286750c257e9a78a82c95c1fc872b3ca6972203..52b659c69fdfc507e6259e928d79c65471f2f025 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -134,7 +134,7 @@ $ bazel-bin/$examples_dir/cifar10/cifar10_eval --run_once ### Block Sparsity -For some hardware architectures, it may be beneficial to induce spatially correlated sparsity. To train models in which the weight tensors have block sparse structure, set *block_height* and *block_width* hyperparameters to the desired block configuration (2x2, 4x4, 4x1, 1x8, etc). Currently, block sparsity is supported for weight tensors with rank 2 only. The matrix is partitioned into non-overlapping blocks of size *[block_height, block_dim]* and the either the average or max absolute value in this block is taken as a proxy for the entire block (set by *block_pooling_function* hyperparameter). +For some hardware architectures, it may be beneficial to induce spatially correlated sparsity. To train models in which the weight tensors have block sparse structure, set *block_height* and *block_width* hyperparameters to the desired block configuration (2x2, 4x4, 4x1, 1x8, etc). Currently, block sparsity is only supported for weight tensors which can be squeezed to rank 2. The matrix is partitioned into non-overlapping blocks of size *[block_height, block_dim]* and the either the average or max absolute value in this block is taken as a proxy for the entire block (set by *block_pooling_function* hyperparameter). The convolution layer tensors are always pruned used block dimensions of [1,1]. ## References diff --git a/tensorflow/contrib/model_pruning/python/layers/layers.py b/tensorflow/contrib/model_pruning/python/layers/layers.py index 988748ad75bdf72f1da3f4e1c6e85aabb04a5954..466daf204a1ae86a7f37107342046305ea7249fc 100644 --- a/tensorflow/contrib/model_pruning/python/layers/layers.py +++ b/tensorflow/contrib/model_pruning/python/layers/layers.py @@ -214,7 +214,7 @@ def masked_convolution(inputs, elif data_format == 'NCHW': df = 'channels_first' else: - raise ValueError('Unsupported data fromat', data_format) + raise ValueError('Unsupported data format', data_format) layer = layer_class( filters=num_outputs, diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index d16af9da19816211ee22f6ea48a347f0b9a4e612..5146a4a2de7806041991c04958de378b2d3dc810 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -216,7 +216,7 @@ def _partitioned_variable_assign(partitioned_var, new_value): """Assign op for partitioned variables. Args: - partitioned_var: A partitioned tensotflow variable + partitioned_var: A partitioned tensorflow variable new_value: Value to be assigned to the variable var Returns: @@ -523,7 +523,8 @@ class Pruning(object): """Performs block-granular masking of the weights. Block pruning occurs only if the block_height or block_width is > 1 and - if the weight tensor has ndims = 2. Otherwise, elementwise pruning occurs. + if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise + pruning occurs. Args: weights: The weight tensor that needs to be masked. threshold: The current threshold value. The function will compute a new @@ -540,7 +541,8 @@ class Pruning(object): Raises: ValueError: if block pooling function is not AVG or MAX """ - if weights.get_shape().ndims != 2 or self._block_dim == [1, 1]: + squeezed_weights = array_ops.squeeze(weights) + if squeezed_weights.get_shape().ndims != 2 or self._block_dim == [1, 1]: return self._update_mask(weights, threshold) if self._block_pooling_function not in ['AVG', 'MAX']: @@ -549,9 +551,11 @@ class Pruning(object): with ops.name_scope(weights.op.name + '_pruning_ops'): abs_weights = math_ops.abs( - array_ops.reshape( - weights, [1, weights.get_shape()[0], - weights.get_shape()[1], 1])) + array_ops.reshape(weights, [ + 1, + squeezed_weights.get_shape()[0], + squeezed_weights.get_shape()[1], 1 + ])) pool_window = [self._block_dim[0], self._block_dim[1]] pooled_weights = nn_ops.pool( abs_weights, @@ -572,9 +576,10 @@ class Pruning(object): array_ops.ones(self._block_dim)) sliced_mask = array_ops.slice( updated_mask, [0, 0], - [weights.get_shape()[0], - weights.get_shape()[1]]) - return smoothed_threshold, sliced_mask + [squeezed_weights.get_shape()[0], + squeezed_weights.get_shape()[1]]) + return smoothed_threshold, array_ops.reshape(sliced_mask, + array_ops.shape(weights)) def _get_mask_assign_ops(self): # Make sure the assignment ops have not already been added to the list diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py index 1767b4bb94a9bb56bc6a4933423ad27d8cf3ed35..89e65713197afc6ed37346cb67a6e9be3fa9290f 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_test.py @@ -140,6 +140,23 @@ class PruningTest(test.TestCase): [0.0, -0.3, 0.0, -0.4]]) expected_mask = [[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]] + self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max, + expected_mask) + self._blockMasking(param_list + ["block_pooling_function=AVG"], weights_avg, + expected_mask) + + def testBlockMaskingWithHigherDimensions(self): + param_list = ["block_height=2", "block_width=2", "threshold_decay=0"] + + # Weights as in testBlockMasking, but with one extra dimension. + weights_avg = constant_op.constant( + [[[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4], + [0.3, 0.3, 0.4, 0.4]]]) + weights_max = constant_op.constant( + [[[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2], [0.3, 0.0, 0.4, 0.0], + [0.0, -0.3, 0.0, -0.4]]]) + expected_mask = [[[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]]] + self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max, expected_mask) self._blockMasking(param_list + ["block_pooling_function=AVG"], diff --git a/tensorflow/contrib/mpi/mpi_utils.h b/tensorflow/contrib/mpi/mpi_utils.h index fa297c28cb47d43ba927ab941854bd472d90b465..df055ff56731140b3bd09704c70e65f81362f763 100644 --- a/tensorflow/contrib/mpi/mpi_utils.h +++ b/tensorflow/contrib/mpi/mpi_utils.h @@ -24,6 +24,8 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" +// Skip MPI C++ bindings support, this matches the usage in other places +#define OMPI_SKIP_MPICXX #include "third_party/mpi/mpi.h" #define MPI_CHECK(cmd) \ do { \ diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD index 5ac96007df7ee08b1e32aacd28f83768859810a9..94d01efee1546feca89a7e88acedf915b1dfb3a4 100644 --- a/tensorflow/contrib/nccl/BUILD +++ b/tensorflow/contrib/nccl/BUILD @@ -52,6 +52,7 @@ tf_cuda_cc_test( "manual", "multi_gpu", "no_oss", + "noguitar", "notap", ], deps = @@ -136,6 +137,7 @@ cuda_py_test( "manual", "multi_gpu", "no_oss", + "noguitar", "notap", ], ) diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops.py b/tensorflow/contrib/nccl/python/ops/nccl_ops.py index 8dc038b9ac992de7db8b762e3697c6693099e192..794372a1f4b0dcc41bcf0da611f5bc2ec9301973 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops.py @@ -267,5 +267,5 @@ def _check_device(tensor, expected=None): def _check_graph_mode(): - if context.in_eager_mode(): + if context.executing_eagerly(): raise ValueError('Nccl ops are not supported in eager mode') diff --git a/tensorflow/contrib/nn/python/ops/scaled_softplus.py b/tensorflow/contrib/nn/python/ops/scaled_softplus.py index fcbfbc239ca5b8a1d4b17b403f99b7eb05db47b0..7184ef2b66ec4662af3a37def070ab151d6e7c15 100644 --- a/tensorflow/contrib/nn/python/ops/scaled_softplus.py +++ b/tensorflow/contrib/nn/python/ops/scaled_softplus.py @@ -30,9 +30,7 @@ def _reduce_and_reshape_grad(g, t): """Returns the gradient, sum-reduced and reshaped to `t`'s shape.""" shape = array_ops.shape(t) g_shape = array_ops.shape(g) - # pylint: disable=protected-access - bcast_dims, _ = gen_array_ops._broadcast_gradient_args(shape, g_shape) - # pylint: enable=protected-access + bcast_dims, _ = gen_array_ops.broadcast_gradient_args(shape, g_shape) return array_ops.reshape(math_ops.reduce_sum(g, bcast_dims), shape) diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index bc374d66c31922aa54542825ea4e04a444d96c5d..bacf15bbd6140caf647552f0dca02209634ae56b 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -52,6 +52,9 @@ py_test( name = "external_optimizer_test", srcs = ["python/training/external_optimizer_test.py"], srcs_version = "PY2AND3", + tags = [ + "no-internal-py3", + ], deps = [ ":opt_py", "//tensorflow/python:array_ops", @@ -70,9 +73,6 @@ py_test( srcs = ["python/training/moving_average_optimizer_test.py"], srcs_version = "PY2AND3", tags = [ - "manual", - "no_oss", # b/73507407 - "notap", "notsan", # b/31055119 ], deps = [ diff --git a/tensorflow/contrib/opt/python/training/addsign_test.py b/tensorflow/contrib/opt/python/training/addsign_test.py index bd19ee3e7ac514448c6d79272abb86a154f55e9a..08d45ed73f3ae4b580d7078272e79fef22ef67c5 100644 --- a/tensorflow/contrib/opt/python/training/addsign_test.py +++ b/tensorflow/contrib/opt/python/training/addsign_test.py @@ -97,7 +97,7 @@ class AddSignTest(test.TestCase): global_step=global_step) neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]), global_step=global_step) - if context.in_graph_mode(): + if not context.executing_eagerly(): self.evaluate(variables.global_variables_initializer()) # Fetch params to validate initial values self.assertAllClose([1.0, 2.0], self.evaluate(var0)) @@ -108,13 +108,13 @@ class AddSignTest(test.TestCase): # last 3 steps with negative gradient (sign(gm) should be -1) for t in range(1, 8): if t < 5: - if context.in_graph_mode(): + if not context.executing_eagerly(): self.evaluate(update) elif t > 1: opt.apply_gradients(zip([grads0, grads1], [var0, var1]), global_step=global_step) else: - if context.in_graph_mode(): + if not context.executing_eagerly(): self.evaluate(neg_update) elif t > 1: opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]), diff --git a/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper.py b/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper.py index cb6c77a86feedde3285d75092511c8eb1e63b2a5..9076cc9d128552e37c09852ab2f24aa0c9977892 100644 --- a/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper.py +++ b/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper.py @@ -22,6 +22,7 @@ import types import six from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops @@ -40,8 +41,10 @@ def _get_wrapper(fn, opt): def wrapper(self, grad, *args, **kwargs): # pylint: disable=unused-argument all_zeros = _is_all_zeros(grad) - return control_flow_ops.cond(all_zeros, control_flow_ops.no_op, - lambda: fn(grad, *args, **kwargs)) + def call_fn(): + with ops.control_dependencies([fn(grad, *args, **kwargs)]): + return control_flow_ops.no_op() + return control_flow_ops.cond(all_zeros, control_flow_ops.no_op, call_fn) wrapper = types.MethodType(wrapper, opt) return wrapper diff --git a/tensorflow/contrib/opt/python/training/powersign_test.py b/tensorflow/contrib/opt/python/training/powersign_test.py index ff7b1a72d47d8ef54980905323bcaf358c988a82..5214082dd66f00eadadad71d50f7e00b178b8c10 100644 --- a/tensorflow/contrib/opt/python/training/powersign_test.py +++ b/tensorflow/contrib/opt/python/training/powersign_test.py @@ -99,7 +99,7 @@ class PowerSignTest(test.TestCase): neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]), global_step=global_step) - if context.in_graph_mode(): + if not context.executing_eagerly(): self.evaluate(variables.global_variables_initializer()) # Fetch params to validate initial values self.assertAllClose([1.0, 2.0], self.evaluate(var0)) @@ -110,13 +110,13 @@ class PowerSignTest(test.TestCase): # last 3 steps with negative gradient (sign(gm) should be -1) for t in range(1, 8): if t < 5: - if context.in_graph_mode(): + if not context.executing_eagerly(): self.evaluate(update) elif t > 1: opt.apply_gradients(zip([grads0, grads1], [var0, var1]), global_step=global_step) else: - if context.in_graph_mode(): + if not context.executing_eagerly(): self.evaluate(neg_update) elif t > 1: opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]), diff --git a/tensorflow/contrib/predictor/predictor_factories.py b/tensorflow/contrib/predictor/predictor_factories.py index 04b5d5bdf158dc6a478d7a24b538c75d1dca8d45..6e77e934fe19851eea9ed0b74eb7aecc76f6237a 100644 --- a/tensorflow/contrib/predictor/predictor_factories.py +++ b/tensorflow/contrib/predictor/predictor_factories.py @@ -53,7 +53,7 @@ def from_contrib_estimator(estimator, `Estimator`. """ if isinstance(estimator, core_estimator.Estimator): - raise TypeError('Espected estimator to be of type ' + raise TypeError('Expected estimator to be of type ' 'tf.contrib.learn.Estimator, but got type ' 'tf.python.estimator.Estimator. You likely want to call ' 'from_estimator.') @@ -88,7 +88,7 @@ def from_estimator(estimator, `Estimator`. """ if isinstance(estimator, contrib_estimator.Estimator): - raise TypeError('Espected estimator to be of type ' + raise TypeError('Expected estimator to be of type ' 'tf.python.estimator.Estimator, but got type ' 'tf.contrib.learn.Estimator. You likely want to call ' 'from_contrib_estimator.') diff --git a/tensorflow/contrib/py2tf/converters/logical_expressions.py b/tensorflow/contrib/py2tf/converters/logical_expressions.py deleted file mode 100644 index df980d41c9c57e325bee9a1fa870d9c95f46ea41..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/py2tf/converters/logical_expressions.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Converter for logical expressions. - -e.g. `a and b -> tf.logical_and(a, b)`. This is not done automatically in TF. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import gast - -from tensorflow.contrib.py2tf.pyct import parser - - -class LogicalExpressionTransformer(gast.NodeTransformer): - """Converts logical expressions to corresponding TF calls.""" - - def __init__(self): - # TODO(mdan): Look into replacing with bitwise operators instead. - self.op_mapping = { - gast.And: 'tf.logical_and', - gast.Or: 'tf.logical_or', - gast.Not: 'tf.logical_not', - gast.Eq: 'tf.equal', - } - - def visit_Compare(self, node): - node = self.generic_visit(node) - if len(node.ops) > 1: - raise NotImplementedError() - cmp_type = type(node.ops[0]) - if cmp_type in self.op_mapping: - tf_function = parser.parse_str(self.op_mapping[cmp_type]).body[0].value - return gast.Call( - func=tf_function, args=[node.left, node.comparators[0]], keywords=[]) - return node - - def visit_UnaryOp(self, node): - node = self.generic_visit(node) - if isinstance(node.op, gast.Not): - tf_function = parser.parse_str(self.op_mapping[type( - node.op)]).body[0].value - node = gast.Call(func=tf_function, args=[node.operand], keywords=[]) - return node - - def visit_BoolOp(self, node): - # TODO(mdan): A normalizer may be useful here. Use ANF? - node = self.generic_visit(node) - tf_function = parser.parse_str(self.op_mapping[type(node.op)]).body[0].value - left = node.values[0] - for i in range(1, len(node.values)): - left = gast.Call( - func=tf_function, args=[left, node.values[i]], keywords=[]) - return left - - -def transform(node): - transformer = LogicalExpressionTransformer() - node = transformer.visit(node) - return node diff --git a/tensorflow/contrib/py2tf/pyct/inspect_utils.py b/tensorflow/contrib/py2tf/pyct/inspect_utils.py deleted file mode 100644 index 86cf52afd59f995284d036080e65eb749dfbca04..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/py2tf/pyct/inspect_utils.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Live entity inspection utilities. - -This module contains whatever inspect doesn't offer out of the box. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import six - -from tensorflow.python.util import tf_inspect - - -def getcallargs(c, *args, **kwargs): - """Extension of getcallargs to non-function callables.""" - if tf_inspect.isfunction(c): - # The traditional getcallargs - return tf_inspect.getcallargs(c, *args, **kwargs) - - if tf_inspect.isclass(c): - # Constructors: pass a fake None for self, then remove it. - arg_map = tf_inspect.getcallargs(c.__init__, None, *args, **kwargs) - assert 'self' in arg_map, 'no "self" argument, is this not a constructor?' - del arg_map['self'] - return arg_map - - if hasattr(c, '__call__'): - # Callable objects: map self to the object itself - return tf_inspect.getcallargs(c.__call__, *args, **kwargs) - - raise NotImplementedError('unknown callable "%s"' % type(c)) - - -def getmethodclass(m, namespace): - """Resolves a function's owner, e.g. a method's class.""" - - # Instance method and class methods: should be bound to a non-null "self". - # If self is a class, then it's a class method. - if hasattr(m, '__self__'): - if m.__self__: - if tf_inspect.isclass(m.__self__): - return m.__self__ - return type(m.__self__) - - # Class and static methods: platform specific. - if hasattr(m, 'im_class'): # Python 2 - return m.im_class - - if hasattr(m, '__qualname__'): # Python 3 - qn = m.__qualname__.split('.') - if len(qn) < 2: - return None - owner_name, func_name = qn[-2:] - assert func_name == m.__name__, ( - 'inconsistent names detected ' - '(__qualname__[1] = "%s", __name__ = "%s") for %s.' % (func_name, - m.__name__, m)) - if owner_name == '': - return None - if owner_name not in namespace: - raise ValueError( - 'Could not resolve name "%s" while analyzing %s. Namespace:\n%s' % - (owner_name, m, namespace)) - return namespace[owner_name] - - if six.PY2: - # In Python 2 it's impossible, to our knowledge, to detect the class of a - # static function. So we're forced to walk all the objects in the - # namespace and see if they own it. If any reader finds a better solution, - # please let us know. - for _, v in namespace.items(): - if hasattr(v, m.__name__) and getattr(v, m.__name__) is m: - return v - - return None diff --git a/tensorflow/contrib/py2tf/pyct/qual_names.py b/tensorflow/contrib/py2tf/pyct/qual_names.py deleted file mode 100644 index 8717ee6cff198ff31f6cbdb7213e5a8dd3df1149..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/py2tf/pyct/qual_names.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities for manipulating qualified names. - -A qualified name is a uniform way to refer to simple (e.g. 'foo') and composite -(e.g. 'foo.bar') syntactic symbols. - -This is *not* related to the __qualname__ attribute used by inspect, which -refers to scopes. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import gast - -from tensorflow.contrib.py2tf.pyct import anno - - -class QN(object): - """Represents a qualified name.""" - - def __init__(self, base, attr=None): - if attr: - if not isinstance(base, QN): - raise ValueError('For attribute QNs, base must be a QN.') - self._parent = base - self.qn = base.qn + (attr,) - else: - if isinstance(base, QN): - if base.is_composite(): - self._parent = base.parent - else: - self._parent = None - self.qn = base.qn - else: - self._parent = None - self.qn = tuple(base.split('.')) - - def is_composite(self): - return len(self.qn) > 1 - - @property - def parent(self): - if self._parent is None: - raise ValueError('Cannot get parent of simple name "%s".' % self.qn[0]) - return self._parent - - def __hash__(self): - return hash(self.qn) - - def __eq__(self, other): - return self.qn == other.qn - - def __str__(self): - return '.'.join(self.qn) - - def __repr__(self): - return str(self) - - def ssf(self): - """Simple symbol form.""" - return '_'.join(self.qn) - - def ast(self): - # The caller must adjust the context appropriately. - if self.is_composite(): - return gast.Attribute(self.parent.ast(), self.qn[-1], None) - return gast.Name(self.qn[0], None, None) - - -class QnResolver(gast.NodeTransformer): - """Annotates nodes with QN information. - - Note: Not using NodeAnnos to avoid circular dependencies. - """ - - def visit_Name(self, node): - self.generic_visit(node) - anno.setanno(node, anno.Basic.QN, QN(node.id)) - return node - - def visit_Attribute(self, node): - self.generic_visit(node) - anno.setanno(node, anno.Basic.QN, - QN(anno.getanno(node.value, anno.Basic.QN), node.attr)) - return node - - -def resolve(node): - return QnResolver().visit(node) diff --git a/tensorflow/contrib/py2tf/pyct/qual_names_test.py b/tensorflow/contrib/py2tf/pyct/qual_names_test.py deleted file mode 100644 index 1b1eee2deca18bb0540c17d6ee85d421602aa2b7..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/py2tf/pyct/qual_names_test.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for qual_names module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import textwrap - -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.python.platform import test - - -class QNTest(test.TestCase): - - def test_basic(self): - a = qual_names.QN('a') - self.assertEqual(a.qn, ('a',)) - self.assertEqual(str(a), 'a') - self.assertEqual(a.ssf(), 'a') - self.assertEqual(a.ast().id, 'a') - self.assertFalse(a.is_composite()) - with self.assertRaises(ValueError): - _ = a.parent - - a_b = qual_names.QN(a, 'b') - self.assertEqual(a_b.qn, ('a', 'b')) - self.assertEqual(str(a_b), 'a.b') - self.assertEqual(a_b.ssf(), 'a_b') - self.assertEqual(a_b.ast().value.id, 'a') - self.assertEqual(a_b.ast().attr, 'b') - self.assertTrue(a_b.is_composite()) - self.assertEqual(a_b.parent.qn, ('a',)) - - a2 = qual_names.QN(a) - self.assertEqual(a2.qn, ('a',)) - with self.assertRaises(ValueError): - _ = a.parent - - a_b2 = qual_names.QN(a_b) - self.assertEqual(a_b2.qn, ('a', 'b')) - self.assertEqual(a_b2.parent.qn, ('a',)) - - self.assertTrue(a2 == a) - self.assertFalse(a2 is a) - - self.assertTrue(a_b.parent == a) - self.assertTrue(a_b2.parent == a) - - self.assertTrue(a_b2 == a_b) - self.assertFalse(a_b2 is a_b) - self.assertFalse(a_b2 == a) - - with self.assertRaises(ValueError): - qual_names.QN('a', 'b') - - def test_hashable(self): - d = {qual_names.QN('a'): 'a', qual_names.QN('b'): 'b'} - - self.assertEqual(d[qual_names.QN('a')], 'a') - self.assertEqual(d[qual_names.QN('b')], 'b') - self.assertTrue(qual_names.QN('c') not in d) - - -class QNResolverTest(test.TestCase): - - def assertQNStringIs(self, node, qn_str): - self.assertEqual(str(anno.getanno(node, anno.Basic.QN)), qn_str) - - def test_resolve(self): - samples = """ - a - a.b - (c, d.e) - [f, (g.h.i)] - j(k, l) - """ - nodes = qual_names.resolve(parser.parse_str(textwrap.dedent(samples))) - nodes = tuple(n.value for n in nodes.body) - - self.assertQNStringIs(nodes[0], 'a') - self.assertQNStringIs(nodes[1], 'a.b') - self.assertQNStringIs(nodes[2].elts[0], 'c') - self.assertQNStringIs(nodes[2].elts[1], 'd.e') - self.assertQNStringIs(nodes[3].elts[0], 'f') - self.assertQNStringIs(nodes[3].elts[1], 'g.h.i') - self.assertQNStringIs(nodes[4].func, 'j') - self.assertQNStringIs(nodes[4].args[0], 'k') - self.assertQNStringIs(nodes[4].args[1], 'l') - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/py2tf/pyct/templates_test.py b/tensorflow/contrib/py2tf/pyct/templates_test.py deleted file mode 100644 index 8ccfde8573724741b0bbe4eacb3c54beb381ee7e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/py2tf/pyct/templates_test.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for templates module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import gast - -from tensorflow.contrib.py2tf.pyct import compiler -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.python.platform import test - - -class TemplatesTest(test.TestCase): - - def test_replace_tuple(self): - template = """ - def test_fn(a, c): - return b, - """ - - node = templates.replace(template, b=('a', 'c'))[0] - result, _ = compiler.ast_to_object(node) - - self.assertEquals((2, 3), result.test_fn(2, 3)) - - def test_replace_variable(self): - template = """ - def test_fn(a): - a += 1 - a = 2 * a + 1 - return b - """ - - node = templates.replace(template, a='b')[0] - result, _ = compiler.ast_to_object(node) - self.assertEquals(7, result.test_fn(2)) - - def test_replace_function_name(self): - template = """ - def fname(a): - a += 1 - a = 2 * a + 1 - return a - """ - - node = templates.replace(template, fname='test_fn')[0] - result, _ = compiler.ast_to_object(node) - self.assertEquals(7, result.test_fn(2)) - - def test_code_block(self): - template = """ - def test_fn(a): - block - return a - """ - - node = templates.replace( - template, - block=[ - gast.Assign([ - gast.Name('a', None, None) - ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))), - ] * 2)[0] - result, _ = compiler.ast_to_object(node) - self.assertEquals(3, result.test_fn(1)) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/py2tf/utils/__init__.py b/tensorflow/contrib/py2tf/utils/__init__.py deleted file mode 100644 index d931322bf34cc36b614e587bbf5a36f5c1a4e38c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/py2tf/utils/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility module that contains APIs usable in the generated code.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.py2tf.utils.context_managers import control_dependency_on_returns -from tensorflow.contrib.py2tf.utils.misc import alias_tensors -from tensorflow.contrib.py2tf.utils.misc import dynamic_len -from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_cond -from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_while -from tensorflow.contrib.py2tf.utils.printing import call_print -from tensorflow.contrib.py2tf.utils.py_func import wrap_py_func -from tensorflow.contrib.py2tf.utils.type_check import is_tensor diff --git a/tensorflow/contrib/py2tf/utils/py_func.py b/tensorflow/contrib/py2tf/utils/py_func.py deleted file mode 100644 index 838872d092a3ab07e965180eff4fec7ff6c4ccf9..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/py2tf/utils/py_func.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Pyfunc creation utilities.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import script_ops - - -def wrap_py_func(f, return_dtypes, arguments, use_dummy_return=False): - """Helper that wraps a callable to py_func. - - The helper passes tensor arguments through the py_func interface. Non-tensor - arguments are allowed, and will be passed to f directly. Note that non-tensor - arguments are captured by f will not update every time the wrapper is - called (this is consistent with its argument list, which only includes - the tensor arguments). In general, it's safest not to reuse this wrapper. - - Args: - f: Callable - return_dtypes: DType, tuple, list or None, the data type for each of f's - return value. None if f has no return values or use_dummy_return is - True. - arguments: Arguments for f - use_dummy_return: If True, the function will return a dummy value of 1 - and discard its actual return value. - Returns: - The return values of f converted to tensor. - Raises: - ValueError: if the arguments are incorrect. - """ - - if return_dtypes and use_dummy_return: - raise ValueError('if use_dummy_return is True, return_dtypes must be empty') - - n = len(arguments) - arg_is_tensor = tuple(map(tensor_util.is_tensor, arguments)) - index_in_tensor_list = [0] * n - i = 0 - for j in range(n): - index_in_tensor_list[j] = i - if arg_is_tensor[j]: - i += 1 - - def f_wrapper(*tensor_args): - f_args = tuple(tensor_args[index_in_tensor_list[i]] - if arg_is_tensor[i] else arguments[i] for i in range(n)) - retval = f(*f_args) - return 1 if use_dummy_return else retval - - return script_ops.py_func( - f_wrapper, tuple(arguments[i] for i in range(n) if arg_is_tensor[i]), - dtypes.int64 if use_dummy_return else return_dtypes) diff --git a/tensorflow/contrib/py2tf/utils/py_func_test.py b/tensorflow/contrib/py2tf/utils/py_func_test.py deleted file mode 100644 index 776b5309c6f027bb2008aa83d48e4155e817ed97..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/py2tf/utils/py_func_test.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for wrap_py_func module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.py2tf.utils import py_func -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.platform import test - - -class PyFuncTest(test.TestCase): - - def test_wrap_py_func_simple(self): - - def test_fn(a, b, c): - return a + b + c - - with self.test_session() as sess: - tensor_1 = constant_op.constant(1) - self.assertEqual(3, - sess.run( - py_func.wrap_py_func(test_fn, dtypes.int64, - (1, tensor_1, 1)))) - self.assertEqual(3, - sess.run( - py_func.wrap_py_func(test_fn, dtypes.int64, - (1, 1, 1)))) - self.assertEqual(3, - sess.run( - py_func.wrap_py_func(test_fn, dtypes.int64, - (tensor_1, 1, tensor_1)))) - - def test_wrap_py_func_complex_args(self): - - class TestClass(object): - - def __init__(self): - self.foo = 5 - - def test_fn(a, b): - return a * b.foo - - with self.test_session() as sess: - self.assertEqual(35, - sess.run( - py_func.wrap_py_func(test_fn, dtypes.int64, - (7, TestClass())))) - self.assertEqual( - 35, - sess.run( - py_func.wrap_py_func(test_fn, dtypes.int64, - (constant_op.constant(7), TestClass())))) - - def test_wrap_py_func_dummy_return(self): - - side_counter = [0] - - def test_fn(_): - side_counter[0] += 1 - - with self.test_session() as sess: - self.assertEqual(1, - sess.run( - py_func.wrap_py_func(test_fn, None, (5,), True))) - self.assertEqual([1], side_counter) - self.assertEqual(1, - sess.run( - py_func.wrap_py_func(test_fn, None, - (constant_op.constant(5),), - True))) - self.assertEqual([2], side_counter) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index aec9f47ccb20349c08bbe2fd813ee24a807f9fe3..0b7629620418340d803753be0df1f04c342dc490 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -24,6 +24,7 @@ py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", "//tensorflow/python:session", + "//tensorflow/python:variable_scope", ], ) diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md index 8b0e7bb68f5a11f5d1942f7cf048e96768da259e..348c824a4072c3329ac4a3441c19c71598bc9c03 100644 --- a/tensorflow/contrib/quantize/README.md +++ b/tensorflow/contrib/quantize/README.md @@ -3,8 +3,7 @@ tf.contrib.quantize provides tools for transforming graphs to include ops to model quantization of weights, biases and activations during both training and inference. This is done using the -[fake quantization op] -(https://www.tensorflow.org/versions/r0.12/api_docs/python/array_ops/fake_quantization). +[fake quantization op](https://www.tensorflow.org/versions/r0.12/api_docs/python/array_ops/fake_quantization). Recent literature has shown that fixed point networks provide comparable performance to floating point networks [1]. This is achieved by modeling the diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py index 3a1fa61e43986af1a1315d5a9e6f010e802ea157..bf648e158ec15e1bfa962ba7dbe0567263c89c9b 100644 --- a/tensorflow/contrib/quantize/python/common.py +++ b/tensorflow/contrib/quantize/python/common.py @@ -23,6 +23,7 @@ import re from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope @@ -101,7 +102,7 @@ def CreateOrGetQuantizationStep(): Quantization step Tensor. """ quantization_step_name = 'fake_quantization_step' - quantization_step_tensor_name = quantization_step_name + '/AssignAdd:0' + quantization_step_tensor_name = quantization_step_name + '/Identity:0' g = ops.get_default_graph() try: return g.get_tensor_by_name(quantization_step_tensor_name) @@ -118,5 +119,15 @@ def CreateOrGetQuantizationStep(): with g.name_scope(quantization_step_tensor.op.name + '/'): # We return the incremented variable tensor. Since this is used in conds # for quant_delay and freeze_bn_delay, it will run once per graph - # execution. - return state_ops.assign_add(quantization_step_tensor, 1) + # execution. We return an identity to force resource variables and + # normal variables to return a tensor of the same name. + return array_ops.identity( + state_ops.assign_add(quantization_step_tensor, 1)) + + +def DropStringPrefix(s, prefix): + """If the string starts with this prefix, drops it.""" + if s.startswith(prefix): + return s[len(prefix):] + else: + return s diff --git a/tensorflow/contrib/quantize/python/common_test.py b/tensorflow/contrib/quantize/python/common_test.py index d6237fe5e38d905bf262d7be3746b9ee6046da47..06c62f2d265503bf42d46fb682a398ce1f4d15fb 100644 --- a/tensorflow/contrib/quantize/python/common_test.py +++ b/tensorflow/contrib/quantize/python/common_test.py @@ -22,6 +22,7 @@ from tensorflow.contrib.quantize.python import common from tensorflow.python.client import session from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -29,8 +30,15 @@ from tensorflow.python.platform import googletest class CommonTest(test_util.TensorFlowTestCase): def testCreateOrGetQuantizationStep(self): + self._TestCreateOrGetQuantizationStep(False) + + def testCreateOrGetQuantizationStepResourceVar(self): + self._TestCreateOrGetQuantizationStep(True) + + def _TestCreateOrGetQuantizationStep(self, use_resource): g = ops.Graph() with session.Session(graph=g) as sess: + variable_scope.get_variable_scope().set_use_resource(use_resource) quantization_step_tensor = common.CreateOrGetQuantizationStep() # Check that operations are added to the graph. diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index 75d9eb0e58d96e4bb2946684febd250e2e1a6b4a..5750be6f4cbd501ec85656a66b9002a470b1a863 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.util import compat @@ -194,7 +195,7 @@ def _FindFusedBatchNorms(graph): layer_op = match_result.get_op(layer_pattern) layer_tensor = match_result.get_tensor(layer_pattern) bn_op = match_result.get_op(batch_norm_pattern) - batch_epsilon_tensor = bn_op.get_attr('epsilon') + batch_epsilon = bn_op.get_attr('epsilon') # In the MatMul case, the output of batch norm is reshaped back into a # 2D tensor, so the output_tensor is the output of the Reshape op. @@ -207,6 +208,11 @@ def _FindFusedBatchNorms(graph): continue output_tensor = output_reshape_op.outputs[0] + # Ensure that the output tensor has consumers, otherwise this is a dangling + # node and not a match. + if not output_tensor.consumers(): + continue + input_tensor = match_result.get_tensor(input_pattern) weight_tensor = match_result.get_tensor(weight_pattern) gamma_tensor = match_result.get_tensor(gamma_pattern) @@ -231,7 +237,7 @@ def _FindFusedBatchNorms(graph): # The batch variance used during forward and backward prop is biased, # i.e it is calculated as: V=sum(x(k)-mu)^2/N. For the moving average # calculation, the variance is corrected by the term N/N-1 (Bessel's - # correction). The variance tensor read from FuseBatchNorm has bessel's + # correction). The variance tensor read from FuseBatchNorm has Bessel's # correction applied, so we undo it here. scope, sep, _ = bn_op.name.rpartition('/') g = ops.get_default_graph() @@ -270,7 +276,7 @@ def _FindFusedBatchNorms(graph): moving_variance_tensor=moving_variance_tensor, bn_decay_mean_tensor=bn_decay_mean_tensor, bn_decay_var_tensor=bn_decay_var_tensor, - batch_epsilon_tensor=batch_epsilon_tensor) + batch_epsilon=batch_epsilon) def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, @@ -300,7 +306,7 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, Args: context: The scope under which we look for batch norm params - match: Object containg required batch norm tensors for correction + match: Object containing required batch norm tensors for correction computation. freeze_batch_norm_delay: Delay in steps at which computation switches from regular batch norm to frozen mean and variance. @@ -311,11 +317,11 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, """ g = ops.get_default_graph() - with g.name_scope(context + '/batch_norm_correction'): + prefix = '' if not context else context + '/' + with g.name_scope(prefix + 'batch_norm_correction'): recip_sigma_mv = math_ops.rsqrt( - match.moving_variance_tensor + match.batch_epsilon_tensor) - recip_sigma = math_ops.rsqrt( - match.variance_tensor + match.batch_epsilon_tensor) + match.moving_variance_tensor + match.batch_epsilon) + recip_sigma = math_ops.rsqrt(match.variance_tensor + match.batch_epsilon) correction_scale = math_ops.divide( recip_sigma_mv, recip_sigma, name='scale_compute') correction_scale = array_ops.identity( @@ -434,6 +440,9 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay): for bn in common.BatchNormGroups(graph): has_scaling = _HasScaling(graph, input_to_ops_map, bn) + if not _IsValidUnfusedBatchNorm(graph, bn): + continue + # The mangling code intimately depends on BatchNorm node's internals. original_op, folded_op = _CreateFoldedOp( graph, @@ -462,6 +471,15 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay): raise ValueError('Unexpected inputs to op: %s' % add_bypass.name) +def _IsValidUnfusedBatchNorm(graph, context): + """Checks that the output of the unfused batch norm has consumers.""" + add_shift = graph.get_operation_by_name( + context + '/BatchNorm/batchnorm/add_1') + # Ensure that the output tensor of batch norm has consumers, otherwise this + # is a dangling node and not a match. + return bool(add_shift.outputs[0].consumers()) + + def _GetBatchNormParams(graph, context, has_scaling): """Extracts relevant tensors for folding batch norms. @@ -478,7 +496,7 @@ def _GetBatchNormParams(graph, context, has_scaling): batch_variance_tensor = None moving_mean_tensor = None moving_variance_tensor = None - batch_epsilon_tensor = None + batch_epsilon = None bn_decay_mean_tensor = None bn_decay_var_tensor = None @@ -486,15 +504,23 @@ def _GetBatchNormParams(graph, context, has_scaling): base_context = split_context[-1] oplist = graph.get_operations() - op_suffix_gamma = base_context + '/BatchNorm/gamma' op_suffix_mean = base_context + '/BatchNorm/moments/Squeeze' op_suffix_variance = base_context + '/BatchNorm/moments/Squeeze_1' - op_suffix_moving_variance = base_context + '/BatchNorm/moving_variance/read' - op_suffix_moving_mean = base_context + '/BatchNorm/moving_mean/read' op_suffix_epsilon = base_context + '/BatchNorm/batchnorm/add/y' op_suffix_bn_decay_mean = base_context + '/BatchNorm/AssignMovingAvg/decay' op_suffix_bn_decay_var = base_context + '/BatchNorm/AssignMovingAvg_1/decay' + if variable_scope.get_variable_scope().use_resource: + op_suffix_gamma = base_context + '/BatchNorm/gamma/Read/ReadVariableOp' + op_suffix_moving_variance = ( + base_context + '/BatchNorm/moving_variance/Read/ReadVariableOp') + op_suffix_moving_mean = ( + base_context + '/BatchNorm/moving_mean/Read/ReadVariableOp') + else: + op_suffix_gamma = base_context + '/BatchNorm/gamma' + op_suffix_moving_variance = base_context + '/BatchNorm/moving_variance/read' + op_suffix_moving_mean = base_context + '/BatchNorm/moving_mean/read' + # Parse through list of ops to find relevant ops for op in oplist: if op.name.endswith(op_suffix_mean): @@ -509,7 +535,7 @@ def _GetBatchNormParams(graph, context, has_scaling): if op.name.endswith(op_suffix_moving_variance): moving_variance_tensor = graph.get_tensor_by_name(op.name + ':0') if op.name.endswith(op_suffix_epsilon): - batch_epsilon_tensor = graph.get_tensor_by_name(op.name + ':0') + batch_epsilon = graph.get_tensor_by_name(op.name + ':0') if op.name.endswith(op_suffix_bn_decay_mean): bn_decay_mean_tensor = graph.get_tensor_by_name(op.name + ':0') if op.name.endswith(op_suffix_bn_decay_var): @@ -535,7 +561,7 @@ def _GetBatchNormParams(graph, context, has_scaling): moving_variance_tensor=moving_variance_tensor, bn_decay_mean_tensor=bn_decay_mean_tensor, bn_decay_var_tensor=bn_decay_var_tensor, - batch_epsilon_tensor=batch_epsilon_tensor) + batch_epsilon=batch_epsilon) def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, @@ -816,7 +842,7 @@ class _BatchNormMatch(object): def __init__(self, layer_op, bn_op, output_tensor, input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, variance_tensor, moving_mean_tensor, moving_variance_tensor, - bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon_tensor): + bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon): self._layer_op = layer_op self._bn_op = bn_op self._output_tensor = output_tensor @@ -830,7 +856,7 @@ class _BatchNormMatch(object): self._moving_variance_tensor = moving_variance_tensor self._bn_decay_mean_tensor = bn_decay_mean_tensor self._bn_decay_var_tensor = bn_decay_var_tensor - self._batch_epsilon_tensor = batch_epsilon_tensor + self._batch_epsilon = batch_epsilon @property def layer_op(self): @@ -877,8 +903,8 @@ class _BatchNormMatch(object): return self._moving_variance_tensor @property - def batch_epsilon_tensor(self): - return self._batch_epsilon_tensor + def batch_epsilon(self): + return self._batch_epsilon @property def bn_decay_mean_tensor(self): diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py index c90a18ab0357f1bcbc5d8ccd48edf894d7baf5f9..af31467476b1536adef2bb74308fd1093f7bea7a 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py @@ -128,6 +128,9 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + for op in g.get_operations(): + self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) + def testFoldConv2d(self): self._RunTestOverParameters(self._TestFoldConv2d) @@ -196,6 +199,9 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + for op in g.get_operations(): + self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) + def testFoldConv2dUnknownShape(self): self._RunTestOverParameters(self._TestFoldConv2dUnknownShape) @@ -260,6 +266,9 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + for op in g.get_operations(): + self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) + def testFoldFullyConnectedLayer(self): self._RunTestOverParameters(self._TestFoldFullyConnectedLayer) @@ -337,6 +346,9 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + for op in g.get_operations(): + self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) + def testFoldDepthwiseConv2d(self): self._RunTestOverParameters(self._TestFoldDepthwiseConv2d) diff --git a/tensorflow/contrib/quantize/python/graph_matcher.py b/tensorflow/contrib/quantize/python/graph_matcher.py index b458f039df0523b5b8b07cff7d14643154124b95..bacc707a3abb5539b3b119c1ebc17bd7b30efc5b 100644 --- a/tensorflow/contrib/quantize/python/graph_matcher.py +++ b/tensorflow/contrib/quantize/python/graph_matcher.py @@ -103,7 +103,7 @@ class OneofPattern(Pattern): class MatchResult(object): r"""Encapsulates the result of a match done by GraphMatcher. - MatchResult contains a map from OpTypePattern to the matching op and tensor. + MatchResult contains a map from Pattern to the matching op and tensor. When the matching op has multiple output tensors, the matching tensor is the output tensor used by the matching op of the parent pattern. E.g., when we match graph @@ -138,7 +138,7 @@ class MatchResult(object): self._name_to_pattern[pattern.name] = pattern def _to_pattern(self, pattern_or_name): - if isinstance(pattern_or_name, OpTypePattern): + if isinstance(pattern_or_name, Pattern): return pattern_or_name if isinstance(pattern_or_name, str): @@ -146,8 +146,8 @@ class MatchResult(object): return None return self._name_to_pattern[pattern_or_name] - raise ValueError('pattern_or_name has type %s. Expect OpTypePattern or str.' - % type(pattern_or_name)) + raise ValueError('pattern_or_name has type %s. Expect Pattern or str.' % + type(pattern_or_name)) def _get_op_tensor(self, pattern_or_name): pattern = self._to_pattern(pattern_or_name) diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py index 0a8e35080cb08f71dc28e33c6138a12656e5a5ea..a4f7b1b22139588be29171126d43b872d6658168 100644 --- a/tensorflow/contrib/quantize/python/quant_ops.py +++ b/tensorflow/contrib/quantize/python/quant_ops.py @@ -282,8 +282,8 @@ def _FakeQuantWithMinMaxVars(inputs, min_var, max_var, per_channel, num_bits, Args: inputs: a tensor containing values to be quantized. min_var: a variable containing quantization range lower end(s). - max_var: a variable containing quantization range lupper end(s). - per_channel: a boolean specifying whether to use per-channel quantizatioh. + max_var: a variable containing quantization range upper end(s). + per_channel: a boolean specifying whether to use per-channel quantization. num_bits: Number of bits to use for quantization, must be between 2 and 8. narrow_range: Whether to use the narrow quantization range [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1]. diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 5fd806d195dce671d079386ea4b6c89042e26cf6..019d123a68602fb15c1ae914f3d5621290deeb00 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -34,10 +34,6 @@ _QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'} # Activations that are supported by the quantization rewrite. _ACTIVATION_TYPES = {'Relu', 'Relu6', 'Identity'} -# Weight types that are supported by the quantization rewrite. -# TODO(suharshs): Add support for ResourceVariable. -_WEIGHT_TYPES = {'Variable', 'VariableV2'} - def Quantize(graph, is_training, @@ -45,7 +41,7 @@ def Quantize(graph, activation_bits=8, ema_decay=0.999, quant_delay=None, - vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES): + vars_collection=ops.GraphKeys.GLOBAL_VARIABLES): """Updates graph with quantization operations. Args: @@ -124,21 +120,61 @@ def Quantize(graph, vars_collection=vars_collection, bits=activation_bits) + # Quantize bypass ops that occur after the activation. + if layer_match.post_activation_bypass_op is not None: + post_activation_bypass_context = re.search( + r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name).group(1) + _InsertQuantOp( + post_activation_bypass_context, + 'post_activation_bypass_quant', + layer_match.post_activation_bypass_op, + input_to_ops_map.ConsumerOperations( + layer_match.post_activation_bypass_op), + is_training, + moving_avg=True, + ema_decay=ema_decay, + quant_delay=quant_delay, + vars_collection=vars_collection, + bits=activation_bits) + def _FindLayersToQuantize(graph): """Matches layers in graph to quantize. + The following patterns get matched. Nodes surrounded by [] will be + optionally matched: + + weight|folded_weight + / + conv|fc + | + [post_conv_correction] + | + biasadd|folded_bias + | + [bypass] + | + activation + | + [post_activation_bypass] + + Match replacements: + If weight|folded_weight is found, FakeQuant is added afterwards. + If bypass is found, FakeQuant is added before and after. + If activation is found, FakeQuant is added afterwards. + If post_activation_bypass is found, FakeQuant is added afterwards. + Args: graph: Graph to perform match on. - Yields: - _LayerMatches. + Returns: + list of _LayerMatches. """ input_pattern = graph_matcher.OpTypePattern('*') - weight_var_pattern = graph_matcher.OpTypePattern('|'.join(_WEIGHT_TYPES)) - weight_pattern = graph_matcher.OpTypePattern( + weight_var_pattern = graph_matcher.OpTypePattern('Variable|VariableV2') + weight_identity_pattern = graph_matcher.OpTypePattern( 'Identity', inputs=[weight_var_pattern]) - + weight_resource_var_pattern = graph_matcher.OpTypePattern('ReadVariableOp') folded_weight_pattern = graph_matcher.OpTypePattern('Mul') # The weights inputs to the layer operation can either be from the Variable or @@ -147,7 +183,10 @@ def _FindLayersToQuantize(graph): '|'.join(_QUANTIZABLE_TYPES), inputs=[ input_pattern, - graph_matcher.OneofPattern([weight_pattern, folded_weight_pattern]) + graph_matcher.OneofPattern([ + weight_identity_pattern, weight_resource_var_pattern, + folded_weight_pattern + ]) ]) folded_bias_mul_pattern = graph_matcher.OpTypePattern( @@ -180,7 +219,7 @@ def _FindLayersToQuantize(graph): [bias_add_pattern, folded_bias_add_pattern]) ]) - # The input to the activation can come from bias add, fold bias add or the + # The input to the activation can come from bias add, fold bias add, the # bypasses. activation_pattern = graph_matcher.OpTypePattern( '|'.join(_ACTIVATION_TYPES), @@ -191,10 +230,62 @@ def _FindLayersToQuantize(graph): ]) ]) + post_activation_bypass_pattern_a = graph_matcher.OpTypePattern( + 'Add', inputs=['*', activation_pattern]) + post_activation_bypass_pattern_b = graph_matcher.OpTypePattern( + 'Add', inputs=[activation_pattern, '*']) + + # The order of the following matching blocks is very important. Since matches + # aren't guaranteed to be disjoint, we structure matches from largest to + # smallest to guarantee that the largest match always wins. Additionally, we + # ensure that we don't match layers multiple times. + + layer_matches = [] + # We use matched_layer_set to ensure that layers aren't matched multiple + # times. + matched_layer_set = set() + + # First, we match layers that have a post activation bypass. We do this first + # to ensure we don't match only the first part of this layer, missing the + # post activation bypass node. + post_activation_bypass_layer_matcher = graph_matcher.GraphMatcher( + graph_matcher.OneofPattern([ + post_activation_bypass_pattern_a, + post_activation_bypass_pattern_b, + ])) + for match_result in post_activation_bypass_layer_matcher.match_graph(graph): + layer_op = match_result.get_op(layer_pattern) + weight_tensor = match_result.get_tensor(weight_identity_pattern) + if weight_tensor is None: + weight_tensor = match_result.get_tensor(weight_resource_var_pattern) + if weight_tensor is None: + weight_tensor = match_result.get_tensor(folded_weight_pattern) + activation_op = match_result.get_op(activation_pattern) + bias_add_op = match_result.get_op(bias_add_pattern) + if bias_add_op is None: + bias_add_op = match_result.get_op(folded_bias_add_pattern) + bypass_op = match_result.get_op(bypass_pattern_a) + if bypass_op is None: + bypass_op = match_result.get_op(bypass_pattern_b) + post_activation_bypass_op = match_result.get_op( + post_activation_bypass_pattern_a) + if post_activation_bypass_op is None: + post_activation_bypass_op = match_result.get_op( + post_activation_bypass_pattern_b) + if layer_op not in matched_layer_set: + matched_layer_set.add(layer_op) + layer_matches.append( + _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op, + post_activation_bypass_op, bias_add_op)) + + # Now, we match the basic layer ending at an activation. We may get duplicate + # matches from above, but we don't add them to layer_matches. layer_matcher = graph_matcher.GraphMatcher(activation_pattern) for match_result in layer_matcher.match_graph(graph): layer_op = match_result.get_op(layer_pattern) - weight_tensor = match_result.get_tensor(weight_pattern) + weight_tensor = match_result.get_tensor(weight_identity_pattern) + if weight_tensor is None: + weight_tensor = match_result.get_tensor(weight_resource_var_pattern) if weight_tensor is None: weight_tensor = match_result.get_tensor(folded_weight_pattern) activation_op = match_result.get_op(activation_pattern) @@ -204,31 +295,53 @@ def _FindLayersToQuantize(graph): bypass_op = match_result.get_op(bypass_pattern_a) if bypass_op is None: bypass_op = match_result.get_op(bypass_pattern_b) - yield _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op, - bias_add_op) - - # Match the final layer, where there will not be an activation and instead - # the output of the final BiasAdd must be quantized, so we treat it as the - # 'activation_op' in the _LayerMatch. - # TODO(suharshs): Figure out how to quantize this final layer across many - # models. + if layer_op not in matched_layer_set: + matched_layer_set.add(layer_op) + layer_matches.append( + _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op, None, + bias_add_op)) + + # Match the final layer, where there may not be an activation and instead + # the output of the final BiasAdd must be quantized. So we treat the BiasAdd + # as the 'activation_op' in the _LayerMatch, to ensure that it's output is + # quantized. final_layer_matcher = graph_matcher.GraphMatcher(bias_add_pattern) for match_result in final_layer_matcher.match_graph(graph): layer_op = match_result.get_op(layer_pattern) - weight_tensor = match_result.get_tensor(weight_pattern) + weight_tensor = match_result.get_tensor(weight_identity_pattern) + if weight_tensor is None: + weight_tensor = match_result.get_tensor(weight_resource_var_pattern) + if weight_tensor is None: + weight_tensor = match_result.get_tensor(folded_weight_pattern) activation_op = match_result.get_op(bias_add_pattern) - yield _LayerMatch(layer_op, weight_tensor, activation_op, None, None) + if activation_op is None: + activation_op = match_result.get_op(folded_bias_add_pattern) + if layer_op not in matched_layer_set: + matched_layer_set.add(layer_op) + layer_matches.append( + _LayerMatch(layer_op, weight_tensor, activation_op, None, None, None)) + + return layer_matches + + +def _HasPostActivationBypass(activation_op): + for activation_tensor in activation_op.outputs: + for output_op in activation_tensor.consumers(): + if output_op.type == 'Add': + return True + return False class _LayerMatch(object): """Contains all information related to a matched Layer.""" def __init__(self, layer_op, weight_tensor, activation_op, bypass_op, - bias_add_op): + post_activation_bypass_op, bias_add_op): self._layer_op = layer_op self._weight_tensor = weight_tensor self._activation_op = activation_op self._bypass_op = bypass_op + self._post_activation_bypass_op = post_activation_bypass_op self._bias_add_op = bias_add_op @property @@ -247,6 +360,10 @@ class _LayerMatch(object): def bypass_op(self): return self._bypass_op + @property + def post_activation_bypass_op(self): + return self._post_activation_bypass_op + @property def bias_add_op(self): return self._bias_add_op @@ -263,12 +380,12 @@ def _InsertQuantOp(context, bits=8, ema_decay=0.999, quant_delay=None, - vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, + vars_collection=ops.GraphKeys.GLOBAL_VARIABLES, narrow_range=False): """Inserts a quant op between a producer op and (multiple) consumer ops. Args: - context: Context w,here producer and consumer operations are nested. + context: Context where producer and consumer operations are nested. name: Name for the new quantization op within the context. producer: Producer operation of the pairs where quantization will be inserted. @@ -294,7 +411,23 @@ def _InsertQuantOp(context, consumer operation. """ name_prefix = _AddContextToName(context, name) + # This is needed on TPU where name_scope == 'TPUReplicate/loop', and + # name_prefix starts with 'TPUReplicate/loop/'; without dropping it + # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which + # breaks things later. + name_prefix = common.DropStringPrefix(name_prefix, ops.get_name_scope() + '/') + inputs = producer.outputs[0] + # Prevent ops from being quantized multiple times. Bypass ops can sometimes + # overlap between multiple matches, so we need to ensure that we don't + # add duplicate FakeQuant operations. + fake_quant_ops = set([ + 'FakeQuantWithMinMaxVars', + 'FakeQuantWithMinMaxArgs' + ]) + if fake_quant_ops.intersection(set([c.type for c in inputs.consumers()])): + return + if moving_avg: quant = ( quant_ops.MovingAvgQuantize( diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py index 5a3a74cec4864ad3808d485849334c81f569d300..0b74b438ac317967bbe10ad936b451de6f69d62c 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph.py +++ b/tensorflow/contrib/quantize/python/quantize_graph.py @@ -72,6 +72,8 @@ def _create_graph(input_graph=None, def create_training_graph(input_graph=None, quant_delay=0): """Rewrites a training input_graph in place for simulated quantization. + Variables added by the rewrite get added to the global variables collection. + The graph has fake quantization ops inserted to simulate the error introduced by quantization. Since the graph is transformed in place, the expected behavior of previously held references to nodes and tensors may @@ -97,16 +99,7 @@ def create_training_graph(input_graph=None, quant_delay=0): # TODO(raghuramank) Need to have freeze_bn_delay be a function of batch size # Currently the values below are hardcoded for mobilenetV1 on imagenet # Please use the experimental API if you need to tune these values. - if quant_delay == 0: - # Corresponds to case of restoring from a floating point checkpoint - # In this case, we can freeze the moving mean and variance early on and - # switch to using them during training. Therefore, freeze_bn_delay is set to - # 2e5. - freeze_bn_delay = int(2e5) - else: - # If training from scratch, set freeze_bn_delay to 100 epochs after quant - # delay. With a batch size of 64, this corresponds to 20000*100=2M steps. - freeze_bn_delay = quant_delay + int(2e6) + freeze_bn_delay = None _create_graph( input_graph=input_graph, @@ -118,6 +111,8 @@ def create_training_graph(input_graph=None, quant_delay=0): def create_eval_graph(input_graph=None): """Rewrites an eval input_graph in place for simulated quantization. + Variables added by the rewrite get added to the global variables collection. + The graph has fake quantization ops inserted to simulate the error introduced by quantization. Since the graph is transformed in place, the expected behavior of previously held references to nodes and tensors may @@ -138,9 +133,11 @@ def experimental_create_training_graph(input_graph=None, weight_bits=8, activation_bits=8, quant_delay=0, - freeze_bn_delay=int(2e5)): + freeze_bn_delay=None): """Rewrites a training input_graph in place for simulated quantization. + Variables added by the rewrite get added to the global variables collection. + This function has additional experimental options not (yet) available to create_training_graph. The resulting behavior may be undefined. @@ -158,7 +155,7 @@ def experimental_create_training_graph(input_graph=None, often fail. Args: - input_graph: The tf.Graph to be transformed,if None then defaults to the + input_graph: The tf.Graph to be transformed, if None then defaults to the default graph. weight_bits: Number of bits to use for quantizing weights. activation_bits: Number of bits to use for quantizing activations. @@ -188,6 +185,8 @@ def experimental_create_eval_graph(input_graph=None, activation_bits=8): """Rewrites an eval input_graph in place for simulated quantization. + Variables added by the rewrite get added to the global variables collection. + This function has additional experimental options not (yet) available to create_eval_graph. The resulting behavior may be undefined. diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py index 639a7454a92aebd7289c59498cebff82cc003f75..db745aa56212af6a9c20e06ee9e4e5d6e27cf3c3 100644 --- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py +++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import googletest batch_norm = layers.batch_norm @@ -56,52 +57,46 @@ class QuantizeTest(test_util.TensorFlowTestCase): (array_ops.identity, 'Identity', True, 5000), ] for params in parameters_list: - test_fn(params[0], params[1], params[2], params[3]) + # Test everything with resource variables and normal variables. + test_fn(params[0], params[1], params[2], params[3], False) + test_fn(params[0], params[1], params[2], params[3], True) - def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name, - with_bypass, delay): - """Tests quantization: inputs -> Conv2d no batch norm -> Activation. - - Args: - activation: Callable that returns an Operation, a factory method for the - Activation. - activation_op_name: String, name of the Activation operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Activation. - delay: Int (optional), delay in number of steps until quantization starts. - """ - graph = ops.Graph() - with graph.as_default(): - batch_size, height, width, depth = 5, 128, 128, 3 - inputs = array_ops.zeros((batch_size, height, width, depth)) - stride = 1 if with_bypass else 2 - out_depth = 3 if with_bypass else 32 - activation_fn = None if with_bypass else activation - scope = 'test/test2' if with_bypass else 'test' - node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=activation_fn, scope=scope) - if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') - node = activation(node, name='test/' + activation_op_name) - update_barrier = control_flow_ops.no_op(name='update_barrier') - with ops.control_dependencies([update_barrier]): - array_ops.identity(node, name='control_dependency') - - quantize.Quantize(graph, True, quant_delay=delay) + def _AssertCorrectQuantizedGraphWithoutBatchNorm( + self, graph, scope, layer, activation_op_name, with_bypass, delay, + use_resource): quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) - expected_inputs = [ - scope + '/weights_quant/AssignMinLast', - scope + '/weights_quant/AssignMaxLast', scope + '/weights/read' - ] + + # Assemble the expected inputs. + if use_resource: + expected_inputs = [ + scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp', + scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', + ] + if layer == 'DepthwiseConv2dNative': + expected_inputs.append(scope + '/depthwise/ReadVariableOp') + else: + expected_inputs.append(scope + '/' + layer + '/ReadVariableOp') + else: + expected_inputs = [ + scope + '/weights_quant/AssignMinLast', + scope + '/weights_quant/AssignMaxLast', + ] + if layer == 'DepthwiseConv2dNative': + expected_inputs.append(scope + '/depthwise_weights/read') + else: + expected_inputs.append(scope + '/weights/read') + self._AssertInputOpsAre(weights_quant, expected_inputs) if delay and delay > 0: output_op_name = scope + '/weights_quant/delayed_quant/Switch_1' else: - output_op_name = scope + '/Conv2D' + if layer == 'DepthwiseConv2dNative': + output_op_name = scope + '/depthwise' + else: + output_op_name = scope + '/' + layer self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) @@ -109,10 +104,17 @@ class QuantizeTest(test_util.TensorFlowTestCase): conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + quantization_node_name) self.assertEqual(conv_quant.type, quantization_node_name) - expected_inputs = [ - scope + '/conv_quant/AssignMinEma', - scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd' - ] + if use_resource: + expected_inputs = [ + scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp', + scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', + scope + '/BiasAdd', + ] + else: + expected_inputs = [ + scope + '/conv_quant/AssignMinEma', + scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd' + ] self._AssertInputOpsAre(conv_quant, expected_inputs) output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' if delay else 'test/Add') @@ -121,22 +123,76 @@ class QuantizeTest(test_util.TensorFlowTestCase): act_quant = graph.get_operation_by_name('test/act_quant/' + quantization_node_name) self.assertEqual(act_quant.type, quantization_node_name) - - expected_inputs = [ - 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', - 'test/' + activation_op_name - ] + if use_resource: + expected_inputs = [ + 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp', + 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', + 'test/' + activation_op_name, + ] + else: + expected_inputs = [ + 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', + 'test/' + activation_op_name + ] self._AssertInputOpsAre(act_quant, expected_inputs) output_op_name = ('test/act_quant/delayed_quant/Switch_1' if delay else 'control_dependency') self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + self._AssertIdempotent(graph) def testQuantize_Conv2dWithoutBatchNorm(self): self._RunWithoutBatchNormTestOverParameters( self._TestQuantize_Conv2dWithoutBatchNorm) + def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name, + with_bypass, delay, use_resource): + """Tests quantization: inputs -> Conv2d no batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + use_resource: Bool, when true uses resource variables. + """ + graph = ops.Graph() + with graph.as_default(): + variable_scope.get_variable_scope().set_use_resource(use_resource) + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + stride = 1 if with_bypass else 2 + out_depth = 3 if with_bypass else 32 + activation_fn = None if with_bypass else activation + scope = 'test/test2' if with_bypass else 'test' + node = conv2d( + inputs, + out_depth, [5, 5], + stride=stride, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + + quantize.Quantize(graph, True, quant_delay=delay) + + self._AssertCorrectQuantizedGraphWithoutBatchNorm( + graph, scope, 'Conv2D', activation_op_name, with_bypass, delay, + use_resource) + + def testQuantize_FCWithoutBatchNorm(self): + self._RunWithoutBatchNormTestOverParameters( + self._TestQuantize_FCWithoutBatchNorm) + def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name, - with_bypass, delay): + with_bypass, delay, use_resource): """Tests quantization: inputs -> FC no batch norm -> Activation. Args: @@ -146,72 +202,40 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + use_resource: Bool, when true uses resource variables. """ graph = ops.Graph() with graph.as_default(): + variable_scope.get_variable_scope().set_use_resource(use_resource) batch_size, depth = 5, 256 inputs = array_ops.zeros((batch_size, depth)) out_depth = 256 if with_bypass else 128 activation_fn = None if with_bypass else activation scope = 'test/test2' if with_bypass else 'test' - node = fully_connected(inputs, out_depth, - weights_initializer=self._WeightInit(0.03), - activation_fn=activation_fn, scope=scope) + node = fully_connected( + inputs, + out_depth, + weights_initializer=self._WeightInit(0.03), + activation_fn=activation_fn, + scope=scope) if with_bypass: node = math_ops.add(inputs, node, name='test/Add') node = activation(node, name='test/' + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - quantize.Quantize(graph, True, quant_delay=delay) - quantization_node_name = 'FakeQuantWithMinMaxVars' - weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + - quantization_node_name) - self.assertEqual(weights_quant.type, quantization_node_name) - expected_inputs = [ - scope + '/weights_quant/AssignMinLast', - scope + '/weights_quant/AssignMaxLast', scope + '/weights/read' - ] - self._AssertInputOpsAre(weights_quant, expected_inputs) - if delay and delay > 0: - output_op_name = scope + '/weights_quant/delayed_quant/Switch_1' - else: - output_op_name = scope + '/MatMul' - self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) - - if with_bypass: - conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + - quantization_node_name) - self.assertEqual(conv_quant.type, quantization_node_name) - expected_inputs = [ - scope + '/conv_quant/AssignMinEma', - scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd' - ] - self._AssertInputOpsAre(conv_quant, expected_inputs) - output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' - if delay else 'test/Add') - self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) - - act_quant = graph.get_operation_by_name('test/act_quant/' + - quantization_node_name) - self.assertEqual(act_quant.type, quantization_node_name) - expected_inputs = [ - 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', - 'test/' + activation_op_name - ] - self._AssertInputOpsAre(act_quant, expected_inputs) - output_op_name = ('test/act_quant/delayed_quant/Switch_1' - if delay else 'control_dependency') - self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + self._AssertCorrectQuantizedGraphWithoutBatchNorm( + graph, scope, 'MatMul', activation_op_name, with_bypass, delay, + use_resource) - def testQuantize_FCWithoutBatchNorm(self): + def testQuantize_DepthwiseConv2dWithoutBatchNorm(self): self._RunWithoutBatchNormTestOverParameters( - self._TestQuantize_FCWithoutBatchNorm) + self._TestQuantize_DepthwiseConv2dWithoutBatchNorm) def _TestQuantize_DepthwiseConv2dWithoutBatchNorm( - self, activation, activation_op_name, with_bypass, delay): + self, activation, activation_op_name, with_bypass, delay, use_resource): """Tests quantization: inputs -> DWConv2d no batch norm -> Activation. Args: @@ -221,71 +245,36 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + use_resource: Bool, when true uses resource variables. """ graph = ops.Graph() with graph.as_default(): + variable_scope.get_variable_scope().set_use_resource(use_resource) batch_size, height, width, depth = 5, 128, 128, 3 inputs = array_ops.zeros((batch_size, height, width, depth)) stride = 1 if with_bypass else 2 activation_fn = None if with_bypass else activation scope = 'test/test2' if with_bypass else 'test' - node = separable_conv2d(inputs, None, [5, 5], stride=stride, - depth_multiplier=1.0, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=activation_fn, scope=scope) + node = separable_conv2d( + inputs, + None, [5, 5], + stride=stride, + depth_multiplier=1.0, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + scope=scope) if with_bypass: node = math_ops.add(inputs, node, name='test/Add') node = activation(node, name='test/' + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - quantize.Quantize(graph, True, quant_delay=delay) - quantization_node_name = 'FakeQuantWithMinMaxVars' - weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + - quantization_node_name) - self.assertEqual(weights_quant.type, quantization_node_name) - expected_inputs = [ - scope + '/weights_quant/AssignMinLast', - scope + '/weights_quant/AssignMaxLast', - scope + '/depthwise_weights/read' - ] - self._AssertInputOpsAre(weights_quant, expected_inputs) - if delay and delay > 0: - output_op_name = scope + '/weights_quant/delayed_quant/Switch_1' - else: - output_op_name = scope + '/depthwise' - self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) - - if with_bypass: - conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + - quantization_node_name) - self.assertEqual(conv_quant.type, quantization_node_name) - expected_inputs = [ - scope + '/conv_quant/AssignMinEma', - scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd' - ] - self._AssertInputOpsAre(conv_quant, expected_inputs) - output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' - if delay else 'test/Add') - self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) - - act_quant = graph.get_operation_by_name('test/act_quant/' + - quantization_node_name) - self.assertEqual(act_quant.type, quantization_node_name) - expected_inputs = [ - 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', - 'test/' + activation_op_name - ] - self._AssertInputOpsAre(act_quant, expected_inputs) - output_op_name = ('test/act_quant/delayed_quant/Switch_1' - if delay else 'control_dependency') - self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) - - def testQuantize_DepthwiseConv2dWithoutBatchNorm(self): - self._RunWithoutBatchNormTestOverParameters( - self._TestQuantize_DepthwiseConv2dWithoutBatchNorm) + self._AssertCorrectQuantizedGraphWithoutBatchNorm( + graph, scope, 'DepthwiseConv2dNative', activation_op_name, with_bypass, + delay, use_resource) def _RunBatchNormTestOverParameters(self, test_fn): # TODO(suharshs): Use parameterized test once OSS TF supports it. @@ -317,13 +306,88 @@ class QuantizeTest(test_util.TensorFlowTestCase): (array_ops.identity, 'Identity', True, 5000, True) ] for params in parameters_list: - test_fn(params[0], params[1], params[2], params[3], params[4]) + # Test everything with resource variables and normal variables. + test_fn(params[0], params[1], params[2], params[3], params[4], False) + test_fn(params[0], params[1], params[2], params[3], params[4], True) + + def _AssertCorrectQuantizedGraphWithBatchNorm(self, graph, scope, layer, + activation_op_name, with_bypass, + delay, use_resource): + quantization_node_name = 'FakeQuantWithMinMaxVars' + weights_quant = graph.get_operation_by_name( + scope + '/weights_quant/' + quantization_node_name) + self.assertEqual(weights_quant.type, quantization_node_name) + if use_resource: + expected_inputs = [ + scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp', + scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', + ] + else: + expected_inputs = [ + scope + '/weights_quant/' + 'AssignMinLast', + scope + '/weights_quant/' + 'AssignMaxLast' + ] + expected_inputs.append(scope + '/mul_fold') + + self._AssertInputOpsAre(weights_quant, expected_inputs) + if layer == 'DepthwiseConv2dNative': + output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' + if delay else '/depthwise_Fold') + else: + output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' + if delay else '/' + layer + '_Fold') + self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) + + if with_bypass: + conv_quant = graph.get_operation_by_name( + scope + '/conv_quant/' + quantization_node_name) + self.assertEqual(conv_quant.type, quantization_node_name) + + if use_resource: + expected_inputs = [ + scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp', + scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', + ] + else: + expected_inputs = [ + scope + '/conv_quant/AssignMinEma', + scope + '/conv_quant/AssignMaxEma', + ] + expected_inputs.append(scope + '/add_fold') + + self._AssertInputOpsAre(conv_quant, expected_inputs) + output_op_name = ( + scope + '/conv_quant/delayed_quant/Switch_1' if delay else 'test/Add') + self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) + + act_quant = graph.get_operation_by_name( + 'test/act_quant/' + quantization_node_name) + self.assertEqual(act_quant.type, quantization_node_name) + + if use_resource: + expected_inputs = [ + 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp', + 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1', + ] + else: + expected_inputs = [ + 'test/act_quant/AssignMinEma', + 'test/act_quant/AssignMaxEma', + ] + expected_inputs.append('test/' + activation_op_name) + + self._AssertInputOpsAre(act_quant, expected_inputs) + output_op_name = ('test/act_quant/delayed_quant/Switch_1' + if delay else 'control_dependency') + self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + self._AssertIdempotent(graph) def testQuantize_Conv2dWithBatchNorm(self): self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm) def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay, fused_batch_norm): + with_bypass, delay, fused_batch_norm, + use_resource): """Tests quantization: inputs -> Conv2d with batch norm -> Activation. Args: @@ -334,9 +398,11 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. fused_batch_norm: Bool, when true use FusedBatchNorm. + use_resource: Bool, when true uses resource variables. """ graph = ops.Graph() with graph.as_default(): + variable_scope.get_variable_scope().set_use_resource(use_resource) batch_size, height, width, depth = 5, 128, 128, 3 inputs = array_ops.zeros((batch_size, height, width, depth)) stride = 1 if with_bypass else 2 @@ -353,7 +419,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): normalizer_params=self._BatchNormParams(fused_batch_norm), scope=scope) - # Manually add a bypass (optionaly) and an activation. + # Manually add a bypass (optional) and an activation. if with_bypass: node = math_ops.add(inputs, node, name='test/Add') @@ -364,52 +430,18 @@ class QuantizeTest(test_util.TensorFlowTestCase): array_ops.identity(node, name='control_dependency') fold_batch_norms.FoldBatchNorms(graph, is_training=True) - quantize.Quantize(graph, True, quant_delay=delay) - quantization_node_name = 'FakeQuantWithMinMaxVars' - weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + - quantization_node_name) - self.assertEqual(weights_quant.type, quantization_node_name) - expected_inputs = [ - scope + '/weights_quant/' + 'AssignMinLast', - scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold' - ] - self._AssertInputOpsAre(weights_quant, expected_inputs) - output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' - if delay else '/Conv2D_Fold') - self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) - - if with_bypass: - conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + - quantization_node_name) - self.assertEqual(conv_quant.type, quantization_node_name) - expected_inputs = [ - scope + '/conv_quant/AssignMinEma', - scope + '/conv_quant/AssignMaxEma', scope + '/add_fold' - ] - self._AssertInputOpsAre(conv_quant, expected_inputs) - output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' - if delay else 'test/Add') - self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) - - act_quant = graph.get_operation_by_name('test/act_quant/' + - quantization_node_name) - self.assertEqual(act_quant.type, quantization_node_name) - expected_inputs = [ - 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', - 'test/' + activation_op_name - ] - self._AssertInputOpsAre(act_quant, expected_inputs) - output_op_name = ('test/act_quant/delayed_quant/Switch_1' - if delay else 'control_dependency') - self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + self._AssertCorrectQuantizedGraphWithBatchNorm( + graph, scope, 'Conv2D', activation_op_name, with_bypass, delay, + use_resource) def testQuantize_FCWithBatchNorm(self): self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm) def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay, fused_batch_norm): + with_bypass, delay, fused_batch_norm, + use_resource): """Tests quantization: inputs -> FC with batch norm -> Activation. Args: @@ -420,9 +452,11 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. fused_batch_norm: Bool, when true use FusedBatchNorm. + use_resource: Bool, when true uses resource variables. """ graph = ops.Graph() with graph.as_default(): + variable_scope.get_variable_scope().set_use_resource(use_resource) batch_size, depth = 5, 256 inputs = array_ops.zeros((batch_size, depth)) out_depth = 256 if with_bypass else 128 @@ -436,7 +470,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): normalizer_params=self._BatchNormParams(fused_batch_norm), scope=scope) - # Manually add a bypass (optionaly) and an activation. + # Manually add a bypass (optional) and an activation. if with_bypass: node = math_ops.add(inputs, node, name='test/Add') @@ -450,43 +484,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantize.Quantize(graph, True, quant_delay=delay) - quantization_node_name = 'FakeQuantWithMinMaxVars' - weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + - quantization_node_name) - self.assertEqual(weights_quant.type, quantization_node_name) - expected_inputs = [ - scope + '/weights_quant/' + 'AssignMinLast', - scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold' - ] - self._AssertInputOpsAre(weights_quant, expected_inputs) - output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' - if delay else '/MatMul_Fold') - self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) - - if with_bypass: - conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + - quantization_node_name) - self.assertEqual(conv_quant.type, quantization_node_name) - expected_inputs = [ - scope + '/conv_quant/AssignMinEma', - scope + '/conv_quant/AssignMaxEma', scope + '/add_fold' - ] - self._AssertInputOpsAre(conv_quant, expected_inputs) - output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' - if delay else 'test/Add') - self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) - - act_quant = graph.get_operation_by_name('test/act_quant/' + - quantization_node_name) - self.assertEqual(act_quant.type, quantization_node_name) - expected_inputs = [ - 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', - 'test/' + activation_op_name - ] - self._AssertInputOpsAre(act_quant, expected_inputs) - output_op_name = ('test/act_quant/delayed_quant/Switch_1' - if delay else 'control_dependency') - self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + self._AssertCorrectQuantizedGraphWithBatchNorm( + graph, scope, 'MatMul', activation_op_name, with_bypass, delay, + use_resource) def testQuantize_DepthwiseConv2dWithBatchNorm(self): self._RunBatchNormTestOverParameters( @@ -494,7 +494,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): def _TestQuantize_DepthwiseConv2dWithBatchNorm( self, activation, activation_op_name, with_bypass, delay, - fused_batch_norm): + fused_batch_norm, use_resource): """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. Args: @@ -505,9 +505,11 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. fused_batch_norm: Bool, when true use FusedBatchNorm. + use_resource: Bool, when true uses resource variables. """ graph = ops.Graph() with graph.as_default(): + variable_scope.get_variable_scope().set_use_resource(use_resource) batch_size, height, width, depth = 5, 128, 128, 3 inputs = array_ops.zeros((batch_size, height, width, depth)) stride = 1 if with_bypass else 2 @@ -524,7 +526,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): normalizer_params=self._BatchNormParams(fused_batch_norm), scope=scope) - # Manually add a bypass (optionaly) and an activation. + # Manually add a bypass (optional) and an activation. if with_bypass: node = math_ops.add(inputs, node, name='test/Add') @@ -535,45 +537,21 @@ class QuantizeTest(test_util.TensorFlowTestCase): array_ops.identity(node, name='control_dependency') fold_batch_norms.FoldBatchNorms(graph, is_training=True) - quantize.Quantize(graph, True, quant_delay=delay) - quantization_node_name = 'FakeQuantWithMinMaxVars' - weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + - quantization_node_name) - self.assertEqual(weights_quant.type, quantization_node_name) - expected_inputs = [ - scope + '/weights_quant/' + 'AssignMinLast', - scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold' - ] - self._AssertInputOpsAre(weights_quant, expected_inputs) - output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' - if delay else '/depthwise_Fold') - self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) - if with_bypass: - conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + - quantization_node_name) - self.assertEqual(conv_quant.type, quantization_node_name) - expected_inputs = [ - scope + '/conv_quant/AssignMinEma', - scope + '/conv_quant/AssignMaxEma', scope + '/add_fold' - ] - self._AssertInputOpsAre(conv_quant, expected_inputs) - output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' - if delay else 'test/Add') - self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) + self._AssertCorrectQuantizedGraphWithBatchNorm( + graph, scope, 'DepthwiseConv2dNative', activation_op_name, + with_bypass, delay, use_resource) - act_quant = graph.get_operation_by_name('test/act_quant/' + - quantization_node_name) - self.assertEqual(act_quant.type, quantization_node_name) - expected_inputs = [ - 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', - 'test/' + activation_op_name - ] - self._AssertInputOpsAre(act_quant, expected_inputs) - output_op_name = ('test/act_quant/delayed_quant/Switch_1' - if delay else 'control_dependency') - self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + def _AssertIdempotent(self, graph): + # Ensure that calling the rewrite again doesn't change the graph. + graph_def_before = str(graph.as_graph_def()) + with graph.as_default(): + # Ensuring that calling the rewrite again doesn't add more nodes. + fold_batch_norms.FoldBatchNorms(graph, is_training=True) + quantize.Quantize(graph, True) + graph_def_after = str(graph.as_graph_def()) + self.assertEqual(graph_def_before, graph_def_after) def _BatchNormParams(self, fused=False): return {'center': True, 'scale': True, 'decay': 1.0 - 0.003, 'fused': fused} @@ -587,7 +565,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): stddev: Standard deviation of normal variable. Returns: - An initialized that initialzes with a truncated normal variable. + An initialized that initializes with a truncated normal variable. """ return init_ops.truncated_normal_initializer(stddev=stddev) diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index ef59475167137e203db2f6ca7f43c7b8f1938060..98f05c8bfc13094aff2839b2a6aa0da5c653da2b 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -135,6 +135,118 @@ class QuantizeTest(test_util.TensorFlowTestCase): self.assertTrue('FakeQuantWithMinMaxVars' in [op.type for op in bias_add_op.outputs[0].consumers()]) + def testPostActivationBypassQuantized(self): + self._RunTestOverParameters(self._TestPostActivationBypassQuantized) + + def _TestPostActivationBypassQuantized(self, is_training): + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32)) + conv = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=array_ops.identity, + scope='test/test') + bypass_tensor = math_ops.add(conv, input2, name='test/add') + _ = array_ops.identity(bypass_tensor, name='test/output') + + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + + # Ensure that the bypass node is preceded and followed by + # FakeQuantWithMinMaxVars operations. + self.assertTrue('FakeQuantWithMinMaxVars' in + [c.type for c in bypass_tensor.consumers()]) + self.assertTrue('FakeQuantWithMinMaxVars' in + [i.op.type for i in bypass_tensor.op.inputs]) + + def testOverlappingPostActivationBypassQuantized(self): + self._RunTestOverParameters( + self._TestOverlappingPostActivationBypassQuantized) + + def _TestOverlappingPostActivationBypassQuantized(self, is_training): + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + conv_input = array_ops.zeros((batch_size, height, width, depth)) + conv1 = conv2d( + conv_input, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=array_ops.identity, + scope='test/test1') + + # The bypass of this conv is the post activation bypass of the previous + # conv. + conv2 = conv2d( + conv_input, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + scope='test/test2') + + bypass_tensor = math_ops.add(conv1, conv2, name='test/add') + _ = array_ops.identity(bypass_tensor, name='test/output') + + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + + # Ensure that the bypass node is preceded and followed by + # FakeQuantWithMinMaxVars operations. + self.assertTrue('FakeQuantWithMinMaxVars' in + [c.type for c in bypass_tensor.consumers()]) + self.assertTrue('FakeQuantWithMinMaxVars' in + [i.op.type for i in bypass_tensor.op.inputs]) + + # Ensure that all the convs and activations are quantized. + op_names = [op.name for op in graph.get_operations()] + self.assertTrue( + 'test/test1/weights_quant/FakeQuantWithMinMaxVars' in op_names) + self.assertTrue( + 'test/test2/weights_quant/FakeQuantWithMinMaxVars' in op_names) + self.assertTrue( + 'test/test1/act_quant/FakeQuantWithMinMaxVars' in op_names) + self.assertTrue('test/act_quant/FakeQuantWithMinMaxVars' in op_names) + self.assertEqual( + 'Identity', + graph.get_operation_by_name( + 'test/test1/act_quant/FakeQuantWithMinMaxVars').inputs[0].op.type) + self.assertEqual( + 'Identity', + graph.get_operation_by_name( + 'test/act_quant/FakeQuantWithMinMaxVars').inputs[0].op.type) + + def testWithNameScope(self): + self._RunTestOverParameters(self._TestWithNameScope) + + def _TestWithNameScope(self, is_training): + graph = ops.Graph() + with graph.as_default(): + with graph.name_scope('name_scope'): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + _ = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + scope='test') + + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + + for op in graph.get_operations(): + self.assertTrue(not op.name.startswith('name_scope/name_scope/'), + 'Broken op: %s' % op.name) + def _WeightInit(self, stddev): """Returns truncated normal variable initializer. @@ -144,7 +256,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): stddev: Standard deviation of normal variable. Returns: - An initialized that initialzes with a truncated normal variable. + An initialized that initializes with a truncated normal variable. """ return init_ops.truncated_normal_initializer(stddev=stddev) diff --git a/tensorflow/contrib/rnn/ops/gru_ops.cc b/tensorflow/contrib/rnn/ops/gru_ops.cc index e91d1e8a80ed252e5f89e116fb0a325be67e3941..9c8e40851a0cc5bd7f37f94a62ecdef7248660c1 100644 --- a/tensorflow/contrib/rnn/ops/gru_ops.cc +++ b/tensorflow/contrib/rnn/ops/gru_ops.cc @@ -69,7 +69,7 @@ Element-wise dot product of a and b is represented by ab Element-wise dot product is represented by \circ Matrix multiplication is represented by * -Baises are initialized with : +Biases are initialized with : `b_ru` - constant_initializer(1.0) `b_c` - constant_initializer(0.0) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 0e62b315b61cb3ceeb5cfd33bf5102a71abef83b..d41fc0b3ac1cee4eacc88cb0f41df1f9ee59e7c3 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -187,6 +187,8 @@ class RNNCellTest(test.TestCase): ], state_is_tuple=False) self.assertEqual(cell.dtype, None) + self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name) + self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name) g, out_m = cell(x, m) # Layer infers the input type. self.assertEqual(cell.dtype, dtype.name) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index 57521c6a9ba0b2d66639017b09c541e270276323..de5df912921932056526e1e6dc5dbb905735f775 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -869,7 +869,7 @@ class LSTMTest(test.TestCase): num_proj = 4 max_length = 8 sequence_length = [4, 6] - in_graph_mode = context.in_graph_mode() + in_graph_mode = not context.executing_eagerly() with self.test_session(graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) @@ -934,8 +934,7 @@ class LSTMTest(test.TestCase): if in_graph_mode: self.assertAllEqual(outputs_static, outputs_dynamic) else: - self.assertAllEqual( - array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy()) + self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic) self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic)) @test_util.run_in_graph_and_eager_modes() @@ -946,7 +945,7 @@ class LSTMTest(test.TestCase): num_proj = 4 max_length = 8 sequence_length = [4, 6] - in_graph_mode = context.in_graph_mode() + in_graph_mode = not context.executing_eagerly() with self.test_session(graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) @@ -1022,10 +1021,9 @@ class LSTMTest(test.TestCase): if in_graph_mode: self.assertAllEqual(outputs_static, outputs_dynamic) else: - self.assertAllEqual( - array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy()) - state_static = [s.numpy() for s in nest.flatten(state_static)] - state_dynamic = [s.numpy() for s in nest.flatten(state_dynamic)] + self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic) + state_static = nest.flatten(state_static) + state_dynamic = nest.flatten(state_dynamic) self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic)) def _testDynamicEquivalentToStaticRNN(self, use_sequence_length): @@ -1043,7 +1041,7 @@ class LSTMTest(test.TestCase): else: sequence_length = None - in_graph_mode = context.in_graph_mode() + in_graph_mode = not context.executing_eagerly() # TODO(b/68017812): Eager ignores operation seeds, so we need to create a # single cell and reuse it across the static and dynamic RNNs. Remove this diff --git a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py index 7957edf68cc8a1461fccfc2de93ad5250dc9fdb5..ffd24218944e150a32b1b915288ab1df90afb45c 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py @@ -54,7 +54,7 @@ def blocks_match(sess, use_peephole): initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=19890212) with variable_scope.variable_scope("test", initializer=initializer): - # magic naming so that the cells pick up these variables and resuse them + # magic naming so that the cells pick up these variables and reuse them if use_peephole: wci = variable_scope.get_variable( "rnn/lstm_cell/w_i_diag", shape=[cell_size], dtype=dtypes.float32) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index 7b883ebc5d7756f1bdf445f900500a4b89e6cffd..63fdd91d368d97007280871f3886e5649e6b2e86 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -455,8 +455,8 @@ class RNNCellTest(test.TestCase): self.assertAllClose(np.concatenate(res[1], axis=1), expected_state) def testAttentionCellWrapperFailures(self): - with self.assertRaisesRegexp(TypeError, - "The parameter cell is not RNNCell."): + with self.assertRaisesRegexp( + TypeError, rnn_cell_impl.ASSERT_LIKE_RNNCELL_ERROR_REGEXP): contrib_rnn_cell.AttentionCellWrapper(None, 0) num_units = 8 @@ -878,7 +878,6 @@ class RNNCellTest(test.TestCase): shape = [2, 1] filter_size = [3] num_features = 1 - batch_size = 2 expected_state_c = np.array( [[[1.4375670191], [1.4375670191]], [[2.7542609292], [2.7542609292]]], dtype=np.float32) @@ -912,7 +911,6 @@ class RNNCellTest(test.TestCase): shape = [2, 2, 1] filter_size = [3, 3] num_features = 1 - batch_size = 2 expected_state_c = np.array( [[[[1.4375670191], [1.4375670191]], [[1.4375670191], [1.4375670191]]], [[[2.7542609292], [2.7542609292]], [[2.7542609292], [2.7542609292]] @@ -954,7 +952,6 @@ class RNNCellTest(test.TestCase): shape = [2, 2, 2, 1] filter_size = [3, 3, 3] num_features = 1 - batch_size = 2 expected_state_c = np.array( [[[[[1.4375670191], [1.4375670191]], [[1.4375670191], [1.4375670191]] ], [[[1.4375670191], [1.4375670191]], [[1.4375670191], @@ -1031,57 +1028,92 @@ class RNNCellTest(test.TestCase): num_units = 4 number_of_groups = 1 - with self.test_session() as sess: - with variable_scope.variable_scope( - "root1", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.ones([batch_size, num_units]) - # When number_of_groups = 1, G-LSTM is equivalent to regular LSTM - gcell = contrib_rnn_cell.GLSTMCell( - num_units=num_units, number_of_groups=number_of_groups) - cell = rnn_cell.LSTMCell(num_units=num_units) - self.assertTrue(isinstance(gcell.state_size, tuple)) - zero_state = gcell.zero_state( - batch_size=batch_size, dtype=dtypes.float32) - gh, gs = gcell(x, zero_state) - h, g = cell(x, zero_state) + # Try with input dimension equal to num_units or not. + for num_inputs in [num_units, num_units + number_of_groups]: + with self.test_session() as sess: + with variable_scope.variable_scope( + "root1_%d" % num_inputs, + initializer=init_ops.constant_initializer(0.5)): + x = array_ops.ones([batch_size, num_inputs]) + # When number_of_groups = 1, G-LSTM is equivalent to regular LSTM + gcell = contrib_rnn_cell.GLSTMCell( + num_units=num_units, number_of_groups=number_of_groups) + cell = rnn_cell.LSTMCell(num_units=num_units) + self.assertTrue(isinstance(gcell.state_size, tuple)) + zero_state = gcell.zero_state( + batch_size=batch_size, dtype=dtypes.float32) + gh, gs = gcell(x, zero_state) + h, g = cell(x, zero_state) - sess.run([variables.global_variables_initializer()]) - glstm_result = sess.run([gh, gs]) - lstm_result = sess.run([h, g]) + sess.run([variables.global_variables_initializer()]) + glstm_result = sess.run([gh, gs]) + lstm_result = sess.run([h, g]) - self.assertAllClose(glstm_result[0], lstm_result[0], 1e-5) - self.assertAllClose(glstm_result[1], lstm_result[1], 1e-5) + self.assertAllClose(glstm_result[0], lstm_result[0], 1e-5) + self.assertAllClose(glstm_result[1], lstm_result[1], 1e-5) # Test that G-LSTM subgroup act like corresponding sub-LSTMs batch_size = 2 num_units = 4 number_of_groups = 2 - with self.test_session() as sess: + # Try with num_inputs equal to or not equal to num_units. + for num_inputs in [num_units, num_units + number_of_groups]: + with self.test_session() as sess: + with variable_scope.variable_scope( + "root2_%d" % num_inputs, + initializer=init_ops.constant_initializer(0.5)): + # input for G-LSTM with 2 groups + glstm_input = array_ops.ones([batch_size, num_inputs]) + gcell = contrib_rnn_cell.GLSTMCell( + num_units=num_units, number_of_groups=number_of_groups) + gcell_zero_state = gcell.zero_state( + batch_size=batch_size, dtype=dtypes.float32) + gh, gs = gcell(glstm_input, gcell_zero_state) + + # input for LSTM cell simulating single G-LSTM group + lstm_input = array_ops.ones( + [batch_size, num_inputs / number_of_groups]) + # note division by number_of_groups. This cell one simulates G-LSTM + # group + cell = rnn_cell.LSTMCell(num_units=int(num_units / number_of_groups)) + cell_zero_state = cell.zero_state( + batch_size=batch_size, dtype=dtypes.float32) + h, g = cell(lstm_input, cell_zero_state) + + sess.run([variables.global_variables_initializer()]) + [gh_res, h_res] = sess.run([gh, h]) + self.assertAllClose(gh_res[:, 0:int(num_units / number_of_groups)], + h_res, 1e-5) + self.assertAllClose(gh_res[:, int(num_units / number_of_groups):], + h_res, 1e-5) + + def testGLSTMCellFailure(self): + batch_size = 2 + num_units = 4 + number_of_groups = 2 + with self.test_session(): with variable_scope.variable_scope( - "root2", initializer=init_ops.constant_initializer(0.5)): - # input for G-LSTM with 2 groups - glstm_input = array_ops.ones([batch_size, num_units]) + "glstm_failure", initializer=init_ops.constant_initializer(0.5)): gcell = contrib_rnn_cell.GLSTMCell( num_units=num_units, number_of_groups=number_of_groups) gcell_zero_state = gcell.zero_state( batch_size=batch_size, dtype=dtypes.float32) - gh, gs = gcell(glstm_input, gcell_zero_state) - # input for LSTM cell simulating single G-LSTM group - lstm_input = array_ops.ones([batch_size, num_units / number_of_groups]) - # note division by number_of_groups. This cell one simulates G-LSTM group - cell = rnn_cell.LSTMCell(num_units=int(num_units / number_of_groups)) - cell_zero_state = cell.zero_state( - batch_size=batch_size, dtype=dtypes.float32) - h, g = cell(lstm_input, cell_zero_state) + # Try an input with statically-unknown innermost dimension. + glstm_input = array_ops.placeholder( + dtypes.float32, shape=[batch_size, None]) + with self.assertRaisesRegexp(ValueError, + "input size must be statically known"): + gcell(glstm_input, gcell_zero_state) - sess.run([variables.global_variables_initializer()]) - [gh_res, h_res] = sess.run([gh, h]) - self.assertAllClose(gh_res[:, 0:int(num_units / number_of_groups)], - h_res, 1e-5) - self.assertAllClose(gh_res[:, int(num_units / number_of_groups):], - h_res, 1e-5) + # Try an input whose innermost dimension isn't divisible into groups. + glstm_input = array_ops.placeholder( + dtypes.float32, shape=[batch_size, 3]) + with self.assertRaisesRegexp( + ValueError, + r"input size \(3\) must be divisible by number_of_groups \(2\)"): + gcell(glstm_input, gcell_zero_state) class LayerNormBasicLSTMCellTest(test.TestCase): @@ -1168,7 +1200,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase): h1 = array_ops.zeros([1, 2]) state1 = rnn_cell.LSTMStateTuple(c1, h1) state = (state0, state1) - single_cell = lambda: contrib_rnn_cell.LayerNormBasicLSTMCell(2, layer_norm=False) + single_cell = lambda: contrib_rnn_cell.LayerNormBasicLSTMCell(2, layer_norm=False) # pylint: disable=line-too-long cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)]) g, out_m = cell(x, state) sess.run([variables.global_variables_initializer()]) @@ -1200,7 +1232,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase): self.assertAllClose(expected_state1_h, actual_state1_h, 1e-5) with variable_scope.variable_scope( - "other", initializer=init_ops.constant_initializer(0.5)) as vs: + "other", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros( [1, 3]) # Test BasicLSTMCell with input_size != num_units. c = array_ops.zeros([1, 2]) @@ -1549,7 +1581,7 @@ class WeightNormLSTMCellTest(test.TestCase): """Compared cell output with pre-calculated values.""" def _cell_output(self, cell): - """Calculate cell output""" + """Calculates cell output.""" with self.test_session() as sess: init = init_ops.constant_initializer(0.5) @@ -1576,7 +1608,7 @@ class WeightNormLSTMCellTest(test.TestCase): return actual_state_c, actual_state_h def testBasicCell(self): - """Tests cell w/o peepholes and w/o normalisation""" + """Tests cell w/o peepholes and w/o normalisation.""" def cell(): return contrib_rnn_cell.WeightNormLSTMCell(2, @@ -1592,7 +1624,7 @@ class WeightNormLSTMCellTest(test.TestCase): self.assertAllClose(expected_h, actual_h, 1e-5) def testNonbasicCell(self): - """Tests cell with peepholes and w/o normalisation""" + """Tests cell with peepholes and w/o normalisation.""" def cell(): return contrib_rnn_cell.WeightNormLSTMCell(2, @@ -1607,9 +1639,8 @@ class WeightNormLSTMCellTest(test.TestCase): self.assertAllClose(expected_c, actual_c, 1e-5) self.assertAllClose(expected_h, actual_h, 1e-5) - def testBasicCellWithNorm(self): - """Tests cell w/o peepholes and with normalisation""" + """Tests cell w/o peepholes and with normalisation.""" def cell(): return contrib_rnn_cell.WeightNormLSTMCell(2, @@ -1625,7 +1656,7 @@ class WeightNormLSTMCellTest(test.TestCase): self.assertAllClose(expected_h, actual_h, 1e-5) def testNonBasicCellWithNorm(self): - """Tests cell with peepholes and with normalisation""" + """Tests cell with peepholes and with normalisation.""" def cell(): return contrib_rnn_cell.WeightNormLSTMCell(2, diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py index 8109ebc718353300f94536c5d7ae3332da584a1d..645f82624bf67b96ffc8520289b293b45f0e69e2 100644 --- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py @@ -40,7 +40,6 @@ from tensorflow.python.util import nest # pylint: disable=protected-access,invalid-name RNNCell = rnn_cell_impl.RNNCell -_like_rnncell = rnn_cell_impl._like_rnncell _WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME _BIAS_VARIABLE_NAME = rnn_cell_impl._BIAS_VARIABLE_NAME # pylint: enable=protected-access,invalid-name @@ -221,8 +220,7 @@ class EmbeddingWrapper(RNNCell): ValueError: if embedding_classes is not positive. """ super(EmbeddingWrapper, self).__init__(_reuse=reuse) - if not _like_rnncell(cell): - raise TypeError("The parameter cell is not RNNCell.") + rnn_cell_impl.assert_like_rnncell("cell", cell) if embedding_classes <= 0 or embedding_size <= 0: raise ValueError("Both embedding_classes and embedding_size must be > 0: " "%d, %d." % (embedding_classes, embedding_size)) @@ -301,8 +299,7 @@ class InputProjectionWrapper(RNNCell): super(InputProjectionWrapper, self).__init__(_reuse=reuse) if input_size is not None: logging.warn("%s: The input_size parameter is deprecated.", self) - if not _like_rnncell(cell): - raise TypeError("The parameter cell is not RNNCell.") + rnn_cell_impl.assert_like_rnncell("cell", cell) self._cell = cell self._num_proj = num_proj self._activation = activation @@ -356,8 +353,7 @@ class OutputProjectionWrapper(RNNCell): ValueError: if output_size is not positive. """ super(OutputProjectionWrapper, self).__init__(_reuse=reuse) - if not _like_rnncell(cell): - raise TypeError("The parameter cell is not RNNCell.") + rnn_cell_impl.assert_like_rnncell("cell", cell) if output_size < 1: raise ValueError("Parameter output_size must be > 0: %d." % output_size) self._cell = cell diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index 4eb4fbcd92f0d7cb3bee712862c8950a1971b632..9e61fc54d10c1b75786450060e428c73974760a7 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -480,8 +480,7 @@ class LSTMBlockWrapper(base_layer.Layer): """Run this LSTM on inputs, starting from the given state. Args: - inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]` - or a list of `time_len` tensors of shape `[batch_size, input_size]`. + inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`. initial_state: a tuple `(initial_cell_state, initial_output)` with tensors of shape `[batch_size, self._num_units]`. If this is not provided, the cell is expected to create a zero initial state of type `dtype`. diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index a6c2d9cdbb2b6f61d59960f708000e945c6115e9..2f6ae9f3678e58dae67bf777991641b10e42ef94 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -534,7 +534,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): initializer: (optional) The initializer to use for the weight and projection matrices, default None. num_unit_shards: (optional) int, default 1, How to split the weight - matrix. If > 1,the weight matrix is stored across num_unit_shards. + matrix. If > 1, the weight matrix is stored across num_unit_shards. forget_bias: (optional) float, default 1.0, The initial bias of the forget gates, used to reduce the scale of forgetting at the beginning of the training. @@ -993,7 +993,7 @@ class BidirectionalGridLSTMCell(GridLSTMCell): initializer: (optional) The initializer to use for the weight and projection matrices, default None. num_unit_shards: (optional) int, default 1, How to split the weight - matrix. If > 1,the weight matrix is stored across num_unit_shards. + matrix. If > 1, the weight matrix is stored across num_unit_shards. forget_bias: (optional) float, default 1.0, The initial bias of the forget gates, used to reduce the scale of forgetting at the beginning of the training. @@ -1143,8 +1143,7 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell): `state_is_tuple` is `False` or if attn_length is zero or less. """ super(AttentionCellWrapper, self).__init__(_reuse=reuse) - if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access - raise TypeError("The parameter cell is not RNNCell.") + rnn_cell_impl.assert_like_rnncell("cell", cell) if nest.is_sequence(cell.state_size) and not state_is_tuple: raise ValueError( "Cell returns tuple of states, but the flag " @@ -2059,16 +2058,19 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell): initializers=None, name="conv_lstm_cell"): """Construct ConvLSTMCell. + Args: conv_ndims: Convolution dimensionality (1, 2 or 3). input_shape: Shape of the input as int tuple, excluding the batch size. output_channels: int, number of output channels of the conv LSTM. kernel_shape: Shape of kernel as in tuple (of size 1,2 or 3). - use_bias: Use bias in convolutions. + use_bias: (bool) Use bias in convolutions. skip_connection: If set to `True`, concatenate the input to the - output of the conv LSTM. Default: `False`. + output of the conv LSTM. Default: `False`. forget_bias: Forget bias. + initializers: Unused. name: Name of the module. + Raises: ValueError: If `skip_connection` is `True` and stride is different from 1 or if `input_shape` is incompatible with `conv_ndims`. @@ -2131,7 +2133,7 @@ class Conv1DLSTMCell(ConvLSTMCell): def __init__(self, name="conv_1d_lstm_cell", **kwargs): """Construct Conv1DLSTM. See `ConvLSTMCell` for more details.""" - super(Conv1DLSTMCell, self).__init__(conv_ndims=1, **kwargs) + super(Conv1DLSTMCell, self).__init__(conv_ndims=1, name=name, **kwargs) class Conv2DLSTMCell(ConvLSTMCell): @@ -2142,7 +2144,7 @@ class Conv2DLSTMCell(ConvLSTMCell): def __init__(self, name="conv_2d_lstm_cell", **kwargs): """Construct Conv2DLSTM. See `ConvLSTMCell` for more details.""" - super(Conv2DLSTMCell, self).__init__(conv_ndims=2, **kwargs) + super(Conv2DLSTMCell, self).__init__(conv_ndims=2, name=name, **kwargs) class Conv3DLSTMCell(ConvLSTMCell): @@ -2153,19 +2155,23 @@ class Conv3DLSTMCell(ConvLSTMCell): def __init__(self, name="conv_3d_lstm_cell", **kwargs): """Construct Conv3DLSTM. See `ConvLSTMCell` for more details.""" - super(Conv3DLSTMCell, self).__init__(conv_ndims=3, **kwargs) + super(Conv3DLSTMCell, self).__init__(conv_ndims=3, name=name, **kwargs) def _conv(args, filter_size, num_features, bias, bias_start=0.0): - """convolution: + """Convolution. + Args: args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D, batch x n, Tensors. filter_size: int tuple of filter height and width. num_features: int, number of features. + bias: Whether to use biases in the convolution layer. bias_start: starting value to initialize the bias; 0 by default. + Returns: A 3D, 4D, or 5D Tensor with shape [batch ... num_features] + Raises: ValueError: if some of the arguments has unspecified or wrong shape. """ @@ -2225,6 +2231,13 @@ class GLSTMCell(rnn_cell_impl.RNNCell): O. Kuchaiev and B. Ginsburg "Factorization Tricks for LSTM Networks", ICLR 2017 workshop. + + In brief, a G-LSTM cell consists of one LSTM sub-cell per group, where each + sub-cell operates on an evenly-sized sub-vector of the input and produces an + evenly-sized sub-vector of the output. For example, a G-LSTM cell with 128 + units and 4 groups consists of 4 LSTMs sub-cells with 32 units each. If that + G-LSTM cell is fed a 200-dim input, then each sub-cell receives a 50-dim part + of the input and produces a 32-dim part of the output. """ def __init__(self, @@ -2298,7 +2311,7 @@ class GLSTMCell(rnn_cell_impl.RNNCell): return self._output_size def _get_input_for_group(self, inputs, group_id, group_size): - """Slices inputs into groups to prepare for processing by cell's groups + """Slices inputs into groups to prepare for processing by cell's groups. Args: inputs: cell input or it's previous state, @@ -2320,9 +2333,12 @@ class GLSTMCell(rnn_cell_impl.RNNCell): """Run one step of G-LSTM. Args: - inputs: input Tensor, 2D, [batch x num_units]. - state: this must be a tuple of state Tensors, both `2-D`, - with column sizes `c_state` and `m_state`. + inputs: input Tensor, 2D, [batch x num_inputs]. num_inputs must be + statically-known and evenly divisible into groups. The innermost + vectors of the inputs are split into evenly-sized sub-vectors and fed + into the per-group LSTM sub-cells. + state: this must be a tuple of state Tensors, both `2-D`, with column + sizes `c_state` and `m_state`. Returns: A tuple containing: @@ -2337,11 +2353,24 @@ class GLSTMCell(rnn_cell_impl.RNNCell): Raises: ValueError: If input size cannot be inferred from inputs via - static shape inference. + static shape inference, or if the input shape is incompatible + with the number of groups. """ (c_prev, m_prev) = state self._batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] + + # If the input size is statically-known, calculate and validate its group + # size. Otherwise, use the output group size. + input_size = inputs.shape[1].value + if input_size is None: + raise ValueError("input size must be statically known") + if input_size % self._number_of_groups != 0: + raise ValueError( + "input size (%d) must be divisible by number_of_groups (%d)" % + (input_size, self._number_of_groups)) + input_group_size = int(input_size / self._number_of_groups) + dtype = inputs.dtype scope = vs.get_variable_scope() with vs.variable_scope(scope, initializer=self._initializer): @@ -2354,8 +2383,7 @@ class GLSTMCell(rnn_cell_impl.RNNCell): with vs.variable_scope("group%d" % group_id): x_g_id = array_ops.concat( [ - self._get_input_for_group(inputs, group_id, - self._group_shape[0]), + self._get_input_for_group(inputs, group_id, input_group_size), self._get_input_for_group(m_prev, group_id, self._group_shape[0]) ], @@ -2684,7 +2712,7 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell): class SRUCell(rnn_cell_impl.LayerRNNCell): - """SRU, Simple Recurrent Unit + """SRU, Simple Recurrent Unit. Implementation based on Training RNNs as Fast as CNNs (cf. https://arxiv.org/abs/1709.02755). @@ -2732,12 +2760,13 @@ class SRUCell(rnn_cell_impl.LayerRNNCell): input_depth = inputs_shape[1].value + # pylint: disable=protected-access self._kernel = self.add_variable( rnn_cell_impl._WEIGHTS_VARIABLE_NAME, shape=[input_depth, 4 * self._num_units]) - + # pylint: enable=protected-access self._bias = self.add_variable( - rnn_cell_impl._BIAS_VARIABLE_NAME, + rnn_cell_impl._BIAS_VARIABLE_NAME, # pylint: disable=protected-access shape=[2 * self._num_units], initializer=init_ops.constant_initializer(0.0, dtype=self.dtype)) @@ -2746,7 +2775,7 @@ class SRUCell(rnn_cell_impl.LayerRNNCell): def call(self, inputs, state): """Simple recurrent unit (SRU) with num_units cells.""" - U = math_ops.matmul(inputs, self._kernel) + U = math_ops.matmul(inputs, self._kernel) # pylint: disable=invalid-name x_bar, f_intermediate, r_intermediate, x_tx = array_ops.split( value=U, num_or_size_splits=4, axis=1) @@ -2876,6 +2905,7 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell): Args: args: a 2D Tensor or a list of 2D, batch x n, Tensors. output_size: int, second dimension of W[i]. + norm: bool, whether to normalize the weights. bias: boolean, whether to add a bias term or not. bias_initializer: starting value to initialize the bias (default is all zeros). diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc index dfa12e873a6aca806031c48d6f92e0432d0ea6e0..a9a32b7b25d6767cc1f944640722e128a9d728b5 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc @@ -74,7 +74,7 @@ class GatherTreeOp : public OpKernel { ctx, step_ids_shape.dim_size(1) == max_sequence_lengths.shape().dim_size(0), errors::InvalidArgument("batch size dimensions step_ids.shape[1] and " - "max_seqeuence_lengths.shape[0] must match. " + "max_sequence_lengths.shape[0] must match. " "but shapes are: ", step_ids_shape.DebugString(), " and ", max_sequence_lengths.shape().DebugString())); diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index b427dff88b2d586ccf8c512bb498cdaf879ac781..07b3ad71d4698b990fc5fbb1dc30fc787872d495 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -222,6 +222,9 @@ class AttentionWrapperTest(test.TestCase): self.assertEqual( (None, batch_size, None), tuple(state_alignment_history.get_shape().as_list())) + nest.assert_same_structure( + cell.state_size, + cell.zero_state(batch_size, dtypes.float32)) # Remove the history from final_state for purposes of the # remainder of the tests. final_state = final_state._replace(alignment_history=()) # pylint: disable=protected-access @@ -782,26 +785,31 @@ class AttentionWrapperTest(test.TestCase): wrapper.BahdanauAttention, wrapper.LuongAttention) expected_final_output = BasicDecoderOutput( - rnn_output=ResultSummary( - shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11798714846372604), - sample_id=ResultSummary( - shape=(5, 3), dtype=dtype('int32'), mean=7.933333333333334)) + rnn_output=ResultSummary(shape=(5, 3, 20), + dtype=dtype('float32'), + mean=0.11723966), + sample_id=ResultSummary(shape=(5, 3), + dtype=dtype('int32'), + mean=9.2666666666666675)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( - c=ResultSummary( - shape=(5, 9), dtype=dtype('float32'), mean=-0.0036486709), - h=ResultSummary( - shape=(5, 9), dtype=dtype('float32'), mean=-0.0018835809)), - attention=ResultSummary( - shape=(5, 20), dtype=dtype('float32'), mean=0.11798714846372604), + c=ResultSummary(shape=(5, 9), + dtype=dtype('float32'), + mean=-0.003545674), + h=ResultSummary(shape=(5, 9), + dtype=dtype('float32'), + mean=-0.0018327223)), + attention=ResultSummary(shape=(5, 20), + dtype=dtype('float32'), + mean=0.11728073), time=3, alignments=( ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125), ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)), + alignment_history=(), attention_state=( ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125), - ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)), - alignment_history=()) + ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125))) expected_final_alignment_history = ( ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125), ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125)) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 926554031775202d7f7d9018cf6ae4efb34fe96b..178328619f087789df040489cd150ba018cc8d14 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -27,6 +27,7 @@ from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops @@ -70,6 +71,98 @@ class TestGatherTree(test.TestCase): self.assertAllEqual(expected_result, res_) + def _test_gather_tree_from_array(self, + depth_ndims=0, + merged_batch_beam=False): + array = np.array( + [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [0, 0, 0]], + [[2, 3, 4], [5, 6, 7], [8, 9, 10], [11, 12, 0]]]).transpose([1, 0, 2]) + parent_ids = np.array( + [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]], + [[0, 0, 0], [1, 1, 0], [2, 0, 1], [0, 1, 0]]]).transpose([1, 0, 2]) + expected_array = np.array( + [[[2, 2, 2], [6, 5, 6], [7, 8, 9], [0, 0, 0]], + [[2, 3, 2], [7, 5, 7], [8, 9, 8], [11, 12, 0]]]).transpose([1, 0, 2]) + sequence_length = [[3, 3, 3], [4, 4, 3]] + + array = ops.convert_to_tensor( + array, dtype=dtypes.float32) + parent_ids = ops.convert_to_tensor( + parent_ids, dtype=dtypes.int32) + expected_array = ops.convert_to_tensor( + expected_array, dtype=dtypes.float32) + + max_time = array_ops.shape(array)[0] + batch_size = array_ops.shape(array)[1] + beam_width = array_ops.shape(array)[2] + + def _tile_in_depth(tensor): + # Generate higher rank tensors by concatenating tensor and tensor + 1. + for _ in range(depth_ndims): + tensor = array_ops.stack([tensor, tensor + 1], -1) + return tensor + + if merged_batch_beam: + array = array_ops.reshape( + array, [max_time, batch_size * beam_width]) + expected_array = array_ops.reshape( + expected_array, [max_time, batch_size * beam_width]) + + if depth_ndims > 0: + array = _tile_in_depth(array) + expected_array = _tile_in_depth(expected_array) + + sorted_array = beam_search_decoder.gather_tree_from_array( + array, parent_ids, sequence_length) + + with self.test_session() as sess: + sorted_array = sess.run(sorted_array) + expected_array = sess.run(expected_array) + self.assertAllEqual(expected_array, sorted_array) + + def test_gather_tree_from_array_scalar(self): + self._test_gather_tree_from_array() + + def test_gather_tree_from_array_1d(self): + self._test_gather_tree_from_array(depth_ndims=1) + + def test_gather_tree_from_array_1d_with_merged_batch_beam(self): + self._test_gather_tree_from_array(depth_ndims=1, merged_batch_beam=True) + + def test_gather_tree_from_array_2d(self): + self._test_gather_tree_from_array(depth_ndims=2) + + +class TestArrayShapeChecks(test.TestCase): + + def _test_array_shape_dynamic_checks(self, static_shape, dynamic_shape, + batch_size, beam_width, is_valid=True): + t = array_ops.placeholder_with_default( + np.random.randn(*static_shape).astype(np.float32), + shape=dynamic_shape) + + batch_size = array_ops.constant(batch_size) + check_op = beam_search_decoder._check_batch_beam(t, batch_size, beam_width) # pylint: disable=protected-access + + with self.test_session() as sess: + if is_valid: + sess.run(check_op) + else: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(check_op) + + def test_array_shape_dynamic_checks(self): + self._test_array_shape_dynamic_checks( + (8, 4, 5, 10), (None, None, 5, 10), 4, 5, is_valid=True) + self._test_array_shape_dynamic_checks( + (8, 20, 10), (None, None, 10), 4, 5, is_valid=True) + self._test_array_shape_dynamic_checks( + (8, 21, 10), (None, None, 10), 4, 5, is_valid=False) + self._test_array_shape_dynamic_checks( + (8, 4, 6, 10), (None, None, None, 10), 4, 5, is_valid=False) + self._test_array_shape_dynamic_checks( + (8, 4), (None, None), 4, 5, is_valid=False) + class TestEosMasking(test.TestCase): """Tests EOS masking used in beam search.""" @@ -319,7 +412,8 @@ class TestLargeBeamStep(test.TestCase): class BeamSearchDecoderTest(test.TestCase): - def _testDynamicDecodeRNN(self, time_major, has_attention): + def _testDynamicDecodeRNN(self, time_major, has_attention, + with_alignment_history=False): encoder_sequence_length = np.array([3, 2, 3, 1, 1]) decoder_sequence_length = np.array([2, 0, 1, 2, 3]) batch_size = 5 @@ -359,7 +453,7 @@ class BeamSearchDecoderTest(test.TestCase): cell=cell, attention_mechanism=attention_mechanism, attention_layer_size=attention_depth, - alignment_history=False) + alignment_history=with_alignment_history) cell_state = cell.zero_state( dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width) if has_attention: @@ -420,6 +514,12 @@ class BeamSearchDecoderTest(test.TestCase): def testDynamicDecodeRNNBatchMajorYesAttention(self): self._testDynamicDecodeRNN(time_major=False, has_attention=True) + def testDynamicDecodeRNNBatchMajorYesAttentionWithAlignmentHistory(self): + self._testDynamicDecodeRNN( + time_major=False, + has_attention=True, + with_alignment_history=True) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 0a53fd66dbe4d28ea102773b9c5bae50b9d18e9c..be537798268b7938bb68e7d96ae2a1d51685433f 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -736,7 +736,7 @@ class _BaseMonotonicAttentionMechanism(_BaseAttentionMechanism): """Base attention mechanism for monotonic attention. Simply overrides the initial_alignments function to provide a dirac - distribution,which is needed in order for the monotonic attention + distribution, which is needed in order for the monotonic attention distributions to have the correct behavior. """ @@ -763,7 +763,7 @@ class _BaseMonotonicAttentionMechanism(_BaseAttentionMechanism): class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): """Monotonic attention mechanism with Bahadanau-style energy function. - This type of attention encorces a monotonic constraint on the attention + This type of attention enforces a monotonic constraint on the attention distributions; that is once the model attends to a given point in the memory it can't attend to any prior points at subsequence output timesteps. It achieves this by using the _monotonic_probability_fn instead of softmax to @@ -867,7 +867,7 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): """Monotonic attention mechanism with Luong-style energy function. - This type of attention encorces a monotonic constraint on the attention + This type of attention enforces a monotonic constraint on the attention distributions; that is once the model attends to a given point in the memory it can't attend to any prior points at subsequence output timesteps. It achieves this by using the _monotonic_probability_fn instead of softmax to @@ -1133,7 +1133,7 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): output_attention: Python bool. If `True` (default), the output at each time step is the attention value. This is the behavior of Luong-style attention mechanisms. If `False`, the output at each time step is - the output of `cell`. This is the beahvior of Bhadanau-style + the output of `cell`. This is the behavior of Bhadanau-style attention mechanisms. In both cases, the `attention` tensor is propagated to the next time step via the state and is used there. This flag only controls whether the attention mechanism is propagated @@ -1152,9 +1152,7 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): is a list, and its length does not match that of `attention_layer_size`. """ super(AttentionWrapper, self).__init__(name=name) - if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access - raise TypeError( - "cell must be an RNNCell, saw type: %s" % type(cell).__name__) + rnn_cell_impl.assert_like_rnncell("cell", cell) if isinstance(attention_mechanism, (list, tuple)): self._is_multi = True attention_mechanisms = attention_mechanism @@ -1280,7 +1278,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): attention_state=self._item_or_tuple( a.state_size for a in self._attention_mechanisms), alignment_history=self._item_or_tuple( - () for _ in self._attention_mechanisms)) # sometimes a TensorArray + a.alignments_size if self._alignment_history else () + for a in self._attention_mechanisms)) # sometimes a TensorArray def zero_state(self, batch_size, dtype): """Return an initial (zero) state tuple for this `AttentionWrapper`. @@ -1320,22 +1319,26 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): cell_state = nest.map_structure( lambda s: array_ops.identity(s, name="checked_cell_state"), cell_state) + initial_alignments = [ + attention_mechanism.initial_alignments(batch_size, dtype) + for attention_mechanism in self._attention_mechanisms] return AttentionWrapperState( cell_state=cell_state, time=array_ops.zeros([], dtype=dtypes.int32), attention=_zero_state_tensors(self._attention_layer_size, batch_size, dtype), - alignments=self._item_or_tuple( - attention_mechanism.initial_alignments(batch_size, dtype) - for attention_mechanism in self._attention_mechanisms), + alignments=self._item_or_tuple(initial_alignments), attention_state=self._item_or_tuple( attention_mechanism.initial_state(batch_size, dtype) for attention_mechanism in self._attention_mechanisms), alignment_history=self._item_or_tuple( - tensor_array_ops.TensorArray(dtype=dtype, size=0, - dynamic_size=True) + tensor_array_ops.TensorArray( + dtype, + size=0, + dynamic_size=True, + element_shape=alignment.shape) if self._alignment_history else () - for _ in self._attention_mechanisms)) + for alignment in initial_alignments)) def call(self, inputs, state): """Perform a step of attention-wrapped RNN. diff --git a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py index ed226239b860e2250072a28a5538b816642ec54b..7eb95e5a70de985dca0d4b565ba03bdf454b6161 100644 --- a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py @@ -59,8 +59,7 @@ class BasicDecoder(decoder.Decoder): Raises: TypeError: if `cell`, `helper` or `output_layer` have an incorrect type. """ - if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access - raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) + rnn_cell_impl.assert_like_rnncell("cell", cell) if not isinstance(helper, helper_py.Helper): raise TypeError("helper must be a Helper, received: %s" % type(helper)) if (output_layer is not None diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 554eb24e5260724a905b099091bf8aea461554cf..184144f64a56358206014a0f75473b4a9b16617a 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.platform import tf_logging from tensorflow.python.util import nest __all__ = [ @@ -121,14 +122,114 @@ def tile_batch(t, multiplier, name=None): return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t) +def gather_tree_from_array(t, parent_ids, sequence_length): + """Calculates the full beams for `TensorArray`s. + + Args: + t: A stacked `TensorArray` of size `max_time` that contains `Tensor`s of + shape `[batch_size, beam_width, s]` or `[batch_size * beam_width, s]` + where `s` is the depth shape. + parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`. + sequence_length: The sequence length of shape `[batch_size, beam_width]`. + + Returns: + A `Tensor` which is a stacked `TensorArray` of the same size and type as + `t` and where beams are sorted in each `Tensor` according to `parent_ids`. + """ + max_time = parent_ids.shape[0].value or array_ops.shape(parent_ids)[0] + batch_size = parent_ids.shape[1].value or array_ops.shape(parent_ids)[1] + beam_width = parent_ids.shape[2].value or array_ops.shape(parent_ids)[2] + + # Generate beam ids that will be reordered by gather_tree. + beam_ids = array_ops.expand_dims( + array_ops.expand_dims(math_ops.range(beam_width), 0), 0) + beam_ids = array_ops.tile(beam_ids, [max_time, batch_size, 1]) + + mask = array_ops.sequence_mask( + sequence_length, maxlen=max_time, dtype=dtypes.int32) + mask = array_ops.transpose(mask, perm=[2, 0, 1]) + + # Use beam_width + 1 to mark the end of beam. + masked_beam_ids = (beam_ids * mask) + (1 - mask) * (beam_width + 1) + + max_sequence_lengths = math_ops.to_int32( + math_ops.reduce_max(sequence_length, axis=1)) + sorted_beam_ids = beam_search_ops.gather_tree( + step_ids=masked_beam_ids, + parent_ids=parent_ids, + max_sequence_lengths=max_sequence_lengths, + end_token=beam_width + 1) + + # For out of range steps, simply copy the same beam. + sorted_beam_ids = array_ops.where( + math_ops.cast(mask, dtypes.bool), x=sorted_beam_ids, y=beam_ids) + + # Generate indices for gather_nd. + time_ind = array_ops.tile(array_ops.reshape( + math_ops.range(max_time), [-1, 1, 1]), [1, batch_size, beam_width]) + batch_ind = array_ops.tile(array_ops.reshape( + math_ops.range(batch_size), [-1, 1, 1]), [1, max_time, beam_width]) + batch_ind = array_ops.transpose(batch_ind, perm=[1, 0, 2]) + indices = array_ops.stack([time_ind, batch_ind, sorted_beam_ids], -1) + + # Gather from a tensor with collapsed additional dimensions. + gather_from = t + final_shape = array_ops.shape(gather_from) + gather_from = array_ops.reshape( + gather_from, [max_time, batch_size, beam_width, -1]) + ordered = array_ops.gather_nd(gather_from, indices) + ordered = array_ops.reshape(ordered, final_shape) + + return ordered + + def _check_maybe(t): - if isinstance(t, tensor_array_ops.TensorArray): - raise TypeError( - "TensorArray state is not supported by BeamSearchDecoder: %s" % t.name) if t.shape.ndims is None: raise ValueError( "Expected tensor (%s) to have known rank, but ndims == None." % t) +def _check_static_batch_beam_maybe(shape, batch_size, beam_width): + """Raises an exception if dimensions are known statically and can not be + reshaped to [batch_size, beam_size, -1]. + """ + reshaped_shape = tensor_shape.TensorShape([batch_size, beam_width, None]) + if (batch_size is not None and shape[0].value is not None + and (shape[0] != batch_size * beam_width + or (shape.ndims >= 2 and shape[1].value is not None + and (shape[0] != batch_size or shape[1] != beam_width)))): + tf_logging.warn("TensorArray reordering expects elements to be " + "reshapable to %s which is incompatible with the " + "current shape %s. Consider setting " + "reorder_tensor_arrays to False to disable TensorArray " + "reordering during the beam search." + % (reshaped_shape, shape)) + return False + return True + +def _check_batch_beam(t, batch_size, beam_width): + """Returns an Assert operation checking that the elements of the stacked + TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point, + the TensorArray elements have a known rank of at least 1. + """ + error_message = ("TensorArray reordering expects elements to be " + "reshapable to [batch_size, beam_size, -1] which is " + "incompatible with the dynamic shape of %s elements. " + "Consider setting reorder_tensor_arrays to False to disable " + "TensorArray reordering during the beam search." + % (t.name)) + rank = t.shape.ndims + shape = array_ops.shape(t) + if rank == 2: + condition = math_ops.equal(shape[1], batch_size * beam_width) + else: + condition = math_ops.logical_or( + math_ops.equal(shape[1], batch_size * beam_width), + math_ops.logical_and( + math_ops.equal(shape[1], batch_size), + math_ops.equal(shape[2], beam_width))) + return control_flow_ops.Assert(condition, [error_message]) + + class BeamSearchDecoder(decoder.Decoder): """BeamSearch sampling decoder. @@ -173,7 +274,8 @@ class BeamSearchDecoder(decoder.Decoder): initial_state, beam_width, output_layer=None, - length_penalty_weight=0.0): + length_penalty_weight=0.0, + reorder_tensor_arrays=True): """Initialize the BeamSearchDecoder. Args: @@ -188,6 +290,12 @@ class BeamSearchDecoder(decoder.Decoder): `tf.layers.Dense`. Optional layer to apply to the RNN output prior to storing the result or sampling. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. + reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell + state will be reordered according to the beam search path. If the + `TensorArray` can be reordered, the stacked form will be returned. + Otherwise, the `TensorArray` will be returned as is. Set this flag to + `False` if the cell state contains `TensorArray`s that are not amenable + to reordering. Raises: TypeError: if `cell` is not an instance of `RNNCell`, @@ -195,14 +303,14 @@ class BeamSearchDecoder(decoder.Decoder): ValueError: If `start_tokens` is not a vector or `end_token` is not a scalar. """ - if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access - raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) + rnn_cell_impl.assert_like_rnncell("cell", cell) # pylint: disable=protected-access if (output_layer is not None and not isinstance(output_layer, layers_base.Layer)): raise TypeError( "output_layer must be a Layer, received: %s" % type(output_layer)) self._cell = cell self._output_layer = output_layer + self._reorder_tensor_arrays = reorder_tensor_arrays if callable(embedding): self._embedding_fn = embedding @@ -300,12 +408,13 @@ class BeamSearchDecoder(decoder.Decoder): """ finished, start_inputs = self._finished, self._start_inputs + dtype = nest.flatten(self._initial_cell_state)[0].dtype log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) array_ops.zeros([self._batch_size], dtype=dtypes.int32), depth=self._beam_width, - on_value=0.0, - off_value=-np.Inf, - dtype=nest.flatten(self._initial_cell_state)[0].dtype) + on_value=ops.convert_to_tensor(0.0, dtype=dtype), + off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), + dtype=dtype) initial_state = BeamSearchDecoderState( cell_state=self._initial_cell_state, @@ -342,6 +451,11 @@ class BeamSearchDecoder(decoder.Decoder): outputs.parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=self._end_token) + if self._reorder_tensor_arrays: + final_state = final_state._replace(cell_state=nest.map_structure( + lambda t: self._maybe_sort_array_beams( + t, outputs.parent_ids, final_state.lengths), + final_state.cell_state)) outputs = FinalBeamSearchDecoderOutput( beam_search_decoder_output=outputs, predicted_ids=predicted_ids) return outputs, final_state @@ -432,9 +546,10 @@ class BeamSearchDecoder(decoder.Decoder): returned unchanged. Raises: - TypeError: If `t` is an instance of `TensorArray`. ValueError: If the rank of `t` is not statically known. """ + if isinstance(t, tensor_array_ops.TensorArray): + return t _check_maybe(t) if t.shape.ndims >= 1: return self._split_batch_beams(t, s) @@ -455,15 +570,55 @@ class BeamSearchDecoder(decoder.Decoder): A reshaped version of t with shape `[batch_size, beam_width] + s`. Raises: - TypeError: If `t` is an instance of `TensorArray`. ValueError: If the rank of `t` is not statically known. """ + if isinstance(t, tensor_array_ops.TensorArray): + return t _check_maybe(t) if t.shape.ndims >= 2: return self._merge_batch_beams(t, s) else: return t + def _maybe_sort_array_beams(self, t, parent_ids, sequence_length): + """Maybe sorts beams within a `TensorArray`. + + Args: + t: A `TensorArray` of size `max_time` that contains `Tensor`s of shape + `[batch_size, beam_width, s]` or `[batch_size * beam_width, s]` where + `s` is the depth shape. + parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`. + sequence_length: The sequence length of shape `[batch_size, beam_width]`. + + Returns: + A `TensorArray` where beams are sorted in each `Tensor` or `t` itself if + it is not a `TensorArray` or does not meet shape requirements. + """ + if not isinstance(t, tensor_array_ops.TensorArray): + return t + # pylint: disable=protected-access + if (not t._infer_shape or not t._element_shape + or t._element_shape[0].ndims is None + or t._element_shape[0].ndims < 1): + shape = ( + t._element_shape[0] if t._infer_shape and t._element_shape + else tensor_shape.TensorShape(None)) + tf_logging.warn("The TensorArray %s in the cell state is not amenable to " + "sorting based on the beam search result. For a " + "TensorArray to be sorted, its elements shape must be " + "defined and have at least a rank of 1, but saw shape: %s" + % (t.handle.name, shape)) + return t + shape = t._element_shape[0] + # pylint: enable=protected-access + if not _check_static_batch_beam_maybe( + shape, tensor_util.constant_value(self._batch_size), self._beam_width): + return t + t = t.stack() + with ops.control_dependencies( + [_check_batch_beam(t, self._batch_size, self._beam_width)]): + return gather_tree_from_array(t, parent_ids, sequence_length) + def step(self, time, inputs, state, name=None): """Perform a decoding step. @@ -570,7 +725,6 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, time = ops.convert_to_tensor(time, name="time") # During the first time step we only consider the initial beam - scores_shape = array_ops.shape(scores) scores_flat = array_ops.reshape(scores, [batch_size, -1]) # Pick the next beams according to the specified successors function @@ -667,9 +821,9 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight): Returns: The scores normalized by the length_penalty. """ - length_penality_ = _length_penalty( + length_penalty_ = _length_penalty( sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight) - return log_probs / length_penality_ + return log_probs / length_penalty_ def _length_penalty(sequence_lengths, penalty_factor): @@ -706,7 +860,7 @@ def _mask_probs(probs, eos_token, finished): unfinished beams remain unchanged. Args: - probs: Log probabiltiies of shape `[batch_size, beam_width, vocab_size]` + probs: Log probabilities of shape `[batch_size, beam_width, vocab_size]` eos_token: An int32 id corresponding to the EOS token to allocate probability to. finished: A boolean tensor of shape `[batch_size, beam_width]` that @@ -759,6 +913,8 @@ def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size, output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)] or the original tensor if its dimensions are too small. """ + if isinstance(gather_from, tensor_array_ops.TensorArray): + return gather_from _check_maybe(gather_from) if gather_from.shape.ndims >= len(gather_shape): return _tensor_gather_helper( diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index f14974b9d5ca8cbcfd9f91086ca0a90ceff48f43..898493662d7594f9996400a9636378db3c6b4cd1 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest @@ -39,6 +40,7 @@ __all__ = ["Decoder", "dynamic_decode"] _transpose_batch_time = rnn._transpose_batch_time # pylint: disable=protected-access +_zero_state_tensors = rnn_cell_impl._zero_state_tensors # pylint: disable=protected-access @six.add_metaclass(abc.ABCMeta) @@ -133,16 +135,8 @@ class Decoder(object): def _create_zero_outputs(size, dtype, batch_size): """Create a zero outputs Tensor structure.""" - def _t(s): - return (s if isinstance(s, ops.Tensor) else constant_op.constant( - tensor_shape.TensorShape(s).as_list(), - dtype=dtypes.int32, - name="zero_suffix_shape")) - def _create(s, d): - return array_ops.zeros( - array_ops.concat( - ([batch_size], _t(s)), axis=0), dtype=d) + return _zero_state_tensors(s, batch_size, d) return nest.map_structure(_create, size, dtype) @@ -212,7 +206,8 @@ def dynamic_decode(decoder, initial_time = constant_op.constant(0, dtype=dtypes.int32) def _shape(batch_size, from_shape): - if not isinstance(from_shape, tensor_shape.TensorShape): + if (not isinstance(from_shape, tensor_shape.TensorShape) or + from_shape.ndims == 0): return tensor_shape.TensorShape(None) else: batch_size = tensor_util.constant_value( diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD index 67011c8fef6c4f54db2626ffe7ae1299bddbb352..75a753ed89a5ea13b7b79f480511979c38f321e3 100644 --- a/tensorflow/contrib/session_bundle/BUILD +++ b/tensorflow/contrib/session_bundle/BUILD @@ -1,9 +1,7 @@ # Description: # TensorFlow Serving session bundle. -package( - default_visibility = ["//visibility:public"], -) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/slim/README.md b/tensorflow/contrib/slim/README.md index 2d9df8f27ee98431f51fd39c168325b8f625dce9..40f484fd78302163ba36142dec057478fe899189 100644 --- a/tensorflow/contrib/slim/README.md +++ b/tensorflow/contrib/slim/README.md @@ -94,7 +94,7 @@ of thin wrapper functions in [variables.py](https://www.tensorflow.org/code/tensorflow/contrib/framework/python/ops/variables.py) which allow callers to easily define variables. -For example, to create a `weight` variable, initialize it using a truncated +For example, to create a `weights` variable, initialize it using a truncated normal distribution, regularize it with an `l2_loss` and place it on the `CPU`, one need only declare the following: diff --git a/tensorflow/contrib/slim/python/slim/data/parallel_reader.py b/tensorflow/contrib/slim/python/slim/data/parallel_reader.py index b3343aef47d9f352c3bcbef4afbe8f9bf2560e6d..99ad48763031cc2f98009449cea050fd90d01eb5 100644 --- a/tensorflow/contrib/slim/python/slim/data/parallel_reader.py +++ b/tensorflow/contrib/slim/python/slim/data/parallel_reader.py @@ -115,8 +115,8 @@ class ParallelReader(io_ops.ReaderBase): reader needs to start reading from a new file since it has finished with the previous file). - A queue runner for enqueing in the `common_queue` is automatically added to - the TF QueueRunners collection. + A queue runner for enqueuing in the `common_queue` is automatically added + to the TF QueueRunners collection. Args: queue: A Queue or a mutable string Tensor representing a handle diff --git a/tensorflow/contrib/slim/python/slim/data/prefetch_queue.py b/tensorflow/contrib/slim/python/slim/data/prefetch_queue.py index 37e9c4754ca62fc02f9146632943a50c33f9423d..62bd20036126b41040ca4329c7f13ea7671a8045 100644 --- a/tensorflow/contrib/slim/python/slim/data/prefetch_queue.py +++ b/tensorflow/contrib/slim/python/slim/data/prefetch_queue.py @@ -36,9 +36,9 @@ def prefetch_queue(tensors, dynamic_pad=False, shared_name=None, name=None): - """Creates a queue to prefetech tensors from `tensors`. + """Creates a queue to prefetch tensors from `tensors`. - A queue runner for enqueing tensors into the prefetch_queue is automatically + A queue runner for enqueuing tensors into the prefetch_queue is automatically added to the TF QueueRunners collection. Example: diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py index b3b61e1dfe5671a7fbbee20b0c577ee5fad0fb9b..f2d31dc8db5688dc9a3308267109214277436040 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py @@ -124,7 +124,7 @@ class BoundingBox(ItemHandler): super(BoundingBox, self).__init__(self._full_keys) def tensors_to_item(self, keys_to_tensors): - """Maps the given dictionary of tensors to a contatenated list of bboxes. + """Maps the given dictionary of tensors to a concatenated list of bboxes. Args: keys_to_tensors: a mapping of TF-Example keys to parsed tensors. diff --git a/tensorflow/contrib/solvers/python/ops/least_squares.py b/tensorflow/contrib/solvers/python/ops/least_squares.py index fb7c0eb649c5216736b239d1a423cdaf7079f582..6e164f53420675d149ded6c1f42ca87bd89b158c 100644 --- a/tensorflow/contrib/solvers/python/ops/least_squares.py +++ b/tensorflow/contrib/solvers/python/ops/least_squares.py @@ -33,7 +33,7 @@ def cgls(operator, rhs, tol=1e-6, max_iter=20, name="cgls"): r"""Conjugate gradient least squares solver. Solves a linear least squares problem \\(||A x - rhs||_2\\) for a single - righ-hand side, using an iterative, matrix-free algorithm where the action of + right-hand side, using an iterative, matrix-free algorithm where the action of the matrix A is represented by `operator`. The CGLS algorithm implicitly applies the symmetric conjugate gradient algorithm to the normal equations \\(A^* A x = A^* rhs\\). The iteration terminates when either diff --git a/tensorflow/contrib/solvers/python/ops/linear_equations.py b/tensorflow/contrib/solvers/python/ops/linear_equations.py index d791d467639b572e7831c1d1a582aa15585649b6..9305c6a11c4ec898c82553773e8e7277a54ab82e 100644 --- a/tensorflow/contrib/solvers/python/ops/linear_equations.py +++ b/tensorflow/contrib/solvers/python/ops/linear_equations.py @@ -41,7 +41,7 @@ def conjugate_gradient(operator, r"""Conjugate gradient solver. Solves a linear system of equations `A*x = rhs` for selfadjoint, positive - definite matrix `A` and righ-hand side vector `rhs`, using an iterative, + definite matrix `A` and right-hand side vector `rhs`, using an iterative, matrix-free algorithm where the action of the matrix A is represented by `operator`. The iteration terminates when either the number of iterations exceeds `max_iter` or when the residual norm has been reduced to `tol` diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index b58c83fdaf574fb349fac57c922f1178b7d13b66..80563c5e150dfb74ef11bc912e95345a1a015212 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -10,12 +10,6 @@ load( "tf_gen_op_wrapper_py", ) -tf_gen_op_wrapper_py( - name = "gen_summary_ops", - out = "gen_summary_ops.py", - deps = ["//tensorflow/core:summary_ops_op_lib"], -) - py_test( name = "summary_ops_test", srcs = ["summary_ops_test.py"], @@ -61,7 +55,6 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - ":gen_summary_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", @@ -72,6 +65,7 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:summary_op_util", + "//tensorflow/python:summary_ops_gen", "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/eager:context", diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index b6249fc92f712b21197c2167fb5d1c4af1f48ca5..bc763fe655edc455e2538e536d6efab314c8228c 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -26,7 +26,6 @@ import time import six -from tensorflow.contrib.summary import gen_summary_ops from tensorflow.core.framework import graph_pb2 from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -35,6 +34,7 @@ from tensorflow.python.framework import ops from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_summary_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import summary_op_util @@ -110,7 +110,7 @@ class SummaryWriter(object): def __init__(self, resource): self._resource = resource - if context.in_eager_mode() and self._resource is not None: + if context.executing_eagerly() and self._resource is not None: self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device="cpu:0") @@ -158,7 +158,7 @@ def initialize( @{tf.contrib.summary.SummaryWriter}. ValueError: If session wasn't passed and no default session. """ - if context.in_eager_mode(): + if context.executing_eagerly(): return if context.context().summary_writer_resource is None: raise RuntimeError("No default tf.contrib.summary.SummaryWriter found") @@ -269,7 +269,7 @@ def _make_summary_writer(name, factory, **kwargs): resource = gen_summary_ops.summary_writer(shared_name=name) # TODO(apassos): Consider doing this instead. # node = factory(resource, **kwargs) - # if not context.in_eager_mode(): + # if not context.executing_eagerly(): # ops.get_default_session().run(node) ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, factory(resource, **kwargs)) @@ -295,7 +295,7 @@ def all_summary_ops(): Returns: The summary ops. """ - if context.in_eager_mode(): + if context.executing_eagerly(): return None return ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access @@ -309,7 +309,7 @@ def summary_writer_initializer_op(): Raises: RuntimeError: If in Eager mode. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError( "tf.contrib.summary.summary_writer_initializer_op is only " "supported in graph mode.") @@ -328,8 +328,12 @@ def summary_writer_function(name, tensor, function, family=None): Returns: The result of writing the summary. """ + name_scope = ops.get_name_scope() + if name_scope: + # Add a slash to allow reentering the name scope. + name_scope += "/" def record(): - with summary_op_util.summary_scope( + with ops.name_scope(name_scope), summary_op_util.summary_scope( name, family, values=[tensor]) as (tag, scope): with ops.control_dependencies([function(tag, scope)]): return constant_op.constant(True) @@ -477,7 +481,7 @@ def graph(param, step=None, name=None): Raises: TypeError: If `param` isn't already a @{tf.Tensor} in graph mode. """ - if not context.in_eager_mode() and not isinstance(param, ops.Tensor): + if not context.executing_eagerly() and not isinstance(param, ops.Tensor): raise TypeError("graph() needs a tf.Tensor (e.g. tf.placeholder) in graph " "mode, but was: %s" % type(param)) writer = context.context().summary_writer_resource diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py index 2b7806f80d020e0064b0f5cf32fd765a9ee993d1..3aba04540eba12092d884cca10e23546eb91c91d 100644 --- a/tensorflow/contrib/summary/summary_ops_graph_test.py +++ b/tensorflow/contrib/summary/summary_ops_graph_test.py @@ -85,6 +85,38 @@ class DbTest(summary_test_util.SummaryDbTest): self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'my_scalar') + def testScalarSummaryNameScope(self): + """Test record_summaries_every_n_global_steps and all_summaries().""" + with ops.Graph().as_default(), self.test_session() as sess: + global_step = training_util.get_or_create_global_step() + global_step.initializer.run() + with ops.device('/cpu:0'): + step_increment = state_ops.assign_add(global_step, 1) + sess.run(step_increment) # Increment global step from 0 to 1 + + logdir = tempfile.mkdtemp() + with summary_ops.create_file_writer(logdir, max_queue=0, + name='t2').as_default(): + with summary_ops.record_summaries_every_n_global_steps(2): + summary_ops.initialize() + with ops.name_scope('scope'): + summary_op = summary_ops.scalar('my_scalar', 2.0) + + # Neither of these should produce a summary because + # global_step is 1 and "1 % 2 != 0" + sess.run(summary_ops.all_summary_ops()) + sess.run(summary_op) + events = summary_test_util.events_from_logdir(logdir) + self.assertEqual(len(events), 1) + + # Increment global step from 1 to 2 and check that the summary + # is now written + sess.run(step_increment) + sess.run(summary_ops.all_summary_ops()) + events = summary_test_util.events_from_logdir(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].tag, 'scope/my_scalar') + def testSummaryGraphModeCond(self): with ops.Graph().as_default(), self.test_session(): training_util.get_or_create_global_step() diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index bb7215f879411e91a1c47b87f5caede63fffea74..c756f8b27055f9cf86a311e485d97745a3c7a95b 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -29,6 +29,7 @@ from tensorflow.core.framework import types_pb2 from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import state_ops @@ -107,6 +108,20 @@ class TargetTest(test_util.TensorFlowTestCase): self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'scalar') + def testSummaryNameScope(self): + training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + with summary_ops.create_file_writer( + logdir, max_queue=0, + name='t2').as_default(), summary_ops.always_record_summaries(): + + with ops.name_scope('scope'): + summary_ops.scalar('scalar', 2.0) + + events = summary_test_util.events_from_logdir(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].tag, 'scope/scalar') + def testSummaryGlobalStep(self): step = training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() diff --git a/tensorflow/contrib/tensor_forest/README.md b/tensorflow/contrib/tensor_forest/README.md index 8b24430c71c16c2ed6b2e1a530e19fbc9ebb1698..9e1491ea666b51ba0d367610778c659c543dacf6 100644 --- a/tensorflow/contrib/tensor_forest/README.md +++ b/tensorflow/contrib/tensor_forest/README.md @@ -116,7 +116,7 @@ a different `feature_bagging_fraction * num_features` sized subset of the input features. Defaults to 1.0 (no feature bagging). * `base_random_seed`. By default (`base_random_seed = 0`), the random number -generator for each tree is seeded by the current time (in microseconds) when +generator for each tree is seeded by a 64-bit random value when each tree is first created. Using a non-zero value causes tree training to be deterministic, in that the i-th tree's random number generator is seeded with the value `base_random_seed + i`. diff --git a/tensorflow/contrib/tensor_forest/kernels/data_spec.h b/tensorflow/contrib/tensor_forest/kernels/data_spec.h index 0a3abe56dfc4f611ac8ed0815e4c74a639d2477e..bb33400214e5ef37be73b538455eecf5ae481db4 100644 --- a/tensorflow/contrib/tensor_forest/kernels/data_spec.h +++ b/tensorflow/contrib/tensor_forest/kernels/data_spec.h @@ -21,6 +21,7 @@ #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" namespace tensorflow { namespace tensorforest { diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc index da600d34eacdf27514709240723e5bb730cfe7f0..63d4d9ba50603f65cc822ea74c97b923c29fea35 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc @@ -19,6 +19,7 @@ #include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h" #include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h" #include "tensorflow/core/lib/random/distribution_sampler.h" +#include "tensorflow/core/lib/random/random.h" namespace tensorflow { namespace tensorforest { @@ -122,9 +123,8 @@ ClassificationStats::ClassificationStats(const TensorForestParams& params, right_gini_.reset(new RunningGiniScores()); } - uint64 time_seed = static_cast(std::clock()); single_rand_ = std::unique_ptr( - new random::PhiloxRandom(time_seed)); + new random::PhiloxRandom(random::New64())); rng_ = std::unique_ptr( new random::SimplePhilox(single_rand_.get())); } diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h index c544a8c75e9bfe8fe6bbea8913e7be17d868bfef..95f75b4d7e6a961edf6b3da1dc1712e7ddaacf31 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h @@ -23,6 +23,7 @@ #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/simple_philox.h" namespace tensorflow { @@ -44,18 +45,20 @@ class TensorDataSet { int column_count = 0; for (int i = 0; i < input_spec_.dense_size(); ++i) { for (int j = 0; j < input_spec_.dense(i).size(); ++j) { - decision_trees::FeatureId id; - id.mutable_id()->set_value(strings::StrCat(column_count)); - available_features_.push_back(id); ++column_count; } } + available_features_.reserve(column_count); + decision_trees::FeatureId id; + for (int i = 0; i < column_count; i++) { + id.mutable_id()->set_value(strings::StrCat(i)); + available_features_.emplace_back(id); + } // Set up the random number generator. if (split_sampling_random_seed_ == 0) { - uint64 time_seed = static_cast(std::clock()); single_rand_ = std::unique_ptr( - new random::PhiloxRandom(time_seed)); + new random::PhiloxRandom(random::New64())); } else { single_rand_ = std::unique_ptr( new random::PhiloxRandom(split_sampling_random_seed_)); diff --git a/tensorflow/contrib/tensorboard/BUILD b/tensorflow/contrib/tensorboard/BUILD index 2e0a46ffe432341a423ac159deb7745d9ef15374..d833744d0c7e85b9f336f60a3becfd043bc3821d 100644 --- a/tensorflow/contrib/tensorboard/BUILD +++ b/tensorflow/contrib/tensorboard/BUILD @@ -13,7 +13,6 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") tf_proto_library( name = "protos_all", srcs = glob(["**/*.proto"]), - go_api_version = 2, visibility = ["//visibility:public"], ) diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 65a0e903a74d066dcec6f2fdb70a22a0872b802f..906cc3f0344e7cb641589bd522e33d658150d3b5 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -47,7 +47,10 @@ tf_cuda_cc_test( tf_custom_op_library( name = "python/ops/_trt_engine_op.so", - srcs = ["ops/trt_engine_op.cc"], + srcs = [ + "ops/trt_calib_op.cc", + "ops/trt_engine_op.cc", + ], deps = [ ":trt_engine_op_kernel", ":trt_shape_function", @@ -71,11 +74,19 @@ tf_cuda_library( cc_library( name = "trt_engine_op_kernel", - srcs = ["kernels/trt_engine_op.cc"], - hdrs = ["kernels/trt_engine_op.h"], + srcs = [ + "kernels/trt_calib_op.cc", + "kernels/trt_engine_op.cc", + ], + hdrs = [ + "kernels/trt_calib_op.h", + "kernels/trt_engine_op.h", + ], copts = tf_copts(), + visibility = ["//visibility:public"], deps = [ ":trt_logging", + ":trt_resources", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:stream_executor_headers_lib", @@ -87,7 +98,10 @@ cc_library( ) tf_gen_op_libs( - op_lib_names = ["trt_engine_op"], + op_lib_names = [ + "trt_engine_op", + "trt_calib_op", + ], deps = if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]), @@ -107,7 +121,9 @@ tf_cuda_library( tf_gen_op_wrapper_py( name = "trt_engine_op", + gen_locally = True, deps = [ + ":trt_calib_op_op_lib", ":trt_engine_op_op_lib", ":trt_logging", ":trt_shape_function", @@ -139,6 +155,7 @@ py_library( deps = [ ":trt_convert_py", ":trt_ops_py", + "//tensorflow/python:errors", ], ) @@ -171,6 +188,27 @@ tf_py_wrap_cc( ], ) +tf_cuda_library( + name = "trt_resources", + srcs = [ + "resources/trt_int8_calibrator.cc", + "resources/trt_resource_manager.cc", + ], + hdrs = [ + "resources/trt_int8_calibrator.h", + "resources/trt_resource_manager.h", + "resources/trt_resources.h", + ], + deps = [ + ":trt_logging", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + # Library for the node-level conversion portion of TensorRT operation creation tf_cuda_library( name = "trt_conversion", @@ -185,6 +223,7 @@ tf_cuda_library( deps = [ ":segment", ":trt_logging", + ":trt_resources", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", "//tensorflow/core:framework", diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md index dfcce0fd00eedf3341850bbc23927dc3b2e2d2aa..6eafc1754ca5102c8adf04f00e33dc2f8ff970f6 100644 --- a/tensorflow/contrib/tensorrt/README.md +++ b/tensorflow/contrib/tensorrt/README.md @@ -1,40 +1,59 @@ -Using TensorRT in TensorFlow -============================ +# Using TensorRT in TensorFlow + This module provides necessary bindings and introduces TRT_engine_op -operator that wraps a subgraph in TensorRT. +operator that wraps a subgraph in TensorRT. This is still a work in progress +but should be useable with most common graphs. + +## Compilation -Compilation ------------ In order to compile the module, you need to have a local TensorRT -installation (libnvinfer.so and respective include files). During the +installation ( libnvinfer.so and respective include files ). During the configuration step, TensorRT should be enabled and installation path should be set. If installed through package managers (deb,rpm), configure script should find the necessary components from the system automatically. If installed from tar packages, user has to set path to location where the library is installed during configuration. - -``` +```shell bazel build --config=cuda --config=opt //tensorflow/tools/pip_package:build_pip_package bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/ ``` After the installation of tensorflow package, TensorRT transformation -will be available. An example use is shown below. - -```python -import tensorflow as tf -import tensorflow.contrib.tensorrt as trt -#... create and train or load model -gdef = sess.graph.as_graph_def() -trt_gdef = trt.create_inference_graph( - gdef, #original graph_def - ["output"], #name of output node(s) - max_batch_size, #maximum batch size to run the inference - max_workspace_size_bytes) # max memory for TensorRT to use -tf.reset_default_graph() -tf.import_graph_def(graph_def=trt_gdef) -#...... run inference +will be available. An example use can be found in test/test_tftrt.py script + +## Installing TensorRT 3.0.4 + +In order to make use of TensorRT integration, you will need a local installation of TensorRT 3.0.4 from the [NVIDIA Developer website](https://developer.nvidia.com/tensorrt). Due to compiler compatibility, you will need to download and install the TensorRT 3.0.4 tarball for _Ubuntu 14.04_, i.e., **_TensorRT-3.0.4.Ubuntu-14.04.5.x86_64.cuda-9.0.cudnn7.0-tar.gz_**, even if you are using Ubuntu 16.04 or later. + +### Preparing TensorRT installation + +Once you have downloaded TensorRT-3.0.4.Ubuntu-14.04.5.x86_64.cuda-9.0.cudnn7.0-tar.gz, you will need to unpack it to an installation directory, which will be referred to as . Please replace with the full path of actual installation directory you choose in commands below. + +```shell +cd && tar -zxf /path/to/TensorRT-3.0.4.Ubuntu-14.04.5.x86_64.cuda-9.0.cudnn7.0-tar.gz ``` + +After unpacking the binaries, you have several options to use them: + +#### To run TensorFlow as a user without superuser privileges + +For a regular user without any sudo rights, you should add TensorRT to your `$LD_LIBRARY_PATH`: + + ```shell + export LD_LIBRARY_PATH=/TensorRT-3.0.4/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} + ``` + +Then you are ready to use TensorFlow-TensorRT integration. `$LD_LIBRARY_PATH` must contain the path to TensorRT installation for TensorFlow-TensorRT integration to work. If you are using a VirtualEnv-like setup, you can add the command above to your `bin/activate` script or to your `.bashrc` script. + +#### To run TensorFlow as a superuser + + When running as a superuser, such as in a container or via sudo, the `$LD_LIBRARY_PATH` approach above may not work. The following is preferred when the user has superuser privileges: + + ```shell + echo "/TensorRT-3.0.4/lib" | sudo tee /etc/ld.so.conf.d/tensorrt304.conf && sudo ldconfig + ``` + + Please ensure that any existing deb package installation of TensorRT is removed before following these instructions to avoid package conflicts. \ No newline at end of file diff --git a/tensorflow/contrib/tensorrt/__init__.py b/tensorflow/contrib/tensorrt/__init__.py index fd551d70b4385b14b84b7b98a6d16b0c03733d38..140ad4828208ae4844a49bf664955b50cd9e51cd 100644 --- a/tensorflow/contrib/tensorrt/__init__.py +++ b/tensorflow/contrib/tensorrt/__init__.py @@ -18,6 +18,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import,wildcard-import -from tensorflow.contrib.tensorrt.python import * -# pylint: enable=unused-import,wildcard-import +from tensorflow.python.framework import errors + +# pylint: disable=unused-import,wildcard-import,g-import-not-at-top +try: + from tensorflow.contrib.tensorrt.python import * +except errors.NotFoundError as e: + no_trt_message = ( + '**** Failed to initialize TensorRT. This is either because the TensorRT' + ' installation path is not in LD_LIBRARY_PATH, or because you do not have' + ' it installed. If not installed, please go to' + ' https://developer.nvidia.com/tensorrt to download and install' + ' TensorRT ****') + print(no_trt_message) + raise e +# pylint: enable=unused-import,wildcard-import,g-import-not-at-top diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 970f8104736d95d09ea3ffabb07f84d8591a8f9c..ff8cc6374d40dc0b49721a784e25015c76541d03 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include #include #include #include @@ -48,16 +49,33 @@ namespace tensorrt { namespace convert { namespace { -static bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) { +bool IsTensorRTCandidate(const tensorflow::Node* node) { // LINT.IfChange // TODO(jie): Segmentation shouldn't associated with op name. // Split it into a registration for each kernel. static const std::set candidate_ops = { - "Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu", - "Add", "Mul", "Sub", "Rsqrt", "Pad" // "Placeholder" ,"Mean" + "Identity", + "Snapshot", + "Const", + "Conv2D", + "MaxPool", + "BiasAdd", + "Relu", + "Add", + "Mul", + "Sub", + "Rsqrt", + "Pad", + "Mean", + "AvgPool", + "ConcatV2", + "DepthwiseConv2dNative", + "FusedBatchNorm", + "FusedBatchNormV2", + // TODO(ben,jie): ... }; // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h) - return candidate_ops.count(node_def.op()); + return candidate_ops.count(node->type_string()); } void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, @@ -67,8 +85,10 @@ void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, const tensorflow::Node* node = graph.FindNodeId(node_id); for (const tensorflow::Edge* edge : node->in_edges()) { if (!subgraph_node_ids.count(edge->src()->id()) && - !edge->src()->IsSource()) { + !edge->src()->IsSource() && !edge->IsControlEdge()) { incoming_edges->insert(edge); + } else { + VLOG(2) << node->name() << " -> " << edge->src()->name() << " N, "; } } } @@ -81,8 +101,11 @@ void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph, const tensorflow::Node* node = graph.FindNodeId(node_id); for (const tensorflow::Edge* edge : node->out_edges()) { if (!subgraph_node_ids.count(edge->dst()->id()) && - !edge->dst()->IsSink()) { + !edge->dst()->IsSink() && !edge->IsControlEdge()) { + VLOG(2) << node->name() << " -> " << edge->dst()->name() << " Y, "; outgoing_edges->insert(edge); + } else { + VLOG(2) << node->name() << " -> " << edge->dst()->name() << " N, "; } } } @@ -109,74 +132,150 @@ std::unordered_map> BuildTensorNameMap( } return result; } - -tensorflow::Status ConvertSubGraphToTensorRT( - const std::vector& output_names, - const std::set& subgraph_node_ids, - size_t max_batch_size, // Max batch size that engine will be created for - // Max amount of memory that engine will be allowed to consume, in bytes - size_t max_workspace_size_bytes, - const tensorflow::grappler::GraphProperties& graph_properties, - tensorflow::Graph* graph) { - tensorflow::EdgeSet subgraph_incoming_edges; - GetSubGraphIncomingEdges(*graph, subgraph_node_ids, &subgraph_incoming_edges); - +// TODO(sami): convert references to pointers +struct ConvertGraphParams { + ConvertGraphParams( + tensorflow::Graph& inp_graph, + const std::vector& output_node_names, + const std::set& subgraph_node_id_numbers, + size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes, + const tensorflow::grappler::GraphProperties& current_graph_properties, + std::unordered_map>* output_edges, + int engine_precision_mode) + : graph(inp_graph), + output_names(output_node_names), + subgraph_node_ids(subgraph_node_id_numbers), + max_batch_size(max_supported_batch_size), + max_workspace_size_bytes(max_consumed_workspace_size_bytes), + graph_properties(current_graph_properties), + output_edge_map(output_edges), + precision_mode(engine_precision_mode) {} + tensorflow::Graph& graph; + const std::vector& output_names; + const std::set& subgraph_node_ids; + size_t max_batch_size; + size_t max_workspace_size_bytes; + const tensorflow::grappler::GraphProperties& graph_properties; + std::unordered_map>* output_edge_map; + int precision_mode; std::vector> subgraph_inputs; + std::vector> subgraph_outputs; + tensorflow::EdgeSet subgraph_incoming_edges; + tensorflow::EdgeSet subgraph_outgoing_edges; +}; - // Collect inputs by looking for incoming edges - for (const tensorflow::Edge* edge : subgraph_incoming_edges) { - subgraph_inputs.push_back({edge->src()->id(), edge->src_output()}); +static tensorflow::Status FillSubGraphEdgeSets(ConvertGraphParams* p) { + GetSubGraphIncomingEdges(p->graph, p->subgraph_node_ids, + &p->subgraph_incoming_edges); + for (const tensorflow::Edge* edge : p->subgraph_incoming_edges) { + p->subgraph_inputs.push_back({edge->src()->id(), edge->src_output()}); } + auto output_name_to_index_map = BuildTensorNameMap(p->output_names); std::set> subgraph_outputs_set; // Collect outputs referenced from output_names - auto output_name_to_index_map = BuildTensorNameMap(output_names); - for (int node_id : subgraph_node_ids) { - tensorflow::Node* node = graph->FindNodeId(node_id); + for (int node_id : p->subgraph_node_ids) { + tensorflow::Node* node = p->graph.FindNodeId(node_id); if (output_name_to_index_map.count(node->name())) { for (int index : output_name_to_index_map.at(node->name())) { subgraph_outputs_set.insert({node_id, index}); } } } - // Collect outputs referenced from outgoing edges - tensorflow::EdgeSet subgraph_outgoing_edges; - GetSubGraphOutgoingEdges(*graph, subgraph_node_ids, &subgraph_outgoing_edges); - for (const tensorflow::Edge* edge : subgraph_outgoing_edges) { + GetSubGraphOutgoingEdges(p->graph, p->subgraph_node_ids, + &p->subgraph_outgoing_edges); + for (const tensorflow::Edge* edge : p->subgraph_outgoing_edges) { subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()}); } - // Impose an ordering on the outputs - std::vector> subgraph_outputs( - subgraph_outputs_set.begin(), subgraph_outputs_set.end()); - // Build TensorRT node and add it to the graph + p->subgraph_outputs.reserve(subgraph_outputs_set.size()); + p->subgraph_outputs.insert(p->subgraph_outputs.begin(), + subgraph_outputs_set.begin(), + subgraph_outputs_set.end()); + return tensorflow::Status::OK(); +}; + +tensorflow::Status GetCalibNode(ConvertGraphParams* params) { + TF_RETURN_IF_ERROR(FillSubGraphEdgeSets(params)); tensorflow::NodeDef trt_node_def; - TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef( - *graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs, - max_batch_size, max_workspace_size_bytes, graph_properties, - &trt_node_def)); + SubGraphParams s(params->graph, params->subgraph_node_ids, + params->subgraph_inputs, params->subgraph_outputs, + params->max_batch_size, params->max_workspace_size_bytes, + params->graph_properties, params->output_edge_map, + &trt_node_def, params->precision_mode); + TF_RETURN_IF_ERROR(InjectCalibrationNode(s)); tensorflow::Status status; - tensorflow::Node* trt_node = graph->AddNode(trt_node_def, &status); + tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status); + + TF_RETURN_IF_ERROR(status); + + for (auto in_edge : + params->subgraph_incoming_edges) { // loop over incoming edges and + // attach them to calib node + // tensorflow::Node* src_node = in_edge->src(); + auto src_output = in_edge->src_output(); + auto dst_node = in_edge->dst(); + auto dst_input = in_edge->dst_input(); + VLOG(1) << " update edge " << trt_node->name() << ":" << src_output + << " -> " << dst_node->name() << ":" << dst_input; + TF_RETURN_IF_ERROR( + params->graph.UpdateEdge(trt_node, src_output, dst_node, dst_input)); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) { + TF_RETURN_IF_ERROR(FillSubGraphEdgeSets(params)); + tensorflow::NodeDef trt_node_def; + + SubGraphParams s(params->graph, params->subgraph_node_ids, + params->subgraph_inputs, params->subgraph_outputs, + params->max_batch_size, params->max_workspace_size_bytes, + params->graph_properties, params->output_edge_map, + &trt_node_def, params->precision_mode); + TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(s)); + tensorflow::Status status; + tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status); + + // AddNode does not wire edges. + // Re-map incoming edges to use the new TRT node instead of the orig subgraph + std::map, int> subgraph_edge_to_input_map; + for (size_t i = 0; i < params->subgraph_inputs.size(); ++i) { + subgraph_edge_to_input_map.insert({params->subgraph_inputs.at(i), i}); + } + for (const tensorflow::Edge* edge : params->subgraph_incoming_edges) { + std::pair old_src = {edge->src()->id(), edge->src_output()}; + int new_src_output = subgraph_edge_to_input_map.at(old_src); + params->graph.AddEdge(edge->src(), edge->src_output(), trt_node, + new_src_output); + params->graph.RemoveEdge(edge); + } + + VLOG(2) << "new wiring edges: " << trt_node->in_edges().size(); + for (const tensorflow::Edge* edge : trt_node->in_edges()) { + VLOG(2) << edge->src()->name() << " port: " << edge->src_output(); + } + TF_RETURN_IF_ERROR(status); // Re-map outgoing edges to use the new TRT node instead of the orig subgraph std::map, int> subgraph_edge_to_output_map; - for (size_t i = 0; i < subgraph_outputs.size(); ++i) { - subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i}); + for (size_t i = 0; i < params->subgraph_outputs.size(); ++i) { + subgraph_edge_to_output_map.insert({params->subgraph_outputs.at(i), i}); } TF_RETURN_IF_ERROR(status); - for (const tensorflow::Edge* edge : subgraph_outgoing_edges) { + for (const tensorflow::Edge* edge : params->subgraph_outgoing_edges) { std::pair old_src = {edge->src()->id(), edge->src_output()}; int new_src_output = subgraph_edge_to_output_map.at(old_src); - TF_RETURN_IF_ERROR(graph->UpdateEdge(trt_node, new_src_output, edge->dst(), - edge->dst_input())); + TF_RETURN_IF_ERROR(params->graph.UpdateEdge( + trt_node, new_src_output, edge->dst(), edge->dst_input())); } // Remove the original subgraph - for (int node_id : subgraph_node_ids) { - tensorflow::Node* node = graph->FindNodeId(node_id); + for (int node_id : params->subgraph_node_ids) { + tensorflow::Node* node = params->graph.FindNodeId(node_id); // Don't remove the input placeholders if (node->type_string() == "Placeholder") { continue; } - graph->RemoveNode(node); + params->graph.RemoveNode(node); } return tensorflow::Status::OK(); } @@ -194,12 +293,39 @@ tensorflow::Status BuildNodeMap( } } // namespace +tensorflow::Status ConvertCalibGraphToInferGraph( + const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph) { + VLOG(0) << "Starting Calib Conversion"; + tensorflow::Graph graph(tensorflow::OpRegistry::Global()); + TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( + tensorflow::GraphConstructorOptions(), graph_def, &graph)); + // get calib nodes + std::vector calib_nodes; + for (auto node : graph.op_nodes()) { + if (node->type_string() == "TRTCalibOp") { + VLOG(1) << "Found Calib Node"; + calib_nodes.push_back(node); + } + } + VLOG(0) << "Num Calib nodes in graph= " << calib_nodes.size(); + if (calib_nodes.size() == 0) + return tensorflow::errors::FailedPrecondition( + "Graph doesn't contain any calibration nodes!." + " Please generate calibration graph and run calibration first"); + for (auto n : calib_nodes) { + TF_RETURN_IF_ERROR( + tensorrt::convert::ConvertCalibrationNodeToEngineNode(graph, n)); + } + graph.ToGraphDef(infer_graph); + return tensorflow::Status::OK(); +} tensorflow::Status ConvertGraphDefToTensorRT( const tensorflow::GraphDef& graph_def, const std::vector& output_names, size_t max_batch_size, - size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def) { - // Optimization pass + size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, + int precision_mode = FP32MODE, int minimum_segment_size = 3) { + // optimization pass tensorflow::grappler::GrapplerItem item; item.fetch = output_names; tensorflow::GraphDef gdef; @@ -209,16 +335,23 @@ tensorflow::Status ConvertGraphDefToTensorRT( tensorflow::grappler::LayoutOptimizer optimizer; tensorflow::grappler::Cluster* cluster; - // Virtual cluster + // virtual cluster tensorflow::DeviceProperties device_properties; + device_properties.set_type("GPU"); device_properties.mutable_environment()->insert({"architecture", "6"}); cluster = new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}}); + // single machine + int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores(); + int num_gpus = tensorflow::grappler::GetNumAvailableGPUs(); + VLOG(2) << "cpu_cores: " << num_cpu_cores; + VLOG(2) << "gpus: " << num_gpus; + TF_RETURN_IF_ERROR(optimizer.Optimize(cluster, item, &gdef)); - // Constant folding + // constant folding item.graph = gdef; tensorflow::grappler::ConstantFolding fold(nullptr); TF_RETURN_IF_ERROR(fold.Optimize(nullptr, item, &gdef)); @@ -226,7 +359,6 @@ tensorflow::Status ConvertGraphDefToTensorRT( // AJ refactoring shape inference through grappler/GraphProperties. tensorflow::grappler::GraphProperties static_graph_properties(item); TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(false)); - // Build full graph tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), gdef.library()); @@ -243,7 +375,7 @@ tensorflow::Status ConvertGraphDefToTensorRT( } // TODO(sami): this should be passed as a knob!!!! - segment_options.minimum_segment_size = 2; + segment_options.minimum_segment_size = minimum_segment_size; tensorflow::tensorrt::segment::SegmentNodesVector segments; TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph( gdef, IsTensorRTCandidate, segment_options, &segments)); @@ -252,14 +384,38 @@ tensorflow::Status ConvertGraphDefToTensorRT( } std::unordered_map node_map; TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map)); + std::unordered_map> output_edge_map; + int count = 0; + float total_num_nodes_in_segments = 0.; + for (auto s : segments) { + total_num_nodes_in_segments += s.size(); + } for (const std::set& subgraph_node_names : segments) { std::set subgraph_node_ids; + size_t max_mem_per_engine = + max_workspace_size_bytes * + ((float)subgraph_node_names.size() / total_num_nodes_in_segments); + std::stringstream oss; for (const string& node_name : subgraph_node_names) { + oss << " " << node_name; subgraph_node_ids.insert(node_map.at(node_name)->id()); } - TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT( - output_names, subgraph_node_ids, max_batch_size, - max_workspace_size_bytes, static_graph_properties, &graph)); + VLOG(2) << "Subgraph nodes" << oss.str(); + ConvertGraphParams p(graph, output_names, subgraph_node_ids, max_batch_size, + max_mem_per_engine, static_graph_properties, + &output_edge_map, precision_mode); + if (precision_mode == INT8MODE) { + TF_RETURN_IF_ERROR(GetCalibNode(&p)); + } else { + tensorflow::Status status = ConvertSubGraphToTensorRT(&p); + if (status != tensorflow::Status::OK()) { + LOG(WARNING) << "subgraph conversion error for subgraph_index:" << count + << " due to: \"" << status.ToString() + << "\" SKIPPING......( " << subgraph_node_names.size() + << " nodes)"; + } + count++; + } } graph.ToGraphDef(new_graph_def); return tensorflow::Status::OK(); diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h index 154ad3f2e8fb0ae702448097fbdece510df30223..e01e4a5328061ad527b2dac6e2e4ef1559bd914d 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.h +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h @@ -28,14 +28,20 @@ namespace tensorflow { namespace tensorrt { namespace convert { +// This method converts an already generated calibration graph which was used in +// calibration runs to an inference graph +tensorflow::Status ConvertCalibGraphToInferGraph( + const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def); + // max_batch_size: maximum batch size which can be used for inference for // optimization targets inference run with max batch size. -// max_workspace_size_bytes: The upper bound of memory allowence for +// max_workspace_size_bytes: The upper bound of memory allowance for // engine building. tensorflow::Status ConvertGraphDefToTensorRT( const tensorflow::GraphDef& graph_def, const std::vector& output_names, size_t max_batch_size, - size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def); + size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, + int precision_mode, int minimum_segment_size); } // namespace convert } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 9ee717dd7fb1eff4a11fb104cf5806ec8ab853d2..e920a797fe428620ef62a2b67c07f35d85ef5211 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -24,6 +24,10 @@ limitations under the License. #include #include +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.pb.h" // NOLINT #include "tensorflow/core/framework/types.h" @@ -32,6 +36,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/tensor_coding.h" @@ -39,7 +44,6 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorrt/include/NvInfer.h" // Check if the types are equal. Cast to int first so that failure log message @@ -49,7 +53,8 @@ limitations under the License. namespace tensorflow { namespace tensorrt { namespace convert { - +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; namespace { inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, @@ -65,7 +70,8 @@ inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, *trt_dtype = nvinfer1::DataType::kHALF; break; default: - return tensorflow::errors::InvalidArgument("Unsupported data type"); + return tensorflow::errors::InvalidArgument( + "Unsupported data type " + tensorflow::DataTypeString(tf_dtype)); } return tensorflow::Status::OK(); } @@ -112,6 +118,18 @@ static std::vector> CreateSamePadding( return padding; } +string GetCommonNameScope(const string& op_name_a, const string& op_name_b) { + size_t last_scope_separator = 0; + for (size_t i = 0; i < std::min(op_name_a.size(), op_name_b.size()); ++i) { + if (op_name_a[i] != op_name_b[i]) { + break; + } else if (op_name_a[i] == '/') { + last_scope_separator = i + 1; + } + } + return op_name_a.substr(0, last_scope_separator); +} + class TRT_ShapedWeights { public: TRT_ShapedWeights(tensorflow::DataType type, const void* values, @@ -244,6 +262,11 @@ std::vector TFAttrs::get>(string key) const { return std::vector(attr.begin(), attr.end()); } +template <> +std::vector TFAttrs::get>(string key) const { + auto attr = this->at(key)->list().s(); + return std::vector(attr.begin(), attr.end()); +} template <> nvinfer1::Dims TFAttrs::get(string key) const { auto values = this->get>(key); @@ -266,6 +289,17 @@ tensorflow::DataType TFAttrs::get(string key) const { return this->at(key)->type(); } +template <> +float TFAttrs::get(string key) const { + return this->at(key)->f(); +} + +template <> +bool TFAttrs::get(string key) const { + return this->at(key)->b(); +} + +// TODO(jie): reorder4 & reorder2 should be merged? template void Reorder4(nvinfer1::DimsNCHW shape, const T* idata, nvinfer1::DimsNCHW istrides, T* odata, @@ -283,29 +317,86 @@ void Reorder4(nvinfer1::DimsNCHW shape, const T* idata, } } +template +void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides, + T* odata, nvinfer1::DimsHW ostrides) { + for (int h = 0; h < shape.h(); ++h) { + for (int w = 0; w < shape.w(); ++w) { + odata[h * ostrides.h() + w * ostrides.w()] = + idata[h * ostrides.h() + w * ostrides.w()]; + } + } +} + +// TODO(jie): fallback to tensorflow!! +void ReorderCKtoKC(const TRT_ShapedWeights& iweights, + TRT_ShapedWeights* oweights) { + int c = iweights.shape_.d[0]; + int k = iweights.shape_.d[1]; + oweights->shape_.d[0] = k; + oweights->shape_.d[1] = c; + nvinfer1::DimsHW istrides = {1, k}; + nvinfer1::DimsHW ostrides = {c, 1}; + switch (iweights.type_) { + case tensorflow::DataType::DT_FLOAT: { + Reorder2({k, c}, static_cast(iweights.GetValues()), + istrides, + static_cast(const_cast(oweights->GetValues())), + ostrides); + break; + } + case tensorflow::DataType::DT_HALF: { + Reorder2({k, c}, static_cast(iweights.GetValues()), + istrides, static_cast( + const_cast(oweights->GetValues())), + ostrides); + break; + } + default: + LOG(FATAL) << "Unsupported type in reorder expected fp32 or fp16 but got " + << DataTypeString(iweights.type_); + } +} + void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, - TRT_ShapedWeights* oweights) { + TRT_ShapedWeights* oweights, int num_groups) { CHECK_EQ(iweights.type_, oweights->type_); CHECK_EQ(iweights.size_bytes(), oweights->size_bytes()); int r = iweights.shape_.d[0]; int s = iweights.shape_.d[1]; - int c = iweights.shape_.d[2]; - int k = iweights.shape_.d[3]; - oweights->shape_.d[0] = k; - oweights->shape_.d[1] = c; + // TRT requires GKcRS, while TF depthwise has RSCK + // where c=1, C=G + VLOG(2) << "num_groups: " << num_groups; + int c = iweights.shape_.d[2] / num_groups; + VLOG(2) << "c" << iweights.shape_.d[2] << " then " << c; + int k = iweights.shape_.d[3] * num_groups; + VLOG(2) << "k" << iweights.shape_.d[3] << " then " << k; + oweights->shape_.d[0] = k / num_groups; + oweights->shape_.d[1] = c * num_groups; oweights->shape_.d[2] = r; oweights->shape_.d[3] = s; nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k}; nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1}; switch (iweights.type_) { - case tensorflow::DataType::DT_FLOAT: + case tensorflow::DataType::DT_FLOAT: { Reorder4({k, c, r, s}, static_cast(iweights.GetValues()), istrides, static_cast(const_cast(oweights->GetValues())), ostrides); break; + } + case tensorflow::DataType::DT_HALF: { + Reorder4( + {k, c, r, s}, static_cast(iweights.GetValues()), + istrides, + static_cast(const_cast(oweights->GetValues())), + ostrides); + break; + } + default: - LOG(FATAL) << "!!!!!!!!!!!!!!!!!!!!!!!!broke!!!!!!!!!!!!"; + LOG(FATAL) << "Unsupported type, expected fp32 or fp16 but got " + << DataTypeString(iweights.type_); } } @@ -323,12 +414,11 @@ inline std::shared_ptr infer_object(T* obj) { return std::shared_ptr(obj, InferDeleter()); } -// Logger for GIE info/warning/errors class Converter; using OpConverter = std::function const&, + const std::vector&, std::vector*)>; class Converter { @@ -336,40 +426,67 @@ class Converter { std::unordered_map op_registry_; nvinfer1::INetworkDefinition* trt_network_; std::list> temp_bufs_; - + tensorflow::tensorrt::TRTWeightStore* weight_store_; + bool fp16_; void register_op_converters(); - - std::vector get_inputs( - const tensorflow::NodeDef& node_def) { - std::vector inputs; - for (const auto& input_name : node_def.input()) { - VLOG(2) << "Retrieve input: " << input_name; - inputs.push_back(trt_tensors_.at(input_name)); + tensorflow::Status get_inputs(const tensorflow::NodeDef& node_def, + std::vector* inputs) { + for (auto const& input_name : node_def.input()) { + /************************************************************************* + * TODO(jie) handle case 1) here + * Normalizes the inputs and extracts associated metadata: + * 1) Inputs can contain a colon followed by a suffix of characters. + * That suffix may be a single number (e.g. inputName:1) or several + * word characters separated from a number by a colon + * (e.g. inputName:foo:1). The + * latter case is used to denote inputs and outputs of functions. + * 2) Control dependency inputs contain caret at the beginning and we + * remove this and annotate the edge as a control dependency. + ************************************************************************/ + string name = input_name[0] == '^' ? input_name.substr(1) : input_name; + auto first = name.find_first_of(':'); + if (first != string::npos && first + 2 == name.size() && + name[first + 1] == '0') + name.erase(first); + + VLOG(2) << "retrieve input: " << name; + if (trt_tensors_.count(name)) { + inputs->push_back(trt_tensors_.at(name)); + } else { + string str("Node "); + StrAppend(&str, node_def.name(), " should have an input named '", name, + "' but it is not available"); + LOG(WARNING) << "input: " << name << " not available for node at " + << node_def.name(); + return tensorflow::errors::InvalidArgument(str); + } } - return inputs; + return tensorflow::Status::OK(); } public: - explicit Converter(nvinfer1::INetworkDefinition* trt_network) - : trt_network_(trt_network) { + explicit Converter(nvinfer1::INetworkDefinition* trt_network, + tensorflow::tensorrt::TRTWeightStore* ws, bool fp16) + : trt_network_(trt_network), weight_store_(ws), fp16_(fp16) { this->register_op_converters(); } - + tensorflow::tensorrt::TRTWeightStore* weight_store() { return weight_store_; } TRT_ShapedWeights get_temp_weights(tensorflow::DataType type, nvinfer1::Dims shape) { TRT_ShapedWeights weights(type, nullptr, shape); // TODO(jie): check weights size_bytes. 0 means type error - temp_bufs_.push_back(std::vector(weights.size_bytes())); - weights.SetValues(temp_bufs_.back().data()); + weight_store_->store_.push_back(std::vector(weights.size_bytes())); + weights.SetValues(weight_store_->store_.back().data()); return weights; } - + bool isFP16() { return fp16_; }; TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) { return this->get_temp_weights(weights.type_, weights.shape_); } tensorflow::Status convert_node(const tensorflow::NodeDef& node_def) { - std::vector inputs = this->get_inputs(node_def); + std::vector inputs; + TF_RETURN_IF_ERROR(this->get_inputs(node_def, &inputs)); string op = node_def.op(); if (!op_registry_.count(op)) { return tensorflow::errors::Unimplemented( @@ -382,7 +499,7 @@ class Converter { TRT_TensorOrWeights output = outputs.at(i); // TODO(jie): tf protobuf seems to be omitting the :0 suffix string output_name = node_def.name(); - if (i != 0) output_name = output_name + ":" + std::to_string(i); + if (i != 0) output_name = StrCat(output_name, ":", i); if (output.is_tensor()) { output.tensor()->setName(output_name.c_str()); } @@ -434,6 +551,19 @@ class Converter { } }; +TRT_ShapedWeights ConvertFP32ToFP16(Converter& ctx, + const TRT_ShapedWeights& weights_src) { + auto dtype_new = tensorflow::DataType::DT_HALF; + TRT_ShapedWeights weights = + ctx.get_temp_weights(dtype_new, weights_src.shape_); + const float* src = static_cast(weights_src.GetValues()); + Eigen::half* dst = const_cast( + static_cast(weights.GetValues())); + for (int64_t i = 0; i < weights_src.count(); i++) { + dst[i] = Eigen::half_impl::float_to_half_rtne(src[i]); + } + return weights; +} // **************************************************************************** // Constant folding functions // TODO(jie): once optimizer kicks in, we should have done constant folding @@ -448,7 +578,7 @@ struct LambdaFactory { switch (op) { case OP_CATEGORY::RSQRT: { VLOG(2) << "RSQRT GETS DONE"; - return [](T t) -> T { return 1.0 / std::sqrt(t); }; + return [](T t) -> T { return 1.0 / sqrt(t); }; } case OP_CATEGORY::NEG: return [](T t) -> T { return -t; }; @@ -534,6 +664,22 @@ struct LambdaFactory { } }; +template <> +std::function LambdaFactory::unary() { + switch (op) { + case OP_CATEGORY::RSQRT: { + VLOG(2) << "RSQRT GETS DONE"; + return [](Eigen::half t) -> Eigen::half { + return Eigen::half(1.0 / sqrt(float(t))); + }; + } + case OP_CATEGORY::NEG: + return [](Eigen::half t) -> Eigen::half { return -t; }; + default: + VLOG(2) << "Not supported op for unary: " << static_cast(op); + return nullptr; + } +} tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights, TRT_ShapedWeights* oweights, LambdaFactory unary_op) { @@ -545,6 +691,14 @@ tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights, std::transform(inp, inp + iweights.count(), oup, unary_op.unary()); break; } + case tensorflow::DataType::DT_HALF: { + auto inp = static_cast(iweights.GetValues()); + auto oup = + static_cast(const_cast(oweights->GetValues())); + std::transform(inp, inp + iweights.count(), oup, + unary_op.unary()); + break; + } default: return tensorflow::errors::Unimplemented( "Data type not supported: " + @@ -588,6 +742,32 @@ tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l, } break; } + case tensorflow::DataType::DT_HALF: { + auto inp_l = static_cast(iweights_l.GetValues()); + auto inp_r = static_cast(iweights_r.GetValues()); + auto oup = + static_cast(const_cast(oweights->GetValues())); + + if (iweights_l.count() != iweights_r.count()) { + // We only supports broadcast of RankZero + if (iweights_l.count() == 1) { + VLOG(2) << "I bet it is not working!" << (*inp_l); + std::transform(inp_r, inp_r + iweights_r.count(), oup, + binary_op.broadcast_l(*inp_l)); + } else if (iweights_r.count() == 1) { + VLOG(2) << "I bet it is not working!" << (*inp_r); + std::transform(inp_l, inp_l + iweights_l.count(), oup, + binary_op.broadcast_r(*inp_r)); + } else { + return tensorflow::errors::Unimplemented( + "Binary op with non-rankZero broadcast not supported"); + } + } else { + std::transform(inp_l, inp_l + iweights_l.count(), inp_r, oup, + binary_op.binary()); + } + break; + } default: return tensorflow::errors::Unimplemented( "Data type not supported: " + @@ -599,7 +779,7 @@ tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l, tensorflow::Status ConstantFoldUnary( Converter& ctx, const tensorflow::NodeDef& node_def, - std::vector const& inputs, + const std::vector& inputs, std::vector* outputs) { TRT_ShapedWeights weights_input = inputs.at(0).weights(); @@ -613,13 +793,12 @@ tensorflow::Status ConstantFoldUnary( CHECK_EQ(weights_input.type_, TFAttrs(node_def).get("T")); - // Maybe I should do a switch LambdaFactory unary_op; if (node_def.op() == "Rsqrt") { // Compute rsqrt unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT; auto ret = UnaryCompute(weights_input, &weights_output, unary_op); - // PAss the output + // Pass the output if (ret == tensorflow::Status::OK()) { outputs->push_back(TRT_TensorOrWeights(weights_output)); } @@ -631,11 +810,11 @@ tensorflow::Status ConstantFoldUnary( } // TODO(jie,ben) broadcast is needed yet not implemented -// Let's get the simple stuff working first. Maybe we should fall bakc to TF +// Let's get the simple stuff working first. Maybe we should fall back to TF // approach for constant folding tensorflow::Status ConstantFoldBinary( Converter& ctx, const tensorflow::NodeDef& node_def, - std::vector const& inputs, + const std::vector& inputs, std::vector* outputs) { TRT_ShapedWeights weights_input_l = inputs.at(0).weights(); TRT_ShapedWeights weights_input_r = inputs.at(1).weights(); @@ -648,12 +827,12 @@ tensorflow::Status ConstantFoldBinary( "Binary op implicit broadcast not supported: " + node_def.op()); // TODO(jie): constant fold should really fall back to TF. - int nb_dims = weights_input_l.shape_.nbDims; + int num_dims = weights_input_l.shape_.nbDims; nvinfer1::Dims output_shape; - output_shape.nbDims = nb_dims; - VLOG(2) << "nb_dims: " << nb_dims + output_shape.nbDims = num_dims; + VLOG(2) << "nb_dims: " << num_dims << ", the other: " << weights_input_r.shape_.nbDims; - for (int i = 0; i < nb_dims; i++) { + for (int i = 0; i < num_dims; i++) { if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) { output_shape.d[i] = weights_input_l.shape_.d[i]; } else if (weights_input_l.shape_.d[i] == 1 || @@ -678,7 +857,6 @@ tensorflow::Status ConstantFoldBinary( // Allocate output weights TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape); - // Maybe I should do a switch LambdaFactory binary_op; if (node_def.op() == "Sub") { binary_op.op = LambdaFactory::OP_CATEGORY::SUB; @@ -712,48 +890,94 @@ tensorflow::Status BinaryTensorOpWeight( // Maybe this part has to be moved into the block of rsqrt later // Check type consistency - auto dtype = TFAttrs(node_def).get("T"); - CHECK_EQ_TYPE(tensor->getType(), dtype); // Cast to int for error messages nvinfer1::DataType ttype; - TF_CHECK_OK(ConvertDType(weights.type_, &ttype)); - CHECK_EQ_TYPE(ttype, dtype); // Cast to int for error message + TF_RETURN_IF_ERROR(ConvertDType(weights.type_, &ttype)); // Check scale mode auto dims_w = weights.shape_; auto dims_t = tensor->getDimensions(); - // Default to channel-wise + // default to element-wise auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; + // TODO(jie): maybe use a permutation instead to support more cases; + bool permutation_flag = false; + if (weights.count() == 1) { VLOG(2) << "UNIFORM"; scale_mode = nvinfer1::ScaleMode::kUNIFORM; } else { - // No broadcasting on Batch dimension; - assert(dims_w.d[0] == 1); - - // Broadcasting on Channel dimension only allowed in kUNIFORM - assert(dims_w.d[1] == dims_t.d[0]); - assert(dims_w.nbDims == dims_t.nbDims); - - // Default is element; - for (int i = 2; i < dims_w.nbDims; i++) { - if (dims_w.d[i] != dims_t.d[i - 1]) { - scale_mode = nvinfer1::ScaleMode::kCHANNEL; - break; + // no broadcasting on Batch dimension; + VLOG(2) << "WEIGHTS DIM: " << dims_w.nbDims + << " tensor DIM: " << dims_t.nbDims; + if (dims_w.nbDims == dims_t.nbDims + 1) { + if (dims_w.d[0] == 1) { + for (int i = 1; i < dims_w.nbDims; i++) { + dims_w.d[i - 1] = dims_w.d[i]; + } + dims_w.nbDims--; + } else { + return tensorflow::errors::InvalidArgument( + "Binary op cannot operate on batch, " + node_def.name()); } } - if (scale_mode == nvinfer1::ScaleMode::kELEMENTWISE) { + + if (dims_w.nbDims == dims_t.nbDims && dims_w.d[0] == dims_t.d[0]) { scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; - for (int i = 2; i < dims_w.nbDims; i++) { - if (dims_w.d[i] != 1) - return tensorflow::errors::InvalidArgument( - "Weight shape not compatible at, " + node_def.name()); + // default is element; + for (int i = 1; i < dims_w.nbDims; i++) { + if (dims_w.d[i] != dims_t.d[i]) { + // if dimension does not match, switch back to channel; + VLOG(2) << "channel"; + scale_mode = nvinfer1::ScaleMode::kCHANNEL; + break; + } + } + // if channel as candidate, validate it + if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) { + for (int i = 1; i < dims_w.nbDims; i++) { + if (dims_w.d[i] != 1) + return tensorflow::errors::InvalidArgument( + "Weight shape not compatible at, " + node_def.name()); + } + } else { + VLOG(2) << "elementwise"; + } + } else if (dims_w.nbDims == 1 && + dims_w.d[0] == dims_t.d[dims_t.nbDims - 1]) { + // channel wise and broadcast required; + permutation_flag = true; + scale_mode = nvinfer1::ScaleMode::kCHANNEL; + } else { + return tensorflow::errors::InvalidArgument( + "Weight shape not compatible at, " + node_def.name()); + } + } + + // transpose last dimension + std::vector permutation(dims_t.nbDims + 1); + if (permutation_flag) { + if (scale_mode == nvinfer1::ScaleMode::kCHANNEL && dims_t.nbDims > 1) { + // we swap the last dimension into channel for trt. + // because of tensorflow default broadcasting rules. + for (int i = 0; i < static_cast(permutation.size()); i++) { + permutation[i] = i; } + permutation[1] = dims_t.nbDims; + permutation[dims_t.nbDims] = 1; + tensor = ctx.TransposeTensor(const_cast(tensor), + permutation); + } else { + return tensorflow::errors::InvalidArgument( + "Transpose cannot be applied, " + node_def.name()); } } - // Prepare weights + if (ctx.isFP16()) { + weights = ConvertFP32ToFP16(ctx, weights); + } + + // prepare weights TRT_ShapedWeights shift_weights(weights.type_); TRT_ShapedWeights scale_weights(weights.type_); TRT_ShapedWeights power_weights(weights.type_); @@ -779,88 +1003,24 @@ tensorflow::Status BinaryTensorOpWeight( scale_weights, power_weights); nvinfer1::ITensor* output_tensor = layer->getOutput(0); + // transpose back dimension + if (permutation_flag) { + output_tensor = ctx.TransposeTensor(output_tensor, permutation); + } // Pass the output outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } -tensorflow::Status BinaryTensorOpTensor( - Converter& ctx, const tensorflow::NodeDef& node_def, - const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r, - std::vector* outputs) { - static const std::unordered_map ops{ - {"Add", nvinfer1::ElementWiseOperation::kSUM}, - {"Mul", nvinfer1::ElementWiseOperation::kPROD}, - // {"max", nvinfer1::ElementWiseOperation::kMAX}, - // {"min", nvinfer1::ElementWiseOperation::kMIN}, - {"Sub", nvinfer1::ElementWiseOperation::kSUB}, - {"Div", nvinfer1::ElementWiseOperation::kDIV}, - }; - - // FIXME assume type matches input weights - // Get trt type & shape - TFAttrs attrs(node_def); - // Maybe this part has to be moved into the block of rsqrt later - nvinfer1::DataType dtype = attrs.get("T"); - - // Check type consistency - CHECK_EQ_TYPE(tensor_l->getType(), dtype); - CHECK_EQ_TYPE(tensor_r->getType(), dtype); - auto op_pair = ops.find(node_def.op()); - if (op_pair == ops.end()) - return tensorflow::errors::Unimplemented("binary op: " + node_def.op() + - " not supported at: " + - node_def.name()); - - nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise( - *const_cast(tensor_l), - *const_cast(tensor_r), op_pair->second); - - nvinfer1::ITensor* output_tensor = layer->getOutput(0); - - // Pass the output - outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return tensorflow::Status::OK(); -} +enum class ConvolutionType { DEFAULT, DEPTHWISE_CONV }; -tensorflow::Status ConvertPlaceholder( +tensorflow::Status ConvertConv2DHelper( Converter& ctx, const tensorflow::NodeDef& node_def, - std::vector const& inputs, - std::vector* outputs) { - VLOG(2) << "Placeholder should have been replace already"; - return tensorflow::errors::Unimplemented(", cannot convert Placeholder op"); - // OK this make sense since we are supposed to replace it with input - TFAttrs attrs(node_def); - nvinfer1::DataType dtype = attrs.get("dtype"); - nvinfer1::Dims dims = attrs.get("shape"); - - dims.nbDims--; - for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1]; - - nvinfer1::ITensor* output = - ctx.network()->addInput(node_def.name().c_str(), dtype, dims); - if (!output) { - return tensorflow::errors::InvalidArgument("Failed to create Input layer"); - } - outputs->push_back(TRT_TensorOrWeights(output)); - return tensorflow::Status::OK(); -} + const std::vector& inputs, + std::vector* outputs, int group) { + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); -tensorflow::Status ConvertConv2D(Converter& ctx, - const tensorflow::NodeDef& node_def, - const std::vector& inputs, - std::vector* outputs) { - nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); - // TODO(jie): handle NHWC/NCHW transpose; - TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); - TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck); - ReorderRSCKToKCRS(weights_rsck, &weights); - TRT_ShapedWeights biases(weights.type_); - int noutput = weights.shape_.d[0]; - nvinfer1::DimsHW kernel_size; - kernel_size.h() = weights.shape_.d[2]; - kernel_size.w() = weights.shape_.d[3]; TFAttrs attrs(node_def); int h_index = 2; @@ -874,11 +1034,35 @@ tensorflow::Status ConvertConv2D(Converter& ctx, // TODO(jie): transpose it } + // tensor after transpose (NCHW) + auto tensor_dim = tensor->getDimensions(); + + int num_groups = group; + if (num_groups == 0) // depthwise convolution + num_groups = tensor_dim.d[0]; + VLOG(2) << "groups count: " << num_groups; + + TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); + if (ctx.isFP16()) { + weights_rsck = ConvertFP32ToFP16(ctx, inputs.at(1).weights()); + } + + TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck); + ReorderRSCKToKCRS(weights_rsck, &weights, num_groups); + TRT_ShapedWeights biases(weights.type_); + int noutput = weights.shape_.d[0] * num_groups; + nvinfer1::DimsHW kernel_size; + kernel_size.h() = weights.shape_.d[2]; + kernel_size.w() = weights.shape_.d[3]; + VLOG(2) << "kernel size: " << kernel_size.h() << ", " << kernel_size.w(); + // TODO(jie): stride. (NHWC/NCHW) auto tf_stride = attrs.get>("strides"); + VLOG(2) << "h_INDEX" << h_index << ", w_index " << w_index; + VLOG(2) << "stride!!!: " << tf_stride[0] << tf_stride[1] << tf_stride[2] + << tf_stride[3]; nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); - auto tensor_dim = tensor->getDimensions(); std::vector> padding; // TODO(jie): padding. if (attrs.get("padding") == "SAME") { @@ -919,10 +1103,11 @@ tensorflow::Status ConvertConv2D(Converter& ctx, layer->setStride(stride); layer->setPadding({padding[0].first, padding[1].first}); layer->setName(node_def.name().c_str()); + layer->setNbGroups(num_groups); nvinfer1::ITensor* output_tensor = layer->getOutput(0); auto dim_after = output_tensor->getDimensions(); - VLOG(2) << "TENSOR out: " << dim_after.d[0] << ", " << dim_after.d[1] + VLOG(2) << "TENSOR out: " << dim_after.d[0] << ", " << dim_after.d[1] << ", " << dim_after.d[2] << ", " << dim_after.d[3]; if (data_format == "NHWC") { @@ -935,11 +1120,101 @@ tensorflow::Status ConvertConv2D(Converter& ctx, return tensorflow::Status::OK(); } +tensorflow::Status ConvertConv2DHelper( + Converter& ctx, const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs, ConvolutionType type) { + switch (type) { + case ConvolutionType::DEFAULT: + return ConvertConv2DHelper(ctx, node_def, inputs, outputs, 1); + case ConvolutionType::DEPTHWISE_CONV: + return ConvertConv2DHelper(ctx, node_def, inputs, outputs, 0); + } + return tensorflow::errors::Unimplemented("unsupported convolution type at, " + + node_def.name()); +} + +tensorflow::Status BinaryTensorOpTensor( + Converter& ctx, const tensorflow::NodeDef& node_def, + const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r, + std::vector* outputs) { + static const std::unordered_map ops{ + {"Add", nvinfer1::ElementWiseOperation::kSUM}, + {"Mul", nvinfer1::ElementWiseOperation::kPROD}, + {"Sub", nvinfer1::ElementWiseOperation::kSUB}, + {"Div", nvinfer1::ElementWiseOperation::kDIV}, + }; + + // FIXME assume type matches input weights + // get trt type & shape + TFAttrs attrs(node_def); + // maybe this part has to be moved into the block of rsqrt later + nvinfer1::DataType dtype = attrs.get("T"); + + // check type consistency + CHECK_EQ_TYPE(tensor_l->getType(), dtype); + CHECK_EQ_TYPE(tensor_r->getType(), dtype); + auto op_pair = ops.find(node_def.op()); + if (op_pair == ops.end()) + return tensorflow::errors::Unimplemented( + "binary op: " + node_def.op() + + " not supported at: " + node_def.name()); + + nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise( + *const_cast(tensor_l), + *const_cast(tensor_r), op_pair->second); + + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + + // pass the output + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertPlaceholder( + Converter& ctx, const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + VLOG(2) << "Placeholder should have been replace already"; + return tensorflow::errors::Unimplemented("cannot convert Placeholder op"); + // OK this make sense since we are supposed to replace it with input + TFAttrs attrs(node_def); + nvinfer1::DataType dtype = attrs.get("dtype"); + nvinfer1::Dims dims = attrs.get("shape"); + + dims.nbDims--; + for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1]; + + nvinfer1::ITensor* output = + ctx.network()->addInput(node_def.name().c_str(), dtype, dims); + if (!output) { + return tensorflow::errors::InvalidArgument("Failed to create Input layer"); + } + outputs->push_back(TRT_TensorOrWeights(output)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertConv2D(Converter& ctx, + const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + return ConvertConv2DHelper(ctx, node_def, inputs, outputs, + ConvolutionType::DEFAULT); +} + +tensorflow::Status ConvertConv2DDepthwise( + Converter& ctx, const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + return ConvertConv2DHelper(ctx, node_def, inputs, outputs, + ConvolutionType::DEPTHWISE_CONV); +} + tensorflow::Status ConvertPool(Converter& ctx, const tensorflow::NodeDef& node_def, - std::vector const& inputs, + const std::vector& inputs, std::vector* outputs) { - nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); TFAttrs attrs(node_def); int h_index = 2; @@ -957,6 +1232,8 @@ tensorflow::Status ConvertPool(Converter& ctx, // TODO(jie): support other pooling type if (node_def.op() == "MaxPool") type = nvinfer1::PoolingType::kMAX; + else if (node_def.op() == "AvgPool") + type = nvinfer1::PoolingType::kAVERAGE; else return tensorflow::errors::Unimplemented("Only supports Max pool"); @@ -1019,9 +1296,9 @@ tensorflow::Status ConvertPool(Converter& ctx, tensorflow::Status ConvertActivation( Converter& ctx, const tensorflow::NodeDef& node_def, - std::vector const& inputs, + const std::vector& inputs, std::vector* outputs) { - nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); nvinfer1::IActivationLayer* layer = ctx.network()->addActivation( *const_cast(tensor), nvinfer1::ActivationType::kRELU); nvinfer1::ITensor* output_tensor = layer->getOutput(0); @@ -1031,17 +1308,20 @@ tensorflow::Status ConvertActivation( tensorflow::Status ConvertScale(Converter& ctx, const tensorflow::NodeDef& node_def, - std::vector const& inputs, + const std::vector& inputs, std::vector* outputs) { if (inputs.size() != 2 || !inputs.at(0).is_tensor() || !inputs.at(1).is_weights()) return tensorflow::errors::Unimplemented( "Only supports tensor op weight for now, at " + node_def.name()); // Implement tensor binaryOp weight [channel wise] for now; - nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - // TODO(jie): handle NHWC/NCHW transpose; TRT_ShapedWeights weights = inputs.at(1).weights(); + if (ctx.isFP16()) { + weights = ConvertFP32ToFP16(ctx, inputs.at(1).weights()); + } + TRT_ShapedWeights empty_weights(weights.type_); TFAttrs attrs(node_def); @@ -1055,12 +1335,29 @@ tensorflow::Status ConvertScale(Converter& ctx, } else { VLOG(2) << "NCHW !!!!"; } - nvinfer1::IScaleLayer* layer = ctx.network()->addScale( - *const_cast(tensor), nvinfer1::ScaleMode::kCHANNEL, - weights, empty_weights, empty_weights); - nvinfer1::ITensor* output_tensor = layer->getOutput(0); - if (data_format == "NHWC") { + auto dims = tensor->getDimensions(); + VLOG(2) << "tensor dimensions: " << dims.nbDims; + for (int i = 0; i < dims.nbDims; i++) { + VLOG(2) << "i: " << dims.d[i]; + } + dims = weights.shape_; + VLOG(2) << "tensor dimensions: " << dims.nbDims; + for (int i = 0; i < dims.nbDims; i++) { + VLOG(2) << "i: " << dims.d[i]; + } + + nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL; + if (weights.shape_.d[0] == 1) { + mode = nvinfer1::ScaleMode::kUNIFORM; + } + + nvinfer1::IScaleLayer* layer = + ctx.network()->addScale(*const_cast(tensor), mode, + weights, empty_weights, empty_weights); + + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + if (data_format == "NHWC") { // TODO(jie): transpose it back! output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); } else { @@ -1072,7 +1369,7 @@ tensorflow::Status ConvertScale(Converter& ctx, tensorflow::Status ConvertConst(Converter& ctx, const tensorflow::NodeDef& node_def, - std::vector const& inputs, + const std::vector& inputs, std::vector* outputs) { const auto& weights_tensor = node_def.attr().at("value").tensor(); @@ -1091,22 +1388,96 @@ tensorflow::Status ConvertConst(Converter& ctx, VLOG(2) << "SCALAR!!!" << node_def.name(); nvinfer1::Dims scalar_shape; if (tensor.dims() > 0) { - VLOG(2) << "Dimensions: " << tensor.dims(); - weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(), - GetTensorShape(tensor)); + VLOG(2) << "dimensions: " << tensor.dims(); + VLOG(2) << "size: " << weights_tensor.float_val_size(); + scalar_shape = GetTensorShape(tensor); + for (int i = 0; i < scalar_shape.nbDims; i++) + VLOG(2) << scalar_shape.d[i]; + if (GetShapeSize(scalar_shape) != weights_tensor.float_val_size()) { + if (weights_tensor.float_val_size() == 1 || + scalar_shape.d[0] == weights_tensor.float_val_size()) { + scalar_shape.nbDims = 1; + // no dimension provided. flatten it + scalar_shape.d[0] = weights_tensor.float_val_size(); + scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL; + } else { + LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and" + << " kUNIFORM, at: " << node_def.name(); + string err_str("Broadcast method is not supported for '"); + StrAppend(&err_str, node_def.name(), "' of type ", node_def.op()); + return tensorflow::errors::InvalidArgument(err_str); + } + } } else { VLOG(2) << "Dimensions: " << tensor.dims(); scalar_shape.nbDims = 1; - scalar_shape.d[0] = 1; + // no dimension provided. flatten it + scalar_shape.d[0] = weights_tensor.float_val_size(); + scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL; + for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) { + scalar_shape.d[i] = 0; + scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL; + } + } + size_t len_data = tensorflow::DataTypeSize(dtype); + for (int i = 0; i < scalar_shape.nbDims; i++) len_data *= scalar_shape.d[i]; + ctx.weight_store()->store_.push_back(std::vector(len_data)); + void* dst = static_cast(&(ctx.weight_store()->store_.back()[0])); + std::vector tensor_data( + weights_tensor.float_val().begin(), + weights_tensor.float_val() + .end()); // make a local copy first to flatten + memcpy(dst, tensor_data.data(), len_data); // store into weight store + weights = TRT_ShapedWeights(dtype, dst, scalar_shape); + } else if (!weights_tensor.int_val().empty()) { + VLOG(2) << "int!!!" << node_def.name(); + nvinfer1::Dims scalar_shape; + if (tensor.dims() > 0) { + VLOG(2) << "dimensions: " << tensor.dims(); + scalar_shape = GetTensorShape(tensor); + if (GetShapeSize(scalar_shape) != weights_tensor.int_val_size()) { + if (weights_tensor.int_val_size() == 1 || + scalar_shape.d[0] == weights_tensor.int_val_size()) { + scalar_shape.nbDims = 1; + // no dimension provided. flatten it + scalar_shape.d[0] = weights_tensor.int_val_size(); + scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL; + } else { + LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and" + << " kUNIFORM, at: " << node_def.name(); + string err_str("Broadcast method is not supported for '"); + StrAppend(&err_str, node_def.name(), "' of type ", node_def.op()); + return tensorflow::errors::InvalidArgument(err_str); + } + } + } else { + VLOG(2) << "dimensions: " << tensor.dims(); + scalar_shape.nbDims = 1; + // no dimension provided. flatten it + scalar_shape.d[0] = weights_tensor.int_val_size(); scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL; for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) { scalar_shape.d[i] = 0; scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL; } - weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(), - scalar_shape); } + // we should not have converted //if (ctx.isFP16()) { + size_t len_data = tensorflow::DataTypeSize(dtype); + for (int i = 0; i < scalar_shape.nbDims; i++) len_data *= scalar_shape.d[i]; + size_t len_tensor = weights_tensor.int_val_size() * sizeof(int32); + len_data = std::max(len_data, len_tensor); + ctx.weight_store()->store_.push_back(std::vector(len_data)); + void* dst = static_cast(&(ctx.weight_store()->store_.back()[0])); + std::vector tensor_data( + weights_tensor.int_val().begin(), + weights_tensor.int_val().end()); // make a local copy first to flatten + // doesn't have to be contigous + memcpy(dst, tensor_data.data(), len_tensor); // store into weight store + weights = TRT_ShapedWeights(dtype, dst, scalar_shape); } else if (!weights_tensor.tensor_content().empty()) { + // obsolete method. + // After optimization path, we do not see weights in this format. + // fp16 conversion technically should be needed here. VLOG(2) << "TENSOR!!!" << node_def.name(); const auto& content = weights_tensor.tensor_content(); @@ -1130,7 +1501,7 @@ tensorflow::Status ConvertConst(Converter& ctx, tensorflow::Status ConvertIdentity( Converter& ctx, const tensorflow::NodeDef& node_def, - std::vector const& inputs, + const std::vector& inputs, std::vector* outputs) { outputs->push_back(inputs.at(0)); return tensorflow::Status::OK(); @@ -1138,7 +1509,7 @@ tensorflow::Status ConvertIdentity( tensorflow::Status ConvertBinary(Converter& ctx, const tensorflow::NodeDef& node_def, - std::vector const& inputs, + const std::vector& inputs, std::vector* outputs) { if (inputs.size() != 2) return tensorflow::errors::FailedPrecondition( @@ -1165,7 +1536,7 @@ tensorflow::Status ConvertBinary(Converter& ctx, tensorflow::Status ConvertUnary(Converter& ctx, const tensorflow::NodeDef& node_def, - std::vector const& inputs, + const std::vector& inputs, std::vector* outputs) { if (inputs.size() != 1) return tensorflow::errors::FailedPrecondition( @@ -1183,7 +1554,7 @@ tensorflow::Status ConvertUnary(Converter& ctx, tensorflow::Status ConvertReduce(Converter& ctx, const tensorflow::NodeDef& node_def, - std::vector const& inputs, + const std::vector& inputs, std::vector* outputs) { if (inputs.size() != 2 || !inputs.at(0).is_tensor() || !inputs.at(1).is_weights()) @@ -1191,7 +1562,7 @@ tensorflow::Status ConvertReduce(Converter& ctx, "Input expects tensor and weights, at" + node_def.name()); // Implement tensor binaryOp weight [channel wise] for now; - nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); auto dims = tensor->getDimensions(); // Restore implicit batch dimension int nb_dims = dims.nbDims + 1; @@ -1229,6 +1600,7 @@ tensorflow::Status ConvertReduce(Converter& ctx, return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" + node_def.name()); if (index_list_data[i] == 1) permuted_index = 1; + idx_set.emplace(index_list_data[i]); } @@ -1236,7 +1608,7 @@ tensorflow::Status ConvertReduce(Converter& ctx, nvinfer1::DimsHW pool_kernel; if (permuted_index == 1) { for (int i = 2; i < nb_dims; i++) { - if (idx_set.count(i)) { + if (idx_set.count(i) == 0) { permuted_index = i; break; } @@ -1271,12 +1643,13 @@ tensorflow::Status ConvertReduce(Converter& ctx, output_tensor = ctx.TransposeTensor( const_cast(output_tensor), permutation_order); } + outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } tensorflow::Status ConvertPad(Converter& ctx, const tensorflow::NodeDef& node_def, - std::vector const& inputs, + const std::vector& inputs, std::vector* outputs) { if (inputs.size() != 2 || !inputs.at(0).is_tensor() || !inputs.at(1).is_weights()) @@ -1284,7 +1657,7 @@ tensorflow::Status ConvertPad(Converter& ctx, "Input expects tensor and weights, at" + node_def.name()); // Implement tensor binaryOp weight [channel wise] for now; - nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); auto dims = tensor->getDimensions(); // Restore implicit batch dimension int nb_dims = dims.nbDims + 1; @@ -1371,19 +1744,318 @@ tensorflow::Status ConvertPad(Converter& ctx, return tensorflow::Status::OK(); } +tensorflow::Status ConvertConcat(Converter& ctx, + const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + // not including the last input (axis) here + int input_size = static_cast(inputs.size()) - 1; + + if (!inputs.at(0).is_tensor()) + return tensorflow::errors::InvalidArgument( + "Concat in TRT support only Tensor input, at " + node_def.name()); + + // We are retrieving the axis + TRT_ShapedWeights axis = inputs.at(input_size).weights(); + + TFAttrs attrs(node_def); + auto index_type = attrs.get("Tidx"); + + // TODO(jie): handle data type + // Only expect to handle INT32 as index attributes for now + if (index_type != tensorflow::DataType::DT_INT32) + return tensorflow::errors::Unimplemented( + "Tidx supports only DT_INT32, at " + node_def.name()); + + int index = *(static_cast(const_cast(axis.GetValues()))); + + // TODO(jie): early termination with no-op (attr_size==1) + + auto dim = inputs.at(0).tensor()->getDimensions(); + // dimension check + if (index > dim.nbDims + 1) + return tensorflow::errors::InvalidArgument( + "Concatenate on axis out of dimension range, at " + node_def.name()); + + if (index == 0) + return tensorflow::errors::InvalidArgument( + "Concatenate on batch dimension not supported, at " + node_def.name()); + + // incase we need permutation; + std::vector permutation_order(dim.nbDims + 1); + + for (int i = 0; i < dim.nbDims + 1; i++) permutation_order[i] = i; + + if (index != 1) { + permutation_order[1] = index - 1; + permutation_order[index - 1] = 1; + } + + std::vector inputs_vec; + // Shap chack (all input tensor should have same shape) + // starting from 0 since we are probably also doing transpose here; + for (int i = 0; i < input_size; i++) { + auto tensor_i = inputs.at(i).tensor(); + auto dim_i = tensor_i->getDimensions(); + if (dim_i.nbDims != dim.nbDims) + return tensorflow::errors::InvalidArgument( + "Concatenate receives inputs with inconsistent dimensions, at " + + node_def.name()); + + for (int j = 0; j < dim.nbDims; j++) { + // check dimension consistency on non-concatenate axis + if (j != index - 1 && dim_i.d[j] != dim.d[j]) + return tensorflow::errors::InvalidArgument( + "Concatenate receives inputs with inconsistent shape, at" + + node_def.name()); + } + + // TRT does concatenation only on channel! + if (index != 1) + tensor_i = ctx.TransposeTensor(const_cast(tensor_i), + permutation_order); + + inputs_vec.push_back(tensor_i); + } + + // nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + nvinfer1::IConcatenationLayer* layer = ctx.network()->addConcatenation( + const_cast(inputs_vec.data()), + inputs_vec.size()); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + + if (index != 1) { + output_tensor = ctx.TransposeTensor(output_tensor, permutation_order); + } + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertFusedBatchNorm( + Converter& ctx, const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + TFAttrs attrs(node_def); + float epsilon = attrs.get("epsilon"); + auto data_format = attrs.get("data_format"); + if (data_format != "NCHW") { + return tensorflow::errors::Unimplemented( + "only data_format=NCHW is supported, at " + node_def.name()); + } + bool is_training = attrs.get("is_training"); + if (is_training) { + return tensorflow::errors::Unimplemented( + "only is_training=false is supported, at " + node_def.name()); + } + nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + + // Check parameter types + auto parameter_type = inputs.at(1).weights().type_; + if ((parameter_type != tensorflow::DataType::DT_FLOAT) && + (parameter_type != tensorflow::DataType::DT_HALF)) { + return tensorflow::errors::Unimplemented( + "only float32 or float16 weight data type is supported, for node " + + node_def.name() + " got " + tensorflow::DataTypeString(parameter_type)); + } + for (int i = 1; i < 5; i++) { + if (inputs.at(i).weights().type_ != parameter_type) { + return tensorflow::errors::Unimplemented( + "Inconsistent parameter type for batchnormis not supported, at: " + + node_def.name()); + } + } + + TRT_ShapedWeights dummy_power_weights(parameter_type); + size_t nweight = 0; + for (int i = 1; i < 5; i++) { + nweight = std::max(nweight, (size_t)inputs.at(i).weights().count()); + } + TRT_ShapedWeights* ptr_shape_weights = nullptr; + for (int i = 1; i < 5; i++) { + if (inputs.at(i).weights().count() == nweight) { + ptr_shape_weights = + const_cast(&(inputs.at(i).weights())); + } else if (inputs.at(i).weights().count() != 1) { + return tensorflow::errors::InvalidArgument( + "Inconsistent batchnorm parameter count, at: " + node_def.name()); + } + } + // We could technically have two weights with different shape. + // that requires two addScale op, arguably less performant + TRT_ShapedWeights combined_scale_weights = + ctx.get_temp_weights_like(*ptr_shape_weights); + TRT_ShapedWeights combined_offset_weights = + ctx.get_temp_weights_like(*ptr_shape_weights); + + const Eigen::half* cast_vals_array[4]; + const float* vals_array[4]; + for (int j = 0; j < 4; j++) { + cast_vals_array[j] = + static_cast(inputs.at(j + 1).weights().GetValues()); + vals_array[j] = + static_cast(inputs.at(j + 1).weights().GetValues()); + } + Eigen::half* cast_combined_scale_vals = const_cast( + static_cast(combined_scale_weights.GetValues())); + Eigen::half* cast_combined_offset_vals = const_cast( + static_cast(combined_offset_weights.GetValues())); + float* combined_scale_vals = const_cast( + static_cast(combined_scale_weights.GetValues())); + float* combined_offset_vals = const_cast( + static_cast(combined_offset_weights.GetValues())); + + for (size_t i = 0; i < nweight; ++i) { + float batchnorm_data[4]; + for (int j = 0; j < 4; j++) { + if (inputs.at(j + 1).weights().count() != 1) { + if (parameter_type == tensorflow::DT_FLOAT) { + batchnorm_data[j] = vals_array[j][i]; + } else if (parameter_type == tensorflow::DT_HALF) { + batchnorm_data[j] = + Eigen::half_impl::half_to_float(cast_vals_array[j][i]); + } + } else { + if (parameter_type == tensorflow::DT_FLOAT) { + batchnorm_data[j] = vals_array[j][0]; + } else if (parameter_type == tensorflow::DT_HALF) { + batchnorm_data[j] = + Eigen::half_impl::half_to_float(cast_vals_array[j][0]); + } + } + } + float scale = batchnorm_data[0]; + float offset = batchnorm_data[1]; + float mean = batchnorm_data[2]; + float variance = batchnorm_data[3]; + float combined_scale_val = scale / sqrtf(variance + epsilon); + float combined_offset_val = offset - mean * combined_scale_val; + if (parameter_type == tensorflow::DT_FLOAT) { + combined_scale_vals[i] = combined_scale_val; + combined_offset_vals[i] = combined_offset_val; + } else if (parameter_type == tensorflow::DT_HALF) { + cast_combined_scale_vals[i] = Eigen::half(combined_scale_val); + cast_combined_offset_vals[i] = Eigen::half(combined_offset_val); + } + } + + nvinfer1::ScaleMode mode = nweight == 1 ? nvinfer1::ScaleMode::kUNIFORM + : nvinfer1::ScaleMode::kCHANNEL; + nvinfer1::IScaleLayer* layer = + ctx.network()->addScale(*const_cast(tensor), mode, + combined_offset_weights.GetWeightsForTRT(), + combined_scale_weights.GetWeightsForTRT(), + dummy_power_weights.GetWeightsForTRT()); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertMatMul(Converter& ctx, + const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + + // TODO(jie): transpose! + TFAttrs attrs(node_def); + + TRT_ShapedWeights weights_ck = inputs.at(1).weights(); + TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_ck); + ReorderCKtoKC(weights_ck, &weights); + TRT_ShapedWeights biases(weights.type_); + + int noutput = weights.shape_.d[0]; + + nvinfer1::IFullyConnectedLayer* layer = ctx.network()->addFullyConnected( + *const_cast(tensor), noutput, weights, biases); + + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertReshape( + Converter& ctx, const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + if (inputs.size() != 2 || !inputs.at(0).is_tensor() || + !inputs.at(1).is_weights()) + return tensorflow::errors::InvalidArgument( + "Input expects tensor and weights, at" + node_def.name()); + + // implement tensor binaryOp weight [channel wise] for now; + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + auto dims = tensor->getDimensions(); + // restore implicit batch dimension + + TRT_ShapedWeights shape = inputs.at(1).weights(); + + TFAttrs attrs(node_def); + + auto padding_type = attrs.get("Tshape"); + + if (shape.shape_.nbDims != 1) + return tensorflow::errors::InvalidArgument( + "reshape new shape is not 1 dimensional, at " + node_def.name()); + + // Only expect to handle INT32 as attributes for now + if (padding_type != tensorflow::DataType::DT_INT32) + return tensorflow::errors::Unimplemented( + "reshape new shape supports only DT_INT32, at " + node_def.name()); + + auto shape_data = static_cast(const_cast(shape.GetValues())); + + if (shape_data[0] != -1) + return tensorflow::errors::InvalidArgument( + "reshape new shape first dimension is not -1, at " + node_def.name()); + + auto shape_num_dims = shape.shape_.d[0]; + VLOG(2) << "shape dimensions: " << shape_num_dims; + int volume_w = 1; + for (int i = 1; i < shape.shape_.d[0]; i++) volume_w *= shape_data[i]; + + int volume_t = 1; + for (int i = 0; i < dims.nbDims; i++) volume_t *= dims.d[i]; + + VLOG(2) << "volume: " << volume_t << " volume weights: " << volume_w; + if (volume_w != volume_t) + return tensorflow::errors::InvalidArgument( + "volume does not agree between tensor and new shape, at " + + node_def.name()); + + nvinfer1::IShuffleLayer* layer = + ctx.network()->addShuffle(*const_cast(tensor)); + + nvinfer1::Dims reshape_dims; + VLOG(2) << "new dimension: " << shape_num_dims - 1; + reshape_dims.nbDims = shape_num_dims - 1; + for (int32_t i = 0; i < reshape_dims.nbDims; ++i) { + reshape_dims.d[i] = shape_data[i + 1]; + } + layer->setReshapeDimensions(reshape_dims); + VLOG(2) << "new dimension: " << shape_num_dims - 1; + + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + auto dims_output = output_tensor->getDimensions(); + VLOG(2) << "output tensor dimension:" << dims_output.nbDims; + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + void Converter::register_op_converters() { // vgg_16 slim implementation op_registry_["Placeholder"] = ConvertPlaceholder; op_registry_["Conv2D"] = ConvertConv2D; + op_registry_["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; op_registry_["Relu"] = ConvertActivation; op_registry_["MaxPool"] = ConvertPool; + op_registry_["AvgPool"] = ConvertPool; // This could be really handled as ConvertBinary op_registry_["BiasAdd"] = ConvertScale; op_registry_["Const"] = ConvertConst; - // op_registry_["MatMul"] = ConvertFullyConnected; // Not used in vgg // TODO(ben,jie): this is a temp hack. op_registry_["Identity"] = ConvertIdentity; // Identity should be removed - // op_registry_["AvgPool"] = ConvertPool; + op_registry_["Snapshot"] = ConvertIdentity; // Snapshot should be removed // resnet_50_v1 slim implementation op_registry_["Add"] = ConvertBinary; @@ -1393,26 +2065,373 @@ void Converter::register_op_converters() { op_registry_["Mean"] = ConvertReduce; op_registry_["Pad"] = ConvertPad; // TODO(ben,jie): Add more ops + + op_registry_["ConcatV2"] = ConvertConcat; + op_registry_["MatMul"] = ConvertMatMul; + op_registry_["Reshape"] = ConvertReshape; + op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm; + op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm; } } // namespace +tensorflow::Status GetTensorRTGraph(tensorrt::convert::SubGraphParams& s) { + return tensorflow::errors::Unimplemented("Not implemented yet"); +} +tensorflow::Status ConvertCalibrationNodeToEngineNode( + tensorflow::Graph& graph, tensorflow::Node* c_node) { + const auto ndef = c_node->def(); + + TFAttrs attrs(ndef); + std::vector segment_nodes( + attrs.get>("segment_nodes")); + std::vector output_nodes( + attrs.get>("segment_output_names")); + std::vector input_names( + attrs.get>("input_names")); + string res_name = attrs.get("resource_name"); + VLOG(1) << "Node name " << c_node->name() << " res_name " << res_name; + string engine_name = "my_trt_op"; + { + const auto node_id = tensorflow::str_util::Split(res_name, "_"); + engine_name += node_id.back(); + } + std::map node_maps; + + for (auto n : graph.op_nodes()) { + node_maps.insert({n->name(), n}); + } + VLOG(1) << "Output Nodes:"; + std::vector out_types; + std::vector out_edges; + for (auto& i : output_nodes) { + auto node_port = tensorflow::str_util::Split(i, ":"); + VLOG(1) << " " << i << " in graph " << node_maps.count(i); + auto out_node_name = node_port.at(0); + if (node_port.size() > 1) { + VLOG(1) << "Multi port output" << node_port.at(0) << " " + << node_port.at(1) << " size=" << node_port.size(); + } + auto node_it = node_maps.find(out_node_name); + if (node_it != node_maps.end()) { + tensorflow::Node* out_node = node_it->second; + int port = 0; + if (node_port.size() == 2) { + port = std::strtoul(node_port.at(1).c_str(), nullptr, 10); + out_types.push_back(out_node->output_type(port)); + } else { + out_types.push_back(out_node->output_type(0)); + } + for (auto out_edge : out_node->out_edges()) { + if (out_edge->src_output() == port) { + out_edges.push_back(out_edge); + break; + } + } + } else { + LOG(WARNING) << " couldn't find output node " << out_node_name; + } + } + VLOG(1) << "Input Nodes:"; + for (auto& i : input_names) { + VLOG(1) << " " << i << " in graph " << node_maps.count(i); + } + auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance(); + auto resmgr = trt_rm->getManager("TRTCalibOps"); + tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr; + auto status = resmgr->Lookup(res_name, res_name, &calib_res); + if (!status.ok() || !calib_res->calibrator_) { + return tensorflow::errors::FailedPrecondition( + "You must run calibration" + " and inference conversion in the same proces"); + } + + calib_res->calibrator_->setDone(); + calib_res->thr_->join(); + delete calib_res->thr_; + if (!calib_res->engine_) { + LOG(ERROR) << "Calibration failed!, engine does not exist. Did you run " + "calibration graph?"; + return tensorflow::errors::FailedPrecondition( + "Calibration graph needs to be executed on" + " calibration data before convertsion to inference graph"); + } + auto weight_rmgr = trt_rm->getManager("WeightStore"); + TF_CHECK_OK(weight_rmgr->Delete( + res_name, res_name)); + auto engine_plan = calib_res->engine_->serialize(); + calib_res->engine_->destroy(); + calib_res->network_->destroy(); + calib_res->builder_->destroy(); + calib_res->thr_ = nullptr; + calib_res->engine_ = nullptr; + calib_res->builder_ = nullptr; + tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp"); + std::vector income_edges; + for (const auto in_edge : c_node->in_edges()) { + auto src = in_edge->src(); + int dest_port = in_edge->dst_input(); + income_edges.emplace_back(src->name(), in_edge->src_output(), + c_node->input_type(dest_port)); + } + tensorflow::gtl::ArraySlice input_list( + income_edges); + op_builder.Input(input_list); + tensorflow::NodeDef engine_node; + const char* engine_plan_data = static_cast(engine_plan->data()); + string engine_plan_string(engine_plan_data, + engine_plan_data + engine_plan->size()); + status = op_builder.Attr("serialized_engine", engine_plan_string) + .Attr("input_nodes", input_names) + .Attr("output_nodes", output_nodes) + .Attr("OutT", out_types) + .Finalize(&engine_node); + if (!status.ok()) { + LOG(ERROR) << "Engine Node creation failed"; + return status; + } + auto trt_engine_node = graph.AddNode(engine_node, &status); + TF_RETURN_IF_ERROR(status); + for (size_t i = 0; i < out_edges.size(); i++) { + VLOG(1) << "Connecting trt_engine_node output " << i << " with " + << out_edges.at(i)->dst()->name() << " port " + << out_edges.at(i)->dst_input(); + TF_RETURN_IF_ERROR(graph.UpdateEdge(trt_engine_node, i, + out_edges.at(i)->dst(), + out_edges.at(i)->dst_input())); + } + VLOG(1) << "Segment nodes:"; + for (auto& i : segment_nodes) { + VLOG(1) << " " << i << " in graph " << node_maps.count(i); + auto it = node_maps.find(i); + if (it != node_maps.end()) { + graph.RemoveNode(it->second); + } + } + graph.RemoveNode(c_node); + return tensorflow::Status::OK(); +} + +tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) { + // Visit nodes in reverse topological order and construct the TRT network. + + // Toposort + std::vector order_vec; + tensorflow::GetPostOrder(s.graph, &order_vec); + // Select just the subgraph + std::list order; + for (tensorflow::Node* node : order_vec) { + if (s.subgraph_node_ids.count(node->id())) { + order.push_front(node); // we want topological order to construct the + // network layer by layer + } + } + // topological order is needed to build TRT network + static int static_id = 0; + string subgraph_name_scope; + if (!order.empty()) { + subgraph_name_scope = order.front()->name(); + } + for (const tensorflow::Node* node : order) { + subgraph_name_scope = GetCommonNameScope(subgraph_name_scope, node->name()); + } + // TODO(sami,ben,jie): proper naming! + string calib_op_name = + StrCat(subgraph_name_scope, "my_trt_calib_op_", static_id); + string engine_name = StrCat(subgraph_name_scope, "my_trt_op", static_id); + static_id++; + auto trt_rmgr = tensorflow::tensorrt::TRTResourceManager::instance(); + auto op_rmgr = trt_rmgr->getManager("TRTCalibOps"); + auto op_res = new tensorflow::tensorrt::TRTCalibrationResource(); + TF_CHECK_OK(op_rmgr->Create(calib_op_name, calib_op_name, op_res)); + op_res->logger_ = new tensorflow::tensorrt::Logger(); + op_res->builder_ = nvinfer1::createInferBuilder(*(op_res->logger_)); + + if (!op_res->builder_) { + return tensorflow::errors::Internal( + "failed to create TensorRT builder object"); + } + + op_res->network_ = op_res->builder_->createNetwork(); + if (!op_res->network_) { + return tensorflow::errors::Internal( + "failed to create TensorRT network object"); + } + + // Build the network + auto weight_rmgr = trt_rmgr->getManager("WeightStore"); + auto ws = new tensorflow::tensorrt::TRTWeightStore(); + TF_CHECK_OK(weight_rmgr->Create(calib_op_name, calib_op_name, ws)); + Converter converter(op_res->network_, ws, s.precision_mode == FP16MODE); + std::vector input_names; + std::vector input_dtypes; + for (const std::pair& input : s.input_inds) { + VLOG(2) << "parsing input. Node id= " << input.first; + int node_id = input.first; + int output_idx = input.second; + tensorflow::Node* node = s.graph.FindNodeId(node_id); + auto node_name = node->name(); + input_names.push_back(node_name); // insert original node name without port + // TODO(jie): alternative :) + if (!s.graph_properties.HasOutputProperties(node_name)) + return tensorflow::errors::Internal("failed to find input node: " + + node_name); + + auto op_info_vec = s.graph_properties.GetOutputProperties(node_name); + if (static_cast(op_info_vec.size()) < output_idx) + return tensorflow::errors::Internal( + "accessing output index of: ", output_idx, ", at node: ", node_name, + "with output entry from shape_map: ", op_info_vec.size()); + + auto op_info = op_info_vec.at(output_idx); + + tensorflow::DataType tf_dtype = op_info.dtype(); + input_dtypes.push_back(tf_dtype); + + nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); + auto type_status = ConvertDType(tf_dtype, &dtype); + if (type_status != tensorflow::Status::OK()) { + LOG(WARNING) << "Data type conversion for input '" << node_name + << "' failed"; + return type_status; + } + TF_CHECK_OK(ConvertDType(tf_dtype, &dtype)); + + VLOG(2) << "accessing output index of: " << output_idx + << ", at node: " << node_name + << "with output entry from shape_map: " << op_info_vec.size(); + + // TODO(ben,jie): update TRT input format/dimension + nvinfer1::DimsCHW input_dim_psuedo_chw; + for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1; + + for (int i = 1; i < op_info.shape().dim_size(); i++) { + VLOG(2) << "dimension: " << i + << " , size: " << op_info.shape().dim(i).size(); + input_dim_psuedo_chw.d[i - 1] = op_info.shape().dim(i).size(); + } + + // TODO(ben,jie): proper way to restore input tensor name? + auto input_tensor_name = node_name; + if (output_idx != 0) input_tensor_name = StrCat(node_name, ":", output_idx); + + nvinfer1::ITensor* input_tensor = converter.network()->addInput( + input_tensor_name.c_str(), dtype, input_dim_psuedo_chw); + + if (!input_tensor) + return tensorflow::errors::InvalidArgument( + "Failed to create Input layer"); + VLOG(2) << "input tensor name :" << input_tensor_name; + + if (!converter.insert_input_tensor(input_tensor_name, input_tensor)) + return tensorflow::errors::AlreadyExists( + "output tensor already exists for op: " + input_tensor_name); + } + + VLOG(2) << "finished sorting"; + + for (const tensorflow::Node* node : order) { + const tensorflow::NodeDef& node_def = node->def(); + VLOG(2) << "converting node: " << node_def.name() << " , " << node_def.op(); + TF_RETURN_IF_ERROR(converter.convert_node(node_def)); + } + + VLOG(2) << "finished conversion"; + + // Gather output metadata + std::vector output_names; + std::vector output_dtypes; + int trt_engine_op_output_idx = 0; + for (const std::pair& output : s.output_inds) { + int node_id = output.first; + int output_idx = output.second; + tensorflow::Node* node = s.graph.FindNodeId(node_id); + string op_name = node->name(); + string tensor_name = op_name; + + s.output_edge_map->insert( + {trt_engine_op_output_idx == 0 + ? engine_name + : StrCat(engine_name, ":", trt_engine_op_output_idx), + {output_idx, tensor_name}}); + trt_engine_op_output_idx++; + if (output_idx != 0) { + tensor_name = StrCat(tensor_name, ":", output_idx); + } + VLOG(1) << "output tensor name: " << tensor_name; + output_names.push_back(tensor_name); + auto tensor_or_weights = converter.get_tensor(tensor_name); + if (!tensor_or_weights.is_tensor()) { + return tensorflow::errors::InvalidArgument("Output node'" + tensor_name + + "' is weights not tensor"); + } + nvinfer1::ITensor* tensor = tensor_or_weights.tensor(); + if (!tensor) { + return tensorflow::errors::NotFound("Output tensor not found: " + + tensor_name); + } + converter.network()->markOutput(*tensor); + tensorflow::DataType tf_dtype = node->output_type(output_idx); + output_dtypes.push_back(tf_dtype); + nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT; + TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype)); + tensor->setType(trt_dtype); + } + + VLOG(2) << "finished output"; + + // Build the engine + op_res->builder_->setMaxBatchSize(s.max_batch_size); + op_res->builder_->setMaxWorkspaceSize(s.max_workspace_size_bytes); + + // Build the TRT op + // TODO(sami,ben,jie): proper naming! + tensorflow::NodeDefBuilder op_builder(calib_op_name, "TRTCalibOp"); + std::vector income_edges; + for (size_t i = 0; i < input_names.size(); ++i) { + int output_idx = s.input_inds.at(i).second; + // we wired up the input here already, it is redundant to do it again in + // ConvertSubGraphToTensorRT(convert_graph.cc) + auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut( + input_names.at(i), output_idx, input_dtypes.at(i)); + VLOG(1) << calib_op_name << " input " << i << " = " << input_names.at(i) + << ":" << output_idx + << " dType= " << tensorflow::DataTypeString(input_dtypes.at(i)); + income_edges.push_back(incoming_edge); + } + tensorflow::gtl::ArraySlice input_list( + income_edges); + op_builder.Input(input_list); + std::vector segment_names; + segment_names.reserve(s.subgraph_node_ids.size()); + for (int i : s.subgraph_node_ids) { + auto node = s.graph.FindNodeId(i); + segment_names.push_back(node->name()); + } + LOG(INFO) << "finished op preparation"; + + auto status = op_builder.Attr("segment_nodes", segment_names) + .Attr("input_names", input_names) + .Attr("segment_output_names", output_names) + .Attr("resource_name", calib_op_name) + .Finalize(s.trt_node); + + LOG(INFO) << status.ToString(); + LOG(INFO) << "finished op building"; + + return tensorflow::Status::OK(); +} tensorflow::Status ConvertSubGraphToTensorRTNodeDef( - const tensorflow::Graph& graph, const std::set& subgraph_node_ids, - const std::vector>& input_inds, - const std::vector>& output_inds, size_t max_batch_size, - size_t max_workspace_size_bytes, - const tensorflow::grappler::GraphProperties& graph_properties, - tensorflow::NodeDef* trt_node) { + tensorrt::convert::SubGraphParams& s) { // Visit nodes in reverse topological order and construct the TRT network. // Toposort std::vector order_vec; - tensorflow::GetPostOrder(graph, &order_vec); + tensorflow::GetPostOrder(s.graph, &order_vec); // Select just the subgraph std::list order; for (tensorflow::Node* node : order_vec) { - if (subgraph_node_ids.count(node->id())) { + if (s.subgraph_node_ids.count(node->id())) { // We want topological order to contstruct the // network layer by layer order.push_front(node); @@ -1434,46 +2453,94 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef( "Failed to create TensorRT network object"); } + string subgraph_name_scope; + if (!order.empty()) { + subgraph_name_scope = order.front()->name(); + } + for (const tensorflow::Node* node : order) { + subgraph_name_scope = GetCommonNameScope(subgraph_name_scope, node->name()); + } + static int static_id = 0; + // TODO(sami,ben,jie): proper naming! + string engine_name = StrCat(subgraph_name_scope, "my_trt_op"); + engine_name = StrCat(engine_name, static_id++); + auto trt_rmgr = tensorflow::tensorrt::TRTResourceManager::instance(); + auto weight_rmgr = trt_rmgr->getManager("WeightStore"); + auto ws = new tensorflow::tensorrt::TRTWeightStore(); + TF_CHECK_OK(weight_rmgr->Create(engine_name, engine_name, ws)); + // Build the network - Converter converter(trt_network.get()); + Converter converter(trt_network.get(), ws, s.precision_mode == FP16MODE); std::vector input_names; std::vector input_dtypes; - for (std::pair const& input : input_inds) { + for (const std::pair& input : s.input_inds) { + VLOG(2) << "parsing input!!!!!"; int node_id = input.first; int output_idx = input.second; - tensorflow::Node* node = graph.FindNodeId(node_id); + tensorflow::Node* node = s.graph.FindNodeId(node_id); auto node_name = node->name(); - input_names.push_back(node_name); // Insert original node name without port - // TODO(jie): alternative :) - if (!graph_properties.HasOutputProperties(node_name)) - return tensorflow::errors::Internal("Failed to find input node: " + - node_name); + // input_names should use the node name in the graph + // here it should be the input tensor name -> matching the binding + // insert original node name without port + auto tensor_name = node_name; + if (output_idx != 0) { + tensor_name = StrCat(tensor_name, ":", output_idx); + } - auto op_info_vec = graph_properties.GetOutputProperties(node_name); - if (static_cast(op_info_vec.size()) < output_idx) - return tensorflow::errors::Internal( - "Accessing output index of: " + std::to_string(output_idx) + - ", at node: " + node_name + " with output entry from shape_map: " + - std::to_string(op_info_vec.size())); + VLOG(2) << "input name: " << node_name << " tensor_name: " << tensor_name + << " idx: " << output_idx; - auto op_info = op_info_vec.at(output_idx); + auto shape_inference_node_name = node_name; + auto shape_inference_output_idx = output_idx; + // rewire the shape inference to original node in the graph + if (s.output_edge_map->count(tensor_name)) { + shape_inference_node_name = s.output_edge_map->at(tensor_name).second; + shape_inference_output_idx = s.output_edge_map->at(tensor_name).first; + } + if (shape_inference_output_idx < 0) continue; + VLOG(2) << "shapeinference name: " << shape_inference_node_name + << " idx: " << shape_inference_output_idx; + + if (!s.graph_properties.HasOutputProperties(shape_inference_node_name)) + return tensorflow::errors::Internal("failed to find input node: " + + shape_inference_node_name); + auto op_info_vec = + s.graph_properties.GetOutputProperties(shape_inference_node_name); + if (static_cast(op_info_vec.size()) <= shape_inference_output_idx) + return tensorflow::errors::Internal( + "accessing output index of: ", shape_inference_output_idx, + ", at node: ", shape_inference_node_name, + " with output entry from shape_map: ", op_info_vec.size()); + + auto op_info = op_info_vec.at(shape_inference_output_idx); tensorflow::DataType tf_dtype = op_info.dtype(); input_dtypes.push_back(tf_dtype); nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); - TF_CHECK_OK(ConvertDType(tf_dtype, &dtype)); + auto type_status = ConvertDType(tf_dtype, &dtype); + if (type_status != tensorflow::Status::OK()) { + LOG(WARNING) << "Type conversion failed for " << node_name; + return type_status; + } - VLOG(2) << "Accessing output index of: " << std::to_string(output_idx) + VLOG(2) << "Accessing output index of: " << output_idx << ", at node: " << node_name - << " with output entry from shape_map: " - << std::to_string(op_info_vec.size()); - + << " with output entry from shape_map: " << op_info_vec.size(); // TODO(ben,jie): update TRT input format/dimension nvinfer1::DimsCHW input_dim_psuedo_chw; for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1; + // TODO(jie): TRT 3.x only support 4 dimensional input tensor. + // update the code once TRT 4.0 comes out. + if (op_info.shape().dim_size() != 4) { + string err_str = "Require 4 dimensional input."; + StrAppend(&err_str, " Got ", op_info.shape().dim_size(), " ", + shape_inference_node_name); + return tensorflow::errors::Unimplemented(err_str); + } + for (int i = 1; i < op_info.shape().dim_size(); i++) { VLOG(2) << "dimension: " << i << " , size: " << op_info.shape().dim(i).size(); @@ -1482,9 +2549,11 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef( // TODO(ben,jie): proper way to restore input tensor name? auto input_tensor_name = node_name; - if (output_idx != 0) - input_tensor_name = node_name + ":" + std::to_string(output_idx); + if (output_idx != 0) { + input_tensor_name = StrCat(node_name, ":", output_idx); + } + input_names.push_back(input_tensor_name); nvinfer1::ITensor* input_tensor = converter.network()->addInput( input_tensor_name.c_str(), dtype, input_dim_psuedo_chw); @@ -1511,20 +2580,28 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef( // Gather output metadata std::vector output_names; std::vector output_dtypes; - for (std::pair const& output : output_inds) { + int trt_engine_op_output_idx = 0; + for (const std::pair& output : s.output_inds) { int node_id = output.first; int output_idx = output.second; - tensorflow::Node* node = graph.FindNodeId(node_id); + tensorflow::Node* node = s.graph.FindNodeId(node_id); string op_name = node->name(); string tensor_name = op_name; + + s.output_edge_map->insert( + {trt_engine_op_output_idx == 0 + ? engine_name + : StrCat(engine_name, ":", trt_engine_op_output_idx), + {output_idx, tensor_name}}); + trt_engine_op_output_idx++; if (output_idx != 0) - tensor_name = tensor_name + ":" + std::to_string(output_idx); + tensorflow::strings::StrAppend(&tensor_name, ":", output_idx); VLOG(2) << "Output tensor name: " << tensor_name; output_names.push_back(tensor_name); auto tensor_or_weights = converter.get_tensor(tensor_name); if (!tensor_or_weights.is_tensor()) { - return tensorflow::errors::InvalidArgument( - "Output node is weights not tensor"); + return tensorflow::errors::InvalidArgument("Output node '" + tensor_name + + "' is weights not tensor"); } nvinfer1::ITensor* tensor = tensor_or_weights.tensor(); if (!tensor) { @@ -1540,19 +2617,25 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef( } VLOG(2) << "Finished output"; - // TODO(jie): static_id is not thread safe. - static int static_id = 0; // Build the engine - trt_builder->setMaxBatchSize(max_batch_size); - trt_builder->setMaxWorkspaceSize(max_workspace_size_bytes); - VLOG(0) << "Starting build engine " << static_id; - // TODO(ben,jie): half2 and int8 mode support + trt_builder->setMaxBatchSize(s.max_batch_size); + trt_builder->setMaxWorkspaceSize(s.max_workspace_size_bytes); + VLOG(0) << "Max batch size= " << s.max_batch_size + << " max workspace size= " << s.max_workspace_size_bytes; + if (s.precision_mode == FP16MODE) { + trt_builder->setHalf2Mode(true); + VLOG(0) << "Using FP16 precision mode"; + } + LOG(INFO) << "starting build engine"; string engine_plan_string; { auto trt_engine = infer_object(trt_builder->buildCudaEngine(*converter.network())); VLOG(0) << "Built network"; + if (trt_engine.get() == nullptr) { + return tensorflow::errors::Internal("Engine building failure"); + } auto engine_plan = infer_object(trt_engine->serialize()); VLOG(0) << "Serialized engine"; const char* engine_plan_data = @@ -1560,18 +2643,20 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef( engine_plan_string = string(engine_plan_data, engine_plan_data + engine_plan->size()); } - - VLOG(0) << "Finished engine"; + TF_RETURN_IF_ERROR(weight_rmgr->Delete( + engine_name, engine_name)); + LOG(INFO) << "finished engine " << engine_name << " containing " + << s.subgraph_node_ids.size() << " nodes"; // Build the TRT op - // TODO(sami,ben,jie): proper naming! - tensorflow::NodeDefBuilder op_builder( - tensorflow::strings::StrCat("my_trt_op", static_id++), "TRTEngineOp"); + tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp"); std::vector income_edges; + VLOG(2) << "input edge size: " << input_names.size(); for (size_t i = 0; i < input_names.size(); ++i) { - int output_idx = input_inds.at(i).second; - // We wired up the input here already, it is redundant to do it again in - // ConvertSubGraphToTensorRT(convert_graph.cc) + VLOG(2) << "input edges: " << i << " " << input_names.at(i); + int output_idx = s.input_inds.at(i).second; + // we wired up the input here already, it is redundant to do it again in + // ConvertSubGraphToTensorRT(convert_graph.cc) auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut( input_names.at(i), output_idx, input_dtypes.at(i)); income_edges.push_back(incoming_edge); @@ -1586,7 +2671,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef( .Attr("input_nodes", input_names) .Attr("output_nodes", output_names) .Attr("OutT", output_dtypes) - .Finalize(trt_node); + .Finalize(s.trt_node); VLOG(0) << status.ToString() << " finished op building"; diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 2e7fd19566e1ed3719b932c7443a9c3f652b2d3e..954a1e72f8604371fc00e088a67b4d411314dda6 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -17,6 +17,8 @@ limitations under the License. #define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ #include +#include +#include #include #include @@ -32,16 +34,49 @@ namespace tensorflow { namespace tensorrt { namespace convert { -tensorflow::Status ConvertSubGraphToTensorRTNodeDef( - const tensorflow::Graph& graph, const std::set& subgraph_node_ids, - const std::vector>& - input_inds, // {node_id, output_idx} - const std::vector>& - output_inds, // {node_id, output_idx} - size_t max_batch_size, size_t max_workspace_size_bytes, - const tensorflow::grappler::GraphProperties& graph_prop, - tensorflow::NodeDef* trt_node); +const int FP32MODE = 0; +const int FP16MODE = 1; +const int INT8MODE = 2; +struct SubGraphParams { + SubGraphParams( + tensorflow::Graph& inp_graph, + const std::set& subgraph_node_id_numbers, + const std::vector>& input_indices, + const std::vector>& output_indices, + size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes, + const tensorflow::grappler::GraphProperties& current_graph_properties, + std::unordered_map>* output_edges, + tensorflow::NodeDef* constructed_trt_node, + int engine_precision_mode = FP32MODE) + : graph(inp_graph), + subgraph_node_ids(subgraph_node_id_numbers), + input_inds(input_indices), + output_inds(output_indices), + max_batch_size(max_supported_batch_size), + max_workspace_size_bytes(max_consumed_workspace_size_bytes), + graph_properties(current_graph_properties), + output_edge_map(output_edges), + trt_node(constructed_trt_node), + precision_mode(engine_precision_mode) {} + + tensorflow::Graph& graph; + const std::set& subgraph_node_ids; + const std::vector>& input_inds; // {node_id, output_idx} + const std::vector>& output_inds; // {node_id, output_idx} + size_t max_batch_size; + size_t max_workspace_size_bytes; + const tensorflow::grappler::GraphProperties& graph_properties; + std::unordered_map>* output_edge_map; + tensorflow::NodeDef* trt_node; + const int precision_mode; +}; + +// TODO(sami): Replace references with const reference or pointers +tensorflow::Status ConvertSubGraphToTensorRTNodeDef(SubGraphParams& params); +tensorflow::Status InjectCalibrationNode(SubGraphParams& params); +tensorflow::Status ConvertCalibrationNodeToEngineNode(tensorflow::Graph& graph, + tensorflow::Node* c_node); } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..aea44fd8a2fcc4c359a6cb0c98ae34711708326e --- /dev/null +++ b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc @@ -0,0 +1,136 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/kernels/trt_calib_op.h" +#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/stream_executor.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "cuda/include/cuda_runtime_api.h" +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +TRTCalibOp::TRTCalibOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("segment_nodes", &segment_nodes_)); + OP_REQUIRES_OK(context, context->GetAttr("input_names", &input_names_)); + OP_REQUIRES_OK(context, context->GetAttr("resource_name", &resource_name_)); +}; + +#define TYPECASE(dt, X, Y) \ + case dt: { \ + return (void*)X->flat::Type>().data(); \ + } + +void* GetTensorAddress(const Tensor* tensor_ptr) { + auto tensor_type = tensor_ptr->dtype(); + switch (tensor_type) { + TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr); + TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr); + TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr); + default: { + LOG(FATAL) << "Unsupported Data type " + << tensorflow::DataTypeString(tensor_type); + return nullptr; + } + } +} + +void TRTCalibOp::Compute(tensorflow::OpKernelContext* ctx) { + // TODO(aaroey): make sure ctx->resource_mgr() is used in future PR. + auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance(); + auto res_mgr = trt_rm->getManager("TRTCalibOps"); + tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr; + auto status = res_mgr->Lookup(resource_name_, resource_name_, &calib_res); + + if (!status.ok()) { + ctx->SetStatus(status); + return; + } + int num_inputs = ctx->num_inputs(); + // first run instantiate calibrator + if (calib_res->calibrator_ == nullptr) { + dev_tensors_.resize(num_inputs); + int batch_size = ctx->input(0).dim_size(0); + VLOG(1) << " Constructing calibrator"; + for (int i = 0; i < num_inputs; i++) { + // allocate workspace on device for inputs + const tensorflow::Tensor& t = ctx->input(i); + OP_REQUIRES_OK(ctx, + ctx->allocate_persistent(t.dtype(), t.shape(), + &dev_tensors_.at(i), nullptr)); + const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); + CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); + void* device_address = GetTensorAddress(device_tensor); + device_buffers_.emplace(input_names_.at(i), + std::pair( + device_address, device_tensor->TotalBytes())); + } + + calib_res->calibrator_ = + new TRTInt8Calibrator(device_buffers_, batch_size, resource_name_); + string label(resource_name_); + calib_res->thr_ = new std::thread([calib_res, label]() { + VLOG(1) << "Starting calibration thread, Calibration Resource @ " + << calib_res; + calib_res->builder_->setInt8Calibrator(calib_res->calibrator_); + calib_res->builder_->setInt8Mode(true); + calib_res->engine_ = calib_res->builder_->buildCudaEngine( + *calib_res->network_); // will loop until we terminate calibrator + VLOG(1) << "Calibration loop terminated " << label; + }); + VLOG(1) << "initialized calibrator resource"; + } // calibrator initialized + + // Pass input data to calibrator + std::unordered_map input_data; + for (int i = 0; i < num_inputs; i++) { + const Tensor& t = ctx->input(i); + void* data_address = GetTensorAddress(&t); + const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); + CHECK_EQ(t.TotalBytes(), + device_tensor->TotalBytes()); // use the tensor so FW keeps it + input_data.emplace(input_names_.at(i), data_address); + ctx->set_output(i, t); + } + VLOG(2) << "Filled map for sending"; + // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files + const cudaStream_t* stream = CHECK_NOTNULL( + reinterpret_cast(ctx->op_device_context() + ->stream() + ->implementation() + ->CudaStreamMemberHack())); + calib_res->calibrator_->setBatch(input_data, *stream); + VLOG(2) << "Passed calibration data"; + // TODO(aaroey): make sure we wait for the completion of calibration on the + // last batch in future PR. +}; + +#undef TYPECASE + +REGISTER_KERNEL_BUILDER(Name("TRTCalibOp").Device(DEVICE_GPU), TRTCalibOp); + +} // namespace tensorrt +} // namespace tensorflow +#endif +#endif diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h new file mode 100644 index 0000000000000000000000000000000000000000..23df9db32f077a080eaff7479fcbe90d6a504c42 --- /dev/null +++ b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H +#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H + +#include +#include +#include +#include +#include +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +namespace tensorflow { +namespace tensorrt { +// TODO(sami): Convert this to async kernel! +class TRTCalibOp : public OpKernel { + public: + explicit TRTCalibOp(OpKernelConstruction* context); + + void Compute(OpKernelContext* context) override; + + private: + string resource_name_; + std::vector segment_nodes_; + std::vector input_names_; + std::vector shapes_; + std::unordered_map> device_buffers_; + std::vector dev_tensors_; +}; +} // namespace tensorrt +} // namespace tensorflow +#endif +#endif +#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 8efdf63ebebc4d7a199c60635ca64348d2b30505..b32371b642f38b0851955a4a3beab97b86e1f6a0 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -24,8 +24,12 @@ limitations under the License. #include "cuda/include/cuda_runtime_api.h" namespace tensorflow { -namespace tensorrt { static ::tensorflow::tensorrt::Logger logger; +namespace gpu = ::perftools::gputools; +using IRuntime = nvinfer1::IRuntime; +using Dims = nvinfer1::Dims; + +namespace tensorrt { TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) { // read serialized_engine @@ -40,10 +44,21 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) { // TODO(samikama) runtime should be taken from a resourcemanager as well. // Only engine should be in the op and context and runtime should be taken // from resourcemanager - nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger); + // TODO(jie): cudaSetDevice make sure trt engine is allocated on the same + // gpu where the input/output is also located. + int gpu_id = context->device()->tensorflow_gpu_device_info()->gpu_id; + cudaSetDevice(gpu_id); + int device; + cudaGetDevice(&device); + if (gpu_id != device) LOG(FATAL) << "set device failed!"; + + // TODO(samikama) runtime should be taken from a resourcemanager as well. + // Only engine should be in the op and context and runtime should be taken + // from resourcemanager + + IRuntime* infer = nvinfer1::createInferRuntime(logger); trt_engine_ptr_.reset(infer->deserializeCudaEngine( serialized_engine.c_str(), serialized_engine.size(), nullptr)); - trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext()); // Runtime is safe to delete after engine creation infer->destroy(); @@ -55,7 +70,6 @@ void TRTEngineOp::Compute(OpKernelContext* context) { size_t binding_index; int num_batch = 0; - bool valid = true; for (int i = 0; i < context->num_inputs(); i++) { // Grab the input tensor binding_index = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str()); @@ -64,8 +78,12 @@ void TRTEngineOp::Compute(OpKernelContext* context) { const TensorShape& input_shape = input_tensor.shape(); if (i == 0) { num_batch = input_shape.dim_size(0); + if (num_batch > trt_engine_ptr_->getMaxBatchSize()) { + LOG(FATAL) << "input tensor batch larger than max_batch_size: " + << trt_engine_ptr_->getMaxBatchSize(); + } } else if (num_batch != input_shape.dim_size(0)) { - valid = false; + LOG(FATAL) << "input data inconsistent batch size"; break; } switch (trt_engine_ptr_->getBindingDataType(binding_index)) { @@ -81,9 +99,6 @@ void TRTEngineOp::Compute(OpKernelContext* context) { } } - // Might want a different way to inform the user of batch size inconsistency - if (!valid) LOG(WARNING) << "input data inconsistent batch size"; - for (int i = 0; i < static_cast(output_nodes_.size()); i++) { // This is bad that we have to reallocate output buffer every run. // Create an output tensor @@ -126,9 +141,11 @@ void TRTEngineOp::Compute(OpKernelContext* context) { ->implementation() ->CudaStreamMemberHack())); - // execution handled by TF since we are getting stream from TF. - // it is safe for CPU pointer array (buffers) to go out of scope after enqueue - trt_execution_context_ptr_->enqueue(num_batch, &buffers[0], *stream, nullptr); + // TODO(jie): trt enqueue does not return error + auto ret = trt_execution_context_ptr_->enqueue(num_batch, &buffers[0], + *stream, nullptr); + VLOG(2) << "enqueue returns: " << ret; + // sync should be done by TF. } REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp); diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.cc b/tensorflow/contrib/tensorrt/log/trt_logger.cc index 7add8cb8b3d2a04206ee4174e79a1a4b86e05f30..dda0dc9e712eb726800abfb6084f4f708d04825b 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.cc +++ b/tensorflow/contrib/tensorrt/log/trt_logger.cc @@ -27,19 +27,19 @@ void Logger::log(Severity severity, const char* msg) { // Suppress info-level messages switch (severity) { case Severity::kINFO: { // Mark TRT info messages as debug! - VLOG(2) << msg; + VLOG(2) << name_ << " " << msg; break; } case Severity::kWARNING: { - LOG(WARNING) << msg; + LOG(WARNING) << name_ << " " << msg; break; } case Severity::kERROR: { - LOG(ERROR) << msg; + LOG(ERROR) << name_ << " " << msg; break; } case Severity::kINTERNAL_ERROR: { - LOG(FATAL) << msg; + LOG(FATAL) << name_ << " " << msg; break; } // This is useless for now. But would catch it in future if enum changes. It diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/contrib/tensorrt/log/trt_logger.h index d71f66b933a8068a6276a7e070755e0075543bb5..7f3544f8cfda8dce13881e1f8f4388b640e315f4 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.h +++ b/tensorflow/contrib/tensorrt/log/trt_logger.h @@ -27,9 +27,11 @@ namespace tensorrt { // Logger for GIE info/warning/errors class Logger : public nvinfer1::ILogger { - private: + public: + Logger(string name = "DefaultLogger") : name_(name){}; void log(nvinfer1::ILogger::Severity severity, const char* msg) override; + private: string name_; }; diff --git a/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc b/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4835e5065068ec7a59995eb7f6126b31aecf6704 --- /dev/null +++ b/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +namespace tensorflow { + +REGISTER_OP("TRTCalibOp") + .Attr("segment_nodes: list(string)") // names of the ops in segment + .Attr("segment_output_names: list(string)") // names of the output ops in + // segment + .Attr("input_names: list(string)") // names of the inputs for + // passing into tensorrt + .Attr("resource_name: string") + .Attr("InT: list({int8, float16, float32})") + .Input("in_tensor: InT") + .Output("out_tensor: InT") + .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) { + for (int i = 0; i < c->num_inputs(); i++) { + c->set_output(i, c->input(i)); + } + return Status::OK(); + }); + +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py index 7e050a768ce97af1fc1d2c85cb52640b4c6a6a97..0b2321b5fc7bcbd53c01d1c97cafcfcb229a83ef 100644 --- a/tensorflow/contrib/tensorrt/python/__init__.py +++ b/tensorflow/contrib/tensorrt/python/__init__.py @@ -20,5 +20,6 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph # pylint: enable=unused-import,line-too-long diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 9454862f857ab743712ce409ff007de55e72a68e..338475d90ea55ab2c1bb8df77f27a71a4a36a5dd 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -20,11 +20,17 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long import six as _six +from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert from tensorflow.core.framework import graph_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.framework import errors from tensorflow.python.framework import errors_impl as _impl +from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops +from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.util import compat +# pylint: enable=unused-import,line-too-long # TODO(skama): get outputs from session when implemented as c++ @@ -32,22 +38,33 @@ from tensorflow.python.framework import ops def create_inference_graph(input_graph_def, outputs, max_batch_size=1, - max_workspace_size_bytes=2 << 20): - """Python wrapper for the TRT transormation. - + max_workspace_size_bytes=2 << 20, + precision_mode="FP32", + minimum_segment_size=3): + """Python wrapper for the TRT transformation. Args: input_graph_def: GraphDef object containing a model to be transformed. - outputs: List of tensors or node names for the model outputs. + outputs: list of tensors or node names for the model outputs. max_batch_size: max size for the input batch max_workspace_size_bytes: parameter to control memory allocation (in Bytes) + precision_mode: one of 'FP32', 'FP16' and 'INT8' + minimum_segment_size: the minimum number of nodes required for a subgraph to + be replaced by TRTEngineOp. Returns: New GraphDef with TRTEngineOps placed in graph replacing subgraphs. Raises: + ValueError: if the provided precision mode is invalid. RuntimeError: if the returned status message is malformed. """ + supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2} + if precision_mode.upper() not in supported_precision_modes: + raise ValueError(("precision mode '{}' is not supported." + "It should be one of {}").format( + precision_mode, "{'FP32', 'FP16', 'INT8'}")) + mode = supported_precision_modes[precision_mode.upper()] def py2bytes(inp): return inp @@ -83,7 +100,7 @@ def create_inference_graph(input_graph_def, # pair or strings where first one is encoded status and the second # one is the transformed graphs protobuf string. out = trt_convert(input_graph_def_str, out_names, max_batch_size, - max_workspace_size_bytes) + max_workspace_size_bytes, mode, minimum_segment_size) status = to_string(out[0]) output_graph_def_string = out[1] del input_graph_def_str # Save some memory @@ -101,3 +118,46 @@ def create_inference_graph(input_graph_def, output_graph_def.ParseFromString(output_graph_def_string) del output_graph_def_string # Save some memory return output_graph_def + + +def calib_graph_to_infer_graph(calibration_graph_def): + """Convert an existing calibration graph to inference graph. + + Args: + calibration_graph_def: the calibration GraphDef object with calibration data + Returns: + New GraphDef with TRTEngineOps placed in graph replacing calibration nodes. + Raises: + RuntimeError: if the returned status message is malformed. + """ + + def py2string(inp): + return inp + + def py3string(inp): + return inp.decode("utf-8") + + if _six.PY2: + to_string = py2string + else: + to_string = py3string + + graph_str = calibration_graph_def.SerializeToString() + out = calib_convert(graph_str) + status = to_string(out[0]) + output_graph_def_string = out[1] + del graph_str # Save some memory + if len(status) < 2: + raise _impl.UnknownError(None, None, status) + if status[:2] != "OK": + msg = status.split(";") + if len(msg) == 1: + raise RuntimeError("Status message is malformed {}".format(status)) + # pylint: disable=protected-access + raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), + int(msg[0])) + # pylint: enable=protected-access + output_graph_def = graph_pb2.GraphDef() + output_graph_def.ParseFromString(output_graph_def_string) + del output_graph_def_string # Save some memory + return output_graph_def diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc new file mode 100644 index 0000000000000000000000000000000000000000..dc7c93f869f5ef7c8eaa2a87eed26cfe69597fdb --- /dev/null +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc @@ -0,0 +1,129 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" + +#include +#include +#include + +#include "tensorflow/core/platform/logging.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "cuda/include/cuda_runtime_api.h" + +namespace tensorflow { +namespace tensorrt { + +// set the batch size before constructing the thread to execute engine +int TRTInt8Calibrator::getBatchSize() const { return batch_size_; } + +TRTInt8Calibrator::TRTInt8Calibrator( + const std::unordered_map>& dev_buffers, + int batch_size, string engine_name) + : batch_size_(batch_size), + done_(false), + dev_buffers_(dev_buffers), + calib_running_(false), + batch_is_set_(false), + engine_name_(engine_name) {} + +bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, + const cudaStream_t stream) { + tensorflow::mutex_lock lock(cond_mtx_); + while ((calib_running_ || batch_is_set_) && + !done_) { // wait while calibration is running + cond_.wait(lock); + } + if (done_) return false; + CHECK(!calib_running_ && !batch_is_set_); + VLOG(1) << "Set Batch Waiting finished"; + for (const auto it : data) { + auto devptr = dev_buffers_.find(it.first); + if (devptr == dev_buffers_.end()) { + LOG(FATAL) << "FATAL " << engine_name_ << " input name '" << it.first + << "' does not match with the buffer names"; + } + const auto& d = devptr->second; + + // TODO(aaroey): we should not use sync copy on default stream. Make sure + // stream->ThenMemcpy() is used in future PRs. + // TODO(sami,aaroey): Need to figure out a way to ensure synchronization + // between stream, perhaps using a tensor? + auto status = cudaMemcpyAsync(d.first, it.second, d.second, + cudaMemcpyDeviceToDevice, stream); + if (status != cudaSuccess) { + LOG(FATAL) << "cudaMemcpy " << engine_name_ << " for '" << it.first + << "' failed with " << status; + } + } + + // TODO(Sami, aaorey): Find an alternative way! + cudaStreamSynchronize( + stream); // we have to wait for the stream before returning! + batch_is_set_ = true; + cond_.notify_all(); + return true; +} + +bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, + int num_bindings) { + tensorflow::mutex_lock lock(cond_mtx_); + calib_running_ = false; + cond_.notify_all(); + while ((!batch_is_set_ && !done_)) { // wait until new batch arrives + cond_.wait(lock); + + } + if (done_) { + return false; + } + + for (int i = 0; i < num_bindings; i++) { + auto it = dev_buffers_.find(names[i]); + if (it == dev_buffers_.end()) { + LOG(FATAL) << "Calibration engine asked for unknown tensor name '" + << names[i] << "' at position " << i; + } + + bindings[i] = it->second.first; + } + batch_is_set_ = false; + calib_running_ = true; + return true; +} + +const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) { + return nullptr; +} + +void TRTInt8Calibrator::setDone() { + tensorflow::mutex_lock lock(cond_mtx_); + done_ = true; + cond_.notify_all(); +} + +void TRTInt8Calibrator::writeCalibrationCache(const void* ptr, + std::size_t length) {} +TRTInt8Calibrator::~TRTInt8Calibrator() { + VLOG(1) << "Destroying calibrator for " << engine_name_; +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif +#endif diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h new file mode 100644 index 0000000000000000000000000000000000000000..d77aa2c5ab184756adaee38f88180b3c128ebe03 --- /dev/null +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h @@ -0,0 +1,72 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ + +#include +#include +#include +#include +#include "tensorflow/core/platform/mutex.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +#include "cuda/include/cuda_runtime_api.h" +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { +// This class provides a 1 element queue to match TFs push model to +// TRTs pull model for calibration. When TRT implements a means for +// a push calibration This class should be updated accordingly + +struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { + public: + TRTInt8Calibrator( + const std::unordered_map>& dev_buffers, + int batch_size, string engine_name); + int getBatchSize() const override; + bool getBatch(void* bindings[], const char* names[], + int num_bindings) override; + bool setBatch(const std::unordered_map& data, + const cudaStream_t stream); + void setDone(); + const void* readCalibrationCache(std::size_t& length) override; + void writeCalibrationCache(const void* ptr, std::size_t length) override; + ~TRTInt8Calibrator(); + + private: + const int batch_size_; + tensorflow::mutex cond_mtx_; // mutex for condition_variable + tensorflow::condition_variable cond_; // condition variable to implement + // producer-consumer queue for + // calibration + bool done_; + const std::unordered_map> + dev_buffers_; // map to keep tensorrt input buffers and sizes keyed with + // buffer names + bool calib_running_; + bool batch_is_set_; + string engine_name_; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif +#endif +#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ diff --git a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc b/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc new file mode 100644 index 0000000000000000000000000000000000000000..e663eed4dd6704e2f41bde1dfabd411e86669ecd --- /dev/null +++ b/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace tensorrt { + +std::shared_ptr +tensorflow::tensorrt::TRTResourceManager::getManager(const string& op_name) { + // mutex is held for lookup only. Most instantiations where mutex will be held + // longer will be during op creation and should be ok. + tensorflow::mutex_lock lock(map_mutex_); + auto s = managers_.find(op_name); + if (s == managers_.end()) { + auto it = managers_.emplace( + op_name, std::make_shared(op_name)); + VLOG(1) << "Returning a new manager " << op_name; + return it.first->second; + } + VLOG(1) << "Returning old manager " << op_name; + return s->second; +} + +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h b/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..5f8ad491d3c13e8911b0b95c3e95e19afe4d59c0 --- /dev/null +++ b/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_ +#include + +#include +#include +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace tensorrt { + +class TRTResourceManager { + TRTResourceManager() = default; + + public: + static std::shared_ptr instance() { + static std::shared_ptr instance_( + new TRTResourceManager); + return instance_; + } + // returns a manager for given op, if it doesn't exists it creates one + std::shared_ptr getManager(const string& op_name); + + private: + std::unordered_map> + managers_; + tensorflow::mutex map_mutex_; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCE_TRT_RESOURCE_MANAGER_H_ diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h new file mode 100644 index 0000000000000000000000000000000000000000..3c85968ae7acf5c5fc567be6805a5d226b1094c7 --- /dev/null +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -0,0 +1,95 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTRESOURCES_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTRESOURCES_H_ + +#include +#include +#include +#include +#include +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/core/framework/resource_mgr.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { +class TRTCalibrationResource : public tensorflow::ResourceBase { + public: + TRTCalibrationResource() + : calibrator_(nullptr), + builder_(nullptr), + network_(nullptr), + engine_(nullptr), + logger_(nullptr), + thr_(nullptr) {} + string DebugString() override { + std::stringstream oss; + oss << " Calibrator = " << std::hex << calibrator_ << std::dec << std::endl + << " Builder = " << std::hex << builder_ << std::dec << std::endl + << " Network = " << std::hex << network_ << std::dec << std::endl + << " Engine = " << std::hex << engine_ << std::dec << std::endl + << " Logger = " << std::hex << logger_ << std::dec << std::endl + << " Thread = " << std::hex << thr_ << std::dec << std::endl; + return oss.str(); + } + ~TRTCalibrationResource() { + VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString(); + } + TRTInt8Calibrator* calibrator_; + nvinfer1::IBuilder* builder_; + nvinfer1::INetworkDefinition* network_; + nvinfer1::ICudaEngine* engine_; + tensorflow::tensorrt::Logger* logger_; + // TODO(sami): Use threadpool threads! + std::thread* thr_; +}; + +class TRTWeightStore : public tensorflow::ResourceBase { + public: + TRTWeightStore() {} + std::list> store_; + string DebugString() override { + std::stringstream oss; + size_t lenBytes = 0; + for (const auto& v : store_) { + lenBytes += v.size() * sizeof(uint8_t); + } + oss << " Number of entries = " << store_.size() << std::endl + << " Total number of bytes = " + << store_.size() * sizeof(std::vector) + lenBytes << std::endl; + return oss.str(); + } + virtual ~TRTWeightStore() { VLOG(1) << "Destroying store" << DebugString(); } +}; + +class TRTEngineResource : public tensorflow::ResourceBase { + public: + TRTEngineResource() : runtime_(nullptr), ctx_(nullptr){}; + string DebugString() override { return string(""); } + nvinfer1::IRuntime* runtime_; + nvinfer1::IExecutionContext* ctx_; +}; + +} // namespace tensorrt +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCEMGR_TRTRESOURCES_H_ +#endif +#endif diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index 6193f0b0a13f6985d5fc8dd4c6fc09b15f72f139..8fc4697c513057c668d31a341cb13f60dc107e81 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -80,13 +80,20 @@ void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph, std::vector in_edges(dst->in_edges().begin(), dst->in_edges().end()); for (const tensorflow::Edge* in_edge : in_edges) { - if (in_edge->src() != src) { - tensorflow::Edge* e = const_cast(in_edge); - if (e->src() == graph->source_node()) { - graph->AddEdge(e->src(), e->src_output(), src, - tensorflow::Graph::kControlSlot); - } else { - graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */); + if (in_edge->IsControlEdge()) { + if (in_edge->src() != src) { + tensorflow::Edge* e = const_cast(in_edge); + graph->AddControlEdge(e->src(), src); + } + } else { + if (in_edge->src() != src) { + tensorflow::Edge* e = const_cast(in_edge); + if (e->src() == graph->source_node()) { + graph->AddEdge(e->src(), e->src_output(), src, + tensorflow::Graph::kControlSlot); + } else { + graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */); + } } } } @@ -94,12 +101,19 @@ void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph, std::vector out_edges(dst->out_edges().begin(), dst->out_edges().end()); for (const tensorflow::Edge* out_edge : out_edges) { - tensorflow::Edge* e = const_cast(out_edge); - if (e->dst() == graph->sink_node()) { - graph->AddEdge(src, tensorflow::Graph::kControlSlot, e->dst(), - e->dst_input()); + if (out_edge->IsControlEdge()) { + tensorflow::Edge* e = const_cast(out_edge); + graph->AddControlEdge(src, e->dst()); } else { - graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input()); + tensorflow::Edge* e = const_cast(out_edge); + if (e->dst() == graph->sink_node()) { + VLOG(1) << " edge to sink node " << src->name() << " -> " + << e->dst()->name(); + graph->AddEdge(src, tensorflow::Graph::kControlSlot, e->dst(), + e->dst_input()); + } else { + graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input()); + } } } @@ -118,7 +132,7 @@ void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph, tensorflow::Status SegmentGraph( const tensorflow::GraphDef& gdef, - const std::function& candidate_fn, + const std::function& candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments) { // Create a Graph representation of the GraphDef. tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), @@ -136,7 +150,7 @@ tensorflow::Status SegmentGraph( for (int i = 0; i < graph.num_node_ids(); ++i) { tensorflow::Node* node = graph.FindNodeId(i); if (options.exclude_node_list.count(node->name()) != 0 || - !candidate_fn(node->def())) { + !candidate_fn(node)) { node = nullptr; } node_segments.emplace_back(node); @@ -155,7 +169,7 @@ tensorflow::Status SegmentGraph( for (const tensorflow::Node* node : order) { // All output nodes of 'node' have been visited... - VLOG(2) << "Trying node " << node->name(); + VLOG(2) << "Trying node " << node->name() << " id=" << node->id(); // 'node' must be a TRT candidate... if (node_segments[node->id()].Value() == nullptr) { @@ -169,8 +183,12 @@ tensorflow::Status SegmentGraph( while (true) { std::set contract_edges; for (const tensorflow::Edge* out_edge : node->out_edges()) { - VLOG(2) << "... out node " << out_edge->dst()->name(); - + VLOG(2) << "... out node " << out_edge->dst()->name() << " ( " + << out_edge->dst()->id() << " <- " << node->id() << " )"; + if (out_edge->IsControlEdge()) { + VLOG(2) << "... ... Control Edge, Skipping"; + continue; + } // Out node must be TRT candidate... if (node_segments[out_edge->dst()->id()].Value() == nullptr) { VLOG(2) << "... ... not a TRT candidate"; @@ -196,7 +214,8 @@ tensorflow::Status SegmentGraph( const tensorflow::Node* src = contract_edge->src(); const tensorflow::Node* dst = contract_edge->dst(); - VLOG(2) << "Merge " << src->name() << " <- " << dst->name(); + VLOG(2) << "Merge " << src->name() << " <- " << dst->name() << " (" + << src->id() << " <- " << dst->id(); node_segments[src->id()].Merge(&node_segments[dst->id()]); // Contracting the edge leaves disconnected graph edges. diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h index ee6e2b3ed26cd1fabc0e952d882d549046cd9a30..7e8685f44a8c8a20fd7159ee40a8835531e78e9f 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.h +++ b/tensorflow/contrib/tensorrt/segment/segment.h @@ -20,10 +20,12 @@ limitations under the License. #include #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { + namespace tensorrt { namespace segment { @@ -46,7 +48,7 @@ struct SegmentOptions { // @return the status. tensorflow::Status SegmentGraph( const tensorflow::GraphDef& gdef, - const std::function& candidate_fn, + const std::function& candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments); } // namespace segment diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc index 74cbc5f2b376b76324eed06d251767da6f928e3e..7ddabec268d4ef7b5c679001e5fb99aa7d83aec0 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -35,7 +35,7 @@ class SegmentTest : public ::testing::Test { TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name); - std::function MakeCandidateFn( + std::function MakeCandidateFn( const std::set& node_names); protected: @@ -60,10 +60,10 @@ bool SegmentTest::GetGraphDef(TF_Graph* graph, return ret; } -std::function SegmentTest::MakeCandidateFn( +std::function SegmentTest::MakeCandidateFn( const std::set& node_names) { - return [node_names](const NodeDef& node) -> bool { - return node_names.find(node.name()) != node_names.end(); + return [node_names](const Node* node) -> bool { + return node_names.find(node->name()) != node_names.end(); }; } diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py index c78f6f222457a875525e768eacc9a4ebf28ad504..ad01bedd8fa066e914b05b20dbc47d9aabe790d9 100644 --- a/tensorflow/contrib/tensorrt/test/test_tftrt.py +++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py @@ -60,6 +60,7 @@ def get_simple_graph_def(): def run_graph(gdef, dumm_inp): + """Run given graphdef once.""" gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) ops.reset_default_graph() g = ops.Graph() @@ -74,15 +75,65 @@ def run_graph(gdef, dumm_inp): return val +# Use real data that is representative of the inference dataset +# for calibration. For this test script it is random data. +def run_calibration(gdef, dumm_inp): + """Run given calibration graph multiple times.""" + gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + ops.reset_default_graph() + g = ops.Graph() + with g.as_default(): + inp, out = importer.import_graph_def( + graph_def=gdef, return_elements=["input", "output"]) + inp = inp.outputs[0] + out = out.outputs[0] + with csess.Session( + config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: + # run over real calibration data here, we are mimicking a calibration set of + # 30 different batches. Use as much calibration data as you want + for _ in range(30): + val = sess.run(out, {inp: dumm_inp}) + return val + + if "__main__" in __name__: inp_dims = (100, 24, 24, 2) dummy_input = np.random.random_sample(inp_dims) - gdef = get_simple_graph_def() + orig_graph = get_simple_graph_def() # use a frozen graph for inference # Get optimized graph - trt_graph = trt.create_inference_graph(gdef, ["output"], inp_dims[0]) - o1 = run_graph(gdef, dummy_input) + trt_graph = trt.create_inference_graph( + input_graph_def=orig_graph, + outputs=["output"], + max_batch_size=inp_dims[0], + max_workspace_size_bytes=1 << 25, + precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8" + minimum_segment_size=2 # minimum number of nodes in an engine + ) + o1 = run_graph(orig_graph, dummy_input) o2 = run_graph(trt_graph, dummy_input) o3 = run_graph(trt_graph, dummy_input) assert np.array_equal(o1, o2) assert np.array_equal(o3, o2) # sanity check + fp16_graph = trt.create_inference_graph( + input_graph_def=orig_graph, + outputs=["output"], + max_batch_size=inp_dims[0], + max_workspace_size_bytes=1 << 25, + precision_mode="FP16", # TRT Engine precision "FP32","FP16" or "INT8" + minimum_segment_size=2 # minimum number of nodes in an engine + ) + int8_calib_gdef = trt.create_inference_graph( + input_graph_def=orig_graph, + outputs=["output"], + max_batch_size=inp_dims[0], + max_workspace_size_bytes=1 << 25, + precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8" + minimum_segment_size=2 # minimum number of nodes in an engine + ) + o4 = run_graph(fp16_graph, dummy_input) + _ = run_calibration(int8_calib_gdef, dummy_input) + int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef) + o5 = run_graph(int8_graph, dummy_input) + assert np.allclose(o1, o4) + assert np.allclose(o1, o5) print("Pass") diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i index d679945d569c1784448b6cb09c2f431b9cda56d7..46480e99a113afb34702b0ecd71468d4bdc83f98 100644 --- a/tensorflow/contrib/tensorrt/trt_conversion.i +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -64,13 +64,17 @@ PyObject* pair_helper(std::pair* in) { %ignoreall %unignore tensorflow; %unignore trt_convert; +%unignore calib_convert; %{ + std::pair trt_convert( string graph_def_string, // The serialized GraphDef string. std::vector output_names, size_t max_batch_size, - size_t max_workspace_size_bytes + size_t max_workspace_size_bytes, + int precision_mode, + int minimum_segment_size // Unfortunately we can't use TF_Status here since it // is in c/c_api and brings in a lot of other libraries // which in turn declare ops. These ops are included @@ -90,16 +94,64 @@ std::pair trt_convert( return std::pair{out_status, ""}; } + if(precision_mode < 0 || precision_mode > 2){ + out_status = "InvalidArgument;Invalid precision_mode"; + return std::pair{out_status, ""}; + } if (!output_names.size()) { out_status = "InvalidArgument;Size of the output_names vector is 0"; return std::pair{out_status, ""}; - // return ""; } tensorflow::GraphDef outGraph; tensorflow::Status conversion_status = tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT( graph_def, output_names, max_batch_size, max_workspace_size_bytes, - &outGraph); + &outGraph, precision_mode, minimum_segment_size); + if (!conversion_status.ok()) { + auto retCode = (int)conversion_status.code(); + char buff[2000]; + snprintf(buff, 2000, "%d;%s", retCode, + conversion_status.error_message().c_str()); + out_status = buff; + return std::pair{out_status, ""}; + } + string result; + if (!outGraph.SerializeToString(&result)) { + out_status = "InvalidArgument;Couldn't serialize output as a GraphDef"; + return std::pair{out_status, ""}; + } + out_status = "OK;All good!"; + return std::pair{out_status, result}; +#else + // Returns FAILED_PRECONDITION. + return std::pair{"9;TensorRT is not enabled!", ""}; +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +} + +std::pair calib_convert(string graph_def_string // const tensorflow::GraphDef& + // unfortunately we can't use TF_Status here since it + // is in c/c_api and brings in a lot of other libraries + // which in turn declare ops. These ops are included + // statically in our library and cause an abort when + // module is loaded due to double registration + // until Tensorflow properly exposes these headers + // we have to work around this by returning a string + // and converting it to exception on python side. + //,TF_Status* out_status) { +) { +#if GOOGLE_CUDA && GOOGLE_TENSORRT + string out_status; + + tensorflow::GraphDef graph_def; + if (!graph_def.ParseFromString(graph_def_string)) { + out_status = "InvalidArgument;Couldn't interpret input as a GraphDef"; + return std::pair{out_status, ""}; + } + + tensorflow::GraphDef outGraph; + tensorflow::Status conversion_status = + tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph(graph_def, + &outGraph); if (!conversion_status.ok()) { auto retCode = (int)conversion_status.code(); char buff[2000]; @@ -122,10 +174,13 @@ std::pair trt_convert( } %} +std::pair calib_convert(string graph_def_string); + std::pair trt_convert(string graph_def_string, std::vector output_names, size_t max_batch_size, - size_t max_workspace_size_bytes); + size_t max_workspace_size_bytes, + int precision_mode, int minimum_segment_size); %unignoreall diff --git a/tensorflow/contrib/testing/python/framework/fake_summary_writer.py b/tensorflow/contrib/testing/python/framework/fake_summary_writer.py index f2065c666255984c8ab770fc10f682b1eabad095..15a415df303df5be44e89c00005cb253ae2af286 100644 --- a/tensorflow/contrib/testing/python/framework/fake_summary_writer.py +++ b/tensorflow/contrib/testing/python/framework/fake_summary_writer.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.core.framework import summary_pb2 +from tensorflow.python.framework import test_util from tensorflow.python.summary.writer import writer from tensorflow.python.summary.writer import writer_cache @@ -85,7 +86,11 @@ class FakeSummaryWriter(object): if expected_added_graphs is not None: test_case.assertEqual(expected_added_graphs, self._added_graphs) if expected_added_meta_graphs is not None: - test_case.assertEqual(expected_added_meta_graphs, self._added_meta_graphs) + test_case.assertEqual(len(expected_added_meta_graphs), + len(self._added_meta_graphs)) + for expected, actual in zip(expected_added_meta_graphs, + self._added_meta_graphs): + test_util.assert_meta_graph_protos_equal(test_case, expected, actual) if expected_session_logs is not None: test_case.assertEqual(expected_session_logs, self._added_session_logs) diff --git a/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv b/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv index b49a0662c29b1d810f4be31ca1f318f0571f533e..9b15b4f0b26f11ac3281ca4206654872984628b6 100644 --- a/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv +++ b/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv @@ -1,100 +1,100 @@ -0,0.926906299771,1.99107237682,2.56546245685,3.07914768197,4.04839057867,1.,0. -1,0.108010001864,1.41645361423,2.1686839775,2.94963962176,4.1263503303,1.,0. -2,-0.800567600028,1.0172132907,1.96434754116,2.99885333086,4.04300485864,1.,0. -3,0.0607042871898,0.719540073421,1.9765012584,2.89265588817,4.0951014426,1.,0. -4,0.933712200629,0.28052120776,1.41018552514,2.69232603996,4.06481164223,1.,0. -5,-0.171730652974,0.260054421028,1.48770816369,2.62199129293,4.44572807842,1.,0. -6,-1.00180162933,0.333045158863,1.50006392277,2.88888309683,4.24755865606,1.,0. -7,0.0580061875336,0.688929398826,1.56543458772,2.99840358953,4.52726873347,1.,0. -8,0.764139447412,1.24704875327,1.77649279698,3.13578593851,4.63238922951,1.,0. -9,-0.230331874785,1.47903998963,2.03547545751,3.20624030377,4.77980005228,1.,0. -10,-1.03846045211,2.01133000781,2.31977503972,3.67951536251,5.09716775897,1.,0. -11,0.188643592253,2.23285349038,2.68338482249,3.49817168611,5.24928239634,1.,0. -12,0.91207302309,2.24244446841,2.71362604985,3.96332587625,5.37802271594,1.,0. -13,-0.296588665881,2.02594634141,3.07733910479,3.99698324956,5.56365901394,1.,0. -14,-0.959961476551,1.45078629833,3.18996420137,4.3763059609,5.65356015609,1.,0. -15,0.46313530679,1.01141441548,3.4980215948,4.20224896882,5.88842247449,1.,0. -16,0.929354125798,0.626635305936,3.70508262244,4.51791573544,5.73945973251,1.,0. -17,-0.519110731957,0.269249223148,3.39866823332,4.46802003061,5.82768174382,1.,0. -18,-0.924330981367,0.349602834684,3.21762413294,4.72803587499,5.94918925767,1.,0. -19,0.253239387885,0.345158023497,3.11071425333,4.79311566935,5.9489259713,1.,0. -20,0.637408390225,0.698996675371,3.25232492145,4.73814732384,5.9612010251,1.,0. -21,-0.407396859412,1.17456342803,2.49526823723,4.59323415742,5.82501686811,1.,0. -22,-0.967485452118,1.66655933642,2.47284606244,4.58316034754,5.88721406681,1.,0. -23,0.474480867904,1.95018556323,2.0228950072,4.48651142819,5.8255943735,1.,0. -24,1.04309652155,2.23519892356,1.91924131572,4.19094661783,5.87457348436,1.,0. -25,-0.517861513772,2.12501967336,1.70266619979,4.05280882887,5.72160912899,1.,0. -26,-0.945301585146,1.65464653549,1.81567174251,3.92309850635,5.58270493814,1.,0. -27,0.501153868974,1.40600764889,1.53991387719,3.72853247942,5.60169001727,1.,0. -28,0.972859524418,1.00344321868,1.5175642828,3.64092376655,5.10567722582,1.,0. -29,-0.70553406135,0.465306263885,1.7038540803,3.33236870312,5.09182481555,1.,0. -30,-0.946093634916,0.294539309453,1.88052827037,2.93011492669,4.97354922696,1.,0. -31,0.47922123231,0.308465865031,2.03445883031,2.90772899045,4.86241793548,1.,0. -32,0.754030014252,0.549752241167,2.46115815089,2.95063349534,4.71834614627,1.,0. -33,-0.64875949826,0.894615488148,2.5922463381,2.81269864022,4.43480095104,1.,0. -34,-0.757829951086,1.39123914261,2.69258079904,2.61834837315,4.36580046156,1.,0. -35,0.565653301088,1.72360022693,2.97794913834,2.80403840334,4.27327248459,1.,0. -36,0.867440092372,2.21100730052,3.38648090792,2.84057515729,4.12210169576,1.,0. -37,-0.894567758095,2.17549105818,3.45532493329,2.90446025717,4.00251740584,1.,0. -38,-0.715442356893,2.15105389965,3.52041791902,3.03650393392,4.12809249577,1.,0. -39,0.80671703672,1.81504564517,3.60463324866,3.00747789871,3.98440762467,1.,0. -40,0.527014790142,1.31803513865,3.43842186337,3.3332594663,4.03232406566,1.,0. -41,-0.795936862129,0.847809114454,3.09875133548,3.52863155938,3.94883924909,1.,0. -42,-0.610245806946,0.425530441018,2.92581949152,3.77238736123,4.27287245021,1.,0. -43,0.611662279431,0.178432049837,2.48128214822,3.73212087883,4.17319013831,1.,0. -44,0.650866553108,0.220341648392,2.41694642022,4.2609098519,4.27271645905,1.,0. -45,-0.774156982023,0.632667602331,2.05474356052,4.32889204886,4.18029723271,1.,0. -46,-0.714058448409,0.924562377599,1.75706135146,4.52492718422,4.3972678094,1.,0. -47,0.889627293379,1.46207968841,1.78299357672,4.64466731095,4.56317887554,1.,0. -48,0.520140662861,1.8996333843,1.41377633823,4.48899091177,4.78805049769,1.,0. -49,-1.03816935616,2.08997002059,1.51218375351,4.84167764204,4.93026048606,1.,0. -50,-0.40772951362,2.30878972136,1.44144415128,4.76854460997,5.01538444629,1.,0. -51,0.792730684781,1.91367048509,1.58887384677,4.71739397335,5.25690012199,1.,0. -52,0.371311881576,1.67565079528,1.81688563053,4.60353107555,5.44265822961,1.,0. -53,-0.814398070371,1.13374634126,1.80328814859,4.72264252878,5.52674761122,1.,0. -54,-0.469017949323,0.601244136627,2.29690896736,4.49859178859,5.54126153454,1.,0. -55,0.871044371426,0.407597593794,2.7499112487,4.19060637761,5.57693767301,1.,0. -56,0.523764933017,0.247705192709,3.09002071379,4.02095509006,5.80510362182,1.,0. -57,-0.881326403531,0.31513103164,3.11358205718,3.96079100808,5.81000652365,1.,0. -58,-0.357928025339,0.486163915865,3.17884556771,3.72634990659,5.85693642011,1.,0. -59,0.853038779822,1.04218094475,3.45835384454,3.36703969978,5.9585988449,1.,0. -60,0.435311516013,1.59715085283,3.63313338588,3.11276729421,5.93643818229,1.,0. -61,-1.02703719138,1.92205832542,3.47606111735,3.06247155999,6.02106646259,1.,0. -62,-0.246661325557,2.14653802542,3.29446326567,2.89936259181,5.67531541272,1.,0. -63,1.02554736569,2.25943737733,3.07031591528,2.78176218013,5.78206328989,1.,0. -64,0.337814475969,2.07589147224,2.80356226089,2.55888206331,5.7094075496,1.,0. -65,-1.12023369929,1.25333011618,2.56497288445,2.77361359194,5.50799418376,1.,0. -66,-0.178980246554,1.11937139901,2.51598681313,2.91438309151,5.47469577206,1.,0. -67,0.97550951531,0.60553823137,2.11657741073,2.88081098981,5.37034999502,1.,0. -68,0.136653357206,0.365828836075,1.97386033165,3.13217903204,5.07254490219,1.,0. -69,-1.05607596951,0.153152115069,1.52110743825,3.01308794192,5.08902539125,1.,0. -70,-0.13095280331,0.337113974483,1.52703079853,3.16687131599,4.86649398514,1.,0. -71,1.07081057754,0.714247566736,1.53761382634,3.45151989484,4.75892309166,1.,0. -72,0.0153410376082,1.24631231847,1.61690939161,3.85481994498,4.35683752832,1.,0. -73,-0.912801257303,1.60791309476,1.8729264524,4.03037260012,4.36072588913,1.,0. -74,-0.0894895640338,2.02535207407,1.93484909619,4.09557485132,4.35327025188,1.,0. -75,0.978646999652,2.20085086625,2.09003440427,4.27542353033,4.1805058388,1.,0. -76,-0.113312642876,2.2444100761,2.50789248839,4.4151861502,4.03267168136,1.,0. -77,-1.00215099149,1.84305628445,2.61691237246,4.45425147595,3.81203553766,1.,0. -78,-0.0183234614205,1.49573923116,2.99308471214,4.71134960112,4.0273804959,1.,0. -79,1.0823738177,1.12211589848,3.27079386925,4.94288270502,4.01851068083,1.,0. -80,0.124370187893,0.616474412808,3.4284236674,4.76942168327,3.9749536483,1.,0. -81,-0.929423379352,0.290977090976,3.34131726136,4.78590392707,4.10190661656,1.,0. -82,0.23766302648,0.155302052254,3.49779513794,4.64605656795,4.15571321107,1.,0. -83,1.03531486192,0.359702776204,3.4880725919,4.48167586667,4.21134561991,1.,0. -84,-0.261234571382,0.713877760378,3.42756426614,4.426443869,4.25208300527,1.,0. -85,-1.03572442277,1.25001113691,2.96908341113,4.25500915322,4.25723010649,1.,0. -86,0.380034261243,1.70543355622,2.73605932518,4.16703432307,4.63700400788,1.,0. -87,1.03734873488,1.97544410562,2.55586572141,3.84976673263,4.55282864289,1.,0. -88,-0.177344253372,2.22614526325,2.09565864891,3.77378097953,4.82577400298,1.,0. -89,-0.976821526892,2.18385079177,1.78522284118,3.67768223554,5.06302440873,1.,0. -90,0.264820472091,1.86981946157,1.50048403865,3.43619796921,5.05651761669,1.,0. -91,1.05642344868,1.47568646076,1.51347671977,3.20898518885,5.50149047462,1.,0. -92,-0.311607433358,1.04226467636,1.52089650905,3.02291865417,5.4889046232,1.,0. -93,-0.724285777937,0.553052311957,1.48573560173,2.7365973598,5.72549174225,1.,0. -94,0.519859192905,0.226520626591,1.61543723167,2.84102086852,5.69330622288,1.,0. -95,1.0323195039,0.260873217055,1.81913034804,2.83951143848,5.90325028086,1.,0. -96,-0.53285682538,0.387695521405,1.70935609313,2.57977050631,5.79579213161,1.,0. -97,-0.975127997215,0.920948771589,2.51292643636,2.71004616612,5.87016469227,1.,0. -98,0.540246804099,1.36445470181,2.61949412896,2.98482553485,6.02447664937,1.,0. -99,0.987764008058,1.85581989607,2.84685706149,2.94760204892,6.0212151724,1.,0. +0,0.926906299771,1.99107237682,2.56546245685,3.07914768197,4.04839057867,1.,0.,strkeya +1,0.108010001864,1.41645361423,2.1686839775,2.94963962176,4.1263503303,1.,0.,strkeyb +2,-0.800567600028,1.0172132907,1.96434754116,2.99885333086,4.04300485864,1.,0.,strkey +3,0.0607042871898,0.719540073421,1.9765012584,2.89265588817,4.0951014426,1.,0.,strkey +4,0.933712200629,0.28052120776,1.41018552514,2.69232603996,4.06481164223,1.,0.,strkey +5,-0.171730652974,0.260054421028,1.48770816369,2.62199129293,4.44572807842,1.,0.,strkey +6,-1.00180162933,0.333045158863,1.50006392277,2.88888309683,4.24755865606,1.,0.,strkey +7,0.0580061875336,0.688929398826,1.56543458772,2.99840358953,4.52726873347,1.,0.,strkey +8,0.764139447412,1.24704875327,1.77649279698,3.13578593851,4.63238922951,1.,0.,strkey +9,-0.230331874785,1.47903998963,2.03547545751,3.20624030377,4.77980005228,1.,0.,strkey +10,-1.03846045211,2.01133000781,2.31977503972,3.67951536251,5.09716775897,1.,0.,strkeyc +11,0.188643592253,2.23285349038,2.68338482249,3.49817168611,5.24928239634,1.,0.,strkey +12,0.91207302309,2.24244446841,2.71362604985,3.96332587625,5.37802271594,1.,0.,strkey +13,-0.296588665881,2.02594634141,3.07733910479,3.99698324956,5.56365901394,1.,0.,strkey +14,-0.959961476551,1.45078629833,3.18996420137,4.3763059609,5.65356015609,1.,0.,strkey +15,0.46313530679,1.01141441548,3.4980215948,4.20224896882,5.88842247449,1.,0.,strkey +16,0.929354125798,0.626635305936,3.70508262244,4.51791573544,5.73945973251,1.,0.,strkey +17,-0.519110731957,0.269249223148,3.39866823332,4.46802003061,5.82768174382,1.,0.,strkey +18,-0.924330981367,0.349602834684,3.21762413294,4.72803587499,5.94918925767,1.,0.,strkey +19,0.253239387885,0.345158023497,3.11071425333,4.79311566935,5.9489259713,1.,0.,strkey +20,0.637408390225,0.698996675371,3.25232492145,4.73814732384,5.9612010251,1.,0.,strkey +21,-0.407396859412,1.17456342803,2.49526823723,4.59323415742,5.82501686811,1.,0.,strkey +22,-0.967485452118,1.66655933642,2.47284606244,4.58316034754,5.88721406681,1.,0.,strkey +23,0.474480867904,1.95018556323,2.0228950072,4.48651142819,5.8255943735,1.,0.,strkey +24,1.04309652155,2.23519892356,1.91924131572,4.19094661783,5.87457348436,1.,0.,strkey +25,-0.517861513772,2.12501967336,1.70266619979,4.05280882887,5.72160912899,1.,0.,strkey +26,-0.945301585146,1.65464653549,1.81567174251,3.92309850635,5.58270493814,1.,0.,strkey +27,0.501153868974,1.40600764889,1.53991387719,3.72853247942,5.60169001727,1.,0.,strkey +28,0.972859524418,1.00344321868,1.5175642828,3.64092376655,5.10567722582,1.,0.,strkey +29,-0.70553406135,0.465306263885,1.7038540803,3.33236870312,5.09182481555,1.,0.,strkey +30,-0.946093634916,0.294539309453,1.88052827037,2.93011492669,4.97354922696,1.,0.,strkey +31,0.47922123231,0.308465865031,2.03445883031,2.90772899045,4.86241793548,1.,0.,strkey +32,0.754030014252,0.549752241167,2.46115815089,2.95063349534,4.71834614627,1.,0.,strkey +33,-0.64875949826,0.894615488148,2.5922463381,2.81269864022,4.43480095104,1.,0.,strkey +34,-0.757829951086,1.39123914261,2.69258079904,2.61834837315,4.36580046156,1.,0.,strkey +35,0.565653301088,1.72360022693,2.97794913834,2.80403840334,4.27327248459,1.,0.,strkey +36,0.867440092372,2.21100730052,3.38648090792,2.84057515729,4.12210169576,1.,0.,strkey +37,-0.894567758095,2.17549105818,3.45532493329,2.90446025717,4.00251740584,1.,0.,strkeyd +38,-0.715442356893,2.15105389965,3.52041791902,3.03650393392,4.12809249577,1.,0.,strkey +39,0.80671703672,1.81504564517,3.60463324866,3.00747789871,3.98440762467,1.,0.,strkey +40,0.527014790142,1.31803513865,3.43842186337,3.3332594663,4.03232406566,1.,0.,strkey +41,-0.795936862129,0.847809114454,3.09875133548,3.52863155938,3.94883924909,1.,0.,strkey +42,-0.610245806946,0.425530441018,2.92581949152,3.77238736123,4.27287245021,1.,0.,strkey +43,0.611662279431,0.178432049837,2.48128214822,3.73212087883,4.17319013831,1.,0.,strkey +44,0.650866553108,0.220341648392,2.41694642022,4.2609098519,4.27271645905,1.,0.,strkey +45,-0.774156982023,0.632667602331,2.05474356052,4.32889204886,4.18029723271,1.,0.,strkey +46,-0.714058448409,0.924562377599,1.75706135146,4.52492718422,4.3972678094,1.,0.,strkey +47,0.889627293379,1.46207968841,1.78299357672,4.64466731095,4.56317887554,1.,0.,strkey +48,0.520140662861,1.8996333843,1.41377633823,4.48899091177,4.78805049769,1.,0.,strkey +49,-1.03816935616,2.08997002059,1.51218375351,4.84167764204,4.93026048606,1.,0.,strkey +50,-0.40772951362,2.30878972136,1.44144415128,4.76854460997,5.01538444629,1.,0.,strkey +51,0.792730684781,1.91367048509,1.58887384677,4.71739397335,5.25690012199,1.,0.,strkey +52,0.371311881576,1.67565079528,1.81688563053,4.60353107555,5.44265822961,1.,0.,strkey +53,-0.814398070371,1.13374634126,1.80328814859,4.72264252878,5.52674761122,1.,0.,strkey +54,-0.469017949323,0.601244136627,2.29690896736,4.49859178859,5.54126153454,1.,0.,strkey +55,0.871044371426,0.407597593794,2.7499112487,4.19060637761,5.57693767301,1.,0.,strkey +56,0.523764933017,0.247705192709,3.09002071379,4.02095509006,5.80510362182,1.,0.,strkey +57,-0.881326403531,0.31513103164,3.11358205718,3.96079100808,5.81000652365,1.,0.,strkey +58,-0.357928025339,0.486163915865,3.17884556771,3.72634990659,5.85693642011,1.,0.,strkey +59,0.853038779822,1.04218094475,3.45835384454,3.36703969978,5.9585988449,1.,0.,strkey +60,0.435311516013,1.59715085283,3.63313338588,3.11276729421,5.93643818229,1.,0.,strkey +61,-1.02703719138,1.92205832542,3.47606111735,3.06247155999,6.02106646259,1.,0.,strkey +62,-0.246661325557,2.14653802542,3.29446326567,2.89936259181,5.67531541272,1.,0.,strkey +63,1.02554736569,2.25943737733,3.07031591528,2.78176218013,5.78206328989,1.,0.,strkey +64,0.337814475969,2.07589147224,2.80356226089,2.55888206331,5.7094075496,1.,0.,strkey +65,-1.12023369929,1.25333011618,2.56497288445,2.77361359194,5.50799418376,1.,0.,strkey +66,-0.178980246554,1.11937139901,2.51598681313,2.91438309151,5.47469577206,1.,0.,strkey +67,0.97550951531,0.60553823137,2.11657741073,2.88081098981,5.37034999502,1.,0.,strkey +68,0.136653357206,0.365828836075,1.97386033165,3.13217903204,5.07254490219,1.,0.,strkey +69,-1.05607596951,0.153152115069,1.52110743825,3.01308794192,5.08902539125,1.,0.,strkey +70,-0.13095280331,0.337113974483,1.52703079853,3.16687131599,4.86649398514,1.,0.,strkey +71,1.07081057754,0.714247566736,1.53761382634,3.45151989484,4.75892309166,1.,0.,strkey +72,0.0153410376082,1.24631231847,1.61690939161,3.85481994498,4.35683752832,1.,0.,strkey +73,-0.912801257303,1.60791309476,1.8729264524,4.03037260012,4.36072588913,1.,0.,strkey +74,-0.0894895640338,2.02535207407,1.93484909619,4.09557485132,4.35327025188,1.,0.,strkey +75,0.978646999652,2.20085086625,2.09003440427,4.27542353033,4.1805058388,1.,0.,strkey +76,-0.113312642876,2.2444100761,2.50789248839,4.4151861502,4.03267168136,1.,0.,strkey +77,-1.00215099149,1.84305628445,2.61691237246,4.45425147595,3.81203553766,1.,0.,strkey +78,-0.0183234614205,1.49573923116,2.99308471214,4.71134960112,4.0273804959,1.,0.,strkey +79,1.0823738177,1.12211589848,3.27079386925,4.94288270502,4.01851068083,1.,0.,strkey +80,0.124370187893,0.616474412808,3.4284236674,4.76942168327,3.9749536483,1.,0.,strkey +81,-0.929423379352,0.290977090976,3.34131726136,4.78590392707,4.10190661656,1.,0.,strkey +82,0.23766302648,0.155302052254,3.49779513794,4.64605656795,4.15571321107,1.,0.,strkey +83,1.03531486192,0.359702776204,3.4880725919,4.48167586667,4.21134561991,1.,0.,strkey +84,-0.261234571382,0.713877760378,3.42756426614,4.426443869,4.25208300527,1.,0.,strkey +85,-1.03572442277,1.25001113691,2.96908341113,4.25500915322,4.25723010649,1.,0.,strkey +86,0.380034261243,1.70543355622,2.73605932518,4.16703432307,4.63700400788,1.,0.,strkey +87,1.03734873488,1.97544410562,2.55586572141,3.84976673263,4.55282864289,1.,0.,strkey +88,-0.177344253372,2.22614526325,2.09565864891,3.77378097953,4.82577400298,1.,0.,strkey +89,-0.976821526892,2.18385079177,1.78522284118,3.67768223554,5.06302440873,1.,0.,strkey +90,0.264820472091,1.86981946157,1.50048403865,3.43619796921,5.05651761669,1.,0.,strkey +91,1.05642344868,1.47568646076,1.51347671977,3.20898518885,5.50149047462,1.,0.,strkey +92,-0.311607433358,1.04226467636,1.52089650905,3.02291865417,5.4889046232,1.,0.,strkey +93,-0.724285777937,0.553052311957,1.48573560173,2.7365973598,5.72549174225,1.,0.,strkey +94,0.519859192905,0.226520626591,1.61543723167,2.84102086852,5.69330622288,1.,0.,strkey +95,1.0323195039,0.260873217055,1.81913034804,2.83951143848,5.90325028086,1.,0.,strkey +96,-0.53285682538,0.387695521405,1.70935609313,2.57977050631,5.79579213161,1.,0.,strkey +97,-0.975127997215,0.920948771589,2.51292643636,2.71004616612,5.87016469227,1.,0.,strkey +98,0.540246804099,1.36445470181,2.61949412896,2.98482553485,6.02447664937,1.,0.,strkey +99,0.987764008058,1.85581989607,2.84685706149,2.94760204892,6.0212151724,1.,0.,strkey diff --git a/tensorflow/contrib/timeseries/examples/known_anomaly.py b/tensorflow/contrib/timeseries/examples/known_anomaly.py index 7659dd308a7ee1b70d6688b85e4f6157ddee0540..e77628ddd390374d6336e3583e07ce03cdec7aea 100644 --- a/tensorflow/contrib/timeseries/examples/known_anomaly.py +++ b/tensorflow/contrib/timeseries/examples/known_anomaly.py @@ -46,12 +46,21 @@ def train_and_evaluate_exogenous(csv_file_name=_DATA_FILE, train_steps=300): # Indicate the format of our exogenous feature, in this case a string # representing a boolean value. - string_feature = tf.contrib.layers.sparse_column_with_keys( - column_name="is_changepoint", keys=["no", "yes"]) + string_feature = tf.feature_column.categorical_column_with_vocabulary_list( + key="is_changepoint", vocabulary_list=["no", "yes"]) # Specify the way this feature is presented to the model, here using a one-hot # encoding. - one_hot_feature = tf.contrib.layers.one_hot_column( - sparse_id_column=string_feature) + one_hot_feature = tf.feature_column.indicator_column( + categorical_column=string_feature) + + def _exogenous_update_condition(times, features): + del times # unused + # Make exogenous updates sparse by setting an update condition. This in + # effect allows missing exogenous features: if the condition evaluates to + # False, no update is performed. Otherwise we sometimes end up with "leaky" + # updates which add unnecessary uncertainty to the model even when there is + # no changepoint. + return tf.equal(tf.squeeze(features["is_changepoint"], axis=-1), "yes") estimator = tf.contrib.timeseries.StructuralEnsembleRegressor( periodicities=12, @@ -60,13 +69,7 @@ def train_and_evaluate_exogenous(csv_file_name=_DATA_FILE, train_steps=300): cycle_num_latent_values=3, num_features=1, exogenous_feature_columns=[one_hot_feature], - # Make exogenous updates sparse by setting an update condition. This in - # effect allows missing exogenous features: if the condition evaluates to - # False, no update is performed. Otherwise we sometimes end up with - # "leaky" updates which add unnecessary uncertainty to the model even when - # there is no changepoint. - exogenous_update_condition= - lambda times, features: tf.equal(features["is_changepoint"], "yes")) + exogenous_update_condition=_exogenous_update_condition) reader = tf.contrib.timeseries.CSVReader( csv_file_name, # Indicate the format of our CSV file. First we have two standard columns, diff --git a/tensorflow/contrib/timeseries/examples/lstm.py b/tensorflow/contrib/timeseries/examples/lstm.py index f37cafcc502dc9415db0829b9b067b862f87dca7..b1c7475442c58b9a190c818b752760a4fb4fe6f0 100644 --- a/tensorflow/contrib/timeseries/examples/lstm.py +++ b/tensorflow/contrib/timeseries/examples/lstm.py @@ -59,10 +59,10 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): num_units: The number of units in the model's LSTMCell. num_features: The dimensionality of the time series (features per timestep). - exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn - objects representing features which are inputs to the model but are - not predicted by it. These must then be present for training, - evaluation, and prediction. + exogenous_feature_columns: A list of `tf.feature_column`s representing + features which are inputs to the model but are not predicted by + it. These must then be present for training, evaluation, and + prediction. dtype: The floating point data type to use. """ super(_LSTMModel, self).__init__( @@ -189,12 +189,16 @@ def train_and_predict( export_directory=None): """Train and predict using a custom time series model.""" # Construct an Estimator from our LSTM model. + categorical_column = tf.feature_column.categorical_column_with_hash_bucket( + key="categorical_exogenous_feature", hash_bucket_size=16) exogenous_feature_columns = [ # Exogenous features are not part of the loss, but can inform # predictions. In this example the features have no extra information, but # are included as an API example. - tf.contrib.layers.real_valued_column( - "2d_exogenous_feature", dimension=2)] + tf.feature_column.numeric_column( + "2d_exogenous_feature", shape=(2,)), + tf.feature_column.embedding_column( + categorical_column=categorical_column, dimension=10)] estimator = ts_estimators.TimeSeriesRegressor( model=_LSTMModel(num_features=5, num_units=128, exogenous_feature_columns=exogenous_feature_columns), @@ -205,7 +209,11 @@ def train_and_predict( csv_file_name, column_names=((tf.contrib.timeseries.TrainEvalFeatures.TIMES,) + (tf.contrib.timeseries.TrainEvalFeatures.VALUES,) * 5 - + ("2d_exogenous_feature",) * 2)) + + ("2d_exogenous_feature",) * 2 + + ("categorical_exogenous_feature",)), + # Data types other than for `times` need to be specified if they aren't + # float32. In this case one of our exogenous features has string dtype. + column_dtypes=((tf.int64,) + (tf.float32,) * 7 + (tf.string,))) train_input_fn = tf.contrib.timeseries.RandomWindowInputFn( reader, batch_size=4, window_size=32) estimator.train(input_fn=train_input_fn, steps=training_steps) @@ -215,7 +223,9 @@ def train_and_predict( predict_exogenous_features = { "2d_exogenous_feature": numpy.concatenate( [numpy.ones([1, 100, 1]), numpy.zeros([1, 100, 1])], - axis=-1)} + axis=-1), + "categorical_exogenous_feature": numpy.array( + ["strkey"] * 100)[None, :, None]} (predictions,) = tuple(estimator.predict( input_fn=tf.contrib.timeseries.predict_continuation_input_fn( evaluation, steps=100, @@ -226,20 +236,36 @@ def train_and_predict( [evaluation["mean"][0], predictions["mean"]], axis=0)) all_times = numpy.concatenate([times, predictions["times"]], axis=0) - # Export the model in SavedModel format. + # Export the model in SavedModel format. We include a bit of extra boilerplate + # for "cold starting" as if we didn't have any state from the Estimator, which + # is the case when serving from a SavedModel. If Estimator output is + # available, the result of "Estimator.evaluate" can be passed directly to + # `tf.contrib.timeseries.saved_model_utils.predict_continuation` as the + # `continue_from` argument. + with tf.Graph().as_default(): + filter_feature_tensors, _ = evaluation_input_fn() + with tf.train.MonitoredSession() as session: + # Fetch the series to "warm up" our state, which will allow us to make + # predictions for its future values. This is just a dictionary of times, + # values, and exogenous features mapping to numpy arrays. The use of an + # input_fn is just a convenience for the example; they can also be + # specified manually. + filter_features = session.run(filter_feature_tensors) if export_directory is None: export_directory = tempfile.mkdtemp() input_receiver_fn = estimator.build_raw_serving_input_receiver_fn() export_location = estimator.export_savedmodel( export_directory, input_receiver_fn) - # Predict using the SavedModel + # Warm up and predict using the SavedModel with tf.Graph().as_default(): with tf.Session() as session: signatures = tf.saved_model.loader.load( session, [tf.saved_model.tag_constants.SERVING], export_location) + state = tf.contrib.timeseries.saved_model_utils.cold_start_filter( + signatures=signatures, session=session, features=filter_features) saved_model_output = ( tf.contrib.timeseries.saved_model_utils.predict_continuation( - continue_from=evaluation, signatures=signatures, + continue_from=state, signatures=signatures, session=session, steps=100, exogenous_features=predict_exogenous_features)) # The exported model gives the same results as the Estimator.predict() diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py index ff140efd48104e386826eab7abbc94bec220f9df..4f6527a5465ca01ed34150a26ba26d73a858cd74 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py @@ -70,7 +70,7 @@ class ARModel(model.TimeSeriesModel): input_window_size: Number of past time steps of data to look at when doing the regression. output_window_size: Number of future time steps to predict. Note that - setting it to > 1 empiricaly seems to give a better fit. + setting it to > 1 empirically seems to give a better fit. num_features: number of input features per time step. num_time_buckets: Number of buckets into which to divide (time % periodicity) for generating time based features. diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py index f8355f366fe8e191ab570fd271bbe4a8bf71c73d..469cea4fd2fca65373eef85b1931a267e6e60238 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.layers.python.layers import feature_column - from tensorflow.contrib.timeseries.python.timeseries import ar_model from tensorflow.contrib.timeseries.python.timeseries import feature_keys from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib @@ -31,11 +29,15 @@ from tensorflow.contrib.timeseries.python.timeseries.state_space_models.filterin from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator.export import export_lib +from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.training import training as train +from tensorflow.python.util import nest class TimeSeriesRegressor(estimator_lib.Estimator): @@ -98,11 +100,11 @@ class TimeSeriesRegressor(estimator_lib.Estimator): def _serving_input_receiver_fn(): """A receiver function to be passed to export_savedmodel.""" placeholders = {} - placeholders[feature_keys.TrainEvalFeatures.TIMES] = ( - array_ops.placeholder( - name=feature_keys.TrainEvalFeatures.TIMES, - dtype=dtypes.int64, - shape=[default_batch_size, default_series_length])) + time_placeholder = array_ops.placeholder( + name=feature_keys.TrainEvalFeatures.TIMES, + dtype=dtypes.int64, + shape=[default_batch_size, default_series_length]) + placeholders[feature_keys.TrainEvalFeatures.TIMES] = time_placeholder # Values are only necessary when filtering. For prediction the default # value will be ignored. placeholders[feature_keys.TrainEvalFeatures.VALUES] = ( @@ -117,36 +119,57 @@ class TimeSeriesRegressor(estimator_lib.Estimator): dtype=self._model.dtype), shape=(default_batch_size, default_series_length, self._model.num_features))) - with ops.Graph().as_default(): - # Default placeholders have only an unknown batch dimension. Make them - # in a separate graph, then splice in the series length to the shapes - # and re-create them in the outer graph. - exogenous_feature_shapes = { - key: (value.get_shape(), value.dtype) for key, value - in feature_column.make_place_holder_tensors_for_base_features( - self._model.exogenous_feature_columns).items()} - for feature_key, (batch_only_feature_shape, value_dtype) in ( - exogenous_feature_shapes.items()): - batch_only_feature_shape = batch_only_feature_shape.with_rank_at_least( - 1).as_list() - feature_shape = ([default_batch_size, default_series_length] - + batch_only_feature_shape[1:]) - placeholders[feature_key] = array_ops.placeholder( - dtype=value_dtype, name=feature_key, shape=feature_shape) + if self._model.exogenous_feature_columns: + with ops.Graph().as_default(): + # Default placeholders have only an unknown batch dimension. Make them + # in a separate graph, then splice in the series length to the shapes + # and re-create them in the outer graph. + parsed_features = ( + feature_column.make_parse_example_spec( + self._model.exogenous_feature_columns)) + placeholder_features = parsing_ops.parse_example( + serialized=array_ops.placeholder( + shape=[None], dtype=dtypes.string), + features=parsed_features) + exogenous_feature_shapes = { + key: (value.get_shape(), value.dtype) for key, value + in placeholder_features.items()} + for feature_key, (batch_only_feature_shape, value_dtype) in ( + exogenous_feature_shapes.items()): + batch_only_feature_shape = ( + batch_only_feature_shape.with_rank_at_least(1).as_list()) + feature_shape = ([default_batch_size, default_series_length] + + batch_only_feature_shape[1:]) + placeholders[feature_key] = array_ops.placeholder( + dtype=value_dtype, name=feature_key, shape=feature_shape) # Models may not know the shape of their state without creating some # variables/ops. Avoid polluting the default graph by making a new one. We # use only static metadata from the returned Tensors. with ops.Graph().as_default(): self._model.initialize_graph() - model_start_state = self._model.get_start_state() - for prefixed_state_name, state_tensor in ts_head_lib.state_to_dictionary( - model_start_state).items(): + # Evaluate the initial state as same-dtype "zero" values. These zero + # constants aren't used, but are necessary for feeding to + # placeholder_with_default for the "cold start" case where state is not + # fed to the model. + def _zeros_like_constant(tensor): + return tensor_util.constant_value(array_ops.zeros_like(tensor)) + start_state = nest.map_structure( + _zeros_like_constant, self._model.get_start_state()) + batch_size_tensor = array_ops.shape(time_placeholder)[0] + for prefixed_state_name, state in ts_head_lib.state_to_dictionary( + start_state).items(): state_shape_with_batch = tensor_shape.TensorShape( - (default_batch_size,)).concatenate(state_tensor.get_shape()) - placeholders[prefixed_state_name] = array_ops.placeholder( + (default_batch_size,)).concatenate(state.shape) + default_state_broadcast = array_ops.tile( + state[None, ...], + multiples=array_ops.concat( + [batch_size_tensor[None], + array_ops.ones(len(state.shape), dtype=dtypes.int32)], + axis=0)) + placeholders[prefixed_state_name] = array_ops.placeholder_with_default( + input=default_state_broadcast, name=prefixed_state_name, - shape=state_shape_with_batch, - dtype=state_tensor.dtype) + shape=state_shape_with_batch) return export_lib.ServingInputReceiver(placeholders, placeholders) return _serving_input_receiver_fn @@ -333,11 +356,11 @@ class StructuralEnsembleRegressor(StateSpaceRegressor): determine the model size. Learning autoregressive coefficients typically requires more steps and a smaller step size than other components. - exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn - objects (for example tf.contrib.layers.embedding_column) corresponding - to exogenous features which provide extra information to the model but - are not part of the series to be predicted. Passed to - tf.contrib.layers.input_from_feature_columns. + exogenous_feature_columns: A list of `tf.feature_column`s (for example + `tf.feature_column.embedding_column`) corresponding to exogenous + features which provide extra information to the model but are not part + of the series to be predicted. Passed to + `tf.feature_column.input_layer`. exogenous_update_condition: A function taking two Tensor arguments, `times` (shape [batch size]) and `features` (a dictionary mapping exogenous feature keys to Tensors with shapes [batch size, ...]), and diff --git a/tensorflow/contrib/timeseries/python/timeseries/feature_keys.py b/tensorflow/contrib/timeseries/python/timeseries/feature_keys.py index 970b9aa8acd6f55db843a4e023052b122992baf4..56566ee2e3207abd81ef665da10f851c9dc98ccb 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/feature_keys.py +++ b/tensorflow/contrib/timeseries/python/timeseries/feature_keys.py @@ -72,3 +72,4 @@ class SavedModelLabels(object): """Names of signatures exported with export_savedmodel.""" PREDICT = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY FILTER = "filter" + COLD_START_FILTER = "cold_start_filter" diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py index 5c49e903abde6d7487d1ffdb83ff902ff6b63585..3d7e61529014ff5045c3b64fb945ceb9c902dd0d 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head.py @@ -96,8 +96,12 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc def _train_ops(self, features): """Add training ops to the graph.""" mode = estimator_lib.ModeKeys.TRAIN - with variable_scope.variable_scope("model"): + with variable_scope.variable_scope( + "model", + # Use ResourceVariables to avoid race conditions. + use_resource=True): model_outputs = self.create_loss(features, mode) + train_op = optimizers.optimize_loss( model_outputs.loss, global_step=training_util.get_global_step(), @@ -112,7 +116,7 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc def _evaluate_ops(self, features): """Add ops for evaluation (aka filtering) to the graph.""" mode = estimator_lib.ModeKeys.EVAL - with variable_scope.variable_scope("model"): + with variable_scope.variable_scope("model", use_resource=True): model_outputs = self.create_loss(features, mode) metrics = {} # Just output in-sample predictions for the last chunk seen @@ -132,7 +136,7 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc def _predict_ops(self, features): """Add ops for prediction to the graph.""" - with variable_scope.variable_scope("model"): + with variable_scope.variable_scope("model", use_resource=True): prediction = self.model.predict(features=features) prediction[feature_keys.PredictionResults.TIMES] = features[ feature_keys.PredictionFeatures.TIMES] @@ -141,11 +145,17 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc def _serving_ops(self, features): """Add ops for serving to the graph.""" - with variable_scope.variable_scope("model"): + with variable_scope.variable_scope("model", use_resource=True): prediction_outputs = self.model.predict(features=features) with variable_scope.variable_scope("model", reuse=True): filtering_outputs = self.create_loss( features, estimator_lib.ModeKeys.EVAL) + with variable_scope.variable_scope("model", reuse=True): + no_state_features = { + k: v for k, v in features.items() + if not k.startswith(feature_keys.State.STATE_PREFIX)} + cold_filtering_outputs = self.create_loss( + no_state_features, estimator_lib.ModeKeys.EVAL) return estimator_lib.EstimatorSpec( mode=estimator_lib.ModeKeys.PREDICT, export_outputs={ @@ -153,7 +163,10 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc export_lib.PredictOutput(prediction_outputs), feature_keys.SavedModelLabels.FILTER: export_lib.PredictOutput( - state_to_dictionary(filtering_outputs.end_state)) + state_to_dictionary(filtering_outputs.end_state)), + feature_keys.SavedModelLabels.COLD_START_FILTER: + export_lib.PredictOutput( + state_to_dictionary(cold_filtering_outputs.end_state)) }, # Likely unused, but it is necessary to return `predictions` to satisfy # the Estimator's error checking. diff --git a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py index 04225333b9377447f46d32663df76aece97a51e7..403c6e2cb4aeb665fb112b6322109a6a90f7a261 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py +++ b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py @@ -492,8 +492,7 @@ class CSVReader(ReaderBaseTimeSeriesParser): features_lists.setdefault(column_name, []).append(value) features = {} for column_name, values in features_lists.items(): - if (len(values) == 1 and - column_name != feature_keys.TrainEvalFeatures.VALUES): + if column_name == feature_keys.TrainEvalFeatures.TIMES: features[column_name] = values[0] else: features[column_name] = array_ops.stack(values, axis=1) diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py index 23452a81c397da3516016d72b7bc9b80f7d6447f..26793c80bfbb3c9394e81a5bbfae360deb95ca58 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py @@ -185,7 +185,7 @@ def batch_matrix_pow(matrices, powers): { matmul(A, power(matmul(A, A), (p - 1) / 2)) for odd p power(A, 0) = I - The power(A, 0) = I case is handeled by starting with accumulator set to the + The power(A, 0) = I case is handled by starting with accumulator set to the identity matrix; matrices with zero residual powers are passed through unchanged. diff --git a/tensorflow/contrib/timeseries/python/timeseries/model.py b/tensorflow/contrib/timeseries/python/timeseries/model.py index bac7d1ebf59b28d4688a3d1a69ecdc1fc12248e0..7644764a7459db3951fe9a2790389713dd412a8f 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/model.py @@ -21,18 +21,17 @@ from __future__ import print_function import abc import collections -from tensorflow.contrib import layers -from tensorflow.contrib.layers import feature_column - from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures +from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope @@ -66,11 +65,11 @@ class TimeSeriesModel(object): Args: num_features: Number of features for the time series - exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn - objects (for example tf.contrib.layers.embedding_column) corresponding - to exogenous features which provide extra information to the model but - are not part of the series to be predicted. Passed to - tf.contrib.layers.input_from_feature_columns. + exogenous_feature_columns: A list of `tf.feature_column`s (for example + `tf.feature_column.embedding_column`) corresponding to exogenous + features which provide extra information to the model but are not + part of the series to be predicted. Passed to + `tf.feature_column.input_layer`. dtype: The floating point datatype to use. """ if exogenous_feature_columns: @@ -86,7 +85,7 @@ class TimeSeriesModel(object): @property def exogenous_feature_columns(self): - """`FeatureColumn` objects for features which are not predicted.""" + """`tf.feature_colum`s for features which are not predicted.""" return self._exogenous_feature_columns # TODO(allenl): Move more of the generic machinery for generating and @@ -265,11 +264,14 @@ class TimeSeriesModel(object): if not self._exogenous_feature_columns: return (0,) with ops.Graph().as_default(): - placeholder_features = ( - feature_column.make_place_holder_tensors_for_base_features( + parsed_features = ( + feature_column.make_parse_example_spec( self._exogenous_feature_columns)) - embedded = layers.input_from_feature_columns( - columns_to_tensors=placeholder_features, + placeholder_features = parsing_ops.parse_example( + serialized=array_ops.placeholder(shape=[None], dtype=dtypes.string), + features=parsed_features) + embedded = feature_column.input_layer( + features=placeholder_features, feature_columns=self._exogenous_feature_columns) return embedded.get_shape().as_list()[1:] @@ -308,13 +310,13 @@ class TimeSeriesModel(object): # Avoid shape warnings when embedding "scalar" exogenous features (those # with only batch and window dimensions); input_from_feature_columns # expects input ranks to match the embedded rank. - if tensor.get_shape().ndims == 1: + if tensor.get_shape().ndims == 1 and tensor.dtype != dtypes.string: exogenous_features_single_batch_dimension[name] = tensor[:, None] else: exogenous_features_single_batch_dimension[name] = tensor embedded_exogenous_features_single_batch_dimension = ( - layers.input_from_feature_columns( - columns_to_tensors=exogenous_features_single_batch_dimension, + feature_column.input_layer( + features=exogenous_features_single_batch_dimension, feature_columns=self._exogenous_feature_columns, trainable=True)) exogenous_regressors = array_ops.reshape( @@ -381,8 +383,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel): may use _scale_back_data or _scale_back_variance to return predictions to the input scale. dtype: The floating point datatype to use. - exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn - objects. See `TimeSeriesModel`. + exogenous_feature_columns: A list of `tf.feature_column`s objects. See + `TimeSeriesModel`. exogenous_update_condition: A function taking two Tensor arguments `times` (shape [batch size]) and `features` (a dictionary mapping exogenous feature keys to Tensors with shapes [batch size, ...]) and returning a diff --git a/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py b/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py index 97f6d36a879532c12684ffdd700ef40b72750567..0461abdc19c08767114e3d26d1134ea4bc5481f8 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py @@ -15,6 +15,7 @@ """Convenience functions for working with time series saved_models. @@predict_continuation +@@cold_start_filter @@filter_continuation """ @@ -30,10 +31,12 @@ from tensorflow.contrib.timeseries.python.timeseries import model_utils as _mode from tensorflow.python.util.all_util import remove_undocumented -def _colate_features_to_feeds_and_fetches(continue_from, signature, features, - graph): +def _colate_features_to_feeds_and_fetches(signature, features, graph, + continue_from=None): """Uses a saved model signature to construct feed and fetch dictionaries.""" - if _feature_keys.FilteringResults.STATE_TUPLE in continue_from: + if continue_from is None: + state_values = {} + elif _feature_keys.FilteringResults.STATE_TUPLE in continue_from: # We're continuing from an evaluation, so we need to unpack/flatten state. state_values = _head.state_to_dictionary( continue_from[_feature_keys.FilteringResults.STATE_TUPLE]) @@ -115,6 +118,55 @@ def predict_continuation(continue_from, return output +def cold_start_filter(signatures, session, features): + """Perform filtering using an exported saved model. + + Filtering refers to updating model state based on new observations. + Predictions based on the returned model state will be conditioned on these + observations. + + Starts from the model's default/uninformed state. + + Args: + signatures: The `MetaGraphDef` protocol buffer returned from + `tf.saved_model.loader.load`. Used to determine the names of Tensors to + feed and fetch. Must be from the same model as `continue_from`. + session: The session to use. The session's graph must be the one into which + `tf.saved_model.loader.load` loaded the model. + features: A dictionary mapping keys to Numpy arrays, with several possible + shapes (requires keys `FilteringFeatures.TIMES` and + `FilteringFeatures.VALUES`): + Single example; `TIMES` is a scalar and `VALUES` is either a scalar or a + vector of length [number of features]. + Sequence; `TIMES` is a vector of shape [series length], `VALUES` either + has shape [series length] (univariate) or [series length x number of + features] (multivariate). + Batch of sequences; `TIMES` is a vector of shape [batch size x series + length], `VALUES` has shape [batch size x series length] or [batch + size x series length x number of features]. + In any case, `VALUES` and any exogenous features must have their shapes + prefixed by the shape of the value corresponding to the `TIMES` key. + Returns: + A dictionary containing model state updated to account for the observations + in `features`. + """ + filter_signature = signatures.signature_def[ + _feature_keys.SavedModelLabels.COLD_START_FILTER] + features = _input_pipeline._canonicalize_numpy_data( # pylint: disable=protected-access + data=features, + require_single_batch=False) + output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches( + signature=filter_signature, + features=features, + graph=session.graph) + output = session.run(output_tensors_by_name, feed_dict=feed_dict) + # Make it easier to chain filter -> predict by keeping track of the current + # time. + output[_feature_keys.FilteringResults.TIMES] = features[ + _feature_keys.FilteringFeatures.TIMES] + return output + + def filter_continuation(continue_from, signatures, session, features): """Perform filtering using an exported saved model. @@ -124,8 +176,8 @@ def filter_continuation(continue_from, signatures, session, features): Args: continue_from: A dictionary containing the results of either an Estimator's - evaluate method or a previous filter_continuation. Used to determine the - model state to start filtering from. + evaluate method or a previous filter step (cold start or + continuation). Used to determine the model state to start filtering from. signatures: The `MetaGraphDef` protocol buffer returned from `tf.saved_model.loader.load`. Used to determine the names of Tensors to feed and fetch. Must be from the same model as `continue_from`. diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py index 6257002647ed53bbde3ace11a6b45e4e2cdeb57d..951c6546d5fed77e0cfa98a4e774b804639d7dad 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py @@ -112,11 +112,11 @@ class StateSpaceModelConfiguration( exogenous_noise_decreases: If True, exogenous regressors can "set" model state, decreasing uncertainty. If both this parameter and exogenous_noise_increases are False, exogenous regressors are ignored. - exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn - objects (for example tf.contrib.layers.embedding_column) corresponding - to exogenous features which provide extra information to the model but - are not part of the series to be predicted. Passed to - tf.contrib.layers.input_from_feature_columns. + exogenous_feature_columns: A list of `tf.feature_column`s (for example + `tf.feature_column.embedding_column`) corresponding to exogenous + features which provide extra information to the model but are not part + of the series to be predicted. Passed to + `tf.feature_column.input_layer`. exogenous_update_condition: A function taking two Tensor arguments `times` (shape [batch size]) and `features` (a dictionary mapping exogenous feature keys to Tensors with shapes [batch size, ...]) and returning a diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma.py index 1afc58cfb240c52a9f001da787addfb7fbb46789..6746dd7b433466c473402e0e8374377093a73492 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma.py @@ -107,7 +107,7 @@ class VARMA(state_space_model.StateSpaceModel): Returns: the state transition matrix. It has shape - [self.state_dimendion, self.state_dimension]. + [self.state_dimension, self.state_dimension]. """ # Pad any unused AR blocks with zeros. The extra state is necessary if # ma_order >= ar_order. @@ -127,7 +127,7 @@ class VARMA(state_space_model.StateSpaceModel): Returns: the state noise transform matrix. It has shape - [self.state_dimendion, self.num_features]. + [self.state_dimension, self.num_features]. """ # Noise is broadcast, through the moving average coefficients, to # un-observed parts of the latent state. diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index c48e84ddfaac8ac9c07e061847315eab3fd72152..eea19e9465e482dfd1ea9a144435c23a2ecf1467 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -24,6 +24,7 @@ cc_library( name = "all_ops", deps = [ ":cross_replica_ops_op_lib", + ":host_compute_ops_op_lib", ":infeed_ops_op_lib", ":outfeed_ops_op_lib", ":replication_ops_op_lib", @@ -69,6 +70,7 @@ py_library( tf_gen_op_libs( op_lib_names = [ "cross_replica_ops", + "host_compute_ops", "infeed_ops", "outfeed_ops", "replication_ops", @@ -78,6 +80,7 @@ tf_gen_op_libs( deps = [ "//tensorflow/contrib/tpu/proto:tpu_embedding_config_proto_cc", "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:protos_all_cc", ], ) @@ -85,6 +88,7 @@ tf_custom_op_library( name = "python/ops/_tpu_ops.so", srcs = [ "ops/cross_replica_ops.cc", + "ops/host_compute_ops.cc", "ops/infeed_ops.cc", "ops/outfeed_ops.cc", "ops/replication_ops.cc", @@ -101,6 +105,7 @@ tf_gen_op_wrapper_py( name = "tpu_ops", deps = [ ":cross_replica_ops_op_lib", + ":host_compute_ops_op_lib", ":infeed_ops_op_lib", ":outfeed_ops_op_lib", ":replication_ops_op_lib", @@ -163,6 +168,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":datasets", ":profiler", ":tpu_py", "//tensorflow/contrib/tpu/proto:topology_proto_py", @@ -181,6 +187,33 @@ py_library( ], ) +py_library( + name = "datasets", + srcs = [ + "python/tpu/datasets.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:function", + "//tensorflow/python:functional_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:readers", + ], +) + +tf_py_test( + name = "datasets_test", + srcs = ["python/tpu/datasets_test.py"], + additional_deps = [ + "//tensorflow/python:client_testlib", + ":datasets", + ], + grpc_enabled = True, +) + tf_py_test( name = "tpu_test", size = "small", @@ -238,6 +271,17 @@ tf_py_test( ], ) +tf_py_test( + name = "tpu_estimator_signals_test", + size = "small", + srcs = ["python/tpu/tpu_estimator_signals_test.py"], + additional_deps = [ + ":tpu_estimator", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/tpu/ops/host_compute_ops.cc b/tensorflow/contrib/tpu/ops/host_compute_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..48aeb81ac1311d3acd4972810f0a27a382f8b136 --- /dev/null +++ b/tensorflow/contrib/tpu/ops/host_compute_ops.cc @@ -0,0 +1,64 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("_XlaSendFromHost") + .Input("inputs: Tinputs") + .Input("dynamic_key: string") + .Attr("Tinputs: list(type) >= 0") + .Attr("key: string") + .Attr("device_ordinal: int") + .SetIsStateful() + .SetShapeFn(::tensorflow::shape_inference::NoOutputs) + .Doc(R"doc( +A placeholder op for multiple values that will be sent from TensorFlow to a +running XLA computation. + +inputs: A list of tensors that will be sent to the XLA computation. +dynamic_key: The key sent at runtime by the compile node to identify which +execution the transfer corresponds to. +Tinputs: The element types of each element in `inputs`. +key: A key that is unique in the computation and associates the send with the consumer in +the XLA computation. +device_ordinal: The device to use. +)doc"); + +REGISTER_OP("_XlaRecvAtHost") + .Input("dynamic_key: string") + .Output("outputs: Toutputs") + .Attr("Toutputs: list(type) >= 0") + .Attr("key: string") + .Attr("device_ordinal: int") + .SetIsStateful() + .SetShapeFn(::tensorflow::shape_inference::UnknownShape) + .Doc(R"doc( +A placeholder op for multiple values that will be sent to TensorFlow from a +running XLA computation. + +dynamic_key: The key sent at runtime by the compile node to identify which +execution the transfer corresponds to. +outputs: A list of tensors that will be received from the XLA computation. +Toutputs: The element types of each element in `outputs`. +key: A key that is unique in the computation and associates the send with the consumer in +the XLA computation. +device_ordinal: The device to use. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc b/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc index f8de8baa65339383c7f92284ee274a434f12f8c2..7bf5c21d0b526ee5e32448f75d39eca8add6d877 100644 --- a/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc +++ b/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc @@ -191,6 +191,7 @@ REGISTER_OP("ConfigureDistributedTPU") .Output("topology: string") .Attr("embedding_config: string = ''") .Attr("tpu_embedding_config: string = ''") + .Attr("is_global_init: bool = false") .SetIsStateful() .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( @@ -202,6 +203,7 @@ topology. tpu_embedding_config: Serialized tensorflow.tpu.TPUEmbeddingConfiguration that describes the embedding lookups of the program. embedding_config: Reserved. Do not use. +is_global_init: Reserved. Do not use. )doc"); REGISTER_OP("ShutdownDistributedTPU") diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc index cc32a265286951a1e4d59228da6b3ac83a75c5e9..72d37f774cc518c559b5953561957a799a7da568 100644 --- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc +++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc @@ -50,7 +50,7 @@ namespace tensorflow { // TPU Embeddings use dedicated ops to enforce Host/TPU consistency in the // state of embedding table variables. Before beginning training or inference, // the model must Load the optimizer parameters into the TPU memories. Before -// saving a checkpoint, the model must Retreieve the parameters back into the +// saving a checkpoint, the model must Retrieve the parameters back into the // host CPU memory. REGISTER_OP("TPUEmbeddingLoadGradientDescentParameters") @@ -263,7 +263,7 @@ REGISTER_OP("TPUEmbeddingReceiveActivations") .SetIsStateful() .SetShapeFn(tpu_embedding_config_util::ActivationShapes) .Doc(R"doc( -An op that receives embeddng activations on the TPU. +An op that receives embedding activations on the TPU. The TPU system performs the embedding lookups and aggregations specified by the arguments to TPUEmbeddingEnqueueSparseBatch. The results of these @@ -293,7 +293,7 @@ REGISTER_OP("TPUEmbeddingActivations") An op enabling differentiation of TPU Embeddings. This op simply returns its first input, which is assumed to have been sliced -from the Tensors returnd by TPUEmbeddingDequeueActivations. The presence of this +from the Tensors returned by TPUEmbeddingDequeueActivations. The presence of this op, and its first argument being a trainable Variable, enables automatic differentiation of graphs containing embeddings via the TPU Embedding Python libraries. diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index 198da0203a7d17249c4f50110713121b74d5ca4f..0a52d0b13b7c8749ad44377659714d297ffec3ee 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -18,7 +18,7 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) -tf_proto_library_cc( +tf_proto_library( name = "tpu_profiler_proto", srcs = ["tpu_profiler.proto"], has_services = 1, @@ -98,16 +98,36 @@ tf_cc_test( ], ) -tf_proto_library_cc( +tf_proto_library( name = "op_profile_proto", srcs = ["op_profile.proto"], cc_api_version = 2, visibility = ["//visibility:public"], ) -tf_proto_library_cc( +tf_proto_library( name = "tf_op_stats_proto", srcs = ["tf_op_stats.proto"], cc_api_version = 2, visibility = ["//visibility:public"], ) + +tf_proto_library( + name = "tpu_profiler_analysis_proto", + srcs = ["tpu_profiler_analysis.proto"], + has_services = 1, + cc_api_version = 2, + cc_grpc_version = 1, + protodeps = [":tpu_profiler_proto"] + tf_additional_all_protos(), + visibility = ["//visibility:public"], +) + +py_library( + name = "tpu_profiler_analysis_pb2_grpc", + srcs = ["tpu_profiler_analysis_pb2_grpc.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":tpu_profiler_analysis_proto_py", + ], +) diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index b1ef9fde37fe0647965f0818895be37d2d56d207..e6811d4ad204edb318638c698090479436f38ecd 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -29,6 +29,9 @@ limitations under the License. #include "tensorflow/contrib/tpu/profiler/version.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" @@ -62,10 +65,13 @@ Status ValidateHostPortPair(const string& host_port) { } ProfileResponse Profile(const string& service_addr, int duration_ms, + const string& repository_root, const string& session_id, const ProfileOptions& opts) { ProfileRequest request; request.set_duration_ms(duration_ms); request.set_max_events(kMaxEvents); + request.set_repository_root(repository_root); + request.set_session_id(session_id); request.add_tools("input_pipeline"); request.add_tools("overview_page"); *request.mutable_opts() = opts; @@ -137,10 +143,17 @@ int main(int argc, char** argv) { opts.set_include_dataset_ops(FLAGS_include_dataset_ops); tensorflow::ProfileResponse response; + // Use the current timestamp as the run name. + tensorflow::string session_id = + tensorflow::tpu::GetCurrentTimeStampAsString(); + constexpr char kProfilePluginDirectory[] = "plugins/profile/"; + tensorflow::string repository_root = + ::tensorflow::io::JoinPath(FLAGS_logdir, kProfilePluginDirectory); while (true) { std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. " << "Remaining attempt(s): " << remaining_attempts-- << std::endl; - response = tensorflow::tpu::Profile(FLAGS_service_addr, duration_ms, opts); + response = tensorflow::tpu::Profile(FLAGS_service_addr, duration_ms, + repository_root, session_id, opts); if (remaining_attempts <= 0 || !response.encoded_trace().empty()) break; std::cout << "No trace event is collected. Automatically retrying." << std::endl @@ -158,10 +171,8 @@ int main(int argc, char** argv) { return 0; } - // Use the current timestamp as the run name. - tensorflow::string run = tensorflow::tpu::GetCurrentTimeStampAsString(); TF_CHECK_OK(tensorflow::tpu::WriteTensorboardTPUProfile( - FLAGS_logdir, run, response, &std::cout)); + FLAGS_logdir, session_id, response, &std::cout)); // Print this at the end so that it's not buried in irrelevant LOG messages. std::cout << "NOTE: using the trace duration " << duration_ms << "ms." << std::endl diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto index 2094294baad63ae73712c8648b588accd4551ef8..20ed7419fde36a0d112900093ed2f44c3af63d75 100644 --- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto +++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto @@ -77,6 +77,8 @@ message StepInfoResult { // The infeed duration in picoseconds. // Can turn into a map if we want a variable number of ops. optional uint64 infeed_duration_ps = 3; + // The start time of this step in picoseconds. + optional uint64 begin_ps = 4; } // Result proto for a sequence of steps. @@ -155,6 +157,54 @@ message RunEnvironmentResult { repeated HostDependentJobInfoResult host_dependent_job_info = 6; } +// The types of host operations that are tracked. +enum HostOp { + // Invalid host op. + kINVALIDHostOp = 0; + // Each of host op type has two parts: + // (1) the stage where the op happens and (2) the op name. + // stage = Input Data Producer, op = Get Next Batch. + kInputDataProducerGetNextBatch = 1; + // stage = Input Data Producer, op = Session Run. + kInputDataProducerSessionRun = 2; + // stage = Input Data Producer, op = Forward Batch. + kInputDataProducerForwardBatch = 3; + // stage = Infeed Thread, op = Get Next Batch. + kInfeedThreadGetNextBatch = 4; + // stage = Infeed Thread, op = Session Run. + kInfeedThreadSessionRun = 5; + // stage = Infeed Thread, op = Forward Batch. + kInfeedThreadForwardBatch = 6; + // stage = Outfeed Thread, op = Get Next Batch. + kOutfeedThreadGetNextBatch = 7; + // stage = Outfeed Thread, op = Session Run. + kOutfeedThreadSessionRun = 8; + // stage = Outfeed Thread, op = Forward Batch. + kOutfeedThreadForwardBatch = 9; +} + +// Result proto for the host ops per TPU step. +message HostOpsPerTpuStep { + // Whether the data in this message is valid. + optional bool valid = 1 [default = false]; + // The current TPU step number. + optional uint32 tpu_step_num = 2; + // The beginning time of the current TPU step on the device in picoseconds. + optional uint64 tpu_step_begin_ps = 3; + // The ending time of the current TPU step on the device in picoseconds. + optional uint64 tpu_step_end_ps = 4; + // For each possible host operation, maps to the difference between the TPU + // step number that the host op targets and the current TPU step number. + // The key is HostOp, value is the step difference. + map step_diffs = 5; +} + +// Result proto for the host ops for all TPU steps. +message HostOpsResult { + // A sequence of HostOpsPerTpuStep (one for each TPU step) + repeated HostOpsPerTpuStep host_op_sequence = 1; +} + // Result proto for TfStatsHelper. message TfOpStats { // The result for the TF-metric database. @@ -171,4 +221,8 @@ message TfOpStats { optional double matrix_unit_utilization_percent = 6; // The run environment of this profiling session. optional RunEnvironmentResult run_environment = 7; + // The result for the host operations. + optional HostOpsResult host_ops = 8; + // A map from core ID to name. + map core_id_to_name_map = 9; } diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto index f3f3302ceb3d27dbb21bdce753aeb2d7fcd77448..cddc3cd1b41d6e00409222170e69c429fe6f91f8 100644 --- a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto +++ b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto @@ -36,10 +36,17 @@ message ProfileRequest { // Optional profiling options that control how a TF session will be profiled. ProfileOptions opts = 4; + // The place where we will dump profile data. We will normally use + // MODEL_DIR/plugin/profile/ as our repository root. + string repository_root = 5; + + // The user provided profile session identifier. + string session_id = 6; + // In future, the caller will indicate which TF session is being profiled, and // only data relating to that program will be returned. For now, we assume // all activity during the profiling period is relevant. - // next-field: 5 + // next-field: 7 } message ProfileToolData { diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto new file mode 100644 index 0000000000000000000000000000000000000000..a4fc8d4e879eb85522f35663c9c628ecd5ef562c --- /dev/null +++ b/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto @@ -0,0 +1,73 @@ +syntax = "proto3"; +package tensorflow; + +import "tensorflow/contrib/tpu/profiler/tpu_profiler.proto"; + +message NewProfileSessionRequest { + ProfileRequest request = 1; + string repository_root = 2; + repeated string hosts = 3; +} + +message NewProfileSessionResponse { + // Auxiliary error_message. + string error_message = 1; + // If success, return session identifier for future reference. + string session_id = 2; +} + +message EnumProfileSessionsAndToolsRequest { + string repository_root = 1; +} + +message ProfileSessionInfo { + string session_id = 1; + // Which tool data is available for consumption. + repeated string available_tools = 2; +} + +message EnumProfileSessionsAndToolsResponse { + // Auxiliary error_message. + string error_message = 1; + // If success, the returned sessions information are stored here. + repeated ProfileSessionInfo sessions = 2; +} + +message ProfileSessionDataRequest { + string repository_root = 1; + string session_id = 2; + // Which tool + string tool_name = 3; + // Tool's specific parameters. e.g. TraceViewer's viewport etc + map parameters = 4; +} + +message ProfileSessionDataResponse { + // Auxiliary error_message. + string error_message = 1; + + // Output format. e.g. "json" or "proto" or "blob" + string output_format = 2; + + // TODO(jiesun): figure out whether to put bytes or oneof tool specific proto. + bytes output = 3; +} +//////////////////////////////////////////////////////////////////////////////// +// TPUProfileAnalysis service provide entry point for profiling TPU and for +// serving profiled data to Tensorboard through GRPC +//////////////////////////////////////////////////////////////////////////////// +service TPUProfileAnalysis { + // Starts a profiling session, blocks until it completes. + // TPUProfileAnalysis service delegate this to TPUProfiler service. + // Populate the profiled data in repository, then return status to caller. + rpc NewSession(NewProfileSessionRequest) returns (NewProfileSessionResponse) { + } + // Enumerate existing sessions and return available profile tools. + rpc EnumSessions(EnumProfileSessionsAndToolsRequest) + returns (EnumProfileSessionsAndToolsResponse) { + } + // Retrieve specific tool's data for specific session. + rpc GetSessionToolData(ProfileSessionDataRequest) + returns (ProfileSessionDataResponse) { + } +} diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis_pb2_grpc.py b/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis_pb2_grpc.py new file mode 100644 index 0000000000000000000000000000000000000000..c28fef22a9d3736748b1b56135302d5ec7845720 --- /dev/null +++ b/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis_pb2_grpc.py @@ -0,0 +1,138 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +# +# Do not use pylint on generated code. +# pylint: disable=missing-docstring,g-short-docstring-punctuation,g-no-space-after-docstring-summary,invalid-name,line-too-long,unused-argument,g-doc-args +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import grpc + +from third_party.tensorflow.contrib.tpu.profiler import tpu_profiler_analysis_pb2 as third__party_dot_tensorflow_dot_contrib_dot_tpu_dot_profiler_dot_tpu__profiler__analysis__pb2 + + +class TPUProfileAnalysisStub(object): + """////////////////////////////////////////////////////////////////////////////// + + TPUProfileAnalysis service provide entry point for profiling TPU and for + serving profiled data to Tensorboard through GRPC + ////////////////////////////////////////////////////////////////////////////// + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.NewSession = channel.unary_unary( + '/tensorflow.TPUProfileAnalysis/NewSession', + request_serializer= + third__party_dot_tensorflow_dot_contrib_dot_tpu_dot_profiler_dot_tpu__profiler__analysis__pb2. + NewProfileSessionRequest.SerializeToString, + response_deserializer= + third__party_dot_tensorflow_dot_contrib_dot_tpu_dot_profiler_dot_tpu__profiler__analysis__pb2. + NewProfileSessionResponse.FromString, + ) + self.EnumSessions = channel.unary_unary( + '/tensorflow.TPUProfileAnalysis/EnumSessions', + request_serializer= + third__party_dot_tensorflow_dot_contrib_dot_tpu_dot_profiler_dot_tpu__profiler__analysis__pb2. + EnumProfileSessionsAndToolsRequest.SerializeToString, + response_deserializer= + third__party_dot_tensorflow_dot_contrib_dot_tpu_dot_profiler_dot_tpu__profiler__analysis__pb2. + EnumProfileSessionsAndToolsResponse.FromString, + ) + self.GetSessionToolData = channel.unary_unary( + '/tensorflow.TPUProfileAnalysis/GetSessionToolData', + request_serializer= + third__party_dot_tensorflow_dot_contrib_dot_tpu_dot_profiler_dot_tpu__profiler__analysis__pb2. + ProfileSessionDataRequest.SerializeToString, + response_deserializer= + third__party_dot_tensorflow_dot_contrib_dot_tpu_dot_profiler_dot_tpu__profiler__analysis__pb2. + ProfileSessionDataResponse.FromString, + ) + + +class TPUProfileAnalysisServicer(object): + """////////////////////////////////////////////////////////////////////////////// + + TPUProfileAnalysis service provide entry point for profiling TPU and for + serving profiled data to Tensorboard through GRPC + ////////////////////////////////////////////////////////////////////////////// + """ + + def NewSession(self, request, context): + """Starts a profiling session, blocks until it completes. + TPUProfileAnalysis service delegate this to TPUProfiler service. + Populate the profiled data in repository, then return status to caller. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def EnumSessions(self, request, context): + """Enumerate existing sessions and return available profile tools. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetSessionToolData(self, request, context): + """Retrieve specific tool's data for specific session. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_TPUProfileAnalysisServicer_to_server(servicer, server): + rpc_method_handlers = { + 'NewSession': + grpc.unary_unary_rpc_method_handler( + servicer.NewSession, + request_deserializer= + third__party_dot_tensorflow_dot_contrib_dot_tpu_dot_profiler_dot_tpu__profiler__analysis__pb2. + NewProfileSessionRequest.FromString, + response_serializer= + third__party_dot_tensorflow_dot_contrib_dot_tpu_dot_profiler_dot_tpu__profiler__analysis__pb2. + NewProfileSessionResponse.SerializeToString, + ), + 'EnumSessions': + grpc.unary_unary_rpc_method_handler( + servicer.EnumSessions, + request_deserializer= + third__party_dot_tensorflow_dot_contrib_dot_tpu_dot_profiler_dot_tpu__profiler__analysis__pb2. + EnumProfileSessionsAndToolsRequest.FromString, + response_serializer= + third__party_dot_tensorflow_dot_contrib_dot_tpu_dot_profiler_dot_tpu__profiler__analysis__pb2. + EnumProfileSessionsAndToolsResponse.SerializeToString, + ), + 'GetSessionToolData': + grpc.unary_unary_rpc_method_handler( + servicer.GetSessionToolData, + request_deserializer= + third__party_dot_tensorflow_dot_contrib_dot_tpu_dot_profiler_dot_tpu__profiler__analysis__pb2. + ProfileSessionDataRequest.FromString, + response_serializer= + third__party_dot_tensorflow_dot_contrib_dot_tpu_dot_profiler_dot_tpu__profiler__analysis__pb2. + ProfileSessionDataResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'tensorflow.TPUProfileAnalysis', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..465c668fd8b42f150892f8e4b52de76c6fe13fa9 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/datasets.py @@ -0,0 +1,184 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""Library of Cloud TPU helper functions for data loading.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import functional_ops + + +def _TextLineDataset(filename): + buffer_size = 8 * 1024 * 1024 # 8 MiB per file + dataset = readers.TextLineDataset(filename, buffer_size=buffer_size) + return dataset + + +def _TFRecordDataset(filename): + buffer_size = 8 * 1024 * 1024 # 8 MiB per file + dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size) + return dataset + + +_FILETYPE_MAP = { + 'tfrecord': _TFRecordDataset, + 'textline': _TextLineDataset, + 'text': _TextLineDataset, +} + + +def StreamingFilesDataset(files, + filetype=None, + file_reader_job=None, + worker_job=None, + num_epochs=None, + filename_shuffle_buffer_size=None, + num_parallel_reads=None, + batch_transfer_size=None, + sloppy=None): + """StreamingFilesDataset constructs a dataset to stream from workers (GCE VM). + + Because Cloud TPUs are allocated over the network, a Cloud TPU cannot read + files local to your GCE VM. In order to train using files stored on your local + VM (e.g. on local SSD for extreme performance), use the StreamingFilesDataset + helper to generate a dataset to feed your Cloud TPU with files from your GCE + VM. + + The resulting dataset may return an OutOfRangeError if there are no files + found as a result of the fileglob expansion. + + Note: StreamingFilesDataset assumes that the session is using a + TPUClusterResolver and has therefore a worker and a coordinator job. File + loading will be done on the coordinator job. + + Args: + files: A string glob to match files, or a `tf.data.Dataset` generating file + names. + filetype: A string (one of 'tfrecord', or 'textline') or a single-argument + TensorFlow function that when given a filename returns a dataset. + file_reader_job: An optional string that corresponds to the job that should + perform the file reads. + worker_job: An optional string that corresponds to the job that should + process the tensors (i.e. your GPU or TPU worker). + num_epochs: The number of epochs through the training set that should be + generated. By default, it will repeat infinitely. + filename_shuffle_buffer_size: An optional integer whose value controls the + shuffling of the file names. If you would like to read from the files in + the same order, set to 0 or False. + num_parallel_reads: An optional integer controlling the number of files to + read from concurrently. (Set to 1 for no parallelism.) + batch_transfer_size: An optional integer controlling the batching used to + amortize the remote function invocation overhead. Set to a very large + number to increase throughput. Set to a very small number to reduce memory + consumption. Set to False to skip batching. + sloppy: (Optional.) If `False`, read input data while maintaining a + deterministic order. (This may have significant performance impacts.) + sloppy defaults to: True. + Returns: + A `tf.data.Dataset` with an infinite stream of elements generated by a + parallel interleaving of the set of files matched (or generated) by `files` + with a type is the output of the dataset specified by `filetype`. + + Raises: + ValueError: if any argument is not of the expected type. + """ + if filetype is None: + filetype = 'tfrecord' + + if isinstance(filetype, str): + if filetype not in _FILETYPE_MAP: + raise ValueError('Unexpected filetype: %s' % filetype) + reader_fn = _FILETYPE_MAP[filetype] + elif callable(filetype): + reader_fn = filetype + else: + raise ValueError('filetype should be a string or a callable') + + file_reader_job = file_reader_job or 'coordinator' + + worker_job = worker_job or 'worker' + + if filename_shuffle_buffer_size is None: + filename_shuffle_buffer_size = 4096 + + num_parallel_reads = num_parallel_reads or 8 + + if batch_transfer_size is None: + batch_transfer_size = 256 + + if sloppy is None: + sloppy = True + + with ops.device('/job:%s' % file_reader_job): + if isinstance(files, str): + source_dataset = dataset_ops.Dataset.list_files(files) + elif isinstance(files, dataset_ops.Dataset): + source_dataset = files + else: + raise ValueError('files was not a string or a dataset: %s' % files) + + if filename_shuffle_buffer_size: + source_dataset = source_dataset.shuffle( + buffer_size=filename_shuffle_buffer_size) + + # NOTE: We perform the `repeat` on the source dataset, because the output + # dataset does not currently have enough information to recreate an iterator + # over the source dataset when it reaches the end. + source_dataset = source_dataset.repeat(num_epochs) + + source_dataset = source_dataset.apply( + interleave_ops.parallel_interleave( + reader_fn, cycle_length=num_parallel_reads, sloppy=sloppy)) + + if batch_transfer_size: + source_dataset = source_dataset.batch(batch_transfer_size) + + source_dataset = source_dataset.prefetch(1) + + source_iterator = source_dataset.make_one_shot_iterator() + source_handle = source_iterator.string_handle() + + @function.Defun(dtypes.string) + def LoadingFunc(h): + remote_iterator = iterator_ops.Iterator.from_string_handle( + h, source_dataset.output_types, source_dataset.output_shapes) + return remote_iterator.get_next() + + def MapFn(unused_input): + return functional_ops.remote_call( + args=[source_handle], + Tout=[dtypes.string], + f=LoadingFunc, + target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job) + + with ops.device('/job:%s' % worker_job): + output_dataset = dataset_ops.Dataset.range(2).repeat().map( + MapFn, num_parallel_calls=4 if sloppy else None) + output_dataset = output_dataset.prefetch(1) + + if batch_transfer_size: + # Undo the batching used during the transfer. + output_dataset = output_dataset.apply(batching.unbatch()).prefetch(1) + + return output_dataset diff --git a/tensorflow/contrib/tpu/python/tpu/datasets_test.py b/tensorflow/contrib/tpu/python/tpu/datasets_test.py new file mode 100644 index 0000000000000000000000000000000000000000..918cf0ed8e513de0d4207f7d2aac61ad886c8288 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/datasets_test.py @@ -0,0 +1,181 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TPU datasets tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.tpu.python.tpu import datasets +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.lib.io import python_io +from tensorflow.python.platform import test +from tensorflow.python.training import server_lib +from tensorflow.python.util import compat + +_NUM_FILES = 10 +_NUM_ENTRIES = 20 + + +class DatasetsTest(test.TestCase): + + def setUp(self): + super(DatasetsTest, self).setUp() + self._coord = server_lib.Server.create_local_server() + self._worker = server_lib.Server.create_local_server() + + self._cluster_def = cluster_pb2.ClusterDef() + worker_job = self._cluster_def.job.add() + worker_job.name = 'worker' + worker_job.tasks[0] = self._worker.target[len('grpc://'):] + coord_job = self._cluster_def.job.add() + coord_job.name = 'coordinator' + coord_job.tasks[0] = self._coord.target[len('grpc://'):] + + session_config = config_pb2.ConfigProto(cluster_def=self._cluster_def) + + self._sess = session.Session(self._worker.target, config=session_config) + + def testTextLineDataset(self): + all_contents = [] + for i in range(_NUM_FILES): + filename = os.path.join(self.get_temp_dir(), 'text_line.%d.txt' % i) + contents = [] + for j in range(_NUM_ENTRIES): + contents.append(compat.as_bytes('%d: %d' % (i, j))) + with open(filename, 'wb') as f: + f.write(b'\n'.join(contents)) + all_contents.extend(contents) + + dataset = datasets.StreamingFilesDataset( + os.path.join(self.get_temp_dir(), 'text_line.*.txt'), filetype='text') + + iterator = dataset.make_initializable_iterator() + self._sess.run(iterator.initializer) + get_next = iterator.get_next() + + retrieved_values = [] + for _ in range(4 * len(all_contents)): + retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) + + self.assertEqual(set(all_contents), set(retrieved_values)) + + def testTFRecordDataset(self): + all_contents = [] + for i in range(_NUM_FILES): + filename = os.path.join(self.get_temp_dir(), 'tf_record.%d' % i) + writer = python_io.TFRecordWriter(filename) + for j in range(_NUM_ENTRIES): + record = compat.as_bytes('Record %d of file %d' % (j, i)) + writer.write(record) + all_contents.append(record) + writer.close() + + dataset = datasets.StreamingFilesDataset( + os.path.join(self.get_temp_dir(), 'tf_record*'), filetype='tfrecord') + + iterator = dataset.make_initializable_iterator() + self._sess.run(iterator.initializer) + get_next = iterator.get_next() + + retrieved_values = [] + for _ in range(4 * len(all_contents)): + retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) + + self.assertEqual(set(all_contents), set(retrieved_values)) + + def testTFRecordDatasetFromDataset(self): + filenames = [] + all_contents = [] + for i in range(_NUM_FILES): + filename = os.path.join(self.get_temp_dir(), 'tf_record.%d' % i) + filenames.append(filename) + writer = python_io.TFRecordWriter(filename) + for j in range(_NUM_ENTRIES): + record = compat.as_bytes('Record %d of file %d' % (j, i)) + writer.write(record) + all_contents.append(record) + writer.close() + + filenames = dataset_ops.Dataset.from_tensor_slices(filenames) + + dataset = datasets.StreamingFilesDataset(filenames, filetype='tfrecord') + + iterator = dataset.make_initializable_iterator() + self._sess.run(iterator.initializer) + get_next = iterator.get_next() + + retrieved_values = [] + for _ in range(4 * len(all_contents)): + retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) + + self.assertEqual(set(all_contents), set(retrieved_values)) + + def testArbitraryReaderFunc(self): + + def MakeRecord(i, j): + return compat.as_bytes('%04d-%04d' % (i, j)) + + record_bytes = len(MakeRecord(10, 200)) + + all_contents = [] + for i in range(_NUM_FILES): + filename = os.path.join(self.get_temp_dir(), 'fixed_length.%d' % i) + with open(filename, 'wb') as f: + for j in range(_NUM_ENTRIES): + record = MakeRecord(i, j) + f.write(record) + all_contents.append(record) + + def FixedLengthFile(filename): + return readers.FixedLengthRecordDataset(filename, record_bytes) + + dataset = datasets.StreamingFilesDataset( + os.path.join(self.get_temp_dir(), 'fixed_length*'), + filetype=FixedLengthFile) + + iterator = dataset.make_initializable_iterator() + self._sess.run(iterator.initializer) + get_next = iterator.get_next() + + retrieved_values = [] + for _ in range(4 * len(all_contents)): + retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) + + self.assertEqual(set(all_contents), set(retrieved_values)) + + def testUnexpectedFiletypeString(self): + with self.assertRaises(ValueError): + datasets.StreamingFilesDataset( + os.path.join(self.get_temp_dir(), '*'), filetype='foo') + + def testUnexpectedFiletypeType(self): + with self.assertRaises(ValueError): + datasets.StreamingFilesDataset( + os.path.join(self.get_temp_dir(), '*'), filetype=3) + + def testUnexpectedFilesType(self): + with self.assertRaises(ValueError): + datasets.StreamingFilesDataset(123, filetype='tfrecord') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/tpu/python/tpu/device_assignment.py b/tensorflow/contrib/tpu/python/tpu/device_assignment.py index bdd9b88af55fa4fb483ddbdbe5c51d7076cce675..726b2d248e3086e1882004827076ed3e563d960d 100644 --- a/tensorflow/contrib/tpu/python/tpu/device_assignment.py +++ b/tensorflow/contrib/tpu/python/tpu/device_assignment.py @@ -191,9 +191,9 @@ class DeviceAssignment(object): logical_core: A tuple of three integers which represents a logical core. Returns: A sorted list of the replicas that are attached to that task and - loical_core. + logical_core. Raises: - ValueError: If no replica exisis in the task which contains the logical + ValueError: If no replica exists in the task which contains the logical core. """ try: diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index d5f54ff4fd278f0c84f79e0079bfb7a409dfba8d..3f2db548ace9e10df7844d8fb461670d27234670 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -201,7 +201,7 @@ def replicate(computation, `DeviceAssignment` may be omitted if each replica of the computation uses only one core, and there is either only one replica, or the number of replicas is equal to the number of cores in the TPU system. - name: The name of the operator. + name: (Deprecated) Does nothing. Returns: A list of lists of output tensors, indexed by `[replica_num][output_num]`. Raises: @@ -209,8 +209,7 @@ def replicate(computation, ValueError: If the number of inputs per replica does not match the number of formal parameters to `computation`. """ - if name is None: - name = "TPUReplicate" + del name inputs = [[]] if inputs is None else inputs metadata_kwargs = {} @@ -274,118 +273,117 @@ def replicate(computation, graph = ops.get_default_graph() - with ops.name_scope(name, "replicate"): - # Fan-in: Builds a TPUReplicatedInput node for each input. - computation_inputs = [] - for i in range(0, input_arity): - replicas = [inputs[replica][i] for replica in xrange(num_replicas)] - computation_inputs.append( - tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) + # Fan-in: Builds a TPUReplicatedInput node for each input. + computation_inputs = [] + for i in range(0, input_arity): + replicas = [inputs[replica][i] for replica in xrange(num_replicas)] + computation_inputs.append( + tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) + + context = TPUReplicateContext(name=graph.unique_name("cluster")) + try: + context.Enter() + + metadata = tpu_ops.tpu_replicate_metadata( + num_replicas=num_replicas, **metadata_kwargs) + + with tpu_function.tpu_shard_context( + num_replicas), ops.control_dependencies([metadata]): + + # The EncapsulateTPUComputations rewrite needs to identify the + # replicated arguments inside each computation. Adds identity operators + # tagged with an attribute _tpu_replicated_input to identify the + # replicated inputs. + # pylint: disable=protected-access + with graph._attr_scope({"_tpu_replicated_input": + attr_value_pb2.AttrValue(b=True)}): + computation_inputs = [ + array_ops.identity(x, name="replicated_input_{}".format(i)) + for i, x in enumerate(computation_inputs)] + # pylint: enable=protected-access + + # If there is an infeed queue, adds the dequeued values to the + # computation's inputs. + if infeed_queue is not None: + infeed_queue.set_number_of_shards(num_replicas) + for t in infeed_queue.generate_dequeue_op(): + computation_inputs.append(t) + + # Only resource variables work inside a TPU computation, so turn on + # resource variables for the computation. + # TODO(phawkins): consider removing this code. It will + # be less confusing to clients if they knowingly choose to use resource + # variables. + vscope = variable_scope.get_variable_scope() + saved_use_resource = vscope.use_resource + vscope.set_use_resource(True) + + outputs = computation(*computation_inputs) + + vscope.set_use_resource(saved_use_resource) + + # If the computation only returned one value, makes it a tuple. + if not isinstance(outputs, (list, tuple)): + outputs = (outputs,) - context = TPUReplicateContext(name=graph.unique_name("cluster")) try: - context.Enter() - - metadata = tpu_ops.tpu_replicate_metadata( - num_replicas=num_replicas, **metadata_kwargs) - - with tpu_function.tpu_shard_context( - num_replicas), ops.control_dependencies([metadata]): - - # The EncapsulateTPUComputations rewrite needs to identify the - # replicated arguments inside each computation. Adds identity operators - # tagged with an attribute _tpu_replicated_input to identify the - # replicated inputs. - # pylint: disable=protected-access - with graph._attr_scope({"_tpu_replicated_input": - attr_value_pb2.AttrValue(b=True)}): - computation_inputs = [ - array_ops.identity(x, name="replicated_input_{}".format(i)) - for i, x in enumerate(computation_inputs)] - # pylint: enable=protected-access - - # If there is an infeed queue, adds the dequeued values to the - # computation's inputs. - if infeed_queue is not None: - infeed_queue.set_number_of_shards(num_replicas) - for t in infeed_queue.generate_dequeue_op(): - computation_inputs.append(t) - - # Only resource variables work inside a TPU computation, so turn on - # resource variables for the computation. - # TODO(phawkins): consider removing this code. It will - # be less confusing to clients if they knowingly choose to use resource - # variables. - vscope = variable_scope.get_variable_scope() - saved_use_resource = vscope.use_resource - vscope.set_use_resource(True) - - outputs = computation(*computation_inputs) - - vscope.set_use_resource(saved_use_resource) - - # If the computation only returned one value, makes it a tuple. - if not isinstance(outputs, (list, tuple)): - outputs = (outputs,) - - try: - with ops.device(core(0)): - outputs = [ - o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) - for o in outputs - ] - except Exception as e: - raise ValueError( - "TPU function return values must all either be Operations or " - "convertible to Tensors. Got '%s'" % str(e)) - - # Separates the returned Operations and Tensors. - output_operations = [o for o in outputs if isinstance(o, ops.Operation)] - output_tensors = [o for o in outputs - if not isinstance(o, ops.Operation)] - - if outputs != output_tensors + output_operations: - raise ValueError( - "TPU functions must return zero-or more Tensor values followed by " - "zero or more Operations.") - output_arity = len(output_tensors) - - # Wraps outputs in Identity ops. Otherwise a replicated input copied - # straight to an output would bypass the replicate(). This would be bad - # because the TPUReplicatedInput/TPUReplicatedOutput operator would not - # be rewritten away, leading to a runtime error. - # TODO(phawkins): extend the rewrite to elide these nodes instead. - new_output_tensors = [] - for t in output_tensors: - with ops.device(t.device if t.device else core(0)): - new_output_tensors.append(array_ops.identity(t)) - output_tensors = new_output_tensors - finally: - context.report_unsupported_operations() - context.Exit() - - # Fan-out: Builds a TPUReplicatedOutput node for each output. - outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas, - name="output{}".format(i)) - for i in xrange(output_arity)] - - with ops.control_dependencies(output_operations): - if output_arity == 0: - # Returns a list of NoOps dependent on the replication Op, indexed by - # [replica_num]. - return [ - control_flow_ops.no_op(name="%s_shard_%d" % (name, i)) - for i in range(num_replicas) - ] - else: - # Wraps the outputs in identity operators so the names of any possible - # `fetch` nodes are preserved by the replication rewrite. - return [ - [array_ops.identity(outputs[out][replica], - name="output_%d_shard_%d" % (out, replica)) - for out in xrange(output_arity)] - for replica in xrange(num_replicas) + with ops.device(core(0)): + outputs = [ + o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) + for o in outputs ] + except Exception as e: + raise ValueError( + "TPU function return values must all either be Operations or " + "convertible to Tensors. Got '%s'" % str(e)) + + # Separates the returned Operations and Tensors. + output_operations = [o for o in outputs if isinstance(o, ops.Operation)] + output_tensors = [o for o in outputs + if not isinstance(o, ops.Operation)] + + if outputs != output_tensors + output_operations: + raise ValueError( + "TPU functions must return zero-or more Tensor values followed by " + "zero or more Operations.") + output_arity = len(output_tensors) + + # Wraps outputs in Identity ops. Otherwise a replicated input copied + # straight to an output would bypass the replicate(). This would be bad + # because the TPUReplicatedInput/TPUReplicatedOutput operator would not + # be rewritten away, leading to a runtime error. + # TODO(phawkins): extend the rewrite to elide these nodes instead. + new_output_tensors = [] + for t in output_tensors: + with ops.device(t.device if t.device else core(0)): + new_output_tensors.append(array_ops.identity(t)) + output_tensors = new_output_tensors + finally: + context.report_unsupported_operations() + context.Exit() + + # Fan-out: Builds a TPUReplicatedOutput node for each output. + outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas, + name="output{}".format(i)) + for i in xrange(output_arity)] + + with ops.control_dependencies(output_operations): + if output_arity == 0: + # Returns a list of NoOps dependent on the replication Op, indexed by + # [replica_num]. + return [ + control_flow_ops.no_op(name="shard_%d" % i) + for i in range(num_replicas) + ] + else: + # Wraps the outputs in identity operators so the names of any possible + # `fetch` nodes are preserved by the replication rewrite. + return [ + [array_ops.identity(outputs[out][replica], + name="output_%d_shard_%d" % (out, replica)) + for out in xrange(output_arity)] + for replica in xrange(num_replicas) + ] def shard(computation, @@ -450,7 +448,7 @@ def shard(computation, `DeviceAssignment` may be omitted if each shard of the computation uses only one core, and there is either only one shard, or the number of shards is equal to the number of cores in the TPU system. - name: The name of the operator. + name: (Deprecated) Does nothing. Returns: A list of output tensors. Raises: @@ -579,7 +577,7 @@ def batch_parallel(computation, `DeviceAssignment` may be omitted if each shard of the computation uses only one core, and there is either only one shard, or the number of shards is equal to the number of cores in the TPU system. - name: The name of the operator. + name: (Deprecated) Does nothing. Returns: A list of output tensors. Raises: @@ -613,7 +611,7 @@ def rewrite(computation, mapping between logical cores in the computation with physical cores in the TPU topology. May be omitted for a single-core computation, in which case the core attached to task 0, TPU device 0 is used. - name: The name of the operator. + name: (Deprecated) Does nothing. Returns: A list of output tensors. """ diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 644070218214643923b9ca3ee138615ec568e8b5..38b5ea23103730630ae8e1cdd7b9180a501013c5 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -26,6 +26,7 @@ import os import numpy as np from tensorflow.contrib.tpu.python.tpu import util as util_lib +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.platform import tf_logging as logging @@ -65,7 +66,7 @@ class TPUConfig( cores. This is required by model-parallelism which enables partitioning the model to multiple cores. For example, [2, 2, 1] means the model is partitioned across 4 cores which span two cores in both x and y - coordinates. Please refer to ${tf.contrib.tpu.TopologyProto} for the + coordinates. Please refer to @{tf.contrib.tpu.Topology} for the geometry of a TPU mesh. per_host_input_for_training: If `True`, `input_fn` is invoked Per-Host rather than Per-Core. With Per-Host input pipeline deployment, `input_fn` @@ -140,6 +141,7 @@ class RunConfig(run_config_lib.RunConfig): tpu_config=None, evaluation_master=None, master=None, + cluster=None, **kwargs): """Constructs a RunConfig. @@ -148,15 +150,26 @@ class RunConfig(run_config_lib.RunConfig): evaluation_master: a string. The address of the master to use for eval. Defaults to master if not set. master: a string. The address of the master to use for training. + cluster: a ClusterResolver **kwargs: keyword config parameters. + + Raises: + ValueError: if cluster is not None and the provided session_config has a + cluster_def already. """ super(RunConfig, self).__init__(**kwargs) self._tpu_config = tpu_config or TPUConfig() + self._cluster = cluster - # If user sets master and/or evaluation_master explicilty, including empty + # If user sets master and/or evaluation_master explicitly, including empty # string '', take it. Otherwise, take the values set by parent class. if master is not None: + if cluster is not None: + raise ValueError('Both master and cluster are set.') self._master = master + else: + if cluster: + self._master = cluster.master() if evaluation_master is not None: self._evaluation_master = evaluation_master @@ -170,6 +183,20 @@ class RunConfig(run_config_lib.RunConfig): # evaluation_master to master, unless user overwrites it. self._evaluation_master = self._master + # Set the ClusterSpec to use + if cluster: + self._cluster_spec = cluster.cluster_spec() + + # Merge the cluster_def into the ConfigProto. + if self._session_config is None: # pylint: disable=access-member-before-definition + self._session_config = config_pb2.ConfigProto(allow_soft_placement=True) + if self._session_config.HasField('cluster_def'): + raise ValueError( + 'You cannot provide a ClusterResolver and ' + 'session_config.cluster_def.') + self._session_config.cluster_def.CopyFrom( + self._cluster_spec.as_cluster_def()) + @property def evaluation_master(self): return self._evaluation_master @@ -182,6 +209,10 @@ class RunConfig(run_config_lib.RunConfig): def tpu_config(self): return self._tpu_config + @property + def cluster(self): + return self._cluster + def replace(self, **kwargs): if 'tpu_config' not in kwargs: return super(RunConfig, self).replace(**kwargs) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index c5c46ea741ea64ca37089431f8ed66cad7bc31fb..3bac2db77e95520a6c9c4c17658267a9a6588d94 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -39,7 +39,7 @@ class _TPUContext(object): This immutable object holds TPUEstimator config, train/eval batch size, and `TPUEstimator.use_tpu`, which is expected to be passed around. It also - provides utility functions, basded on the current state, to determine other + provides utility functions, based on the current state, to determine other information commonly required by TPU computation, such as TPU device names, TPU hosts, shard batch size, etc. @@ -218,7 +218,7 @@ class _TPUContext(object): model, when mode == PREDICT. Only with this bool, we could tell whether user is calling the Estimator.predict or Estimator.export_savedmodel, which are running on TPU and CPU - respectively. Parent class Estimator does not distingush these two. + respectively. Parent class Estimator does not distinguish these two. Returns: bool, whether current input_fn or model_fn should be running on CPU. diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 1b2eda1caa0fa2779834d65b5a49121d9cc0af56..152f8c8c69ef7344c1346885cbdf8059e0849db3 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -25,6 +25,7 @@ import threading import time import traceback +import numpy as np import six from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin @@ -48,6 +49,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -61,6 +63,7 @@ from tensorflow.python.training import evaluation from tensorflow.python.training import session_run_hook from tensorflow.python.training import training from tensorflow.python.training import training_util +from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect _INITIAL_LOSS = 1e7 @@ -69,6 +72,7 @@ _TPU_ESTIMATOR = 'tpu_estimator' _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' _BATCH_SIZE_KEY = 'batch_size' _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' +_ONE_GIGABYTE = 1024 * 1024 * 1024 _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY] @@ -133,7 +137,7 @@ def _increase_eval_step_op(iterations_per_loop): """Returns an op to increase the eval step for TPU evaluation. Args: - iterations_per_loop: Tensor. The number of eval steps runnining in TPU + iterations_per_loop: Tensor. The number of eval steps running in TPU system before returning to CPU host for each `Session.run`. Returns: @@ -605,17 +609,17 @@ class _StoppingPredictHook(session_run_hook.SessionRunHook): # batch. And we append one more batch to signal the system it should stop. # The data flow might look like # - # batch 0: images, labels, stop = 0 (user provideded) - # batch 1: images, labels, stop = 0 (user provideded) + # batch 0: images, labels, stop = 0 (user provided) + # batch 1: images, labels, stop = 0 (user provided) # ... - # batch 99: images, labels, stop = 0 (user provideded) + # batch 99: images, labels, stop = 0 (user provided) # batch 100: images, labels, stop = 1 (TPUEstimator appended) # # where the final batch (id = 100) is appended by TPUEstimator, so we # should drop it before returning the predictions to user. # To achieve that, we throw the OutOfRangeError in after_run. Once # Monitored Session sees this error in SessionRunHook.after_run, the - # "current" prediciton, i.e., batch with id=100, will be discarded + # "current" prediction, i.e., batch with id=100, will be discarded # immediately raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.') @@ -676,8 +680,11 @@ def generate_per_host_enqueue_ops_fn_for_host( raise TypeError( 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' '`features` and `labels`.') + if batch_axis is not None: + raise TypeError('For mode PREDICT, batch_axis is not supported yet.') inputs = _InputsWithStoppingSignals( - dataset=inputs.dataset, batch_size=ctx.batch_size_for_input_fn) + dataset=inputs.dataset, batch_size=ctx.batch_size_for_input_fn, + add_padding=True) if is_dataset: hooks.append(inputs.dataset_initializer_hook()) @@ -751,7 +758,7 @@ class _InputPipeline(object): 2. (features, labels) Internally, form 1 is reformed to `(features, None)` as features and labels - are passed separatedly to underlying methods. For TPU training, TPUEstimator + are passed separately to underlying methods. For TPU training, TPUEstimator may expect multiple `features` and `labels` tuples one for each core. TPUEstimator allows various different structures for inputs (namely `features` @@ -784,7 +791,8 @@ class _InputPipeline(object): def _extract_key_names(tensor_or_dict): if tensor_or_dict is None: return [] - return tensor_or_dict.keys() if isinstance(tensor_or_dict, dict) else [] + return sorted(tensor_or_dict.keys()) if isinstance( + tensor_or_dict, dict) else [] # Extract structure. has_labels = labels is not None @@ -923,8 +931,7 @@ class _InputPipeline(object): # In the model-parallel case, both the host-side and device-side # computations must agree on the core on which infeed takes place. We # choose to perform infeed on logical core 0 of each replica. - with ops.device(tpu.core(0)): - values = self._infeed_queue.generate_dequeue_op() + values = self._infeed_queue.generate_dequeue_op(tpu_device=0) # The unflatten process uses the structure information recorded above. return self._inputs_structure_recorder.unflatten_features_and_labels( values) @@ -1036,8 +1043,8 @@ class _ModelFnWrapper(object): self._params = params self._ctx = ctx - def call_without_tpu(self, features, labels): - return self._call_model_fn(features, labels) + def call_without_tpu(self, features, labels, is_export_mode): + return self._call_model_fn(features, labels, is_export_mode=is_export_mode) def convert_to_single_tpu_train_step(self, dequeue_fn): """Converts user provided model_fn` as a single train step on TPU. @@ -1196,7 +1203,7 @@ class _ModelFnWrapper(object): return predict_step, host_calls, captured_scaffold_fn - def _call_model_fn(self, features, labels, is_export_mode=True): + def _call_model_fn(self, features, labels, is_export_mode=False): """Calls the model_fn with required parameters.""" model_fn_args = util.fn_args(self._model_fn) kwargs = {} @@ -1222,7 +1229,11 @@ class _ModelFnWrapper(object): 'required by TPUEstimator to pass batch size as ' 'params[\'batch_size\']'.format(self._model_fn)) - batch_size_for_model_fn = self._ctx.batch_size_for_model_fn + if is_export_mode: + batch_size_for_model_fn = None + else: + batch_size_for_model_fn = self._ctx.batch_size_for_model_fn + if batch_size_for_model_fn is not None: params[_BATCH_SIZE_KEY] = batch_size_for_model_fn @@ -1516,14 +1527,20 @@ class TPUEstimator(estimator_lib.Estimator): size when calling the `input_fn` and `model_fn`. Users should specify global batch size in constructor, and then get the batch size for each shard in `input_fn` and `model_fn` by `params['batch_size']`. - For training, `model_fn` gets per-core batch size; `input_fn` may get - per-core or per-host batch size depending on - `per_host_input_for_training` in `TPUConfig`. - For evaluation, `model_fn` gets per-core batch size and `input_fn` get - per-host batch size. + + - For training, `model_fn` gets per-core batch size; `input_fn` may get + per-core or per-host batch size depending on `per_host_input_for_training` + in `TPUConfig` (See docstring for TPUConfig for details). + + - For evaluation and prediction, `model_fn` gets per-core batch size and + `input_fn` get per-host batch size. + + Evaluation + ========== `model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics` for TPU evaluation. + `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. (See `TPUEstimatorSpec` for details). `metric_fn` takes the `tensors` and returns @@ -1535,12 +1552,17 @@ class TPUEstimator(estimator_lib.Estimator): `train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`. Current limitations: + -------------------- - 1. TPU evaluation only works on single host. - 2. `input_fn` for evaluation should not throw OutOfRange error for all - evaluation steps and all batches should have the same size. + 1. TPU evaluation only works on a single host (one TPU worker). + + 2. `input_fn` for evaluation should **NOT** raise an end-of-input exception + (`OutOfRangeError` or `StopIteration`). And all evaluation steps and all + batches should have the same size. Example (MNIST): + ---------------- + ``` # The metric Fn which runs on CPU. def metric_fn(labels, logits): @@ -1576,8 +1598,83 @@ class TPUEstimator(estimator_lib.Estimator): })) ``` - Predict support on TPU is not yet implemented. So, `predict` and - `export_savedmodel` are executed on CPU, even if `use_tpu` is true. + Prediction + ========== + + Prediction on TPU is an experimental feature to support large batch inference. + It is not designed for latency-critical system. In addition, due to some + usability issues, for prediction with small dataset, CPU `.predict`, i.e., + creating a new `TPUEstimator` instance with `use_tpu=False`, might be more + convenient. + + Note: In contrast to TPU training/evaluation, the `input_fn` for prediction + *should* raise an end-of-input exception (`OutOfRangeError` or + `StopIteration`), which serves as the stopping signal to `TPUEstimator`. To be + precise, the ops created by `input_fn` produce one batch of the data. + The `predict()` API processes one batch at a time. When reaching the end of + the data source, an end-of-input exception should be raised by one of these + operations. The user usually does not need to do this manually. As long as the + dataset is not repeated forever, the `tf.data` API will raise an end-of-input + exception automatically after the last batch has been produced. + + Note: Estimator.predict returns a Python generator. Please consume all the + data from the generator so that TPUEstimator can shutdown the TPU system + properly for user. + + Current limitations: + -------------------- + 1. TPU prediction only works on a single host (one TPU worker). + + 2. `input_fn` must return a `Dataset` instance rather than `features`. In + fact, .train() and .evaluate() also support Dataset as return value. + + Example (MNIST): + ---------------- + ``` + height = 32 + width = 32 + total_examples = 100 + + def predict_input_fn(params): + batch_size = params['batch_size'] + + images = tf.random_uniform( + [total_examples, height, width, 3], minval=-1, maxval=1) + + dataset = tf.data.Dataset.from_tensor_slices(images) + dataset = dataset.map(lambda images: {'image': images}) + + dataset = dataset.batch(batch_size) + return dataset + + def model_fn(features, labels, params, mode): + # Generate predictions, called 'output', from features['image'] + + if mode == tf.estimator.ModeKeys.PREDICT: + return tf.contrib.tpu.TPUEstimatorSpec( + mode=mode, + predictions={ + 'predictions': output, + 'is_padding': features['is_padding'] + }) + + tpu_est = TPUEstimator( + model_fn=model_fn, + ..., + predict_batch_size=16) + + # Fully consume the generator so that TPUEstimator can shutdown the TPU + # system. + for item in tpu_est.predict(input_fn=input_fn): + # Filter out item if the `is_padding` is 1. + # Process the 'predictions' + ``` + + Exporting + ========= + + Exporting `SavedModel` support on TPU is not yet implemented. So, + `export_savedmodel` is executed on CPU, even if `use_tpu` is true. """ def __init__(self, @@ -1684,6 +1781,8 @@ class TPUEstimator(estimator_lib.Estimator): eval_batch_size, predict_batch_size, use_tpu) + self._is_input_fn_invoked = None + def _create_global_step(self, graph): """Creates a global step suitable for TPUs. @@ -1766,6 +1865,9 @@ class TPUEstimator(estimator_lib.Estimator): if 'mode' in input_fn_args: kwargs['mode'] = mode + # Records the fact input_fn has been invoked. + self._is_input_fn_invoked = True + with self._ctx.with_mode(mode) as ctx: # Setting the batch size in params first. This helps user to have same # input_fn for use_tpu=True/False. @@ -1794,6 +1896,17 @@ class TPUEstimator(estimator_lib.Estimator): return _input_fn + def _validate_features_in_predict_input(self, result): + """Skip the validation. + + For TPUEstimator, we do not need to check the result type. `_InputPipeline` + has stronger check. Parent class's check generates confusing warning msg. + + Args: + result: `features` returned by input_fn. + """ + pass + def _augment_model_fn(self, model_fn, batch_axis): """Returns a new model_fn, which wraps the TPU support.""" @@ -1802,15 +1915,24 @@ class TPUEstimator(estimator_lib.Estimator): with self._ctx.with_mode(mode) as ctx: model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx) - # For export_savedmodel, input_fn is never passed to Estimator. So, - # if features is callable, it means it is the input_fn passed by - # TPUEstimator._call_input_fn. Then we can know if the mode == PREDICT, - # it implies, it is the .predict API, not export_savedmodel API. - is_export_mode = not callable(features) + if mode != model_fn_lib.ModeKeys.PREDICT: + is_export_mode = False + else: + # For export_savedmodel, input_fn is never passed to Estimator. So, by + # checking the self._is_input_fn_invoked bit, we can know, given the + # mode == PREDICT, it is the .predict API, not export_savedmodel API. + if self._is_input_fn_invoked: + is_export_mode = False + else: + is_export_mode = True + + # Clear the bit. + self._is_input_fn_invoked = None if ctx.is_running_on_cpu(is_export_mode=is_export_mode): logging.info('Running %s on CPU', mode) - return model_fn_wrapper.call_without_tpu(features, labels) + return model_fn_wrapper.call_without_tpu( + features, labels, is_export_mode=is_export_mode) assert labels is None, '`labels` passed to `model_fn` must be `None`.' # TPUEstimator._call_input_fn passes `input_fn` as features to here. @@ -1948,12 +2070,18 @@ class TPUEstimator(estimator_lib.Estimator): host_ops = host_call_ret['host_call'] predictions = host_call_ret['predictions'] - stopping_signals = host_call_ret['signals'] + _verify_cross_hosts_transfer_size( + predictions, message=( + 'The estimated size for TPUEstimatorSpec.predictions is too ' + 'large.')) + signals = host_call_ret['signals'] with ops.control_dependencies(host_ops): host_ops = [] # Empty, we do do not need it anymore. scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal( - stopping_signals) + signals) + predictions = _PaddingSignals.slice_tensor_or_dict( + predictions, signals) hooks = [ _StoppingPredictHook(scalar_stopping_signal), @@ -1980,8 +2108,7 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): def multi_tpu_eval_steps_on_single_shard(): return training_loop.repeat( iterations_per_loop_var, - single_tpu_eval_step, [_ZERO_LOSS], - name='loop') + single_tpu_eval_step, [_ZERO_LOSS]) (loss,) = tpu.shard( multi_tpu_eval_steps_on_single_shard, @@ -2004,8 +2131,7 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): def multi_tpu_train_steps_on_single_shard(): return training_loop.repeat( iterations_per_loop_var, - single_tpu_train_step, [_INITIAL_LOSS], - name=b'loop') + single_tpu_train_step, [_INITIAL_LOSS]) (loss,) = tpu.shard( multi_tpu_train_steps_on_single_shard, @@ -2248,20 +2374,19 @@ class _Inputs(object): return self._dataset -# TODO(xiejw): Extend this to support final partial batch. class _InputsWithStoppingSignals(_Inputs): """Inputs with `_StopSignals` inserted into the dataset.""" - def __init__(self, dataset, batch_size): + def __init__(self, dataset, batch_size, add_padding=False): assert dataset is not None user_provided_dataset = dataset.map( _InputsWithStoppingSignals.insert_stopping_signal( - stop=False, batch_size=batch_size)) + stop=False, batch_size=batch_size, add_padding=add_padding)) final_batch_dataset = dataset.take(1).map( _InputsWithStoppingSignals.insert_stopping_signal( - stop=True, batch_size=batch_size)) + stop=True, batch_size=batch_size, add_padding=add_padding)) dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2) super(_InputsWithStoppingSignals, self).__init__(dataset=dataset) @@ -2291,7 +2416,7 @@ class _InputsWithStoppingSignals(_Inputs): return signals @staticmethod - def insert_stopping_signal(stop, batch_size): + def insert_stopping_signal(stop, batch_size, add_padding=False): """Inserts stopping_signal into dataset via _map_fn. Here we change the data structure in the dataset, such that the return value @@ -2302,19 +2427,39 @@ class _InputsWithStoppingSignals(_Inputs): Args: stop: bool, state of current stopping signals. batch_size: int, batch size. + add_padding: bool, whether to pad the tensor to full batch size. Returns: A map_fn passed to dataset.map API. """ def _map_fn(*args): + """The map fn to insert signals.""" + if len(args) == 1: + # Unpack the single Tensor/dict argument as features. This is required + # for the input_fn returns no labels. + args = args[0] features, labels = _Inputs._parse_inputs(args) new_input_dict = {} - new_input_dict['features'] = features - if labels is not None: - new_input_dict['labels'] = labels + + if add_padding: + padding_mask, features, labels = ( + _PaddingSignals.pad_features_and_labels( + features, labels, batch_size)) + + new_input_dict['features'] = features + if labels is not None: + new_input_dict['labels'] = labels + + else: + new_input_dict['features'] = features + if labels is not None: + new_input_dict['labels'] = labels + padding_mask = None + new_input_dict['signals'] = _StopSignals( - stop=stop, batch_size=batch_size).as_dict() + stop=stop, batch_size=batch_size, padding_mask=padding_mask).as_dict() + return new_input_dict return _map_fn @@ -2323,23 +2468,28 @@ class _InputsWithStoppingSignals(_Inputs): class _StopSignals(object): """Signals class holding all logic to handle TPU stopping condition.""" - NON_STOPPING_SIGNAL = 0.0 - STOPPING_SIGNAL = 1.0 + NON_STOPPING_SIGNAL = False + STOPPING_SIGNAL = True - def __init__(self, stop, batch_size): + def __init__(self, stop, batch_size, padding_mask=None): self._stop = stop self._batch_size = batch_size + self._padding_mask = padding_mask def as_dict(self): + """Returns the signals as Python dict.""" shape = [self._batch_size, 1] - dtype = dtypes.float32 + dtype = dtypes.bool if self._stop: stopping = array_ops.ones(shape=shape, dtype=dtype) else: stopping = array_ops.zeros(shape=shape, dtype=dtype) - return {'stopping': stopping} + signals = {'stopping': stopping} + if self._padding_mask is not None: + signals['padding_mask'] = self._padding_mask + return signals @staticmethod def as_scalar_stopping_signal(signals): @@ -2347,7 +2497,118 @@ class _StopSignals(object): @staticmethod def should_stop(scalar_stopping_signal): - return scalar_stopping_signal >= _StopSignals.STOPPING_SIGNAL + if isinstance(scalar_stopping_signal, ops.Tensor): + # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF + # way to express the bool check whether scalar_stopping_signal is True. + return math_ops.logical_and( + scalar_stopping_signal, _StopSignals.STOPPING_SIGNAL) + else: + # For non Tensor case, it is used in SessionRunHook. So, we cannot modify + # the graph anymore. Here, we use pure Python. + return bool(scalar_stopping_signal) + + +class _PaddingSignals(object): + """Signals class holding all logic to handle padding.""" + + @staticmethod + def pad_features_and_labels(features, labels, batch_size): + """Pads out the batch dimension of features and labels.""" + real_batch_size = array_ops.shape( + _PaddingSignals._find_any_tensor(features))[0] + + batch_size_tensor = constant_op.constant(batch_size, dtypes.int32) + + check_greater = check_ops.assert_greater_equal( + batch_size_tensor, real_batch_size, + data=(batch_size_tensor, real_batch_size), + message='The real batch size should not be greater than batch_size.') + + with ops.control_dependencies([check_greater]): + missing_count = batch_size_tensor - real_batch_size + + def pad_single_tensor(tensor): + """Pads out the batch dimension of a tensor to the complete batch_size.""" + rank = len(tensor.shape) + assert rank > 0 + padding = array_ops.stack([[0, missing_count]] + [[0, 0]] * (rank - 1)) + padded_shape = (batch_size,) + tuple(tensor.shape[1:]) + padded_tensor = array_ops.pad(tensor, padding) + padded_tensor.set_shape(padded_shape) + return padded_tensor + + def nest_pad(tensor_or_dict): + return nest.map_structure(pad_single_tensor, tensor_or_dict) + + features = nest_pad(features) + if labels is not None: + labels = nest_pad(labels) + + padding_mask = _PaddingSignals._padding_mask( + real_batch_size, missing_count, batch_size) + + return padding_mask, features, labels + + @staticmethod + def slice_tensor_or_dict(tensor_or_dict, signals): + """Slice the real Tensors according to padding mask in signals.""" + + padding_mask = signals['padding_mask'] + batch_size = array_ops.shape(padding_mask)[0] + + def verify_batch_size(tensor): + check_batch_size = math_ops.equal(batch_size, tensor.shape[0]) + with ops.control_dependencies([check_batch_size]): + return array_ops.identity(tensor) + + def slice_single_tensor(tensor): + rank = len(tensor.shape) + assert rank > 0 + real_batch_size = batch_size - math_ops.reduce_sum(padding_mask) + return verify_batch_size(tensor)[0:real_batch_size] + + # As we split the Tensors to all TPU cores and concat them back, it is + # important to ensure the real data is placed before padded ones, i.e., + # order is preserved. By that, the sliced padding mask should have all 0's. + # If this assertion failed, # the slice logic here would not hold. + sliced_padding_mask = slice_single_tensor(padding_mask) + assert_padding_mask = math_ops.equal( + math_ops.reduce_sum(sliced_padding_mask), 0) + + with ops.control_dependencies([assert_padding_mask]): + should_stop = _StopSignals.should_stop( + _StopSignals.as_scalar_stopping_signal(signals)) + + is_full_batch = math_ops.equal(math_ops.reduce_sum(padding_mask), 0) + + def slice_fn(tensor): + # If the current batch is full batch or part of stopping signals, we do + # not need to slice to save performance. + return control_flow_ops.cond( + math_ops.logical_or(should_stop, is_full_batch), + (lambda: verify_batch_size(tensor)), + (lambda: slice_single_tensor(tensor))) + + return nest.map_structure(slice_fn, tensor_or_dict) + + @staticmethod + def _find_any_tensor(batch_features): + tensors = [x for x in nest.flatten(batch_features) + if isinstance(x, ops.Tensor)] + if not tensors: + raise ValueError('Cannot find any Tensor in features dict.') + return tensors[0] + + @staticmethod + def _padding_mask(real_batch_size, missing_count, batch_size): + padding_mask = array_ops.concat( + [ + array_ops.zeros((real_batch_size,), dtype=dtypes.int32), + array_ops.ones((missing_count,), dtype=dtypes.int32) + ], + axis=0) + padding_mask.set_shape((batch_size,)) + return padding_mask class _SignalsHelper(object): @@ -2368,3 +2629,21 @@ class _SignalsHelper(object): @staticmethod def as_tensor_list(signals): return [signals[key] for key in sorted(signals.iterkeys())] + + +def _verify_cross_hosts_transfer_size(tensor_dict, message): + total_size = 0 + tensor_structure = {} + for key, tensor in tensor_dict.items(): + shape = tensor.shape + size = np.product(shape) * tensor.dtype.size + tensor_structure[key] = shape + total_size += size + if total_size >= _ONE_GIGABYTE: + raise ValueError( + '{} The transfer size is larger than the protobuf limit. Please ' + 'consider to use Tensors with smaller shapes or reduce batch ' + 'size. Given:\n' + '{}'.format(message, '\n'.join([ + ' -- Key: {}, Shape: {}'.format(k, v) + for k, v in tensor_structure.items()]))) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3e90957e6dea7ff1777dd3e26cdf1c6fdb340dd3 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py @@ -0,0 +1,291 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TPU Estimator Signalling Tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tpu.python.tpu import tpu_estimator +from tensorflow.python import data as dataset_lib +from tensorflow.python.client import session +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.platform import test + + +def make_input_fn(num_samples): + a = np.linspace(0, 100.0, num=num_samples) + b = np.reshape(np.array(a, dtype=np.float32), (len(a), 1)) + + def input_fn(params): + batch_size = params['batch_size'] + da1 = dataset_lib.Dataset.from_tensor_slices(a) + da2 = dataset_lib.Dataset.from_tensor_slices(b) + + dataset = dataset_lib.Dataset.zip((da1, da2)) + dataset = dataset.map(lambda fa, fb: {'a': fa, 'b': fb}) + dataset = dataset.batch(batch_size) + return dataset + return input_fn, (a, b) + + +def make_input_fn_with_labels(num_samples): + a = np.linspace(0, 100.0, num=num_samples) + b = np.reshape(np.array(a, dtype=np.float32), (len(a), 1)) + + def input_fn(params): + batch_size = params['batch_size'] + da1 = dataset_lib.Dataset.from_tensor_slices(a) + da2 = dataset_lib.Dataset.from_tensor_slices(b) + + dataset = dataset_lib.Dataset.zip((da1, da2)) + dataset = dataset.map(lambda fa, fb: ({'a': fa}, fb)) + dataset = dataset.batch(batch_size) + return dataset + return input_fn, (a, b) + + +class TPUEstimatorStoppingSignalsTest(test.TestCase): + + def test_normal_output_without_signals(self): + num_samples = 4 + batch_size = 2 + + params = {'batch_size': batch_size} + input_fn, (a, b) = make_input_fn(num_samples=num_samples) + + with ops.Graph().as_default(): + dataset = input_fn(params) + features = dataset.make_one_shot_iterator().get_next() + + # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape. + self.assertIsNone(features['a'].shape.as_list()[0]) + + with session.Session() as sess: + result = sess.run(features) + self.assertAllEqual(a[:batch_size], result['a']) + self.assertAllEqual(b[:batch_size], result['b']) + + # This run should work as num_samples / batch_size = 2. + result = sess.run(features) + self.assertAllEqual(a[batch_size:num_samples], result['a']) + self.assertAllEqual(b[batch_size:num_samples], result['b']) + + with self.assertRaises(errors.OutOfRangeError): + # Given num_samples and batch_size, this run should fail. + sess.run(features) + + def test_output_with_stopping_signals(self): + num_samples = 4 + batch_size = 2 + + params = {'batch_size': batch_size} + input_fn, (a, b) = make_input_fn(num_samples=num_samples) + + with ops.Graph().as_default(): + dataset = input_fn(params) + inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size) + hook = inputs.dataset_initializer_hook() + features, _ = inputs.features_and_labels() + signals = inputs.signals() + + # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape. + self.assertIsNone(features['a'].shape.as_list()[0]) + + with session.Session() as sess: + hook.begin() + hook.after_create_session(sess, coord=None) + + result, evaluated_signals = sess.run([features, signals]) + self.assertAllEqual(a[:batch_size], result['a']) + self.assertAllEqual(b[:batch_size], result['b']) + self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) + + # This run should work as num_samples / batch_size = 2. + result, evaluated_signals = sess.run([features, signals]) + self.assertAllEqual(a[batch_size:num_samples], result['a']) + self.assertAllEqual(b[batch_size:num_samples], result['b']) + self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) + + # This run should work, *but* see STOP ('1') as signals + _, evaluated_signals = sess.run([features, signals]) + self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping']) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(features) + + +class TPUEstimatorStoppingSignalsWithPaddingTest(test.TestCase): + + def test_num_samples_divisible_by_batch_size(self): + num_samples = 4 + batch_size = 2 + + params = {'batch_size': batch_size} + input_fn, (a, b) = make_input_fn(num_samples=num_samples) + + with ops.Graph().as_default(): + dataset = input_fn(params) + inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size, + add_padding=True) + hook = inputs.dataset_initializer_hook() + features, _ = inputs.features_and_labels() + signals = inputs.signals() + + # With padding, all shapes are static now. + self.assertEqual(batch_size, features['a'].shape.as_list()[0]) + + with session.Session() as sess: + hook.begin() + hook.after_create_session(sess, coord=None) + + result, evaluated_signals = sess.run([features, signals]) + self.assertAllEqual(a[:batch_size], result['a']) + self.assertAllEqual(b[:batch_size], result['b']) + self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) + self.assertAllEqual([0.] * batch_size, + evaluated_signals['padding_mask']) + + # This run should work as num_samples / batch_size = 2. + result, evaluated_signals = sess.run([features, signals]) + self.assertAllEqual(a[batch_size:num_samples], result['a']) + self.assertAllEqual(b[batch_size:num_samples], result['b']) + self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) + self.assertAllEqual([0.] * batch_size, + evaluated_signals['padding_mask']) + + # This run should work, *but* see STOP ('1') as signals + _, evaluated_signals = sess.run([features, signals]) + self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping']) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(features) + + def test_num_samples_not_divisible_by_batch_size(self): + num_samples = 5 + batch_size = 2 + + params = {'batch_size': batch_size} + input_fn, (a, b) = make_input_fn_with_labels(num_samples=num_samples) + + with ops.Graph().as_default(): + dataset = input_fn(params) + inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size, + add_padding=True) + hook = inputs.dataset_initializer_hook() + features, labels = inputs.features_and_labels() + signals = inputs.signals() + + # With padding, all shapes are static. + self.assertEqual(batch_size, features['a'].shape.as_list()[0]) + + with session.Session() as sess: + hook.begin() + hook.after_create_session(sess, coord=None) + + evaluated_features, evaluated_labels, evaluated_signals = ( + sess.run([features, labels, signals])) + self.assertAllEqual(a[:batch_size], evaluated_features['a']) + self.assertAllEqual(b[:batch_size], evaluated_labels) + self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) + self.assertAllEqual([0.] * batch_size, + evaluated_signals['padding_mask']) + + # This run should work as num_samples / batch_size >= 2. + evaluated_features, evaluated_labels, evaluated_signals = ( + sess.run([features, labels, signals])) + self.assertAllEqual(a[batch_size:2*batch_size], evaluated_features['a']) + self.assertAllEqual(b[batch_size:2*batch_size], evaluated_labels) + self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) + self.assertAllEqual([0.] * batch_size, + evaluated_signals['padding_mask']) + + # This is the final partial batch. + evaluated_features, evaluated_labels, evaluated_signals = ( + sess.run([features, labels, signals])) + real_batch_size = num_samples % batch_size + + # Assert the real part. + self.assertAllEqual(a[2*batch_size:num_samples], + evaluated_features['a'][:real_batch_size]) + self.assertAllEqual(b[2*batch_size:num_samples], + evaluated_labels[:real_batch_size]) + # Assert the padded part. + self.assertAllEqual([0.0] * (batch_size - real_batch_size), + evaluated_features['a'][real_batch_size:]) + self.assertAllEqual([[0.0]] * (batch_size - real_batch_size), + evaluated_labels[real_batch_size:]) + + self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) + + padding = ([.0] * real_batch_size + + [1.] * (batch_size - real_batch_size)) + self.assertAllEqual(padding, evaluated_signals['padding_mask']) + + # This run should work, *but* see STOP ('1') as signals + _, evaluated_signals = sess.run([features, signals]) + self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping']) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(features) + + def test_slice(self): + num_samples = 3 + batch_size = 2 + + params = {'batch_size': batch_size} + input_fn, (a, b) = make_input_fn(num_samples=num_samples) + + with ops.Graph().as_default(): + dataset = input_fn(params) + inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size, + add_padding=True) + hook = inputs.dataset_initializer_hook() + features, _ = inputs.features_and_labels() + signals = inputs.signals() + + sliced_features = ( + tpu_estimator._PaddingSignals.slice_tensor_or_dict( + features, signals)) + + with session.Session() as sess: + hook.begin() + hook.after_create_session(sess, coord=None) + + result, evaluated_signals = sess.run([sliced_features, signals]) + self.assertAllEqual(a[:batch_size], result['a']) + self.assertAllEqual(b[:batch_size], result['b']) + self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) + + # This is the final partial batch. + result, evaluated_signals = sess.run([sliced_features, signals]) + self.assertEqual(1, len(result['a'])) + self.assertAllEqual(a[batch_size:num_samples], result['a']) + self.assertAllEqual(b[batch_size:num_samples], result['b']) + self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) + + # This run should work, *but* see STOP ('1') as signals + _, evaluated_signals = sess.run([sliced_features, signals]) + self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping']) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(sliced_features) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py index 42ac6eb680437ec82287468bcba2b770ac0e5749..604e6600c81a4136a1f10e79a725a887a96f4d86 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py @@ -23,6 +23,7 @@ from __future__ import print_function from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_sharding from tensorflow.python.framework import dtypes @@ -368,13 +369,20 @@ class InfeedQueue(object): policy.freeze() self._validate() - def generate_dequeue_op(self): + def generate_dequeue_op(self, tpu_device=0): """Generates the device-side Op to dequeue a tuple from the queue. Implicitly freezes the queue configuration if it is not already frozen, which will raise errors if the shapes and types have not been fully specified. + Args: + tpu_device: The TPU device ordinal where the infeed instruction should be + placed. If None, no explicit placement will be performed, and it is up + to the user to call this API from within a proper TPU device scope. + The XLA code will fail if the TPU dequeue instruction is not bound to + any device. + Returns: A list of Outputs corresponding to a shard of infeed dequeued into XLA, suitable for use within a replicated block. @@ -392,8 +400,13 @@ class InfeedQueue(object): policy.get_sharded_shape(shape) for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) ] - return tpu_ops.infeed_dequeue_tuple( - dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) + if tpu_device is not None: + with ops.device(tpu.core(tpu_device)): + return tpu_ops.infeed_dequeue_tuple( + dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) + else: + return tpu_ops.infeed_dequeue_tuple( + dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) def _generate_enqueue_op(self, inputs, diff --git a/tensorflow/contrib/tpu/python/tpu/training_loop.py b/tensorflow/contrib/tpu/python/tpu/training_loop.py index 3d7896127a99653167f164873331a2cc95f656e8..10a8bccf3b23add75188e16eb3591c32eb8621ee 100644 --- a/tensorflow/contrib/tpu/python/tpu/training_loop.py +++ b/tensorflow/contrib/tpu/python/tpu/training_loop.py @@ -44,7 +44,7 @@ def while_loop(condition, body, inputs=None, infeed_queue=None, name=None): None (equivalent to an empty list). infeed_queue: if not None, the infeed queue from which to append a tuple of arguments as inputs to condition. - name: an optional name for the loop. + name: (Deprecated) Does nothing. Returns: The final values of the loop-carried tensors. @@ -52,7 +52,7 @@ def while_loop(condition, body, inputs=None, infeed_queue=None, name=None): Raises: TypeError: if body or condition has the wrong signature. """ - + del name # Converts inputs to Tensors. inputs = [] if inputs is None else [ops.convert_to_tensor(x) for x in inputs] @@ -166,11 +166,11 @@ def while_loop(condition, body, inputs=None, infeed_queue=None, name=None): if input_arity == 0: inputs = [array_ops.constant(0)] return control_flow_ops.while_loop(condition_wrapper, body_wrapper, inputs, - name=name) + name="") def repeat(n, body, inputs=None, infeed_queue=None, name=None): - """Builds a training loop that executes a fixed number of interations. + """Builds a training loop that executes a fixed number of iterations. The set of loop-carried tensors correspond to `inputs`. `body` must be a function that takes and returns the values of the @@ -183,7 +183,7 @@ def repeat(n, body, inputs=None, infeed_queue=None, name=None): None (equivalent to an empty list). infeed_queue: if not None, the infeed queue from which to append a tuple of arguments as inputs to condition. - name: an optional name for the loop. + name: (Deprecated) Does nothing. Returns: The final values of the loop-carried tensors. Raises: diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index 6db373d2d5e20ea7da449530b2730403c3bb64cc..6ae2f382528c37ae647b73ea01a7f88c07580c78 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -324,7 +324,6 @@ tf_proto_library( name = "protos_all", srcs = glob(["**/*.proto"]), cc_api_version = 2, - go_api_version = 2, java_api_version = 2, visibility = ["//visibility:public"], ) diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index 16397622edd382bc6dcb12870de5fa22130a2c2b..96eff86d8d48bb7f61b0fe9db2ccf2fe12c741bb 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -38,40 +38,60 @@ class HParamsTest(test.TestCase): self.assertFalse('bar' in hparams) def testSomeValues(self): - hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6') - self.assertDictEqual({'aaa': 1, 'b': 2.0, 'c_c': 'relu6'}, hparams.values()) - expected_str = '[(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\')]' + hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d='/a/b=c/d') + self.assertDictEqual( + {'aaa': 1, 'b': 2.0, 'c_c': 'relu6', 'd': '/a/b=c/d'}, + hparams.values()) + expected_str = ('[(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\'), ' + '(\'d\', \'/a/b=c/d\')]') self.assertEqual(expected_str, str(hparams.__str__())) self.assertEqual(expected_str, str(hparams)) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) + self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('aaa=12') self.assertDictEqual({ 'aaa': 12, 'b': 2.0, - 'c_c': 'relu6' + 'c_c': 'relu6', + 'd': '/a/b=c/d' }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) + self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=relu4, b=-2.0e10') self.assertDictEqual({ 'aaa': 12, 'b': -2.0e10, - 'c_c': 'relu4' + 'c_c': 'relu4', + 'd': '/a/b=c/d' }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(-2.0e10, hparams.b) self.assertEqual('relu4', hparams.c_c) + self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=,b=0,') - self.assertDictEqual({'aaa': 12, 'b': 0, 'c_c': ''}, hparams.values()) + self.assertDictEqual({'aaa': 12, 'b': 0, 'c_c': '', 'd': '/a/b=c/d'}, + hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(0.0, hparams.b) self.assertEqual('', hparams.c_c) + self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=2.3",b=+2,') self.assertEqual(2.0, hparams.b) self.assertEqual('2.3"', hparams.c_c) + hparams.parse('d=/a/b/c/d,aaa=11,') + self.assertEqual(11, hparams.aaa) + self.assertEqual(2.0, hparams.b) + self.assertEqual('2.3"', hparams.c_c) + self.assertEqual('/a/b/c/d', hparams.d) + hparams.parse('b=1.5,d=/a=b/c/d,aaa=10,') + self.assertEqual(10, hparams.aaa) + self.assertEqual(1.5, hparams.b) + self.assertEqual('2.3"', hparams.c_c) + self.assertEqual('/a=b/c/d', hparams.d) with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'): hparams.parse('x=123') with self.assertRaisesRegexp(ValueError, 'Could not parse'): @@ -84,17 +104,19 @@ class HParamsTest(test.TestCase): hparams.parse('b=relu') with self.assertRaisesRegexp(ValueError, 'Must not pass a list'): hparams.parse('aaa=[123]') - self.assertEqual(12, hparams.aaa) - self.assertEqual(2.0, hparams.b) + self.assertEqual(10, hparams.aaa) + self.assertEqual(1.5, hparams.b) self.assertEqual('2.3"', hparams.c_c) + self.assertEqual('/a=b/c/d', hparams.d) # Exports to proto. hparam_def = hparams.to_proto() # Imports from proto. hparams2 = hparam.HParams(hparam_def=hparam_def) # Verifies that all hparams are restored. - self.assertEqual(12, hparams2.aaa) - self.assertEqual(2.0, hparams2.b) + self.assertEqual(10, hparams2.aaa) + self.assertEqual(1.5, hparams2.b) self.assertEqual('2.3"', hparams2.c_c) + self.assertEqual('/a=b/c/d', hparams2.d) def testSetFromMap(self): hparams = hparam.HParams(a=1, b=2.0, c='tanh') diff --git a/tensorflow/contrib/verbs/README.md b/tensorflow/contrib/verbs/README.md index 58fed4e5cb4c24b0f21dfe9b99cf4c665d2591c7..4b6104a8b4d542b1d8a9cb3e48eeed4950d791cd 100644 --- a/tensorflow/contrib/verbs/README.md +++ b/tensorflow/contrib/verbs/README.md @@ -93,7 +93,7 @@ When the receiver receives the RDMA write, it will locate the relevant **RdmaTen 1. When the sender receives a tensor request, the source tensor may or may not be ready yet. The situation is handled through a process of tag matching: * If the request arrives before the tensor is ready, then a callback is put in a local table, and will be invoked once the tensor arrives. - * If the tensor is ready before the request arives, than the tensor is put in a local table. When the request arrives, it will invoke the callback immediately. + * If the tensor is ready before the request arrives, than the tensor is put in a local table. When the request arrives, it will invoke the callback immediately. In code it is done by calling **RecvLocalAsync()**, which receives the tensor's key, step-id, and the callback. 2. When the callback is invoked, the relevant tensor is removed from the tag matching table. In the case where we need to send the tensor's meta-data, the **RdmaTensorResponse** will store a copy of the tensor until the re-request arrives. 3. The sending of protocol messages (**RDMA_MESSAGE_TENSOR_REQUEST**, **RDMA_MESSAGE_META_DATA_RESPONSE** and **RDMA_MESSAGE_TENSOR_RE_REQUEST**) is done by the class **RdmaMessageBuffer**. All messages are sent using RDMA writes from/to fixed messages buffers. This implies that we cannot send on a specific channel more than one message at a time. In order to synchronize the messages, the **RdmaMessageBuffer** holds the a local and remote buffer statuses which can be either busy or idle. When a write is issued, both statuses will be changed to busy. When the write-complete event is received, the local status is changed to idle. When the write is received on the remote side, the remote side will parse the message, and return an ACK back to the sending side on which the sending side will update the remote status to idle. When both the local and remote statuses are idle, the next message can be sent. diff --git a/tensorflow/contrib/verbs/patch_notes_verbs_with_0_copies.md b/tensorflow/contrib/verbs/patch_notes_verbs_with_0_copies.md index 956b8f2147cf8154b6f1ade006d7bff194864c9b..da6fdd48e19e9d1503d1537926b1c464a0e77589 100644 --- a/tensorflow/contrib/verbs/patch_notes_verbs_with_0_copies.md +++ b/tensorflow/contrib/verbs/patch_notes_verbs_with_0_copies.md @@ -64,7 +64,7 @@ The protocol messages themselves will remain mostly unchanged at the first stage * type - The message type. * request_index - Request index. * is_dead/data_type/tensor_shape/tensor_bytes - The up-to-date meta-data. -* **RDMA_MESSAGE_BUFFER_RESPONSE** - (receiver ==> sender) Tensor re-requset after meta-data update and reallocation of result/proxy tensors. +* **RDMA_MESSAGE_BUFFER_RESPONSE** - (receiver ==> sender) Tensor re-request after meta-data update and reallocation of result/proxy tensors. * type - The message type. * name (name_size) - Name of the requested tensor. * step_id - Step ID. diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 1893967cdd0034bcff52c84f4db0bf1e2e3334f4..1d11410332c76595fd1c3ac5e801c5c161570ca2 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -220,7 +220,6 @@ tf_proto_library( srcs = CORE_PROTO_SRCS + ADDITIONAL_CORE_PROTO_SRCS, cc_api_version = 2, default_header = True, - go_api_version = 2, j2objc_api_version = 1, java_api_version = 2, js_api_version = 2, @@ -314,6 +313,7 @@ cc_library( "lib/gtl/optional.h", "lib/gtl/priority_queue_util.h", "lib/hash/crc32c.h", + "lib/hash/hash.h", "lib/histogram/histogram.h", "lib/io/buffered_inputstream.h", "lib/io/compression.h", @@ -339,6 +339,7 @@ cc_library( "lib/strings/strcat.h", "lib/strings/stringprintf.h", "platform/abi.h", + "platform/context.h", "platform/cpu_feature_guard.h", "platform/cpu_info.h", "platform/dynamic_annotations.h", @@ -353,6 +354,7 @@ cc_library( "platform/mutex.h", "platform/net.h", "platform/notification.h", + "platform/null_file_system.h", "platform/prefetch.h", "platform/profile_utils/clock_cycle_profiler.h", "platform/profile_utils/cpu_utils.h", @@ -593,6 +595,7 @@ cc_library( "platform/prefetch.h", "platform/thread_annotations.h", "platform/types.h", + "platform/cpu_info.h", ] + if_windows(["platform/windows/integral_types.h"]), visibility = ["//visibility:public"], deps = @@ -632,6 +635,7 @@ tf_gen_op_libs( "random_ops", "remote_fused_graph_ops", "resource_variable_ops", + "scoped_allocator_ops", "sdca_ops", "set_ops", "script_ops", @@ -685,6 +689,34 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "cudnn_rnn_ops", + srcs = [ + "ops/cudnn_rnn_ops.cc", + ], + linkstatic = 1, + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:stream_executor", + "//tensorflow/core/kernels:bounds_check_lib", + "//third_party/eigen3", + "@farmhash_archive//:farmhash", + ], + alwayslink = 1, +) + +tf_gen_op_libs( + op_lib_names = [ + "cudnn_rnn_ops", + ], + deps = [ + ":lib", + ], +) + cc_library( name = "ops", visibility = ["//visibility:public"], @@ -697,6 +729,7 @@ cc_library( ":checkpoint_ops_op_lib", ":control_flow_ops_op_lib", ":ctc_ops_op_lib", + ":cudnn_rnn_ops_op_lib", ":data_flow_ops_op_lib", ":dataset_ops_op_lib", ":function_ops_op_lib", @@ -715,11 +748,13 @@ cc_library( ":random_ops_op_lib", ":remote_fused_graph_ops_op_lib", ":resource_variable_ops_op_lib", + ":scoped_allocator_ops_op_lib", ":script_ops_op_lib", ":sdca_ops_op_lib", ":sendrecv_ops_op_lib", ":set_ops_op_lib", ":sparse_ops_op_lib", + ":summary_ops_op_lib", ":spectral_ops_op_lib", ":state_ops_op_lib", ":stateless_random_ops_op_lib", @@ -835,6 +870,7 @@ cc_library( "//tensorflow/core/kernels:checkpoint_ops", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:ctc_ops", + "//tensorflow/core/kernels:cudnn_rnn_kernels", "//tensorflow/core/kernels:data_flow", "//tensorflow/core/kernels:dataset_ops", "//tensorflow/core/kernels:fake_quant_ops", @@ -858,6 +894,7 @@ cc_library( "//tensorflow/core/kernels:remote_fused_graph_ops", "//tensorflow/core/kernels:required", "//tensorflow/core/kernels:resource_variable_ops", + "//tensorflow/core/kernels:scoped_allocator_ops", "//tensorflow/core/kernels:sdca_ops", "//tensorflow/core/kernels:set_kernels", "//tensorflow/core/kernels:sparse", @@ -1034,6 +1071,7 @@ filegroup( "util/tensor_bundle/*.h", "util/tensor_bundle/*.cc", "common_runtime/gpu/**/*", + "common_runtime/eager/*", "common_runtime/gpu_device_factory.*", ], ), @@ -1059,6 +1097,7 @@ filegroup( "**/*testlib*", "**/*main.cc", "common_runtime/gpu/**/*", + "common_runtime/eager/*", "common_runtime/gpu_device_factory.*", "graph/dot.*", ], @@ -1402,6 +1441,13 @@ tf_pyclif_proto_library( visibility = ["//visibility:public"], ) +tf_pyclif_proto_library( + name = "protobuf/device_properties_pyclif", + proto_lib = ":protos_all_cc", + proto_srcfile = "protobuf/device_properties.proto", + visibility = ["//visibility:public"], +) + tf_pyclif_proto_library( name = "protobuf/meta_graph_pyclif", proto_lib = ":protos_all_cc", @@ -1410,9 +1456,9 @@ tf_pyclif_proto_library( ) tf_pyclif_proto_library( - name = "protobuf/device_properties_pyclif", + name = "protobuf/saved_model_pyclif", proto_lib = ":protos_all_cc", - proto_srcfile = "protobuf/device_properties.proto", + proto_srcfile = "protobuf/saved_model.proto", visibility = ["//visibility:public"], ) @@ -1518,6 +1564,7 @@ LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [ "lib/strings/base64.h", "lib/strings/ordered_code.h", "lib/strings/proto_text_util.h", + "lib/strings/proto_serialization.h", "lib/strings/scanner.h", "lib/wav/wav_io.h", "platform/demangle.h", @@ -1663,6 +1710,25 @@ cc_library( ], ) +cc_library( + name = "tflite_portable_logging", + srcs = [], + hdrs = [ + "lib/bfloat16/bfloat16.h", + "platform/default/integral_types.h", + "platform/default/logging.h", + "platform/logging.h", + "platform/macros.h", + "platform/platform.h", + "platform/types.h", + ], + copts = tf_copts(), + linkopts = ["-ldl"], + deps = [ + "//tensorflow/core/platform/default/build_config:logging", + ], +) + cc_library( name = "android_jpeg_internal", srcs = if_android([ @@ -1854,6 +1920,13 @@ cc_header_only_library( ], ) +cc_header_only_library( + name = "core_cpu_headers_lib", + deps = [ + ":core_cpu_lib", + ], +) + tf_cuda_library( name = "framework_internal_impl", srcs = FRAMEWORK_INTERNAL_PRIVATE_HEADERS + [ @@ -1919,7 +1992,7 @@ tf_cuda_library( ) + if_mkl( [ "//third_party/mkl:intel_binary_blob", - "@mkl_dnn//:mkl_dnn", + "@mkl_dnn", ], ), alwayslink = 1, @@ -1931,7 +2004,6 @@ cc_header_only_library( deps = [ ":framework", ":reader_base", - "@nsync//:nsync_headers", ], ) @@ -2038,14 +2110,19 @@ tf_cuda_library( CORE_CPU_BASE_HDRS = GRAPH_HDRS + [ "common_runtime/device.h", + "common_runtime/device_mgr.h", + "common_runtime/eval_const_tensor.h", "common_runtime/graph_runner.h", "common_runtime/shape_refiner.h", "framework/versions.h", + "common_runtime/process_function_library_runtime.h", + "common_runtime/function.h", ] tf_cuda_library( name = "core_cpu_base", srcs = [ + "common_runtime/eval_const_tensor.cc", "common_runtime/shape_refiner.cc", "common_runtime/shape_refiner.h", "framework/versions.h", @@ -2086,24 +2163,23 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/costmodel_manager.h", "common_runtime/debugger_state_interface.h", "common_runtime/device_factory.h", - "common_runtime/device_mgr.h", "common_runtime/device_set.h", "common_runtime/dma_helper.h", "common_runtime/eigen_thread_pool.h", "common_runtime/executor.h", - "common_runtime/function.h", "common_runtime/graph_optimizer.h", "common_runtime/local_device.h", "common_runtime/memory_types.h", "common_runtime/mkl_cpu_allocator.h", "common_runtime/optimization_registry.h", "common_runtime/pending_counts.h", - "common_runtime/process_function_library_runtime.h", "common_runtime/process_util.h", "common_runtime/profile_handler.h", "common_runtime/renamed_device.h", "common_runtime/rendezvous_mgr.h", "common_runtime/rendezvous_util.h", + "common_runtime/scoped_allocator.h", + "common_runtime/scoped_allocator_mgr.h", "common_runtime/session_factory.h", "common_runtime/placer.h", "common_runtime/stats_publisher_interface.h", @@ -2134,6 +2210,7 @@ tf_cuda_library( "common_runtime/graph_runner.cc", "common_runtime/local_device.cc", "common_runtime/memory_types.cc", + "common_runtime/mkl_cpu_allocator.cc", "common_runtime/optimization_registry.cc", "common_runtime/parallel_concat_optimizer.cc", "common_runtime/placer.cc", @@ -2142,6 +2219,8 @@ tf_cuda_library( "common_runtime/renamed_device.cc", "common_runtime/rendezvous_mgr.cc", "common_runtime/rendezvous_util.cc", + "common_runtime/scoped_allocator.cc", + "common_runtime/scoped_allocator_mgr.cc", "common_runtime/session.cc", "common_runtime/session_factory.cc", "common_runtime/session_options.cc", @@ -2173,6 +2252,7 @@ tf_cuda_library( ] + if_mkl( [ "//third_party/mkl:intel_binary_blob", + "@mkl_dnn", ], ), alwayslink = 1, @@ -2217,14 +2297,12 @@ tf_cuda_library( ] + if_mkl( [ "//third_party/mkl:intel_binary_blob", - "@mkl_dnn//:mkl_dnn", + "@mkl_dnn", ], ) + tf_additional_core_deps() + if_static([":core_cpu_impl"]), alwayslink = 1, ) -# This library is deprecated and no longer publicly available. -# Do not add more uses of it. cc_library( name = "regexp_internal", hdrs = [ @@ -2867,6 +2945,23 @@ tf_cc_tests( ], ) +tf_cc_test( + name = "cudnn_rnn_ops_test_cc", + size = "small", + srcs = [ + "ops/cudnn_rnn_ops_test.cc", + ], + deps = [ + ":cudnn_rnn_ops", + "//tensorflow/core", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cc_test_mkl( name = "mkl_runtime_tests", size = "small", @@ -3127,6 +3222,7 @@ tf_cc_test( ":core_cpu", ":core_cpu_internal", ":framework", + ":lib", ":test", ":test_main", ":testlib", @@ -3175,6 +3271,7 @@ tf_cc_test( "//tensorflow/core/kernels:dense_update_ops", "//tensorflow/core/kernels:fifo_queue_op", "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:matmul_op", "//tensorflow/core/kernels:ops_util", @@ -3217,6 +3314,7 @@ tf_cc_test( "//tensorflow/core/kernels:fifo_queue_op", "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:identity_op", + "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:matmul_op", "//tensorflow/core/kernels:ops_util", "//tensorflow/core/kernels:queue_ops", @@ -3291,6 +3389,43 @@ tf_cc_test( size = "small", srcs = ["common_runtime/function_test.cc"], linkstatic = tf_kernel_tests_linkstatic(), + tags = [ + "manual", + "no_oss", + ], + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":ops", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", + "//tensorflow/core/kernels:cast_op", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:random_ops", + "//tensorflow/core/kernels:shape_ops", + "//third_party/eigen3", + ], +) + +tf_cc_test( + name = "common_runtime_function_threadpool_test", + size = "small", + srcs = ["common_runtime/function_threadpool_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), deps = [ ":core", ":core_cpu", @@ -3319,6 +3454,21 @@ tf_cc_test( ], ) +tf_cc_test( + name = "common_runtime_scoped_allocator_mgr_test", + size = "small", + srcs = ["common_runtime/scoped_allocator_mgr_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core_cpu", + ":core_cpu_internal", + ":framework", + ":lib", + ":test", + ":test_main", + ], +) + tf_cc_test_gpu( name = "gpu_allocator_retry_test", size = "medium", @@ -3623,6 +3773,13 @@ filegroup( "lib/gif/testdata/optimized.gif", # BMP data "lib/bmp/testdata/lena.bmp", + # SSIM, PSNR data + "lib/ssim/testdata/checkerboard1.png", + "lib/ssim/testdata/checkerboard2.png", + "lib/ssim/testdata/checkerboard3.png", + "lib/psnr/testdata/cat_q20.jpg", + "lib/psnr/testdata/cat_q72.jpg", + "lib/psnr/testdata/cat_q95.jpg", ], visibility = ["//visibility:public"], ) diff --git a/tensorflow/core/api_def/base_api/api_def_CloseSummaryWriter.pbtxt b/tensorflow/core/api_def/base_api/api_def_CloseSummaryWriter.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f6fd7d93169306fdf5ca62d27635e1f86f37bc4d --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_CloseSummaryWriter.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "CloseSummaryWriter" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_CreateSummaryDbWriter.pbtxt b/tensorflow/core/api_def/base_api/api_def_CreateSummaryDbWriter.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..28da46a0f8e452f65d06a13c4b0d0b03b2a75757 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_CreateSummaryDbWriter.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "CreateSummaryDbWriter" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_CreateSummaryFileWriter.pbtxt b/tensorflow/core/api_def/base_api/api_def_CreateSummaryFileWriter.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2ce2c4d37e5001681ffa733bf4726c6bea652029 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_CreateSummaryFileWriter.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "CreateSummaryFileWriter" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNN.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNN.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..daeb5fe9a223d7d1254725325921a28a7d165902 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNN.pbtxt @@ -0,0 +1,36 @@ +op { + graph_op_name: "CudnnRNN" + summary: "A RNN backed by cuDNN." + description: <

diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt index 771cf0b591367e18f007e91bf66bc1cfd02ab459..8e99718c7e3751c1bf4ef4d03e558be3c0ada51e 100644 --- a/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt @@ -53,6 +53,6 @@ This makes it easier to chain operations that need to use the reset value. Duplicate entries are handled correctly: if multiple `indices` reference the same location, their contributions divide. -Requires `updates.shape = indices.shape + ref.shape[1:]`. +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. END } diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7b52dad4a163643af659320f324ce6558fcffcd8 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt @@ -0,0 +1,60 @@ +op { + graph_op_name: "ScatterMax" + in_arg { + name: "ref" + description: < + +
+END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..721ac0ff35f934583e227317515b0ba3298de747 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt @@ -0,0 +1,60 @@ +op { + graph_op_name: "ScatterMin" + in_arg { + name: "ref" + description: < + + +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt index a51f571b00d7fc68a24dbfc4a0104522f8c0f559..b9e293ba9efba10de9ccd774111899adf4342c90 100644 --- a/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt @@ -53,6 +53,6 @@ This makes it easier to chain operations that need to use the reset value. Duplicate entries are handled correctly: if multiple `indices` reference the same location, their contributions multiply. -Requires `updates.shape = indices.shape + ref.shape[1:]`. +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. END } diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt index c0d3a4a1337ee1e1a32114adc51c930e014bc268..d12b3e68c25c22825349bf7affbb09de8fdf98ac 100644 --- a/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt @@ -51,7 +51,7 @@ This makes it easier to chain operations that need to use the reset value. Duplicate entries are handled correctly: if multiple `indices` reference the same location, their (negated) contributions add. -Requires `updates.shape = indices.shape + ref.shape[1:]`. +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt index c44dbbd2332828242792d9cdd4a218e7457c7d2b..4804908afc61356db76391a4d425b0857c52412d 100644 --- a/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt @@ -54,7 +54,7 @@ If values in `ref` is to be updated more than once, because there are duplicate entries in `indices`, the order at which the updates happen for each value is undefined. -Requires `updates.shape = indices.shape + ref.shape[1:]`. +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
diff --git a/tensorflow/core/api_def/base_api/api_def_SdcaOptimizer.pbtxt b/tensorflow/core/api_def/base_api/api_def_SdcaOptimizer.pbtxt index b0b58ac00e6709922ed517ad2c9efebbedf450a3..9da0e124ebe02f1cfb6450b96471d7d9d146bd20 100644 --- a/tensorflow/core/api_def/base_api/api_def_SdcaOptimizer.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SdcaOptimizer.pbtxt @@ -97,8 +97,11 @@ END } attr { name: "adaptative" + default_value { + b: True + } description: < ## Validate your installation @@ -505,7 +481,7 @@ If you installed through Docker, start a Docker container from which you can run bash. For example:
-$ docker run -it gcr.io/tensorflow/tensorflow bash
+$ docker run -it tensorflow/tensorflow bash
 
@@ -530,11 +506,18 @@ TensorFlow programs:
Hello, TensorFlow!
-If you are new to TensorFlow, see @{$get_started/premade_estimators$Getting Started with TensorFlow}. - If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). +If you are new to machine learning, we recommend the following: + +* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course) +* @{$get_started/get_started_for_beginners$Getting Started for ML Beginners} + +If you are experienced with machine learning but new to TensorFlow, see +@{$get_started/premade_estimators$Getting Started with TensorFlow}. + + ## Common installation problems We are relying on Stack Overflow to document TensorFlow installation problems @@ -647,14 +630,14 @@ This section documents the relevant values for Linux installations. CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.7.0rc1-cp27-none-linux_x86_64.whl
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc1-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.7.0rc1-cp27-none-linux_x86_64.whl
 
Note that GPU support requires the NVIDIA hardware and software described in @@ -666,14 +649,14 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.7.0rc1-cp34-cp34m-linux_x86_64.whl
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc1-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.7.0rc1-cp34-cp34m-linux_x86_64.whl
 
Note that GPU support requires the NVIDIA hardware and software described in @@ -685,14 +668,14 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.7.0rc1-cp35-cp35m-linux_x86_64.whl
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc1-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.7.0rc1-cp35-cp35m-linux_x86_64.whl
 
@@ -704,14 +687,14 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.7.0rc1-cp36-cp36m-linux_x86_64.whl
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc1-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.7.0rc1-cp36-cp36m-linux_x86_64.whl
 
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md index 5be38ae1ef9fe994e5da18449fd0b3dc7335146b..7060ef43da3e978a87250cacf916b4a792274a47 100644 --- a/tensorflow/docs_src/install/install_mac.md +++ b/tensorflow/docs_src/install/install_mac.md @@ -119,7 +119,7 @@ Take the following steps to install TensorFlow with Virtualenv: TensorFlow in the active Virtualenv is as follows:
 $ pip3 install --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py3-none-any.whl
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.7.0rc1-py3-none-any.whl If you encounter installation problems, see [Common Installation Problems](#common-installation-problems). @@ -238,11 +238,11 @@ take the following steps: operating system and Python version. Find the appropriate value for tfBinaryURL [here](#the_url_of_the_tensorflow_python_package). For example, if - you are installing TensorFlow for Mac OS and Python 2.7 + you are installing TensorFlow for macOS and Python 2.7 issue the following command:
 $ sudo pip3 install --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py3-none-any.whl 
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.7.0rc1-py3-none-any.whl If the preceding command fails, see [installation problems](#common-installation-problems). @@ -292,24 +292,23 @@ where: to 6006. * TensorFlowImage is required. It identifies the Docker container. You must specify one of the following values: - * gcr.io/tensorflow/tensorflow: TensorFlow binary image. - * gcr.io/tensorflow/tensorflow:latest-devel: TensorFlow + * tensorflow/tensorflow: TensorFlow binary image. + * tensorflow/tensorflow:latest-devel: TensorFlow Binary image plus source code. -gcr.io is the Google Container Registry. Note that some -TensorFlow images are also available at +The TensorFlow images are available at [dockerhub](https://hub.docker.com/r/tensorflow/tensorflow/). For example, the following command launches a TensorFlow CPU binary image in a Docker container from which you can run TensorFlow programs in a shell: -
$ docker run -it gcr.io/tensorflow/tensorflow bash
+
$ docker run -it tensorflow/tensorflow bash
The following command also launches a TensorFlow CPU binary image in a Docker container. However, in this Docker container, you can run TensorFlow programs in a Jupyter notebook: -
$ docker run -it -p 8888:8888 gcr.io/tensorflow/tensorflow
+
$ docker run -it -p 8888:8888 tensorflow/tensorflow
Docker will download the TensorFlow binary image the first time you launch it. @@ -351,7 +350,7 @@ Take the following steps to install TensorFlow in an Anaconda environment: TensorFlow for Python 2.7:
 (targetDirectory)$ pip install --ignore-installed --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py2-none-any.whl
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.7.0rc1-py2-none-any.whl @@ -376,7 +375,7 @@ do the following: If you installed through Docker, start a Docker container that runs bash. For example: -
$ docker run -it gcr.io/tensorflow/tensorflow bash
+
$ docker run -it tensorflow/tensorflow bash
@@ -401,12 +400,18 @@ writing TensorFlow programs:
Hello, TensorFlow!
-If you are new to TensorFlow, see -@{$get_started/premade_estimators$Getting Started with TensorFlow}. - If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). +If you are new to machine learning, we recommend the following: + +* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course) +* @{$get_started/get_started_for_beginners$Getting Started for ML Beginners} + +If you are experienced with machine learning but new to TensorFlow, see +@{$get_started/premade_estimators$Getting Started with TensorFlow}. + + ## Common installation problems We are relying on Stack Overflow to document TensorFlow installation problems @@ -513,18 +518,13 @@ RuntimeError: Broken toolchain: cannot link a simple C program ## The URL of the TensorFlow Python package A few installation mechanisms require the URL of the TensorFlow Python package. -The value you specify depends on three factors: - - * operating system - * Python version - -This section documents the relevant values for Mac OS installations. +The value you specify depends on your Python version. ### Python 2.7
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py2-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.7.0rc1-py2-none-any.whl
 
@@ -532,5 +532,5 @@ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py2-none-a
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py3-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.7.0rc1-py3-none-any.whl
 
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md index 8d83e9f1190ed307ca99d81168df7dfab51e4507..148f80efe25f12cfaef9df8a8edfaa700782dacd 100644 --- a/tensorflow/docs_src/install/install_sources.md +++ b/tensorflow/docs_src/install/install_sources.md @@ -133,30 +133,21 @@ The following NVIDIA hardware must be installed on your system: The following NVIDIA software must be installed on your system: - * NVIDIA's Cuda Toolkit (>= 7.0). We recommend version 9.0. + * [CUDA Toolkit](http://nvidia.com/cuda) (>= 7.0). We recommend version 9.0. For details, see - [NVIDIA's documentation](http://docs.nvidia.com/cuda/cuda-installation-guide-linux/#axzz4VZnqTJ2A). - Ensure that you append the relevant Cuda pathnames to the + [NVIDIA's documentation](http://docs.nvidia.com/cuda/cuda-installation-guide-linux/). + Ensure that you append the relevant CUDA pathnames to the `LD_LIBRARY_PATH` environment variable as described in the NVIDIA documentation. - * The NVIDIA drivers associated with NVIDIA's Cuda Toolkit. - * cuDNN (>= v3). We recommend version 6.0. For details, see - [NVIDIA's documentation](https://developer.nvidia.com/cudnn), - particularly the description of appending the appropriate pathname - to your `LD_LIBRARY_PATH` environment variable. - -Finally, you must also install `libcupti` which for Cuda Toolkit >= 8.0 you do via - -
 $ sudo apt-get install cuda-command-line-tools 
- -and add its path to your `LD_LIBRARY_PATH` environment variable: - -
 $ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/extras/CUPTI/lib64 
- -For Cuda Toolkit <= 7.5, you install `libcupti-dev` by invoking the following command: - -
 $ sudo apt-get install libcupti-dev 
+ * [GPU drivers](http://nvidia.com/driver) supporting your version of the CUDA + Toolkit. + * [cuDNN SDK](http://developer.nvidia.com/cudnn) (>= v3). We recommend version 7.0. For details, see + [NVIDIA's documentation](http://docs.nvidia.com/deeplearning/sdk/cudnn-install/). + * [CUPTI](http://docs.nvidia.com/cuda/cupti/) ships with the CUDA Toolkit, but + you also need to append its path to the `LD_LIBRARY_PATH` environment + variable: +
 $ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/extras/CUPTI/lib64 
### Next @@ -240,8 +231,8 @@ such as compiler flags. You must run this script *prior* to creating the pip package and installing TensorFlow. If you wish to build TensorFlow with GPU, `configure` will ask -you to specify the version numbers of Cuda and cuDNN. If several -versions of Cuda or cuDNN are installed on your system, explicitly select +you to specify the version numbers of CUDA and cuDNN. If several +versions of CUDA or cuDNN are installed on your system, explicitly select the desired version instead of relying on the default. One of the questions that `configure` will ask is as follows: @@ -289,12 +280,12 @@ Do you wish to build TensorFlow with CUDA support? [y/N] Y CUDA support will be enabled for TensorFlow Do you want to use clang as CUDA compiler? [y/N] nvcc will be used as CUDA compiler -Please specify the Cuda SDK version you want to use, e.g. 7.0. [Leave empty to default to CUDA 9.0]: 9.0 +Please specify the CUDA SDK version you want to use, e.g. 7.0. [Leave empty to default to CUDA 9.0]: 9.0 Please specify the location where CUDA 9.0 toolkit is installed. Refer to README.md for more details. [Default is /usr/local/cuda]: Please specify which gcc should be used by nvcc as the host compiler. [Default is /usr/bin/gcc]: Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 7.0]: 7 Please specify the location where cuDNN 7 library is installed. Refer to README.md for more details. [Default is /usr/local/cuda]: -Please specify a list of comma-separated Cuda compute capabilities you want to build with. +Please specify a list of comma-separated CUDA compute capabilities you want to build with. You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus. Please note that each additional compute capability significantly increases your build time and binary size. [Default is: "3.5,5.2"]: 3.0 @@ -304,14 +295,14 @@ Configuration finished If you told `configure` to build for GPU support, then `configure` -will create a canonical set of symbolic links to the Cuda libraries -on your system. Therefore, every time you change the Cuda library paths, +will create a canonical set of symbolic links to the CUDA libraries +on your system. Therefore, every time you change the CUDA library paths, you must rerun the `configure` script before re-invoking the bazel build command. Note the following: - * Although it is possible to build both Cuda and non-Cuda configs + * Although it is possible to build both CUDA and non-CUDA configs under the same source tree, we recommend running `bazel clean` when switching between these two configurations in the same source tree. * If you don't run the `configure` script *before* running the @@ -359,10 +350,10 @@ Invoke `pip install` to install that pip package. The filename of the `.whl` file depends on your platform. For example, the following command will install the pip package -for TensorFlow 1.6.0rc1 on Linux: +for TensorFlow 1.7.0rc1 on Linux:
-$ sudo pip install /tmp/tensorflow_pkg/tensorflow-1.6.0rc1-py2-none-any.whl
+$ sudo pip install /tmp/tensorflow_pkg/tensorflow-1.7.0rc1-py2-none-any.whl
 
## Validate your installation @@ -393,8 +384,7 @@ TensorFlow programs:
Hello, TensorFlow!
-If you are new to TensorFlow, see @{$get_started/premade_estimators$Getting Started with -TensorFlow}. +If you are new to TensorFlow, see @{$get_started/premade_estimators$Getting Started with TensorFlow}. If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). @@ -460,6 +450,8 @@ Stack Overflow and specify the `tensorflow` tag. **Linux** + + @@ -479,6 +471,7 @@ Stack Overflow and specify the `tensorflow` tag. **Mac**
Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
tensorflow-1.7.0rc1CPU2.7, 3.3-3.6GCC 4.8Bazel 0.10.0N/AN/A
tensorflow_gpu-1.7.0rc1GPU2.7, 3.3-3.6GCC 4.8Bazel 0.9.079
tensorflow-1.6.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.9.0N/AN/A
tensorflow_gpu-1.6.0GPU2.7, 3.3-3.6GCC 4.8Bazel 0.9.079
tensorflow-1.5.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.8.0N/AN/A
+ @@ -493,6 +486,8 @@ Stack Overflow and specify the `tensorflow` tag. **Windows**
Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
tensorflow-1.7.0rc1CPU2.7, 3.3-3.6Clang from xcodeBazel 0.10.1N/AN/A
tensorflow-1.6.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.8.1N/AN/A
tensorflow-1.5.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.8.1N/AN/A
tensorflow-1.4.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.5.4N/AN/A
+ + diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md index dedf485f93d6fd6a8ce7b4465548cc998d307daa..86add74da15005a56bf0fd88c775139cd030c243 100644 --- a/tensorflow/docs_src/install/install_windows.md +++ b/tensorflow/docs_src/install/install_windows.md @@ -17,7 +17,7 @@ You must choose one of the following types of TensorFlow to install: NVIDIA® GPU, you must install this version. Note that this version of TensorFlow is typically much easier to install (typically, in 5 or 10 minutes), so even if you have an NVIDIA GPU, we recommend - installing this version first. + installing this version first. Prebuilt binaries will use AVX instructions. * **TensorFlow with GPU support**. TensorFlow programs typically run significantly faster on a GPU than on a CPU. Therefore, if your system has a NVIDIA® GPU meeting the prerequisites shown below @@ -41,7 +41,8 @@ installed on your system: Note that cuDNN is typically installed in a different location from the other CUDA DLLs. Ensure that you add the directory where you installed the cuDNN DLL to your `%PATH%` environment variable. - * GPU card with CUDA Compute Capability 3.0 or higher. See + * GPU card with CUDA Compute Capability 3.0 or higher for building + from source and 3.5 or higher for our binaries. See [NVIDIA documentation](https://developer.nvidia.com/cuda-gpus) for a list of supported GPU cards. @@ -153,14 +154,17 @@ TensorFlow programs:
Hello, TensorFlow!
-If you are new to TensorFlow, see @{$get_started/premade_estimators$Getting Started with -TensorFlow}. - If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). -There is also a helpful [script](https://gist.github.com/mrry/ee5dbcfdd045fa48a27d56664411d41c) -for Windows TensorFlow installation issues. +If you are new to machine learning, we recommend the following: + +* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course) +* @{$get_started/get_started_for_beginners$Getting Started for ML Beginners} + +If you are experienced with machine learning but new to TensorFlow, see +@{$get_started/premade_estimators$Getting Started with TensorFlow}. + ## Common installation problems diff --git a/tensorflow/docs_src/mobile/android_build.md b/tensorflow/docs_src/mobile/android_build.md index b5a1d5d7d1bf3b456ab24165e273969bdbd7bfca..08a5fbe41c87c88399682208c38bf7a892d8fc1a 100644 --- a/tensorflow/docs_src/mobile/android_build.md +++ b/tensorflow/docs_src/mobile/android_build.md @@ -90,8 +90,8 @@ using [ADB](https://developer.android.com/studio/command-line/adb.html). This requires some knowledge of build systems and Android developer tools, but we'll guide you through the basics here. -- First, follow our instructions for @{$install/install_sources$installing from - sources}. This will also guide you through installing Bazel and cloning the +- First, follow our instructions for @{$install/install_sources$installing from sources}. + This will also guide you through installing Bazel and cloning the TensorFlow code. - Download the Android [SDK](https://developer.android.com/studio/index.html) diff --git a/tensorflow/docs_src/mobile/leftnav_files b/tensorflow/docs_src/mobile/leftnav_files index ac50f528ba468d8a830c059539d3399f413f39c8..4cf134cc3c2c323405d769a5ced5d5a68f188203 100644 --- a/tensorflow/docs_src/mobile/leftnav_files +++ b/tensorflow/docs_src/mobile/leftnav_files @@ -2,6 +2,7 @@ index.md ### TensorFlow Lite tflite/index.md tflite/demo_android.md +tflite/demo_ios.md >>> ### TensorFlow Mobile mobile_intro.md diff --git a/tensorflow/docs_src/mobile/optimizing.md b/tensorflow/docs_src/mobile/optimizing.md index 44cacff5dbbcb0685044c342184464b47a8ed090..778e4d3a6233c3bec70b830bc998013745a1f0ba 100644 --- a/tensorflow/docs_src/mobile/optimizing.md +++ b/tensorflow/docs_src/mobile/optimizing.md @@ -233,6 +233,8 @@ order by how long they took. From left to right, the columns are: - The cumulative total time of this and the previous ops in the table. This is handy for understanding what the distribution of work is across the layers, to see if just a few of the nodes are taking up most of the time. + +- The amount of memory consumed by outputs of this type of op. - Name of the node. @@ -290,8 +292,8 @@ run it on a 64-bit ARM device: You can interpret the results in exactly the same way as the desktop version above. If you have any trouble figuring out what the right input and output -names and types are, take a look at the @{$mobile/prepare_models$Preparing -models} page for details about detecting these for your model, and look at the +names and types are, take a look at the @{$mobile/prepare_models$Preparing models} +page for details about detecting these for your model, and look at the `summarize_graph` tool which may give you helpful information. diff --git a/tensorflow/docs_src/mobile/prepare_models.md b/tensorflow/docs_src/mobile/prepare_models.md index 360ee302aa96bc3a0b65eab7b39c3dacf56b42c0..8b22c04d872f18607c485775cb8f096f0a361995 100644 --- a/tensorflow/docs_src/mobile/prepare_models.md +++ b/tensorflow/docs_src/mobile/prepare_models.md @@ -60,7 +60,7 @@ and serialized as protocol buffers: the `NodeDef`, so if all the `Variable` weights are converted to `Const` nodes, then we only need a single `GraphDef` file to hold the model architecture and the weights. Freezing the graph handles the process of loading the - checkpoints, and then converts all Consts to Variables. You can then load the + checkpoints, and then converts all Variables to Consts. You can then load the resulting file in a single call, without having to restore variable values from checkpoints. One thing to watch out for with `GraphDef` files is that sometimes they’re stored in text format for easy inspection. These versions diff --git a/tensorflow/docs_src/mobile/tflite/demo_android.md b/tensorflow/docs_src/mobile/tflite/demo_android.md index 79b567897cb8a38ed2e27e73aa7e8fee95f718b8..c94b5597a673a7e68aed517b325b9719b3b73bbd 100644 --- a/tensorflow/docs_src/mobile/tflite/demo_android.md +++ b/tensorflow/docs_src/mobile/tflite/demo_android.md @@ -8,6 +8,9 @@ You'll need an Android device running Android 5.0 or higher to run the demo. To get you started working with TensorFlow Lite on Android, we'll walk you through building and deploying our TensorFlow demo app in Android Studio. +Note: For a more detailed guide see the +[TFLite Codelab](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2-tflite/index.html#0) + It's also possible to build the demo app with Bazel, but we only recommend this for advanced users who are very familiar with the Bazel build environment. For more information on that, see our page [on Github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite#building-tensorflow-lite-and-the-demo-app-from-source). diff --git a/tensorflow/docs_src/mobile/tflite/demo_ios.md b/tensorflow/docs_src/mobile/tflite/demo_ios.md new file mode 100644 index 0000000000000000000000000000000000000000..3ee9b1cbca6cfef98616bd33bbf91b756b4efa15 --- /dev/null +++ b/tensorflow/docs_src/mobile/tflite/demo_ios.md @@ -0,0 +1,68 @@ +# TensorFlow Lite Demo for iOS + +The TensorFlow Lite demo is a camera app that continuously classifies whatever +it sees from your device's back camera, using a quantized MobileNet model. These +instructions walk you through building and running the demo on an iOS device. + +## Prerequisites + +* You must have [Xcode](https://developer.apple.com/xcode/) installed and have a + valid Apple Developer ID, and have an iOS device set up and linked to your + developer account with all of the appropriate certificates. For these + instructions, we assume that you have already been able to build and deploy an + app to an iOS device with your current developer environment. + +* The demo app requires a camera and must be executed on a real iOS device. You + can build it and run with the iPhone Simulator but it won't have any camera + information to classify. + +* You don't need to build the entire TensorFlow library to run the demo, but you + will need to clone the TensorFlow repository if you haven't already: + + git clone https://github.com/tensorflow/tensorflow + +* You'll also need the Xcode command-line tools: + + xcode-select --install + + If this is a new install, you will need to run the Xcode application once to + agree to the license before continuing. + +## Building the iOS Demo App + +1. Install CocoaPods if you don't have it: + + sudo gem install cocoapods + +2. Download the model files used by the demo app (this is done from inside the + cloned directory): + + sh tensorflow/contrib/lite/examples/ios/download_models.sh + +3. Install the pod to generate the workspace file: + + cd tensorflow/contrib/lite/examples/ios/camera + pod install + + If you have installed this pod before and that command doesn't work, try + + pod update + + At the end of this step you should have a file called + `tflite_camera_example.xcworkspace`. + +4. Open the project in Xcode by typing this on the command line: + + open tflite_camera_example.xcworkspace + + This launches Xcode if it isn't open already and opens the + `tflite_camera_example` project. + +5. Build and run the app in Xcode. + + Note that as mentioned earlier, you must already have a device set up and + linked to your Apple Developer account in order to deploy the app on a + device. + +You'll have to grant permissions for the app to use the device's camera. Point +the camera at various objects and enjoy seeing how the model classifies things! diff --git a/tensorflow/docs_src/performance/leftnav_files b/tensorflow/docs_src/performance/leftnav_files index 316f023f43dcfe781c7819d1681335267ddd5f76..d11a7e5d07c3e6cfa092e7ac11189ce6c272c1ad 100644 --- a/tensorflow/docs_src/performance/leftnav_files +++ b/tensorflow/docs_src/performance/leftnav_files @@ -2,6 +2,7 @@ performance_guide.md datasets_performance.md performance_models.md benchmarks.md +quantization.md ### XLA xla/index.md @@ -11,6 +12,3 @@ xla/jit.md xla/operation_semantics.md xla/shapes.md xla/tfcompile.md - -### Quantization -quantization.md diff --git a/tensorflow/docs_src/performance/performance_guide.md b/tensorflow/docs_src/performance/performance_guide.md index cd47fc2803bc1429d28bd0ae4c2ad68e632a6f03..580a899ac4e4f5c3d97ce023f25083168fe00d01 100644 --- a/tensorflow/docs_src/performance/performance_guide.md +++ b/tensorflow/docs_src/performance/performance_guide.md @@ -78,7 +78,7 @@ training CIFAR-10 illustrates the use of the `tf.data` API along with The `tf.data` API utilizes C++ multi-threading and has a much lower overhead than the Python-based `queue_runner` that is limited by Python's multi-threading performance. A detailed performance guide for the `tf.data` API can be found -[here](#datasets_performance). +[here](@{$datasets_performance}). While feeding data using a `feed_dict` offers a high level of flexibility, in general `feed_dict` does not provide a scalable solution. If only a single GPU diff --git a/tensorflow/docs_src/performance/quantization.md b/tensorflow/docs_src/performance/quantization.md index 544274cab68934419e8601a4d9714d80335fca28..411889cb1c616130f809e6228cc692ba3f951d48 100644 --- a/tensorflow/docs_src/performance/quantization.md +++ b/tensorflow/docs_src/performance/quantization.md @@ -1,226 +1,253 @@ -# How to Quantize Neural Networks with TensorFlow - -When modern neural networks were being developed, the biggest challenge was -getting them to work at all! That meant that accuracy and speed during training -were the top priorities. Using floating point arithmetic was the easiest way to -preserve accuracy, and GPUs were well-equipped to accelerate those calculations, -so it's natural that not much attention was paid to other numerical formats. - -These days, we actually have a lot of models being deployed in commercial -applications. The computation demands of training grow with the number of -researchers, but the cycles needed for inference expand in proportion to users. -That means pure inference efficiency has become a burning issue for a lot of -teams. - -That is where quantization comes in. It's an umbrella term that covers a lot of -different techniques to store numbers and perform calculations on them in more -compact formats than 32-bit floating point. I am going to focus on eight-bit -fixed point, for reasons I'll go into more detail on later. - -[TOC] - -## Why does Quantization Work? - -Training neural networks is done by applying many tiny nudges to the weights, -and these small increments typically need floating point precision to work -(though there are research efforts to use quantized representations here too). - -Taking a pre-trained model and running inference is very different. One of the -magical qualities of deep networks is that they tend to cope very well with high -levels of noise in their inputs. If you think about recognizing an object in a -photo you've just taken, the network has to ignore all the CCD noise, lighting -changes, and other non-essential differences between it and the training -examples it's seen before, and focus on the important similarities instead. This -ability means that they seem to treat low-precision calculations as just another -source of noise, and still produce accurate results even with numerical formats -that hold less information. - -## Why Quantize? - -Neural network models can take up a lot of space on disk, with the original -AlexNet being over 200 MB in float format for example. Almost all of that size -is taken up with the weights for the neural connections, since there are often -many millions of these in a single model. Because they're all slightly different -floating point numbers, simple compression formats like zip don't compress them -well. They are arranged in large layers though, and within each layer the -weights tend to be normally distributed within a certain range, for example -3.0 -to 6.0. - -The simplest motivation for quantization is to shrink file sizes by storing the -min and max for each layer, and then compressing each float value to an -eight-bit integer representing the closest real number in a linear set of 256 -within the range. For example with the -3.0 to 6.0 range, a 0 byte would -represent -3.0, a 255 would stand for 6.0, and 128 would represent about 1.5. -I'll go into the exact calculations later, since there's some subtleties, but -this means you can get the benefit of a file on disk that's shrunk by 75%, and -then convert back to float after loading so that your existing floating-point -code can work without any changes. - -Another reason to quantize is to reduce the computational resources you need to -do the inference calculations, by running them entirely with eight-bit inputs -and outputs. This is a lot more difficult since it requires changes everywhere -you do calculations, but offers a lot of potential rewards. Fetching eight-bit -values only requires 25% of the memory bandwidth of floats, so you'll make much -better use of caches and avoid bottlenecking on RAM access. You can also -typically use SIMD operations that do many more operations per clock cycle. In -some case you'll have a DSP chip available that can accelerate eight-bit -calculations too, which can offer a lot of advantages. - -Moving calculations over to eight bit will help you run your models faster, and -use less power (which is especially important on mobile devices). It also opens -the door to a lot of embedded systems that can't run floating point code -efficiently, so it can enable a lot of applications in the IoT world. - -## Why Not Train in Lower Precision Directly? - -There have been some experiments training at lower bit depths, but the results -seem to indicate that you need higher than eight bit to handle the back -propagation and gradients. That makes implementing the training more -complicated, and so starting with inference made sense. We also already have a -lot of float models already that we use and know well, so being able to convert -them directly is very convenient. - -## How Can You Quantize Your Models? - -TensorFlow has production-grade support for eight-bit calculations built in. It -also has a process for converting many models trained in floating-point over to -equivalent graphs using quantized calculations for inference. For example, -here's how you can translate the latest GoogLeNet model into a version that uses -eight-bit computations: - -```sh -curl -L "https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz" | - tar -C tensorflow/examples/label_image/data -xz -bazel build tensorflow/tools/graph_transforms:transform_graph -bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ - --in_graph=tensorflow/examples/label_image/data/inception_v3_2016_08_28_frozen.pb \ - --out_graph=/tmp/quantized_graph.pb \ - --inputs=input \ - --outputs=InceptionV3/Predictions/Reshape_1 \ - --transforms='add_default_attributes strip_unused_nodes(type=float, shape="1,299,299,3") - remove_nodes(op=Identity, op=CheckNumerics) fold_constants(ignore_errors=true) - fold_batch_norms fold_old_batch_norms quantize_weights quantize_nodes - strip_unused_nodes sort_by_execution_order' +# Fixed Point Quantization + +Quantization techniques store and calculate numbers in more compact formats. +[TensorFlow Lite](/mobile/tflite/) adds quantization that uses an 8-bit fixed +point representation. + +Since a challenge for modern neural networks is optimizing for high accuracy, the +priority has been improving accuracy and speed during training. Using floating +point arithmetic is an easy way to preserve accuracy and GPUs are designed to +accelerate these calculations. + +However, as more machine learning models are deployed to mobile devices, +inference efficiency has become a critical issue. Where the computational demand +for *training* grows with the amount of models trained on different +architectures, the computational demand for *inference* grows in proportion to +the amount of users. + +## Quantization benefits + + +Using 8-bit calculations help your models run faster and use less power. This is +especially important for mobile devices and embedded applications that can't run +floating point code efficiently, for example, Internet of Things (IoT) and +robotics devices. There are additional opportunities to extend this support to +more backends and research lower precision networks. + +### Smaller file sizes {: .hide-from-toc} + +Neural network models require a lot of space on disk. For example, the original +AlexNet requires over 200 MB for the float format—almost all of that for the +model's millions of weights. Because the weights are slightly different +floating point numbers, simple compression formats perform poorly (like zip). + +Weights fall in large layers of numerical values. For each layer, weights tend to +be normally distributed within a range. Quantization can shrink file sizes by +storing the minimum and maximum weight for each layer, then compress each +weight's float value to an 8-bit integer representing the closest real number in +a linear set of 256 within the range. + +### Faster inference {: .hide-from-toc} + +Since calculations are run entirely on 8-bit inputs and outputs, quantization +reduces the computational resources needed for inference calculations. This is +more involved, requiring changes to all floating point calculations, but results +in a large speed-up for inference time. + +### Memory efficiency {: .hide-from-toc} + +Since fetching 8-bit values only requires 25% of the memory bandwidth of floats, +more efficient caches avoid bottlenecks for RAM access. In many cases, the power +consumption for running a neural network is dominated by memory access. The +savings from using fixed-point 8-bit weights and activations are significant. + +Typically, SIMD operations are available that run more operations per clock +cycle. In some cases, a DSP chip is available that accelerates 8-bit calculations +resulting in a massive speedup. + +## Fixed point quantization techniques + +The goal is to use the same precision for weights and activations during both +training and inference. But an important difference is that training consists of +a forward pass and a backward pass, while inference only uses a forward pass. +When we train the model with quantization in the loop, we ensure that the forward +pass matches precision for both training and inference. + +To minimize the loss in accuracy for fully fixed point models (weights and +activations), train the model with quantization in the loop. This simulates +quantization in the forward pass of a model so weights tend towards values that +perform better during quantized inference. The backward pass uses quantized +weights and activations and models quantization as a straight through estimator. +(See Bengio et al., [2013](https://arxiv.org/abs/1308.3432)) + +Additionally, the minimum and maximum values for activations are determined +during training. This allows a model trained with quantization in the loop to be +converted to a fixed point inference model with little effort, eliminating the +need for a separate calibration step. + +## Quantization training with TensorFlow + +TensorFlow can train models with quantization in the loop. Because training +requires small gradient adjustments, floating point values are still used. To +keep models as floating point while adding the quantization error in the training +loop, @{$array_ops#Fake_quantization$fake quantization} nodes simulate the +effect of quantization in the forward and backward passes. + +Since it's difficult to add these fake quantization operations to all the +required locations in the model, there's a function available that rewrites the +training graph. To create a fake quantized training graph: + +``` +# Build forward pass of model. +loss = tf.losses.get_total_loss() + +# Call the training rewrite which rewrites the graph in-place with +# FakeQuantization nodes and folds batchnorm for training. It is +# often needed to fine tune a floating point model for quantization +# with this training tool. When training from scratch, quant_delay +# can be used to activate quantization after training to converge +# with the float graph, effectively fine-tuning the model. +tf.contrib.quantize.create_training_graph(quant_delay=2000000) + +# Call backward pass optimizer as usual. +optimizer = tf.train.GradientDescentOptimizer(learning_rate) +optimizer.minimize(loss) ``` -This will produce a new model that runs the same operations as the original, but -with eight bit calculations internally, and all weights quantized as well. If -you look at the file size, you'll see it's about a quarter of the original (23MB -versus 91MB). You can still run this model using exactly the same inputs and -outputs though, and you should get equivalent results. Here's an example: +The rewritten *eval graph* is non-trivially different from the *training graph* +since the quantization ops affect the batch normalization step. Because of this, +we've added a separate rewrite for the *eval graph*: -```sh -bazel build tensorflow/examples/label_image:label_image -bazel-bin/tensorflow/examples/label_image/label_image \ ---graph=/tmp/quantized_graph.pb \ +``` +# Build eval model +logits = tf.nn.softmax_cross_entropy_with_logits(...) + +# Call the eval rewrite which rewrites the graph in-place with +# FakeQuantization nodes and fold batchnorm for eval. +tf.contrib.quantize.create_eval_graph() + +# Save the checkpoint and eval graph proto to disk for freezing +# and providing to TFLite. +with open(eval_graph_file, ‘w’) as f: + f.write(str(g.as_graph_def())) +saver = tf.train.Saver() +saver.save(sess, checkpoint_name) +``` + +Methods to rewrite the training and eval graphs are an active area of research +and experimentation. Although rewrites and quantized training might not work or +improve performance for all models, we are working to generalize these +techniques. + +## Generating fully quantized models + +The previously demonstrated after-rewrite eval graph only *simulates* +quantization. To generate real fixed point computations from a trained +quantization model, convert it to a fixed point kernel. Tensorflow Lite supports +this conversion from the graph resulting from `create_eval_graph`. + +First, create a frozen graph that will be the input for the TensorFlow Lite +toolchain: + +``` +bazel build tensorflow/python/tools:freeze_graph && \ + bazel-bin/tensorflow/python/tools/freeze_graph \ + --input_graph=eval_graph_def.pb \ + --input_checkpoint=checkpoint \ + --output_graph=frozen_eval_graph.pb --output_node_names=outputs ``` -You'll see that this runs the newly-quantized graph, and outputs a very similar -answer to the original. - -You can run the same process on your own models saved out as GraphDefs, with the -input and output names adapted to those your network requires. I recommend that -you run them through the freeze_graph script first, to convert checkpoints into -constants stored in the file. - -## How Does the Quantization Process Work? - -We've implemented quantization by writing equivalent eight-bit versions of -operations that are commonly used during inference. These include convolution, -matrix multiplication, activation functions, pooling operations and -concatenation. The conversion script first replaces all the individual ops it -knows about with quantized equivalents. These are small sub-graphs that have -conversion functions before and after to move the data between float and -eight-bit. Below is an example of what they look like. First here's the original -Relu operation, with float inputs and outputs: - -![Relu Diagram](https://www.tensorflow.org/images/quantization0.png) - -Then, this is the equivalent converted subgraph, still with float inputs and -outputs, but with internal conversions so the calculations are done in eight -bit. - -![Converted Diagram](https://www.tensorflow.org/images/quantization1.png) - -The min and max operations actually look at the values in the input float -tensor, and then feeds them into the Dequantize operation that converts the -tensor into eight-bits. There are more details on how the quantized representation -works later on. - -Once the individual operations have been converted, the next stage is to remove -unnecessary conversions to and from float. If there are consecutive sequences of -operations that all have float equivalents, then there will be a lot of adjacent -Dequantize/Quantize ops. This stage spots that pattern, recognizes that they -cancel each other out, and removes them, like this: - -![Stripping Diagram](https://www.tensorflow.org/images/quantization2.png) - -Applied on a large scale to models where all of the operations have quantized -equivalents, this gives a graph where all of the tensor calculations are done in -eight bit, without having to convert to float. - -## What Representation is Used for Quantized Tensors? - -We approach converting floating-point arrays of numbers into eight-bit -representations as a compression problem. We know that the weights and -activation tensors in trained neural network models tend to have values that are -distributed across comparatively small ranges (for example you might have -15 to -+15 for weights, -500 to 1000 for activations on an image model, though the -exact numbers will vary). We also know from experiment that neural nets tend to -be very robust in the face of noise, and so the noise-like error produced by -quantizing down to a small set of values will not hurt the precision of the -overall results very much. We also want to pick a representation that's easy to -perform calculations on, especially the large matrix multiplications that form -the bulk of the work that's needed to run a model. - -These led us to pick a representation that has two floats to store the overall -minimum and maximum values that are represented by the lowest and highest -quantized value. Each entry in the quantized array represents a float value in -that range, distributed linearly between the minimum and maximum. For example, -if we have minimum = -10.0, and maximum = 30.0f, and an eight-bit array, here's -what the quantized values represent: +Provide this to the TensorFlow Lite Optimizing Converter (TOCO) to get a fully +quantized TensorFLow Lite model: ``` -Quantized | Float ---------- | ----- -0 | -10.0 -255 | 30.0 -128 | 10.0 +bazel build tensorflow/contrib/lite/toco:toco && \ + ./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \ + --input_file=frozen_eval_graph.pb \ + --output_file=tflite_model.tflite \ + --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \ + --inference_type=QUANTIZED_UINT8 \ + --input_shape="1,224, 224,3" \ + --input_array=input \ + --output_array=outputs \ + --std_value=127.5 --mean_value=127.5 ``` -The advantages of this format are that it can represent arbitrary magnitudes of -ranges, they don't have to be symmetrical, it can represent signed and unsigned -values, and the linear spread makes doing multiplications straightforward. There -are alternatives like [Song Han's code books](http://arxiv.org/pdf/1510.00149.pdf) -that can use lower bit depths by non-linearly distributing the float values -across the representation, but these tend to be more expensive to calculate on. - -The advantage of having a strong and clear definition of the quantized format is -that it's always possible to convert back and forth from float for operations -that aren't quantization-ready, or to inspect the tensors for debugging -purposes. One implementation detail in TensorFlow that we're hoping to improve -in the future is that the minimum and maximum float values need to be passed as -separate tensors to the one holding the quantized values, so graphs can get a -bit dense! - -The nice thing about the minimum and maximum ranges is that they can often be -pre-calculated. Weight parameters are constants known at load time, so their -ranges can also be stored as constants. We often know the ranges for inputs (for -examples images are usually RGB values in the range 0.0 to 255.0), and many -activation functions have known ranges too. This can avoid having to analyze the -outputs of an operation to determine the range, which we need to do for math ops -like convolution or matrix multiplication which produce 32-bit accumulated -results from 8-bit inputs. - -## What's Next? - -We've found that we can get extremely good performance on mobile and embedded -devices by using eight-bit arithmetic rather than floating-point. You can see -the framework we use to optimize matrix multiplications at -[gemmlowp](https://github.com/google/gemmlowp). We still need to apply all the -lessons we've learned to the TensorFlow ops to get maximum performance on -mobile, but we're actively working on that. Right now, this quantized -implementation is a reasonably fast and accurate reference implementation that -we're hoping will enable wider support for our eight-bit models on a wider -variety of devices. We also hope that this demonstration will encourage the -community to explore what's possible with low-precision neural networks. +See the documentation for @{tf.contrib.quantize} and +[TensorFlow Lite](/mobile/tflite/). + +## Quantized accuracy + +Fixed point [MobileNet](https://arxiv.org/abs/1704.0486) models are released with +8-bit weights and activations. Using the rewriters, these models achieve the +Top-1 accuracies listed in Table 1. For comparison, the floating point accuracies +are listed for the same models. The code used to generate these models +[is available](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md) +along with links to all of the pretrained mobilenet_v1 models. + +
+
Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
tensorflow-1.7.0rc1CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
tensorflow_gpu-1.7.0rc1GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.379
tensorflow-1.6.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
tensorflow_gpu-1.6.0GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.379
tensorflow-1.5.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
+ + + + + + + + + + + + + + + + + + + + + + +
Image SizeDepthTop-1 Accuracy:
Floating point
Top-1 Accuracy:
Fixed point: 8 bit weights and activations
1280.250.4150.399
1280.50.5630.549
1280.750.6210.598
12810.6520.64
1600.250.4550.435
1600.50.5910.577
1600.750.6530.639
16010.680.673
1920.250.4770.458
1920.50.6170.604
1920.750.6720.662
19210.70.69
2240.250.4980.482
2240.50.6330.622
2240.750.6840.679
22410.7090.697
+
+ Table 1: MobileNet Top-1 accuracy on Imagenet Validation dataset. +
+ + +## Representation for quantized tensors + +TensorFlow approaches the conversion of floating-point arrays of numbers into +8-bit representations as a compression problem. Since the weights and activation +tensors in trained neural network models tend to have values that are distributed +across comparatively small ranges (for example, -15 to +15 for weights or -500 to +1000 for image model activations). And since neural nets tend to be robust +handling noise, the error introduced by quantizing to a small set of values +maintains the precision of the overall results within an acceptable threshold. A +chosen representation must perform fast calculations, especially the large matrix +multiplications that comprise the bulk of the computations while running a model. + +This is represented with two floats that store the overall minimum and maximum +values corresponding to the lowest and highest quantized value. Each entry in the +quantized array represents a float value in that range, distributed linearly +between the minimum and maximum. For example, with a minimum of -10.0 and maximum +of 30.0f, and an 8-bit array, the quantized values represent the following: + +
+ + + + + +
QuantizedFloat
0-10.0
25530.0
12810.0
+
+ Table 2: Example quantized value range +
+
+ +The advantages of this representation format are: + +* It efficiently represents an arbitrary magnitude of ranges. +* The values don't have to be symmetrical. +* The format represents both signed and unsigned values. +* The linear spread makes multiplications straightforward. + +Alternative techniques use lower bit depths by non-linearly distributing the +float values across the representation, but currently are more expensive in terms +of computation time. (See Han et al., +[2016](https://arxiv.org/abs/1510.00149).) + +The advantage of having a clear definition of the quantized format is that it's +always possible to convert back and forth from fixed-point to floating-point for +operations that aren't quantization-ready, or to inspect the tensors for +debugging. diff --git a/tensorflow/docs_src/performance/xla/jit.md b/tensorflow/docs_src/performance/xla/jit.md index d4dc3e57c8fb5ec2a979b6ba7ebe2a3b6c3a5f94..d9a979ccbd31773b9d227ff946486706844a8f81 100644 --- a/tensorflow/docs_src/performance/xla/jit.md +++ b/tensorflow/docs_src/performance/xla/jit.md @@ -157,7 +157,7 @@ to fuse Ops is visible by starting at `hlo_graph_0.dot` and viewing each diagram in succession. To Render the .dot file into a png, install -[GraphViz](http://www.graphviz.org/Download..php) and run: +[GraphViz](https://www.graphviz.org/download/) and run: ```shell dot -Tpng hlo_graph_80.dot -o hlo_graph_80.png diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index b0abf5fdd2e0d8c3c20ae4bcd8f185124028df04..5e39e710a0dba74dfd68a04367ce402362520590 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -45,27 +45,30 @@ feature dimension in `operand`), the operation calculates the gradients with respect to `operand`, `offset` and `scale` across all the other dimensions. The `feature_index` must be a valid index for the feature dimension in `operand`. -The three gradients are defined by the following formulas: +The three gradients are defined by the following formulas (Assuming a +4-dimensional tensor as `operand` and (l) is the index for feature dimension): -\\( \nabla x = \nabla y * \gamma * \sqrt{\sigma^2+\epsilon} \\) +\\( coef_l = \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h (\nabla y_{ijkl} * (x_{ijkl} - \mu_l) / (\sigma^2_{l}+\epsilon)) \\) -\\( \nabla \gamma = sum(\nabla y * (x - \mu) * \sqrt{\sigma^2 + \epsilon}) \\) +\\( \nabla x_{ijkl} = \gamma_{l} * (1/\sqrt{\sigma^2_{l}+\epsilon}) * [\nabla y_{ijkl} - mean(\nabla y) - (x_{ijkl} - \mu_{l}) * coef_l] \\) -\\( \nabla \beta = sum(\nabla y) \\) +\\( \nabla \beta_l = \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} \\) + +\\( \nabla \gamma_l = \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} * ((x_{ijkl} - \mu_l) / \sqrt{\sigma^2_{l}+\epsilon}) \\) The inputs `mean` and `variance` represents moments value across batch and spatial dimensions. The output type is a tuple of three handles: -|Outputs | Type | Semantics | -|------------- | ----------------------- | ------------------------------------| -|`grad_operand`| `ComputationDataHandle` | gradient with respect to input | -: : : `operand` : -|`grad_scale` | `ComputationDataHandle` | gradient with respect to input | -: : : `scale` : -|`grad_offset` | `ComputationDataHandle` | gradient with respect to input | -: : : `offset` : +|Outputs | Type | Semantics | +|------------- | ----------------------- | ------------------------------------ | +|`grad_operand`| `ComputationDataHandle` | gradient with respect to input | +: : : `operand` (\\( \nabla x\\)) : +|`grad_scale` | `ComputationDataHandle` | gradient with respect to input | +: : : `scale` (\\( \nabla \gamma\\)) : +|`grad_offset` | `ComputationDataHandle` | gradient with respect to input | +: : : `offset`(\\( \nabla \beta\\)) : ## BatchNormInference @@ -119,7 +122,7 @@ Normalizes an array across batch and spatial dimensions. | Arguments | Type | Semantics | | --------------- | ----------------------- | -------------------------------- | | `operand` | `ComputationDataHandle` | n dimensional array to be | -: : : normalized : +: : : normalized (x) : | `scale` | `ComputationDataHandle` | 1 dimensional array | : : : (\\(\gamma\\)) : | `offset` | `ComputationDataHandle` | 1 dimensional array | @@ -254,7 +257,7 @@ the range between the minimum and maximum, else returns the minimum value if the operand is below this range or the maximum value if the operand is above this range. That is, `clamp(a, x, b) = min(max(a, x), b)`. -All three arrays must be the same shape. Alternately, as a restricted form of +All three arrays must be the same shape. Alternatively, as a restricted form of [broadcasting](broadcasting.md), `min` and/or `max` can be a scalar of type `T`. Example with scalar `min` and `max`: @@ -1050,6 +1053,9 @@ For a more intuitive description, see the "Informal Description" section below. : : : indices of the slices we're : : : : we're stitching together into : : : : the output tensor. : +|`index_vector_dim` | `int64` | The dimension in | +: : : `gather_indices` that contains : +: : : the starting indices. : |`output_window_dims` | `ArraySlice` | The set of dimensions in the | : : : output shape that are _window : : : : dimensions_ (defined below). : @@ -1066,22 +1072,20 @@ For a more intuitive description, see the "Informal Description" section below. : : : `output_window_dims`) and the window : : : : dimensions that are elided (via : : : : `elided_window_dims`). : -|`gather_dims_to_operand_dims` | `ArraySlice` | A dimension map (the | +|`gather_dims_to_operand_dims` | `ArraySlice` | A dimension map (the | : : : array is interpreted as mapping `i` to : : : : `gather_dims_to_operand_dims[i]`) from : : : : the gather indices in `gather_indices` to : : : : the operand index space. It has to be : : : : one-to-one and total. : -If `gather_indices` is a vector with `N` elements then we implicitly reshape it -to a tensor of shape `[N,1]` before proceeding. - For every index `Out` in the output tensor, we compute two things (more precisely described later): - - An index into the first `gather_indices.rank` - `1` dimensions of - `gather_indices`, which gives us a starting index of a slice, _operand - slice_, in the operand tensor. + - An index into `gather_indices.rank` - `1` dimensions of `gather_indices`, + which gives us a starting index of a slice, _operand slice_, in the operand + tensor. These `gather_indices.rank` - `1` dimensions are all the dimensions + in `gather_indices` except `index_vector_dim`. - A _window index_ that has the same rank as the operand. This index is composed of the values in `Out` at dimensions `output_window_dims`, embedded @@ -1093,29 +1097,42 @@ should be present in the output at index `Out`. The output is a tensor of rank `output_window_dims.size` + `gather_indices.rank` - `1`. Additionally, as a shorthand, we define `output_gather_dims` of type `ArraySlice` as the set of dimensions in the output shape but not in -`output_window_dims`, in ascending order. E.g. if the output tensor has rank 5, -`output_window_dims` is {`2`, `4`} then `output_gather_dims` is {`0`, `1`, `3`} +`output_window_dims`, in ascending order. E.g. if the output tensor has rank +`5`, `output_window_dims` is {`2`, `4`} then `output_gather_dims` is {`0`, `1`, +`3`} + +If `index_vector_dim` is equal to `gather_indices.rank` we implicitly +consider `gather_indices` to have a trailing `1` dimension (i.e. if +`gather_indices` was of shape `[6,7]` and `index_vector_dim` is `2` then +we implicitly consider the shape of `gather_indices` to be `[6,7,1]`). The bounds for the output tensor along dimension `i` is computed as follows: 1. If `i` is present in `output_gather_dims` (i.e. is equal to - `output_gather_dims[k]` for some `k`) then we pick the corresponding - dimension bounds out of `gather_indices.shape` (i.e. pick - `gather_indices.shape.dims[k]`). + `output_gather_dims[k]` for some `k`) then we pick the corresponding + dimension bounds out of `gather_indices.shape`, skipping + `index_vector_dim` (i.e. pick `gather_indices.shape.dims`[`k`] if `k` + < `index_vector_dim` and `gather_indices.shape.dims`[`k`+`1`] + otherwise). 2. If `i` is present in `output_window_dims` (i.e. equal to - `output_window_dims[k]` for some `k`) then we pick the corresponding bound - out of `window_bounds` after accounting for `elided_window_dims` (i.e. we - pick `adjusted_window_bounds[k]` where `adjusted_window_bounds` is - `window_bounds` with the bounds at indices `elided_window_dims` removed). + `output_window_dims`[`k`] for some `k`) then we pick the corresponding + bound out of `window_bounds` after accounting for `elided_window_dims` + (i.e. we pick `adjusted_window_bounds`[`k`] where `adjusted_window_bounds` + is `window_bounds` with the bounds at indices `elided_window_dims` + removed). The operand index `In` corresponding to an output index `Out` is computed as follows: 1. Let `G` = { `Out`[`k`] for `k` in `output_gather_dims` }. Use `G` to slice - out vector `S` such that `S`[`i`] = `gather_indices`[`G`, `i`]. - 2. Create an index, `S``in`, into `operand` using `S` by scattering - `S` using the `gather_dims_to_operand_dims` map (`S``in` is the - starting indices for _operand slice_ mentioned above.). More precisely: + out vector `S` such that `S`[`i`] = `gather_indices`[Combine(`G`, `i`)] + where Combine(A, b) inserts b at position `index_vector_dim` into A. + Note that this is well defined even if `G` is empty -- if `G` is empty then + `S` = `gather_indices`. + 2. Create an index, `S``in`, into `operand` using `S` by + scattering `S` using the `gather_dims_to_operand_dims` map + (`S``in` is the starting indices for _operand slice_ mentioned + above). More precisely: 1. `S``in`[`gather_dims_to_operand_dims`[`k`]] = `S`[`k`] if `k` < `gather_dims_to_operand_dims.size`. 2. `S``in`[`_`] = `0` otherwise. @@ -1136,7 +1153,12 @@ follows: `operand.rank` is `6` and `elided_window_dims` is {`0`, `2`} then `window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}. -### Informal Description +### Informal Description and Examples + +`index_vector_dim` is set to `gather_indices.rank` - `1` in all of the +examples that follow. More interesting values for `index_vector_dim` +does not change the operation fundamentally, but makes the visual representation +more cumbersome. To get an intuition on how all of the above fits together, let's look at an example that gathers 5 slices of shape `[8,6]` from a `[16,11]` tensor. The diff --git a/tensorflow/docs_src/programmers_guide/datasets.md b/tensorflow/docs_src/programmers_guide/datasets.md index d19200e80cdfe6620789ddd273647660c10b2a60..9ccdbde627e6b2415835f7c0771eca1afa04f7f8 100644 --- a/tensorflow/docs_src/programmers_guide/datasets.md +++ b/tensorflow/docs_src/programmers_guide/datasets.md @@ -18,11 +18,11 @@ The `tf.data` API introduces two new abstractions to TensorFlow: tensors representing the image data and a label. There are two distinct ways to create a dataset: - * Creating a **source** (e.g. `Dataset.from_tensor_slices()`) constructs a + * Creating a **source** (e.g. `Dataset.from_tensor_slices()`) constructs a dataset from one or more `tf.Tensor` objects. - * Applying a **transformation** (e.g. `Dataset.batch()`) constructs a dataset + * Applying a **transformation** (e.g. `Dataset.batch()`) constructs a dataset from one or more `tf.data.Dataset` objects. * A `tf.data.Iterator` provides the main way to extract elements from a @@ -327,6 +327,35 @@ same op/node (created by `Iterator.get_next()`). Therefore, evaluating *any* of these tensors will advance the iterator for all components. A typical consumer of an iterator will include all components in a single expression. +### Saving iterator state + +The @{tf.contrib.data.make_saveable_from_iterator} function creates a +`SaveableObject` from an iterator, which can be used to save and +restore the current state of the iterator (and, effectively, the whole input +pipeline). A saveable object thus created can be added to @{tf.train.Saver} +variables list or the `tf.GraphKeys.SAVEABLE_OBJECTS` collection for saving and +restoring in the same manner as a @{tf.Variable}. Refer to +@{$saved_model$Saving and Restoring} for details on how to save and restore +variables. + +```python +# Create saveable object from iterator. +saveable = tf.contrib.data.make_saveable_from_iterator(iterator) + +# Save the iterator state by adding it to the saveable objects collection. +tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable) +saver = tf.train.Saver() + +with tf.Session() as sess: + + if should_checkpoint: + saver.save(path_to_checkpoint) + +# Restore the iterator state. +with tf.Session() as sess: + saver.restore(sess, path_to_checkpoint) +``` + ## Reading input data ### Consuming NumPy arrays diff --git a/tensorflow/docs_src/programmers_guide/debugger.md b/tensorflow/docs_src/programmers_guide/debugger.md index c8fdae6f60c33776b6d9a8c1a33666ce4ddb1cb2..d1cd7e7c06e525abd9fadf24d5e706780bb316fc 100644 --- a/tensorflow/docs_src/programmers_guide/debugger.md +++ b/tensorflow/docs_src/programmers_guide/debugger.md @@ -23,8 +23,13 @@ debuggers such as Python's `pdb` due to TensorFlow's computation-graph paradigm. > installed using `pip install .whl`, however curses on Windows > may not work as reliably as curses on Linux or Mac. -This tutorial demonstrates how to use the **tfdbg** command-line interface -(CLI) to debug the appearance of [`nan`s](https://en.wikipedia.org/wiki/NaN) +> NOTE: This guide focuses on the command-line interface (CLI) of tfdbg. For +> guide on how to use the graphical user interface (GUI) of tfdbg, i.e., the +> **TensorBoard Debugger Plugin**, please visit +> [its README](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/debugger/README.md). + +This tutorial demonstrates how to use the **tfdbg** CLI to debug the appearance +of [`nan`s](https://en.wikipedia.org/wiki/NaN) and [`inf`s](https://en.wikipedia.org/wiki/Infinity), a frequently-encountered type of bug in TensorFlow model development. The following example is for users who use the low-level @@ -150,6 +155,7 @@ Try the following commands at the `tfdbg>` prompt (referencing the code at | | `-n ` | List dumped tensors with names matching given regular-expression pattern. | `lt -n Softmax.*` | | | `-t ` | List dumped tensors with op types matching given regular-expression pattern. | `lt -t MatMul` | | | `-f ` | List only the tensors that pass a registered tensor filter. | `lt -f has_inf_or_nan` | +| | `-f -fenn ` | List only the tensors that pass a registered tensor filter, excluding nodes with names matching the regular expression. | `lt -f has_inf_or_nan` `-fenn .*Sqrt.*` | | | `-s ` | Sort the output by given `sort_key`, whose possible values are `timestamp` (default), `dump_size`, `op_type` and `tensor_name`. | `lt -s dump_size` | | | `-r` | Sort in reverse order. | `lt -r -s dump_size` | | **`pt`** | | **Print value of a dumped tensor.** | | @@ -195,6 +201,7 @@ Try the following commands at the `tfdbg>` prompt (referencing the code at | | `-n` | Execute through the next `Session.run` without debugging, and drop to CLI right before the run after that. | `run -n` | | | `-t ` | Execute `Session.run` `T - 1` times without debugging, followed by a run with debugging. Then drop to CLI right after the debugged run. | `run -t 10` | | | `-f ` | Continue executing `Session.run` until any intermediate tensor triggers the specified Tensor filter (causes the filter to return `True`). | `run -f has_inf_or_nan` | +| | `-f -fenn ` | Continue executing `Session.run` until any intermediate tensor whose node names doesn't match the regular expression triggers the specified Tensor filter (causes the filter to return `True`). | `run -f has_inf_or_nan -fenn .*Sqrt.*` | | | `--node_name_filter ` | Execute the next `Session.run`, watching only nodes with names matching the given regular-expression pattern. | `run --node_name_filter Softmax.*` | | | `--op_type_filter ` | Execute the next `Session.run`, watching only nodes with op types matching the given regular-expression pattern. | `run --op_type_filter Variable.*` | | | `--tensor_dtype_filter ` | Execute the next `Session.run`, dumping only Tensors with data types (`dtype`s) matching the given regular-expression pattern. | `run --tensor_dtype_filter int.*` | @@ -454,7 +461,7 @@ accuracy_score = classifier.evaluate(x=test_set.data, [debug_tflearn_iris.py](https://www.tensorflow.org/code/tensorflow/python/debug/examples/debug_tflearn_iris.py), -based on {$tflearn$tf-learn's iris tutorial}, contains a full example of how to +based on [tf-learn's iris tutorial](https://www.tensorflow.org/versions/r1.2/get_started/tflearn), contains a full example of how to use the tfdbg with `Estimator`s. To run this example, do: ```none @@ -748,6 +755,7 @@ There are three possible workarounds or solutions: # For LocalCLIDebugHook hooks = [tf_debug.LocalCLIDebugHook(dump_root="/with/lots/of/space")] ``` + Make sure that the directory pointed to by dump_root is empty or nonexistent. tfdbg cleans up the dump directories before exiting. * Reduce the batch size used during the runs. @@ -806,3 +814,27 @@ sess.run(b) the constant-folding would not occur and `tfdbg` should show the intermediate tensor dumps. + + +**Q**: I am debugging a model that generates unwanted infinities or NaNs. But + there are some nodes in my model that are known to generate infinities + or NaNs in their output tensors even under completely normal conditions. + How can I skip those nodes during my `run -f has_inf_or_nan` actions? + +**A**: Use the `--filter_exclude_node_names` (`-fenn` for short) flag. For + example, if you known you have a node with name matching the regular + expression `.*Sqrt.*` that generates infinities or NaNs regardless + of whether the model is behaving correctly, you can exclude the nodes + from the infinity/NaN-finding runs with the command + `run -f has_inf_or_nan -fenn .*Sqrt.*`. + + +**Q**: Is there a GUI for tfdbg? + +**A**: Yes, the **TensorBoard Debugger Plugin** is the GUI of tfdbg. + It offers features such as inspection of the computation graph, + real-time visualization of tensor values, continuation to tensor + and conditional breakpoints, and tying tensors to their + graph-construction source code, all in the browser environment. + To get started, please visit + [its README](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/debugger/README.md). diff --git a/tensorflow/docs_src/programmers_guide/embedding.md b/tensorflow/docs_src/programmers_guide/embedding.md index e8027fc12b368ddfbc51cc47441478901d7caec7..d5703e07375b1f68f4e22476288f1ed57d340c5b 100644 --- a/tensorflow/docs_src/programmers_guide/embedding.md +++ b/tensorflow/docs_src/programmers_guide/embedding.md @@ -7,6 +7,9 @@ with the TensorBoard Embedding Projector newcomers to machine learning or TensorFlow, and the Embedding Projector how-to is for users at all levels. +An alternative tutorial on these concepts is available in the +[Embeddings section of Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/embeddings/video-lecture). + [TOC] An **embedding** is a mapping from discrete objects, such as words, to vectors diff --git a/tensorflow/docs_src/programmers_guide/faq.md b/tensorflow/docs_src/programmers_guide/faq.md index 70931f2862de98cb1e934f85919d558a3b36304a..392ac6f7f12532c3efce5bec1917691f55c7bee5 100644 --- a/tensorflow/docs_src/programmers_guide/faq.md +++ b/tensorflow/docs_src/programmers_guide/faq.md @@ -159,8 +159,7 @@ available. These operations allow you to build sophisticated @{$reading_data$input pipelines}, at the cost of making the TensorFlow computation somewhat more complicated. See the how-to documentation for -@{$reading_data#creating-threads-to-prefetch-using-queuerunner-objects$using -`QueueRunner` objects to drive queues and readers} +@{$reading_data#creating_threads_to_prefetch_using_queuerunner_objects$using `QueueRunner` objects to drive queues and readers} for more information on how to use them. ## Variables @@ -273,7 +272,7 @@ Prefer predefined TensorFlow operations such as @{tf.decode_raw}, If your data is not easily parsable with the built-in TensorFlow operations, consider converting it, offline, to a format that is easily parsable, such -as ${tf.python_io.TFRecordWriter$`TFRecord`} format. +as @{tf.python_io.TFRecordWriter$`TFRecord`} format. The more efficient method to customize the parsing behavior is to @{$adding_an_op$add a new op written in C++} that parses your diff --git a/tensorflow/docs_src/programmers_guide/graphs.md b/tensorflow/docs_src/programmers_guide/graphs.md index 9049a5a9f3d44e255188c6c41cdb12a619464379..e69b717432e6a8fab0085eb419dcbc0991cd9d28 100644 --- a/tensorflow/docs_src/programmers_guide/graphs.md +++ b/tensorflow/docs_src/programmers_guide/graphs.md @@ -210,9 +210,8 @@ with tf.device("/device:GPU:0"): # Operations created in this context will be pinned to the GPU. result = tf.matmul(weights, img) ``` - -If you are deploying TensorFlow in a @{$deploy/distributed$typical distributed -configuration}, you might specify the job name and task ID to place variables on +If you are deploying TensorFlow in a @{$deploy/distributed$typical distributed configuration}, +you might specify the job name and task ID to place variables on a task in the parameter server job (`"/job:ps"`), and the other operations on task in the worker job (`"/job:worker"`): @@ -336,20 +335,20 @@ described below. controls the behavior of the session. For example, some of the configuration options include: - * `allow_soft_placement`. Set this to `True` to enable a "soft" device + * `allow_soft_placement`. Set this to `True` to enable a "soft" device placement algorithm, which ignores @{tf.device} annotations that attempt to place CPU-only operations on a GPU device, and places them on the CPU instead. - * `cluster_def`. When using distributed TensorFlow, this option allows you + * `cluster_def`. When using distributed TensorFlow, this option allows you to specify what machines to use in the computation, and provide a mapping between job names, task indices, and network addresses. See @{tf.train.ClusterSpec.as_cluster_def} for details. - * `graph_options.optimizer_options`. Provides control over the optimizations + * `graph_options.optimizer_options`. Provides control over the optimizations that TensorFlow performs on your graph before executing it. - * `gpu_options.allow_growth`. Set this to `True` to change the GPU memory + * `gpu_options.allow_growth`. Set this to `True` to change the GPU memory allocator so that it gradually increases the amount of memory allocated, rather than allocating most of the memory at startup. diff --git a/tensorflow/docs_src/programmers_guide/index.md b/tensorflow/docs_src/programmers_guide/index.md index 7a5e90081d9145ca934929f0af11f2a40cb2dcae..e8c2fa6990c8ecfca1cfe76b3f813b4ae6917742 100644 --- a/tensorflow/docs_src/programmers_guide/index.md +++ b/tensorflow/docs_src/programmers_guide/index.md @@ -30,8 +30,12 @@ works. The units are as follows: can still be helpful. * @{$programmers_guide/saved_model}, which explains how to save and restore variables and models. + +## Accelerators + * @{$using_gpu} explains how TensorFlow assigns operations to devices and how you can change the arrangement manually. + * @{$using_tpu} explains how to modify `Estimator` programs to run on a TPU. ## ML Concepts diff --git a/tensorflow/docs_src/programmers_guide/saved_model.md b/tensorflow/docs_src/programmers_guide/saved_model.md index f18d50b282400810b4869f78ba7f536ad5ea4798..55ee42dd6405db6bd34b064d71deaeb94839b0fa 100644 --- a/tensorflow/docs_src/programmers_guide/saved_model.md +++ b/tensorflow/docs_src/programmers_guide/saved_model.md @@ -1,38 +1,33 @@ -# Saving and Restoring +# Save and Restore -This document explains how to save and restore -@{$variables$variables} and models. +The @{tf.train.Saver} class provides methods to save and restore models. The +@{tf.saved_model.simple_save} function is an easy way to build a +@{tf.saved_model$saved model} suitable for serving. +[Estimators](@{$programmers_guide/estimators}) automatically save and restore +variables in the `model_dir`. -Important: TensorFlow model files are code. Be careful with untrusted code. -See [Using TensorFlow Securely](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/SECURITY.md) -for details. - -## Saving and restoring variables - -A TensorFlow variable provides the best way to represent shared, persistent -state manipulated by your program. (See @{$variables$Variables} for details.) -This section explains how to save and restore variables. -Note that Estimators automatically saves and restores variables -(in the `model_dir`). +## Save and restore variables -The `tf.train.Saver` class provides methods for saving and restoring models. -The `tf.train.Saver` constructor adds `save` and `restore` ops to the graph -for all, or a specified list, of the variables in the graph. The `Saver` -object provides methods to run these ops, specifying paths for the checkpoint -files to write to or read from. +TensorFlow @{$variables} are the best way to represent shared, persistent state +manipulated by your program. The `tf.train.Saver` constructor adds `save` and +`restore` ops to the graph for all, or a specified list, of the variables in the +graph. The `Saver` object provides methods to run these ops, specifying paths +for the checkpoint files to write to or read from. -The saver will restore all variables already defined in your model. If you're +`Saver` restores all variables already defined in your model. If you're loading a model without knowing how to build its graph (for example, if you're writing a generic program to load models), then read the [Overview of saving and restoring models](#models) section later in this document. -TensorFlow saves variables in binary **checkpoint files** that, -roughly speaking, map variable names to tensor values. - +TensorFlow saves variables in binary *checkpoint files* that map variable +names to tensor values. +Caution: TensorFlow model files are code. Be careful with untrusted code. +See [Using TensorFlow Securely](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) +for details. -### Saving variables +### Save variables Create a `Saver` with `tf.train.Saver()` to manage all variables in the model. For example, the following snippet demonstrates how to call the @@ -64,9 +59,7 @@ with tf.Session() as sess: print("Model saved in path: %s" % save_path) ``` - - -### Restoring variables +### Restore variables The `tf.train.Saver` object not only saves variables to checkpoint files, it also restores variables. Note that when you restore variables you do not have @@ -95,14 +88,11 @@ with tf.Session() as sess: print("v2 : %s" % v2.eval()) ``` -Notes: - -* There is not a physical file called "/tmp/model.ckpt". It is the **prefix** - of filenames created for the checkpoint. Users only interact with the - prefix instead of physical checkpoint files. +Note: There is not a physical file called `/tmp/model.ckpt`. It is the *prefix* of +filenames created for the checkpoint. Users only interact with the prefix +instead of physical checkpoint files. - -### Choosing which variables to save and restore +### Choose variables to save and restore If you do not pass any arguments to `tf.train.Saver()`, the saver handles all variables in the graph. Each variable is saved under the name that was passed @@ -201,29 +191,42 @@ chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v2', all_t -## Overview of saving and restoring models +## Save and restore models -When you want to save and load variables, the graph, and the -graph's metadata--basically, when you want to save or restore -your model--we recommend using SavedModel. -**SavedModel** is a language-neutral, recoverable, hermetic -serialization format. SavedModel enables higher-level systems -and tools to produce, consume, and transform TensorFlow models. -TensorFlow provides several mechanisms for interacting with -SavedModel, including tf.saved_model APIs, Estimator APIs and a CLI. +Use `SavedModel` to save and load your model—variables, the graph, and the +graph's metadata. This is a language-neutral, recoverable, hermetic +serialization format that enables higher-level systems and tools to produce, +consume, and transform TensorFlow models. TensorFlow provides several ways to +interact with `SavedModel`, including the @{tf.saved_model} APIs, +@{tf.estimator.Estimator}, and a command-line interface. -## APIs to build and load a SavedModel +## Build and load a SavedModel -This section focuses on the APIs for building and loading a SavedModel, -particularly when using lower-level TensorFlow APIs. +### Simple save +The easiest way to create a `SavedModel` is to use the @{tf.saved_model.simple_save} +function: -### Building a SavedModel +```python +simple_save(session, + export_dir, + inputs={"x": x, "y": y}, + outputs={"z": z}) +``` -We provide a Python implementation of the SavedModel -@{tf.saved_model.builder$builder}. -The `SavedModelBuilder` class provides functionality to +This configures the `SavedModel` so it can be loaded by +[TensorFlow serving](/serving/serving_basic) and supports the +[Predict API](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/predict.proto). +To access the classify, regress, or multi-inference APIs, use the manual +`SavedModel` builder APIs or an @{tf.estimator.Estimator}. + +### Manually build a SavedModel + +If your use case isn't covered by @{tf.saved_model.simple_save}, use the manual +@{tf.saved_model.builder$builder APIs} to create a `SavedModel`. + +The @{tf.saved_model.builder.SavedModelBuilder} class provides functionality to save multiple `MetaGraphDef`s. A **MetaGraph** is a dataflow graph, plus its associated variables, assets, and signatures. A **`MetaGraphDef`** is the protocol buffer representation of a MetaGraph. A **signature** is @@ -253,16 +256,51 @@ with tf.Session(graph=tf.Graph()) as sess: builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING], signature_def_map=foo_signatures, - assets_collection=foo_assets) + assets_collection=foo_assets, + strip_default_attrs=True) ... # Add a second MetaGraphDef for inference. with tf.Session(graph=tf.Graph()) as sess: ... - builder.add_meta_graph([tag_constants.SERVING]) + builder.add_meta_graph([tag_constants.SERVING], strip_default_attrs=True) ... builder.save() ``` + +#### Forward compatibility via `strip_default_attrs=True` + +Following the guidance below gives you forward compatibility only if the set of +Ops has not changed. + +The @{tf.saved_model.builder.SavedModelBuilder$`SavedModelBuilder`} class allows +users to control whether default-valued attributes must be stripped from the +@{$extend/tool_developers#nodes$`NodeDefs`} +while adding a meta graph to the SavedModel bundle. Both +@{tf.saved_model.builder.SavedModelBuilder.add_meta_graph_and_variables$`SavedModelBuilder.add_meta_graph_and_variables`} +and @{tf.saved_model.builder.SavedModelBuilder.add_meta_graph$`SavedModelBuilder.add_meta_graph`} +methods accept a Boolean flag `strip_default_attrs` that controls this behavior. + +If `strip_default_attrs` is `False`, the exported @{tf.MetaGraphDef} will have +the default valued attributes in all its @{tf.NodeDef} instances. +This can break forward compatibility with a sequence of events such as the +following: + +* An existing Op (`Foo`) is updated to include a new attribute (`T`) with a + default (`bool`) at version 101. +* A model producer such as a "trainer binary" picks up this change (version 101) + to the `OpDef` and re-exports an existing model that uses Op `Foo`. +* A model consumer (such as [Tensorflow Serving](/serving)) running an older + binary (version 100) doesn't have attribute `T` for Op `Foo`, but tries to + import this model. The model consumer doesn't recognize attribute `T` in a + `NodeDef` that uses Op `Foo` and therefore fails to load the model. +* By setting `strip_default_attrs` to True, the model producers can strip away + any default valued attributes in the `NodeDefs`. This helps ensure that newly + added attributes with defaults don't cause older model consumers to fail + loading models regenerated with newer training binaries. + +See [compatibility guidance](https://www.tensorflow.org/programmers_guide/version_compat) +for more information. ### Loading a SavedModel in Python @@ -288,7 +326,7 @@ with tf.Session(graph=tf.Graph()) as sess: ``` -### Loading a SavedModel in C++ +### Load a SavedModel in C++ The C++ version of the SavedModel [loader](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/loader.h) @@ -306,7 +344,7 @@ LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagTrain}, &bundle); ``` -### Loading and Serving a SavedModel in TensorFlow Serving +### Load and serve a SavedModel in TensorFlow serving You can easily load and serve a SavedModel with the TensorFlow Serving Model Server binary. See [instructions](https://www.tensorflow.org/serving/setup#installing_using_apt-get) @@ -362,7 +400,7 @@ defined in: After training an `Estimator` model, you may want to create a service from that model that takes requests and returns a result. You can run such a -service locally on your machine or deploy it scalably in the cloud. +service locally on your machine or deploy it in the cloud. To prepare a trained Estimator for serving, you must export it in the standard SavedModel format. This section explains how to: @@ -374,7 +412,7 @@ SavedModel format. This section explains how to: * Serve the model from a local server and request predictions. -### Preparing serving inputs +### Prepare serving inputs During training, an @{$premade_estimators#input_fn$`input_fn()`} ingests data and prepares it for use by the model. At serving time, similarly, a @@ -448,14 +486,15 @@ to expect and how to map them to your model's expected inputs. By contrast, the *output* portion of the signature is determined by the model. -### Performing the export +### Perform the export To export your trained Estimator, call @{tf.estimator.Estimator.export_savedmodel} with the export base path and the `serving_input_receiver_fn`. ```py -estimator.export_savedmodel(export_dir_base, serving_input_receiver_fn) +estimator.export_savedmodel(export_dir_base, serving_input_receiver_fn, + strip_default_attrs=True) ``` This method builds a new graph by first calling the @@ -471,7 +510,7 @@ Session. > Note: It is your responsibility to garbage-collect old exports. > Otherwise, successive exports will accumulate under `export_dir_base`. -### Specifying the outputs of a custom model +### Specify the outputs of a custom model When writing a custom `model_fn`, you must populate the `export_outputs` element of the @{tf.estimator.EstimatorSpec} return value. This is a dict of @@ -503,7 +542,7 @@ indicating which `SignatureDef` will be served when an inference request does not specify one. -### Serving the exported model locally +### Serve the exported model locally For local deployment, you can serve your model using [TensorFlow Serving](https://github.com/tensorflow/serving), an open-source project that loads a @@ -522,7 +561,7 @@ bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 - Now you have a server listening for inference requests via gRPC on port 9000! -### Requesting predictions from a local server +### Request predictions from a local server The server responds to gRPC requests according to the [PredictionService](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/prediction_service.proto#L15) @@ -615,7 +654,7 @@ passing in sample inputs in various formats (for example, Python expressions) and then fetching the output. -### Installing the SavedModel CLI +### Install the SavedModel CLI Broadly speaking, you can install TensorFlow in either of the following two ways: @@ -697,15 +736,15 @@ executing the computation graph later. For example: $ saved_model_cli show --dir \ /tmp/saved_model_dir --tag_set serve --signature_def serving_default The given SavedModel SignatureDef contains the following input(s): -inputs['x'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: x:0 + inputs['x'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: x:0 The given SavedModel SignatureDef contains the following output(s): -outputs['y'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: y:0 + outputs['y'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y:0 Method name is: tensorflow/serving/predict ``` @@ -717,32 +756,32 @@ $ saved_model_cli show --dir /tmp/saved_model_dir --all MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs: signature_def['classify_x2_to_y3']: -The given SavedModel SignatureDef contains the following input(s): -inputs['inputs'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: x2:0 -The given SavedModel SignatureDef contains the following output(s): -outputs['scores'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: y3:0 -Method name is: tensorflow/serving/classify + The given SavedModel SignatureDef contains the following input(s): + inputs['inputs'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: x2:0 + The given SavedModel SignatureDef contains the following output(s): + outputs['scores'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y3:0 + Method name is: tensorflow/serving/classify ... signature_def['serving_default']: -The given SavedModel SignatureDef contains the following input(s): -inputs['x'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: x:0 -The given SavedModel SignatureDef contains the following output(s): -outputs['y'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: y:0 -Method name is: tensorflow/serving/predict + The given SavedModel SignatureDef contains the following input(s): + inputs['x'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: x:0 + The given SavedModel SignatureDef contains the following output(s): + outputs['y'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y:0 + Method name is: tensorflow/serving/predict ``` @@ -842,7 +881,7 @@ For example: `=[{"age":[22,24],"education":["BS","MS"]}]` ``` -#### Save Output +#### Save output By default, the SavedModel CLI writes output to stdout. If a directory is passed to `--outdir` option, the outputs will be saved as npy files named after @@ -851,7 +890,7 @@ output tensor keys under the given directory. Use `--overwrite` to overwrite existing output files. -#### TensorFlow Debugger (tfdbg) Integration +#### TensorFlow debugger (tfdbg) integration If `--tf_debug` option is set, the SavedModel CLI will use the TensorFlow Debugger (tfdbg) to watch the intermediate Tensors and runtime @@ -958,6 +997,3 @@ of checkpoints and assets: Each graph is associated with a specific set of tags, which enables identification during a load or restore operation. - - - diff --git a/tensorflow/docs_src/programmers_guide/summaries_and_tensorboard.md b/tensorflow/docs_src/programmers_guide/summaries_and_tensorboard.md index 05dfdfdc4d2257fc680e7fa99b666ef86e3bef09..fadfa03e78349801d69e0045991a8fa9a0a59df9 100644 --- a/tensorflow/docs_src/programmers_guide/summaries_and_tensorboard.md +++ b/tensorflow/docs_src/programmers_guide/summaries_and_tensorboard.md @@ -16,10 +16,17 @@ TensorBoard is fully configured, it looks like this:
-This tutorial is intended to get you started with simple TensorBoard usage. -There are other resources available as well! The [TensorBoard's GitHub](https://github.com/tensorflow/tensorboard) -has a lot more information on TensorBoard usage, including tips & tricks, and -debugging information. +This 30-minute tutorial is intended to get you started with simple TensorBoard +usage. It assumes a basic understanding of TensorFlow. + +There are other resources available as well! The [TensorBoard GitHub](https://github.com/tensorflow/tensorboard) +has a lot more information on using individual dashboards within TensorBoard +including tips & tricks and debugging information. + +## Setup + +[Install TensorFlow](https://www.tensorflow.org/install/). Installing TensorFlow +via pip should also automatically install TensorBoard. ## Serializing the data @@ -76,7 +83,7 @@ data than you need, though. Instead, consider running the merged summary op every `n` steps. The code example below is a modification of the -@{$layers$simple MNIST tutorial}, +[simple MNIST tutorial](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/mnist.py), in which we have added some summary ops, and run them every ten steps. If you run this and then launch `tensorboard --logdir=/tmp/tensorflow/mnist`, you'll be able to visualize statistics, such as how the weights or accuracy varied during @@ -214,4 +221,5 @@ corner. Each tab represents a set of serialized data that can be visualized. For in depth information on how to use the *graph* tab to visualize your graph, see @{$graph_viz$TensorBoard: Graph Visualization}. -For more usage information on TensorBoard in general, see the [TensorBoard's GitHub](https://github.com/tensorflow/tensorboard). +For more usage information on TensorBoard in general, see the +[TensorBoard GitHub](https://github.com/tensorflow/tensorboard). diff --git a/tensorflow/docs_src/programmers_guide/using_tpu.md b/tensorflow/docs_src/programmers_guide/using_tpu.md index d74d7f3181c9cf44e6c97e13742db682858f4694..a9c2cb3e33d4817b9a35400dcce9227ddd635ff4 100644 --- a/tensorflow/docs_src/programmers_guide/using_tpu.md +++ b/tensorflow/docs_src/programmers_guide/using_tpu.md @@ -129,10 +129,9 @@ my_tpu_estimator = tf.contrib.tpu.TPUEstimator( Typically the `FLAGS` would be set by command line arguments. To switch from training locally to training on a cloud TPU you would need to: - 1) Set `FLAGS.use_tpu` to `True` - 1) Set `FLAGS.tpu_name` so the - `tf.contrib.cluster_resolver.TPUClusterResolver` can find it - 1) Set `FLAGS.model_dir` to a Google Cloud Storage bucket url (`gs://`). +* Set `FLAGS.use_tpu` to `True` +* Set `FLAGS.tpu_name` so the `tf.contrib.cluster_resolver.TPUClusterResolver` can find it +* Set `FLAGS.model_dir` to a Google Cloud Storage bucket url (`gs://`). ## Optimizer diff --git a/tensorflow/docs_src/programmers_guide/version_compat.md b/tensorflow/docs_src/programmers_guide/version_compat.md index e6613cc69f8aedf344fa25b6564889e34cd9bf53..72e427c5f8f0f6581d528f4ead18699736eafd04 100644 --- a/tensorflow/docs_src/programmers_guide/version_compat.md +++ b/tensorflow/docs_src/programmers_guide/version_compat.md @@ -183,7 +183,7 @@ Our versioning scheme has three requirements: * **Forward compatibility** to support scenarios where the producer of a graph or checkpoint is upgraded to a newer version of TensorFlow before the consumer. -* Enable evolving TensorFlow in incompatible ways. For example, removing Ops, +* Enable evolving TensorFlow in incompatible ways. For example, removing ops, adding attributes, and removing attributes. Note that while the `GraphDef` version mechanism is separate from the TensorFlow @@ -245,32 +245,51 @@ contains a main data version which is treated as either `producer` or `TF_CHECKPOINT_VERSION_MIN_CONSUMER`, and `TF_CHECKPOINT_VERSION_MIN_PRODUCER`. +### Add a new attribute with default to an existing op + +Following the guidance below gives you forward compatibility only if the set of +ops has not changed: + +1. If forward compatibility is desired, set `strip_default_attrs` to `True` + while exporting the model using either the + @{tf.saved_model.builder.SavedModelBuilder.add_meta_graph_and_variables$`add_meta_graph_and_variables`} + and @{tf.saved_model.builder.SavedModelBuilder.add_meta_graph$`add_meta_graph`} + methods of the `SavedModelBuilder` class, or + @{tf.estimator.Estimator.export_savedmodel$`Estimator.export_savedmodel`} +2. This strips off the default valued attributes at the time of + producing/exporting the models. This makes sure that the exported + @{tf.MetaGraphDef} does not contain the new op-attribute when the default + value is used. +3. Having this control could allow out-of-date consumers (for example, serving + binaries that lag behind training binaries) to continue loading the models + and prevent interruptions in model serving. + ### Evolving GraphDef versions This section explains how to use this versioning mechanism to make different types of changes to the `GraphDef` format. -#### Add an Op +#### Add an op -Add the new Op to both consumers and producers at the same time, and do not +Add the new op to both consumers and producers at the same time, and do not change any `GraphDef` versions. This type of change is automatically backward compatible, and does not impact forward compatibility plan since existing producer scripts will not suddenly use the new functionality. -#### Add an Op and switch existing Python wrappers to use it +#### Add an op and switch existing Python wrappers to use it 1. Implement new consumer functionality and increment the `GraphDef` version. 2. If it is possible to make the wrappers use the new functionality only in cases that did not work before, the wrappers can be updated now. 3. Change Python wrappers to use the new functionality. Do not increment - `min_consumer`, since models that do not use this Op should not break. + `min_consumer`, since models that do not use this op should not break. -#### Remove or restrict an Op's functionality +#### Remove or restrict an op's functionality -1. Fix all producer scripts (not TensorFlow itself) to not use the banned Op or +1. Fix all producer scripts (not TensorFlow itself) to not use the banned op or functionality. 2. Increment the `GraphDef` version and implement new consumer functionality - that bans the removed Op or functionality for GraphDefs at the new version + that bans the removed op or functionality for GraphDefs at the new version and above. If possible, make TensorFlow stop producing `GraphDefs` with the banned functionality. To do so, add the [`REGISTER_OP(...).Deprecated(deprecated_at_version, @@ -279,15 +298,15 @@ existing producer scripts will not suddenly use the new functionality. 4. Increase `min_producer` to the GraphDef version from (2) and remove the functionality entirely. -#### Change an Op's functionality +#### Change an op's functionality -1. Add a new similar Op named `SomethingV2` or similar and go through the +1. Add a new similar op named `SomethingV2` or similar and go through the process of adding it and switching existing Python wrappers to use it, which may take three weeks if forward compatibility is desired. -2. Remove the old Op (Can only take place with a major version change due to +2. Remove the old op (Can only take place with a major version change due to backward compatibility). -3. Increase `min_consumer` to rule out consumers with the old Op, add back the - old Op as an alias for `SomethingV2`, and go through the process to switch +3. Increase `min_consumer` to rule out consumers with the old op, add back the + old op as an alias for `SomethingV2`, and go through the process to switch existing Python wrappers to use it. 4. Go through the process to remove `SomethingV2`. @@ -295,6 +314,6 @@ existing producer scripts will not suddenly use the new functionality. 1. Bump the `GraphDef` version and add the bad version to `bad_consumers` for all new GraphDefs. If possible, add to `bad_consumers` only for GraphDefs - which contain a certain Op or similar. + which contain a certain op or similar. 2. If existing consumers have the bad version, push them out as soon as possible. diff --git a/tensorflow/docs_src/tutorials/deep_cnn.md b/tensorflow/docs_src/tutorials/deep_cnn.md index 679754020470dddfcffa76e62ca8f55a439ec4f5..6a4c9a9b0727208a158b1b57d13ca70290961ec2 100644 --- a/tensorflow/docs_src/tutorials/deep_cnn.md +++ b/tensorflow/docs_src/tutorials/deep_cnn.md @@ -268,7 +268,7 @@ in `cifar10_input.py`. `cifar10_train.py` periodically @{tf.train.Saver$saves} all model parameters in -@{$variables#saving-and-restoring$checkpoint files} +@{$programmers_guide/saved_model$checkpoint files} but it does *not* evaluate the model. The checkpoint file will be used by `cifar10_eval.py` to measure the predictive performance (see [Evaluating a Model](#evaluating-a-model) below). diff --git a/tensorflow/docs_src/tutorials/image_retraining.md b/tensorflow/docs_src/tutorials/image_retraining.md index df15bc0a9c3763aa51c2fc8cf36ce9fc3544ae68..93d7c86e42aa90d145d27b56edc0abfec7034686 100644 --- a/tensorflow/docs_src/tutorials/image_retraining.md +++ b/tensorflow/docs_src/tutorials/image_retraining.md @@ -115,7 +115,7 @@ process is progressing. The training's objective is to make the loss as small as possible, so you can tell if the learning is working by keeping an eye on whether the loss keeps trending downwards, ignoring the short-term noise. -By default this script will run 4,000 training steps. Each step chooses ten +By default this script will run 4,000 training steps. Each step chooses 100 images at random from the training set, finds their bottlenecks from the cache, and feeds them into the final layer to get predictions. Those predictions are then compared against the actual labels to update the final layer's weights @@ -349,31 +349,32 @@ results, but if you intend to deploy your model on mobile devices or other resource-constrained environments you may want to trade off a little accuracy for much smaller file sizes or faster speeds. To help with that, the [retrain.py script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py) -supports 32 different variations on the [Mobilenet architecture](https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html). +supports different variations on the [Mobilenet architecture](https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html). These are a little less precise than Inception v3, but can result in far -smaller file sizes (down to less than a megabyte) and can be many times faster +smaller file sizes (a few megabytes) and can be many times faster to run. To train with one of these models, pass in the `--architecture` flag, for example: ``` python tensorflow/examples/image_retraining/retrain.py \ - --image_dir ~/flower_photos --architecture mobilenet_0.25_128_quantized + --image_dir ~/flower_photos --architecture mobilenet_0.25_128 ``` -This will create a 941KB model file in `/tmp/output_graph.pb`, with 25% of the -parameters of the full Mobilenet, taking 128x128 sized input images, and with -its weights quantized down to eight bits on disk. You can choose '1.0', '0.75', -'0.50', or '0.25' to control the number of weight parameters, and so the file -size (and to some extent the speed), '224', '192', '160', or '128' for the input -image size, with smaller sizes giving faster speeds, and an optional -'_quantized' at the end to indicate whether the file should contain 8-bit or -32-bit float weights. +This will create a 1.9MB model file in `/tmp/output_graph.pb`, with only 25% of +the number of neurons of the full Mobilenet, and trained to take 128x128 sized +input images. + +You can choose '1.0', '0.75', '0.50', or '0.25' to control the number of +neurons (activations of hidden layers); the number of weights (and hence to +some extent the file size and speed) shrinks like the square of that fraction. +You can choose '224', '192', '160', or '128' for the input image size, +with smaller sizes giving faster speeds. The speed and size advantages come at a loss to accuracy of course, but for many purposes this isn't critical. They can also be somewhat offset with improved training data. For example, training with distortions allows me to get above 80% -accuracy on the flower data set even with the 0.25/128/quantized graph above. +accuracy on the flower data set even with the 0.25/128 graph above. If you're going to be using the Mobilenet models in label_image or your own programs, you'll need to feed in an image of the specified size converted to a @@ -395,3 +396,9 @@ python tensorflow/examples/label_image/label_image.py \ --input_mean=128 --input_std=128 \ --image=$HOME/flower_photos/daisy/21652746_cc379e0eea_m.jpg ``` + +For more information on deploying the retrained model to a mobile device, see +the [codelab version](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0) +of this tutorial, especially [part 2](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2-tflite/#0), which describes +[TensorFlow Lite](/mobile/tflite/) and the additional optimizations it offers +(including quantization of model weights). diff --git a/tensorflow/docs_src/tutorials/index.md b/tensorflow/docs_src/tutorials/index.md index 8c697e48e550c4e425db33bab7257532d209ac7a..af01d3eaa12157f82c981de005708509f6652cca 100644 --- a/tensorflow/docs_src/tutorials/index.md +++ b/tensorflow/docs_src/tutorials/index.md @@ -10,7 +10,7 @@ these tutorials. These tutorials cover different aspects of image recognition: - * @{$layers}, which introduces convolutional neural networks (CNNs) and + * @{$layers$MNIST}, which introduces convolutional neural networks (CNNs) and demonstrates how to build a CNN in TensorFlow. * @{$image_recognition}, which introduces the field of image recognition and uses a pre-trained model (Inception) for recognizing images. diff --git a/tensorflow/docs_src/tutorials/kernel_methods.md b/tensorflow/docs_src/tutorials/kernel_methods.md index 63f408c2ca304d6345ffff459b799b011f8d8035..73e5c5105784ddc9729b8cea6cd31921572837e1 100644 --- a/tensorflow/docs_src/tutorials/kernel_methods.md +++ b/tensorflow/docs_src/tutorials/kernel_methods.md @@ -1,9 +1,9 @@ # Improving Linear Models Using Explicit Kernel Methods -Note: This document uses a deprecated version of ${tf.estimator}, -which has a ${tf.contrib.learn.estimator$different interface}. +Note: This document uses a deprecated version of @{tf.estimator}, +which has a @{tf.contrib.learn.Estimator$different interface}. It also uses other `contrib` methods whose -${$version_compat#not_covered$API may not be stable}. +@{$version_compat#not_covered$API may not be stable}. In this tutorial, we demonstrate how combining (explicit) kernel methods with linear models can drastically increase the latters' quality of predictions @@ -53,7 +53,7 @@ In order to feed data to a `tf.contrib.learn Estimator`, it is helpful to conver it to Tensors. For this, we will use an `input function` which adds Ops to the TensorFlow graph that, when executed, create mini-batches of Tensors to be used downstream. For more background on input functions, check -@{$get_started/premade_estimators#input_fn$this section on input functions}. +@{$get_started/premade_estimators#create_input_functions$this section on input functions}. In this example, we will use the `tf.train.shuffle_batch` Op which, besides converting numpy arrays to Tensors, allows us to specify the batch_size and whether to randomize the input every time the input_fn Ops are executed diff --git a/tensorflow/docs_src/tutorials/layers.md b/tensorflow/docs_src/tutorials/layers.md index 5111b16247e2b5c3410e69dcdf08318a35b18c2f..cadaec391d8970faf5847c9b9e39bccb31f885ed 100644 --- a/tensorflow/docs_src/tutorials/layers.md +++ b/tensorflow/docs_src/tutorials/layers.md @@ -193,22 +193,28 @@ to calculate loss, configure the training op, and generate predictions. If you're already experienced with CNNs and @{$get_started/custom_estimators$TensorFlow `Estimator`s}, and find the above code intuitive, you may want to skim these sections or just skip ahead to ["Training and Evaluating the CNN MNIST -Classifier"](#training-and-evaluating-the-cnn-mnist-classifier). +Classifier"](#training_and_evaluating_the_cnn_mnist_classifier). ### Input Layer The methods in the `layers` module for creating convolutional and pooling layers for two-dimensional image data expect input tensors to have a shape of -[batch_size, image_width, image_height, -channels], defined as follows: +[batch_size, image_height, image_width, +channels] by default. This behavior can be changed using the data_format parameter; defined as follows: + * _`batch_size`_. Size of the subset of examples to use when performing gradient descent during training. -* _`image_width`_. Width of the example images. * _`image_height`_. Height of the example images. +* _`image_width`_. Width of the example images. * _`channels`_. Number of color channels in the example images. For color images, the number of channels is 3 (red, green, blue). For monochrome images, there is just 1 channel (black). +* _`image_height`_. Height of the example images. +* _`data_format`_. A string, one of `channels_last` (default) or `channels_first`. + `channels_last` corresponds to inputs with shape + `(batch, ..., channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, ...)`. Here, our MNIST dataset is composed of monochrome 28x28 pixel images, so the desired shape for our input layer is [batch_size, 28, 28, @@ -247,28 +253,27 @@ conv1 = tf.layers.conv2d( ``` The `inputs` argument specifies our input tensor, which must have the shape -[batch_size, image_width, image_height, +[batch_size, image_height, image_width, channels]. Here, we're connecting our first convolutional layer to `input_layer`, which has the shape [batch_size, 28, 28, 1]. > Note: conv2d() will instead accept a shape of -> [channels, batch_size, image_width, -> image_height] when passed the argument +> [batch_size, channels, image_height, image_width] when passed the argument > data_format=channels_first. The `filters` argument specifies the number of filters to apply (here, 32), and -`kernel_size` specifies the dimensions of the filters as [width, -height] (here, [5, 5]). +`kernel_size` specifies the dimensions of the filters as [height, +width] (here, [5, 5]). -

TIP: If filter width and height have the same value, you can instead specify a +

TIP: If filter height and width have the same value, you can instead specify a single integer for kernel_size—e.g., kernel_size=5.

The `padding` argument specifies one of two enumerated values (case-insensitive): `valid` (default value) or `same`. To specify that the -output tensor should have the same width and height values as the input tensor, +output tensor should have the same height and width values as the input tensor, we set `padding=same` here, which instructs TensorFlow to add 0 values to the -edges of the input tensor to preserve width and height of 28. (Without padding, +edges of the input tensor to preserve height and width of 28. (Without padding, a 5x5 convolution over a 28x28 tensor will produce a 24x24 tensor, as there are 24x24 locations to extract a 5x5 tile from a 28x28 grid.) @@ -277,7 +282,7 @@ output of the convolution. Here, we specify ReLU activation with @{tf.nn.relu}. Our output tensor produced by `conv2d()` has a shape of -[batch_size, 28, 28, 32]: the same width and height +[batch_size, 28, 28, 32]: the same height and width dimensions as the input, but now with 32 channels holding the output from each of the filters. @@ -292,31 +297,30 @@ pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) ``` Again, `inputs` specifies the input tensor, with a shape of -[batch_size, image_width, image_height, +[batch_size, image_height, image_width, channels]. Here, our input tensor is `conv1`, the output from the first convolutional layer, which has a shape of [batch_size, 28, 28, 32]. > Note: As with conv2d(), max_pooling2d() will instead -> accept a shape of [channels, batch_size, -> image_width, image_height] when passed the argument +> accept a shape of [batch_size, channels, +> image_height, image_width] when passed the argument > data_format=channels_first. The `pool_size` argument specifies the size of the max pooling filter as -[width, height] (here, `[2, 2]`). If both +[height, width] (here, `[2, 2]`). If both dimensions have the same value, you can instead specify a single integer (e.g., `pool_size=2`). The `strides` argument specifies the size of the stride. Here, we set a stride of 2, which indicates that the subregions extracted by the filter should be -separated by 2 pixels in both the width and height dimensions (for a 2x2 filter, +separated by 2 pixels in both the height and width dimensions (for a 2x2 filter, this means that none of the regions extracted will overlap). If you want to set -different stride values for width and height, you can instead specify a tuple or +different stride values for height and width, you can instead specify a tuple or list (e.g., `stride=[3, 6]`). Our output tensor produced by `max_pooling2d()` (`pool1`) has a shape of -[batch_size, 14, 14, 32]: the 2x2 filter reduces width and -height by 50% each. +[batch_size, 14, 14, 32]: the 2x2 filter reduces height and width by 50% each. ### Convolutional Layer #2 and Pooling Layer #2 @@ -338,13 +342,11 @@ pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) Note that convolutional layer #2 takes the output tensor of our first pooling layer (`pool1`) as input, and produces the tensor `conv2` as output. `conv2` -has a shape of [batch_size, 14, 14, 64], the same width -and height as `pool1` (due to `padding="same"`), and 64 channels for the 64 +has a shape of [batch_size, 14, 14, 64], the same height and width as `pool1` (due to `padding="same"`), and 64 channels for the 64 filters applied. Pooling layer #2 takes `conv2` as input, producing `pool2` as output. `pool2` -has shape [batch_size, 7, 7, 64] (50% reduction of width -and height from `conv2`). +has shape [batch_size, 7, 7, 64] (50% reduction of height and width from `conv2`). ### Dense Layer @@ -360,7 +362,7 @@ pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64]) In the `reshape()` operation above, the `-1` signifies that the *`batch_size`* dimension will be dynamically calculated based on the number of examples in our -input data. Each example has 7 (`pool2` width) * 7 (`pool2` height) * 64 +input data. Each example has 7 (`pool2` height) * 7 (`pool2` width) * 64 (`pool2` channels) features, so we want the `features` dimension to have a value of 7 * 7 * 64 (3136 in total). The output tensor, `pool2_flat`, has shape [batch_size, 3136]. @@ -446,7 +448,7 @@ tf.nn.softmax(logits, name="softmax_tensor") > Note: We use the `name` argument to explicitly name this operation > `softmax_tensor`, so we can reference it later. (We'll set up logging for the -> softmax values in ["Set Up a Logging Hook"](#set-up-a-logging-hook). +> softmax values in ["Set Up a Logging Hook"](#set-up-a-logging-hook)). We compile our predictions in a dict, and return an `EstimatorSpec` object: @@ -534,9 +536,8 @@ if mode == tf.estimator.ModeKeys.TRAIN: ``` > Note: For a more in-depth look at configuring training ops for Estimator model -> functions, see @{$get_started/custom_estimators#defining-the-training-op-for-the-model$"Defining -> the training op for the model"} in the @{$get_started/custom_estimators$"Creating Estimations in -> tf.estimator"} tutorial. +> functions, see @{$get_started/custom_estimators#defining_the_training_op_for_the_model$"Defining the training op for the model"} +> in the @{$get_started/custom_estimators$"Creating Estimators in tf.estimator."} tutorial. ### Add evaluation metrics @@ -625,8 +626,8 @@ operation earlier when we generated the probabilities in `cnn_model_fn`. > Note: If you don't explicitly assign a name to an operation via the `name` > argument, TensorFlow will assign a default name. A couple easy ways to > discover the names applied to operations are to visualize your graph on -> @{$graph_viz$TensorBoard}) or to enable the @{$debugger$TensorFlow Debugger -> (tfdbg)}. +> @{$graph_viz$TensorBoard}) or to enable the +> @{$programmers_guide/debugger$TensorFlow Debugger (tfdbg)}. Next, we create the `LoggingTensorHook`, passing `tensors_to_log` to the `tensors` argument. We set `every_n_iter=50`, which specifies that probabilities diff --git a/tensorflow/docs_src/tutorials/recurrent_quickdraw.md b/tensorflow/docs_src/tutorials/recurrent_quickdraw.md index e22536adb6f0b893602ff79612cfb01e10586a18..5d83fbe2a3709c0834f448cbc316453f80428dd1 100644 --- a/tensorflow/docs_src/tutorials/recurrent_quickdraw.md +++ b/tensorflow/docs_src/tutorials/recurrent_quickdraw.md @@ -38,8 +38,8 @@ To try the code for this tutorial: 1. [Download the data](#download-the-data) in `TFRecord` format from [here](http://download.tensorflow.org/data/quickdraw_tutorial_dataset_v1.tar.gz) and unzip it. More details about [how to obtain the original Quick, Draw! - data](#optional-download-the-full-quick-draw-data) and [how to convert that - to `TFRecord` files](#optional-converting-the-data) is available below. + data](#optional_download_the_full_quick_draw_data) and [how to convert that + to `TFRecord` files](#optional_converting_the_data) is available below. 1. Execute the tutorial code with the following command to train the RNN-based model described in this tutorial. Make sure to adjust the paths to point to @@ -108,8 +108,9 @@ This download will take a while and download a bit more than 23GB of data. ### Optional: Converting the data To convert the `ndjson` files to -@{$python/python_io#tfrecords_format_details$TFRecord} files containing -${tf.train.Example} protos run the following command. +@{$python/python_io#TFRecords_Format_Details$TFRecord} files containing +[`tf.train.Example`](https://www.tensorflow.org/code/tensorflow/core/example/example.proto) +protos run the following command. ```shell python create_dataset.py --ndjson_path rnn_tutorial_data \ @@ -117,7 +118,7 @@ ${tf.train.Example} protos run the following command. ``` This will store the data in 10 shards of -@{$python/python_io#tfrecords_format_details$TFRecord} files with 10000 items +@{$python/python_io#TFRecords_Format_Details$TFRecord} files with 10000 items per class for the training data and 1000 items per class as eval data. This conversion process is described in more detail in the following. diff --git a/tensorflow/docs_src/tutorials/wide.md b/tensorflow/docs_src/tutorials/wide.md index 005dc020f94f666da295f4ff0342fae858121012..27ce75a30dd2acd5925702611042270e767b0c73 100644 --- a/tensorflow/docs_src/tutorials/wide.md +++ b/tensorflow/docs_src/tutorials/wide.md @@ -74,8 +74,8 @@ Here's a list of columns available in the Census Income dataset: | relationship | Categorical | Wife, Own-child, Husband, | : : : Not-in-family, Other-relative, : : : : Unmarried. : -| race | Categorical | White, Asian-Pac-Islander, | -: : : Amer-Indian-Eskimo, Other, Black. : +| race | Categorical | Amer-Indian-Eskimo, Asian-Pac- | +: : : Islander, Black, White, Other. : | gender | Categorical | Female, Male. | | capital_gain | Continuous | Capital gains recorded. | | capital_loss | Continuous | Capital Losses recorded. | @@ -247,7 +247,7 @@ hours_per_week = tf.feature_column.numeric_column('hours_per_week') ### Making Continuous Features Categorical through Bucketization Sometimes the relationship between a continuous feature and the label is not -linear. As an hypothetical example, a person's income may grow with age in the +linear. As a hypothetical example, a person's income may grow with age in the early stage of one's career, then the growth may slow at some point, and finally the income decreases after retirement. In this scenario, using the raw `age` as a real-valued feature column might not be a good choice because the model can @@ -361,6 +361,16 @@ The first line of the final output should be something like `accuracy: 0.83557522`, which means the accuracy is 83.6%. Feel free to try more features and transformations and see if you can do even better! +After the model is evaluated, we can use the model to predict whether an individual has an annual income of over +50,000 dollars given an individual's information input. +```python + pred_iter = model.predict(input_fn=lambda: input_fn(FLAGS.test_data, 1, False, 1)) + for pred in pred_iter: + print(pred['classes']) +``` + +The model prediction output would be like `[b'1']` or `[b'0']` which means whether corresponding individual has an annual income of over 50,000 dollars or not. + If you'd like to see a working end-to-end example, you can download our [example code](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py) and set the `model_type` flag to `wide`. diff --git a/tensorflow/examples/android/AndroidManifest.xml b/tensorflow/examples/android/AndroidManifest.xml index bb75431a1f8bab2951299520903aa6e043f8415e..5c47ce6b673e4c9d635b867c1ccdc679f67c6ae5 100644 --- a/tensorflow/examples/android/AndroidManifest.xml +++ b/tensorflow/examples/android/AndroidManifest.xml @@ -40,6 +40,7 @@ + @@ -49,6 +50,7 @@ + @@ -58,6 +60,7 @@ + @@ -67,6 +70,7 @@ + diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/CameraActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/CameraActivity.java index 8bd4abb154a8f8c74f2195d4acbb99d3d5d498ea..429138abe5338e63d602ef6005e7607d21e1e357 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/CameraActivity.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/CameraActivity.java @@ -351,6 +351,10 @@ public abstract class CameraActivity extends Activity protected void setFragment() { String cameraId = chooseCamera(); + if (cameraId == null) { + Toast.makeText(this, "No Camera Detected", Toast.LENGTH_SHORT).show(); + finish(); + } Fragment fragment; if (useCamera2API) { @@ -416,7 +420,8 @@ public abstract class CameraActivity extends Activity @Override public boolean onKeyDown(final int keyCode, final KeyEvent event) { - if (keyCode == KeyEvent.KEYCODE_VOLUME_DOWN || keyCode == KeyEvent.KEYCODE_VOLUME_UP) { + if (keyCode == KeyEvent.KEYCODE_VOLUME_DOWN || keyCode == KeyEvent.KEYCODE_VOLUME_UP + || keyCode == KeyEvent.KEYCODE_BUTTON_L1 || keyCode == KeyEvent.KEYCODE_DPAD_CENTER) { debug = !debug; requestRender(); onSetDebug(debug); diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java index 6a66ec3927be62f1f996eb18bb6c04ea66f43152..33ec65e9f73a1d04bcafdc09d1618b32e03b1dc0 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java @@ -16,8 +16,10 @@ package org.tensorflow.demo; +import android.app.UiModeManager; import android.content.Context; import android.content.res.AssetManager; +import android.content.res.Configuration; import android.graphics.Bitmap; import android.graphics.Bitmap.Config; import android.graphics.BitmapFactory; @@ -31,9 +33,11 @@ import android.graphics.Typeface; import android.media.ImageReader.OnImageAvailableListener; import android.os.Bundle; import android.os.SystemClock; +import android.util.DisplayMetrics; import android.util.Size; import android.util.TypedValue; import android.view.Display; +import android.view.KeyEvent; import android.view.MotionEvent; import android.view.View; import android.view.View.OnClickListener; @@ -43,6 +47,7 @@ import android.widget.BaseAdapter; import android.widget.Button; import android.widget.GridView; import android.widget.ImageView; +import android.widget.RelativeLayout; import android.widget.Toast; import java.io.IOException; import java.io.InputStream; @@ -381,6 +386,27 @@ public class StylizeActivity extends CameraActivity implements OnImageAvailableL grid = (GridView) findViewById(R.id.grid_layout); grid.setAdapter(adapter); grid.setOnTouchListener(gridTouchAdapter); + + // Change UI on Android TV + UiModeManager uiModeManager = (UiModeManager) getSystemService(UI_MODE_SERVICE); + if (uiModeManager.getCurrentModeType() == Configuration.UI_MODE_TYPE_TELEVISION) { + DisplayMetrics displayMetrics = new DisplayMetrics(); + getWindowManager().getDefaultDisplay().getMetrics(displayMetrics); + int styleSelectorHeight = displayMetrics.heightPixels; + int styleSelectorWidth = displayMetrics.widthPixels - styleSelectorHeight; + RelativeLayout.LayoutParams layoutParams = new RelativeLayout.LayoutParams(styleSelectorWidth, ViewGroup.LayoutParams.MATCH_PARENT); + + // Calculate number of style in a row, so all the style can show up without scrolling + int numOfStylePerRow = 3; + while (styleSelectorWidth / numOfStylePerRow * Math.ceil((float) (adapter.getCount() - 2) / numOfStylePerRow) > styleSelectorHeight) { + numOfStylePerRow++; + } + grid.setNumColumns(numOfStylePerRow); + layoutParams.addRule(RelativeLayout.ALIGN_PARENT_RIGHT); + grid.setLayoutParams(layoutParams); + adapter.buttons.clear(); + } + setStyle(adapter.items[0], 1.0f); } @@ -602,4 +628,38 @@ public class StylizeActivity extends CameraActivity implements OnImageAvailableL borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines); } + + @Override + public boolean onKeyDown(int keyCode, KeyEvent event) { + int moveOffset = 0; + switch (keyCode) { + case KeyEvent.KEYCODE_DPAD_LEFT: + moveOffset = -1; + break; + case KeyEvent.KEYCODE_DPAD_RIGHT: + moveOffset = 1; + break; + case KeyEvent.KEYCODE_DPAD_UP: + moveOffset = -1 * grid.getNumColumns(); + break; + case KeyEvent.KEYCODE_DPAD_DOWN: + moveOffset = grid.getNumColumns(); + break; + default: + return super.onKeyDown(keyCode, event); + } + + // get the highest selected style + int currentSelect = 0; + float highestValue = 0; + for (int i = 0; i < adapter.getCount(); i++) { + if (adapter.items[i].value > highestValue) { + currentSelect = i; + highestValue = adapter.items[i].value; + } + } + setStyle(adapter.items[(currentSelect + moveOffset + adapter.getCount()) % adapter.getCount()], 1); + + return true; + } } diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py index 25e09fecbfd093e97899807b82a03f1116dbe5ff..99a71206acbd533ec8bc5a9644435eacad564cd4 100644 --- a/tensorflow/examples/image_retraining/retrain.py +++ b/tensorflow/examples/image_retraining/retrain.py @@ -75,13 +75,16 @@ python tensorflow/examples/image_retraining/retrain.py \ --image_dir ~/flower_photos --architecture mobilenet_1.0_224 ``` -Run quantized version of mobilenet: +Run mobilenet, instrumented for quantization: ```bash python tensorflow/examples/image_retraining/retrain.py \ - --image_dir ~/flower_photos/ --architecture mobilenet_1.0_224_quantized + --image_dir ~/flower_photos/ --architecture mobilenet_1.0_224_quant ``` +These instrumented models can be converted to fully quantized mobile models via +TensorFlow Lite. + There are 32 different Mobilenet models to choose from, with a variety of file size and latency options. The first number can be '1.0', '0.75', '0.50', or '0.25' to control the size, and the second controls the input image size, either @@ -121,7 +124,6 @@ import numpy as np from six.moves import urllib import tensorflow as tf -from tensorflow.contrib.quantize.python import quant_ops from tensorflow.python.framework import graph_util from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import gfile @@ -135,6 +137,9 @@ FLAGS = None # need to update these to reflect the values in the network you're using. MAX_NUM_IMAGES_PER_CLASS = 2 ** 27 - 1 # ~134M +# The location where variable checkpoints will be stored. +CHECKPOINT_NAME = '/tmp/_retrain_checkpoint' + def create_image_lists(image_dir, testing_percentage, validation_percentage): """Builds a list of training images from the file system. @@ -745,9 +750,9 @@ def variable_summaries(var): tf.summary.histogram('histogram', var) -def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor, - bottleneck_tensor_size, quantize_layer): - """Adds a new softmax and fully-connected layer for training. +def add_final_retrain_ops(class_count, final_tensor_name, bottleneck_tensor, + bottleneck_tensor_size, quantize_layer, is_training): + """Adds a new softmax and fully-connected layer for training and eval. We need to retrain the top layer to identify our new classes, so this function adds the right operations to the graph, along with some variables to hold the @@ -763,7 +768,9 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor, bottleneck_tensor: The output of the main CNN graph. bottleneck_tensor_size: How many entries in the bottleneck vector. quantize_layer: Boolean, specifying whether the newly added layer should be - quantized. + instrumented for quantized. + is_training: Boolean, specifying whether the newly add layer is for training + or eval. Returns: The tensors for the training and cross entropy results, and tensors for the @@ -778,50 +785,41 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor, ground_truth_input = tf.placeholder( tf.int64, [None], name='GroundTruthInput') - # Organizing the following ops as `final_training_ops` so they're easier - # to see in TensorBoard - layer_name = 'final_training_ops' + # Organizing the following ops so they are easier to see in TensorBoard. + layer_name = 'final_retrain_ops' with tf.name_scope(layer_name): with tf.name_scope('weights'): initial_value = tf.truncated_normal( [bottleneck_tensor_size, class_count], stddev=0.001) layer_weights = tf.Variable(initial_value, name='final_weights') - if quantize_layer: - quantized_layer_weights = quant_ops.MovingAvgQuantize( - layer_weights, is_training=True) - variable_summaries(quantized_layer_weights) - variable_summaries(layer_weights) + with tf.name_scope('biases'): layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases') - if quantize_layer: - quantized_layer_biases = quant_ops.MovingAvgQuantize( - layer_biases, is_training=True) - variable_summaries(quantized_layer_biases) - variable_summaries(layer_biases) with tf.name_scope('Wx_plus_b'): - if quantize_layer: - logits = tf.matmul(bottleneck_input, - quantized_layer_weights) + quantized_layer_biases - logits = quant_ops.MovingAvgQuantize( - logits, - init_min=-32.0, - init_max=32.0, - is_training=True, - num_bits=8, - narrow_range=False, - ema_decay=0.5) - tf.summary.histogram('pre_activations', logits) - else: - logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases - tf.summary.histogram('pre_activations', logits) + logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases + tf.summary.histogram('pre_activations', logits) final_tensor = tf.nn.softmax(logits, name=final_tensor_name) + # The tf.contrib.quantize functions rewrite the graph in place for + # quantization. The imported model graph has already been rewritten, so upon + # calling these rewrites, only the newly added final layer will be + # transformed. + if quantize_layer: + if is_training: + tf.contrib.quantize.create_training_graph() + else: + tf.contrib.quantize.create_eval_graph() + tf.summary.histogram('activations', final_tensor) + # If this is an eval graph, we don't need to add loss ops or an optimizer. + if not is_training: + return None, None, bottleneck_input, ground_truth_input, final_tensor + with tf.name_scope('cross_entropy'): cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy( labels=ground_truth_input, logits=logits) @@ -857,13 +855,91 @@ def add_evaluation_step(result_tensor, ground_truth_tensor): return evaluation_step, prediction -def save_graph_to_file(sess, graph, graph_file_name): +def run_final_eval(sess, model_info, class_count, image_lists, jpeg_data_tensor, + decoded_image_tensor, resized_image_tensor, + bottleneck_tensor): + """Runs a final evaluation on an eval graph using the test data set. + + Args: + sess: Session for the train graph. + model_info: Model info dictionary from create_model_info() + class_count: Number of classes + image_lists: Dictionary of training images for each label. + jpeg_data_tensor: The layer to feed jpeg image data into. + decoded_image_tensor: The output of decoding and resizing the image. + resized_image_tensor: The input node of the recognition graph. + bottleneck_tensor: The bottleneck output layer of the CNN graph. + """ + (sess, bottleneck_input, ground_truth_input, evaluation_step, + prediction) = build_eval_session(model_info, class_count) + + test_bottlenecks, test_ground_truth, test_filenames = ( + get_random_cached_bottlenecks(sess, image_lists, FLAGS.test_batch_size, + 'testing', FLAGS.bottleneck_dir, + FLAGS.image_dir, jpeg_data_tensor, + decoded_image_tensor, resized_image_tensor, + bottleneck_tensor, FLAGS.architecture)) + test_accuracy, predictions = sess.run( + [evaluation_step, prediction], + feed_dict={ + bottleneck_input: test_bottlenecks, + ground_truth_input: test_ground_truth + }) + tf.logging.info('Final test accuracy = %.1f%% (N=%d)' % + (test_accuracy * 100, len(test_bottlenecks))) + + if FLAGS.print_misclassified_test_images: + tf.logging.info('=== MISCLASSIFIED TEST IMAGES ===') + for i, test_filename in enumerate(test_filenames): + if predictions[i] != test_ground_truth[i]: + tf.logging.info('%70s %s' % (test_filename, + list(image_lists.keys())[predictions[i]])) + + +def build_eval_session(model_info, class_count): + """Builds an restored eval session without train operations for exporting. + + Args: + model_info: Model info dictionary from create_model_info() + class_count: Number of classes + + Returns: + Eval session containing the restored eval graph. + The bottleneck input, ground truth, eval step, and prediction tensors. + """ + # If quantized, we need to create the correct eval graph for exporting. + eval_graph, bottleneck_tensor, _ = create_model_graph(model_info) + + eval_sess = tf.Session(graph=eval_graph) + with eval_graph.as_default(): + # Add the new layer for exporting. + (_, _, bottleneck_input, + ground_truth_input, final_tensor) = add_final_retrain_ops( + class_count, FLAGS.final_tensor_name, bottleneck_tensor, + model_info['bottleneck_tensor_size'], model_info['quantize_layer'], + False) + + # Now we need to restore the values from the training graph to the eval + # graph. + tf.train.Saver().restore(eval_sess, CHECKPOINT_NAME) + + evaluation_step, prediction = add_evaluation_step(final_tensor, + ground_truth_input) + + return (eval_sess, bottleneck_input, ground_truth_input, evaluation_step, + prediction) + + +def save_graph_to_file(graph, graph_file_name, model_info, class_count): + """Saves an graph to file, creating a valid quantized one if necessary.""" + sess, _, _, _, _ = build_eval_session(model_info, class_count) + graph = sess.graph + output_graph_def = graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) with gfile.FastGFile(graph_file_name, 'wb') as f: f.write(output_graph_def.SerializeToString()) - return def prepare_file_system(): @@ -916,11 +992,10 @@ def create_model_info(architecture): return None version_string = parts[1] if (version_string != '1.0' and version_string != '0.75' and - version_string != '0.50' and version_string != '0.25'): + version_string != '0.5' and version_string != '0.25'): tf.logging.error( - """"The Mobilenet version should be '1.0', '0.75', '0.50', or '0.25', - but found '%s' for architecture '%s'""", - version_string, architecture) + """"The Mobilenet version should be '1.0', '0.75', '0.5', or '0.25', + but found '%s' for architecture '%s'""", version_string, architecture) return None size_string = parts[2] if (size_string != '224' and size_string != '192' and @@ -933,35 +1008,26 @@ def create_model_info(architecture): if len(parts) == 3: is_quantized = False else: - if parts[3] != 'quantized': + if parts[3] != 'quant': tf.logging.error( "Couldn't understand architecture suffix '%s' for '%s'", parts[3], architecture) return None is_quantized = True + data_url = 'http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/' + model_name = 'mobilenet_v1_' + version_string + '_' + size_string if is_quantized: - data_url = 'http://download.tensorflow.org/models/mobilenet_v1_' - data_url += version_string + '_' + size_string + '_quantized_frozen.tgz' - bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0' - resized_input_tensor_name = 'Placeholder:0' - model_dir_name = ('mobilenet_v1_' + version_string + '_' + size_string + - '_quantized_frozen') - model_base_name = 'quantized_frozen_graph.pb' - - else: - data_url = 'http://download.tensorflow.org/models/mobilenet_v1_' - data_url += version_string + '_' + size_string + '_frozen.tgz' - bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0' - resized_input_tensor_name = 'input:0' - model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string - model_base_name = 'frozen_graph.pb' + model_name += '_quant' + data_url += model_name + '.tgz' + bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0' + resized_input_tensor_name = 'input:0' + model_file_name = model_name + '_frozen.pb' bottleneck_tensor_size = 1001 input_width = int(size_string) input_height = int(size_string) input_depth = 3 - model_file_name = os.path.join(model_dir_name, model_base_name) input_mean = 127.5 input_std = 127.5 else: @@ -1011,43 +1077,45 @@ def add_jpeg_decoding(input_width, input_height, input_depth, input_mean, return jpeg_data, mul_image -def export_model(sess, architecture, saved_model_dir): +def export_model(model_info, class_count, saved_model_dir): """Exports model for serving. Args: - sess: Current active TensorFlow Session. - architecture: Model architecture. + model_info: The modelinfo for the current model. + class_count: The number of classes. saved_model_dir: Directory in which to save exported model and variables. """ - if architecture == 'inception_v3': - input_tensor = 'DecodeJpeg/contents:0' - elif architecture.startswith('mobilenet_'): - input_tensor = 'input:0' - else: - raise ValueError('Unknown architecture', architecture) - in_image = sess.graph.get_tensor_by_name(input_tensor) - inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)} - - out_classes = sess.graph.get_tensor_by_name('final_result:0') - outputs = {'prediction': tf.saved_model.utils.build_tensor_info(out_classes)} + # The SavedModel should hold the eval graph. + sess, _, _, _, _ = build_eval_session(model_info, class_count) + graph = sess.graph + with graph.as_default(): + input_tensor = model_info['resized_input_tensor_name'] + in_image = sess.graph.get_tensor_by_name(input_tensor) + inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)} + + out_classes = sess.graph.get_tensor_by_name('final_result:0') + outputs = { + 'prediction': tf.saved_model.utils.build_tensor_info(out_classes) + } - signature = tf.saved_model.signature_def_utils.build_signature_def( - inputs=inputs, - outputs=outputs, - method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME) + signature = tf.saved_model.signature_def_utils.build_signature_def( + inputs=inputs, + outputs=outputs, + method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME) - legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op') + legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op') - # Save out the SavedModel. - builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir) - builder.add_meta_graph_and_variables( - sess, [tf.saved_model.tag_constants.SERVING], - signature_def_map={ - tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: - signature - }, - legacy_init_op=legacy_init_op) - builder.save() + # Save out the SavedModel. + builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir) + builder.add_meta_graph_and_variables( + sess, [tf.saved_model.tag_constants.SERVING], + signature_def_map={ + tf.saved_model.signature_constants. + DEFAULT_SERVING_SIGNATURE_DEF_KEY: + signature + }, + legacy_init_op=legacy_init_op) + builder.save() def main(_): @@ -1064,11 +1132,6 @@ def main(_): tf.logging.error('Did not recognize architecture flag') return -1 - # Set up the pre-trained graph. - maybe_download_and_extract(model_info['data_url']) - graph, bottleneck_tensor, resized_image_tensor = ( - create_model_graph(model_info)) - # Look at the folder structure, and create lists of all the images. image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage, FLAGS.validation_percentage) @@ -1087,6 +1150,19 @@ def main(_): FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale, FLAGS.random_brightness) + # Set up the pre-trained graph. + maybe_download_and_extract(model_info['data_url']) + graph, bottleneck_tensor, resized_image_tensor = ( + create_model_graph(model_info)) + + # Add the new layer that we'll be training. + with graph.as_default(): + (train_step, cross_entropy, bottleneck_input, + ground_truth_input, final_tensor) = add_final_retrain_ops( + class_count, FLAGS.final_tensor_name, bottleneck_tensor, + model_info['bottleneck_tensor_size'], model_info['quantize_layer'], + True) + with tf.Session(graph=graph) as sess: # Set up the image decoding sub-graph. jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding( @@ -1110,15 +1186,8 @@ def main(_): decoded_image_tensor, resized_image_tensor, bottleneck_tensor, FLAGS.architecture) - # Add the new layer that we'll be training. - (train_step, cross_entropy, bottleneck_input, ground_truth_input, - final_tensor) = add_final_training_ops( - len(image_lists.keys()), FLAGS.final_tensor_name, bottleneck_tensor, - model_info['bottleneck_tensor_size'], model_info['quantize_layer']) - # Create the operations we need to evaluate the accuracy of our new layer. - evaluation_step, prediction = add_evaluation_step( - final_tensor, ground_truth_input) + evaluation_step, _ = add_evaluation_step(final_tensor, ground_truth_input) # Merge all the summaries and write them out to the summaries_dir merged = tf.summary.merge_all() @@ -1128,6 +1197,10 @@ def main(_): validation_writer = tf.summary.FileWriter( FLAGS.summaries_dir + '/validation') + # Create a train saver that is used to restore values into an eval graph + # when exporting models. + train_saver = tf.train.Saver() + # Set up all our weights to their initial default values. init = tf.global_variables_initializer() sess.run(init) @@ -1168,6 +1241,9 @@ def main(_): (datetime.now(), i, train_accuracy * 100)) tf.logging.info('%s: Step %d: Cross entropy = %f' % (datetime.now(), i, cross_entropy_value)) + # TODO(suharshs): Make this use an eval graph, to avoid quantization + # moving averages being updated by the validation set, though in + # practice this makes a negligable difference. validation_bottlenecks, validation_ground_truth, _ = ( get_random_cached_bottlenecks( sess, image_lists, FLAGS.validation_batch_size, 'validation', @@ -1190,42 +1266,32 @@ def main(_): if (intermediate_frequency > 0 and (i % intermediate_frequency == 0) and i > 0): + # If we want to do an intermediate save, save a checkpoint of the train + # graph, to restore into the eval graph. + train_saver.save(sess, CHECKPOINT_NAME) intermediate_file_name = (FLAGS.intermediate_output_graphs_dir + 'intermediate_' + str(i) + '.pb') tf.logging.info('Save intermediate result to : ' + intermediate_file_name) - save_graph_to_file(sess, graph, intermediate_file_name) + save_graph_to_file(graph, intermediate_file_name, model_info, + class_count) + + # After training is complete, force one last save of the train checkpoint. + train_saver.save(sess, CHECKPOINT_NAME) # We've completed all our training, so run a final test evaluation on # some new images we haven't used before. - test_bottlenecks, test_ground_truth, test_filenames = ( - get_random_cached_bottlenecks( - sess, image_lists, FLAGS.test_batch_size, 'testing', - FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, - decoded_image_tensor, resized_image_tensor, bottleneck_tensor, - FLAGS.architecture)) - test_accuracy, predictions = sess.run( - [evaluation_step, prediction], - feed_dict={bottleneck_input: test_bottlenecks, - ground_truth_input: test_ground_truth}) - tf.logging.info('Final test accuracy = %.1f%% (N=%d)' % - (test_accuracy * 100, len(test_bottlenecks))) - - if FLAGS.print_misclassified_test_images: - tf.logging.info('=== MISCLASSIFIED TEST IMAGES ===') - for i, test_filename in enumerate(test_filenames): - if predictions[i] != test_ground_truth[i]: - tf.logging.info('%70s %s' % - (test_filename, - list(image_lists.keys())[predictions[i]])) + run_final_eval(sess, model_info, class_count, image_lists, jpeg_data_tensor, + decoded_image_tensor, resized_image_tensor, + bottleneck_tensor) # Write out the trained graph and labels with the weights stored as # constants. - save_graph_to_file(sess, graph, FLAGS.output_graph) + save_graph_to_file(graph, FLAGS.output_graph, model_info, class_count) with gfile.FastGFile(FLAGS.output_labels, 'w') as f: f.write('\n'.join(image_lists.keys()) + '\n') - export_model(sess, FLAGS.architecture, FLAGS.saved_model_dir) + export_model(model_info, class_count, FLAGS.saved_model_dir) if __name__ == '__main__': @@ -1406,8 +1472,9 @@ if __name__ == '__main__': form 'mobilenet__[_quantized]'. For example, 'mobilenet_1.0_224' will pick a model that is 17 MB in size and takes 224 pixel input images, while 'mobilenet_0.25_128_quantized' will choose a much - less accurate, but smaller and faster network that's 920 KB on disk and - takes 128x128 images. See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html + smaller and less accurate model, taking 128x128 images, and instrumented + for eventual quantization via TensorFlow Lite. + See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html for more information on Mobilenet.\ """) parser.add_argument( diff --git a/tensorflow/examples/image_retraining/retrain_test.py b/tensorflow/examples/image_retraining/retrain_test.py index 8b8dd45fd72e3d29bdb7f6291cc53b912adf3644..fb7324c58ac1be60baad840207f31a61ec6182be 100644 --- a/tensorflow/examples/image_retraining/retrain_test.py +++ b/tensorflow/examples/image_retraining/retrain_test.py @@ -67,22 +67,52 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase): self.assertIsNotNone(sess.graph.get_tensor_by_name('DistortResult:0')) @tf.test.mock.patch.object(retrain, 'FLAGS', learning_rate=0.01) - def testAddFinalTrainingOps(self, flags_mock): + def testAddFinalRetrainOps(self, flags_mock): with tf.Graph().as_default(): with tf.Session() as sess: bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck') - # Test creating final training op with quantization - retrain.add_final_training_ops(5, 'final', bottleneck, 1024, False) + # Test creating final training op with quantization. + retrain.add_final_retrain_ops(5, 'final', bottleneck, 1024, False, + False) self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0')) @tf.test.mock.patch.object(retrain, 'FLAGS', learning_rate=0.01) - def testAddFinalTrainingOpsQuantized(self, flags_mock): - with tf.Graph().as_default(): + def testAddFinalRetrainOpsQuantized(self, flags_mock): + # Ensure that the training and eval graph for quantized models are correctly + # created. + with tf.Graph().as_default() as g: + with tf.Session() as sess: + bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck') + # Test creating final training op with quantization, set is_training to + # true. + retrain.add_final_retrain_ops(5, 'final', bottleneck, 1024, True, True) + self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0')) + found_fake_quant = 0 + for op in g.get_operations(): + if op.type == 'FakeQuantWithMinMaxVars': + found_fake_quant += 1 + # Ensure that the inputs of each FakeQuant operations has 2 Assign + # operations in the training graph (Assign[Min,Max]Last, + # Assign[Min,Max]Ema) + self.assertEqual(2, + len([i for i in op.inputs if 'Assign' in i.name])) + self.assertEqual(found_fake_quant, 2) + with tf.Graph().as_default() as g: with tf.Session() as sess: bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck') - # Test creating final training op with quantization - retrain.add_final_training_ops(5, 'final', bottleneck, 1024, True) + # Test creating final training op with quantization, set is_training to + # false. + retrain.add_final_retrain_ops(5, 'final', bottleneck, 1024, True, False) self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0')) + found_fake_quant = 0 + for op in g.get_operations(): + if op.type == 'FakeQuantWithMinMaxVars': + found_fake_quant += 1 + for i in op.inputs: + # Ensure that no operations are Assign operation since this is the + # evaluation graph. + self.assertTrue('Assign' not in i.name) + self.assertEqual(found_fake_quant, 2) def testAddEvaluationStep(self): with tf.Graph().as_default(): diff --git a/tensorflow/examples/ios/README.md b/tensorflow/examples/ios/README.md index 5bdaeb43ce143e36e78cfe301fd9b59e8b85b034..5d7bd36837b2a2c33ab4bc311a582c174666dcd5 100644 --- a/tensorflow/examples/ios/README.md +++ b/tensorflow/examples/ios/README.md @@ -119,11 +119,13 @@ rundown: `tensorflow/contrib/makefile/gen/lib` to the Library Search Paths setting. - You'll also need to add `libprotobuf.a` and `libprotobuf-lite.a` from - `tensorflow/contrib/makefile/gen/protobuf_ios/lib` to your _Build Stages_ and - _Library Search Paths_. + `tensorflow/contrib/makefile/gen/protobuf_ios/lib` + and `nsync.a` from `tensorflow/contrib/makefile/downloads/nsync/builds/lipo.ios.c++11` + to your _Build Stages_ and _Library Search Paths_. - The _Header Search_ paths needs to contain: - the root folder of tensorflow, + - `tensorflow/contrib/makefile/downloads/nsync/public` - `tensorflow/contrib/makefile/downloads/protobuf/src` - `tensorflow/contrib/makefile/downloads`, - `tensorflow/contrib/makefile/downloads/eigen`, and diff --git a/tensorflow/examples/learn/mnist.py b/tensorflow/examples/learn/mnist.py index 98819b20bfea5021d52e2c50b004bccdaf1f25e7..3ead8614b68959b95ccad43623d4df4a5c4665bd 100644 --- a/tensorflow/examples/learn/mnist.py +++ b/tensorflow/examples/learn/mnist.py @@ -61,8 +61,10 @@ def conv_model(features, labels, mode): # Densely connected layer with 1024 neurons. h_fc1 = tf.layers.dense(h_pool2_flat, 1024, activation=tf.nn.relu) - if mode == tf.estimator.ModeKeys.TRAIN: - h_fc1 = tf.layers.dropout(h_fc1, rate=0.5) + h_fc1 = tf.layers.dropout( + h_fc1, + rate=0.5, + training=(mode == tf.estimator.ModeKeys.TRAIN)) # Compute logits (1 per class) and compute loss. logits = tf.layers.dense(h_fc1, N_DIGITS, activation=None) diff --git a/tensorflow/examples/learn/resnet.py b/tensorflow/examples/learn/resnet.py index 9542e552504580a6614f8bd2f43c38dfa795750f..c00de932a8707ad5717aaf1251cf5c88464a28b0 100755 --- a/tensorflow/examples/learn/resnet.py +++ b/tensorflow/examples/learn/resnet.py @@ -53,6 +53,8 @@ def res_net_model(features, labels, mode): ndim = int(sqrt(input_shape[1])) x = tf.reshape(x, [-1, ndim, ndim, 1]) + training = (mode == tf.estimator.ModeKeys.TRAIN) + # First convolution expands to 64 channels with tf.variable_scope('conv_layer1'): net = tf.layers.conv2d( @@ -60,7 +62,7 @@ def res_net_model(features, labels, mode): filters=64, kernel_size=7, activation=tf.nn.relu) - net = tf.layers.batch_normalization(net) + net = tf.layers.batch_normalization(net, training=training) # Max pool net = tf.layers.max_pooling2d( @@ -88,7 +90,7 @@ def res_net_model(features, labels, mode): kernel_size=1, padding='valid', activation=tf.nn.relu) - conv = tf.layers.batch_normalization(conv) + conv = tf.layers.batch_normalization(conv, training=training) with tf.variable_scope(name + '/conv_bottleneck'): conv = tf.layers.conv2d( @@ -97,7 +99,7 @@ def res_net_model(features, labels, mode): kernel_size=3, padding='same', activation=tf.nn.relu) - conv = tf.layers.batch_normalization(conv) + conv = tf.layers.batch_normalization(conv, training=training) # 1x1 convolution responsible for restoring dimension with tf.variable_scope(name + '/conv_out'): @@ -108,7 +110,7 @@ def res_net_model(features, labels, mode): kernel_size=1, padding='valid', activation=tf.nn.relu) - conv = tf.layers.batch_normalization(conv) + conv = tf.layers.batch_normalization(conv, training=training) # shortcut connections that turn the network into its counterpart # residual function (identity shortcut) @@ -154,7 +156,7 @@ def res_net_model(features, labels, mode): loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) # Create training op. - if mode == tf.estimator.ModeKeys.TRAIN: + if training: optimizer = tf.train.AdagradOptimizer(learning_rate=0.01) train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) diff --git a/tensorflow/examples/tutorials/word2vec/BUILD b/tensorflow/examples/tutorials/word2vec/BUILD index 42d6355b4f06258a3c22d0ef324bb31880f2d9a3..bfcf4592690a1692db67090c9b6d4e1e4832c45f 100644 --- a/tensorflow/examples/tutorials/word2vec/BUILD +++ b/tensorflow/examples/tutorials/word2vec/BUILD @@ -13,6 +13,9 @@ py_binary( "word2vec_basic.py", ], srcs_version = "PY2AND3", + tags = [ + "no-internal-py3", + ], deps = [ "//tensorflow:tensorflow_py", "//third_party/py/numpy", diff --git a/tensorflow/contrib/bayesflow/python/ops/optimizers.py b/tensorflow/experimental_api.py similarity index 54% rename from tensorflow/contrib/bayesflow/python/ops/optimizers.py rename to tensorflow/experimental_api.py index fb70628d1083836281e9327e83e109493276c64f..63a8aa9cb1dc130a7999c3b248815633998c4cd0 100644 --- a/tensorflow/contrib/bayesflow/python/ops/optimizers.py +++ b/tensorflow/experimental_api.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,25 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Probabilistic optimizer modules. -See ${python/contrib.bayesflow.optimizers}. -""" +# Bring in all of the public TensorFlow interface into this +# module. from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import +# pylint: disable=g-bad-import-order +from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import # pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.sgld_optimizer import * -from tensorflow.contrib.bayesflow.python.ops.variational_sgd_optimizer import * +from tensorflow.tools.api.generator.api import * # pylint: disable=redefined-builtin # pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = [ - 'SGLDOptimizer', - 'VariationalSGDOptimizer', -] +from tensorflow.python.util.lazy_loader import LazyLoader +contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') +del LazyLoader -remove_undocumented(__name__, _allowed_symbols) +from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top +app.flags = flags # pylint: disable=undefined-variable + +del absolute_import +del division +del print_function diff --git a/tensorflow/go/genop/internal/api_def_map.go b/tensorflow/go/genop/internal/api_def_map.go index 07b689dbba23a3aa991983f3b373fa8445c673e1..8600452b476dee49292cbffe630026cf6077e22b 100644 --- a/tensorflow/go/genop/internal/api_def_map.go +++ b/tensorflow/go/genop/internal/api_def_map.go @@ -31,7 +31,7 @@ import ( "unsafe" "github.com/golang/protobuf/proto" - pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/tensorflow/core/framework" + pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/tensorflow/core/framework_go_proto" ) // Encapsulates a collection of API definitions. diff --git a/tensorflow/go/genop/internal/genop.go b/tensorflow/go/genop/internal/genop.go index 82f7510f2ed947e0a87e4d88cfce1ecaaa6362f8..fb8163121850cee36e1fcc652ca258b1fe2d42ff 100644 --- a/tensorflow/go/genop/internal/genop.go +++ b/tensorflow/go/genop/internal/genop.go @@ -47,7 +47,7 @@ import ( "unsafe" "github.com/golang/protobuf/proto" - pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/tensorflow/core/framework" + pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/tensorflow/core/framework_go_proto" ) // GenerateFunctionsForRegisteredOps writes a Go source code file to w @@ -359,13 +359,13 @@ type attrWrapper struct { api *pb.ApiDef_Attr } -func (a *attrWrapper) Name() string { return a.api.Name } -func (a *attrWrapper) RenameTo() string { return a.api.RenameTo } -func (a *attrWrapper) Description() string { return a.api.Description } -func (a *attrWrapper) Type() string { return a.op.Type } -func (a *attrWrapper) IsListAttr() bool { return isListAttr(a.op) } -func (a *attrWrapper) HasMinimum() bool { return a.op.HasMinimum } -func (a *attrWrapper) Minimum() int64 { return a.op.Minimum } +func (a *attrWrapper) Name() string { return a.api.Name } +func (a *attrWrapper) RenameTo() string { return a.api.RenameTo } +func (a *attrWrapper) Description() string { return a.api.Description } +func (a *attrWrapper) Type() string { return a.op.Type } +func (a *attrWrapper) IsListAttr() bool { return isListAttr(a.op) } +func (a *attrWrapper) HasMinimum() bool { return a.op.HasMinimum } +func (a *attrWrapper) Minimum() int64 { return a.op.Minimum } func (a *attrWrapper) DefaultValue() interface{} { return a.api.DefaultValue } type argWrapper struct { diff --git a/tensorflow/go/genop/internal/genop_test.go b/tensorflow/go/genop/internal/genop_test.go index b3a23dff102a690b1f7f08b675219929355f139f..d20d22e0c1502f92ade7ef5aa40985dce73b7552 100644 --- a/tensorflow/go/genop/internal/genop_test.go +++ b/tensorflow/go/genop/internal/genop_test.go @@ -22,7 +22,7 @@ import ( "testing" "github.com/golang/protobuf/proto" - pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/tensorflow/core/framework" + pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/tensorflow/core/framework_go_proto" ) // Creates an ApiDef based on opdef and applies overrides diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index d9e684a661f2690c9352baec0649fbf42fc79255..838f4f230193b871dfd62b5c19943e2f9fa0fc89 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -38,188 +38,6 @@ func makeOutputList(op *tf.Operation, start int, output string) ([]tf.Output, in return list, start + size, nil } -// WriteImageSummaryAttr is an optional argument to WriteImageSummary. -type WriteImageSummaryAttr func(optionalAttr) - -// WriteImageSummaryMaxImages sets the optional max_images attribute to value. -// -// value: Max number of batch elements to generate images for. -// If not specified, defaults to 3 -// -// REQUIRES: value >= 1 -func WriteImageSummaryMaxImages(value int64) WriteImageSummaryAttr { - return func(m optionalAttr) { - m["max_images"] = value - } -} - -// Writes a `Summary` protocol buffer with images. -// -// The summary has up to `max_images` summary values containing images. The -// images are built from `tensor` which must be 4-D with shape `[batch_size, -// height, width, channels]` and where `channels` can be: -// -// * 1: `tensor` is interpreted as Grayscale. -// * 3: `tensor` is interpreted as RGB. -// * 4: `tensor` is interpreted as RGBA. -// -// The images have the same number of channels as the input tensor. For float -// input, the values are normalized one image at a time to fit in the range -// `[0, 255]`. `uint8` values are unchanged. The op uses two different -// normalization algorithms: -// -// * If the input values are all positive, they are rescaled so the largest one -// is 255. -// -// * If any input value is negative, the values are shifted so input value 0.0 -// is at 127. They are then rescaled so that either the smallest value is 0, -// or the largest one is 255. -// -// The `tag` argument is a scalar `Tensor` of type `string`. It is used to -// build the `tag` of the summary values: -// -// * If `max_images` is 1, the summary value tag is '*tag*/image'. -// * If `max_images` is greater than 1, the summary value tags are -// generated sequentially as '*tag*/image/0', '*tag*/image/1', etc. -// -// The `bad_color` argument is the color to use in the generated images for -// non-finite input values. It is a `unit8` 1-D tensor of length `channels`. -// Each element must be in the range `[0, 255]` (It represents the value of a -// pixel in the output image). Non-finite values in the input tensor are -// replaced by this tensor in the output image. The default value is the color -// red. -// -// Arguments: -// writer: A handle to a summary writer. -// step: The step to write the summary for. -// tag: Scalar. Used to build the `tag` attribute of the summary values. -// tensor: 4-D of shape `[batch_size, height, width, channels]` where -// `channels` is 1, 3, or 4. -// bad_color: Color to use for pixels with non-finite values. -// -// Returns the created operation. -func WriteImageSummary(scope *Scope, writer tf.Output, step tf.Output, tag tf.Output, tensor tf.Output, bad_color tf.Output, optional ...WriteImageSummaryAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "WriteImageSummary", - Input: []tf.Input{ - writer, step, tag, tensor, bad_color, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Outputs a `tf.Event` protocol buffer. -// -// When CreateSummaryDbWriter is being used, this op can be useful for -// importing data from event logs. -// -// Arguments: -// writer: A handle to a summary writer. -// event: A string containing a binary-encoded tf.Event proto. -// -// Returns the created operation. -func ImportEvent(scope *Scope, writer tf.Output, event tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ImportEvent", - Input: []tf.Input{ - writer, event, - }, - } - return scope.AddOperation(opspec) -} - -// Outputs a `Summary` protocol buffer with a tensor. -// -// Arguments: -// writer: A handle to a summary writer. -// step: The step to write the summary for. -// tensor: A tensor to serialize. -// tag: The summary's tag. -// summary_metadata: Serialized SummaryMetadata protocol buffer containing -// plugin-related metadata for this summary. -// -// Returns the created operation. -func WriteSummary(scope *Scope, writer tf.Output, step tf.Output, tensor tf.Output, tag tf.Output, summary_metadata tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "WriteSummary", - Input: []tf.Input{ - writer, step, tensor, tag, summary_metadata, - }, - } - return scope.AddOperation(opspec) -} - -// Creates summary database writer accessible by given resource handle. -// -// This can be used to write tensors from the execution graph directly -// to a database. Only SQLite is supported right now. This function -// will create the schema if it doesn't exist. Entries in the Users, -// Experiments, and Runs tables will be created automatically if they -// don't already exist. -// -// Arguments: -// writer: Handle to SummaryWriter resource to overwrite. -// db_uri: For example "file:/tmp/foo.sqlite". -// experiment_name: Can't contain ASCII control characters or <>. Case -// sensitive. If empty, then the Run will not be associated with any -// Experiment. -// run_name: Can't contain ASCII control characters or <>. Case sensitive. -// If empty, then each Tag will not be associated with any Run. -// user_name: Must be valid as both a DNS label and Linux username. If -// empty, then the Experiment will not be associated with any User. -// -// Returns the created operation. -func CreateSummaryDbWriter(scope *Scope, writer tf.Output, db_uri tf.Output, experiment_name tf.Output, run_name tf.Output, user_name tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "CreateSummaryDbWriter", - Input: []tf.Input{ - writer, db_uri, experiment_name, run_name, user_name, - }, - } - return scope.AddOperation(opspec) -} - -// Creates a summary file writer accessible by the given resource handle. -// -// Arguments: -// writer: A handle to the summary writer resource -// logdir: Directory where the event file will be written. -// max_queue: Size of the queue of pending events and summaries. -// flush_millis: How often, in milliseconds, to flush the pending events and -// summaries to disk. -// filename_suffix: Every event file's name is suffixed with this suffix. -// -// Returns the created operation. -func CreateSummaryFileWriter(scope *Scope, writer tf.Output, logdir tf.Output, max_queue tf.Output, flush_millis tf.Output, filename_suffix tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "CreateSummaryFileWriter", - Input: []tf.Input{ - writer, logdir, max_queue, flush_millis, filename_suffix, - }, - } - return scope.AddOperation(opspec) -} - // FakeQuantWithMinMaxVarsPerChannelGradientAttr is an optional argument to FakeQuantWithMinMaxVarsPerChannelGradient. type FakeQuantWithMinMaxVarsPerChannelGradientAttr func(optionalAttr) @@ -384,105 +202,113 @@ func FakeQuantWithMinMaxVarsGradient(scope *Scope, gradients tf.Output, inputs t return op.Output(0), op.Output(1), op.Output(2) } -// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2. -type MutableHashTableOfTensorsV2Attr func(optionalAttr) +// FakeQuantWithMinMaxArgsGradientAttr is an optional argument to FakeQuantWithMinMaxArgsGradient. +type FakeQuantWithMinMaxArgsGradientAttr func(optionalAttr) -// MutableHashTableOfTensorsV2Container sets the optional container attribute to value. -// -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr { +// FakeQuantWithMinMaxArgsGradientMin sets the optional min attribute to value. +// If not specified, defaults to -6 +func FakeQuantWithMinMaxArgsGradientMin(value float32) FakeQuantWithMinMaxArgsGradientAttr { return func(m optionalAttr) { - m["container"] = value + m["min"] = value } } -// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr { +// FakeQuantWithMinMaxArgsGradientMax sets the optional max attribute to value. +// If not specified, defaults to 6 +func FakeQuantWithMinMaxArgsGradientMax(value float32) FakeQuantWithMinMaxArgsGradientAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["max"] = value } } -// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. -// If not specified, defaults to false -func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr { +// FakeQuantWithMinMaxArgsGradientNumBits sets the optional num_bits attribute to value. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxArgsGradientNumBits(value int64) FakeQuantWithMinMaxArgsGradientAttr { return func(m optionalAttr) { - m["use_node_name_sharing"] = value + m["num_bits"] = value } } -// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value. -// If not specified, defaults to <> -func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr { +// FakeQuantWithMinMaxArgsGradientNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxArgsGradientNarrowRange(value bool) FakeQuantWithMinMaxArgsGradientAttr { return func(m optionalAttr) { - m["value_shape"] = value + m["narrow_range"] = value } } -// Creates an empty hash table. -// -// This op creates a mutable hash table, specifying the type of its keys and -// values. Each value must be a vector. Data can be inserted into the table using -// the insert operations. It does not support the initialization operation. +// Compute gradients for a FakeQuantWithMinMaxArgs operation. // // Arguments: -// key_dtype: Type of the table keys. -// value_dtype: Type of the table values. +// gradients: Backpropagated gradients above the FakeQuantWithMinMaxArgs operation. +// inputs: Values passed as inputs to the FakeQuantWithMinMaxArgs operation. // -// Returns Handle to a table. -func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) { +// Returns Backpropagated gradients below the FakeQuantWithMinMaxArgs operation: +// `gradients * (inputs >= min && inputs <= max)`. +func FakeQuantWithMinMaxArgsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsGradientAttr) (backprops tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MutableHashTableOfTensorsV2", - + Type: "FakeQuantWithMinMaxArgsGradient", + Input: []tf.Input{ + gradients, inputs, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad. -type ResourceApplyProximalAdagradAttr func(optionalAttr) +// FakeQuantWithMinMaxArgsAttr is an optional argument to FakeQuantWithMinMaxArgs. +type FakeQuantWithMinMaxArgsAttr func(optionalAttr) -// ResourceApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// FakeQuantWithMinMaxArgsMin sets the optional min attribute to value. +// If not specified, defaults to -6 +func FakeQuantWithMinMaxArgsMin(value float32) FakeQuantWithMinMaxArgsAttr { + return func(m optionalAttr) { + m["min"] = value + } +} + +// FakeQuantWithMinMaxArgsMax sets the optional max attribute to value. +// If not specified, defaults to 6 +func FakeQuantWithMinMaxArgsMax(value float32) FakeQuantWithMinMaxArgsAttr { + return func(m optionalAttr) { + m["max"] = value + } +} + +// FakeQuantWithMinMaxArgsNumBits sets the optional num_bits attribute to value. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxArgsNumBits(value int64) FakeQuantWithMinMaxArgsAttr { + return func(m optionalAttr) { + m["num_bits"] = value + } +} + +// FakeQuantWithMinMaxArgsNarrowRange sets the optional narrow_range attribute to value. // If not specified, defaults to false -func ResourceApplyProximalAdagradUseLocking(value bool) ResourceApplyProximalAdagradAttr { +func FakeQuantWithMinMaxArgsNarrowRange(value bool) FakeQuantWithMinMaxArgsAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["narrow_range"] = value } } -// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. -// -// accum += grad * grad -// prox_v = var - lr * grad * (1 / sqrt(accum)) -// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} +// Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type. // -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// grad: The gradient. +// Attributes `[min; max]` define the clamping range for the `inputs` data. +// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` +// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and +// then de-quantized and output as floats in `[min; max]` interval. +// `num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive. // -// Returns the created operation. -func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, optional ...ResourceApplyProximalAdagradAttr) (o *tf.Operation) { +// Quantization is called fake since the output is still in floating point. +func FakeQuantWithMinMaxArgs(scope *Scope, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsAttr) (outputs tf.Output) { if scope.Err() != nil { return } @@ -491,377 +317,705 @@ func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyProximalAdagrad", + Type: "FakeQuantWithMinMaxArgs", Input: []tf.Input{ - var_, accum, lr, l1, l2, grad, + inputs, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// MutableHashTableV2Attr is an optional argument to MutableHashTableV2. -type MutableHashTableV2Attr func(optionalAttr) - -// MutableHashTableV2Container sets the optional container attribute to value. +// Scatter `updates` into a new (initially zero) tensor according to `indices`. // -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func MutableHashTableV2Container(value string) MutableHashTableV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MutableHashTableV2SharedName sets the optional shared_name attribute to value. +// Creates a new tensor by applying sparse `updates` to individual +// values or slices within a zero tensor of the given `shape` according to +// indices. This operator is the inverse of the @{tf.gather_nd} operator which +// extracts values or slices from a given tensor. // -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func MutableHashTableV2SharedName(value string) MutableHashTableV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// MutableHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// **WARNING**: The order in which updates are applied is nondeterministic, so the +// output will be nondeterministic if `indices` contains duplicates. // -// value: If true and shared_name is empty, the table is shared -// using the node name. -// If not specified, defaults to false -func MutableHashTableV2UseNodeNameSharing(value bool) MutableHashTableV2Attr { - return func(m optionalAttr) { - m["use_node_name_sharing"] = value - } -} - -// Creates an empty hash table. +// `indices` is an integer tensor containing indices into a new tensor of shape +// `shape`. The last dimension of `indices` can be at most the rank of `shape`: // -// This op creates a mutable hash table, specifying the type of its keys and -// values. Each value must be a scalar. Data can be inserted into the table using -// the insert operations. It does not support the initialization operation. +// indices.shape[-1] <= shape.rank +// +// The last dimension of `indices` corresponds to indices into elements +// (if `indices.shape[-1] = shape.rank`) or slices +// (if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of +// `shape`. `updates` is a tensor with shape +// +// indices.shape[:-1] + shape[indices.shape[-1]:] +// +// The simplest form of scatter is to insert individual elements in a tensor by +// index. For example, say we want to insert 4 scattered elements in a rank-1 +// tensor with 8 elements. +// +//
+// +//
+// +// In Python, this scatter operation would look like this: +// +// ```python +// indices = tf.constant([[4], [3], [1], [7]]) +// updates = tf.constant([9, 10, 11, 12]) +// shape = tf.constant([8]) +// scatter = tf.scatter_nd(indices, updates, shape) +// with tf.Session() as sess: +// print(sess.run(scatter)) +// ``` +// +// The resulting tensor would look like this: +// +// [0, 11, 0, 10, 9, 0, 0, 12] +// +// We can also, insert entire slices of a higher rank tensor all at once. For +// example, if we wanted to insert two slices in the first dimension of a +// rank-3 tensor with two matrices of new values. +// +//
+// +//
+// +// In Python, this scatter operation would look like this: +// +// ```python +// indices = tf.constant([[0], [2]]) +// updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], +// [7, 7, 7, 7], [8, 8, 8, 8]], +// [[5, 5, 5, 5], [6, 6, 6, 6], +// [7, 7, 7, 7], [8, 8, 8, 8]]]) +// shape = tf.constant([4, 4, 4]) +// scatter = tf.scatter_nd(indices, updates, shape) +// with tf.Session() as sess: +// print(sess.run(scatter)) +// ``` +// +// The resulting tensor would look like this: +// +// [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], +// [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], +// [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], +// [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] // // Arguments: -// key_dtype: Type of the table keys. -// value_dtype: Type of the table values. +// indices: Index tensor. +// updates: Updates to scatter into output. +// shape: 1-D. The shape of the resulting tensor. // -// Returns Handle to a table. -func MutableHashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableV2Attr) (table_handle tf.Output) { +// Returns A new tensor with the given shape and updates applied according +// to the indices. +func ScatterNd(scope *Scope, indices tf.Output, updates tf.Output, shape tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "MutableHashTableV2", - - Attrs: attrs, + Type: "ScatterNd", + Input: []tf.Input{ + indices, updates, shape, + }, } op := scope.AddOperation(opspec) return op.Output(0) } -// MapUnstageNoKeyAttr is an optional argument to MapUnstageNoKey. -type MapUnstageNoKeyAttr func(optionalAttr) +// QuantizeAndDequantizeV2Attr is an optional argument to QuantizeAndDequantizeV2. +type QuantizeAndDequantizeV2Attr func(optionalAttr) -// MapUnstageNoKeyCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// QuantizeAndDequantizeV2SignedInput sets the optional signed_input attribute to value. // -// REQUIRES: value >= 0 -func MapUnstageNoKeyCapacity(value int64) MapUnstageNoKeyAttr { +// value: If the quantization is signed or unsigned. +// If not specified, defaults to true +func QuantizeAndDequantizeV2SignedInput(value bool) QuantizeAndDequantizeV2Attr { return func(m optionalAttr) { - m["capacity"] = value + m["signed_input"] = value } } -// MapUnstageNoKeyMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// QuantizeAndDequantizeV2NumBits sets the optional num_bits attribute to value. // -// REQUIRES: value >= 0 -func MapUnstageNoKeyMemoryLimit(value int64) MapUnstageNoKeyAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// MapUnstageNoKeyContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func MapUnstageNoKeyContainer(value string) MapUnstageNoKeyAttr { +// value: The bitwidth of the quantization. +// If not specified, defaults to 8 +func QuantizeAndDequantizeV2NumBits(value int64) QuantizeAndDequantizeV2Attr { return func(m optionalAttr) { - m["container"] = value + m["num_bits"] = value } } -// MapUnstageNoKeySharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func MapUnstageNoKeySharedName(value string) MapUnstageNoKeyAttr { +// QuantizeAndDequantizeV2RangeGiven sets the optional range_given attribute to value. +// +// value: If the range is given or should be computed from the tensor. +// If not specified, defaults to false +func QuantizeAndDequantizeV2RangeGiven(value bool) QuantizeAndDequantizeV2Attr { return func(m optionalAttr) { - m["shared_name"] = value + m["range_given"] = value } } -// Op removes and returns a random (key, value) +// Quantizes then dequantizes a tensor. // -// from the underlying container. If the underlying container -// does not contain elements, the op will block until it does. -func MapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageNoKeyAttr) (key tf.Output, values []tf.Output) { +// This op simulates the precision loss from the quantized forward pass by: +// 1. Quantizing the tensor to fixed point numbers, which should match the target +// quantization method when it is used in inference. +// 2. Dequantizing it back to floating point numbers for the following ops, most +// likely matmul. +// +// There are different ways to quantize. This version does not use the full range +// of the output type, choosing to elide the lowest possible value for symmetry +// (e.g., output range is -127 to 127, not -128 to 127 for signed 8 bit +// quantization), so that 0.0 maps to 0. +// +// To perform this op, we first find the range of values in our tensor. The range +// we use is always centered on 0, so we find m such that +// +// 1. m = max(abs(input_min), abs(input_max)) if range_given is true, +// 2. m = max(abs(min_elem(input)), abs(max_elem(input))) otherwise. +// +// Our input tensor range is then [-m, m]. +// +// Next, we choose our fixed-point quantization buckets, [min_fixed, max_fixed]. +// If signed_input is true, this is +// +// [min_fixed, max_fixed ] = +// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1]. +// +// Otherwise, if signed_input is false, the fixed-point range is +// +// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1]. +// +// From this we compute our scaling factor, s: +// +// s = (max_fixed - min_fixed) / (2 * m). +// +// Now we can quantize and dequantize the elements of our tensor. An element e +// is transformed into e': +// +// e' = (e * s).round_to_nearest() / s. +// +// Note that we have a different number of buckets in the signed vs. unsigned +// cases. For example, if num_bits == 8, we get 254 buckets in the signed case +// vs. 255 in the unsigned case. +// +// For example, suppose num_bits = 8 and m = 1. Then +// +// [min_fixed, max_fixed] = [-127, 127], and +// s = (127 + 127) / 2 = 127. +// +// Given the vector {-1, -0.5, 0, 0.3}, this is quantized to +// {-127, -63, 0, 38}, and dequantized to {-1, -63.0/127, 0, 38.0/127}. +// +// Arguments: +// input: Tensor to quantize and then dequantize. +// input_min: If range_given, this is the min of the range, otherwise this input +// will be ignored. +// input_max: If range_given, this is the max of the range, otherwise this input +// will be ignored. +func QuantizeAndDequantizeV2(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, optional ...QuantizeAndDequantizeV2Attr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MapUnstageNoKey", + Type: "QuantizeAndDequantizeV2", Input: []tf.Input{ - indices, + input, input_min, input_max, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - key = op.Output(idx) - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("MapUnstageNoKey", err) - return - } - return key, values + return op.Output(0) } -// HashTableV2Attr is an optional argument to HashTableV2. -type HashTableV2Attr func(optionalAttr) - -// HashTableV2Container sets the optional container attribute to value. +// Bitcasts a tensor from one type to another without copying data. // -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func HashTableV2Container(value string) HashTableV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// HashTableV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func HashTableV2SharedName(value string) HashTableV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// HashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. -// -// value: If true and shared_name is empty, the table is shared -// using the node name. -// If not specified, defaults to false -func HashTableV2UseNodeNameSharing(value bool) HashTableV2Attr { - return func(m optionalAttr) { - m["use_node_name_sharing"] = value - } -} - -// Creates a non-initialized hash table. +// Given a tensor `input`, this operation returns a tensor that has the same buffer +// data as `input` with datatype `type`. // -// This op creates a hash table, specifying the type of its keys and values. -// Before using the table you will have to initialize it. After initialization the -// table will be immutable. +// If the input datatype `T` is larger than the output datatype `type` then the +// shape changes from [...] to [..., sizeof(`T`)/sizeof(`type`)]. // -// Arguments: -// key_dtype: Type of the table keys. -// value_dtype: Type of the table values. +// If `T` is smaller than `type`, the operator requires that the rightmost +// dimension be equal to sizeof(`type`)/sizeof(`T`). The shape then goes from +// [..., sizeof(`type`)/sizeof(`T`)] to [...]. // -// Returns Handle to a table. -func HashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...HashTableV2Attr) (table_handle tf.Output) { +// *NOTE*: Bitcast is implemented as a low-level cast, so machines with different +// endian orderings will give different results. +func Bitcast(scope *Scope, input tf.Output, type_ tf.DataType) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"type": type_} opspec := tf.OpSpec{ - Type: "HashTableV2", - + Type: "Bitcast", + Input: []tf.Input{ + input, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Replaces the contents of the table with the specified keys and values. -// -// The tensor `keys` must be of the same type as the keys of the table. -// The tensor `values` must be of the type of the table values. +// Extract `patches` from `images` and put them in the "depth" output dimension. // // Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. -// values: Values to associate with keys. +// images: 4-D Tensor with shape `[batch, in_rows, in_cols, depth]`. +// ksizes: The size of the sliding window for each dimension of `images`. +// strides: 1-D of length 4. How far the centers of two consecutive patches are in +// the images. Must be: `[1, stride_rows, stride_cols, 1]`. +// rates: 1-D of length 4. Must be: `[1, rate_rows, rate_cols, 1]`. This is the +// input stride, specifying how far two consecutive patch samples are in the +// input. Equivalent to extracting patches with +// `patch_sizes_eff = patch_sizes + (patch_sizes - 1) * (rates - 1)`, followed by +// subsampling them spatially by a factor of `rates`. This is equivalent to +// `rate` in dilated (a.k.a. Atrous) convolutions. +// padding: The type of padding algorithm to use. // -// Returns the created operation. -func LookupTableImportV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { +// We specify the size-related attributes as: +// +// ```python +// ksizes = [1, ksize_rows, ksize_cols, 1] +// strides = [1, strides_rows, strides_cols, 1] +// rates = [1, rates_rows, rates_cols, 1] +// ``` +// +// Returns 4-D Tensor with shape `[batch, out_rows, out_cols, ksize_rows * +// ksize_cols * depth]` containing image patches with size +// `ksize_rows x ksize_cols x depth` vectorized in the "depth" dimension. Note +// `out_rows` and `out_cols` are the dimensions of the output patches. +func ExtractImagePatches(scope *Scope, images tf.Output, ksizes []int64, strides []int64, rates []int64, padding string) (patches tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"ksizes": ksizes, "strides": strides, "rates": rates, "padding": padding} opspec := tf.OpSpec{ - Type: "LookupTableImportV2", + Type: "ExtractImagePatches", Input: []tf.Input{ - table_handle, keys, values, + images, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// MapPeekAttr is an optional argument to MapPeek. -type MapPeekAttr func(optionalAttr) +// SpaceToDepthAttr is an optional argument to SpaceToDepth. +type SpaceToDepthAttr func(optionalAttr) -// MapPeekCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapPeekCapacity(value int64) MapPeekAttr { +// SpaceToDepthDataFormat sets the optional data_format attribute to value. +// If not specified, defaults to "NHWC" +func SpaceToDepthDataFormat(value string) SpaceToDepthAttr { return func(m optionalAttr) { - m["capacity"] = value + m["data_format"] = value } } -// MapPeekMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// SpaceToDepth for tensors of type T. // -// REQUIRES: value >= 0 -func MapPeekMemoryLimit(value int64) MapPeekAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// MapPeekContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func MapPeekContainer(value string) MapPeekAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MapPeekSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func MapPeekSharedName(value string) MapPeekAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op peeks at the values at the specified key. If the +// Rearranges blocks of spatial data, into depth. More specifically, +// this op outputs a copy of the input tensor where values from the `height` +// and `width` dimensions are moved to the `depth` dimension. +// The attr `block_size` indicates the input block size. // -// underlying container does not contain this key -// this op will block until it does. -func MapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapPeekAttr) (values []tf.Output) { +// * Non-overlapping blocks of size `block_size x block size` are rearranged +// into depth at each location. +// * The depth of the output tensor is `block_size * block_size * input_depth`. +// * The Y, X coordinates within each block of the input become the high order +// component of the output channel index. +// * The input tensor's height and width must be divisible by block_size. +// +// The `data_format` attr specifies the layout of the input and output tensors +// with the following options: +// "NHWC": `[ batch, height, width, channels ]` +// "NCHW": `[ batch, channels, height, width ]` +// "NCHW_VECT_C": +// `qint8 [ batch, channels / 4, height, width, 4 ]` +// +// It is useful to consider the operation as transforming a 6-D Tensor. +// e.g. for data_format = NHWC, +// Each element in the input tensor can be specified via 6 coordinates, +// ordered by decreasing memory layout significance as: +// n,oY,bY,oX,bX,iC (where n=batch index, oX, oY means X or Y coordinates +// within the output image, bX, bY means coordinates +// within the input block, iC means input channels). +// The output would be a transpose to the following layout: +// n,oY,oX,bY,bX,iC +// +// This operation is useful for resizing the activations between convolutions +// (but keeping all data), e.g. instead of pooling. It is also useful for training +// purely convolutional models. +// +// For example, given an input of shape `[1, 2, 2, 1]`, data_format = "NHWC" and +// block_size = 2: +// +// ``` +// x = [[[[1], [2]], +// [[3], [4]]]] +// ``` +// +// This operation will output a tensor of shape `[1, 1, 1, 4]`: +// +// ``` +// [[[[1, 2, 3, 4]]]] +// ``` +// +// Here, the input has a batch of 1 and each batch element has shape `[2, 2, 1]`, +// the corresponding output will have a single element (i.e. width and height are +// both 1) and will have a depth of 4 channels (1 * block_size * block_size). +// The output element shape is `[1, 1, 4]`. +// +// For an input tensor with larger depth, here of shape `[1, 2, 2, 3]`, e.g. +// +// ``` +// x = [[[[1, 2, 3], [4, 5, 6]], +// [[7, 8, 9], [10, 11, 12]]]] +// ``` +// +// This operation, for block_size of 2, will return the following tensor of shape +// `[1, 1, 1, 12]` +// +// ``` +// [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]] +// ``` +// +// Similarly, for the following input of shape `[1 4 4 1]`, and a block size of 2: +// +// ``` +// x = [[[[1], [2], [5], [6]], +// [[3], [4], [7], [8]], +// [[9], [10], [13], [14]], +// [[11], [12], [15], [16]]]] +// ``` +// +// the operator will return the following tensor of shape `[1 2 2 4]`: +// +// ``` +// x = [[[[1, 2, 3, 4], +// [5, 6, 7, 8]], +// [[9, 10, 11, 12], +// [13, 14, 15, 16]]]] +// ``` +// +// Arguments: +// +// block_size: The size of the spatial block. +func SpaceToDepth(scope *Scope, input tf.Output, block_size int64, optional ...SpaceToDepthAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{"block_size": block_size} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MapPeek", + Type: "SpaceToDepth", Input: []tf.Input{ - key, indices, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("MapPeek", err) - return - } - return values + return op.Output(0) } -// Returns (x - y)(x - y) element-wise. +// SpaceToBatch for 4-D tensors of type T. // -// *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func SquaredDifference(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// This is a legacy version of the more general SpaceToBatchND. +// +// Zero-pads and then rearranges (permutes) blocks of spatial data into batch. +// More specifically, this op outputs a copy of the input tensor where values from +// the `height` and `width` dimensions are moved to the `batch` dimension. After +// the zero-padding, both `height` and `width` of the input must be divisible by the +// block size. +// +// Arguments: +// input: 4-D with shape `[batch, height, width, depth]`. +// paddings: 2-D tensor of non-negative integers with shape `[2, 2]`. It specifies +// the padding of the input with zeros across the spatial dimensions as follows: +// +// paddings = [[pad_top, pad_bottom], [pad_left, pad_right]] +// +// The effective spatial dimensions of the zero-padded input tensor will be: +// +// height_pad = pad_top + height + pad_bottom +// width_pad = pad_left + width + pad_right +// +// The attr `block_size` must be greater than one. It indicates the block size. +// +// * Non-overlapping blocks of size `block_size x block size` in the height and +// width dimensions are rearranged into the batch dimension at each location. +// * The batch of the output tensor is `batch * block_size * block_size`. +// * Both height_pad and width_pad must be divisible by block_size. +// +// The shape of the output will be: +// +// [batch*block_size*block_size, height_pad/block_size, width_pad/block_size, +// depth] +// +// Some examples: +// +// (1) For the following input of shape `[1, 2, 2, 1]` and block_size of 2: +// +// ``` +// x = [[[[1], [2]], [[3], [4]]]] +// ``` +// +// The output tensor has shape `[4, 1, 1, 1]` and value: +// +// ``` +// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] +// ``` +// +// (2) For the following input of shape `[1, 2, 2, 3]` and block_size of 2: +// +// ``` +// x = [[[[1, 2, 3], [4, 5, 6]], +// [[7, 8, 9], [10, 11, 12]]]] +// ``` +// +// The output tensor has shape `[4, 1, 1, 3]` and value: +// +// ``` +// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] +// ``` +// +// (3) For the following input of shape `[1, 4, 4, 1]` and block_size of 2: +// +// ``` +// x = [[[[1], [2], [3], [4]], +// [[5], [6], [7], [8]], +// [[9], [10], [11], [12]], +// [[13], [14], [15], [16]]]] +// ``` +// +// The output tensor has shape `[4, 2, 2, 1]` and value: +// +// ``` +// x = [[[[1], [3]], [[9], [11]]], +// [[[2], [4]], [[10], [12]]], +// [[[5], [7]], [[13], [15]]], +// [[[6], [8]], [[14], [16]]]] +// ``` +// +// (4) For the following input of shape `[2, 2, 4, 1]` and block_size of 2: +// +// ``` +// x = [[[[1], [2], [3], [4]], +// [[5], [6], [7], [8]]], +// [[[9], [10], [11], [12]], +// [[13], [14], [15], [16]]]] +// ``` +// +// The output tensor has shape `[8, 1, 2, 1]` and value: +// +// ``` +// x = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], +// [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] +// ``` +// +// Among others, this operation is useful for reducing atrous convolution into +// regular convolution. +// +func SpaceToBatch(scope *Scope, input tf.Output, paddings tf.Output, block_size int64) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"block_size": block_size} opspec := tf.OpSpec{ - Type: "SquaredDifference", + Type: "SpaceToBatch", Input: []tf.Input{ - x, y, + input, paddings, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Forwards the input to the output. +// SpaceToBatch for N-D tensors of type T. // -// This operator represents the loop termination condition used by the -// "pivot" switches of a loop. +// This operation divides "spatial" dimensions `[1, ..., M]` of the input into a +// grid of blocks of shape `block_shape`, and interleaves these blocks with the +// "batch" dimension (0) such that in the output, the spatial dimensions +// `[1, ..., M]` correspond to the position within the grid, and the batch +// dimension combines both the position within a spatial block and the original +// batch position. Prior to division into blocks, the spatial dimensions of the +// input are optionally zero padded according to `paddings`. See below for a +// precise description. // // Arguments: -// input: A boolean scalar, representing the branch predicate of the Switch op. +// input: N-D with shape `input_shape = [batch] + spatial_shape + remaining_shape`, +// where spatial_shape has `M` dimensions. +// block_shape: 1-D with shape `[M]`, all values must be >= 1. +// paddings: 2-D with shape `[M, 2]`, all values must be >= 0. +// `paddings[i] = [pad_start, pad_end]` specifies the padding for input dimension +// `i + 1`, which corresponds to spatial dimension `i`. It is required that +// `block_shape[i]` divides `input_shape[i + 1] + pad_start + pad_end`. // -// Returns The same tensor as `input`. -func LoopCond(scope *Scope, input tf.Output) (output tf.Output) { +// This operation is equivalent to the following steps: +// +// 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the +// input according to `paddings` to produce `padded` of shape `padded_shape`. +// +// 2. Reshape `padded` to `reshaped_padded` of shape: +// +// [batch] + +// [padded_shape[1] / block_shape[0], +// block_shape[0], +// ..., +// padded_shape[M] / block_shape[M-1], +// block_shape[M-1]] + +// remaining_shape +// +// 3. Permute dimensions of `reshaped_padded` to produce +// `permuted_reshaped_padded` of shape: +// +// block_shape + +// [batch] + +// [padded_shape[1] / block_shape[0], +// ..., +// padded_shape[M] / block_shape[M-1]] + +// remaining_shape +// +// 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the batch +// dimension, producing an output tensor of shape: +// +// [batch * prod(block_shape)] + +// [padded_shape[1] / block_shape[0], +// ..., +// padded_shape[M] / block_shape[M-1]] + +// remaining_shape +// +// Some examples: +// +// (1) For the following input of shape `[1, 2, 2, 1]`, `block_shape = [2, 2]`, and +// `paddings = [[0, 0], [0, 0]]`: +// +// ``` +// x = [[[[1], [2]], [[3], [4]]]] +// ``` +// +// The output tensor has shape `[4, 1, 1, 1]` and value: +// +// ``` +// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] +// ``` +// +// (2) For the following input of shape `[1, 2, 2, 3]`, `block_shape = [2, 2]`, and +// `paddings = [[0, 0], [0, 0]]`: +// +// ``` +// x = [[[[1, 2, 3], [4, 5, 6]], +// [[7, 8, 9], [10, 11, 12]]]] +// ``` +// +// The output tensor has shape `[4, 1, 1, 3]` and value: +// +// ``` +// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] +// ``` +// +// (3) For the following input of shape `[1, 4, 4, 1]`, `block_shape = [2, 2]`, and +// `paddings = [[0, 0], [0, 0]]`: +// +// ``` +// x = [[[[1], [2], [3], [4]], +// [[5], [6], [7], [8]], +// [[9], [10], [11], [12]], +// [[13], [14], [15], [16]]]] +// ``` +// +// The output tensor has shape `[4, 2, 2, 1]` and value: +// +// ``` +// x = [[[[1], [3]], [[9], [11]]], +// [[[2], [4]], [[10], [12]]], +// [[[5], [7]], [[13], [15]]], +// [[[6], [8]], [[14], [16]]]] +// ``` +// +// (4) For the following input of shape `[2, 2, 4, 1]`, block_shape = `[2, 2]`, and +// paddings = `[[0, 0], [2, 0]]`: +// +// ``` +// x = [[[[1], [2], [3], [4]], +// [[5], [6], [7], [8]]], +// [[[9], [10], [11], [12]], +// [[13], [14], [15], [16]]]] +// ``` +// +// The output tensor has shape `[8, 1, 3, 1]` and value: +// +// ``` +// x = [[[[0], [1], [3]]], [[[0], [9], [11]]], +// [[[0], [2], [4]]], [[[0], [10], [12]]], +// [[[0], [5], [7]]], [[[0], [13], [15]]], +// [[[0], [6], [8]]], [[[0], [14], [16]]]] +// ``` +// +// Among others, this operation is useful for reducing atrous convolution into +// regular convolution. +func SpaceToBatchND(scope *Scope, input tf.Output, block_shape tf.Output, paddings tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "LoopCond", + Type: "SpaceToBatchND", Input: []tf.Input{ - input, + input, block_shape, paddings, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// QuantizedMulAttr is an optional argument to QuantizedMul. -type QuantizedMulAttr func(optionalAttr) +// ListDiffAttr is an optional argument to ListDiff. +type ListDiffAttr func(optionalAttr) -// QuantizedMulToutput sets the optional Toutput attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedMulToutput(value tf.DataType) QuantizedMulAttr { +// ListDiffOutIdx sets the optional out_idx attribute to value. +// If not specified, defaults to DT_INT32 +func ListDiffOutIdx(value tf.DataType) ListDiffAttr { return func(m optionalAttr) { - m["Toutput"] = value + m["out_idx"] = value } } -// Returns x * y element-wise, working on quantized buffers. +// Computes the difference between two lists of numbers or strings. // -// Arguments: +// Given a list `x` and a list `y`, this operation returns a list `out` that +// represents all values that are in `x` but not in `y`. The returned list `out` +// is sorted in the same order that the numbers appear in `x` (duplicates are +// preserved). This operation also returns a list `idx` that represents the +// position of each `out` element in `x`. In other words: // +// `out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1]` // -// min_x: The float value that the lowest quantized `x` value represents. -// max_x: The float value that the highest quantized `x` value represents. -// min_y: The float value that the lowest quantized `y` value represents. -// max_y: The float value that the highest quantized `y` value represents. +// For example, given this input: // -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +// ``` +// x = [1, 2, 3, 4, 5, 6] +// y = [1, 3, 5] +// ``` // -// *NOTE*: `QuantizedMul` supports limited forms of broadcasting. More about -// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func QuantizedMul(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedMulAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { +// This operation would return: +// +// ``` +// out ==> [2, 4, 6] +// idx ==> [1, 3, 5] +// ``` +// +// Arguments: +// x: 1-D. Values to keep. +// y: 1-D. Values to remove. +// +// Returns 1-D. Values present in `x` but not in `y`.1-D. Positions of `x` values preserved in `out`. +func ListDiff(scope *Scope, x tf.Output, y tf.Output, optional ...ListDiffAttr) (out tf.Output, idx tf.Output) { if scope.Err() != nil { return } @@ -870,111 +1024,297 @@ func QuantizedMul(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizedMul", + Type: "ListDiff", Input: []tf.Input{ - x, y, min_x, max_x, min_y, max_y, + x, y, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0), op.Output(1) } -// QuantizedMatMulAttr is an optional argument to QuantizedMatMul. -type QuantizedMatMulAttr func(optionalAttr) +// Inserts a dimension of 1 into a tensor's shape. +// +// Given a tensor `input`, this operation inserts a dimension of 1 at the +// dimension index `axis` of `input`'s shape. The dimension index `axis` starts at +// zero; if you specify a negative number for `axis` it is counted backward from +// the end. +// +// This operation is useful if you want to add a batch dimension to a single +// element. For example, if you have a single image of shape `[height, width, +// channels]`, you can make it a batch of 1 image with `expand_dims(image, 0)`, +// which will make the shape `[1, height, width, channels]`. +// +// Other examples: +// +// ``` +// # 't' is a tensor of shape [2] +// shape(expand_dims(t, 0)) ==> [1, 2] +// shape(expand_dims(t, 1)) ==> [2, 1] +// shape(expand_dims(t, -1)) ==> [2, 1] +// +// # 't2' is a tensor of shape [2, 3, 5] +// shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5] +// shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5] +// shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1] +// ``` +// +// This operation requires that: +// +// `-1-input.dims() <= dim <= input.dims()` +// +// This operation is related to `squeeze()`, which removes dimensions of +// size 1. +// +// Arguments: +// +// axis: 0-D (scalar). Specifies the dimension index at which to +// expand the shape of `input`. Must be in the range +// `[-rank(input) - 1, rank(input)]`. +// +// Returns Contains the same data as `input`, but its shape has an additional +// dimension of size 1 added. +func ExpandDims(scope *Scope, input tf.Output, axis tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ExpandDims", + Input: []tf.Input{ + input, axis, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} -// QuantizedMatMulToutput sets the optional Toutput attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedMatMulToutput(value tf.DataType) QuantizedMatMulAttr { - return func(m optionalAttr) { - m["Toutput"] = value +// A placeholder op that passes through `input` when its output is not fed. +// +// Arguments: +// input: The default value to produce when `output` is not fed. +// shape: The (possibly partial) shape of the tensor. +// +// Returns A placeholder tensor that defaults to `input` if it is not fed. +func PlaceholderWithDefault(scope *Scope, input tf.Output, shape tf.Shape) (output tf.Output) { + if scope.Err() != nil { + return } + attrs := map[string]interface{}{"shape": shape} + opspec := tf.OpSpec{ + Type: "PlaceholderWithDefault", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) } -// QuantizedMatMulTransposeA sets the optional transpose_a attribute to value. +// A placeholder op for a value that will be fed into the computation. // -// value: If true, `a` is transposed before multiplication. -// If not specified, defaults to false -func QuantizedMatMulTransposeA(value bool) QuantizedMatMulAttr { - return func(m optionalAttr) { - m["transpose_a"] = value +// DEPRECATED at GraphDef version 23: Placeholder now behaves the same as PlaceholderV2. +// +// N.B. This operation will fail with an error if it is executed. It is +// intended as a way to represent a value that will always be fed, and to +// provide attrs that enable the fed value to be checked at runtime. +// +// Arguments: +// dtype: The type of elements in the tensor. +// shape: The shape of the tensor. The shape can be any partially-specified +// shape. To be unconstrained, pass in a shape with unknown rank. +// +// Returns A placeholder tensor that must be replaced using the feed mechanism. +func PlaceholderV2(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.Output) { + if scope.Err() != nil { + return } + attrs := map[string]interface{}{"dtype": dtype, "shape": shape} + opspec := tf.OpSpec{ + Type: "PlaceholderV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) } -// QuantizedMatMulTransposeB sets the optional transpose_b attribute to value. +// PlaceholderAttr is an optional argument to Placeholder. +type PlaceholderAttr func(optionalAttr) + +// PlaceholderShape sets the optional shape attribute to value. // -// value: If true, `b` is transposed before multiplication. -// If not specified, defaults to false -func QuantizedMatMulTransposeB(value bool) QuantizedMatMulAttr { +// value: (Optional) The shape of the tensor. If the shape has 0 dimensions, the +// shape is unconstrained. +// If not specified, defaults to +func PlaceholderShape(value tf.Shape) PlaceholderAttr { return func(m optionalAttr) { - m["transpose_b"] = value + m["shape"] = value } } -// QuantizedMatMulTactivation sets the optional Tactivation attribute to value. +// A placeholder op for a value that will be fed into the computation. +// +// N.B. This operation will fail with an error if it is executed. It is +// intended as a way to represent a value that will always be fed, and to +// provide attrs that enable the fed value to be checked at runtime. +// +// Arguments: +// dtype: The type of elements in the tensor. +// +// Returns A placeholder tensor that must be replaced using the feed mechanism. +func Placeholder(scope *Scope, dtype tf.DataType, optional ...PlaceholderAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Placeholder", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Gradient op for `MirrorPad` op. This op folds a mirror-padded tensor. +// +// This operation folds the padded areas of `input` by `MirrorPad` according to the +// `paddings` you specify. `paddings` must be the same as `paddings` argument +// given to the corresponding `MirrorPad` op. +// +// The folded size of each dimension D of the output is: +// +// `input.dim_size(D) - paddings(D, 0) - paddings(D, 1)` +// +// For example: +// +// ``` +// # 't' is [[1, 2, 3], [4, 5, 6], [7, 8, 9]]. +// # 'paddings' is [[0, 1]], [0, 1]]. +// # 'mode' is SYMMETRIC. +// # rank of 't' is 2. +// pad(t, paddings) ==> [[ 1, 5] +// [11, 28]] +// ``` +// +// Arguments: +// input: The input tensor to be folded. +// paddings: A two-column matrix specifying the padding sizes. The number of +// rows must be the same as the rank of `input`. +// mode: The mode used in the `MirrorPad` op. +// +// Returns The folded tensor. +func MirrorPadGrad(scope *Scope, input tf.Output, paddings tf.Output, mode string) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"mode": mode} + opspec := tf.OpSpec{ + Type: "MirrorPadGrad", + Input: []tf.Input{ + input, paddings, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Pads a tensor with mirrored values. +// +// This operation pads a `input` with mirrored values according to the `paddings` +// you specify. `paddings` is an integer tensor with shape `[n, 2]`, where n is +// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates +// how many values to add before the contents of `input` in that dimension, and +// `paddings[D, 1]` indicates how many values to add after the contents of `input` +// in that dimension. Both `paddings[D, 0]` and `paddings[D, 1]` must be no greater +// than `input.dim_size(D)` (or `input.dim_size(D) - 1`) if `copy_border` is true +// (if false, respectively). // -// value: The type of output produced by activation function -// following this operation. -// If not specified, defaults to DT_QUINT8 -func QuantizedMatMulTactivation(value tf.DataType) QuantizedMatMulAttr { - return func(m optionalAttr) { - m["Tactivation"] = value - } -} - -// Perform a quantized matrix multiplication of `a` by the matrix `b`. +// The padded size of each dimension D of the output is: // -// The inputs must be two-dimensional matrices and the inner dimension of -// `a` (after being transposed if `transpose_a` is non-zero) must match the -// outer dimension of `b` (after being transposed if `transposed_b` is -// non-zero). +// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` +// +// For example: +// +// ``` +// # 't' is [[1, 2, 3], [4, 5, 6]]. +// # 'paddings' is [[1, 1]], [2, 2]]. +// # 'mode' is SYMMETRIC. +// # rank of 't' is 2. +// pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2] +// [2, 1, 1, 2, 3, 3, 2] +// [5, 4, 4, 5, 6, 6, 5] +// [5, 4, 4, 5, 6, 6, 5]] +// ``` // // Arguments: -// a: Must be a two-dimensional tensor. -// b: Must be a two-dimensional tensor. -// min_a: The float value that the lowest quantized `a` value represents. -// max_a: The float value that the highest quantized `a` value represents. -// min_b: The float value that the lowest quantized `b` value represents. -// max_b: The float value that the highest quantized `b` value represents. +// input: The input tensor to be padded. +// paddings: A two-column matrix specifying the padding sizes. The number of +// rows must be the same as the rank of `input`. +// mode: Either `REFLECT` or `SYMMETRIC`. In reflect mode the padded regions +// do not include the borders, while in symmetric mode the padded regions +// do include the borders. For example, if `input` is `[1, 2, 3]` and `paddings` +// is `[0, 2]`, then the output is `[1, 2, 3, 2, 1]` in reflect mode, and +// it is `[1, 2, 3, 3, 2]` in symmetric mode. // -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -func QuantizedMatMul(scope *Scope, a tf.Output, b tf.Output, min_a tf.Output, max_a tf.Output, min_b tf.Output, max_b tf.Output, optional ...QuantizedMatMulAttr) (out tf.Output, min_out tf.Output, max_out tf.Output) { +// Returns The padded tensor. +func MirrorPad(scope *Scope, input tf.Output, paddings tf.Output, mode string) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"mode": mode} opspec := tf.OpSpec{ - Type: "QuantizedMatMul", + Type: "MirrorPad", Input: []tf.Input{ - a, b, min_a, max_a, min_b, max_b, + input, paddings, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// A placeholder op that passes through `input` when its output is not fed. +// Pads a tensor. // -// Arguments: -// input: The default value to produce when `output` is not fed. -// shape: The (possibly partial) shape of the tensor. +// This operation pads `input` according to the `paddings` and `constant_values` +// you specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is +// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates +// how many padding values to add before the contents of `input` in that dimension, +// and `paddings[D, 1]` indicates how many padding values to add after the contents +// of `input` in that dimension. `constant_values` is a scalar tensor of the same +// type as `input` that indicates the value to use for padding `input`. // -// Returns A placeholder tensor that defaults to `input` if it is not fed. -func PlaceholderWithDefault(scope *Scope, input tf.Output, shape tf.Shape) (output tf.Output) { +// The padded size of each dimension D of the output is: +// +// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` +// +// For example: +// +// ``` +// # 't' is [[1, 1], [2, 2]] +// # 'paddings' is [[1, 1], [2, 2]] +// # 'constant_values' is 0 +// # rank of 't' is 2 +// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] +// [0, 0, 1, 1, 0, 0] +// [0, 0, 2, 2, 0, 0] +// [0, 0, 0, 0, 0, 0]] +// ``` +func PadV2(scope *Scope, input tf.Output, paddings tf.Output, constant_values tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"shape": shape} opspec := tf.OpSpec{ - Type: "PlaceholderWithDefault", + Type: "PadV2", Input: []tf.Input{ - input, + input, paddings, constant_values, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -1398,6 +1738,21 @@ func Sinh(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } +// Computes rectified linear 6: `min(max(features, 0), 6)`. +func Relu6(scope *Scope, features tf.Output) (activations tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Relu6", + Input: []tf.Input{ + features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes the sum along segments of a tensor. // // Read @{$math_ops#segmentation$the section on segmentation} for an explanation of @@ -1754,6 +2109,47 @@ func LogUniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true i return op.Output(0), op.Output(1), op.Output(2) } +// Returns (x - y)(x - y) element-wise. +// +// *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func SquaredDifference(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SquaredDifference", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Forwards the input to the output. +// +// This operator represents the loop termination condition used by the +// "pivot" switches of a loop. +// +// Arguments: +// input: A boolean scalar, representing the branch predicate of the Switch op. +// +// Returns The same tensor as `input`. +func LoopCond(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LoopCond", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ApproximateEqualAttr is an optional argument to ApproximateEqual. type ApproximateEqualAttr func(optionalAttr) @@ -1821,59 +2217,6 @@ func Mul(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// SparseReduceSumSparseAttr is an optional argument to SparseReduceSumSparse. -type SparseReduceSumSparseAttr func(optionalAttr) - -// SparseReduceSumSparseKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func SparseReduceSumSparseKeepDims(value bool) SparseReduceSumSparseAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the sum of elements across dimensions of a SparseTensor. -// -// This Op takes a SparseTensor and is the sparse counterpart to -// `tf.reduce_sum()`. In contrast to SparseReduceSum, this Op returns a -// SparseTensor. -// -// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained -// with length 1. -// -// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor -// with a single element is returned. Additionally, the axes can be negative, -// which are interpreted according to the indexing rules in Python. -// -// Arguments: -// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. -// input_shape: 1-D. Shape of the input SparseTensor. -// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. -func SparseReduceSumSparse(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceSumSparseAttr) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SparseReduceSumSparse", - Input: []tf.Input{ - input_indices, input_values, input_shape, reduction_axes, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - // BiasAddAttr is an optional argument to BiasAdd. type BiasAddAttr func(optionalAttr) @@ -1922,36 +2265,41 @@ func BiasAdd(scope *Scope, value tf.Output, bias tf.Output, optional ...BiasAddA return op.Output(0) } -// BiasAddGradAttr is an optional argument to BiasAddGrad. -type BiasAddGradAttr func(optionalAttr) +// SparseReduceSumSparseAttr is an optional argument to SparseReduceSumSparse. +type SparseReduceSumSparseAttr func(optionalAttr) -// BiasAddGradDataFormat sets the optional data_format attribute to value. +// SparseReduceSumSparseKeepDims sets the optional keep_dims attribute to value. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the bias tensor will be added to the last dimension -// of the value tensor. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// The tensor will be added to "in_channels", the third-to-the-last -// dimension. -// If not specified, defaults to "NHWC" -func BiasAddGradDataFormat(value string) BiasAddGradAttr { +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func SparseReduceSumSparseKeepDims(value bool) SparseReduceSumSparseAttr { return func(m optionalAttr) { - m["data_format"] = value + m["keep_dims"] = value } } -// The backward operation for "BiasAdd" on the "bias" tensor. +// Computes the sum of elements across dimensions of a SparseTensor. // -// It accumulates all the values from out_backprop into the feature dimension. -// For NHWC data format, the feature dimension is the last. For NCHW data format, -// the feature dimension is the third-to-last. +// This Op takes a SparseTensor and is the sparse counterpart to +// `tf.reduce_sum()`. In contrast to SparseReduceSum, this Op returns a +// SparseTensor. // -// Arguments: -// out_backprop: Any number of dimensions. +// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained +// with length 1. // -// Returns 1-D with size the feature dimension of `out_backprop`. -func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAttr) (output tf.Output) { +// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor +// with a single element is returned. Additionally, the axes can be negative, +// which are interpreted according to the indexing rules in Python. +// +// Arguments: +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. +// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. +func SparseReduceSumSparse(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceSumSparseAttr) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { if scope.Err() != nil { return } @@ -1960,14 +2308,14 @@ func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAt a(attrs) } opspec := tf.OpSpec{ - Type: "BiasAddGrad", + Type: "SparseReduceSumSparse", Input: []tf.Input{ - out_backprop, + input_indices, input_values, input_shape, reduction_axes, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } // Returns x + y element-wise. @@ -2130,50 +2478,6 @@ func Sign(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// QuantizedAddAttr is an optional argument to QuantizedAdd. -type QuantizedAddAttr func(optionalAttr) - -// QuantizedAddToutput sets the optional Toutput attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedAddToutput(value tf.DataType) QuantizedAddAttr { - return func(m optionalAttr) { - m["Toutput"] = value - } -} - -// Returns x + y element-wise, working on quantized buffers. -// -// Arguments: -// -// -// min_x: The float value that the lowest quantized `x` value represents. -// max_x: The float value that the highest quantized `x` value represents. -// min_y: The float value that the lowest quantized `y` value represents. -// max_y: The float value that the highest quantized `y` value represents. -// -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -// -// *NOTE*: `QuantizedAdd` supports limited forms of broadcasting. More about -// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedAddAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizedAdd", - Input: []tf.Input{ - x, y, min_x, max_x, min_y, max_y, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - // ArgMinAttr is an optional argument to ArgMin. type ArgMinAttr func(optionalAttr) @@ -2454,125 +2758,6 @@ func CholeskyGrad(scope *Scope, l tf.Output, grad tf.Output) (output tf.Output) return op.Output(0) } -// Computes inverse hyperbolic cosine of x element-wise. -func Acosh(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Acosh", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// SerializeManySparseAttr is an optional argument to SerializeManySparse. -type SerializeManySparseAttr func(optionalAttr) - -// SerializeManySparseOutType sets the optional out_type attribute to value. -// -// value: The `dtype` to use for serialization; the supported types are `string` -// (default) and `variant`. -// If not specified, defaults to DT_STRING -func SerializeManySparseOutType(value tf.DataType) SerializeManySparseAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor` object. -// -// The `SparseTensor` must have rank `R` greater than 1, and the first dimension -// is treated as the minibatch dimension. Elements of the `SparseTensor` -// must be sorted in increasing order of this first dimension. The serialized -// `SparseTensor` objects going into each row of `serialized_sparse` will have -// rank `R-1`. -// -// The minibatch size `N` is extracted from `sparse_shape[0]`. -// -// Arguments: -// sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`. -// sparse_values: 1-D. The `values` of the minibatch `SparseTensor`. -// sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`. -func SerializeManySparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeManySparseAttr) (serialized_sparse tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SerializeManySparse", - Input: []tf.Input{ - sparse_indices, sparse_values, sparse_shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// TensorArrayV2Attr is an optional argument to TensorArrayV2. -type TensorArrayV2Attr func(optionalAttr) - -// TensorArrayV2ElementShape sets the optional element_shape attribute to value. -// If not specified, defaults to -func TensorArrayV2ElementShape(value tf.Shape) TensorArrayV2Attr { - return func(m optionalAttr) { - m["element_shape"] = value - } -} - -// TensorArrayV2DynamicSize sets the optional dynamic_size attribute to value. -// If not specified, defaults to false -func TensorArrayV2DynamicSize(value bool) TensorArrayV2Attr { - return func(m optionalAttr) { - m["dynamic_size"] = value - } -} - -// TensorArrayV2ClearAfterRead sets the optional clear_after_read attribute to value. -// If not specified, defaults to true -func TensorArrayV2ClearAfterRead(value bool) TensorArrayV2Attr { - return func(m optionalAttr) { - m["clear_after_read"] = value - } -} - -// TensorArrayV2TensorArrayName sets the optional tensor_array_name attribute to value. -// If not specified, defaults to "" -func TensorArrayV2TensorArrayName(value string) TensorArrayV2Attr { - return func(m optionalAttr) { - m["tensor_array_name"] = value - } -} - -// Deprecated. Use TensorArrayV3 -// -// DEPRECATED at GraphDef version 26: Use TensorArrayV3 -func TensorArrayV2(scope *Scope, size tf.Output, dtype tf.DataType, optional ...TensorArrayV2Attr) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TensorArrayV2", - Input: []tf.Input{ - size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes the mean along sparse segments of a tensor. // // Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is @@ -2860,61 +3045,6 @@ func Rsqrt(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Inserts a dimension of 1 into a tensor's shape. -// -// Given a tensor `input`, this operation inserts a dimension of 1 at the -// dimension index `axis` of `input`'s shape. The dimension index `axis` starts at -// zero; if you specify a negative number for `axis` it is counted backward from -// the end. -// -// This operation is useful if you want to add a batch dimension to a single -// element. For example, if you have a single image of shape `[height, width, -// channels]`, you can make it a batch of 1 image with `expand_dims(image, 0)`, -// which will make the shape `[1, height, width, channels]`. -// -// Other examples: -// -// ``` -// # 't' is a tensor of shape [2] -// shape(expand_dims(t, 0)) ==> [1, 2] -// shape(expand_dims(t, 1)) ==> [2, 1] -// shape(expand_dims(t, -1)) ==> [2, 1] -// -// # 't2' is a tensor of shape [2, 3, 5] -// shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5] -// shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5] -// shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1] -// ``` -// -// This operation requires that: -// -// `-1-input.dims() <= dim <= input.dims()` -// -// This operation is related to `squeeze()`, which removes dimensions of -// size 1. -// -// Arguments: -// -// axis: 0-D (scalar). Specifies the dimension index at which to -// expand the shape of `input`. Must be in the range -// `[-rank(input) - 1, rank(input)]`. -// -// Returns Contains the same data as `input`, but its shape has an additional -// dimension of size 1 added. -func ExpandDims(scope *Scope, input tf.Output, axis tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ExpandDims", - Input: []tf.Input{ - input, axis, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // MatrixInverseAttr is an optional argument to MatrixInverse. type MatrixInverseAttr func(optionalAttr) @@ -3346,55 +3476,31 @@ func QuantizedBiasAdd(scope *Scope, input tf.Output, bias tf.Output, min_input t // Produces the average pool of the input tensor for quantized types. // -// Arguments: -// input: 4-D with shape `[batch, height, width, channels]`. -// min_input: The float value that the lowest quantized input value represents. -// max_input: The float value that the highest quantized input value represents. -// ksize: The size of the window for each dimension of the input tensor. -// The length must be 4 to match the number of dimensions of the input. -// strides: The stride of the sliding window for each dimension of the input -// tensor. The length must be 4 to match the number of dimensions of the input. -// padding: The type of padding algorithm to use. -// -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -func QuantizedAvgPool(scope *Scope, input tf.Output, min_input tf.Output, max_input tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output, min_output tf.Output, max_output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - opspec := tf.OpSpec{ - Type: "QuantizedAvgPool", - Input: []tf.Input{ - input, min_input, max_input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Updates the table to associates keys with values. -// -// The tensor `keys` must be of the same type as the keys of the table. -// The tensor `values` must be of the type of the table values. -// -// Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. -// values: Values to associate with keys. -// -// Returns the created operation. -func LookupTableInsertV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { +// Arguments: +// input: 4-D with shape `[batch, height, width, channels]`. +// min_input: The float value that the lowest quantized input value represents. +// max_input: The float value that the highest quantized input value represents. +// ksize: The size of the window for each dimension of the input tensor. +// The length must be 4 to match the number of dimensions of the input. +// strides: The stride of the sliding window for each dimension of the input +// tensor. The length must be 4 to match the number of dimensions of the input. +// padding: The type of padding algorithm to use. +// +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +func QuantizedAvgPool(scope *Scope, input tf.Output, min_input tf.Output, max_input tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output, min_output tf.Output, max_output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "LookupTableInsertV2", + Type: "QuantizedAvgPool", Input: []tf.Input{ - table_handle, keys, values, + input, min_input, max_input, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } // FractionalAvgPoolAttr is an optional argument to FractionalAvgPool. @@ -3533,285 +3639,68 @@ func RandomCropSeed2(value int64) RandomCropAttr { // `size` is a 1-D int64 tensor with 2 elements representing the crop height and // width. The values must be non negative. // -// This Op picks a random location in `image` and crops a `height` by `width` -// rectangle from that location. The random location is picked so the cropped -// area will fit inside the original image. -// -// Arguments: -// image: 3-D of shape `[height, width, channels]`. -// size: 1-D of length 2 containing: `crop_height`, `crop_width`.. -// -// Returns 3-D of shape `[crop_height, crop_width, channels].` -func RandomCrop(scope *Scope, image tf.Output, size tf.Output, optional ...RandomCropAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RandomCrop", - Input: []tf.Input{ - image, size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// TopKV2Attr is an optional argument to TopKV2. -type TopKV2Attr func(optionalAttr) - -// TopKV2Sorted sets the optional sorted attribute to value. -// -// value: If true the resulting `k` elements will be sorted by the values in -// descending order. -// If not specified, defaults to true -func TopKV2Sorted(value bool) TopKV2Attr { - return func(m optionalAttr) { - m["sorted"] = value - } -} - -// Finds values and indices of the `k` largest elements for the last dimension. -// -// If the input is a vector (rank-1), finds the `k` largest entries in the vector -// and outputs their values and indices as vectors. Thus `values[j]` is the -// `j`-th largest entry in `input`, and its index is `indices[j]`. -// -// For matrices (resp. higher rank input), computes the top `k` entries in each -// row (resp. vector along the last dimension). Thus, -// -// values.shape = indices.shape = input.shape[:-1] + [k] -// -// If two elements are equal, the lower-index element appears first. -// -// Arguments: -// input: 1-D or higher with last dimension at least `k`. -// k: 0-D. Number of top elements to look for along the last dimension (along each -// row for matrices). -// -// Returns The `k` largest elements along each last dimensional slice.The indices of `values` within the last dimension of `input`. -func TopKV2(scope *Scope, input tf.Output, k tf.Output, optional ...TopKV2Attr) (values tf.Output, indices tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TopKV2", - Input: []tf.Input{ - input, k, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Returns x // y element-wise. -// -// *NOTE*: `FloorDiv` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func FloorDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "FloorDiv", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns a batched diagonal tensor with a given batched diagonal values. -// -// Given a `diagonal`, this operation returns a tensor with the `diagonal` and -// everything else padded with zeros. The diagonal is computed as follows: -// -// Assume `diagonal` has `k` dimensions `[I, J, K, ..., N]`, then the output is a -// tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where: -// -// `output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]`. -// -// For example: -// -// ``` -// # 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]] -// -// and diagonal.shape = (2, 4) -// -// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0] -// [0, 2, 0, 0] -// [0, 0, 3, 0] -// [0, 0, 0, 4]], -// [[5, 0, 0, 0] -// [0, 6, 0, 0] -// [0, 0, 7, 0] -// [0, 0, 0, 8]]] -// -// which has shape (2, 4, 4) -// ``` -// -// Arguments: -// diagonal: Rank `k`, where `k >= 1`. -// -// Returns Rank `k+1`, with `output.shape = diagonal.shape + [diagonal.shape[-1]]`. -func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MatrixDiag", - Input: []tf.Input{ - diagonal, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Says whether the targets are in the top `K` predictions. -// -// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the -// prediction for the target class is among the top `k` predictions among -// all predictions for example `i`. Note that the behavior of `InTopK` differs -// from the `TopK` op in its handling of ties; if multiple classes have the -// same prediction value and straddle the top-`k` boundary, all of those -// classes are considered to be in the top `k`. -// -// More formally, let -// -// \\(predictions_i\\) be the predictions for all classes for example `i`, -// \\(targets_i\\) be the target class for example `i`, -// \\(out_i\\) be the output for example `i`, -// -// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ -// -// Arguments: -// predictions: A `batch_size` x `classes` tensor. -// targets: A `batch_size` vector of class ids. -// k: Number of top elements to look at for computing precision. -// -// Returns Computed Precision at `k` as a `bool Tensor`. -func InTopK(scope *Scope, predictions tf.Output, targets tf.Output, k int64) (precision tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"k": k} - opspec := tf.OpSpec{ - Type: "InTopK", - Input: []tf.Input{ - predictions, targets, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Given a quantized tensor described by (input, input_min, input_max), outputs a -// -// range that covers the actual values present in that tensor. This op is -// typically used to produce the requested_output_min and requested_output_max for -// Requantize. -// -// Arguments: -// -// input_min: The float value that the minimum quantized input value represents. -// input_max: The float value that the maximum quantized input value represents. -// -// Returns The computed min output.the computed max output. -func RequantizationRange(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output) (output_min tf.Output, output_max tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RequantizationRange", - Input: []tf.Input{ - input, input_min, input_max, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Returns the truth value of (x <= y) element-wise. -// -// *NOTE*: `LessEqual` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func LessEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LessEqual", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes softmax activations. -// -// For each batch `i` and class `j` we have -// -// softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j])) +// This Op picks a random location in `image` and crops a `height` by `width` +// rectangle from that location. The random location is picked so the cropped +// area will fit inside the original image. // // Arguments: -// logits: 2-D with shape `[batch_size, num_classes]`. +// image: 3-D of shape `[height, width, channels]`. +// size: 1-D of length 2 containing: `crop_height`, `crop_width`.. // -// Returns Same shape as `logits`. -func Softmax(scope *Scope, logits tf.Output) (softmax tf.Output) { +// Returns 3-D of shape `[crop_height, crop_width, channels].` +func RandomCrop(scope *Scope, image tf.Output, size tf.Output, optional ...RandomCropAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Softmax", + Type: "RandomCrop", Input: []tf.Input{ - logits, + image, size, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// DecodeBmpAttr is an optional argument to DecodeBmp. -type DecodeBmpAttr func(optionalAttr) +// TopKV2Attr is an optional argument to TopKV2. +type TopKV2Attr func(optionalAttr) -// DecodeBmpChannels sets the optional channels attribute to value. -// If not specified, defaults to 0 -func DecodeBmpChannels(value int64) DecodeBmpAttr { +// TopKV2Sorted sets the optional sorted attribute to value. +// +// value: If true the resulting `k` elements will be sorted by the values in +// descending order. +// If not specified, defaults to true +func TopKV2Sorted(value bool) TopKV2Attr { return func(m optionalAttr) { - m["channels"] = value + m["sorted"] = value } } -// Decode the first frame of a BMP-encoded image to a uint8 tensor. +// Finds values and indices of the `k` largest elements for the last dimension. // -// The attr `channels` indicates the desired number of color channels for the -// decoded image. +// If the input is a vector (rank-1), finds the `k` largest entries in the vector +// and outputs their values and indices as vectors. Thus `values[j]` is the +// `j`-th largest entry in `input`, and its index is `indices[j]`. // -// Accepted values are: +// For matrices (resp. higher rank input), computes the top `k` entries in each +// row (resp. vector along the last dimension). Thus, // -// * 0: Use the number of channels in the BMP-encoded image. -// * 3: output an RGB image. -// * 4: output an RGBA image. +// values.shape = indices.shape = input.shape[:-1] + [k] +// +// If two elements are equal, the lower-index element appears first. // // Arguments: -// contents: 0-D. The BMP-encoded image. +// input: 1-D or higher with last dimension at least `k`. +// k: 0-D. Number of top elements to look for along the last dimension (along each +// row for matrices). // -// Returns 3-D with shape `[height, width, channels]`. RGB order -func DecodeBmp(scope *Scope, contents tf.Output, optional ...DecodeBmpAttr) (image tf.Output) { +// Returns The `k` largest elements along each last dimensional slice.The indices of `values` within the last dimension of `input`. +func TopKV2(scope *Scope, input tf.Output, k tf.Output, optional ...TopKV2Attr) (values tf.Output, indices tf.Output) { if scope.Err() != nil { return } @@ -3820,204 +3709,150 @@ func DecodeBmp(scope *Scope, contents tf.Output, optional ...DecodeBmpAttr) (ima a(attrs) } opspec := tf.OpSpec{ - Type: "DecodeBmp", + Type: "TopKV2", Input: []tf.Input{ - contents, + input, k, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Computes softsign gradients for a softsign operation. -// -// Arguments: -// gradients: The backpropagated gradients to the corresponding softsign operation. -// features: The features passed as input to the corresponding softsign operation. +// Returns x // y element-wise. // -// Returns The gradients: `gradients / (1 + abs(features)) ** 2`. -func SoftsignGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { +// *NOTE*: `FloorDiv` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func FloorDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SoftsignGrad", + Type: "FloorDiv", Input: []tf.Input{ - gradients, features, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// BatchMatMulAttr is an optional argument to BatchMatMul. -type BatchMatMulAttr func(optionalAttr) - -// BatchMatMulAdjX sets the optional adj_x attribute to value. +// Returns a batched diagonal tensor with a given batched diagonal values. // -// value: If `True`, adjoint the slices of `x`. Defaults to `False`. -// If not specified, defaults to false -func BatchMatMulAdjX(value bool) BatchMatMulAttr { - return func(m optionalAttr) { - m["adj_x"] = value - } -} - -// BatchMatMulAdjY sets the optional adj_y attribute to value. +// Given a `diagonal`, this operation returns a tensor with the `diagonal` and +// everything else padded with zeros. The diagonal is computed as follows: // -// value: If `True`, adjoint the slices of `y`. Defaults to `False`. -// If not specified, defaults to false -func BatchMatMulAdjY(value bool) BatchMatMulAttr { - return func(m optionalAttr) { - m["adj_y"] = value - } -} - -// Multiplies slices of two tensors in batches. +// Assume `diagonal` has `k` dimensions `[I, J, K, ..., N]`, then the output is a +// tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where: // -// Multiplies all slices of `Tensor` `x` and `y` (each slice can be -// viewed as an element of a batch), and arranges the individual results -// in a single output tensor of the same batch size. Each of the -// individual slices can optionally be adjointed (to adjoint a matrix -// means to transpose and conjugate it) before multiplication by setting -// the `adj_x` or `adj_y` flag to `True`, which are by default `False`. +// `output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]`. // -// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` -// and `[..., r_y, c_y]`. +// For example: // -// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: +// ``` +// # 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]] // -// r_o = c_x if adj_x else r_x -// c_o = r_y if adj_y else c_y +// and diagonal.shape = (2, 4) // -// It is computed as: +// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0] +// [0, 2, 0, 0] +// [0, 0, 3, 0] +// [0, 0, 0, 4]], +// [[5, 0, 0, 0] +// [0, 6, 0, 0] +// [0, 0, 7, 0] +// [0, 0, 0, 8]]] // -// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) +// which has shape (2, 4, 4) +// ``` // // Arguments: -// x: 2-D or higher with shape `[..., r_x, c_x]`. -// y: 2-D or higher with shape `[..., r_y, c_y]`. +// diagonal: Rank `k`, where `k >= 1`. // -// Returns 3-D or higher with shape `[..., r_o, c_o]` -func BatchMatMul(scope *Scope, x tf.Output, y tf.Output, optional ...BatchMatMulAttr) (output tf.Output) { +// Returns Rank `k+1`, with `output.shape = diagonal.shape + [diagonal.shape[-1]]`. +func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "BatchMatMul", + Type: "MatrixDiag", Input: []tf.Input{ - x, y, + diagonal, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Pads a tensor. -// -// This operation pads `input` according to the `paddings` and `constant_values` -// you specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is -// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates -// how many padding values to add before the contents of `input` in that dimension, -// and `paddings[D, 1]` indicates how many padding values to add after the contents -// of `input` in that dimension. `constant_values` is a scalar tensor of the same -// type as `input` that indicates the value to use for padding `input`. -// -// The padded size of each dimension D of the output is: -// -// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` -// -// For example: +// Returns the truth value of (x <= y) element-wise. // -// ``` -// # 't' is [[1, 1], [2, 2]] -// # 'paddings' is [[1, 1], [2, 2]] -// # 'constant_values' is 0 -// # rank of 't' is 2 -// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] -// [0, 0, 1, 1, 0, 0] -// [0, 0, 2, 2, 0, 0] -// [0, 0, 0, 0, 0, 0]] -// ``` -func PadV2(scope *Scope, input tf.Output, paddings tf.Output, constant_values tf.Output) (output tf.Output) { +// *NOTE*: `LessEqual` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func LessEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "PadV2", + Type: "LessEqual", Input: []tf.Input{ - input, paddings, constant_values, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns which elements of x are NaN. +// Computes softmax activations. // -// @compatibility(numpy) -// Equivalent to np.isnan -// @end_compatibility -func IsNan(scope *Scope, x tf.Output) (y tf.Output) { +// For each batch `i` and class `j` we have +// +// softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j])) +// +// Arguments: +// logits: 2-D with shape `[batch_size, num_classes]`. +// +// Returns Same shape as `logits`. +func Softmax(scope *Scope, logits tf.Output) (softmax tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "IsNan", + Type: "Softmax", Input: []tf.Input{ - x, + logits, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// FractionalAvgPoolGradAttr is an optional argument to FractionalAvgPoolGrad. -type FractionalAvgPoolGradAttr func(optionalAttr) +// DecodeBmpAttr is an optional argument to DecodeBmp. +type DecodeBmpAttr func(optionalAttr) -// FractionalAvgPoolGradOverlapping sets the optional overlapping attribute to value. -// -// value: When set to True, it means when pooling, the values at the boundary -// of adjacent pooling cells are used by both cells. For example: -// -// `index 0 1 2 3 4` -// -// `value 20 5 16 3 7` -// -// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. -// The result would be [41/3, 26/3] for fractional avg pooling. -// If not specified, defaults to false -func FractionalAvgPoolGradOverlapping(value bool) FractionalAvgPoolGradAttr { +// DecodeBmpChannels sets the optional channels attribute to value. +// If not specified, defaults to 0 +func DecodeBmpChannels(value int64) DecodeBmpAttr { return func(m optionalAttr) { - m["overlapping"] = value + m["channels"] = value } } -// Computes gradient of the FractionalAvgPool function. +// Decode the first frame of a BMP-encoded image to a uint8 tensor. // -// Unlike FractionalMaxPoolGrad, we don't need to find arg_max for -// FractionalAvgPoolGrad, we just need to evenly back-propagate each element of -// out_backprop to those indices that form the same pooling cell. Therefore, we -// just need to know the shape of original input tensor, instead of the whole -// tensor. +// The attr `channels` indicates the desired number of color channels for the +// decoded image. +// +// Accepted values are: +// +// * 0: Use the number of channels in the BMP-encoded image. +// * 3: output an RGB image. +// * 4: output an RGBA image. // // Arguments: -// orig_input_tensor_shape: Original input tensor shape for `fractional_avg_pool` -// out_backprop: 4-D with shape `[batch, height, width, channels]`. Gradients -// w.r.t. the output of `fractional_avg_pool`. -// row_pooling_sequence: row pooling sequence, form pooling region with -// col_pooling_sequence. -// col_pooling_sequence: column pooling sequence, form pooling region with -// row_pooling sequence. +// contents: 0-D. The BMP-encoded image. // -// Returns 4-D. Gradients w.r.t. the input of `fractional_avg_pool`. -func FractionalAvgPoolGrad(scope *Scope, orig_input_tensor_shape tf.Output, out_backprop tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output, optional ...FractionalAvgPoolGradAttr) (output tf.Output) { +// Returns 3-D with shape `[height, width, channels]`. RGB order +func DecodeBmp(scope *Scope, contents tf.Output, optional ...DecodeBmpAttr) (image tf.Output) { if scope.Err() != nil { return } @@ -4026,9 +3861,9 @@ func FractionalAvgPoolGrad(scope *Scope, orig_input_tensor_shape tf.Output, out_ a(attrs) } opspec := tf.OpSpec{ - Type: "FractionalAvgPoolGrad", + Type: "DecodeBmp", Input: []tf.Input{ - orig_input_tensor_shape, out_backprop, row_pooling_sequence, col_pooling_sequence, + contents, }, Attrs: attrs, } @@ -4036,76 +3871,88 @@ func FractionalAvgPoolGrad(scope *Scope, orig_input_tensor_shape tf.Output, out_ return op.Output(0) } -// Computes gradients for the exponential linear (Elu) operation. +// Computes softsign gradients for a softsign operation. // // Arguments: -// gradients: The backpropagated gradients to the corresponding Elu operation. -// outputs: The outputs of the corresponding Elu operation. +// gradients: The backpropagated gradients to the corresponding softsign operation. +// features: The features passed as input to the corresponding softsign operation. // -// Returns The gradients: `gradients * (outputs + 1)` if outputs < 0, -// `gradients` otherwise. -func EluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { +// Returns The gradients: `gradients / (1 + abs(features)) ** 2`. +func SoftsignGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "EluGrad", + Type: "SoftsignGrad", Input: []tf.Input{ - gradients, outputs, + gradients, features, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Converts each string in the input Tensor to its hash mod by a number of buckets. -// -// The hash function is deterministic on the content of the string within the -// process. -// -// Note that the hash function may change from time to time. -// This functionality will be deprecated and it's recommended to use -// `tf.string_to_hash_bucket_fast()` or `tf.string_to_hash_bucket_strong()`. -// -// Arguments: -// -// num_buckets: The number of buckets. +// BatchMatMulAttr is an optional argument to BatchMatMul. +type BatchMatMulAttr func(optionalAttr) + +// BatchMatMulAdjX sets the optional adj_x attribute to value. // -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToHashBucket(scope *Scope, string_tensor tf.Output, num_buckets int64) (output tf.Output) { - if scope.Err() != nil { - return +// value: If `True`, adjoint the slices of `x`. Defaults to `False`. +// If not specified, defaults to false +func BatchMatMulAdjX(value bool) BatchMatMulAttr { + return func(m optionalAttr) { + m["adj_x"] = value } - attrs := map[string]interface{}{"num_buckets": num_buckets} - opspec := tf.OpSpec{ - Type: "StringToHashBucket", - Input: []tf.Input{ - string_tensor, - }, - Attrs: attrs, +} + +// BatchMatMulAdjY sets the optional adj_y attribute to value. +// +// value: If `True`, adjoint the slices of `y`. Defaults to `False`. +// If not specified, defaults to false +func BatchMatMulAdjY(value bool) BatchMatMulAttr { + return func(m optionalAttr) { + m["adj_y"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Creates a dataset that contains `count` elements from the `input_dataset`. +// Multiplies slices of two tensors in batches. // -// Arguments: +// Multiplies all slices of `Tensor` `x` and `y` (each slice can be +// viewed as an element of a batch), and arranges the individual results +// in a single output tensor of the same batch size. Each of the +// individual slices can optionally be adjointed (to adjoint a matrix +// means to transpose and conjugate it) before multiplication by setting +// the `adj_x` or `adj_y` flag to `True`, which are by default `False`. // -// count: A scalar representing the number of elements from the `input_dataset` -// that should be taken. A value of `-1` indicates that all of `input_dataset` -// is taken. +// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +// and `[..., r_y, c_y]`. // +// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: // -func TakeDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// r_o = c_x if adj_x else r_x +// c_o = r_y if adj_y else c_y +// +// It is computed as: +// +// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) +// +// Arguments: +// x: 2-D or higher with shape `[..., r_x, c_x]`. +// y: 2-D or higher with shape `[..., r_y, c_y]`. +// +// Returns 3-D or higher with shape `[..., r_o, c_o]` +func BatchMatMul(scope *Scope, x tf.Output, y tf.Output, optional ...BatchMatMulAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TakeDataset", + Type: "BatchMatMul", Input: []tf.Input{ - input_dataset, count, + x, y, }, Attrs: attrs, } @@ -4113,15 +3960,19 @@ func TakeDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_ return op.Output(0) } -// Computes rectified linear 6: `min(max(features, 0), 6)`. -func Relu6(scope *Scope, features tf.Output) (activations tf.Output) { +// Returns which elements of x are NaN. +// +// @compatibility(numpy) +// Equivalent to np.isnan +// @end_compatibility +func IsNan(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Relu6", + Type: "IsNan", Input: []tf.Input{ - features, + x, }, } op := scope.AddOperation(opspec) @@ -4418,174 +4269,28 @@ func MaxPoolGradGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output return op.Output(0) } -// MaxPoolAttr is an optional argument to MaxPool. -type MaxPoolAttr func(optionalAttr) - -// MaxPoolDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolDataFormat(value string) MaxPoolAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Performs max pooling on the input. -// -// Arguments: -// input: 4-D input to pool over. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns The max pooled output tensor. -func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPool", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Bucketizes 'input' based on 'boundaries'. -// -// For example, if the inputs are -// boundaries = [0, 10, 100] -// input = [[-5, 10000] -// [150, 10] -// [5, 100]] -// -// then the output will be -// output = [[0, 3] -// [3, 2] -// [1, 3]] -// -// Arguments: -// input: Any shape of Tensor contains with int or float type. -// boundaries: A sorted list of floats gives the boundary of the buckets. -// -// Returns Same shape with 'input', each value of input replaced with bucket index. -// -// @compatibility(numpy) -// Equivalent to np.digitize. -// @end_compatibility -func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"boundaries": boundaries} - opspec := tf.OpSpec{ - Type: "Bucketize", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes gradients of the maxpooling function. -// -// Arguments: -// input: The original input. -// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the -// output of `max_pool`. -// argmax: The indices of the maximum values chosen for each output of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns Gradients w.r.t. the input of `max_pool`. -func MaxPoolGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - opspec := tf.OpSpec{ - Type: "MaxPoolGradWithArgmax", - Input: []tf.Input{ - input, grad, argmax, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// FakeQuantWithMinMaxArgsGradientAttr is an optional argument to FakeQuantWithMinMaxArgsGradient. -type FakeQuantWithMinMaxArgsGradientAttr func(optionalAttr) - -// FakeQuantWithMinMaxArgsGradientMin sets the optional min attribute to value. -// If not specified, defaults to -6 -func FakeQuantWithMinMaxArgsGradientMin(value float32) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["min"] = value - } -} - -// FakeQuantWithMinMaxArgsGradientMax sets the optional max attribute to value. -// If not specified, defaults to 6 -func FakeQuantWithMinMaxArgsGradientMax(value float32) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["max"] = value - } -} - -// FakeQuantWithMinMaxArgsGradientNumBits sets the optional num_bits attribute to value. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxArgsGradientNumBits(value int64) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["num_bits"] = value - } -} - -// FakeQuantWithMinMaxArgsGradientNarrowRange sets the optional narrow_range attribute to value. -// If not specified, defaults to false -func FakeQuantWithMinMaxArgsGradientNarrowRange(value bool) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["narrow_range"] = value - } -} - -// Compute gradients for a FakeQuantWithMinMaxArgs operation. +// Computes gradients of the maxpooling function. // // Arguments: -// gradients: Backpropagated gradients above the FakeQuantWithMinMaxArgs operation. -// inputs: Values passed as inputs to the FakeQuantWithMinMaxArgs operation. +// input: The original input. +// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the +// output of `max_pool`. +// argmax: The indices of the maximum values chosen for each output of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. // -// Returns Backpropagated gradients below the FakeQuantWithMinMaxArgs operation: -// `gradients * (inputs >= min && inputs <= max)`. -func FakeQuantWithMinMaxArgsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsGradientAttr) (backprops tf.Output) { +// Returns Gradients w.r.t. the input of `max_pool`. +func MaxPoolGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxArgsGradient", + Type: "MaxPoolGradWithArgmax", Input: []tf.Input{ - gradients, inputs, + input, grad, argmax, }, Attrs: attrs, } @@ -4661,284 +4366,124 @@ func Mod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// Computes square root of x element-wise. -// -// I.e., \\(y = \sqrt{x} = x^{1/2}\\). -func Sqrt(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Sqrt", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the gradients of 3-D convolution with respect to the filter. -// -// DEPRECATED at GraphDef version 10: Use Conv3DBackpropFilterV2 -// -// Arguments: -// input: Shape `[batch, depth, rows, cols, in_channels]`. -// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. -// `in_channels` must match between `input` and `filter`. -// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, -// out_channels]`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func Conv3DBackpropFilter(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - opspec := tf.OpSpec{ - Type: "Conv3DBackpropFilter", - Input: []tf.Input{ - input, filter, out_backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the gradient for the rsqrt of `x` wrt its input. -// -// Specifically, `grad = dy * -0.5 * y^3`, where `y = rsqrt(x)`, and `dy` -// is the corresponding input gradient. -func RsqrtGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RsqrtGrad", - Input: []tf.Input{ - y, dy, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ReverseSequenceAttr is an optional argument to ReverseSequence. -type ReverseSequenceAttr func(optionalAttr) +// DepthToSpaceAttr is an optional argument to DepthToSpace. +type DepthToSpaceAttr func(optionalAttr) -// ReverseSequenceBatchDim sets the optional batch_dim attribute to value. -// -// value: The dimension along which reversal is performed. -// If not specified, defaults to 0 -func ReverseSequenceBatchDim(value int64) ReverseSequenceAttr { +// DepthToSpaceDataFormat sets the optional data_format attribute to value. +// If not specified, defaults to "NHWC" +func DepthToSpaceDataFormat(value string) DepthToSpaceAttr { return func(m optionalAttr) { - m["batch_dim"] = value + m["data_format"] = value } } -// Reverses variable length slices. +// DepthToSpace for tensors of type T. // -// This op first slices `input` along the dimension `batch_dim`, and for each -// slice `i`, reverses the first `seq_lengths[i]` elements along -// the dimension `seq_dim`. +// Rearranges data from depth into blocks of spatial data. +// This is the reverse transformation of SpaceToDepth. More specifically, +// this op outputs a copy of the input tensor where values from the `depth` +// dimension are moved in spatial blocks to the `height` and `width` dimensions. +// The attr `block_size` indicates the input block size and how the data is moved. // -// The elements of `seq_lengths` must obey `seq_lengths[i] <= input.dims[seq_dim]`, -// and `seq_lengths` must be a vector of length `input.dims[batch_dim]`. +// * Chunks of data of size `block_size * block_size` from depth are rearranged +// into non-overlapping blocks of size `block_size x block_size` +// * The width the output tensor is `input_depth * block_size`, whereas the +// height is `input_height * block_size`. +// * The Y, X coordinates within each block of the output image are determined +// by the high order component of the input channel index. +// * The depth of the input tensor must be divisible by +// `block_size * block_size`. // -// The output slice `i` along dimension `batch_dim` is then given by input -// slice `i`, with the first `seq_lengths[i]` slices along dimension -// `seq_dim` reversed. +// The `data_format` attr specifies the layout of the input and output tensors +// with the following options: +// "NHWC": `[ batch, height, width, channels ]` +// "NCHW": `[ batch, channels, height, width ]` +// "NCHW_VECT_C": +// `qint8 [ batch, channels / 4, height, width, 4 ]` // -// For example: +// It is useful to consider the operation as transforming a 6-D Tensor. +// e.g. for data_format = NHWC, +// Each element in the input tensor can be specified via 6 coordinates, +// ordered by decreasing memory layout significance as: +// n,iY,iX,bY,bX,oC (where n=batch index, iX, iY means X or Y coordinates +// within the input image, bX, bY means coordinates +// within the output block, oC means output channels). +// The output would be the input transposed to the following layout: +// n,iY,bY,iX,bX,oC // -// ``` -// # Given this: -// batch_dim = 0 -// seq_dim = 1 -// input.dims = (4, 8, ...) -// seq_lengths = [7, 2, 3, 5] +// This operation is useful for resizing the activations between convolutions +// (but keeping all data), e.g. instead of pooling. It is also useful for training +// purely convolutional models. // -// # then slices of input are reversed on seq_dim, but only up to seq_lengths: -// output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...] -// output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...] -// output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...] -// output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...] +// For example, given an input of shape `[1, 1, 1, 4]`, data_format = "NHWC" and +// block_size = 2: // -// # while entries past seq_lens are copied through: -// output[0, 7:, :, ...] = input[0, 7:, :, ...] -// output[1, 2:, :, ...] = input[1, 2:, :, ...] -// output[2, 3:, :, ...] = input[2, 3:, :, ...] -// output[3, 2:, :, ...] = input[3, 2:, :, ...] // ``` -// -// In contrast, if: +// x = [[[[1, 2, 3, 4]]]] // // ``` -// # Given this: -// batch_dim = 2 -// seq_dim = 0 -// input.dims = (8, ?, 4, ...) -// seq_lengths = [7, 2, 3, 5] // -// # then slices of input are reversed on seq_dim, but only up to seq_lengths: -// output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...] -// output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...] -// output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...] -// output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...] +// This operation will output a tensor of shape `[1, 2, 2, 1]`: // -// # while entries past seq_lens are copied through: -// output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...] -// output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...] -// output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...] -// output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...] +// ``` +// [[[[1], [2]], +// [[3], [4]]]] // ``` // -// Arguments: -// input: The input to reverse. -// seq_lengths: 1-D with length `input.dims(batch_dim)` and -// `max(seq_lengths) <= input.dims(seq_dim)` -// seq_dim: The dimension which is partially reversed. -// -// Returns The partially reversed input. It has the same shape as `input`. -func ReverseSequence(scope *Scope, input tf.Output, seq_lengths tf.Output, seq_dim int64, optional ...ReverseSequenceAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"seq_dim": seq_dim} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ReverseSequence", - Input: []tf.Input{ - input, seq_lengths, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DepthwiseConv2dNativeAttr is an optional argument to DepthwiseConv2dNative. -type DepthwiseConv2dNativeAttr func(optionalAttr) - -// DepthwiseConv2dNativeDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, height, width, channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, channels, height, width]. -// If not specified, defaults to "NHWC" -func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// DepthwiseConv2dNativeDilations sets the optional dilations attribute to value. -// -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each filter -// element on that dimension. The dimension order is determined by the value of -// `data_format`, see above for details. Dilations in the batch and depth -// dimensions must be 1. -// If not specified, defaults to -func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors. +// Here, the input has a batch of 1 and each batch element has shape `[1, 1, 4]`, +// the corresponding output will have 2x2 elements and will have a depth of +// 1 channel (1 = `4 / (block_size * block_size)`). +// The output element shape is `[2, 2, 1]`. // -// Given an input tensor of shape `[batch, in_height, in_width, in_channels]` -// and a filter / kernel tensor of shape -// `[filter_height, filter_width, in_channels, channel_multiplier]`, containing -// `in_channels` convolutional filters of depth 1, `depthwise_conv2d` applies -// a different filter to each input channel (expanding from 1 channel to -// `channel_multiplier` channels for each), then concatenates the results -// together. Thus, the output has `in_channels * channel_multiplier` channels. +// For an input tensor with larger depth, here of shape `[1, 1, 1, 12]`, e.g. // // ``` -// for k in 0..in_channels-1 -// for q in 0..channel_multiplier-1 -// output[b, i, j, k * channel_multiplier + q] = -// sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] * -// filter[di, dj, k, q] +// x = [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]] // ``` // -// Must have `strides[0] = strides[3] = 1`. For the most common case of the same -// horizontal and vertices strides, `strides = [1, stride, stride, 1]`. +// This operation, for block size of 2, will return the following tensor of shape +// `[1, 2, 2, 3]` // -// Arguments: +// ``` +// [[[[1, 2, 3], [4, 5, 6]], +// [[7, 8, 9], [10, 11, 12]]]] // +// ``` // -// strides: 1-D of length 4. The stride of the sliding window for each dimension -// of `input`. -// padding: The type of padding algorithm to use. -func DepthwiseConv2dNative(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DepthwiseConv2dNative", - Input: []tf.Input{ - input, filter, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MaxPoolGradV2Attr is an optional argument to MaxPoolGradV2. -type MaxPoolGradV2Attr func(optionalAttr) - -// MaxPoolGradV2DataFormat sets the optional data_format attribute to value. +// Similarly, for the following input of shape `[1 2 2 4]`, and a block size of 2: // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolGradV2DataFormat(value string) MaxPoolGradV2Attr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes gradients of the maxpooling function. +// ``` +// x = [[[[1, 2, 3, 4], +// [5, 6, 7, 8]], +// [[9, 10, 11, 12], +// [13, 14, 15, 16]]]] +// ``` +// +// the operator will return the following tensor of shape `[1 4 4 1]`: +// +// ``` +// x = [[[ [1], [2], [5], [6]], +// [ [3], [4], [7], [8]], +// [ [9], [10], [13], [14]], +// [ [11], [12], [15], [16]]]] +// +// ``` // // Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: 4-D. Gradients w.r.t. the output of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. // -// Returns Gradients w.r.t. the input to `max_pool`. -func MaxPoolGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize tf.Output, strides tf.Output, padding string, optional ...MaxPoolGradV2Attr) (output tf.Output) { +// block_size: The size of the spatial block, same as in Space2Depth. +func DepthToSpace(scope *Scope, input tf.Output, block_size int64, optional ...DepthToSpaceAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"padding": padding} + attrs := map[string]interface{}{"block_size": block_size} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MaxPoolGradV2", + Type: "DepthToSpace", Input: []tf.Input{ - orig_input, orig_output, grad, ksize, strides, + input, }, Attrs: attrs, } @@ -4946,69 +4491,62 @@ func MaxPoolGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output, gr return op.Output(0) } -// Restore a reader to a previously saved state. -// -// Not all Readers support being restored, so this can produce an -// Unimplemented error. -// -// Arguments: -// reader_handle: Handle to a Reader. -// state: Result of a ReaderSerializeState of a Reader with type -// matching reader_handle. +// Conv3DBackpropInputV2Attr is an optional argument to Conv3DBackpropInputV2. +type Conv3DBackpropInputV2Attr func(optionalAttr) + +// Conv3DBackpropInputV2DataFormat sets the optional data_format attribute to value. // -// Returns the created operation. -func ReaderRestoreStateV2(scope *Scope, reader_handle tf.Output, state tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReaderRestoreStateV2", - Input: []tf.Input{ - reader_handle, state, - }, +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { + return func(m optionalAttr) { + m["data_format"] = value } - return scope.AddOperation(opspec) } -// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3. -type TensorArrayGatherV3Attr func(optionalAttr) - -// TensorArrayGatherV3ElementShape sets the optional element_shape attribute to value. +// Conv3DBackpropInputV2Dilations sets the optional dilations attribute to value. // -// value: The expected shape of an element, if known. Used to -// validate the shapes of TensorArray elements. If this shape is not -// fully specified, gathering zero-size TensorArrays is an error. -// If not specified, defaults to -func TensorArrayGatherV3ElementShape(value tf.Shape) TensorArrayGatherV3Attr { +// value: 1-D tensor of length 5. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each +// filter element on that dimension. The dimension order is determined by the +// value of `data_format`, see above for details. Dilations in the batch and +// depth dimensions must be 1. +// If not specified, defaults to +func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { return func(m optionalAttr) { - m["element_shape"] = value + m["dilations"] = value } } -// Gather specific elements from the TensorArray into output `value`. -// -// All elements selected by `indices` must have the same shape. +// Computes the gradients of 3-D convolution with respect to the input. // // Arguments: -// handle: The handle to a TensorArray. -// indices: The locations in the TensorArray from which to read tensor elements. -// flow_in: A float scalar that enforces proper chaining of operations. -// dtype: The type of the elem that is returned. -// -// Returns All of the elements in the TensorArray, concatenated along a new -// axis (the new dimension 0). -func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV3Attr) (value tf.Output) { +// input_sizes: An integer vector representing the tensor shape of `input`, +// where `input` is a 5-D +// `[batch, depth, rows, cols, in_channels]` tensor. +// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. +// `in_channels` must match between `input` and `filter`. +// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, +// out_channels]`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func Conv3DBackpropInputV2(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropInputV2Attr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{"strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TensorArrayGatherV3", + Type: "Conv3DBackpropInputV2", Input: []tf.Input{ - handle, indices, flow_in, + input_sizes, filter, out_backprop, }, Attrs: attrs, } @@ -5016,155 +4554,159 @@ func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow return op.Output(0) } -// Converts each string in the input Tensor to its hash mod by a number of buckets. -// -// The hash function is deterministic on the content of the string within the -// process and will never change. However, it is not suitable for cryptography. -// This function may be used when CPU time is scarce and inputs are trusted or -// unimportant. There is a risk of adversaries constructing inputs that all hash -// to the same bucket. To prevent this problem, use a strong hash function with -// `tf.string_to_hash_bucket_strong`. -// -// Arguments: -// input: The strings to assign a hash bucket. -// num_buckets: The number of buckets. +// Computes square root of x element-wise. // -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (output tf.Output) { +// I.e., \\(y = \sqrt{x} = x^{1/2}\\). +func Sqrt(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_buckets": num_buckets} opspec := tf.OpSpec{ - Type: "StringToHashBucketFast", + Type: "Sqrt", Input: []tf.Input{ - input, + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the max of x and y (i.e. x > y ? x : y) element-wise. +// Computes the gradients of 3-D convolution with respect to the filter. // -// *NOTE*: `Maximum` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Maximum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Maximum", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Outputs all keys and values in the table. +// DEPRECATED at GraphDef version 10: Use Conv3DBackpropFilterV2 // // Arguments: -// table_handle: Handle to the table. -// -// -// -// Returns Vector of all keys present in the table.Tensor of all values in the table. Indexed in parallel with `keys`. -func LookupTableExportV2(scope *Scope, table_handle tf.Output, Tkeys tf.DataType, Tvalues tf.DataType) (keys tf.Output, values tf.Output) { +// input: Shape `[batch, depth, rows, cols, in_channels]`. +// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. +// `in_channels` must match between `input` and `filter`. +// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, +// out_channels]`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func Conv3DBackpropFilter(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"Tkeys": Tkeys, "Tvalues": Tvalues} + attrs := map[string]interface{}{"strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "LookupTableExportV2", + Type: "Conv3DBackpropFilter", Input: []tf.Input{ - table_handle, + input, filter, out_backprop, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Real-valued fast Fourier transform. -// -// Computes the 1-dimensional discrete Fourier transform of a real-valued signal -// over the inner-most dimension of `input`. -// -// Since the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the -// `fft_length / 2 + 1` unique components of the FFT: the zero-frequency term, -// followed by the `fft_length / 2` positive-frequency terms. -// -// Along the axis `RFFT` is computed on, if `fft_length` is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. -// -// Arguments: -// input: A float32 tensor. -// fft_length: An int32 tensor of shape [1]. The FFT length. -// -// Returns A complex64 tensor of the same rank as `input`. The inner-most -// dimension of `input` is replaced with the `fft_length / 2 + 1` unique -// frequency components of its 1D Fourier transform. + return op.Output(0) +} + +// Computes the gradient for the rsqrt of `x` wrt its input. // -// @compatibility(numpy) -// Equivalent to np.fft.rfft -// @end_compatibility -func RFFT(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { +// Specifically, `grad = dy * -0.5 * y^3`, where `y = rsqrt(x)`, and `dy` +// is the corresponding input gradient. +func RsqrtGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "RFFT", + Type: "RsqrtGrad", Input: []tf.Input{ - input, fft_length, + y, dy, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ComplexAttr is an optional argument to Complex. -type ComplexAttr func(optionalAttr) +// ReverseSequenceAttr is an optional argument to ReverseSequence. +type ReverseSequenceAttr func(optionalAttr) -// ComplexTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_COMPLEX64 -func ComplexTout(value tf.DataType) ComplexAttr { +// ReverseSequenceBatchDim sets the optional batch_dim attribute to value. +// +// value: The dimension along which reversal is performed. +// If not specified, defaults to 0 +func ReverseSequenceBatchDim(value int64) ReverseSequenceAttr { return func(m optionalAttr) { - m["Tout"] = value + m["batch_dim"] = value } } -// Converts two real numbers to a complex number. +// Reverses variable length slices. // -// Given a tensor `real` representing the real part of a complex number, and a -// tensor `imag` representing the imaginary part of a complex number, this -// operation returns complex numbers elementwise of the form \\(a + bj\\), where -// *a* represents the `real` part and *b* represents the `imag` part. +// This op first slices `input` along the dimension `batch_dim`, and for each +// slice `i`, reverses the first `seq_lengths[i]` elements along +// the dimension `seq_dim`. // -// The input tensors `real` and `imag` must have the same shape. +// The elements of `seq_lengths` must obey `seq_lengths[i] <= input.dims[seq_dim]`, +// and `seq_lengths` must be a vector of length `input.dims[batch_dim]`. +// +// The output slice `i` along dimension `batch_dim` is then given by input +// slice `i`, with the first `seq_lengths[i]` slices along dimension +// `seq_dim` reversed. // // For example: // // ``` -// # tensor 'real' is [2.25, 3.25] -// # tensor `imag` is [4.75, 5.75] -// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]] +// # Given this: +// batch_dim = 0 +// seq_dim = 1 +// input.dims = (4, 8, ...) +// seq_lengths = [7, 2, 3, 5] +// +// # then slices of input are reversed on seq_dim, but only up to seq_lengths: +// output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...] +// output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...] +// output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...] +// output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...] +// +// # while entries past seq_lens are copied through: +// output[0, 7:, :, ...] = input[0, 7:, :, ...] +// output[1, 2:, :, ...] = input[1, 2:, :, ...] +// output[2, 3:, :, ...] = input[2, 3:, :, ...] +// output[3, 2:, :, ...] = input[3, 2:, :, ...] // ``` -func Complex(scope *Scope, real tf.Output, imag tf.Output, optional ...ComplexAttr) (out tf.Output) { +// +// In contrast, if: +// +// ``` +// # Given this: +// batch_dim = 2 +// seq_dim = 0 +// input.dims = (8, ?, 4, ...) +// seq_lengths = [7, 2, 3, 5] +// +// # then slices of input are reversed on seq_dim, but only up to seq_lengths: +// output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...] +// output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...] +// output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...] +// output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...] +// +// # while entries past seq_lens are copied through: +// output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...] +// output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...] +// output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...] +// output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...] +// ``` +// +// Arguments: +// input: The input to reverse. +// seq_lengths: 1-D with length `input.dims(batch_dim)` and +// `max(seq_lengths) <= input.dims(seq_dim)` +// seq_dim: The dimension which is partially reversed. +// +// Returns The partially reversed input. It has the same shape as `input`. +func ReverseSequence(scope *Scope, input tf.Output, seq_lengths tf.Output, seq_dim int64, optional ...ReverseSequenceAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"seq_dim": seq_dim} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Complex", + Type: "ReverseSequence", Input: []tf.Input{ - real, imag, + input, seq_lengths, }, Attrs: attrs, } @@ -5172,42 +4714,76 @@ func Complex(scope *Scope, real tf.Output, imag tf.Output, optional ...ComplexAt return op.Output(0) } -// ImagAttr is an optional argument to Imag. -type ImagAttr func(optionalAttr) +// DepthwiseConv2dNativeAttr is an optional argument to DepthwiseConv2dNative. +type DepthwiseConv2dNativeAttr func(optionalAttr) -// ImagTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_FLOAT -func ImagTout(value tf.DataType) ImagAttr { +// DepthwiseConv2dNativeDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, height, width, channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, channels, height, width]. +// If not specified, defaults to "NHWC" +func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr { return func(m optionalAttr) { - m["Tout"] = value + m["data_format"] = value } } -// Returns the imaginary part of a complex number. +// DepthwiseConv2dNativeDilations sets the optional dilations attribute to value. // -// Given a tensor `input` of complex numbers, this operation returns a tensor of -// type `float` that is the imaginary part of each element in `input`. All -// elements in `input` must be complex numbers of the form \\(a + bj\\), where *a* -// is the real part and *b* is the imaginary part returned by this operation. +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to +func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors. // -// For example: +// Given an input tensor of shape `[batch, in_height, in_width, in_channels]` +// and a filter / kernel tensor of shape +// `[filter_height, filter_width, in_channels, channel_multiplier]`, containing +// `in_channels` convolutional filters of depth 1, `depthwise_conv2d` applies +// a different filter to each input channel (expanding from 1 channel to +// `channel_multiplier` channels for each), then concatenates the results +// together. Thus, the output has `in_channels * channel_multiplier` channels. // // ``` -// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] -// tf.imag(input) ==> [4.75, 5.75] +// for k in 0..in_channels-1 +// for q in 0..channel_multiplier-1 +// output[b, i, j, k * channel_multiplier + q] = +// sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] * +// filter[di, dj, k, q] // ``` -func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output) { +// +// Must have `strides[0] = strides[3] = 1`. For the most common case of the same +// horizontal and vertices strides, `strides = [1, stride, stride, 1]`. +// +// Arguments: +// +// +// strides: 1-D of length 4. The stride of the sliding window for each dimension +// of `input`. +// padding: The type of padding algorithm to use. +func DepthwiseConv2dNative(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Imag", + Type: "DepthwiseConv2dNative", Input: []tf.Input{ - input, + input, filter, }, Attrs: attrs, } @@ -5215,89 +4791,119 @@ func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output return op.Output(0) } -// Compute the Hurwitz zeta function \\(\zeta(x, q)\\). +// MaxPoolGradV2Attr is an optional argument to MaxPoolGradV2. +type MaxPoolGradV2Attr func(optionalAttr) + +// MaxPoolGradV2DataFormat sets the optional data_format attribute to value. // -// The Hurwitz zeta function is defined as: +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolGradV2DataFormat(value string) MaxPoolGradV2Attr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes gradients of the maxpooling function. // +// Arguments: +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: 4-D. Gradients w.r.t. the output of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. // -// \\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\\) -func Zeta(scope *Scope, x tf.Output, q tf.Output) (z tf.Output) { +// Returns Gradients w.r.t. the input to `max_pool`. +func MaxPoolGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize tf.Output, strides tf.Output, padding string, optional ...MaxPoolGradV2Attr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"padding": padding} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Zeta", + Type: "MaxPoolGradV2", Input: []tf.Input{ - x, q, + orig_input, orig_output, grad, ksize, strides, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// LRNGradAttr is an optional argument to LRNGrad. -type LRNGradAttr func(optionalAttr) - -// LRNGradDepthRadius sets the optional depth_radius attribute to value. +// Restore a reader to a previously saved state. // -// value: A depth radius. -// If not specified, defaults to 5 -func LRNGradDepthRadius(value int64) LRNGradAttr { - return func(m optionalAttr) { - m["depth_radius"] = value - } -} - -// LRNGradBias sets the optional bias attribute to value. +// Not all Readers support being restored, so this can produce an +// Unimplemented error. // -// value: An offset (usually > 0 to avoid dividing by 0). -// If not specified, defaults to 1 -func LRNGradBias(value float32) LRNGradAttr { - return func(m optionalAttr) { - m["bias"] = value - } -} - -// LRNGradAlpha sets the optional alpha attribute to value. +// Arguments: +// reader_handle: Handle to a Reader. +// state: Result of a ReaderSerializeState of a Reader with type +// matching reader_handle. // -// value: A scale factor, usually positive. -// If not specified, defaults to 1 -func LRNGradAlpha(value float32) LRNGradAttr { - return func(m optionalAttr) { - m["alpha"] = value +// Returns the created operation. +func ReaderRestoreStateV2(scope *Scope, reader_handle tf.Output, state tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReaderRestoreStateV2", + Input: []tf.Input{ + reader_handle, state, + }, } + return scope.AddOperation(opspec) } -// LRNGradBeta sets the optional beta attribute to value. +// MaxPoolGradAttr is an optional argument to MaxPoolGrad. +type MaxPoolGradAttr func(optionalAttr) + +// MaxPoolGradDataFormat sets the optional data_format attribute to value. // -// value: An exponent. -// If not specified, defaults to 0.5 -func LRNGradBeta(value float32) LRNGradAttr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolGradDataFormat(value string) MaxPoolGradAttr { return func(m optionalAttr) { - m["beta"] = value + m["data_format"] = value } } -// Gradients for Local Response Normalization. +// Computes gradients of the maxpooling function. // // Arguments: -// input_grads: 4-D with shape `[batch, height, width, channels]`. -// input_image: 4-D with shape `[batch, height, width, channels]`. -// output_image: 4-D with shape `[batch, height, width, channels]`. +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: 4-D. Gradients w.r.t. the output of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. // -// Returns The gradients for LRN. -func LRNGrad(scope *Scope, input_grads tf.Output, input_image tf.Output, output_image tf.Output, optional ...LRNGradAttr) (output tf.Output) { +// Returns Gradients w.r.t. the input to `max_pool`. +func MaxPoolGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "LRNGrad", + Type: "MaxPoolGrad", Input: []tf.Input{ - input_grads, input_image, output_image, + orig_input, orig_output, grad, }, Attrs: attrs, } @@ -5305,33 +4911,66 @@ func LRNGrad(scope *Scope, input_grads tf.Output, input_image tf.Output, output_ return op.Output(0) } -// AnyAttr is an optional argument to Any. -type AnyAttr func(optionalAttr) +// CropAndResizeAttr is an optional argument to CropAndResize. +type CropAndResizeAttr func(optionalAttr) -// AnyKeepDims sets the optional keep_dims attribute to value. +// CropAndResizeMethod sets the optional method attribute to value. // -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func AnyKeepDims(value bool) AnyAttr { +// value: A string specifying the interpolation method. Only 'bilinear' is +// supported for now. +// If not specified, defaults to "bilinear" +func CropAndResizeMethod(value string) CropAndResizeAttr { return func(m optionalAttr) { - m["keep_dims"] = value + m["method"] = value } } -// Computes the "logical or" of elements across dimensions of a tensor. +// CropAndResizeExtrapolationValue sets the optional extrapolation_value attribute to value. // -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. +// value: Value used for extrapolation, when applicable. +// If not specified, defaults to 0 +func CropAndResizeExtrapolationValue(value float32) CropAndResizeAttr { + return func(m optionalAttr) { + m["extrapolation_value"] = value + } +} + +// Extracts crops from the input image tensor and bilinearly resizes them (possibly +// +// with aspect ratio change) to a common output size specified by `crop_size`. This +// is more general than the `crop_to_bounding_box` op which extracts a fixed size +// slice from the input image and does not allow resizing or aspect ratio change. +// +// Returns a tensor with `crops` from the input `image` at positions defined at the +// bounding box locations in `boxes`. The cropped boxes are all resized (with +// bilinear interpolation) to a fixed `size = [crop_height, crop_width]`. The +// result is a 4-D tensor `[num_boxes, crop_height, crop_width, depth]`. The +// resizing is corner aligned. In particular, if `boxes = [[0, 0, 1, 1]]`, the +// method will give identical results to using `tf.image.resize_bilinear()` +// with `align_corners=True`. // // Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. +// image: A 4-D tensor of shape `[batch, image_height, image_width, depth]`. +// Both `image_height` and `image_width` need to be positive. +// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor +// specifies the coordinates of a box in the `box_ind[i]` image and is specified +// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of +// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the +// `[0, 1]` interval of normalized image height is mapped to +// `[0, image_height - 1]` in image height coordinates. We do allow `y1` > `y2`, in +// which case the sampled crop is an up-down flipped version of the original +// image. The width dimension is treated similarly. Normalized coordinates +// outside the `[0, 1]` range are allowed, in which case we use +// `extrapolation_value` to extrapolate the input image values. +// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. +// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. +// crop_size: A 1-D tensor of 2 elements, `size = [crop_height, crop_width]`. All +// cropped image patches are resized to this size. The aspect ratio of the image +// content is not preserved. Both `crop_height` and `crop_width` need to be +// positive. // -// Returns The reduced tensor. -func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (output tf.Output) { +// Returns A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. +func CropAndResize(scope *Scope, image tf.Output, boxes tf.Output, box_ind tf.Output, crop_size tf.Output, optional ...CropAndResizeAttr) (crops tf.Output) { if scope.Err() != nil { return } @@ -5340,9 +4979,9 @@ func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (ou a(attrs) } opspec := tf.OpSpec{ - Type: "Any", + Type: "CropAndResize", Input: []tf.Input{ - input, axis, + image, boxes, box_ind, crop_size, }, Attrs: attrs, } @@ -5350,280 +4989,225 @@ func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (ou return op.Output(0) } -// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl. -type ResourceApplyFtrlAttr func(optionalAttr) - -// ResourceApplyFtrlUseLocking sets the optional use_locking attribute to value. +// Fills empty rows in the input 2-D `SparseTensor` with a default value. // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyFtrlUseLocking(value bool) ResourceApplyFtrlAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' according to the Ftrl-proximal scheme. +// The input `SparseTensor` is represented via the tuple of inputs +// (`indices`, `values`, `dense_shape`). The output `SparseTensor` has the +// same `dense_shape` but with indices `output_indices` and values +// `output_values`. // -// accum_new = accum + grad * grad -// linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var -// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 -// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 -// accum = accum_new +// This op inserts a single entry for every row that doesn't have any values. +// The index is created as `[row, 0, ..., 0]` and the inserted value +// is `default_value`. +// +// For example, suppose `sp_input` has shape `[5, 6]` and non-empty values: +// +// [0, 1]: a +// [0, 3]: b +// [2, 0]: c +// [3, 1]: d +// +// Rows 1 and 4 are empty, so the output will be of shape `[5, 6]` with values: +// +// [0, 1]: a +// [0, 3]: b +// [1, 0]: default_value +// [2, 0]: c +// [3, 1]: d +// [4, 0]: default_value +// +// The output `SparseTensor` will be in row-major order and will have the +// same shape as the input. +// +// This op also returns an indicator vector shaped `[dense_shape[0]]` such that +// +// empty_row_indicator[i] = True iff row i was an empty row. +// +// And a reverse index map vector shaped `[indices.shape[0]]` that is used during +// backpropagation, +// +// reverse_index_map[j] = out_j s.t. indices[j, :] == output_indices[out_j, :] // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// linear: Should be from a Variable(). -// grad: The gradient. -// lr: Scaling factor. Must be a scalar. -// l1: L1 regulariation. Must be a scalar. -// l2: L2 regulariation. Must be a scalar. -// lr_power: Scaling factor. Must be a scalar. +// indices: 2-D. the indices of the sparse tensor. +// values: 1-D. the values of the sparse tensor. +// dense_shape: 1-D. the shape of the sparse tensor. +// default_value: 0-D. default value to insert into location `[row, 0, ..., 0]` +// for rows missing from the input sparse tensor. +// output indices: 2-D. the indices of the filled sparse tensor. // -// Returns the created operation. -func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlAttr) (o *tf.Operation) { +// Returns 1-D. the values of the filled sparse tensor.1-D. whether the dense row was missing in the +// input sparse tensor.1-D. a map from the input indices to the output indices. +func SparseFillEmptyRows(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output, default_value tf.Output) (output_indices tf.Output, output_values tf.Output, empty_row_indicator tf.Output, reverse_index_map tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ResourceApplyFtrl", + Type: "SparseFillEmptyRows", Input: []tf.Input{ - var_, accum, linear, grad, lr, l1, l2, lr_power, + indices, values, dense_shape, default_value, }, - Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// RandomUniformAttr is an optional argument to RandomUniform. -type RandomUniformAttr func(optionalAttr) - -// RandomUniformSeed sets the optional seed attribute to value. +// Reverses specific dimensions of a tensor. // -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomUniformSeed(value int64) RandomUniformAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomUniformSeed2 sets the optional seed2 attribute to value. +// Given a `tensor`, and a `bool` tensor `dims` representing the dimensions +// of `tensor`, this operation reverses each dimension i of `tensor` where +// `dims[i]` is `True`. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomUniformSeed2(value int64) RandomUniformAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Outputs random values from a uniform distribution. +// `tensor` can have up to 8 dimensions. The number of dimensions +// of `tensor` must equal the number of elements in `dims`. In other words: +// +// `rank(tensor) = size(dims)` +// +// For example: +// +// ``` +// # tensor 't' is [[[[ 0, 1, 2, 3], +// # [ 4, 5, 6, 7], +// # [ 8, 9, 10, 11]], +// # [[12, 13, 14, 15], +// # [16, 17, 18, 19], +// # [20, 21, 22, 23]]]] +// # tensor 't' shape is [1, 2, 3, 4] +// +// # 'dims' is [False, False, False, True] +// reverse(t, dims) ==> [[[[ 3, 2, 1, 0], +// [ 7, 6, 5, 4], +// [ 11, 10, 9, 8]], +// [[15, 14, 13, 12], +// [19, 18, 17, 16], +// [23, 22, 21, 20]]]] +// +// # 'dims' is [False, True, False, False] +// reverse(t, dims) ==> [[[[12, 13, 14, 15], +// [16, 17, 18, 19], +// [20, 21, 22, 23] +// [[ 0, 1, 2, 3], +// [ 4, 5, 6, 7], +// [ 8, 9, 10, 11]]]] // -// The generated values follow a uniform distribution in the range `[0, 1)`. The -// lower bound 0 is included in the range, while the upper bound 1 is excluded. +// # 'dims' is [False, False, True, False] +// reverse(t, dims) ==> [[[[8, 9, 10, 11], +// [4, 5, 6, 7], +// [0, 1, 2, 3]] +// [[20, 21, 22, 23], +// [16, 17, 18, 19], +// [12, 13, 14, 15]]]] +// ``` // // Arguments: -// shape: The shape of the output tensor. -// dtype: The type of the output. +// tensor: Up to 8-D. +// dims: 1-D. The dimensions to reverse. // -// Returns A tensor of the specified shape filled with uniform random values. -func RandomUniform(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomUniformAttr) (output tf.Output) { +// Returns The same shape as `tensor`. +func Reverse(scope *Scope, tensor tf.Output, dims tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "RandomUniform", + Type: "Reverse", Input: []tf.Input{ - shape, + tensor, dims, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// AssertAttr is an optional argument to Assert. -type AssertAttr func(optionalAttr) - -// AssertSummarize sets the optional summarize attribute to value. +// Computes log softmax activations. // -// value: Print this many entries of each tensor. -// If not specified, defaults to 3 -func AssertSummarize(value int64) AssertAttr { - return func(m optionalAttr) { - m["summarize"] = value - } -} - -// Asserts that the given condition is true. +// For each batch `i` and class `j` we have // -// If `condition` evaluates to false, print the list of tensors in `data`. -// `summarize` determines how many entries of the tensors to print. +// logsoftmax[i, j] = logits[i, j] - log(sum(exp(logits[i]))) // // Arguments: -// condition: The condition to evaluate. -// data: The tensors to print out when condition is false. -// -// Returns the created operation. -func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...AssertAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Assert", - Input: []tf.Input{ - condition, tf.OutputList(data), - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Computes element-wise population count (a.k.a. popcount, bitsum, bitcount). -// -// For each entry in `x`, calculates the number of `1` (on) bits in the binary -// representation of that entry. +// logits: 2-D with shape `[batch_size, num_classes]`. // -// **NOTE**: It is more efficient to first `tf.bitcast` your tensors into -// `int32` or `int64` and perform the bitcount on the result, than to feed in -// 8- or 16-bit inputs and then aggregate the resulting counts. -func PopulationCount(scope *Scope, x tf.Output) (y tf.Output) { +// Returns Same shape as `logits`. +func LogSoftmax(scope *Scope, logits tf.Output) (logsoftmax tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "PopulationCount", + Type: "LogSoftmax", Input: []tf.Input{ - x, + logits, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Split a `SparseTensor` into `num_split` tensors along one dimension. +// Computes the inverse permutation of a tensor. // -// If the `shape[split_dim]` is not an integer multiple of `num_split`. Slices -// `[0 : shape[split_dim] % num_split]` gets one extra dimension. -// For example, if `split_dim = 1` and `num_split = 2` and the input is +// This operation computes the inverse of an index permutation. It takes a 1-D +// integer tensor `x`, which represents the indices of a zero-based array, and +// swaps each value with its index position. In other words, for an output tensor +// `y` and an input tensor `x`, this operation computes the following: // -// input_tensor = shape = [2, 7] -// [ a d e ] -// [b c ] +// `y[x[i]] = i for i in [0, 1, ..., len(x) - 1]` // -// Graphically the output tensors are: +// The values must include 0. There can be no duplicate values or negative values. // -// output_tensor[0] = shape = [2, 4] -// [ a ] -// [b c ] +// For example: // -// output_tensor[1] = shape = [2, 3] -// [ d e ] -// [ ] +// ``` +// # tensor `x` is [3, 4, 0, 2, 1] +// invert_permutation(x) ==> [2, 4, 3, 0, 1] +// ``` // // Arguments: -// split_dim: 0-D. The dimension along which to split. Must be in the range -// `[0, rank(shape))`. -// indices: 2-D tensor represents the indices of the sparse tensor. -// values: 1-D tensor represents the values of the sparse tensor. -// shape: 1-D. tensor represents the shape of the sparse tensor. -// output indices: A list of 1-D tensors represents the indices of the output -// sparse tensors. -// num_split: The number of ways to split. -// -// Returns A list of 1-D tensors represents the values of the output sparse -// tensors.A list of 1-D tensors represents the shape of the output sparse -// tensors. -func SparseSplit(scope *Scope, split_dim tf.Output, indices tf.Output, values tf.Output, shape tf.Output, num_split int64) (output_indices []tf.Output, output_values []tf.Output, output_shape []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_split": num_split} - opspec := tf.OpSpec{ - Type: "SparseSplit", - Input: []tf.Input{ - split_dim, indices, values, shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output_indices, idx, err = makeOutputList(op, idx, "output_indices"); err != nil { - scope.UpdateErr("SparseSplit", err) - return - } - if output_values, idx, err = makeOutputList(op, idx, "output_values"); err != nil { - scope.UpdateErr("SparseSplit", err) - return - } - if output_shape, idx, err = makeOutputList(op, idx, "output_shape"); err != nil { - scope.UpdateErr("SparseSplit", err) - return - } - return output_indices, output_values, output_shape -} - -// Returns the truth value of (x < y) element-wise. +// x: 1-D. // -// *NOTE*: `Less` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Less(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Returns 1-D. +func InvertPermutation(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Less", + Type: "InvertPermutation", Input: []tf.Input{ - x, y, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// QuantizedReluXAttr is an optional argument to QuantizedReluX. -type QuantizedReluXAttr func(optionalAttr) +// BiasAddGradAttr is an optional argument to BiasAddGrad. +type BiasAddGradAttr func(optionalAttr) -// QuantizedReluXOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_QUINT8 -func QuantizedReluXOutType(value tf.DataType) QuantizedReluXAttr { +// BiasAddGradDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the bias tensor will be added to the last dimension +// of the value tensor. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// The tensor will be added to "in_channels", the third-to-the-last +// dimension. +// If not specified, defaults to "NHWC" +func BiasAddGradDataFormat(value string) BiasAddGradAttr { return func(m optionalAttr) { - m["out_type"] = value + m["data_format"] = value } } -// Computes Quantized Rectified Linear X: `min(max(features, 0), max_value)` -// -// Arguments: +// The backward operation for "BiasAdd" on the "bias" tensor. // +// It accumulates all the values from out_backprop into the feature dimension. +// For NHWC data format, the feature dimension is the last. For NCHW data format, +// the feature dimension is the third-to-last. // -// min_features: The float value that the lowest quantized value represents. -// max_features: The float value that the highest quantized value represents. +// Arguments: +// out_backprop: Any number of dimensions. // -// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. -func QuantizedReluX(scope *Scope, features tf.Output, max_value tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedReluXAttr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { +// Returns 1-D with size the feature dimension of `out_backprop`. +func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -5632,202 +5216,166 @@ func QuantizedReluX(scope *Scope, features tf.Output, max_value tf.Output, min_f a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizedReluX", + Type: "BiasAddGrad", Input: []tf.Input{ - features, max_value, min_features, max_features, + out_backprop, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// SummaryWriterAttr is an optional argument to SummaryWriter. -type SummaryWriterAttr func(optionalAttr) +// FusedBatchNormV2Attr is an optional argument to FusedBatchNormV2. +type FusedBatchNormV2Attr func(optionalAttr) -// SummaryWriterSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func SummaryWriterSharedName(value string) SummaryWriterAttr { +// FusedBatchNormV2Epsilon sets the optional epsilon attribute to value. +// +// value: A small float number added to the variance of x. +// If not specified, defaults to 0.0001 +func FusedBatchNormV2Epsilon(value float32) FusedBatchNormV2Attr { return func(m optionalAttr) { - m["shared_name"] = value + m["epsilon"] = value } } -// SummaryWriterContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func SummaryWriterContainer(value string) SummaryWriterAttr { +// FusedBatchNormV2DataFormat sets the optional data_format attribute to value. +// +// value: The data format for x and y. Either "NHWC" (default) or "NCHW". +// If not specified, defaults to "NHWC" +func FusedBatchNormV2DataFormat(value string) FusedBatchNormV2Attr { return func(m optionalAttr) { - m["container"] = value + m["data_format"] = value } } -// Returns a handle to be used to access a summary writer. -// -// The summary writer is an in-graph resource which can be used by ops to write -// summaries to event files. +// FusedBatchNormV2IsTraining sets the optional is_training attribute to value. // -// Returns the summary writer resource. Scalar handle. -func SummaryWriter(scope *Scope, optional ...SummaryWriterAttr) (writer tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SummaryWriter", - - Attrs: attrs, +// value: A bool value to indicate the operation is for training (default) +// or inference. +// If not specified, defaults to true +func FusedBatchNormV2IsTraining(value bool) FusedBatchNormV2Attr { + return func(m optionalAttr) { + m["is_training"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Computes gradients for SparseSegmentMean. +// Batch normalization. // -// Returns tensor "output" with same shape as grad, except for dimension 0 whose -// value is output_dim0. +// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +// The size of 1D Tensors matches the dimension C of the 4D Tensors. +// +// Arguments: +// x: A 4D Tensor for input data. +// scale: A 1D Tensor for scaling factor, to scale the normalized x. +// offset: A 1D Tensor for offset, to shift to the normalized x. +// mean: A 1D Tensor for population mean. Used for inference only; +// must be empty for training. +// variance: A 1D Tensor for population variance. Used for inference only; +// must be empty for training. // -// Arguments: -// grad: gradient propagated to the SparseSegmentMean op. -// indices: indices passed to the corresponding SparseSegmentMean op. -// segment_ids: segment_ids passed to the corresponding SparseSegmentMean op. -// output_dim0: dimension 0 of "data" passed to SparseSegmentMean op. -func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) { +// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow +// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by +// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused +// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance +// in the cuDNN case), to be reused in the gradient computation. +func FusedBatchNormV2(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormV2Attr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SparseSegmentMeanGrad", + Type: "FusedBatchNormV2", Input: []tf.Input{ - grad, indices, segment_ids, output_dim0, + x, scale, offset, mean, variance, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// Applies softmax to a batched N-D `SparseTensor`. -// -// The inputs represent an N-D SparseTensor with logical shape `[..., B, C]` -// (where `N >= 2`), and with indices sorted in the canonical lexicographic order. -// -// This op is equivalent to applying the normal `tf.nn.softmax()` to each innermost -// logical submatrix with shape `[B, C]`, but with the catch that *the implicitly -// zero elements do not participate*. Specifically, the algorithm is equivalent -// to the following: +// Returns the rank of a tensor. // -// (1) Applies `tf.nn.softmax()` to a densified view of each innermost submatrix -// with shape `[B, C]`, along the size-C dimension; -// (2) Masks out the original implicitly-zero locations; -// (3) Renormalizes the remaining elements. +// This operation returns an integer representing the rank of `input`. // -// Hence, the `SparseTensor` result has exactly the same non-zero indices and -// shape. +// For example: // -// Arguments: -// sp_indices: 2-D. `NNZ x R` matrix with the indices of non-empty values in a -// SparseTensor, in canonical ordering. -// sp_values: 1-D. `NNZ` non-empty values corresponding to `sp_indices`. -// sp_shape: 1-D. Shape of the input SparseTensor. +// ``` +// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] +// # shape of tensor 't' is [2, 2, 3] +// rank(t) ==> 3 +// ``` // -// Returns 1-D. The `NNZ` values for the result `SparseTensor`. -func SparseSoftmax(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output) (output tf.Output) { +// **Note**: The rank of a tensor is not the same as the rank of a matrix. The rank +// of a tensor is the number of indices required to uniquely select each element +// of the tensor. Rank is also known as "order", "degree", or "ndims." +func Rank(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseSoftmax", + Type: "Rank", Input: []tf.Input{ - sp_indices, sp_values, sp_shape, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// RandomPoissonAttr is an optional argument to RandomPoisson. -type RandomPoissonAttr func(optionalAttr) - -// RandomPoissonSeed sets the optional seed attribute to value. -// If not specified, defaults to 0 -func RandomPoissonSeed(value int64) RandomPoissonAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomPoissonSeed2 sets the optional seed2 attribute to value. -// If not specified, defaults to 0 -func RandomPoissonSeed2(value int64) RandomPoissonAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Use RandomPoissonV2 instead. +// Transforms a Tensor into a serialized TensorProto proto. // -// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2 -func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) { +// Arguments: +// tensor: A Tensor of type `T`. +// +// Returns A serialized TensorProto proto of the input tensor. +func SerializeTensor(scope *Scope, tensor tf.Output) (serialized tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "RandomPoisson", + Type: "SerializeTensor", Input: []tf.Input{ - shape, rate, + tensor, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceSparseApplyFtrlV2Attr is an optional argument to ResourceSparseApplyFtrlV2. -type ResourceSparseApplyFtrlV2Attr func(optionalAttr) +// MatrixSolveAttr is an optional argument to MatrixSolve. +type MatrixSolveAttr func(optionalAttr) -// ResourceSparseApplyFtrlV2UseLocking sets the optional use_locking attribute to value. +// MatrixSolveAdjoint sets the optional adjoint attribute to value. // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. +// value: Boolean indicating whether to solve with `matrix` or its (block-wise) +// adjoint. // If not specified, defaults to false -func ResourceSparseApplyFtrlV2UseLocking(value bool) ResourceSparseApplyFtrlV2Attr { +func MatrixSolveAdjoint(value bool) MatrixSolveAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["adjoint"] = value } } -// Update relevant entries in '*var' according to the Ftrl-proximal scheme. +// Solves systems of linear equations. // -// That is for rows we have grad for, we update var, accum and linear as follows: -// grad_with_shrinkage = grad + 2 * l2_shrinkage * var -// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage -// linear += grad_with_shrinkage + -// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var -// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 -// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 -// accum = accum_new +// `Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is +// a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix +// satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. +// If `adjoint` is `True` then each output matrix satisfies +// `adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// linear: Should be from a Variable(). -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 shrinkage regulariation. Must be a scalar. -// -// lr_power: Scaling factor. Must be a scalar. +// matrix: Shape is `[..., M, M]`. +// rhs: Shape is `[..., M, K]`. // -// Returns the created operation. -func ResourceSparseApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, l2_shrinkage tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlV2Attr) (o *tf.Operation) { +// Returns Shape is `[..., M, K]`. +func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixSolveAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -5836,165 +5384,236 @@ func ResourceSparseApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, li a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyFtrlV2", + Type: "MatrixSolve", Input: []tf.Input{ - var_, accum, linear, grad, indices, lr, l1, l2, l2_shrinkage, lr_power, + matrix, rhs, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Associates the given iterator with the given statistics aggregator. -// -// Returns the created operation. -func IteratorSetStatsAggregator(scope *Scope, iterator_handle tf.Output, stats_aggregator_handle tf.Output) (o *tf.Operation) { +// Computes acos of x element-wise. +func Acos(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "IteratorSetStatsAggregator", + Type: "Acos", Input: []tf.Input{ - iterator_handle, stats_aggregator_handle, + x, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Returns element-wise smallest integer in not less than x. -func Ceil(scope *Scope, x tf.Output) (y tf.Output) { +// Real-valued fast Fourier transform. +// +// Computes the 1-dimensional discrete Fourier transform of a real-valued signal +// over the inner-most dimension of `input`. +// +// Since the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the +// `fft_length / 2 + 1` unique components of the FFT: the zero-frequency term, +// followed by the `fft_length / 2` positive-frequency terms. +// +// Along the axis `RFFT` is computed on, if `fft_length` is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. +// +// Arguments: +// input: A float32 tensor. +// fft_length: An int32 tensor of shape [1]. The FFT length. +// +// Returns A complex64 tensor of the same rank as `input`. The inner-most +// dimension of `input` is replaced with the `fft_length / 2 + 1` unique +// frequency components of its 1D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.rfft +// @end_compatibility +func RFFT(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Ceil", + Type: "RFFT", Input: []tf.Input{ - x, + input, fft_length, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the number of elements in the given table. +// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter. +type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr) + +// DepthwiseConv2dNativeBackpropFilterDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, height, width, channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, channels, height, width]. +// If not specified, defaults to "NHWC" +func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2dNativeBackpropFilterAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// DepthwiseConv2dNativeBackpropFilterDilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to +func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of depthwise convolution with respect to the filter. // // Arguments: -// table_handle: Handle to the table. +// input: 4-D with shape based on `data_format`. For example, if +// `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height, +// in_width, in_channels]` tensor. +// filter_sizes: An integer vector representing the tensor shape of `filter`, +// where `filter` is a 4-D +// `[filter_height, filter_width, in_channels, depthwise_multiplier]` tensor. +// out_backprop: 4-D with shape based on `data_format`. +// For example, if `data_format` is 'NHWC' then +// out_backprop shape is `[batch, out_height, out_width, out_channels]`. +// Gradients w.r.t. the output of the convolution. +// strides: The stride of the sliding window for each dimension of the input +// of the convolution. +// padding: The type of padding algorithm to use. // -// Returns Scalar that contains number of elements in the table. -func LookupTableSizeV2(scope *Scope, table_handle tf.Output) (size tf.Output) { +// Returns 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. +// the `filter` input of the convolution. +func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropFilterAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "LookupTableSizeV2", + Type: "DepthwiseConv2dNativeBackpropFilter", Input: []tf.Input{ - table_handle, + input, filter_sizes, out_backprop, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResizeBilinearGradAttr is an optional argument to ResizeBilinearGrad. -type ResizeBilinearGradAttr func(optionalAttr) +// LRNGradAttr is an optional argument to LRNGrad. +type LRNGradAttr func(optionalAttr) -// ResizeBilinearGradAlignCorners sets the optional align_corners attribute to value. +// LRNGradDepthRadius sets the optional depth_radius attribute to value. // -// value: If true, rescale grads by (orig_height - 1) / (height - 1), which -// exactly aligns the 4 corners of grads and original_image. If false, rescale by -// orig_height / height. Treat similarly the width dimension. -// If not specified, defaults to false -func ResizeBilinearGradAlignCorners(value bool) ResizeBilinearGradAttr { +// value: A depth radius. +// If not specified, defaults to 5 +func LRNGradDepthRadius(value int64) LRNGradAttr { return func(m optionalAttr) { - m["align_corners"] = value + m["depth_radius"] = value } } -// Computes the gradient of bilinear interpolation. -// -// Arguments: -// grads: 4-D with shape `[batch, height, width, channels]`. -// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, -// The image tensor that was resized. +// LRNGradBias sets the optional bias attribute to value. // -// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. -// Gradients with respect to the input image. Input image must have been -// float or double. -func ResizeBilinearGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBilinearGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResizeBilinearGrad", - Input: []tf.Input{ - grads, original_image, - }, - Attrs: attrs, +// value: An offset (usually > 0 to avoid dividing by 0). +// If not specified, defaults to 1 +func LRNGradBias(value float32) LRNGradAttr { + return func(m optionalAttr) { + m["bias"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Computes the sum along sparse segments of a tensor divided by the sqrt of N. +// LRNGradAlpha sets the optional alpha attribute to value. // -// N is the size of the segment being reduced. +// value: A scale factor, usually positive. +// If not specified, defaults to 1 +func LRNGradAlpha(value float32) LRNGradAttr { + return func(m optionalAttr) { + m["alpha"] = value + } +} + +// LRNGradBeta sets the optional beta attribute to value. // -// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of -// segments. +// value: An exponent. +// If not specified, defaults to 0.5 +func LRNGradBeta(value float32) LRNGradAttr { + return func(m optionalAttr) { + m["beta"] = value + } +} + +// Gradients for Local Response Normalization. // // Arguments: +// input_grads: 4-D with shape `[batch, height, width, channels]`. +// input_image: 4-D with shape `[batch, height, width, channels]`. +// output_image: 4-D with shape `[batch, height, width, channels]`. // -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SparseSegmentSqrtN(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { +// Returns The gradients for LRN. +func LRNGrad(scope *Scope, input_grads tf.Output, input_image tf.Output, output_image tf.Output, optional ...LRNGradAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SparseSegmentSqrtN", + Type: "LRNGrad", Input: []tf.Input{ - data, indices, segment_ids, + input_grads, input_image, output_image, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal. -type StatelessTruncatedNormalAttr func(optionalAttr) +// AnyAttr is an optional argument to Any. +type AnyAttr func(optionalAttr) -// StatelessTruncatedNormalDtype sets the optional dtype attribute to value. +// AnyKeepDims sets the optional keep_dims attribute to value. // -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAttr { +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func AnyKeepDims(value bool) AnyAttr { return func(m optionalAttr) { - m["dtype"] = value + m["keep_dims"] = value } } -// Outputs deterministic pseudorandom values from a truncated normal distribution. -// -// The generated values follow a normal distribution with mean 0 and standard -// deviation 1, except that values whose magnitude is more than 2 standard -// deviations from the mean are dropped and re-picked. +// Computes the "logical or" of elements across dimensions of a tensor. // -// The outputs are a deterministic function of `shape` and `seed`. +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. // // Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. // -// Returns Random values with specified shape. -func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessTruncatedNormalAttr) (output tf.Output) { +// Returns The reduced tensor. +func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -6003,9 +5622,9 @@ func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, opt a(attrs) } opspec := tf.OpSpec{ - Type: "StatelessTruncatedNormal", + Type: "Any", Input: []tf.Input{ - shape, seed, + input, axis, }, Attrs: attrs, } @@ -6013,137 +5632,136 @@ func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, opt return op.Output(0) } -// RestoreSliceAttr is an optional argument to RestoreSlice. -type RestoreSliceAttr func(optionalAttr) +// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl. +type ResourceApplyFtrlAttr func(optionalAttr) -// RestoreSlicePreferredShard sets the optional preferred_shard attribute to value. +// ResourceApplyFtrlUseLocking sets the optional use_locking attribute to value. // -// value: Index of file to open first if multiple files match -// `file_pattern`. See the documentation for `Restore`. -// If not specified, defaults to -1 -func RestoreSlicePreferredShard(value int64) RestoreSliceAttr { +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyFtrlUseLocking(value bool) ResourceApplyFtrlAttr { return func(m optionalAttr) { - m["preferred_shard"] = value + m["use_locking"] = value } } -// Restores a tensor from checkpoint files. -// -// This is like `Restore` except that restored tensor can be listed as filling -// only a slice of a larger tensor. `shape_and_slice` specifies the shape of the -// larger tensor and the slice that the restored tensor covers. +// Update '*var' according to the Ftrl-proximal scheme. // -// The `shape_and_slice` input has the same format as the -// elements of the `shapes_and_slices` input of the `SaveSlices` op. +// accum_new = accum + grad * grad +// linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +// accum = accum_new // // Arguments: -// file_pattern: Must have a single element. The pattern of the files from -// which we read the tensor. -// tensor_name: Must have a single element. The name of the tensor to be -// restored. -// shape_and_slice: Scalar. The shapes and slice specifications to use when -// restoring a tensors. -// dt: The type of the tensor to be restored. +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// linear: Should be from a Variable(). +// grad: The gradient. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regulariation. Must be a scalar. +// l2: L2 regulariation. Must be a scalar. +// lr_power: Scaling factor. Must be a scalar. // -// Returns The restored tensor. -func RestoreSlice(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, shape_and_slice tf.Output, dt tf.DataType, optional ...RestoreSliceAttr) (tensor tf.Output) { +// Returns the created operation. +func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dt": dt} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RestoreSlice", + Type: "ResourceApplyFtrl", Input: []tf.Input{ - file_pattern, tensor_name, shape_and_slice, + var_, accum, linear, grad, lr, l1, l2, lr_power, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// UniqueWithCountsAttr is an optional argument to UniqueWithCounts. -type UniqueWithCountsAttr func(optionalAttr) +// RandomUniformAttr is an optional argument to RandomUniform. +type RandomUniformAttr func(optionalAttr) -// UniqueWithCountsOutIdx sets the optional out_idx attribute to value. -// If not specified, defaults to DT_INT32 -func UniqueWithCountsOutIdx(value tf.DataType) UniqueWithCountsAttr { +// RandomUniformSeed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomUniformSeed(value int64) RandomUniformAttr { return func(m optionalAttr) { - m["out_idx"] = value + m["seed"] = value } } -// Finds unique elements in a 1-D tensor. -// -// This operation returns a tensor `y` containing all of the unique elements of `x` -// sorted in the same order that they occur in `x`. This operation also returns a -// tensor `idx` the same size as `x` that contains the index of each value of `x` -// in the unique output `y`. Finally, it returns a third tensor `count` that -// contains the count of each element of `y` in `x`. In other words: -// -// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` +// RandomUniformSeed2 sets the optional seed2 attribute to value. // -// For example: +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomUniformSeed2(value int64) RandomUniformAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Outputs random values from a uniform distribution. // -// ``` -// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] -// y, idx, count = unique_with_counts(x) -// y ==> [1, 2, 4, 7, 8] -// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] -// count ==> [2, 1, 3, 1, 2] -// ``` +// The generated values follow a uniform distribution in the range `[0, 1)`. The +// lower bound 0 is included in the range, while the upper bound 1 is excluded. // // Arguments: -// x: 1-D. +// shape: The shape of the output tensor. +// dtype: The type of the output. // -// Returns 1-D.1-D.1-D. -func UniqueWithCounts(scope *Scope, x tf.Output, optional ...UniqueWithCountsAttr) (y tf.Output, idx tf.Output, count tf.Output) { +// Returns A tensor of the specified shape filled with uniform random values. +func RandomUniform(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomUniformAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtype": dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "UniqueWithCounts", + Type: "RandomUniform", Input: []tf.Input{ - x, + shape, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal. -type StatelessRandomNormalAttr func(optionalAttr) +// AssertAttr is an optional argument to Assert. +type AssertAttr func(optionalAttr) -// StatelessRandomNormalDtype sets the optional dtype attribute to value. +// AssertSummarize sets the optional summarize attribute to value. // -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatelessRandomNormalDtype(value tf.DataType) StatelessRandomNormalAttr { +// value: Print this many entries of each tensor. +// If not specified, defaults to 3 +func AssertSummarize(value int64) AssertAttr { return func(m optionalAttr) { - m["dtype"] = value + m["summarize"] = value } } -// Outputs deterministic pseudorandom values from a normal distribution. -// -// The generated values will have mean 0 and standard deviation 1. +// Asserts that the given condition is true. // -// The outputs are a deterministic function of `shape` and `seed`. +// If `condition` evaluates to false, print the list of tensors in `data`. +// `summarize` determines how many entries of the tensors to print. // // Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). +// condition: The condition to evaluate. +// data: The tensors to print out when condition is false. // -// Returns Random values with specified shape. -func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomNormalAttr) (output tf.Output) { +// Returns the created operation. +func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...AssertAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -6152,202 +5770,184 @@ func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, option a(attrs) } opspec := tf.OpSpec{ - Type: "StatelessRandomNormal", + Type: "Assert", Input: []tf.Input{ - shape, seed, + condition, tf.OutputList(data), }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Reshapes a quantized tensor as per the Reshape op. -// -// ``` -// -// Arguments: +// Computes element-wise population count (a.k.a. popcount, bitsum, bitcount). // -// shape: Defines the shape of the output tensor. -// input_min: The minimum value of the input. -// input_max: The maximum value of the input. +// For each entry in `x`, calculates the number of `1` (on) bits in the binary +// representation of that entry. // -// Returns This value is copied from input_min.This value is copied from input_max. -func QuantizedReshape(scope *Scope, tensor tf.Output, shape tf.Output, input_min tf.Output, input_max tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { +// **NOTE**: It is more efficient to first `tf.bitcast` your tensors into +// `int32` or `int64` and perform the bitcount on the result, than to feed in +// 8- or 16-bit inputs and then aggregate the resulting counts. +func PopulationCount(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "QuantizedReshape", + Type: "PopulationCount", Input: []tf.Input{ - tensor, shape, input_min, input_max, + x, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// GatherAttr is an optional argument to Gather. -type GatherAttr func(optionalAttr) - -// GatherValidateIndices sets the optional validate_indices attribute to value. -// If not specified, defaults to true -func GatherValidateIndices(value bool) GatherAttr { - return func(m optionalAttr) { - m["validate_indices"] = value - } + return op.Output(0) } -// Gather slices from `params` according to `indices`. +// Split a `SparseTensor` into `num_split` tensors along one dimension. // -// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). -// Produces an output tensor with shape `indices.shape + params.shape[1:]` where: +// If the `shape[split_dim]` is not an integer multiple of `num_split`. Slices +// `[0 : shape[split_dim] % num_split]` gets one extra dimension. +// For example, if `split_dim = 1` and `num_split = 2` and the input is // -// ```python -// # Scalar indices -// output[:, ..., :] = params[indices, :, ... :] +// input_tensor = shape = [2, 7] +// [ a d e ] +// [b c ] // -// # Vector indices -// output[i, :, ..., :] = params[indices[i], :, ... :] +// Graphically the output tensors are: // -// # Higher rank indices -// output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :] -// ``` +// output_tensor[0] = shape = [2, 4] +// [ a ] +// [b c ] // -// If `indices` is a permutation and `len(indices) == params.shape[0]` then -// this operation will permute `params` accordingly. +// output_tensor[1] = shape = [2, 3] +// [ d e ] +// [ ] // -// `validate_indices`: DEPRECATED. If this operation is assigned to CPU, values in -// `indices` are always validated to be within range. If assigned to GPU, -// out-of-bound indices result in safe but unspecified behavior, which may include -// raising an error. +// Arguments: +// split_dim: 0-D. The dimension along which to split. Must be in the range +// `[0, rank(shape))`. +// indices: 2-D tensor represents the indices of the sparse tensor. +// values: 1-D tensor represents the values of the sparse tensor. +// shape: 1-D. tensor represents the shape of the sparse tensor. +// output indices: A list of 1-D tensors represents the indices of the output +// sparse tensors. +// num_split: The number of ways to split. // -//
-// -//
-func Gather(scope *Scope, params tf.Output, indices tf.Output, optional ...GatherAttr) (output tf.Output) { +// Returns A list of 1-D tensors represents the values of the output sparse +// tensors.A list of 1-D tensors represents the shape of the output sparse +// tensors. +func SparseSplit(scope *Scope, split_dim tf.Output, indices tf.Output, values tf.Output, shape tf.Output, num_split int64) (output_indices []tf.Output, output_values []tf.Output, output_shape []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"num_split": num_split} opspec := tf.OpSpec{ - Type: "Gather", + Type: "SparseSplit", Input: []tf.Input{ - params, indices, + split_dim, indices, values, shape, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the truth value of (x != y) element-wise. -// -// *NOTE*: `NotEqual` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func NotEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "NotEqual", - Input: []tf.Input{ - x, y, - }, + var idx int + var err error + if output_indices, idx, err = makeOutputList(op, idx, "output_indices"); err != nil { + scope.UpdateErr("SparseSplit", err) + return } - op := scope.AddOperation(opspec) - return op.Output(0) + if output_values, idx, err = makeOutputList(op, idx, "output_values"); err != nil { + scope.UpdateErr("SparseSplit", err) + return + } + if output_shape, idx, err = makeOutputList(op, idx, "output_shape"); err != nil { + scope.UpdateErr("SparseSplit", err) + return + } + return output_indices, output_values, output_shape } -// Inverse 3D real-valued fast Fourier transform. -// -// Computes the inverse 3-dimensional discrete Fourier transform of a real-valued -// signal over the inner-most 3 dimensions of `input`. -// -// The inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`: -// The inner-most dimension contains the `fft_length / 2 + 1` unique components of -// the DFT of a real-valued signal. If `fft_length` is not provided, it is computed -// from the size of the inner-most 3 dimensions of `input`. If the FFT length used -// to compute `input` is odd, it should be provided since it cannot be inferred -// properly. -// -// Along each axis `IRFFT3D` is computed on, if `fft_length` (or -// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. -// -// Arguments: -// input: A complex64 tensor. -// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. -// -// Returns A float32 tensor of the same rank as `input`. The inner-most 3 -// dimensions of `input` are replaced with the `fft_length` samples of their -// inverse 3D real Fourier transform. +// RandomPoissonAttr is an optional argument to RandomPoisson. +type RandomPoissonAttr func(optionalAttr) + +// RandomPoissonSeed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func RandomPoissonSeed(value int64) RandomPoissonAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomPoissonSeed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func RandomPoissonSeed2(value int64) RandomPoissonAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Use RandomPoissonV2 instead. // -// @compatibility(numpy) -// Equivalent to np.irfftn with 3 dimensions. -// @end_compatibility -func IRFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { +// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2 +func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "IRFFT3D", + Type: "RandomPoisson", Input: []tf.Input{ - input, fft_length, + shape, rate, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// StringSplitAttr is an optional argument to StringSplit. -type StringSplitAttr func(optionalAttr) +// ResourceSparseApplyFtrlV2Attr is an optional argument to ResourceSparseApplyFtrlV2. +type ResourceSparseApplyFtrlV2Attr func(optionalAttr) -// StringSplitSkipEmpty sets the optional skip_empty attribute to value. +// ResourceSparseApplyFtrlV2UseLocking sets the optional use_locking attribute to value. // -// value: A `bool`. If `True`, skip the empty strings from the result. -// If not specified, defaults to true -func StringSplitSkipEmpty(value bool) StringSplitAttr { +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyFtrlV2UseLocking(value bool) ResourceSparseApplyFtrlV2Attr { return func(m optionalAttr) { - m["skip_empty"] = value + m["use_locking"] = value } } -// Split elements of `input` based on `delimiter` into a `SparseTensor`. -// -// Let N be the size of source (typically N will be the batch size). Split each -// element of `input` based on `delimiter` and return a `SparseTensor` -// containing the splitted tokens. Empty tokens are ignored. -// -// `delimiter` can be empty, or a string of split characters. If `delimiter` is an -// empty string, each element of `input` is split into individual single-byte -// character strings, including splitting of UTF-8 multibyte sequences. Otherwise -// every character of `delimiter` is a potential split point. -// -// For example: -// N = 2, input[0] is 'hello world' and input[1] is 'a b c', then the output -// will be +// Update relevant entries in '*var' according to the Ftrl-proximal scheme. // -// indices = [0, 0; -// 0, 1; -// 1, 0; -// 1, 1; -// 1, 2] -// shape = [2, 3] -// values = ['hello', 'world', 'a', 'b', 'c'] +// That is for rows we have grad for, we update var, accum and linear as follows: +// grad_with_shrinkage = grad + 2 * l2_shrinkage * var +// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage +// linear += grad_with_shrinkage + +// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +// accum = accum_new // // Arguments: -// input: 1-D. Strings to split. -// delimiter: 0-D. Delimiter characters (bytes), or empty string. +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// linear: Should be from a Variable(). +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 shrinkage regulariation. Must be a scalar. // -// Returns A dense matrix of int64 representing the indices of the sparse tensor.A vector of strings corresponding to the splited values.a length-2 vector of int64 representing the shape of the sparse -// tensor, where the first value is N and the second value is the maximum number -// of tokens in a single input entry. -func StringSplit(scope *Scope, input tf.Output, delimiter tf.Output, optional ...StringSplitAttr) (indices tf.Output, values tf.Output, shape tf.Output) { +// lr_power: Scaling factor. Must be a scalar. +// +// Returns the created operation. +func ResourceSparseApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, l2_shrinkage tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -6356,98 +5956,63 @@ func StringSplit(scope *Scope, input tf.Output, delimiter tf.Output, optional .. a(attrs) } opspec := tf.OpSpec{ - Type: "StringSplit", + Type: "ResourceSparseApplyFtrlV2", Input: []tf.Input{ - input, delimiter, + var_, accum, linear, grad, indices, lr, l1, l2, l2_shrinkage, lr_power, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// WriteAudioSummaryAttr is an optional argument to WriteAudioSummary. -type WriteAudioSummaryAttr func(optionalAttr) - -// WriteAudioSummaryMaxOutputs sets the optional max_outputs attribute to value. -// -// value: Max number of batch elements to generate audio for. -// If not specified, defaults to 3 -// -// REQUIRES: value >= 1 -func WriteAudioSummaryMaxOutputs(value int64) WriteAudioSummaryAttr { - return func(m optionalAttr) { - m["max_outputs"] = value - } + return scope.AddOperation(opspec) } -// Writes a `Summary` protocol buffer with audio. -// -// The summary has up to `max_outputs` summary values containing audio. The -// audio is built from `tensor` which must be 3-D with shape `[batch_size, -// frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are -// assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. -// -// The `tag` argument is a scalar `Tensor` of type `string`. It is used to -// build the `tag` of the summary values: -// -// * If `max_outputs` is 1, the summary value tag is '*tag*/audio'. -// * If `max_outputs` is greater than 1, the summary value tags are -// generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. -// -// Arguments: -// writer: A handle to a summary writer. -// step: The step to write the summary for. -// tag: Scalar. Used to build the `tag` attribute of the summary values. -// tensor: 2-D of shape `[batch_size, frames]`. -// sample_rate: The sample rate of the signal in hertz. +// Associates the given iterator with the given statistics aggregator. // // Returns the created operation. -func WriteAudioSummary(scope *Scope, writer tf.Output, step tf.Output, tag tf.Output, tensor tf.Output, sample_rate tf.Output, optional ...WriteAudioSummaryAttr) (o *tf.Operation) { +func IteratorSetStatsAggregator(scope *Scope, iterator_handle tf.Output, stats_aggregator_handle tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "WriteAudioSummary", + Type: "IteratorSetStatsAggregator", Input: []tf.Input{ - writer, step, tag, tensor, sample_rate, + iterator_handle, stats_aggregator_handle, }, - Attrs: attrs, } return scope.AddOperation(opspec) } -// ProdAttr is an optional argument to Prod. -type ProdAttr func(optionalAttr) +// DataFormatVecPermuteAttr is an optional argument to DataFormatVecPermute. +type DataFormatVecPermuteAttr func(optionalAttr) -// ProdKeepDims sets the optional keep_dims attribute to value. +// DataFormatVecPermuteSrcFormat sets the optional src_format attribute to value. // -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func ProdKeepDims(value bool) ProdAttr { +// value: source data format. +// If not specified, defaults to "NHWC" +func DataFormatVecPermuteSrcFormat(value string) DataFormatVecPermuteAttr { return func(m optionalAttr) { - m["keep_dims"] = value + m["src_format"] = value } } -// Computes the product of elements across dimensions of a tensor. +// DataFormatVecPermuteDstFormat sets the optional dst_format attribute to value. // -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. +// value: destination data format. +// If not specified, defaults to "NCHW" +func DataFormatVecPermuteDstFormat(value string) DataFormatVecPermuteAttr { + return func(m optionalAttr) { + m["dst_format"] = value + } +} + +// Returns the permuted vector/tensor in the destination data format given the +// +// one in the source data format. // // Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. +// x: Vector of size 4 or Tensor of shape (4, 2) in source data format. // -// Returns The reduced tensor. -func Prod(scope *Scope, input tf.Output, axis tf.Output, optional ...ProdAttr) (output tf.Output) { +// Returns Vector of size 4 or Tensor of shape (4, 2) in destination data format. +func DataFormatVecPermute(scope *Scope, x tf.Output, optional ...DataFormatVecPermuteAttr) (y tf.Output) { if scope.Err() != nil { return } @@ -6456,9 +6021,9 @@ func Prod(scope *Scope, input tf.Output, axis tf.Output, optional ...ProdAttr) ( a(attrs) } opspec := tf.OpSpec{ - Type: "Prod", + Type: "DataFormatVecPermute", Input: []tf.Input{ - input, axis, + x, }, Attrs: attrs, } @@ -6466,33 +6031,58 @@ func Prod(scope *Scope, input tf.Output, axis tf.Output, optional ...ProdAttr) ( return op.Output(0) } -// ResizeBilinearAttr is an optional argument to ResizeBilinear. -type ResizeBilinearAttr func(optionalAttr) +// Computes tan of x element-wise. +func Tan(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Tan", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} -// ResizeBilinearAlignCorners sets the optional align_corners attribute to value. +// ResourceSparseApplyFtrlAttr is an optional argument to ResourceSparseApplyFtrl. +type ResourceSparseApplyFtrlAttr func(optionalAttr) + +// ResourceSparseApplyFtrlUseLocking sets the optional use_locking attribute to value. // -// value: If true, rescale input by (new_height - 1) / (height - 1), which -// exactly aligns the 4 corners of images and resized images. If false, rescale -// by new_height / height. Treat similarly the width dimension. +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. // If not specified, defaults to false -func ResizeBilinearAlignCorners(value bool) ResizeBilinearAttr { +func ResourceSparseApplyFtrlUseLocking(value bool) ResourceSparseApplyFtrlAttr { return func(m optionalAttr) { - m["align_corners"] = value + m["use_locking"] = value } } -// Resize `images` to `size` using bilinear interpolation. +// Update relevant entries in '*var' according to the Ftrl-proximal scheme. // -// Input images can be of different types but output images are always float. +// That is for rows we have grad for, we update var, accum and linear as follows: +// accum_new = accum + grad * grad +// linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +// accum = accum_new // // Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// linear: Should be from a Variable(). +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// lr_power: Scaling factor. Must be a scalar. // -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBilinearAttr) (resized_images tf.Output) { +// Returns the created operation. +func ResourceSparseApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -6501,199 +6091,184 @@ func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ... a(attrs) } opspec := tf.OpSpec{ - Type: "ResizeBilinear", + Type: "ResourceSparseApplyFtrl", Input: []tf.Input{ - images, size, + var_, accum, linear, grad, indices, lr, l1, l2, lr_power, }, Attrs: attrs, } + return scope.AddOperation(opspec) +} + +// Returns which elements of x are Inf. +// +// @compatibility(numpy) +// Equivalent to np.isinf +// @end_compatibility +func IsInf(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IsInf", + Input: []tf.Input{ + x, + }, + } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes softsign: `features / (abs(features) + 1)`. -func Softsign(scope *Scope, features tf.Output) (activations tf.Output) { +// Computes the sum along sparse segments of a tensor divided by the sqrt of N. +// +// N is the size of the segment being reduced. +// +// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of +// segments. +// +// Arguments: +// +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SparseSegmentSqrtN(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Softsign", + Type: "SparseSegmentSqrtN", Input: []tf.Input{ - features, + data, indices, segment_ids, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// GenerateVocabRemappingAttr is an optional argument to GenerateVocabRemapping. -type GenerateVocabRemappingAttr func(optionalAttr) +// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal. +type StatelessTruncatedNormalAttr func(optionalAttr) -// GenerateVocabRemappingOldVocabSize sets the optional old_vocab_size attribute to value. -// -// value: Number of entries in the old vocab file to consider. If -1, -// use the entire old vocabulary. -// If not specified, defaults to -1 +// StatelessTruncatedNormalDtype sets the optional dtype attribute to value. // -// REQUIRES: value >= -1 -func GenerateVocabRemappingOldVocabSize(value int64) GenerateVocabRemappingAttr { +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAttr { return func(m optionalAttr) { - m["old_vocab_size"] = value + m["dtype"] = value } } -// Given a path to new and old vocabulary files, returns a remapping Tensor of -// -// length `num_new_vocab`, where `remapping[i]` contains the row number in the old -// vocabulary that corresponds to row `i` in the new vocabulary (starting at line -// `new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i` -// in the new vocabulary is not in the old vocabulary. The old vocabulary is -// constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the -// default value of -1. -// -// `num_vocab_offset` enables -// use in the partitioned variable case, and should generally be set through -// examining partitioning info. The format of the files should be a text file, -// with each line containing a single entity within the vocabulary. -// -// For example, with `new_vocab_file` a text file containing each of the following -// elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3], -// `num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be -// `[0, -1, 2]`. +// Outputs deterministic pseudorandom values from a truncated normal distribution. // -// The op also returns a count of how many entries in the new vocabulary -// were present in the old vocabulary, which is used to calculate the number of -// values to initialize in a weight matrix remapping +// The generated values follow a normal distribution with mean 0 and standard +// deviation 1, except that values whose magnitude is more than 2 standard +// deviations from the mean are dropped and re-picked. // -// This functionality can be used to remap both row vocabularies (typically, -// features) and column vocabularies (typically, classes) from TensorFlow -// checkpoints. Note that the partitioning logic relies on contiguous vocabularies -// corresponding to div-partitioned variables. Moreover, the underlying remapping -// uses an IndexTable (as opposed to an inexact CuckooTable), so client code should -// use the corresponding index_table_from_file() as the FeatureColumn framework -// does (as opposed to tf.feature_to_id(), which uses a CuckooTable). +// The outputs are a deterministic function of `shape` and `seed`. // // Arguments: -// new_vocab_file: Path to the new vocab file. -// old_vocab_file: Path to the old vocab file. -// new_vocab_offset: How many entries into the new vocab file to start reading. -// num_new_vocab: Number of entries in the new vocab file to remap. +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). // -// Returns A Tensor of length num_new_vocab where the element at index i -// is equal to the old ID that maps to the new ID i. This element is -1 for any -// new ID that is not found in the old vocabulary.Number of new vocab entries found in old vocab. -func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_file tf.Output, new_vocab_offset int64, num_new_vocab int64, optional ...GenerateVocabRemappingAttr) (remapping tf.Output, num_present tf.Output) { +// Returns Random values with specified shape. +func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessTruncatedNormalAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"new_vocab_offset": new_vocab_offset, "num_new_vocab": num_new_vocab} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "GenerateVocabRemapping", + Type: "StatelessTruncatedNormal", Input: []tf.Input{ - new_vocab_file, old_vocab_file, + shape, seed, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Assigns sparse updates to the variable referenced by `resource`. -// -// This operation computes +// RestoreSliceAttr is an optional argument to RestoreSlice. +type RestoreSliceAttr func(optionalAttr) + +// RestoreSlicePreferredShard sets the optional preferred_shard attribute to value. // -// # Scalar indices -// ref[indices, ...] = updates[...] +// value: Index of file to open first if multiple files match +// `file_pattern`. See the documentation for `Restore`. +// If not specified, defaults to -1 +func RestoreSlicePreferredShard(value int64) RestoreSliceAttr { + return func(m optionalAttr) { + m["preferred_shard"] = value + } +} + +// Restores a tensor from checkpoint files. // -// # Vector indices (for each i) -// ref[indices[i], ...] = updates[i, ...] +// This is like `Restore` except that restored tensor can be listed as filling +// only a slice of a larger tensor. `shape_and_slice` specifies the shape of the +// larger tensor and the slice that the restored tensor covers. // -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] +// The `shape_and_slice` input has the same format as the +// elements of the `shapes_and_slices` input of the `SaveSlices` op. // // Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. +// file_pattern: Must have a single element. The pattern of the files from +// which we read the tensor. +// tensor_name: Must have a single element. The name of the tensor to be +// restored. +// shape_and_slice: Scalar. The shapes and slice specifications to use when +// restoring a tensors. +// dt: The type of the tensor to be restored. // -// Returns the created operation. -func ResourceScatterUpdate(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { +// Returns The restored tensor. +func RestoreSlice(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, shape_and_slice tf.Output, dt tf.DataType, optional ...RestoreSliceAttr) (tensor tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dt": dt} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ResourceScatterUpdate", + Type: "RestoreSlice", Input: []tf.Input{ - resource, indices, updates, + file_pattern, tensor_name, shape_and_slice, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// CumsumAttr is an optional argument to Cumsum. -type CumsumAttr func(optionalAttr) - -// CumsumExclusive sets the optional exclusive attribute to value. -// -// value: If `True`, perform exclusive cumsum. -// If not specified, defaults to false -func CumsumExclusive(value bool) CumsumAttr { - return func(m optionalAttr) { - m["exclusive"] = value - } -} +// ImagAttr is an optional argument to Imag. +type ImagAttr func(optionalAttr) -// CumsumReverse sets the optional reverse attribute to value. -// -// value: A `bool` (default: False). -// If not specified, defaults to false -func CumsumReverse(value bool) CumsumAttr { +// ImagTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_FLOAT +func ImagTout(value tf.DataType) ImagAttr { return func(m optionalAttr) { - m["reverse"] = value + m["Tout"] = value } } -// Compute the cumulative sum of the tensor `x` along `axis`. -// -// By default, this op performs an inclusive cumsum, which means that the first -// element of the input is identical to the first element of the output: -// -// ```python -// tf.cumsum([a, b, c]) # => [a, a + b, a + b + c] -// ``` -// -// By setting the `exclusive` kwarg to `True`, an exclusive cumsum is -// performed instead: +// Returns the imaginary part of a complex number. // -// ```python -// tf.cumsum([a, b, c], exclusive=True) # => [0, a, a + b] -// ``` +// Given a tensor `input` of complex numbers, this operation returns a tensor of +// type `float` that is the imaginary part of each element in `input`. All +// elements in `input` must be complex numbers of the form \\(a + bj\\), where *a* +// is the real part and *b* is the imaginary part returned by this operation. // -// By setting the `reverse` kwarg to `True`, the cumsum is performed in the -// opposite direction: +// For example: // -// ```python -// tf.cumsum([a, b, c], reverse=True) # => [a + b + c, b + c, c] // ``` -// -// This is more efficient than using separate `tf.reverse` ops. -// -// The `reverse` and `exclusive` kwargs can also be combined: -// -// ```python -// tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] +// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +// tf.imag(input) ==> [4.75, 5.75] // ``` -// -// Arguments: -// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, -// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, -// `complex128`, `qint8`, `quint8`, `qint32`, `half`. -// axis: A `Tensor` of type `int32` (default: 0). Must be in the range -// `[-rank(x), rank(x))`. -func Cumsum(scope *Scope, x tf.Output, axis tf.Output, optional ...CumsumAttr) (out tf.Output) { +func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -6702,9 +6277,9 @@ func Cumsum(scope *Scope, x tf.Output, axis tf.Output, optional ...CumsumAttr) ( a(attrs) } opspec := tf.OpSpec{ - Type: "Cumsum", + Type: "Imag", Input: []tf.Input{ - x, axis, + input, }, Attrs: attrs, } @@ -6712,26 +6287,34 @@ func Cumsum(scope *Scope, x tf.Output, axis tf.Output, optional ...CumsumAttr) ( return op.Output(0) } -// QuantizedRelu6Attr is an optional argument to QuantizedRelu6. -type QuantizedRelu6Attr func(optionalAttr) +// ComplexAttr is an optional argument to Complex. +type ComplexAttr func(optionalAttr) -// QuantizedRelu6OutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_QUINT8 -func QuantizedRelu6OutType(value tf.DataType) QuantizedRelu6Attr { +// ComplexTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_COMPLEX64 +func ComplexTout(value tf.DataType) ComplexAttr { return func(m optionalAttr) { - m["out_type"] = value + m["Tout"] = value } } -// Computes Quantized Rectified Linear 6: `min(max(features, 0), 6)` +// Converts two real numbers to a complex number. // -// Arguments: +// Given a tensor `real` representing the real part of a complex number, and a +// tensor `imag` representing the imaginary part of a complex number, this +// operation returns complex numbers elementwise of the form \\(a + bj\\), where +// *a* represents the `real` part and *b* represents the `imag` part. // -// min_features: The float value that the lowest quantized value represents. -// max_features: The float value that the highest quantized value represents. +// The input tensors `real` and `imag` must have the same shape. // -// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. -func QuantizedRelu6(scope *Scope, features tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedRelu6Attr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { +// For example: +// +// ``` +// # tensor 'real' is [2.25, 3.25] +// # tensor `imag` is [4.75, 5.75] +// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]] +// ``` +func Complex(scope *Scope, real tf.Output, imag tf.Output, optional ...ComplexAttr) (out tf.Output) { if scope.Err() != nil { return } @@ -6740,274 +6323,299 @@ func QuantizedRelu6(scope *Scope, features tf.Output, min_features tf.Output, ma a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizedRelu6", + Type: "Complex", Input: []tf.Input{ - features, min_features, max_features, + real, imag, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// FixedLengthRecordReaderV2Attr is an optional argument to FixedLengthRecordReaderV2. -type FixedLengthRecordReaderV2Attr func(optionalAttr) +// UniqueWithCountsAttr is an optional argument to UniqueWithCounts. +type UniqueWithCountsAttr func(optionalAttr) -// FixedLengthRecordReaderV2HeaderBytes sets the optional header_bytes attribute to value. -// -// value: Number of bytes in the header, defaults to 0. -// If not specified, defaults to 0 -func FixedLengthRecordReaderV2HeaderBytes(value int64) FixedLengthRecordReaderV2Attr { +// UniqueWithCountsOutIdx sets the optional out_idx attribute to value. +// If not specified, defaults to DT_INT32 +func UniqueWithCountsOutIdx(value tf.DataType) UniqueWithCountsAttr { return func(m optionalAttr) { - m["header_bytes"] = value + m["out_idx"] = value } } -// FixedLengthRecordReaderV2FooterBytes sets the optional footer_bytes attribute to value. +// Finds unique elements in a 1-D tensor. // -// value: Number of bytes in the footer, defaults to 0. -// If not specified, defaults to 0 -func FixedLengthRecordReaderV2FooterBytes(value int64) FixedLengthRecordReaderV2Attr { - return func(m optionalAttr) { - m["footer_bytes"] = value - } -} - -// FixedLengthRecordReaderV2HopBytes sets the optional hop_bytes attribute to value. +// This operation returns a tensor `y` containing all of the unique elements of `x` +// sorted in the same order that they occur in `x`. This operation also returns a +// tensor `idx` the same size as `x` that contains the index of each value of `x` +// in the unique output `y`. Finally, it returns a third tensor `count` that +// contains the count of each element of `y` in `x`. In other words: // -// value: Number of bytes to hop before each read. Default of 0 means using -// record_bytes. -// If not specified, defaults to 0 -func FixedLengthRecordReaderV2HopBytes(value int64) FixedLengthRecordReaderV2Attr { - return func(m optionalAttr) { - m["hop_bytes"] = value - } -} - -// FixedLengthRecordReaderV2Container sets the optional container attribute to value. +// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` // -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func FixedLengthRecordReaderV2Container(value string) FixedLengthRecordReaderV2Attr { - return func(m optionalAttr) { - m["container"] = value +// For example: +// +// ``` +// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] +// y, idx, count = unique_with_counts(x) +// y ==> [1, 2, 4, 7, 8] +// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] +// count ==> [2, 1, 3, 1, 2] +// ``` +// +// Arguments: +// x: 1-D. +// +// Returns 1-D.1-D.1-D. +func UniqueWithCounts(scope *Scope, x tf.Output, optional ...UniqueWithCountsAttr) (y tf.Output, idx tf.Output, count tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "UniqueWithCounts", + Input: []tf.Input{ + x, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// FixedLengthRecordReaderV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func FixedLengthRecordReaderV2SharedName(value string) FixedLengthRecordReaderV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} +// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal. +type StatelessRandomNormalAttr func(optionalAttr) -// FixedLengthRecordReaderV2Encoding sets the optional encoding attribute to value. +// StatelessRandomNormalDtype sets the optional dtype attribute to value. // -// value: The type of encoding for the file. Currently ZLIB and GZIP -// are supported. Defaults to none. -// If not specified, defaults to "" -func FixedLengthRecordReaderV2Encoding(value string) FixedLengthRecordReaderV2Attr { +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatelessRandomNormalDtype(value tf.DataType) StatelessRandomNormalAttr { return func(m optionalAttr) { - m["encoding"] = value + m["dtype"] = value } } -// A Reader that outputs fixed-length records from a file. +// Outputs deterministic pseudorandom values from a normal distribution. +// +// The generated values will have mean 0 and standard deviation 1. +// +// The outputs are a deterministic function of `shape` and `seed`. // // Arguments: -// record_bytes: Number of bytes in the record. +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). // -// Returns The handle to reference the Reader. -func FixedLengthRecordReaderV2(scope *Scope, record_bytes int64, optional ...FixedLengthRecordReaderV2Attr) (reader_handle tf.Output) { +// Returns Random values with specified shape. +func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomNormalAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"record_bytes": record_bytes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "FixedLengthRecordReaderV2", - + Type: "StatelessRandomNormal", + Input: []tf.Input{ + shape, seed, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// The gradient operator for the SparseAdd op. +// Reshapes a quantized tensor as per the Reshape op. // -// The SparseAdd op calculates A + B, where A, B, and the sum are all represented -// as `SparseTensor` objects. This op takes in the upstream gradient w.r.t. -// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty -// values of A and B. +// ``` // // Arguments: -// backprop_val_grad: 1-D with shape `[nnz(sum)]`. The gradient with respect to -// the non-empty values of the sum. -// a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`. -// b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`. -// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size -// `[nnz(sum), ndims]`. // -// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the -// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the -// non-empty values of B. -func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) { +// shape: Defines the shape of the output tensor. +// input_min: The minimum value of the input. +// input_max: The maximum value of the input. +// +// Returns This value is copied from input_min.This value is copied from input_max. +func QuantizedReshape(scope *Scope, tensor tf.Output, shape tf.Output, input_min tf.Output, input_max tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseAddGrad", + Type: "QuantizedReshape", Input: []tf.Input{ - backprop_val_grad, a_indices, b_indices, sum_indices, + tensor, shape, input_min, input_max, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0), op.Output(1), op.Output(2) } -// Computes atan of x element-wise. -func Atan(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Atan", - Input: []tf.Input{ - x, - }, +// GatherAttr is an optional argument to Gather. +type GatherAttr func(optionalAttr) + +// GatherValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func GatherValidateIndices(value bool) GatherAttr { + return func(m optionalAttr) { + m["validate_indices"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Encode audio data using the WAV file format. +// Gather slices from `params` according to `indices`. // -// This operation will generate a string suitable to be saved out to create a .wav -// audio file. It will be encoded in the 16-bit PCM format. It takes in float -// values in the range -1.0f to 1.0f, and any outside that value will be clamped to -// that range. +// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). +// Produces an output tensor with shape `indices.shape + params.shape[1:]` where: // -// `audio` is a 2-D float Tensor of shape `[length, channels]`. -// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100). +// ```python +// # Scalar indices +// output[:, ..., :] = params[indices, :, ... :] // -// Arguments: -// audio: 2-D with shape `[length, channels]`. -// sample_rate: Scalar containing the sample frequency. +// # Vector indices +// output[i, :, ..., :] = params[indices[i], :, ... :] // -// Returns 0-D. WAV-encoded file contents. -func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) { +// # Higher rank indices +// output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :] +// ``` +// +// If `indices` is a permutation and `len(indices) == params.shape[0]` then +// this operation will permute `params` accordingly. +// +// `validate_indices`: DEPRECATED. If this operation is assigned to CPU, values in +// `indices` are always validated to be within range. If assigned to GPU, +// out-of-bound indices result in safe but unspecified behavior, which may include +// raising an error. +// +//
+// +//
+func Gather(scope *Scope, params tf.Output, indices tf.Output, optional ...GatherAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "EncodeWav", + Type: "Gather", Input: []tf.Input{ - audio, sample_rate, + params, indices, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Converts each string in the input Tensor to its hash mod by a number of buckets. -// -// The hash function is deterministic on the content of the string within the -// process. The hash function is a keyed hash function, where attribute `key` -// defines the key of the hash function. `key` is an array of 2 elements. -// -// A strong hash is important when inputs may be malicious, e.g. URLs with -// additional components. Adversaries could try to make their inputs hash to the -// same bucket for a denial-of-service attack or to skew the results. A strong -// hash prevents this by making it difficult, if not infeasible, to compute inputs -// that hash to the same bucket. This comes at a cost of roughly 4x higher compute -// time than `tf.string_to_hash_bucket_fast`. -// -// Arguments: -// input: The strings to assign a hash bucket. -// num_buckets: The number of buckets. -// key: The key for the keyed hash function passed as a list of two uint64 -// elements. +// Returns the truth value of (x != y) element-wise. // -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToHashBucketStrong(scope *Scope, input tf.Output, num_buckets int64, key []int64) (output tf.Output) { +// *NOTE*: `NotEqual` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func NotEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_buckets": num_buckets, "key": key} opspec := tf.OpSpec{ - Type: "StringToHashBucketStrong", + Type: "NotEqual", Input: []tf.Input{ - input, + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Generates values in an interval. +// Inverse 3D real-valued fast Fourier transform. // -// A sequence of `num` evenly-spaced values are generated beginning at `start`. -// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`, -// so that the last one is exactly `stop`. +// Computes the inverse 3-dimensional discrete Fourier transform of a real-valued +// signal over the inner-most 3 dimensions of `input`. // -// For example: +// The inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`: +// The inner-most dimension contains the `fft_length / 2 + 1` unique components of +// the DFT of a real-valued signal. If `fft_length` is not provided, it is computed +// from the size of the inner-most 3 dimensions of `input`. If the FFT length used +// to compute `input` is odd, it should be provided since it cannot be inferred +// properly. // -// ``` -// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] -// ``` +// Along each axis `IRFFT3D` is computed on, if `fft_length` (or +// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. // // Arguments: -// start: First entry in the range. -// stop: Last entry in the range. -// num: Number of values to generate. +// input: A complex64 tensor. +// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. // -// Returns 1-D. The generated values. -func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (output tf.Output) { +// Returns A float32 tensor of the same rank as `input`. The inner-most 3 +// dimensions of `input` are replaced with the `fft_length` samples of their +// inverse 3D real Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.irfftn with 3 dimensions. +// @end_compatibility +func IRFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "LinSpace", + Type: "IRFFT3D", Input: []tf.Input{ - start, stop, num, + input, fft_length, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// DestroyResourceOpAttr is an optional argument to DestroyResourceOp. -type DestroyResourceOpAttr func(optionalAttr) +// StringSplitAttr is an optional argument to StringSplit. +type StringSplitAttr func(optionalAttr) -// DestroyResourceOpIgnoreLookupError sets the optional ignore_lookup_error attribute to value. +// StringSplitSkipEmpty sets the optional skip_empty attribute to value. // -// value: whether to ignore the error when the resource -// doesn't exist. +// value: A `bool`. If `True`, skip the empty strings from the result. // If not specified, defaults to true -func DestroyResourceOpIgnoreLookupError(value bool) DestroyResourceOpAttr { +func StringSplitSkipEmpty(value bool) StringSplitAttr { return func(m optionalAttr) { - m["ignore_lookup_error"] = value + m["skip_empty"] = value } } -// Deletes the resource specified by the handle. +// Split elements of `input` based on `delimiter` into a `SparseTensor`. // -// All subsequent operations using the resource will result in a NotFound -// error status. +// Let N be the size of source (typically N will be the batch size). Split each +// element of `input` based on `delimiter` and return a `SparseTensor` +// containing the splitted tokens. Empty tokens are ignored. +// +// `delimiter` can be empty, or a string of split characters. If `delimiter` is an +// empty string, each element of `input` is split into individual single-byte +// character strings, including splitting of UTF-8 multibyte sequences. Otherwise +// every character of `delimiter` is a potential split point. +// +// For example: +// N = 2, input[0] is 'hello world' and input[1] is 'a b c', then the output +// will be +// +// indices = [0, 0; +// 0, 1; +// 1, 0; +// 1, 1; +// 1, 2] +// shape = [2, 3] +// values = ['hello', 'world', 'a', 'b', 'c'] // // Arguments: -// resource: handle to the resource to delete. +// input: 1-D. Strings to split. +// delimiter: 0-D. Delimiter characters (bytes), or empty string. // -// Returns the created operation. -func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyResourceOpAttr) (o *tf.Operation) { +// Returns A dense matrix of int64 representing the indices of the sparse tensor.A vector of strings corresponding to the splited values.a length-2 vector of int64 representing the shape of the sparse +// tensor, where the first value is N and the second value is the maximum number +// of tokens in a single input entry. +func StringSplit(scope *Scope, input tf.Output, delimiter tf.Output, optional ...StringSplitAttr) (indices tf.Output, values tf.Output, shape tf.Output) { if scope.Err() != nil { return } @@ -7016,76 +6624,43 @@ func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyReso a(attrs) } opspec := tf.OpSpec{ - Type: "DestroyResourceOp", + Type: "StringSplit", Input: []tf.Input{ - resource, + input, delimiter, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// CumprodAttr is an optional argument to Cumprod. -type CumprodAttr func(optionalAttr) - -// CumprodExclusive sets the optional exclusive attribute to value. -// -// value: If `True`, perform exclusive cumprod. -// If not specified, defaults to false -func CumprodExclusive(value bool) CumprodAttr { - return func(m optionalAttr) { - m["exclusive"] = value - } -} +// ResizeBilinearAttr is an optional argument to ResizeBilinear. +type ResizeBilinearAttr func(optionalAttr) -// CumprodReverse sets the optional reverse attribute to value. +// ResizeBilinearAlignCorners sets the optional align_corners attribute to value. // -// value: A `bool` (default: False). +// value: If true, rescale input by (new_height - 1) / (height - 1), which +// exactly aligns the 4 corners of images and resized images. If false, rescale +// by new_height / height. Treat similarly the width dimension. // If not specified, defaults to false -func CumprodReverse(value bool) CumprodAttr { +func ResizeBilinearAlignCorners(value bool) ResizeBilinearAttr { return func(m optionalAttr) { - m["reverse"] = value + m["align_corners"] = value } } -// Compute the cumulative product of the tensor `x` along `axis`. -// -// By default, this op performs an inclusive cumprod, which means that the first -// element of the input is identical to the first element of the output: -// -// ```python -// tf.cumprod([a, b, c]) # => [a, a * b, a * b * c] -// ``` -// -// By setting the `exclusive` kwarg to `True`, an exclusive cumprod is -// performed instead: -// -// ```python -// tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] -// ``` -// -// By setting the `reverse` kwarg to `True`, the cumprod is performed in the -// opposite direction: -// -// ```python -// tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] -// ``` -// -// This is more efficient than using separate `tf.reverse` ops. -// -// The `reverse` and `exclusive` kwargs can also be combined: +// Resize `images` to `size` using bilinear interpolation. // -// ```python -// tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] -// ``` +// Input images can be of different types but output images are always float. // // Arguments: -// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, -// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, -// `complex128`, `qint8`, `quint8`, `qint32`, `half`. -// axis: A `Tensor` of type `int32` (default: 0). Must be in the range -// `[-rank(x), rank(x))`. -func Cumprod(scope *Scope, x tf.Output, axis tf.Output, optional ...CumprodAttr) (out tf.Output) { +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. +// +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBilinearAttr) (resized_images tf.Output) { if scope.Err() != nil { return } @@ -7094,9 +6669,9 @@ func Cumprod(scope *Scope, x tf.Output, axis tf.Output, optional ...CumprodAttr) a(attrs) } opspec := tf.OpSpec{ - Type: "Cumprod", + Type: "ResizeBilinear", Input: []tf.Input{ - x, axis, + images, size, }, Attrs: attrs, } @@ -7104,395 +6679,391 @@ func Cumprod(scope *Scope, x tf.Output, axis tf.Output, optional ...CumprodAttr) return op.Output(0) } -// Computes the mean along segments of a tensor. -// -// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of -// segments. -// -// Computes a tensor such that -// \\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is -// over `j` such that `segment_ids[j] == i` and `N` is the total number of -// values summed. -// -// If the mean is empty for a given segment ID `i`, `output[i] = 0`. -// -//
-// -//
-// -// Arguments: -// -// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s -// first dimension. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SegmentMean(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { +// Computes softsign: `features / (abs(features) + 1)`. +func Softsign(scope *Scope, features tf.Output) (activations tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SegmentMean", + Type: "Softsign", Input: []tf.Input{ - data, segment_ids, + features, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceSparseApplyCenteredRMSPropAttr is an optional argument to ResourceSparseApplyCenteredRMSProp. -type ResourceSparseApplyCenteredRMSPropAttr func(optionalAttr) +// GenerateVocabRemappingAttr is an optional argument to GenerateVocabRemapping. +type GenerateVocabRemappingAttr func(optionalAttr) -// ResourceSparseApplyCenteredRMSPropUseLocking sets the optional use_locking attribute to value. +// GenerateVocabRemappingOldVocabSize sets the optional old_vocab_size attribute to value. // -// value: If `True`, updating of the var, mg, ms, and mom tensors is -// protected by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyCenteredRMSPropUseLocking(value bool) ResourceSparseApplyCenteredRMSPropAttr { +// value: Number of entries in the old vocab file to consider. If -1, +// use the entire old vocabulary. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func GenerateVocabRemappingOldVocabSize(value int64) GenerateVocabRemappingAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["old_vocab_size"] = value } } -// Update '*var' according to the centered RMSProp algorithm. +// Given a path to new and old vocabulary files, returns a remapping Tensor of // -// The centered RMSProp algorithm uses an estimate of the centered second moment -// (i.e., the variance) for normalization, as opposed to regular RMSProp, which -// uses the (uncentered) second moment. This often helps with training, but is -// slightly more expensive in terms of computation and memory. +// length `num_new_vocab`, where `remapping[i]` contains the row number in the old +// vocabulary that corresponds to row `i` in the new vocabulary (starting at line +// `new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i` +// in the new vocabulary is not in the old vocabulary. The old vocabulary is +// constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the +// default value of -1. // -// Note that in dense implementation of this algorithm, mg, ms, and mom will -// update even if the grad is zero, but in this sparse implementation, mg, ms, -// and mom will not update in iterations during which the grad is zero. +// `num_vocab_offset` enables +// use in the partitioned variable case, and should generally be set through +// examining partitioning info. The format of the files should be a text file, +// with each line containing a single entity within the vocabulary. // -// mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// mean_grad = decay * mean_grad + (1-decay) * gradient -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) +// For example, with `new_vocab_file` a text file containing each of the following +// elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3], +// `num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be +// `[0, -1, 2]`. // -// ms <- rho * ms_{t-1} + (1-rho) * grad * grad -// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) -// var <- var - mom +// The op also returns a count of how many entries in the new vocabulary +// were present in the old vocabulary, which is used to calculate the number of +// values to initialize in a weight matrix remapping // -// Arguments: -// var_: Should be from a Variable(). -// mg: Should be from a Variable(). -// ms: Should be from a Variable(). -// mom: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// rho: Decay rate. Must be a scalar. +// This functionality can be used to remap both row vocabularies (typically, +// features) and column vocabularies (typically, classes) from TensorFlow +// checkpoints. Note that the partitioning logic relies on contiguous vocabularies +// corresponding to div-partitioned variables. Moreover, the underlying remapping +// uses an IndexTable (as opposed to an inexact CuckooTable), so client code should +// use the corresponding index_table_from_file() as the FeatureColumn framework +// does (as opposed to tf.feature_to_id(), which uses a CuckooTable). // -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var, ms and mom. +// Arguments: +// new_vocab_file: Path to the new vocab file. +// old_vocab_file: Path to the old vocab file. +// new_vocab_offset: How many entries into the new vocab file to start reading. +// num_new_vocab: Number of entries in the new vocab file to remap. // -// Returns the created operation. -func ResourceSparseApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyCenteredRMSPropAttr) (o *tf.Operation) { +// Returns A Tensor of length num_new_vocab where the element at index i +// is equal to the old ID that maps to the new ID i. This element is -1 for any +// new ID that is not found in the old vocabulary.Number of new vocab entries found in old vocab. +func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_file tf.Output, new_vocab_offset int64, num_new_vocab int64, optional ...GenerateVocabRemappingAttr) (remapping tf.Output, num_present tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"new_vocab_offset": new_vocab_offset, "num_new_vocab": num_new_vocab} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyCenteredRMSProp", + Type: "GenerateVocabRemapping", Input: []tf.Input{ - var_, mg, ms, mom, lr, rho, momentum, epsilon, grad, indices, + new_vocab_file, old_vocab_file, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } -// Creates a dataset that batches `batch_size` elements from `input_dataset`. +// Assigns sparse updates to the variable referenced by `resource`. // -// Arguments: +// This operation computes // -// batch_size: A scalar representing the number of elements to accumulate in a -// batch. +// # Scalar indices +// ref[indices, ...] = updates[...] +// +// # Vector indices (for each i) +// ref[indices[i], ...] = updates[i, ...] // +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] // -func BatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Arguments: +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. +// +// Returns the created operation. +func ResourceScatterUpdate(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "BatchDataset", + Type: "ResourceScatterUpdate", Input: []tf.Input{ - input_dataset, batch_size, + resource, indices, updates, }, - Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Inverse fast Fourier transform. +// AvgPoolGradAttr is an optional argument to AvgPoolGrad. +type AvgPoolGradAttr func(optionalAttr) + +// AvgPoolGradDataFormat sets the optional data_format attribute to value. // -// Computes the inverse 1-dimensional discrete Fourier transform over the -// inner-most dimension of `input`. +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func AvgPoolGradDataFormat(value string) AvgPoolGradAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes gradients of the average pooling function. // // Arguments: -// input: A complex64 tensor. -// -// Returns A complex64 tensor of the same shape as `input`. The inner-most -// dimension of `input` is replaced with its inverse 1D Fourier transform. +// orig_input_shape: 1-D. Shape of the original input to `avg_pool`. +// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. +// the output of `avg_pool`. +// ksize: The size of the sliding window for each dimension of the input. +// strides: The stride of the sliding window for each dimension of the input. +// padding: The type of padding algorithm to use. // -// @compatibility(numpy) -// Equivalent to np.fft.ifft -// @end_compatibility -func IFFT(scope *Scope, input tf.Output) (output tf.Output) { +// Returns 4-D. Gradients w.r.t. the input of `avg_pool`. +func AvgPoolGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolGradAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "IFFT", + Type: "AvgPoolGrad", Input: []tf.Input{ - input, + orig_input_shape, grad, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// LRNAttr is an optional argument to LRN. -type LRNAttr func(optionalAttr) +// StageClearAttr is an optional argument to StageClear. +type StageClearAttr func(optionalAttr) -// LRNDepthRadius sets the optional depth_radius attribute to value. +// StageClearCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// value: 0-D. Half-width of the 1-D normalization window. -// If not specified, defaults to 5 -func LRNDepthRadius(value int64) LRNAttr { +// REQUIRES: value >= 0 +func StageClearCapacity(value int64) StageClearAttr { return func(m optionalAttr) { - m["depth_radius"] = value + m["capacity"] = value } } -// LRNBias sets the optional bias attribute to value. +// StageClearMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// value: An offset (usually positive to avoid dividing by 0). -// If not specified, defaults to 1 -func LRNBias(value float32) LRNAttr { +// REQUIRES: value >= 0 +func StageClearMemoryLimit(value int64) StageClearAttr { return func(m optionalAttr) { - m["bias"] = value + m["memory_limit"] = value } } -// LRNAlpha sets the optional alpha attribute to value. -// -// value: A scale factor, usually positive. -// If not specified, defaults to 1 -func LRNAlpha(value float32) LRNAttr { +// StageClearContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func StageClearContainer(value string) StageClearAttr { return func(m optionalAttr) { - m["alpha"] = value + m["container"] = value } } -// LRNBeta sets the optional beta attribute to value. -// -// value: An exponent. -// If not specified, defaults to 0.5 -func LRNBeta(value float32) LRNAttr { +// StageClearSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func StageClearSharedName(value string) StageClearAttr { return func(m optionalAttr) { - m["beta"] = value + m["shared_name"] = value } } -// Local Response Normalization. -// -// The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last -// dimension), and each vector is normalized independently. Within a given vector, -// each component is divided by the weighted, squared sum of inputs within -// `depth_radius`. In detail, -// -// sqr_sum[a, b, c, d] = -// sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) -// output = input / (bias + alpha * sqr_sum) ** beta -// -// For details, see [Krizhevsky et al., ImageNet classification with deep -// convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks). +// Op removes all elements in the underlying container. // -// Arguments: -// input: 4-D. -func LRN(scope *Scope, input tf.Output, optional ...LRNAttr) (output tf.Output) { +// Returns the created operation. +func StageClear(scope *Scope, dtypes []tf.DataType, optional ...StageClearAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "LRN", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} + Type: "StageClear", -// Creates a dataset that zips together `input_datasets`. -func ZipDataset(scope *Scope, input_datasets []tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "ZipDataset", - Input: []tf.Input{ - tf.OutputList(input_datasets), - }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Writes a `GraphDef` protocol buffer to a `SummaryWriter`. -// -// Arguments: -// writer: Handle of `SummaryWriter`. -// step: The step to write the summary for. -// tensor: A scalar string of the serialized tf.GraphDef proto. +// ComputeAccidentalHitsAttr is an optional argument to ComputeAccidentalHits. +type ComputeAccidentalHitsAttr func(optionalAttr) + +// ComputeAccidentalHitsSeed sets the optional seed attribute to value. // -// Returns the created operation. -func WriteGraphSummary(scope *Scope, writer tf.Output, step tf.Output, tensor tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "WriteGraphSummary", - Input: []tf.Input{ - writer, step, tensor, - }, +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func ComputeAccidentalHitsSeed(value int64) ComputeAccidentalHitsAttr { + return func(m optionalAttr) { + m["seed"] = value } - return scope.AddOperation(opspec) } -// ResourceSparseApplyAdagradAttr is an optional argument to ResourceSparseApplyAdagrad. -type ResourceSparseApplyAdagradAttr func(optionalAttr) - -// ResourceSparseApplyAdagradUseLocking sets the optional use_locking attribute to value. +// ComputeAccidentalHitsSeed2 sets the optional seed2 attribute to value. // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyAdagradUseLocking(value bool) ResourceSparseApplyAdagradAttr { +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func ComputeAccidentalHitsSeed2(value int64) ComputeAccidentalHitsAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["seed2"] = value } } -// Update relevant entries in '*var' and '*accum' according to the adagrad scheme. +// Computes the ids of the positions in sampled_candidates that match true_labels. // -// That is for rows we have grad for, we update var and accum as follows: -// accum += grad * grad -// var -= lr * grad * (1 / sqrt(accum)) +// When doing log-odds NCE, the result of this op should be passed through a +// SparseToDense op, then added to the logits of the sampled candidates. This has +// the effect of 'removing' the sampled labels that match the true labels by +// making the classifier sure that they are sampled labels. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. +// true_classes: The true_classes output of UnpackSparseLabels. +// sampled_candidates: The sampled_candidates output of CandidateSampler. +// num_true: Number of true labels per context. // -// Returns the created operation. -func ResourceSparseApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdagradAttr) (o *tf.Operation) { +// Returns A vector of indices corresponding to rows of true_candidates.A vector of IDs of positions in sampled_candidates that match a true_label +// for the row with the corresponding index in indices.A vector of the same length as indices and ids, in which each element +// is -FLOAT_MAX. +func ComputeAccidentalHits(scope *Scope, true_classes tf.Output, sampled_candidates tf.Output, num_true int64, optional ...ComputeAccidentalHitsAttr) (indices tf.Output, ids tf.Output, weights tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_true": num_true} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyAdagrad", + Type: "ComputeAccidentalHits", Input: []tf.Input{ - var_, accum, lr, grad, indices, + true_classes, sampled_candidates, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// 2D real-valued fast Fourier transform. +// CumsumAttr is an optional argument to Cumsum. +type CumsumAttr func(optionalAttr) + +// CumsumExclusive sets the optional exclusive attribute to value. // -// Computes the 2-dimensional discrete Fourier transform of a real-valued signal -// over the inner-most 2 dimensions of `input`. +// value: If `True`, perform exclusive cumsum. +// If not specified, defaults to false +func CumsumExclusive(value bool) CumsumAttr { + return func(m optionalAttr) { + m["exclusive"] = value + } +} + +// CumsumReverse sets the optional reverse attribute to value. // -// Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the -// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension -// of `output`: the zero-frequency term, followed by the `fft_length / 2` -// positive-frequency terms. +// value: A `bool` (default: False). +// If not specified, defaults to false +func CumsumReverse(value bool) CumsumAttr { + return func(m optionalAttr) { + m["reverse"] = value + } +} + +// Compute the cumulative sum of the tensor `x` along `axis`. // -// Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. +// By default, this op performs an inclusive cumsum, which means that the first +// element of the input is identical to the first element of the output: // -// Arguments: -// input: A float32 tensor. -// fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. +// ```python +// tf.cumsum([a, b, c]) # => [a, a + b, a + b + c] +// ``` // -// Returns A complex64 tensor of the same rank as `input`. The inner-most 2 -// dimensions of `input` are replaced with their 2D Fourier transform. The -// inner-most dimension contains `fft_length / 2 + 1` unique frequency -// components. +// By setting the `exclusive` kwarg to `True`, an exclusive cumsum is +// performed instead: // -// @compatibility(numpy) -// Equivalent to np.fft.rfft2 -// @end_compatibility -func RFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { +// ```python +// tf.cumsum([a, b, c], exclusive=True) # => [0, a, a + b] +// ``` +// +// By setting the `reverse` kwarg to `True`, the cumsum is performed in the +// opposite direction: +// +// ```python +// tf.cumsum([a, b, c], reverse=True) # => [a + b + c, b + c, c] +// ``` +// +// This is more efficient than using separate `tf.reverse` ops. +// +// The `reverse` and `exclusive` kwargs can also be combined: +// +// ```python +// tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] +// ``` +// +// Arguments: +// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, +// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, +// `complex128`, `qint8`, `quint8`, `qint32`, `half`. +// axis: A `Tensor` of type `int32` (default: 0). Must be in the range +// `[-rank(x), rank(x))`. +func Cumsum(scope *Scope, x tf.Output, axis tf.Output, optional ...CumsumAttr) (out tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "RFFT2D", + Type: "Cumsum", Input: []tf.Input{ - input, fft_length, + x, axis, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResizeAreaAttr is an optional argument to ResizeArea. -type ResizeAreaAttr func(optionalAttr) +// QuantizedRelu6Attr is an optional argument to QuantizedRelu6. +type QuantizedRelu6Attr func(optionalAttr) -// ResizeAreaAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, rescale input by (new_height - 1) / (height - 1), which -// exactly aligns the 4 corners of images and resized images. If false, rescale -// by new_height / height. Treat similarly the width dimension. -// If not specified, defaults to false -func ResizeAreaAlignCorners(value bool) ResizeAreaAttr { +// QuantizedRelu6OutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_QUINT8 +func QuantizedRelu6OutType(value tf.DataType) QuantizedRelu6Attr { return func(m optionalAttr) { - m["align_corners"] = value + m["out_type"] = value } } -// Resize `images` to `size` using area interpolation. -// -// Input images can be of different types but output images are always float. -// -// Each output pixel is computed by first transforming the pixel's footprint into -// the input tensor and then averaging the pixels that intersect the footprint. An -// input pixel's contribution to the average is weighted by the fraction of its -// area that intersects the footprint. This is the same as OpenCV's INTER_AREA. +// Computes Quantized Rectified Linear 6: `min(max(features, 0), 6)` // // Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. // -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeAreaAttr) (resized_images tf.Output) { +// min_features: The float value that the lowest quantized value represents. +// max_features: The float value that the highest quantized value represents. +// +// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. +func QuantizedRelu6(scope *Scope, features tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedRelu6Attr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { if scope.Err() != nil { return } @@ -7501,102 +7072,129 @@ func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...Resi a(attrs) } opspec := tf.OpSpec{ - Type: "ResizeArea", + Type: "QuantizedRelu6", Input: []tf.Input{ - images, size, + features, min_features, max_features, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// StatelessRandomUniformAttr is an optional argument to StatelessRandomUniform. -type StatelessRandomUniformAttr func(optionalAttr) +// FixedLengthRecordReaderV2Attr is an optional argument to FixedLengthRecordReaderV2. +type FixedLengthRecordReaderV2Attr func(optionalAttr) -// StatelessRandomUniformDtype sets the optional dtype attribute to value. +// FixedLengthRecordReaderV2HeaderBytes sets the optional header_bytes attribute to value. // -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatelessRandomUniformDtype(value tf.DataType) StatelessRandomUniformAttr { +// value: Number of bytes in the header, defaults to 0. +// If not specified, defaults to 0 +func FixedLengthRecordReaderV2HeaderBytes(value int64) FixedLengthRecordReaderV2Attr { return func(m optionalAttr) { - m["dtype"] = value + m["header_bytes"] = value } } -// Outputs deterministic pseudorandom random values from a uniform distribution. +// FixedLengthRecordReaderV2FooterBytes sets the optional footer_bytes attribute to value. // -// The generated values follow a uniform distribution in the range `[0, 1)`. The -// lower bound 0 is included in the range, while the upper bound 1 is excluded. +// value: Number of bytes in the footer, defaults to 0. +// If not specified, defaults to 0 +func FixedLengthRecordReaderV2FooterBytes(value int64) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["footer_bytes"] = value + } +} + +// FixedLengthRecordReaderV2HopBytes sets the optional hop_bytes attribute to value. // -// The outputs are a deterministic function of `shape` and `seed`. +// value: Number of bytes to hop before each read. Default of 0 means using +// record_bytes. +// If not specified, defaults to 0 +func FixedLengthRecordReaderV2HopBytes(value int64) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["hop_bytes"] = value + } +} + +// FixedLengthRecordReaderV2Container sets the optional container attribute to value. +// +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func FixedLengthRecordReaderV2Container(value string) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// FixedLengthRecordReaderV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func FixedLengthRecordReaderV2SharedName(value string) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// FixedLengthRecordReaderV2Encoding sets the optional encoding attribute to value. +// +// value: The type of encoding for the file. Currently ZLIB and GZIP +// are supported. Defaults to none. +// If not specified, defaults to "" +func FixedLengthRecordReaderV2Encoding(value string) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["encoding"] = value + } +} + +// A Reader that outputs fixed-length records from a file. // // Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). +// record_bytes: Number of bytes in the record. // -// Returns Random values with specified shape. -func StatelessRandomUniform(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomUniformAttr) (output tf.Output) { +// Returns The handle to reference the Reader. +func FixedLengthRecordReaderV2(scope *Scope, record_bytes int64, optional ...FixedLengthRecordReaderV2Attr) (reader_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"record_bytes": record_bytes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "StatelessRandomUniform", - Input: []tf.Input{ - shape, seed, - }, + Type: "FixedLengthRecordReaderV2", + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// AngleAttr is an optional argument to Angle. -type AngleAttr func(optionalAttr) - -// AngleTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_FLOAT -func AngleTout(value tf.DataType) AngleAttr { - return func(m optionalAttr) { - m["Tout"] = value - } -} - -// Returns the argument of a complex number. +// Converts each string in the input Tensor to its hash mod by a number of buckets. // -// Given a tensor `input` of complex numbers, this operation returns a tensor of -// type `float` that is the argument of each element in `input`. All elements in -// `input` must be complex numbers of the form \\(a + bj\\), where *a* -// is the real part and *b* is the imaginary part. +// The hash function is deterministic on the content of the string within the +// process. // -// The argument returned by this operation is of the form \\(atan2(b, a)\\). +// Note that the hash function may change from time to time. +// This functionality will be deprecated and it's recommended to use +// `tf.string_to_hash_bucket_fast()` or `tf.string_to_hash_bucket_strong()`. // -// For example: +// Arguments: // -// ``` -// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] -// tf.angle(input) ==> [2.0132, 1.056] -// ``` +// num_buckets: The number of buckets. // -// @compatibility(numpy) -// Equivalent to np.angle. -// @end_compatibility -func Angle(scope *Scope, input tf.Output, optional ...AngleAttr) (output tf.Output) { +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToHashBucket(scope *Scope, string_tensor tf.Output, num_buckets int64) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"num_buckets": num_buckets} opspec := tf.OpSpec{ - Type: "Angle", + Type: "StringToHashBucket", Input: []tf.Input{ - input, + string_tensor, }, Attrs: attrs, } @@ -7604,411 +7202,350 @@ func Angle(scope *Scope, input tf.Output, optional ...AngleAttr) (output tf.Outp return op.Output(0) } -// VarHandleOpAttr is an optional argument to VarHandleOp. -type VarHandleOpAttr func(optionalAttr) - -// VarHandleOpContainer sets the optional container attribute to value. +// Computes gradients for the exponential linear (Elu) operation. // -// value: the container this variable is placed in. -// If not specified, defaults to "" -func VarHandleOpContainer(value string) VarHandleOpAttr { - return func(m optionalAttr) { - m["container"] = value +// Arguments: +// gradients: The backpropagated gradients to the corresponding Elu operation. +// outputs: The outputs of the corresponding Elu operation. +// +// Returns The gradients: `gradients * (outputs + 1)` if outputs < 0, +// `gradients` otherwise. +func EluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "EluGrad", + Input: []tf.Input{ + gradients, outputs, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// VarHandleOpSharedName sets the optional shared_name attribute to value. +// Creates a dataset that contains `count` elements from the `input_dataset`. // -// value: the name by which this variable is referred to. -// If not specified, defaults to "" -func VarHandleOpSharedName(value string) VarHandleOpAttr { - return func(m optionalAttr) { - m["shared_name"] = value +// Arguments: +// +// count: A scalar representing the number of elements from the `input_dataset` +// that should be taken. A value of `-1` indicates that all of `input_dataset` +// is taken. +// +// +func TakeDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "TakeDataset", + Input: []tf.Input{ + input_dataset, count, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Creates a handle to a Variable resource. +// The gradient operator for the SparseAdd op. +// +// The SparseAdd op calculates A + B, where A, B, and the sum are all represented +// as `SparseTensor` objects. This op takes in the upstream gradient w.r.t. +// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty +// values of A and B. // // Arguments: -// dtype: the type of this variable. Must agree with the dtypes -// of all ops using this variable. -// shape: The (possibly partially specified) shape of this variable. -func VarHandleOp(scope *Scope, dtype tf.DataType, shape tf.Shape, optional ...VarHandleOpAttr) (resource tf.Output) { +// backprop_val_grad: 1-D with shape `[nnz(sum)]`. The gradient with respect to +// the non-empty values of the sum. +// a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`. +// b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`. +// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size +// `[nnz(sum), ndims]`. +// +// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the +// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the +// non-empty values of B. +func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype, "shape": shape} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "VarHandleOp", - - Attrs: attrs, + Type: "SparseAddGrad", + Input: []tf.Input{ + backprop_val_grad, a_indices, b_indices, sum_indices, + }, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Elementwise computes the bitwise XOR of `x` and `y`. -// -// The result will have those bits set, that are different in `x` and `y`. The -// computation is performed on the underlying representations of `x` and `y`. -func BitwiseXor(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Computes atan of x element-wise. +func Atan(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "BitwiseXor", + Type: "Atan", Input: []tf.Input{ - x, y, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Deserialize `SparseTensor` objects. -// -// The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where -// the last dimension stores serialized `SparseTensor` objects and the other N -// dimensions (N >= 0) correspond to a batch. The ranks of the original -// `SparseTensor` objects must all match. When the final `SparseTensor` is -// created, its rank is the rank of the incoming `SparseTensor` objects plus N; -// the sparse tensors have been concatenated along new dimensions, one for each -// batch. -// -// The output `SparseTensor` object's shape values for the original dimensions -// are the max across the input `SparseTensor` objects' shape values for the -// corresponding dimensions. The new dimensions match the size of the batch. -// -// The input `SparseTensor` objects' indices are assumed ordered in -// standard lexicographic order. If this is not the case, after this -// step run `SparseReorder` to restore index ordering. -// -// For example, if the serialized input is a `[2 x 3]` matrix representing two -// original `SparseTensor` objects: -// -// index = [ 0] -// [10] -// [20] -// values = [1, 2, 3] -// shape = [50] -// -// and -// -// index = [ 2] -// [10] -// values = [4, 5] -// shape = [30] +// Encode audio data using the WAV file format. // -// then the final deserialized `SparseTensor` will be: +// This operation will generate a string suitable to be saved out to create a .wav +// audio file. It will be encoded in the 16-bit PCM format. It takes in float +// values in the range -1.0f to 1.0f, and any outside that value will be clamped to +// that range. // -// index = [0 0] -// [0 10] -// [0 20] -// [1 2] -// [1 10] -// values = [1, 2, 3, 4, 5] -// shape = [2 50] +// `audio` is a 2-D float Tensor of shape `[length, channels]`. +// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100). // // Arguments: -// serialized_sparse: The serialized `SparseTensor` objects. The last dimension -// must have 3 columns. -// dtype: The `dtype` of the serialized `SparseTensor` objects. -func DeserializeSparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { +// audio: 2-D with shape `[length, channels]`. +// sample_rate: Scalar containing the sample frequency. +// +// Returns 0-D. WAV-encoded file contents. +func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "DeserializeSparse", + Type: "EncodeWav", Input: []tf.Input{ - serialized_sparse, + audio, sample_rate, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// ResourceApplyRMSPropAttr is an optional argument to ResourceApplyRMSProp. -type ResourceApplyRMSPropAttr func(optionalAttr) - -// ResourceApplyRMSPropUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var, ms, and mom tensors is protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyRMSPropUseLocking(value bool) ResourceApplyRMSPropAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } + return op.Output(0) } -// Update '*var' according to the RMSProp algorithm. -// -// Note that in dense implementation of this algorithm, ms and mom will -// update even if the grad is zero, but in this sparse implementation, ms -// and mom will not update in iterations during which the grad is zero. +// Converts each string in the input Tensor to its hash mod by a number of buckets. // -// mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) +// The hash function is deterministic on the content of the string within the +// process. The hash function is a keyed hash function, where attribute `key` +// defines the key of the hash function. `key` is an array of 2 elements. // -// ms <- rho * ms_{t-1} + (1-rho) * grad * grad -// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) -// var <- var - mom +// A strong hash is important when inputs may be malicious, e.g. URLs with +// additional components. Adversaries could try to make their inputs hash to the +// same bucket for a denial-of-service attack or to skew the results. A strong +// hash prevents this by making it difficult, if not infeasible, to compute inputs +// that hash to the same bucket. This comes at a cost of roughly 4x higher compute +// time than `tf.string_to_hash_bucket_fast`. // // Arguments: -// var_: Should be from a Variable(). -// ms: Should be from a Variable(). -// mom: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// rho: Decay rate. Must be a scalar. -// -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. +// input: The strings to assign a hash bucket. +// num_buckets: The number of buckets. +// key: The key for the keyed hash function passed as a list of two uint64 +// elements. // -// Returns the created operation. -func ResourceApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyRMSPropAttr) (o *tf.Operation) { +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToHashBucketStrong(scope *Scope, input tf.Output, num_buckets int64, key []int64) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"num_buckets": num_buckets, "key": key} opspec := tf.OpSpec{ - Type: "ResourceApplyRMSProp", + Type: "StringToHashBucketStrong", Input: []tf.Input{ - var_, ms, mom, lr, rho, momentum, epsilon, grad, + input, }, Attrs: attrs, } - return scope.AddOperation(opspec) -} - -// SizeAttr is an optional argument to Size. -type SizeAttr func(optionalAttr) - -// SizeOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_INT32 -func SizeOutType(value tf.DataType) SizeAttr { - return func(m optionalAttr) { - m["out_type"] = value - } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Returns the size of a tensor. +// Applies softmax to a batched N-D `SparseTensor`. // -// This operation returns an integer representing the number of elements in -// `input`. +// The inputs represent an N-D SparseTensor with logical shape `[..., B, C]` +// (where `N >= 2`), and with indices sorted in the canonical lexicographic order. // -// For example: +// This op is equivalent to applying the normal `tf.nn.softmax()` to each innermost +// logical submatrix with shape `[B, C]`, but with the catch that *the implicitly +// zero elements do not participate*. Specifically, the algorithm is equivalent +// to the following: // -// ``` -// # 't' is [[[1, 1,, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]] -// size(t) ==> 12 -// ``` -func Size(scope *Scope, input tf.Output, optional ...SizeAttr) (output tf.Output) { +// (1) Applies `tf.nn.softmax()` to a densified view of each innermost submatrix +// with shape `[B, C]`, along the size-C dimension; +// (2) Masks out the original implicitly-zero locations; +// (3) Renormalizes the remaining elements. +// +// Hence, the `SparseTensor` result has exactly the same non-zero indices and +// shape. +// +// Arguments: +// sp_indices: 2-D. `NNZ x R` matrix with the indices of non-empty values in a +// SparseTensor, in canonical ordering. +// sp_values: 1-D. `NNZ` non-empty values corresponding to `sp_indices`. +// sp_shape: 1-D. Shape of the input SparseTensor. +// +// Returns 1-D. The `NNZ` values for the result `SparseTensor`. +func SparseSoftmax(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Size", + Type: "SparseSoftmax", Input: []tf.Input{ - input, + sp_indices, sp_values, sp_shape, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceScatterNdUpdateAttr is an optional argument to ResourceScatterNdUpdate. -type ResourceScatterNdUpdateAttr func(optionalAttr) - -// ResourceScatterNdUpdateUseLocking sets the optional use_locking attribute to value. -// -// value: An optional bool. Defaults to True. If True, the assignment will -// be protected by a lock; otherwise the behavior is undefined, -// but may exhibit less contention. -// If not specified, defaults to true -func ResourceScatterNdUpdateUseLocking(value bool) ResourceScatterNdUpdateAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Applies sparse `updates` to individual values or slices within a given -// -// variable according to `indices`. -// -// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. -// -// `indices` must be integer tensor, containing indices into `ref`. -// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. +// Partitions `data` into `num_partitions` tensors using indices from `partitions`. // -// The innermost dimension of `indices` (with length `K`) corresponds to -// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th -// dimension of `ref`. +// For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]` +// becomes part of `outputs[partitions[js]]`. The slices with `partitions[js] = i` +// are placed in `outputs[i]` in lexicographic order of `js`, and the first +// dimension of `outputs[i]` is the number of entries in `partitions` equal to `i`. +// In detail, // -// `updates` is `Tensor` of rank `Q-1+P-K` with shape: +// ```python +// outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:] // -// ``` -// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. +// outputs[i] = pack([data[js, ...] for js if partitions[js] == i]) // ``` // -// For example, say we want to update 4 scattered elements to a rank-1 tensor to -// 8 elements. In Python, that update would look like this: +// `data.shape` must start with `partitions.shape`. +// +// For example: // // ```python -// ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8]) -// indices = tf.constant([[4], [3], [1] ,[7]]) -// updates = tf.constant([9, 10, 11, 12]) -// update = tf.scatter_nd_update(ref, indices, updates) -// with tf.Session() as sess: -// print sess.run(update) -// ``` +// # Scalar partitions. +// partitions = 1 +// num_partitions = 2 +// data = [10, 20] +// outputs[0] = [] # Empty with shape [0, 2] +// outputs[1] = [[10, 20]] // -// The resulting update to ref would look like this: +// # Vector partitions. +// partitions = [0, 0, 1, 1, 0] +// num_partitions = 2 +// data = [10, 20, 30, 40, 50] +// outputs[0] = [10, 20, 50] +// outputs[1] = [30, 40] +// ``` // -// [1, 11, 3, 10, 9, 6, 7, 12] +// See `dynamic_stitch` for an example on how to merge partitions back. // -// See @{tf.scatter_nd} for more details about how to make updates to -// slices. +//
+// +//
// // Arguments: -// ref: A resource handle. Must be from a VarHandleOp. -// indices: A Tensor. Must be one of the following types: int32, int64. -// A tensor of indices into ref. -// updates: A Tensor. Must have the same type as ref. A tensor of updated -// values to add to ref. // -// Returns the created operation. -func ResourceScatterNdUpdate(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdUpdateAttr) (o *tf.Operation) { +// partitions: Any shape. Indices in the range `[0, num_partitions)`. +// num_partitions: The number of partitions to output. +func DynamicPartition(scope *Scope, data tf.Output, partitions tf.Output, num_partitions int64) (outputs []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"num_partitions": num_partitions} opspec := tf.OpSpec{ - Type: "ResourceScatterNdUpdate", + Type: "DynamicPartition", Input: []tf.Input{ - ref, indices, updates, + data, partitions, }, Attrs: attrs, } - return scope.AddOperation(opspec) -} - -// StageSizeAttr is an optional argument to StageSize. -type StageSizeAttr func(optionalAttr) - -// StageSizeCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func StageSizeCapacity(value int64) StageSizeAttr { - return func(m optionalAttr) { - m["capacity"] = value + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return } -} - -// StageSizeMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func StageSizeMemoryLimit(value int64) StageSizeAttr { - return func(m optionalAttr) { - m["memory_limit"] = value + var idx int + var err error + if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { + scope.UpdateErr("DynamicPartition", err) + return } + return outputs } -// StageSizeContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func StageSizeContainer(value string) StageSizeAttr { - return func(m optionalAttr) { - m["container"] = value - } -} +// ResourceApplyAdagradAttr is an optional argument to ResourceApplyAdagrad. +type ResourceApplyAdagradAttr func(optionalAttr) -// StageSizeSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func StageSizeSharedName(value string) StageSizeAttr { +// ResourceApplyAdagradUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyAdagradUseLocking(value bool) ResourceApplyAdagradAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["use_locking"] = value } } -// Op returns the number of elements in the underlying container. -func StageSize(scope *Scope, dtypes []tf.DataType, optional ...StageSizeAttr) (size tf.Output) { +// Update '*var' according to the adagrad scheme. +// +// accum += grad * grad +// var -= lr * grad * (1 / sqrt(accum)) +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, optional ...ResourceApplyAdagradAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "StageSize", - + Type: "ResourceApplyAdagrad", + Input: []tf.Input{ + var_, accum, lr, grad, + }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// NonMaxSuppressionAttr is an optional argument to NonMaxSuppression. -type NonMaxSuppressionAttr func(optionalAttr) +// ResourceApplyPowerSignAttr is an optional argument to ResourceApplyPowerSign. +type ResourceApplyPowerSignAttr func(optionalAttr) -// NonMaxSuppressionIouThreshold sets the optional iou_threshold attribute to value. +// ResourceApplyPowerSignUseLocking sets the optional use_locking attribute to value. // -// value: A float representing the threshold for deciding whether boxes -// overlap too much with respect to IOU. -// If not specified, defaults to 0.5 -func NonMaxSuppressionIouThreshold(value float32) NonMaxSuppressionAttr { +// value: If `True`, updating of the var and m tensors is +// protected by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyPowerSignUseLocking(value bool) ResourceApplyPowerSignAttr { return func(m optionalAttr) { - m["iou_threshold"] = value + m["use_locking"] = value } } -// Greedily selects a subset of bounding boxes in descending order of score, +// Update '*var' according to the AddSign update. // -// pruning away boxes that have high intersection-over-union (IOU) overlap -// with previously selected boxes. Bounding boxes are supplied as -// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any -// diagonal pair of box corners and the coordinates can be provided as normalized -// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm -// is agnostic to where the origin is in the coordinate system. Note that this -// algorithm is invariant to orthogonal transformations and translations -// of the coordinate system; thus translating or reflections of the coordinate -// system result in the same boxes being selected by the algorithm. -// The output of this operation is a set of integers indexing into the input -// collection of bounding boxes representing the selected boxes. The bounding -// box coordinates corresponding to the selected indices can then be obtained -// using the `tf.gather operation`. For example: -// selected_indices = tf.image.non_max_suppression( -// boxes, scores, max_output_size, iou_threshold) -// selected_boxes = tf.gather(boxes, selected_indices) +// m_t <- beta1 * m_{t-1} + (1 - beta1) * g +// update <- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g +// variable <- variable - lr_t * update // // Arguments: -// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. -// scores: A 1-D float tensor of shape `[num_boxes]` representing a single -// score corresponding to each box (each row of boxes). -// max_output_size: A scalar integer tensor representing the maximum number of -// boxes to be selected by non max suppression. +// var_: Should be from a Variable(). +// m: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// logbase: Must be a scalar. +// sign_decay: Must be a scalar. +// beta: Must be a scalar. +// grad: The gradient. // -// Returns A 1-D integer tensor of shape `[M]` representing the selected -// indices from the boxes tensor, where `M <= max_output_size`. -func NonMaxSuppression(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, optional ...NonMaxSuppressionAttr) (selected_indices tf.Output) { +// Returns the created operation. +func ResourceApplyPowerSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Output, logbase tf.Output, sign_decay tf.Output, beta tf.Output, grad tf.Output, optional ...ResourceApplyPowerSignAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -8017,101 +7554,76 @@ func NonMaxSuppression(scope *Scope, boxes tf.Output, scores tf.Output, max_outp a(attrs) } opspec := tf.OpSpec{ - Type: "NonMaxSuppression", + Type: "ResourceApplyPowerSign", Input: []tf.Input{ - boxes, scores, max_output_size, + var_, m, lr, logbase, sign_decay, beta, grad, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Creates a dataset that emits `components` as a tuple of tensors once. -func TensorDataset(scope *Scope, components []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "TensorDataset", - Input: []tf.Input{ - tf.OutputList(components), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} +// CumprodAttr is an optional argument to Cumprod. +type CumprodAttr func(optionalAttr) -// Component-wise multiplies a SparseTensor by a dense Tensor. -// -// The output locations corresponding to the implicitly zero elements in the sparse -// tensor will be zero (i.e., will not take up storage space), regardless of the -// contents of the dense tensor (even if it's +/-INF and that INF*0 == NaN). -// -// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not -// the other direction. -// -// Arguments: -// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. -// sp_shape: 1-D. Shape of the input SparseTensor. -// dense: `R`-D. The dense Tensor operand. +// CumprodExclusive sets the optional exclusive attribute to value. // -// Returns 1-D. The `N` values that are operated on. -func SparseDenseCwiseMul(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseDenseCwiseMul", - Input: []tf.Input{ - sp_indices, sp_values, sp_shape, dense, - }, +// value: If `True`, perform exclusive cumprod. +// If not specified, defaults to false +func CumprodExclusive(value bool) CumprodAttr { + return func(m optionalAttr) { + m["exclusive"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// ResourceSparseApplyFtrlAttr is an optional argument to ResourceSparseApplyFtrl. -type ResourceSparseApplyFtrlAttr func(optionalAttr) - -// ResourceSparseApplyFtrlUseLocking sets the optional use_locking attribute to value. +// CumprodReverse sets the optional reverse attribute to value. // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. +// value: A `bool` (default: False). // If not specified, defaults to false -func ResourceSparseApplyFtrlUseLocking(value bool) ResourceSparseApplyFtrlAttr { +func CumprodReverse(value bool) CumprodAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["reverse"] = value } } -// Update relevant entries in '*var' according to the Ftrl-proximal scheme. +// Compute the cumulative product of the tensor `x` along `axis`. +// +// By default, this op performs an inclusive cumprod, which means that the first +// element of the input is identical to the first element of the output: +// +// ```python +// tf.cumprod([a, b, c]) # => [a, a * b, a * b * c] +// ``` +// +// By setting the `exclusive` kwarg to `True`, an exclusive cumprod is +// performed instead: +// +// ```python +// tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] +// ``` +// +// By setting the `reverse` kwarg to `True`, the cumprod is performed in the +// opposite direction: // -// That is for rows we have grad for, we update var, accum and linear as follows: -// accum_new = accum + grad * grad -// linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var -// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 -// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 -// accum = accum_new +// ```python +// tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] +// ``` // -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// linear: Should be from a Variable(). -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// lr_power: Scaling factor. Must be a scalar. +// This is more efficient than using separate `tf.reverse` ops. // -// Returns the created operation. -func ResourceSparseApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlAttr) (o *tf.Operation) { +// The `reverse` and `exclusive` kwargs can also be combined: +// +// ```python +// tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] +// ``` +// +// Arguments: +// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, +// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, +// `complex128`, `qint8`, `quint8`, `qint32`, `half`. +// axis: A `Tensor` of type `int32` (default: 0). Must be in the range +// `[-rank(x), rank(x))`. +func Cumprod(scope *Scope, x tf.Output, axis tf.Output, optional ...CumprodAttr) (out tf.Output) { if scope.Err() != nil { return } @@ -8120,57 +7632,82 @@ func ResourceSparseApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, line a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyFtrl", + Type: "Cumprod", Input: []tf.Input{ - var_, accum, linear, grad, indices, lr, l1, l2, lr_power, + x, axis, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Returns which elements of x are Inf. +// Computes the mean along segments of a tensor. // -// @compatibility(numpy) -// Equivalent to np.isinf -// @end_compatibility -func IsInf(scope *Scope, x tf.Output) (y tf.Output) { +// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of +// segments. +// +// Computes a tensor such that +// \\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is +// over `j` such that `segment_ids[j] == i` and `N` is the total number of +// values summed. +// +// If the mean is empty for a given segment ID `i`, `output[i] = 0`. +// +//
+// +//
+// +// Arguments: +// +// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s +// first dimension. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SegmentMean(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "IsInf", + Type: "SegmentMean", Input: []tf.Input{ - x, + data, segment_ids, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp. -type ResourceSparseApplyRMSPropAttr func(optionalAttr) +// ResourceSparseApplyCenteredRMSPropAttr is an optional argument to ResourceSparseApplyCenteredRMSProp. +type ResourceSparseApplyCenteredRMSPropAttr func(optionalAttr) -// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value. +// ResourceSparseApplyCenteredRMSPropUseLocking sets the optional use_locking attribute to value. // -// value: If `True`, updating of the var, ms, and mom tensors is protected -// by a lock; otherwise the behavior is undefined, but may exhibit less +// value: If `True`, updating of the var, mg, ms, and mom tensors is +// protected by a lock; otherwise the behavior is undefined, but may exhibit less // contention. // If not specified, defaults to false -func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr { +func ResourceSparseApplyCenteredRMSPropUseLocking(value bool) ResourceSparseApplyCenteredRMSPropAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Update '*var' according to the RMSProp algorithm. +// Update '*var' according to the centered RMSProp algorithm. // -// Note that in dense implementation of this algorithm, ms and mom will -// update even if the grad is zero, but in this sparse implementation, ms +// The centered RMSProp algorithm uses an estimate of the centered second moment +// (i.e., the variance) for normalization, as opposed to regular RMSProp, which +// uses the (uncentered) second moment. This often helps with training, but is +// slightly more expensive in terms of computation and memory. +// +// Note that in dense implementation of this algorithm, mg, ms, and mom will +// update even if the grad is zero, but in this sparse implementation, mg, ms, // and mom will not update in iterations during which the grad is zero. // // mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) +// mean_grad = decay * mean_grad + (1-decay) * gradient +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) // // ms <- rho * ms_{t-1} + (1-rho) * grad * grad // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) @@ -8178,6 +7715,7 @@ func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSProp // // Arguments: // var_: Should be from a Variable(). +// mg: Should be from a Variable(). // ms: Should be from a Variable(). // mom: Should be from a Variable(). // lr: Scaling factor. Must be a scalar. @@ -8188,7 +7726,7 @@ func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSProp // indices: A vector of indices into the first dimension of var, ms and mom. // // Returns the created operation. -func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) { +func ResourceSparseApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyCenteredRMSPropAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -8197,168 +7735,200 @@ func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyRMSProp", + Type: "ResourceSparseApplyCenteredRMSProp", Input: []tf.Input{ - var_, ms, mom, lr, rho, momentum, epsilon, grad, indices, + var_, mg, ms, mom, lr, rho, momentum, epsilon, grad, indices, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// Returns the truth value of (x > y) element-wise. +// Creates a dataset that batches `batch_size` elements from `input_dataset`. // -// *NOTE*: `Greater` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Arguments: +// +// batch_size: A scalar representing the number of elements to accumulate in a +// batch. +// +// +func BatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Greater", + Type: "BatchDataset", Input: []tf.Input{ - x, y, + input_dataset, batch_size, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox. -type SampleDistortedBoundingBoxAttr func(optionalAttr) - -// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value. +// Inverse fast Fourier transform. // -// value: If either `seed` or `seed2` are set to non-zero, the random number -// generator is seeded by the given `seed`. Otherwise, it is seeded by a random -// seed. -// If not specified, defaults to 0 -func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["seed"] = value +// Computes the inverse 1-dimensional discrete Fourier transform over the +// inner-most dimension of `input`. +// +// Arguments: +// input: A complex64 tensor. +// +// Returns A complex64 tensor of the same shape as `input`. The inner-most +// dimension of `input` is replaced with its inverse 1D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.ifft +// @end_compatibility +func IFFT(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IFFT", + Input: []tf.Input{ + input, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value. +// Generates values in an interval. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["seed2"] = value +// A sequence of `num` evenly-spaced values are generated beginning at `start`. +// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`, +// so that the last one is exactly `stop`. +// +// For example: +// +// ``` +// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] +// ``` +// +// Arguments: +// start: First entry in the range. +// stop: Last entry in the range. +// num: Number of values to generate. +// +// Returns 1-D. The generated values. +func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LinSpace", + Input: []tf.Input{ + start, stop, num, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value. +// DestroyResourceOpAttr is an optional argument to DestroyResourceOp. +type DestroyResourceOpAttr func(optionalAttr) + +// DestroyResourceOpIgnoreLookupError sets the optional ignore_lookup_error attribute to value. // -// value: The cropped area of the image must contain at least this -// fraction of any bounding box supplied. The value of this parameter should be -// non-negative. In the case of 0, the cropped area does not need to overlap -// any of the bounding boxes supplied. -// If not specified, defaults to 0.1 -func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr { +// value: whether to ignore the error when the resource +// doesn't exist. +// If not specified, defaults to true +func DestroyResourceOpIgnoreLookupError(value bool) DestroyResourceOpAttr { return func(m optionalAttr) { - m["min_object_covered"] = value + m["ignore_lookup_error"] = value } } -// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value. +// Deletes the resource specified by the handle. // -// value: The cropped area of the image must have an aspect ratio = -// width / height within this range. -// If not specified, defaults to -func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { - return func(m optionalAttr) { - m["aspect_ratio_range"] = value +// All subsequent operations using the resource will result in a NotFound +// error status. +// +// Arguments: +// resource: handle to the resource to delete. +// +// Returns the created operation. +func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyResourceOpAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DestroyResourceOp", + Input: []tf.Input{ + resource, + }, + Attrs: attrs, } + return scope.AddOperation(opspec) } -// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value. +// LRNAttr is an optional argument to LRN. +type LRNAttr func(optionalAttr) + +// LRNDepthRadius sets the optional depth_radius attribute to value. // -// value: The cropped area of the image must contain a fraction of the -// supplied image within in this range. -// If not specified, defaults to -func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { +// value: 0-D. Half-width of the 1-D normalization window. +// If not specified, defaults to 5 +func LRNDepthRadius(value int64) LRNAttr { return func(m optionalAttr) { - m["area_range"] = value + m["depth_radius"] = value } } -// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value. +// LRNBias sets the optional bias attribute to value. // -// value: Number of attempts at generating a cropped region of the image -// of the specified constraints. After `max_attempts` failures, return the entire -// image. -// If not specified, defaults to 100 -func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr { +// value: An offset (usually positive to avoid dividing by 0). +// If not specified, defaults to 1 +func LRNBias(value float32) LRNAttr { return func(m optionalAttr) { - m["max_attempts"] = value + m["bias"] = value } } -// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value. +// LRNAlpha sets the optional alpha attribute to value. // -// value: Controls behavior if no bounding boxes supplied. -// If true, assume an implicit bounding box covering the whole input. If false, -// raise an error. -// If not specified, defaults to false -func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr { +// value: A scale factor, usually positive. +// If not specified, defaults to 1 +func LRNAlpha(value float32) LRNAttr { return func(m optionalAttr) { - m["use_image_if_no_bounding_boxes"] = value + m["alpha"] = value } } -// Generate a single randomly distorted bounding box for an image. -// -// Bounding box annotations are often supplied in addition to ground-truth labels -// in image recognition or object localization tasks. A common technique for -// training such a system is to randomly distort an image while preserving -// its content, i.e. *data augmentation*. This Op outputs a randomly distorted -// localization of an object, i.e. bounding box, given an `image_size`, -// `bounding_boxes` and a series of constraints. -// -// The output of this Op is a single bounding box that may be used to crop the -// original image. The output is returned as 3 tensors: `begin`, `size` and -// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the -// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize -// what the bounding box looks like. -// -// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The -// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and -// height of the underlying image. -// -// For example, +// LRNBeta sets the optional beta attribute to value. // -// ```python -// # Generate a single distorted bounding box. -// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( -// tf.shape(image), -// bounding_boxes=bounding_boxes) +// value: An exponent. +// If not specified, defaults to 0.5 +func LRNBeta(value float32) LRNAttr { + return func(m optionalAttr) { + m["beta"] = value + } +} + +// Local Response Normalization. // -// # Draw the bounding box in an image summary. -// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), -// bbox_for_draw) -// tf.summary.image('images_with_box', image_with_box) +// The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last +// dimension), and each vector is normalized independently. Within a given vector, +// each component is divided by the weighted, squared sum of inputs within +// `depth_radius`. In detail, // -// # Employ the bounding box to distort the image. -// distorted_image = tf.slice(image, begin, size) -// ``` +// sqr_sum[a, b, c, d] = +// sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) +// output = input / (bias + alpha * sqr_sum) ** beta // -// Note that if no bounding box information is available, setting -// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit -// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is -// false and no bounding boxes are supplied, an error is raised. +// For details, see [Krizhevsky et al., ImageNet classification with deep +// convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks). // // Arguments: -// image_size: 1-D, containing `[height, width, channels]`. -// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes -// associated with the image. -// -// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to -// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to -// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box. -// Provide as input to `tf.image.draw_bounding_boxes`. -func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) { +// input: 4-D. +func LRN(scope *Scope, input tf.Output, optional ...LRNAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -8367,167 +7937,152 @@ func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_box a(attrs) } opspec := tf.OpSpec{ - Type: "SampleDistortedBoundingBox", + Type: "LRN", Input: []tf.Input{ - image_size, bounding_boxes, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Returns x / y element-wise for integer types. -// -// Truncation designates that negative numbers will round fractional quantities -// toward zero. I.e. -7 / 5 = -1. This matches C semantics but it is different -// than Python semantics. See `FloorDiv` for a division function that matches -// Python Semantics. -// -// *NOTE*: `TruncateDiv` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func TruncateDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Creates a dataset that zips together `input_datasets`. +func ZipDataset(scope *Scope, input_datasets []tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "TruncateDiv", + Type: "ZipDataset", Input: []tf.Input{ - x, y, + tf.OutputList(input_datasets), }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Restores tensors from a V2 checkpoint. -// -// For backward compatibility with the V1 format, this Op currently allows -// restoring from a V1 checkpoint as well: -// - This Op first attempts to find the V2 index file pointed to by "prefix", and -// if found proceed to read it as a V2 checkpoint; -// - Otherwise the V1 read path is invoked. -// Relying on this behavior is not recommended, as the ability to fall back to read -// V1 might be deprecated and eventually removed. +// ResourceSparseApplyAdagradAttr is an optional argument to ResourceSparseApplyAdagrad. +type ResourceSparseApplyAdagradAttr func(optionalAttr) + +// ResourceSparseApplyAdagradUseLocking sets the optional use_locking attribute to value. // -// By default, restores the named tensors in full. If the caller wishes to restore -// specific slices of stored tensors, "shape_and_slices" should be non-empty -// strings and correspondingly well-formed. +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyAdagradUseLocking(value bool) ResourceSparseApplyAdagradAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update relevant entries in '*var' and '*accum' according to the adagrad scheme. // -// Callers must ensure all the named tensors are indeed stored in the checkpoint. +// That is for rows we have grad for, we update var and accum as follows: +// accum += grad * grad +// var -= lr * grad * (1 / sqrt(accum)) // // Arguments: -// prefix: Must have a single element. The prefix of a V2 checkpoint. -// tensor_names: shape {N}. The names of the tensors to be restored. -// shape_and_slices: shape {N}. The slice specs of the tensors to be restored. -// Empty strings indicate that they are non-partitioned tensors. -// dtypes: shape {N}. The list of expected dtype for the tensors. Must match -// those stored in the checkpoint. +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. // -// Returns shape {N}. The restored tensors, whose shapes are read from the -// checkpoint directly. -func RestoreV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and_slices tf.Output, dtypes []tf.DataType) (tensors []tf.Output) { +// Returns the created operation. +func ResourceSparseApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdagradAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "RestoreV2", + Type: "ResourceSparseApplyAdagrad", Input: []tf.Input{ - prefix, tensor_names, shape_and_slices, + var_, accum, lr, grad, indices, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if tensors, idx, err = makeOutputList(op, idx, "tensors"); err != nil { - scope.UpdateErr("RestoreV2", err) - return - } - return tensors + return scope.AddOperation(opspec) } -// Decode web-safe base64-encoded strings. +// 2D real-valued fast Fourier transform. // -// Input may or may not have padding at the end. See EncodeBase64 for padding. -// Web-safe means that input must use - and _ instead of + and /. +// Computes the 2-dimensional discrete Fourier transform of a real-valued signal +// over the inner-most 2 dimensions of `input`. // -// Arguments: -// input: Base64 strings to decode. +// Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the +// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension +// of `output`: the zero-frequency term, followed by the `fft_length / 2` +// positive-frequency terms. // -// Returns Decoded strings. -func DecodeBase64(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DecodeBase64", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Store the input tensor in the state of the current session. +// Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. // // Arguments: -// value: The tensor to be stored. +// input: A float32 tensor. +// fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. // -// Returns The handle for the tensor stored in the session state, represented -// as a string. -func GetSessionHandle(scope *Scope, value tf.Output) (handle tf.Output) { +// Returns A complex64 tensor of the same rank as `input`. The inner-most 2 +// dimensions of `input` are replaced with their 2D Fourier transform. The +// inner-most dimension contains `fft_length / 2 + 1` unique frequency +// components. +// +// @compatibility(numpy) +// Equivalent to np.fft.rfft2 +// @end_compatibility +func RFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "GetSessionHandle", + Type: "RFFT2D", Input: []tf.Input{ - value, + input, fft_length, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceSparseApplyProximalAdagradAttr is an optional argument to ResourceSparseApplyProximalAdagrad. -type ResourceSparseApplyProximalAdagradAttr func(optionalAttr) +// ResizeAreaAttr is an optional argument to ResizeArea. +type ResizeAreaAttr func(optionalAttr) -// ResourceSparseApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. +// ResizeAreaAlignCorners sets the optional align_corners attribute to value. // -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// value: If true, rescale input by (new_height - 1) / (height - 1), which +// exactly aligns the 4 corners of images and resized images. If false, rescale +// by new_height / height. Treat similarly the width dimension. // If not specified, defaults to false -func ResourceSparseApplyProximalAdagradUseLocking(value bool) ResourceSparseApplyProximalAdagradAttr { +func ResizeAreaAlignCorners(value bool) ResizeAreaAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["align_corners"] = value } } -// Sparse update entries in '*var' and '*accum' according to FOBOS algorithm. +// Resize `images` to `size` using area interpolation. // -// That is for rows we have grad for, we update var and accum as follows: -// accum += grad * grad -// prox_v = var -// prox_v -= lr * grad * (1 / sqrt(accum)) -// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} +// Input images can be of different types but output images are always float. +// +// Each output pixel is computed by first transforming the pixel's footprint into +// the input tensor and then averaging the pixels that intersect the footprint. An +// input pixel's contribution to the average is weighted by the fraction of its +// area that intersects the footprint. This is the same as OpenCV's INTER_AREA. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. // -// Returns the created operation. -func ResourceSparseApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalAdagradAttr) (o *tf.Operation) { +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeAreaAttr) (resized_images tf.Output) { if scope.Err() != nil { return } @@ -8536,113 +8091,112 @@ func ResourceSparseApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.O a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyProximalAdagrad", + Type: "ResizeArea", Input: []tf.Input{ - var_, accum, lr, l1, l2, grad, indices, + images, size, }, Attrs: attrs, } - return scope.AddOperation(opspec) -} - -// Returns element-wise largest integer not greater than x. -func Floor(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Floor", - Input: []tf.Input{ - x, - }, - } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the Gauss error function of `x` element-wise. -func Erf(scope *Scope, x tf.Output) (y tf.Output) { +// Pads a tensor with zeros. +// +// This operation pads a `input` with zeros according to the `paddings` you +// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the +// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates +// how many zeros to add before the contents of `input` in that dimension, and +// `paddings[D, 1]` indicates how many zeros to add after the contents of `input` +// in that dimension. +// +// The padded size of each dimension D of the output is: +// +// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` +// +// For example: +// +// ``` +// # 't' is [[1, 1], [2, 2]] +// # 'paddings' is [[1, 1], [2, 2]] +// # rank of 't' is 2 +// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] +// [0, 0, 1, 1, 0, 0] +// [0, 0, 2, 2, 0, 0] +// [0, 0, 0, 0, 0, 0]] +// ``` +func Pad(scope *Scope, input tf.Output, paddings tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Erf", + Type: "Pad", Input: []tf.Input{ - x, + input, paddings, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Reads the value of a variable. -// -// The tensor returned by this operation is immutable. -// -// The value returned by this operation is guaranteed to be influenced by all the -// writes on which this operation depends directly or indirectly, and to not be -// influenced by any of the writes which depend directly or indirectly on this -// operation. +// Checks whether a resource handle-based variable has been initialized. // // Arguments: -// resource: handle to the resource in which to store the variable. -// dtype: the dtype of the value. -func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value tf.Output) { +// resource: the input resource handle. +// +// Returns a scalar boolean which is true if the variable has been +// initialized. +func VarIsInitializedOp(scope *Scope, resource tf.Output) (is_initialized tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "ReadVariableOp", + Type: "VarIsInitializedOp", Input: []tf.Input{ resource, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// MaxPool3DGradAttr is an optional argument to MaxPool3DGrad. -type MaxPool3DGradAttr func(optionalAttr) +// StatelessRandomUniformAttr is an optional argument to StatelessRandomUniform. +type StatelessRandomUniformAttr func(optionalAttr) -// MaxPool3DGradDataFormat sets the optional data_format attribute to value. +// StatelessRandomUniformDtype sets the optional dtype attribute to value. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func MaxPool3DGradDataFormat(value string) MaxPool3DGradAttr { +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatelessRandomUniformDtype(value tf.DataType) StatelessRandomUniformAttr { return func(m optionalAttr) { - m["data_format"] = value + m["dtype"] = value } } -// Computes gradients of max pooling function. +// Outputs deterministic pseudorandom random values from a uniform distribution. +// +// The generated values follow a uniform distribution in the range `[0, 1)`. The +// lower bound 0 is included in the range, while the upper bound 1 is excluded. +// +// The outputs are a deterministic function of `shape` and `seed`. // // Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func MaxPool3DGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradAttr) (output tf.Output) { +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). +// +// Returns Random values with specified shape. +func StatelessRandomUniform(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomUniformAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MaxPool3DGrad", + Type: "StatelessRandomUniform", Input: []tf.Input{ - orig_input, orig_output, grad, + shape, seed, }, Attrs: attrs, } @@ -8650,188 +8204,143 @@ func MaxPool3DGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, gr return op.Output(0) } -// SparseReduceSumAttr is an optional argument to SparseReduceSum. -type SparseReduceSumAttr func(optionalAttr) - -// SparseReduceSumKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func SparseReduceSumKeepDims(value bool) SparseReduceSumAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the sum of elements across dimensions of a SparseTensor. -// -// This Op takes a SparseTensor and is the sparse counterpart to -// `tf.reduce_sum()`. In particular, this Op also returns a dense `Tensor` -// instead of a sparse one. -// -// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained -// with length 1. -// -// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor -// with a single element is returned. Additionally, the axes can be negative, -// which are interpreted according to the indexing rules in Python. +// Makes its input available to the next iteration. // // Arguments: -// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. -// input_shape: 1-D. Shape of the input SparseTensor. -// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. +// data: The tensor to be made available to the next iteration. // -// Returns `R-K`-D. The reduced Tensor. -func SparseReduceSum(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceSumAttr) (output tf.Output) { +// Returns The same tensor as `data`. +func NextIteration(scope *Scope, data tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "SparseReduceSum", + Type: "NextIteration", Input: []tf.Input{ - input_indices, input_values, input_shape, reduction_axes, + data, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Partitions `data` into `num_partitions` tensors using indices from `partitions`. -// -// For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]` -// becomes part of `outputs[partitions[js]]`. The slices with `partitions[js] = i` -// are placed in `outputs[i]` in lexicographic order of `js`, and the first -// dimension of `outputs[i]` is the number of entries in `partitions` equal to `i`. -// In detail, -// -// ```python -// outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:] +// Output a fact about factorials. +func Fact(scope *Scope) (fact tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Fact", + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// AngleAttr is an optional argument to Angle. +type AngleAttr func(optionalAttr) + +// AngleTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_FLOAT +func AngleTout(value tf.DataType) AngleAttr { + return func(m optionalAttr) { + m["Tout"] = value + } +} + +// Returns the argument of a complex number. // -// outputs[i] = pack([data[js, ...] for js if partitions[js] == i]) -// ``` +// Given a tensor `input` of complex numbers, this operation returns a tensor of +// type `float` that is the argument of each element in `input`. All elements in +// `input` must be complex numbers of the form \\(a + bj\\), where *a* +// is the real part and *b* is the imaginary part. // -// `data.shape` must start with `partitions.shape`. +// The argument returned by this operation is of the form \\(atan2(b, a)\\). // // For example: // -// ```python -// # Scalar partitions. -// partitions = 1 -// num_partitions = 2 -// data = [10, 20] -// outputs[0] = [] # Empty with shape [0, 2] -// outputs[1] = [[10, 20]] -// -// # Vector partitions. -// partitions = [0, 0, 1, 1, 0] -// num_partitions = 2 -// data = [10, 20, 30, 40, 50] -// outputs[0] = [10, 20, 50] -// outputs[1] = [30, 40] +// ``` +// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +// tf.angle(input) ==> [2.0132, 1.056] // ``` // -// See `dynamic_stitch` for an example on how to merge partitions back. -// -//
-// -//
-// -// Arguments: -// -// partitions: Any shape. Indices in the range `[0, num_partitions)`. -// num_partitions: The number of partitions to output. -func DynamicPartition(scope *Scope, data tf.Output, partitions tf.Output, num_partitions int64) (outputs []tf.Output) { +// @compatibility(numpy) +// Equivalent to np.angle. +// @end_compatibility +func Angle(scope *Scope, input tf.Output, optional ...AngleAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_partitions": num_partitions} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "DynamicPartition", + Type: "Angle", Input: []tf.Input{ - data, partitions, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { - scope.UpdateErr("DynamicPartition", err) - return - } - return outputs + return op.Output(0) } -// ResourceApplyAdagradAttr is an optional argument to ResourceApplyAdagrad. -type ResourceApplyAdagradAttr func(optionalAttr) +// VarHandleOpAttr is an optional argument to VarHandleOp. +type VarHandleOpAttr func(optionalAttr) -// ResourceApplyAdagradUseLocking sets the optional use_locking attribute to value. +// VarHandleOpContainer sets the optional container attribute to value. // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyAdagradUseLocking(value bool) ResourceApplyAdagradAttr { +// value: the container this variable is placed in. +// If not specified, defaults to "" +func VarHandleOpContainer(value string) VarHandleOpAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["container"] = value } } -// Update '*var' according to the adagrad scheme. +// VarHandleOpSharedName sets the optional shared_name attribute to value. // -// accum += grad * grad -// var -= lr * grad * (1 / sqrt(accum)) +// value: the name by which this variable is referred to. +// If not specified, defaults to "" +func VarHandleOpSharedName(value string) VarHandleOpAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Creates a handle to a Variable resource. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// grad: The gradient. -// -// Returns the created operation. -func ResourceApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, optional ...ResourceApplyAdagradAttr) (o *tf.Operation) { +// dtype: the type of this variable. Must agree with the dtypes +// of all ops using this variable. +// shape: The (possibly partially specified) shape of this variable. +func VarHandleOp(scope *Scope, dtype tf.DataType, shape tf.Shape, optional ...VarHandleOpAttr) (resource tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtype": dtype, "shape": shape} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyAdagrad", - Input: []tf.Input{ - var_, accum, lr, grad, - }, + Type: "VarHandleOp", + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Returns element-wise remainder of division. This emulates C semantics in that -// -// the result here is consistent with a truncating divide. E.g. `truncate(x / y) * -// y + truncate_mod(x, y) = x`. +// Elementwise computes the bitwise XOR of `x` and `y`. // -// *NOTE*: `TruncateMod` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func TruncateMod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// The result will have those bits set, that are different in `x` and `y`. The +// computation is performed on the underlying representations of `x` and `y`. +func BitwiseXor(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TruncateMod", + Type: "BitwiseXor", Input: []tf.Input{ x, y, }, @@ -8840,149 +8349,150 @@ func TruncateMod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// Inverse 2D real-valued fast Fourier transform. +// Deserialize `SparseTensor` objects. // -// Computes the inverse 2-dimensional discrete Fourier transform of a real-valued -// signal over the inner-most 2 dimensions of `input`. +// The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where +// the last dimension stores serialized `SparseTensor` objects and the other N +// dimensions (N >= 0) correspond to a batch. The ranks of the original +// `SparseTensor` objects must all match. When the final `SparseTensor` is +// created, its rank is the rank of the incoming `SparseTensor` objects plus N; +// the sparse tensors have been concatenated along new dimensions, one for each +// batch. // -// The inner-most 2 dimensions of `input` are assumed to be the result of `RFFT2D`: -// The inner-most dimension contains the `fft_length / 2 + 1` unique components of -// the DFT of a real-valued signal. If `fft_length` is not provided, it is computed -// from the size of the inner-most 2 dimensions of `input`. If the FFT length used -// to compute `input` is odd, it should be provided since it cannot be inferred -// properly. +// The output `SparseTensor` object's shape values for the original dimensions +// are the max across the input `SparseTensor` objects' shape values for the +// corresponding dimensions. The new dimensions match the size of the batch. // -// Along each axis `IRFFT2D` is computed on, if `fft_length` (or -// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. +// The input `SparseTensor` objects' indices are assumed ordered in +// standard lexicographic order. If this is not the case, after this +// step run `SparseReorder` to restore index ordering. // -// Arguments: -// input: A complex64 tensor. -// fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. +// For example, if the serialized input is a `[2 x 3]` matrix representing two +// original `SparseTensor` objects: // -// Returns A float32 tensor of the same rank as `input`. The inner-most 2 -// dimensions of `input` are replaced with the `fft_length` samples of their -// inverse 2D Fourier transform. +// index = [ 0] +// [10] +// [20] +// values = [1, 2, 3] +// shape = [50] // -// @compatibility(numpy) -// Equivalent to np.fft.irfft2 -// @end_compatibility -func IRFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "IRFFT2D", - Input: []tf.Input{ - input, fft_length, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Transforms a vector of brain.Example protos (as strings) into typed tensors. +// and +// +// index = [ 2] +// [10] +// values = [4, 5] +// shape = [30] +// +// then the final deserialized `SparseTensor` will be: +// +// index = [0 0] +// [0 10] +// [0 20] +// [1 2] +// [1 10] +// values = [1, 2, 3, 4, 5] +// shape = [2 50] // // Arguments: -// serialized: A vector containing a batch of binary serialized Example protos. -// names: A vector containing the names of the serialized protos. -// May contain, for example, table key (descriptive) names for the -// corresponding serialized protos. These are purely useful for debugging -// purposes, and the presence of values here has no effect on the output. -// May also be an empty vector if no names are available. -// If non-empty, this vector must be the same length as "serialized". -// sparse_keys: A list of Nsparse string Tensors (scalars). -// The keys expected in the Examples' features associated with sparse values. -// dense_keys: A list of Ndense string Tensors (scalars). -// The keys expected in the Examples' features associated with dense values. -// dense_defaults: A list of Ndense Tensors (some may be empty). -// dense_defaults[j] provides default values -// when the example's feature_map lacks dense_key[j]. If an empty Tensor is -// provided for dense_defaults[j], then the Feature dense_keys[j] is required. -// The input type is inferred from dense_defaults[j], even when it's empty. -// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, -// then the shape of dense_defaults[j] must match that of dense_shapes[j]. -// If dense_shapes[j] has an undefined major dimension (variable strides dense -// feature), dense_defaults[j] must contain a single element: -// the padding element. -// sparse_types: A list of Nsparse types; the data types of data in each Feature -// given in sparse_keys. -// Currently the ParseExample supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// dense_shapes: A list of Ndense shapes; the shapes of data in each Feature -// given in dense_keys. -// The number of elements in the Feature corresponding to dense_key[j] -// must always equal dense_shapes[j].NumEntries(). -// If dense_shapes[j] == (D0, D1, ..., DN) then the shape of output -// Tensor dense_values[j] will be (|serialized|, D0, D1, ..., DN): -// The dense outputs are just the inputs row-stacked by batch. -// This works for dense_shapes[j] = (-1, D1, ..., DN). In this case -// the shape of the output Tensor dense_values[j] will be -// (|serialized|, M, D1, .., DN), where M is the maximum number of blocks -// of elements of length D1 * .... * DN, across all minibatch entries -// in the input. Any minibatch entry with less than M blocks of elements of -// length D1 * ... * DN will be padded with the corresponding default_value -// scalar element along the second dimension. -func ParseExample(scope *Scope, serialized tf.Output, names tf.Output, sparse_keys []tf.Output, dense_keys []tf.Output, dense_defaults []tf.Output, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { +// serialized_sparse: The serialized `SparseTensor` objects. The last dimension +// must have 3 columns. +// dtype: The `dtype` of the serialized `SparseTensor` objects. +func DeserializeSparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"sparse_types": sparse_types, "dense_shapes": dense_shapes} + attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "ParseExample", + Type: "DeserializeSparse", Input: []tf.Input{ - serialized, names, tf.OutputList(sparse_keys), tf.OutputList(dense_keys), tf.OutputList(dense_defaults), + serialized_sparse, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil { - scope.UpdateErr("ParseExample", err) - return + return op.Output(0), op.Output(1), op.Output(2) +} + +// ResourceApplyRMSPropAttr is an optional argument to ResourceApplyRMSProp. +type ResourceApplyRMSPropAttr func(optionalAttr) + +// ResourceApplyRMSPropUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var, ms, and mom tensors is protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyRMSPropUseLocking(value bool) ResourceApplyRMSPropAttr { + return func(m optionalAttr) { + m["use_locking"] = value } - if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil { - scope.UpdateErr("ParseExample", err) +} + +// Update '*var' according to the RMSProp algorithm. +// +// Note that in dense implementation of this algorithm, ms and mom will +// update even if the grad is zero, but in this sparse implementation, ms +// and mom will not update in iterations during which the grad is zero. +// +// mean_square = decay * mean_square + (1-decay) * gradient ** 2 +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) +// +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +// var <- var - mom +// +// Arguments: +// var_: Should be from a Variable(). +// ms: Should be from a Variable(). +// mom: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// rho: Decay rate. Must be a scalar. +// +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyRMSPropAttr) (o *tf.Operation) { + if scope.Err() != nil { return } - if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil { - scope.UpdateErr("ParseExample", err) - return + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) } - if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil { - scope.UpdateErr("ParseExample", err) - return + opspec := tf.OpSpec{ + Type: "ResourceApplyRMSProp", + Input: []tf.Input{ + var_, ms, mom, lr, rho, momentum, epsilon, grad, + }, + Attrs: attrs, } - return sparse_indices, sparse_values, sparse_shapes, dense_values + return scope.AddOperation(opspec) } -// VariableShapeAttr is an optional argument to VariableShape. -type VariableShapeAttr func(optionalAttr) +// SizeAttr is an optional argument to Size. +type SizeAttr func(optionalAttr) -// VariableShapeOutType sets the optional out_type attribute to value. +// SizeOutType sets the optional out_type attribute to value. // If not specified, defaults to DT_INT32 -func VariableShapeOutType(value tf.DataType) VariableShapeAttr { +func SizeOutType(value tf.DataType) SizeAttr { return func(m optionalAttr) { m["out_type"] = value } } -// Returns the shape of the variable pointed to by `resource`. +// Returns the size of a tensor. // -// This operation returns a 1-D integer tensor representing the shape of `input`. +// This operation returns an integer representing the number of elements in +// `input`. // // For example: // // ``` -// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] -// shape(t) ==> [2, 2, 3] +// # 't' is [[[1, 1,, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]] +// size(t) ==> 12 // ``` -func VariableShape(scope *Scope, input tf.Output, optional ...VariableShapeAttr) (output tf.Output) { +func Size(scope *Scope, input tf.Output, optional ...SizeAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -8991,7 +8501,7 @@ func VariableShape(scope *Scope, input tf.Output, optional ...VariableShapeAttr) a(attrs) } opspec := tf.OpSpec{ - Type: "VariableShape", + Type: "Size", Input: []tf.Input{ input, }, @@ -9001,232 +8511,250 @@ func VariableShape(scope *Scope, input tf.Output, optional ...VariableShapeAttr) return op.Output(0) } -// Fills empty rows in the input 2-D `SparseTensor` with a default value. +// ResourceScatterNdUpdateAttr is an optional argument to ResourceScatterNdUpdate. +type ResourceScatterNdUpdateAttr func(optionalAttr) + +// ResourceScatterNdUpdateUseLocking sets the optional use_locking attribute to value. // -// The input `SparseTensor` is represented via the tuple of inputs -// (`indices`, `values`, `dense_shape`). The output `SparseTensor` has the -// same `dense_shape` but with indices `output_indices` and values -// `output_values`. +// value: An optional bool. Defaults to True. If True, the assignment will +// be protected by a lock; otherwise the behavior is undefined, +// but may exhibit less contention. +// If not specified, defaults to true +func ResourceScatterNdUpdateUseLocking(value bool) ResourceScatterNdUpdateAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Applies sparse `updates` to individual values or slices within a given // -// This op inserts a single entry for every row that doesn't have any values. -// The index is created as `[row, 0, ..., 0]` and the inserted value -// is `default_value`. +// variable according to `indices`. // -// For example, suppose `sp_input` has shape `[5, 6]` and non-empty values: +// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. // -// [0, 1]: a -// [0, 3]: b -// [2, 0]: c -// [3, 1]: d +// `indices` must be integer tensor, containing indices into `ref`. +// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. // -// Rows 1 and 4 are empty, so the output will be of shape `[5, 6]` with values: +// The innermost dimension of `indices` (with length `K`) corresponds to +// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +// dimension of `ref`. // -// [0, 1]: a -// [0, 3]: b -// [1, 0]: default_value -// [2, 0]: c -// [3, 1]: d -// [4, 0]: default_value +// `updates` is `Tensor` of rank `Q-1+P-K` with shape: // -// The output `SparseTensor` will be in row-major order and will have the -// same shape as the input. +// ``` +// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. +// ``` // -// This op also returns an indicator vector shaped `[dense_shape[0]]` such that +// For example, say we want to update 4 scattered elements to a rank-1 tensor to +// 8 elements. In Python, that update would look like this: // -// empty_row_indicator[i] = True iff row i was an empty row. +// ```python +// ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8]) +// indices = tf.constant([[4], [3], [1] ,[7]]) +// updates = tf.constant([9, 10, 11, 12]) +// update = tf.scatter_nd_update(ref, indices, updates) +// with tf.Session() as sess: +// print sess.run(update) +// ``` // -// And a reverse index map vector shaped `[indices.shape[0]]` that is used during -// backpropagation, +// The resulting update to ref would look like this: // -// reverse_index_map[j] = out_j s.t. indices[j, :] == output_indices[out_j, :] +// [1, 11, 3, 10, 9, 6, 7, 12] +// +// See @{tf.scatter_nd} for more details about how to make updates to +// slices. // // Arguments: -// indices: 2-D. the indices of the sparse tensor. -// values: 1-D. the values of the sparse tensor. -// dense_shape: 1-D. the shape of the sparse tensor. -// default_value: 0-D. default value to insert into location `[row, 0, ..., 0]` -// for rows missing from the input sparse tensor. -// output indices: 2-D. the indices of the filled sparse tensor. +// ref: A resource handle. Must be from a VarHandleOp. +// indices: A Tensor. Must be one of the following types: int32, int64. +// A tensor of indices into ref. +// updates: A Tensor. Must have the same type as ref. A tensor of updated +// values to add to ref. // -// Returns 1-D. the values of the filled sparse tensor.1-D. whether the dense row was missing in the -// input sparse tensor.1-D. a map from the input indices to the output indices. -func SparseFillEmptyRows(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output, default_value tf.Output) (output_indices tf.Output, output_values tf.Output, empty_row_indicator tf.Output, reverse_index_map tf.Output) { +// Returns the created operation. +func ResourceScatterNdUpdate(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdUpdateAttr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SparseFillEmptyRows", + Type: "ResourceScatterNdUpdate", Input: []tf.Input{ - indices, values, dense_shape, default_value, + ref, indices, updates, }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) + return scope.AddOperation(opspec) } -// Reverses specific dimensions of a tensor. +// SqueezeAttr is an optional argument to Squeeze. +type SqueezeAttr func(optionalAttr) + +// SqueezeAxis sets the optional axis attribute to value. // -// Given a `tensor`, and a `bool` tensor `dims` representing the dimensions -// of `tensor`, this operation reverses each dimension i of `tensor` where -// `dims[i]` is `True`. +// value: If specified, only squeezes the dimensions listed. The dimension +// index starts at 0. It is an error to squeeze a dimension that is not 1. Must +// be in the range `[-rank(input), rank(input))`. +// If not specified, defaults to <> // -// `tensor` can have up to 8 dimensions. The number of dimensions -// of `tensor` must equal the number of elements in `dims`. In other words: +// REQUIRES: len(value) >= 0 +func SqueezeAxis(value []int64) SqueezeAttr { + return func(m optionalAttr) { + m["squeeze_dims"] = value + } +} + +// Removes dimensions of size 1 from the shape of a tensor. // -// `rank(tensor) = size(dims)` +// Given a tensor `input`, this operation returns a tensor of the same type with +// all dimensions of size 1 removed. If you don't want to remove all size 1 +// dimensions, you can remove specific size 1 dimensions by specifying +// `axis`. // // For example: // // ``` -// # tensor 't' is [[[[ 0, 1, 2, 3], -// # [ 4, 5, 6, 7], -// # [ 8, 9, 10, 11]], -// # [[12, 13, 14, 15], -// # [16, 17, 18, 19], -// # [20, 21, 22, 23]]]] -// # tensor 't' shape is [1, 2, 3, 4] -// -// # 'dims' is [False, False, False, True] -// reverse(t, dims) ==> [[[[ 3, 2, 1, 0], -// [ 7, 6, 5, 4], -// [ 11, 10, 9, 8]], -// [[15, 14, 13, 12], -// [19, 18, 17, 16], -// [23, 22, 21, 20]]]] +// # 't' is a tensor of shape [1, 2, 1, 3, 1, 1] +// shape(squeeze(t)) ==> [2, 3] +// ``` // -// # 'dims' is [False, True, False, False] -// reverse(t, dims) ==> [[[[12, 13, 14, 15], -// [16, 17, 18, 19], -// [20, 21, 22, 23] -// [[ 0, 1, 2, 3], -// [ 4, 5, 6, 7], -// [ 8, 9, 10, 11]]]] +// Or, to remove specific size 1 dimensions: // -// # 'dims' is [False, False, True, False] -// reverse(t, dims) ==> [[[[8, 9, 10, 11], -// [4, 5, 6, 7], -// [0, 1, 2, 3]] -// [[20, 21, 22, 23], -// [16, 17, 18, 19], -// [12, 13, 14, 15]]]] +// ``` +// # 't' is a tensor of shape [1, 2, 1, 3, 1, 1] +// shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] // ``` // // Arguments: -// tensor: Up to 8-D. -// dims: 1-D. The dimensions to reverse. +// input: The `input` to squeeze. // -// Returns The same shape as `tensor`. -func Reverse(scope *Scope, tensor tf.Output, dims tf.Output) (output tf.Output) { +// Returns Contains the same data as `input`, but has one or more dimensions of +// size 1 removed. +func Squeeze(scope *Scope, input tf.Output, optional ...SqueezeAttr) (output tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "Reverse", - Input: []tf.Input{ - tensor, dims, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes log softmax activations. -// -// For each batch `i` and class `j` we have -// -// logsoftmax[i, j] = logits[i, j] - log(sum(exp(logits[i]))) -// -// Arguments: -// logits: 2-D with shape `[batch_size, num_classes]`. -// -// Returns Same shape as `logits`. -func LogSoftmax(scope *Scope, logits tf.Output) (logsoftmax tf.Output) { - if scope.Err() != nil { - return + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) } opspec := tf.OpSpec{ - Type: "LogSoftmax", + Type: "Squeeze", Input: []tf.Input{ - logits, + input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the inverse permutation of a tensor. -// -// This operation computes the inverse of an index permutation. It takes a 1-D -// integer tensor `x`, which represents the indices of a zero-based array, and -// swaps each value with its index position. In other words, for an output tensor -// `y` and an input tensor `x`, this operation computes the following: -// -// `y[x[i]] = i for i in [0, 1, ..., len(x) - 1]` -// -// The values must include 0. There can be no duplicate values or negative values. +// ResourceApplyAdadeltaAttr is an optional argument to ResourceApplyAdadelta. +type ResourceApplyAdadeltaAttr func(optionalAttr) + +// ResourceApplyAdadeltaUseLocking sets the optional use_locking attribute to value. // -// For example: +// value: If True, updating of the var, accum and update_accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceApplyAdadeltaUseLocking(value bool) ResourceApplyAdadeltaAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the adadelta scheme. // -// ``` -// # tensor `x` is [3, 4, 0, 2, 1] -// invert_permutation(x) ==> [2, 4, 3, 0, 1] -// ``` +// accum = rho() * accum + (1 - rho()) * grad.square(); +// update = (update_accum + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad; +// update_accum = rho() * update_accum + (1 - rho()) * update.square(); +// var -= update; // // Arguments: -// x: 1-D. +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// accum_update: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// rho: Decay factor. Must be a scalar. +// epsilon: Constant factor. Must be a scalar. +// grad: The gradient. // -// Returns 1-D. -func InvertPermutation(scope *Scope, x tf.Output) (y tf.Output) { +// Returns the created operation. +func ResourceApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdadeltaAttr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "InvertPermutation", + Type: "ResourceApplyAdadelta", Input: []tf.Input{ - x, + var_, accum, accum_update, lr, rho, epsilon, grad, }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Gradient op for `MirrorPad` op. This op folds a mirror-padded tensor. -// -// This operation folds the padded areas of `input` by `MirrorPad` according to the -// `paddings` you specify. `paddings` must be the same as `paddings` argument -// given to the corresponding `MirrorPad` op. -// -// The folded size of each dimension D of the output is: -// -// `input.dim_size(D) - paddings(D, 0) - paddings(D, 1)` +// NonMaxSuppressionAttr is an optional argument to NonMaxSuppression. +type NonMaxSuppressionAttr func(optionalAttr) + +// NonMaxSuppressionIouThreshold sets the optional iou_threshold attribute to value. // -// For example: +// value: A float representing the threshold for deciding whether boxes +// overlap too much with respect to IOU. +// If not specified, defaults to 0.5 +func NonMaxSuppressionIouThreshold(value float32) NonMaxSuppressionAttr { + return func(m optionalAttr) { + m["iou_threshold"] = value + } +} + +// Greedily selects a subset of bounding boxes in descending order of score, // -// ``` -// # 't' is [[1, 2, 3], [4, 5, 6], [7, 8, 9]]. -// # 'paddings' is [[0, 1]], [0, 1]]. -// # 'mode' is SYMMETRIC. -// # rank of 't' is 2. -// pad(t, paddings) ==> [[ 1, 5] -// [11, 28]] -// ``` +// pruning away boxes that have high intersection-over-union (IOU) overlap +// with previously selected boxes. Bounding boxes are supplied as +// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +// diagonal pair of box corners and the coordinates can be provided as normalized +// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +// is agnostic to where the origin is in the coordinate system. Note that this +// algorithm is invariant to orthogonal transformations and translations +// of the coordinate system; thus translating or reflections of the coordinate +// system result in the same boxes being selected by the algorithm. +// The output of this operation is a set of integers indexing into the input +// collection of bounding boxes representing the selected boxes. The bounding +// box coordinates corresponding to the selected indices can then be obtained +// using the `tf.gather operation`. For example: +// selected_indices = tf.image.non_max_suppression( +// boxes, scores, max_output_size, iou_threshold) +// selected_boxes = tf.gather(boxes, selected_indices) // // Arguments: -// input: The input tensor to be folded. -// paddings: A two-column matrix specifying the padding sizes. The number of -// rows must be the same as the rank of `input`. -// mode: The mode used in the `MirrorPad` op. +// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. +// scores: A 1-D float tensor of shape `[num_boxes]` representing a single +// score corresponding to each box (each row of boxes). +// max_output_size: A scalar integer tensor representing the maximum number of +// boxes to be selected by non max suppression. // -// Returns The folded tensor. -func MirrorPadGrad(scope *Scope, input tf.Output, paddings tf.Output, mode string) (output tf.Output) { +// Returns A 1-D integer tensor of shape `[M]` representing the selected +// indices from the boxes tensor, where `M <= max_output_size`. +func NonMaxSuppression(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, optional ...NonMaxSuppressionAttr) (selected_indices tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"mode": mode} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "MirrorPadGrad", + Type: "NonMaxSuppression", Input: []tf.Input{ - input, paddings, + boxes, scores, max_output_size, }, Attrs: attrs, } @@ -9234,92 +8762,95 @@ func MirrorPadGrad(scope *Scope, input tf.Output, paddings tf.Output, mode strin return op.Output(0) } -// Computes softmax cross entropy cost and gradients to backpropagate. -// -// Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept -// a matrix of label probabilities, but rather a single label per row -// of features. This label is considered to have probability 1.0 for the -// given row. -// -// Inputs are the logits, not probabilities. -// -// Arguments: -// features: batch_size x num_classes matrix -// labels: batch_size vector with values in [0, num_classes). -// This is the label for the given minibatch entry. -// -// Returns Per example loss (batch_size vector).backpropagated gradients (batch_size x num_classes matrix). -func SparseSoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.Output) (loss tf.Output, backprop tf.Output) { +// Creates a dataset that emits `components` as a tuple of tensors once. +func TensorDataset(scope *Scope, components []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "SparseSoftmaxCrossEntropyWithLogits", + Type: "TensorDataset", Input: []tf.Input{ - features, labels, + tf.OutputList(components), }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Fast Fourier transform. +// Component-wise multiplies a SparseTensor by a dense Tensor. // -// Computes the 1-dimensional discrete Fourier transform over the inner-most -// dimension of `input`. +// The output locations corresponding to the implicitly zero elements in the sparse +// tensor will be zero (i.e., will not take up storage space), regardless of the +// contents of the dense tensor (even if it's +/-INF and that INF*0 == NaN). // -// Arguments: -// input: A complex64 tensor. +// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not +// the other direction. // -// Returns A complex64 tensor of the same shape as `input`. The inner-most -// dimension of `input` is replaced with its 1D Fourier transform. +// Arguments: +// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. +// sp_shape: 1-D. Shape of the input SparseTensor. +// dense: `R`-D. The dense Tensor operand. // -// @compatibility(numpy) -// Equivalent to np.fft.fft -// @end_compatibility -func FFT(scope *Scope, input tf.Output) (output tf.Output) { +// Returns 1-D. The `N` values that are operated on. +func SparseDenseCwiseMul(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "FFT", + Type: "SparseDenseCwiseMul", Input: []tf.Input{ - input, + sp_indices, sp_values, sp_shape, dense, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceSparseApplyAdagradDAAttr is an optional argument to ResourceSparseApplyAdagradDA. -type ResourceSparseApplyAdagradDAAttr func(optionalAttr) +// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp. +type ResourceSparseApplyRMSPropAttr func(optionalAttr) -// ResourceSparseApplyAdagradDAUseLocking sets the optional use_locking attribute to value. +// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value. // -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// value: If `True`, updating of the var, ms, and mom tensors is protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. // If not specified, defaults to false -func ResourceSparseApplyAdagradDAUseLocking(value bool) ResourceSparseApplyAdagradDAAttr { +func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Update entries in '*var' and '*accum' according to the proximal adagrad scheme. +// Update '*var' according to the RMSProp algorithm. +// +// Note that in dense implementation of this algorithm, ms and mom will +// update even if the grad is zero, but in this sparse implementation, ms +// and mom will not update in iterations during which the grad is zero. +// +// mean_square = decay * mean_square + (1-decay) * gradient ** 2 +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) +// +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +// var <- var - mom // // Arguments: // var_: Should be from a Variable(). -// gradient_accumulator: Should be from a Variable(). -// gradient_squared_accumulator: Should be from a Variable(). +// ms: Should be from a Variable(). +// mom: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// rho: Decay rate. Must be a scalar. +// +// epsilon: Ridge term. Must be a scalar. // grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// lr: Learning rate. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// global_step: Training step number. Must be a scalar. +// indices: A vector of indices into the first dimension of var, ms and mom. // // Returns the created operation. -func ResourceSparseApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceSparseApplyAdagradDAAttr) (o *tf.Operation) { +func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -9328,203 +8859,273 @@ func ResourceSparseApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumul a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyAdagradDA", - Input: []tf.Input{ - var_, gradient_accumulator, gradient_squared_accumulator, grad, indices, lr, l1, l2, global_step, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Returns the truth value of NOT x element-wise. -func LogicalNot(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LogicalNot", + Type: "ResourceSparseApplyRMSProp", Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// 3D real-valued fast Fourier transform. -// -// Computes the 3-dimensional discrete Fourier transform of a real-valued signal -// over the inner-most 3 dimensions of `input`. -// -// Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the -// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension -// of `output`: the zero-frequency term, followed by the `fft_length / 2` -// positive-frequency terms. -// -// Along each axis `RFFT3D` is computed on, if `fft_length` is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. -// -// Arguments: -// input: A float32 tensor. -// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. -// -// Returns A complex64 tensor of the same rank as `input`. The inner-most 3 -// dimensions of `input` are replaced with the their 3D Fourier transform. The -// inner-most dimension contains `fft_length / 2 + 1` unique frequency -// components. + var_, ms, mom, lr, rho, momentum, epsilon, grad, indices, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Returns the truth value of (x > y) element-wise. // -// @compatibility(numpy) -// Equivalent to np.fft.rfftn with 3 dimensions. -// @end_compatibility -func RFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { +// *NOTE*: `Greater` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "RFFT3D", + Type: "Greater", Input: []tf.Input{ - input, fft_length, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// TensorArrayV3Attr is an optional argument to TensorArrayV3. -type TensorArrayV3Attr func(optionalAttr) +// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox. +type SampleDistortedBoundingBoxAttr func(optionalAttr) -// TensorArrayV3ElementShape sets the optional element_shape attribute to value. +// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value. // -// value: The expected shape of an element, if known. Used to -// validate the shapes of TensorArray elements. If this shape is not -// fully specified, gathering zero-size TensorArrays is an error. -// If not specified, defaults to -func TensorArrayV3ElementShape(value tf.Shape) TensorArrayV3Attr { +// value: If either `seed` or `seed2` are set to non-zero, the random number +// generator is seeded by the given `seed`. Otherwise, it is seeded by a random +// seed. +// If not specified, defaults to 0 +func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { - m["element_shape"] = value + m["seed"] = value } } -// TensorArrayV3DynamicSize sets the optional dynamic_size attribute to value. +// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value. // -// value: A boolean that determines whether writes to the TensorArray -// are allowed to grow the size. By default, this is not allowed. -// If not specified, defaults to false -func TensorArrayV3DynamicSize(value bool) TensorArrayV3Attr { +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { - m["dynamic_size"] = value + m["seed2"] = value } } -// TensorArrayV3ClearAfterRead sets the optional clear_after_read attribute to value. +// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value. // -// value: If true (default), Tensors in the TensorArray are cleared -// after being read. This disables multiple read semantics but allows early -// release of memory. -// If not specified, defaults to true -func TensorArrayV3ClearAfterRead(value bool) TensorArrayV3Attr { +// value: The cropped area of the image must contain at least this +// fraction of any bounding box supplied. The value of this parameter should be +// non-negative. In the case of 0, the cropped area does not need to overlap +// any of the bounding boxes supplied. +// If not specified, defaults to 0.1 +func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { - m["clear_after_read"] = value + m["min_object_covered"] = value } } -// TensorArrayV3IdenticalElementShapes sets the optional identical_element_shapes attribute to value. +// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value. // -// value: If true (default is false), then all -// elements in the TensorArray will be expected to have have identical shapes. -// This allows certain behaviors, like dynamically checking for -// consistent shapes on write, and being able to fill in properly -// shaped zero tensors on stack -- even if the element_shape attribute -// is not fully defined. -// If not specified, defaults to false -func TensorArrayV3IdenticalElementShapes(value bool) TensorArrayV3Attr { +// value: The cropped area of the image must have an aspect ratio = +// width / height within this range. +// If not specified, defaults to +func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { - m["identical_element_shapes"] = value + m["aspect_ratio_range"] = value } } -// TensorArrayV3TensorArrayName sets the optional tensor_array_name attribute to value. +// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value. // -// value: Overrides the name used for the temporary tensor_array -// resource. Default value is the name of the 'TensorArray' op (which -// is guaranteed unique). -// If not specified, defaults to "" -func TensorArrayV3TensorArrayName(value string) TensorArrayV3Attr { +// value: The cropped area of the image must contain a fraction of the +// supplied image within in this range. +// If not specified, defaults to +func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { return func(m optionalAttr) { - m["tensor_array_name"] = value + m["area_range"] = value } } -// An array of Tensors of given size. +// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value. // -// Write data via Write and read via Read or Pack. +// value: Number of attempts at generating a cropped region of the image +// of the specified constraints. After `max_attempts` failures, return the entire +// image. +// If not specified, defaults to 100 +func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["max_attempts"] = value + } +} + +// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value. +// +// value: Controls behavior if no bounding boxes supplied. +// If true, assume an implicit bounding box covering the whole input. If false, +// raise an error. +// If not specified, defaults to false +func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr { + return func(m optionalAttr) { + m["use_image_if_no_bounding_boxes"] = value + } +} + +// Generate a single randomly distorted bounding box for an image. +// +// Bounding box annotations are often supplied in addition to ground-truth labels +// in image recognition or object localization tasks. A common technique for +// training such a system is to randomly distort an image while preserving +// its content, i.e. *data augmentation*. This Op outputs a randomly distorted +// localization of an object, i.e. bounding box, given an `image_size`, +// `bounding_boxes` and a series of constraints. +// +// The output of this Op is a single bounding box that may be used to crop the +// original image. The output is returned as 3 tensors: `begin`, `size` and +// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the +// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize +// what the bounding box looks like. +// +// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The +// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and +// height of the underlying image. +// +// For example, +// +// ```python +// # Generate a single distorted bounding box. +// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( +// tf.shape(image), +// bounding_boxes=bounding_boxes) +// +// # Draw the bounding box in an image summary. +// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), +// bbox_for_draw) +// tf.summary.image('images_with_box', image_with_box) +// +// # Employ the bounding box to distort the image. +// distorted_image = tf.slice(image, begin, size) +// ``` +// +// Note that if no bounding box information is available, setting +// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit +// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is +// false and no bounding boxes are supplied, an error is raised. // // Arguments: -// size: The size of the array. -// dtype: The type of the elements on the tensor_array. +// image_size: 1-D, containing `[height, width, channels]`. +// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes +// associated with the image. // -// Returns The handle to the TensorArray.A scalar used to control gradient flow. -func TensorArrayV3(scope *Scope, size tf.Output, dtype tf.DataType, optional ...TensorArrayV3Attr) (handle tf.Output, flow tf.Output) { +// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to +// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to +// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box. +// Provide as input to `tf.image.draw_bounding_boxes`. +func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TensorArrayV3", + Type: "SampleDistortedBoundingBox", Input: []tf.Input{ - size, + image_size, bounding_boxes, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0), op.Output(1), op.Output(2) } -// MaxPool3DAttr is an optional argument to MaxPool3D. -type MaxPool3DAttr func(optionalAttr) +// Converts each string in the input Tensor to its hash mod by a number of buckets. +// +// The hash function is deterministic on the content of the string within the +// process and will never change. However, it is not suitable for cryptography. +// This function may be used when CPU time is scarce and inputs are trusted or +// unimportant. There is a risk of adversaries constructing inputs that all hash +// to the same bucket. To prevent this problem, use a strong hash function with +// `tf.string_to_hash_bucket_strong`. +// +// Arguments: +// input: The strings to assign a hash bucket. +// num_buckets: The number of buckets. +// +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_buckets": num_buckets} + opspec := tf.OpSpec{ + Type: "StringToHashBucketFast", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} -// MaxPool3DDataFormat sets the optional data_format attribute to value. +// Returns the max of x and y (i.e. x > y ? x : y) element-wise. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func MaxPool3DDataFormat(value string) MaxPool3DAttr { +// *NOTE*: `Maximum` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Maximum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Maximum", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3. +type TensorArrayGatherV3Attr func(optionalAttr) + +// TensorArrayGatherV3ElementShape sets the optional element_shape attribute to value. +// +// value: The expected shape of an element, if known. Used to +// validate the shapes of TensorArray elements. If this shape is not +// fully specified, gathering zero-size TensorArrays is an error. +// If not specified, defaults to +func TensorArrayGatherV3ElementShape(value tf.Shape) TensorArrayGatherV3Attr { return func(m optionalAttr) { - m["data_format"] = value + m["element_shape"] = value } } -// Performs 3D max pooling on the input. +// Gather specific elements from the TensorArray into output `value`. +// +// All elements selected by `indices` must have the same shape. // // Arguments: -// input: Shape `[batch, depth, rows, cols, channels]` tensor to pool over. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. +// handle: The handle to a TensorArray. +// indices: The locations in the TensorArray from which to read tensor elements. +// flow_in: A float scalar that enforces proper chaining of operations. +// dtype: The type of the elem that is returned. // -// Returns The max pooled output tensor. -func MaxPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DAttr) (output tf.Output) { +// Returns All of the elements in the TensorArray, concatenated along a new +// axis (the new dimension 0). +func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV3Attr) (value tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{"dtype": dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MaxPool3D", + Type: "TensorArrayGatherV3", Input: []tf.Input{ - input, + handle, indices, flow_in, }, Attrs: attrs, } @@ -9532,308 +9133,391 @@ func MaxPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, pa return op.Output(0) } -// Computes the gradients of 3-D convolution with respect to the input. +// Returns x / y element-wise for integer types. // -// DEPRECATED at GraphDef version 10: Use Conv3DBackpropInputV2 +// Truncation designates that negative numbers will round fractional quantities +// toward zero. I.e. -7 / 5 = -1. This matches C semantics but it is different +// than Python semantics. See `FloorDiv` for a division function that matches +// Python Semantics. // -// Arguments: -// input: Shape `[batch, depth, rows, cols, in_channels]`. -// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. -// `in_channels` must match between `input` and `filter`. -// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, -// out_channels]`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func Conv3DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string) (output tf.Output) { +// *NOTE*: `TruncateDiv` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func TruncateDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "Conv3DBackpropInput", + Type: "TruncateDiv", Input: []tf.Input{ - input, filter, out_backprop, + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Inverse 2D fast Fourier transform. +// Restores tensors from a V2 checkpoint. // -// Computes the inverse 2-dimensional discrete Fourier transform over the -// inner-most 2 dimensions of `input`. +// For backward compatibility with the V1 format, this Op currently allows +// restoring from a V1 checkpoint as well: +// - This Op first attempts to find the V2 index file pointed to by "prefix", and +// if found proceed to read it as a V2 checkpoint; +// - Otherwise the V1 read path is invoked. +// Relying on this behavior is not recommended, as the ability to fall back to read +// V1 might be deprecated and eventually removed. // -// Arguments: -// input: A complex64 tensor. +// By default, restores the named tensors in full. If the caller wishes to restore +// specific slices of stored tensors, "shape_and_slices" should be non-empty +// strings and correspondingly well-formed. // -// Returns A complex64 tensor of the same shape as `input`. The inner-most 2 -// dimensions of `input` are replaced with their inverse 2D Fourier transform. +// Callers must ensure all the named tensors are indeed stored in the checkpoint. // -// @compatibility(numpy) -// Equivalent to np.fft.ifft2 -// @end_compatibility -func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) { +// Arguments: +// prefix: Must have a single element. The prefix of a V2 checkpoint. +// tensor_names: shape {N}. The names of the tensors to be restored. +// shape_and_slices: shape {N}. The slice specs of the tensors to be restored. +// Empty strings indicate that they are non-partitioned tensors. +// dtypes: shape {N}. The list of expected dtype for the tensors. Must match +// those stored in the checkpoint. +// +// Returns shape {N}. The restored tensors, whose shapes are read from the +// checkpoint directly. +func RestoreV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and_slices tf.Output, dtypes []tf.DataType) (tensors []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtypes": dtypes} opspec := tf.OpSpec{ - Type: "IFFT2D", + Type: "RestoreV2", Input: []tf.Input{ - input, + prefix, tensor_names, shape_and_slices, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if tensors, idx, err = makeOutputList(op, idx, "tensors"); err != nil { + scope.UpdateErr("RestoreV2", err) + return + } + return tensors } -// Creates a tensor filled with a scalar value. -// -// This operation creates a tensor of shape `dims` and fills it with `value`. +// Creates a dataset that skips `count` elements from the `input_dataset`. // -// For example: +// Arguments: // -// ``` -// # Output tensor has shape [2, 3]. -// fill([2, 3], 9) ==> [[9, 9, 9] -// [9, 9, 9]] -// ``` +// count: A scalar representing the number of elements from the `input_dataset` +// that should be skipped. If count is -1, skips everything. // -// Arguments: -// dims: 1-D. Represents the shape of the output tensor. -// value: 0-D (scalar). Value to fill the returned tensor. // -// @compatibility(numpy) -// Equivalent to np.full -// @end_compatibility -func Fill(scope *Scope, dims tf.Output, value tf.Output) (output tf.Output) { +func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Fill", + Type: "SkipDataset", Input: []tf.Input{ - dims, value, + input_dataset, count, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// 2D fast Fourier transform. +// Computes the maximum along segments of a tensor. // -// Computes the 2-dimensional discrete Fourier transform over the inner-most -// 2 dimensions of `input`. +// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of +// segments. +// +// Computes a tensor such that +// \\(output_i = \max_j(data_j)\\) where `max` is over `j` such +// that `segment_ids[j] == i`. +// +// If the max is empty for a given segment ID `i`, `output[i] = 0`. +// +//
+// +//
// // Arguments: -// input: A complex64 tensor. // -// Returns A complex64 tensor of the same shape as `input`. The inner-most 2 -// dimensions of `input` are replaced with their 2D Fourier transform. +// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s +// first dimension. Values should be sorted and can be repeated. // -// @compatibility(numpy) -// Equivalent to np.fft.fft2 -// @end_compatibility -func FFT2D(scope *Scope, input tf.Output) (output tf.Output) { +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "FFT2D", + Type: "SegmentMax", Input: []tf.Input{ - input, + data, segment_ids, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceApplyProximalGradientDescentAttr is an optional argument to ResourceApplyProximalGradientDescent. -type ResourceApplyProximalGradientDescentAttr func(optionalAttr) - -// ResourceApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value. -// -// value: If True, the subtraction will be protected by a lock; -// otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceApplyProximalGradientDescentUseLocking(value bool) ResourceApplyProximalGradientDescentAttr { - return func(m optionalAttr) { - m["use_locking"] = value +// Computes hyperbolic tangent of `x` element-wise. +func Tanh(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Tanh", + Input: []tf.Input{ + x, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Update '*var' as FOBOS algorithm with fixed learning rate. +// Decode web-safe base64-encoded strings. // -// prox_v = var - alpha * delta -// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} +// Input may or may not have padding at the end. See EncodeBase64 for padding. +// Web-safe means that input must use - and _ instead of + and /. // // Arguments: -// var_: Should be from a Variable(). -// alpha: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// delta: The change. +// input: Base64 strings to decode. // -// Returns the created operation. -func ResourceApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, delta tf.Output, optional ...ResourceApplyProximalGradientDescentAttr) (o *tf.Operation) { +// Returns Decoded strings. +func DecodeBase64(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ResourceApplyProximalGradientDescent", + Type: "DecodeBase64", Input: []tf.Input{ - var_, alpha, l1, l2, delta, + input, }, - Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Computes the gradient for the sqrt of `x` wrt its input. +// Store the input tensor in the state of the current session. // -// Specifically, `grad = dy * 0.5 / y`, where `y = sqrt(x)`, and `dy` -// is the corresponding input gradient. -func SqrtGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { +// Arguments: +// value: The tensor to be stored. +// +// Returns The handle for the tensor stored in the session state, represented +// as a string. +func GetSessionHandle(scope *Scope, value tf.Output) (handle tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SqrtGrad", + Type: "GetSessionHandle", Input: []tf.Input{ - y, dy, + value, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Get the value of the tensor specified by its handle. +// ResourceSparseApplyProximalAdagradAttr is an optional argument to ResourceSparseApplyProximalAdagrad. +type ResourceSparseApplyProximalAdagradAttr func(optionalAttr) + +// ResourceSparseApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. +// +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyProximalAdagradUseLocking(value bool) ResourceSparseApplyProximalAdagradAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Sparse update entries in '*var' and '*accum' according to FOBOS algorithm. +// +// That is for rows we have grad for, we update var and accum as follows: +// accum += grad * grad +// prox_v = var +// prox_v -= lr * grad * (1 / sqrt(accum)) +// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} // // Arguments: -// handle: The handle for a tensor stored in the session state. -// dtype: The type of the output value. +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. // -// Returns The tensor for the given handle. -func GetSessionTensor(scope *Scope, handle tf.Output, dtype tf.DataType) (value tf.Output) { +// Returns the created operation. +func ResourceSparseApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalAdagradAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "GetSessionTensor", + Type: "ResourceSparseApplyProximalAdagrad", Input: []tf.Input{ - handle, + var_, accum, lr, l1, l2, grad, indices, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Returns x - y element-wise. -// -// *NOTE*: `Subtract` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Sub(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Returns element-wise largest integer not greater than x. +func Floor(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Sub", + Type: "Floor", Input: []tf.Input{ - x, y, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes softmax cross entropy cost and gradients to backpropagate. -// -// Inputs are the logits, not probabilities. -// -// Arguments: -// features: batch_size x num_classes matrix -// labels: batch_size x num_classes matrix -// The caller must ensure that each batch of labels represents a valid -// probability distribution. -// -// Returns Per example loss (batch_size vector).backpropagated gradients (batch_size x num_classes matrix). -func SoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.Output) (loss tf.Output, backprop tf.Output) { +// Computes the Gauss error function of `x` element-wise. +func Erf(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SoftmaxCrossEntropyWithLogits", + Type: "Erf", Input: []tf.Input{ - features, labels, + x, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// ReduceJoinAttr is an optional argument to ReduceJoin. -type ReduceJoinAttr func(optionalAttr) +// OneHotAttr is an optional argument to OneHot. +type OneHotAttr func(optionalAttr) -// ReduceJoinKeepDims sets the optional keep_dims attribute to value. +// OneHotAxis sets the optional axis attribute to value. // -// value: If `True`, retain reduced dimensions with length `1`. -// If not specified, defaults to false -func ReduceJoinKeepDims(value bool) ReduceJoinAttr { +// value: The axis to fill (default: -1, a new inner-most axis). +// If not specified, defaults to -1 +func OneHotAxis(value int64) OneHotAttr { return func(m optionalAttr) { - m["keep_dims"] = value + m["axis"] = value } } -// ReduceJoinSeparator sets the optional separator attribute to value. +// Returns a one-hot tensor. // -// value: The separator to use when joining. -// If not specified, defaults to "" -func ReduceJoinSeparator(value string) ReduceJoinAttr { - return func(m optionalAttr) { - m["separator"] = value - } -} - -// Joins a string Tensor across the given dimensions. +// The locations represented by indices in `indices` take value `on_value`, +// while all other locations take value `off_value`. // -// Computes the string join across dimensions in the given string Tensor of shape -// `[d_0, d_1, ..., d_n-1]`. Returns a new Tensor created by joining the input -// strings with the given separator (default: empty string). Negative indices are -// counted backwards from the end, with `-1` being equivalent to `n - 1`. +// If the input `indices` is rank `N`, the output will have rank `N+1`, +// The new axis is created at dimension `axis` (default: the new axis is +// appended at the end). // -// For example: +// If `indices` is a scalar the output shape will be a vector of length `depth`. // -// ```python -// # tensor `a` is [["a", "b"], ["c", "d"]] -// tf.reduce_join(a, 0) ==> ["ac", "bd"] -// tf.reduce_join(a, 1) ==> ["ab", "cd"] -// tf.reduce_join(a, -2) = tf.reduce_join(a, 0) ==> ["ac", "bd"] -// tf.reduce_join(a, -1) = tf.reduce_join(a, 1) ==> ["ab", "cd"] -// tf.reduce_join(a, 0, keep_dims=True) ==> [["ac", "bd"]] -// tf.reduce_join(a, 1, keep_dims=True) ==> [["ab"], ["cd"]] -// tf.reduce_join(a, 0, separator=".") ==> ["a.c", "b.d"] -// tf.reduce_join(a, [0, 1]) ==> ["acbd"] -// tf.reduce_join(a, [1, 0]) ==> ["abcd"] -// tf.reduce_join(a, []) ==> ["abcd"] +// If `indices` is a vector of length `features`, the output shape will be: +// ``` +// features x depth if axis == -1 +// depth x features if axis == 0 +// ``` +// +// If `indices` is a matrix (batch) with shape `[batch, features]`, +// the output shape will be: +// ``` +// batch x features x depth if axis == -1 +// batch x depth x features if axis == 1 +// depth x batch x features if axis == 0 +// ``` +// +// +// Examples +// ========= +// +// Suppose that +// +// ``` +// indices = [0, 2, -1, 1] +// depth = 3 +// on_value = 5.0 +// off_value = 0.0 +// axis = -1 +// ``` +// +// Then output is `[4 x 3]`: +// +// ```output = +// [5.0 0.0 0.0] // one_hot(0) +// [0.0 0.0 5.0] // one_hot(2) +// [0.0 0.0 0.0] // one_hot(-1) +// [0.0 5.0 0.0] // one_hot(1) +// ``` +// +// Suppose that +// +// ``` +// indices = [0, 2, -1, 1] +// depth = 3 +// on_value = 0.0 +// off_value = 3.0 +// axis = 0 +// ``` +// +// Then output is `[3 x 4]`: +// +// ```output = +// [0.0 3.0 3.0 3.0] +// [3.0 3.0 3.0 0.0] +// [3.0 3.0 3.0 3.0] +// [3.0 0.0 3.0 3.0] +// // ^ one_hot(0) +// // ^ one_hot(2) +// // ^ one_hot(-1) +// // ^ one_hot(1) +// ``` +// Suppose that +// +// ``` +// indices = [[0, 2], [1, -1]] +// depth = 3 +// on_value = 1.0 +// off_value = 0.0 +// axis = -1 // ``` // +// Then output is `[2 x 2 x 3]`: +// +// ```output = +// [ +// [1.0, 0.0, 0.0] // one_hot(0) +// [0.0, 0.0, 1.0] // one_hot(2) +// ][ +// [0.0, 1.0, 0.0] // one_hot(1) +// [0.0, 0.0, 0.0] // one_hot(-1) +// ]``` +// // Arguments: -// inputs: The input to be joined. All reduced indices must have non-zero size. -// reduction_indices: The dimensions to reduce over. Dimensions are reduced in the -// order specified. Omitting `reduction_indices` is equivalent to passing -// `[n-1, n-2, ..., 0]`. Negative indices from `-n` to `-1` are supported. +// indices: A tensor of indices. +// depth: A scalar defining the depth of the one hot dimension. +// on_value: A scalar defining the value to fill in output when `indices[j] = i`. +// off_value: A scalar defining the value to fill in output when `indices[j] != i`. // -// Returns Has shape equal to that of the input with reduced dimensions removed or -// set to `1` depending on `keep_dims`. -func ReduceJoin(scope *Scope, inputs tf.Output, reduction_indices tf.Output, optional ...ReduceJoinAttr) (output tf.Output) { +// Returns The one-hot tensor. +func OneHot(scope *Scope, indices tf.Output, depth tf.Output, on_value tf.Output, off_value tf.Output, optional ...OneHotAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -9842,9 +9526,9 @@ func ReduceJoin(scope *Scope, inputs tf.Output, reduction_indices tf.Output, opt a(attrs) } opspec := tf.OpSpec{ - Type: "ReduceJoin", + Type: "OneHot", Input: []tf.Input{ - inputs, reduction_indices, + indices, depth, on_value, off_value, }, Attrs: attrs, } @@ -9852,201 +9536,118 @@ func ReduceJoin(scope *Scope, inputs tf.Output, reduction_indices tf.Output, opt return op.Output(0) } -// Computes cos of x element-wise. -func Cos(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Cos", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// FusedBatchNormGradAttr is an optional argument to FusedBatchNormGrad. -type FusedBatchNormGradAttr func(optionalAttr) - -// FusedBatchNormGradEpsilon sets the optional epsilon attribute to value. -// -// value: A small float number added to the variance of x. -// If not specified, defaults to 0.0001 -func FusedBatchNormGradEpsilon(value float32) FusedBatchNormGradAttr { - return func(m optionalAttr) { - m["epsilon"] = value - } -} - -// FusedBatchNormGradDataFormat sets the optional data_format attribute to value. -// -// value: The data format for y_backprop, x, x_backprop. -// Either "NHWC" (default) or "NCHW". -// If not specified, defaults to "NHWC" -func FusedBatchNormGradDataFormat(value string) FusedBatchNormGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// FusedBatchNormGradIsTraining sets the optional is_training attribute to value. +// Reads the value of a variable. // -// value: A bool value to indicate the operation is for training (default) -// or inference. -// If not specified, defaults to true -func FusedBatchNormGradIsTraining(value bool) FusedBatchNormGradAttr { - return func(m optionalAttr) { - m["is_training"] = value - } -} - -// Gradient for batch normalization. +// The tensor returned by this operation is immutable. // -// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -// The size of 1D Tensors matches the dimension C of the 4D Tensors. +// The value returned by this operation is guaranteed to be influenced by all the +// writes on which this operation depends directly or indirectly, and to not be +// influenced by any of the writes which depend directly or indirectly on this +// operation. // // Arguments: -// y_backprop: A 4D Tensor for the gradient with respect to y. -// x: A 4D Tensor for input data. -// scale: A 1D Tensor for scaling factor, to scale the normalized x. -// reserve_space_1: When is_training is True, a 1D Tensor for the computed batch -// mean to be reused in gradient computation. When is_training is -// False, a 1D Tensor for the population mean to be reused in both -// 1st and 2nd order gradient computation. -// reserve_space_2: When is_training is True, a 1D Tensor for the computed batch -// variance (inverted variance in the cuDNN case) to be reused in -// gradient computation. When is_training is False, a 1D Tensor -// for the population variance to be reused in both 1st and 2nd -// order gradient computation. -// -// Returns A 4D Tensor for the gradient with respect to x.A 1D Tensor for the gradient with respect to scale.A 1D Tensor for the gradient with respect to offset.Unused placeholder to match the mean input in FusedBatchNorm.Unused placeholder to match the variance input -// in FusedBatchNorm. -func FusedBatchNormGrad(scope *Scope, y_backprop tf.Output, x tf.Output, scale tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output, optional ...FusedBatchNormGradAttr) (x_backprop tf.Output, scale_backprop tf.Output, offset_backprop tf.Output, reserve_space_3 tf.Output, reserve_space_4 tf.Output) { +// resource: handle to the resource in which to store the variable. +// dtype: the dtype of the value. +func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "FusedBatchNormGrad", + Type: "ReadVariableOp", Input: []tf.Input{ - y_backprop, x, scale, reserve_space_1, reserve_space_2, + resource, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) + return op.Output(0) } -// TopKAttr is an optional argument to TopK. -type TopKAttr func(optionalAttr) +// MaxPool3DGradAttr is an optional argument to MaxPool3DGrad. +type MaxPool3DGradAttr func(optionalAttr) -// TopKSorted sets the optional sorted attribute to value. +// MaxPool3DGradDataFormat sets the optional data_format attribute to value. // -// value: If true the resulting `k` elements will be sorted by the values in -// descending order. -// If not specified, defaults to true -func TopKSorted(value bool) TopKAttr { +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func MaxPool3DGradDataFormat(value string) MaxPool3DGradAttr { return func(m optionalAttr) { - m["sorted"] = value + m["data_format"] = value } } -// Finds values and indices of the `k` largest elements for the last dimension. -// -// DEPRECATED at GraphDef version 7: Use TopKV2 instead -// -// If the input is a vector (rank-1), finds the `k` largest entries in the vector -// and outputs their values and indices as vectors. Thus `values[j]` is the -// `j`-th largest entry in `input`, and its index is `indices[j]`. -// -// For matrices (resp. higher rank input), computes the top `k` entries in each -// row (resp. vector along the last dimension). Thus, -// -// values.shape = indices.shape = input.shape[:-1] + [k] -// -// If two elements are equal, the lower-index element appears first. -// -// If `k` varies dynamically, use `TopKV2` below. +// Computes gradients of max pooling function. // // Arguments: -// input: 1-D or higher with last dimension at least `k`. -// k: Number of top elements to look for along the last dimension (along each -// row for matrices). -// -// Returns The `k` largest elements along each last dimensional slice.The indices of `values` within the last dimension of `input`. -func TopK(scope *Scope, input tf.Output, k int64, optional ...TopKAttr) (values tf.Output, indices tf.Output) { +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func MaxPool3DGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"k": k} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TopK", + Type: "MaxPool3DGrad", Input: []tf.Input{ - input, + orig_input, orig_output, grad, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Transforms a Tensor into a serialized TensorProto proto. -// -// Arguments: -// tensor: A Tensor of type `T`. -// -// Returns A serialized TensorProto proto of the input tensor. -func SerializeTensor(scope *Scope, tensor tf.Output) (serialized tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SerializeTensor", - Input: []tf.Input{ - tensor, - }, - } - op := scope.AddOperation(opspec) return op.Output(0) } -// MatrixSolveAttr is an optional argument to MatrixSolve. -type MatrixSolveAttr func(optionalAttr) +// SparseReduceSumAttr is an optional argument to SparseReduceSum. +type SparseReduceSumAttr func(optionalAttr) -// MatrixSolveAdjoint sets the optional adjoint attribute to value. +// SparseReduceSumKeepDims sets the optional keep_dims attribute to value. // -// value: Boolean indicating whether to solve with `matrix` or its (block-wise) -// adjoint. +// value: If true, retain reduced dimensions with length 1. // If not specified, defaults to false -func MatrixSolveAdjoint(value bool) MatrixSolveAttr { +func SparseReduceSumKeepDims(value bool) SparseReduceSumAttr { return func(m optionalAttr) { - m["adjoint"] = value + m["keep_dims"] = value } } -// Solves systems of linear equations. +// Computes the sum of elements across dimensions of a SparseTensor. // -// `Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is -// a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix -// satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. -// If `adjoint` is `True` then each output matrix satisfies -// `adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`. +// This Op takes a SparseTensor and is the sparse counterpart to +// `tf.reduce_sum()`. In particular, this Op also returns a dense `Tensor` +// instead of a sparse one. +// +// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained +// with length 1. +// +// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor +// with a single element is returned. Additionally, the axes can be negative, +// which are interpreted according to the indexing rules in Python. // // Arguments: -// matrix: Shape is `[..., M, M]`. -// rhs: Shape is `[..., M, K]`. +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. +// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. // -// Returns Shape is `[..., M, K]`. -func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixSolveAttr) (output tf.Output) { +// Returns `R-K`-D. The reduced Tensor. +func SparseReduceSum(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceSumAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -10055,9 +9656,9 @@ func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...Matr a(attrs) } opspec := tf.OpSpec{ - Type: "MatrixSolve", + Type: "SparseReduceSum", Input: []tf.Input{ - matrix, rhs, + input_indices, input_values, input_shape, reduction_axes, }, Attrs: attrs, } @@ -10065,183 +9666,285 @@ func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...Matr return op.Output(0) } -// Looks up keys in a table, outputs the corresponding values. -// -// The tensor `keys` must of the same type as the keys of the table. -// The output `values` is of the type of the table values. -// -// The scalar `default_value` is the value output for keys not present in the -// table. It must also be of the same type as the table values. -// -// Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. +// Returns element-wise remainder of division. This emulates C semantics in that // +// the result here is consistent with a truncating divide. E.g. `truncate(x / y) * +// y + truncate_mod(x, y) = x`. // -// Returns Same shape as `keys`. Values found in the table, or `default_values` -// for missing keys. -func LookupTableFindV2(scope *Scope, table_handle tf.Output, keys tf.Output, default_value tf.Output) (values tf.Output) { +// *NOTE*: `TruncateMod` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func TruncateMod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "LookupTableFindV2", + Type: "TruncateMod", Input: []tf.Input{ - table_handle, keys, default_value, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Inverse 3D fast Fourier transform. +// Inverse 2D real-valued fast Fourier transform. // -// Computes the inverse 3-dimensional discrete Fourier transform over the -// inner-most 3 dimensions of `input`. +// Computes the inverse 2-dimensional discrete Fourier transform of a real-valued +// signal over the inner-most 2 dimensions of `input`. +// +// The inner-most 2 dimensions of `input` are assumed to be the result of `RFFT2D`: +// The inner-most dimension contains the `fft_length / 2 + 1` unique components of +// the DFT of a real-valued signal. If `fft_length` is not provided, it is computed +// from the size of the inner-most 2 dimensions of `input`. If the FFT length used +// to compute `input` is odd, it should be provided since it cannot be inferred +// properly. +// +// Along each axis `IRFFT2D` is computed on, if `fft_length` (or +// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. // // Arguments: // input: A complex64 tensor. +// fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. // -// Returns A complex64 tensor of the same shape as `input`. The inner-most 3 -// dimensions of `input` are replaced with their inverse 3D Fourier transform. +// Returns A float32 tensor of the same rank as `input`. The inner-most 2 +// dimensions of `input` are replaced with the `fft_length` samples of their +// inverse 2D Fourier transform. // // @compatibility(numpy) -// Equivalent to np.fft.ifftn with 3 dimensions. +// Equivalent to np.fft.irfft2 // @end_compatibility -func IFFT3D(scope *Scope, input tf.Output) (output tf.Output) { +func IRFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "IFFT3D", + Type: "IRFFT2D", Input: []tf.Input{ - input, + input, fft_length, }, } - op := scope.AddOperation(opspec) - return op.Output(0) + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DecodeJpegAttr is an optional argument to DecodeJpeg. +type DecodeJpegAttr func(optionalAttr) + +// DecodeJpegChannels sets the optional channels attribute to value. +// +// value: Number of color channels for the decoded image. +// If not specified, defaults to 0 +func DecodeJpegChannels(value int64) DecodeJpegAttr { + return func(m optionalAttr) { + m["channels"] = value + } +} + +// DecodeJpegRatio sets the optional ratio attribute to value. +// +// value: Downscaling ratio. +// If not specified, defaults to 1 +func DecodeJpegRatio(value int64) DecodeJpegAttr { + return func(m optionalAttr) { + m["ratio"] = value + } +} + +// DecodeJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. +// +// value: If true use a slower but nicer upscaling of the +// chroma planes (yuv420/422 only). +// If not specified, defaults to true +func DecodeJpegFancyUpscaling(value bool) DecodeJpegAttr { + return func(m optionalAttr) { + m["fancy_upscaling"] = value + } +} + +// DecodeJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. +// +// value: If true try to recover an image from truncated input. +// If not specified, defaults to false +func DecodeJpegTryRecoverTruncated(value bool) DecodeJpegAttr { + return func(m optionalAttr) { + m["try_recover_truncated"] = value + } +} + +// DecodeJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. +// +// value: The minimum required fraction of lines before a truncated +// input is accepted. +// If not specified, defaults to 1 +func DecodeJpegAcceptableFraction(value float32) DecodeJpegAttr { + return func(m optionalAttr) { + m["acceptable_fraction"] = value + } +} + +// DecodeJpegDctMethod sets the optional dct_method attribute to value. +// +// value: string specifying a hint about the algorithm used for +// decompression. Defaults to "" which maps to a system-specific +// default. Currently valid values are ["INTEGER_FAST", +// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal +// jpeg library changes to a version that does not have that specific +// option.) +// If not specified, defaults to "" +func DecodeJpegDctMethod(value string) DecodeJpegAttr { + return func(m optionalAttr) { + m["dct_method"] = value + } } -// Adds `bias` to `value`. +// Decode a JPEG-encoded image to a uint8 tensor. // -// This is a deprecated version of BiasAdd and will be soon removed. +// The attr `channels` indicates the desired number of color channels for the +// decoded image. // -// This is a special case of `tf.add` where `bias` is restricted to be 1-D. -// Broadcasting is supported, so `value` may have any number of dimensions. +// Accepted values are: +// +// * 0: Use the number of channels in the JPEG-encoded image. +// * 1: output a grayscale image. +// * 3: output an RGB image. +// +// If needed, the JPEG-encoded image is transformed to match the requested number +// of color channels. +// +// The attr `ratio` allows downscaling the image by an integer factor during +// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than +// downscaling the image later. +// +// +// This op also supports decoding PNGs and non-animated GIFs since the interface is +// the same, though it is cleaner to use `tf.image.decode_image`. // // Arguments: -// value: Any number of dimensions. -// bias: 1-D with size the last dimension of `value`. +// contents: 0-D. The JPEG-encoded image. // -// Returns Broadcasted sum of `value` and `bias`. -func BiasAddV1(scope *Scope, value tf.Output, bias tf.Output) (output tf.Output) { +// Returns 3-D with shape `[height, width, channels]`.. +func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (image tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "BiasAddV1", + Type: "DecodeJpeg", Input: []tf.Input{ - value, bias, + contents, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Reverses specific dimensions of a tensor. -// -// NOTE `tf.reverse` has now changed behavior in preparation for 1.0. -// `tf.reverse_v2` is currently an alias that will be deprecated before TF 1.0. -// -// Given a `tensor`, and a `int32` tensor `axis` representing the set of -// dimensions of `tensor` to reverse. This operation reverses each dimension -// `i` for which there exists `j` s.t. `axis[j] == i`. -// -// `tensor` can have up to 8 dimensions. The number of dimensions specified -// in `axis` may be 0 or more entries. If an index is specified more than -// once, a InvalidArgument error is raised. -// -// For example: -// -// ``` -// # tensor 't' is [[[[ 0, 1, 2, 3], -// # [ 4, 5, 6, 7], -// # [ 8, 9, 10, 11]], -// # [[12, 13, 14, 15], -// # [16, 17, 18, 19], -// # [20, 21, 22, 23]]]] -// # tensor 't' shape is [1, 2, 3, 4] -// -// # 'dims' is [3] or 'dims' is [-1] -// reverse(t, dims) ==> [[[[ 3, 2, 1, 0], -// [ 7, 6, 5, 4], -// [ 11, 10, 9, 8]], -// [[15, 14, 13, 12], -// [19, 18, 17, 16], -// [23, 22, 21, 20]]]] -// -// # 'dims' is '[1]' (or 'dims' is '[-3]') -// reverse(t, dims) ==> [[[[12, 13, 14, 15], -// [16, 17, 18, 19], -// [20, 21, 22, 23] -// [[ 0, 1, 2, 3], -// [ 4, 5, 6, 7], -// [ 8, 9, 10, 11]]]] -// -// # 'dims' is '[2]' (or 'dims' is '[-2]') -// reverse(t, dims) ==> [[[[8, 9, 10, 11], -// [4, 5, 6, 7], -// [0, 1, 2, 3]] -// [[20, 21, 22, 23], -// [16, 17, 18, 19], -// [12, 13, 14, 15]]]] -// ``` +// Transforms a vector of brain.Example protos (as strings) into typed tensors. // // Arguments: -// tensor: Up to 8-D. -// axis: 1-D. The indices of the dimensions to reverse. Must be in the range -// `[-rank(tensor), rank(tensor))`. -// -// Returns The same shape as `tensor`. -func ReverseV2(scope *Scope, tensor tf.Output, axis tf.Output) (output tf.Output) { +// serialized: A vector containing a batch of binary serialized Example protos. +// names: A vector containing the names of the serialized protos. +// May contain, for example, table key (descriptive) names for the +// corresponding serialized protos. These are purely useful for debugging +// purposes, and the presence of values here has no effect on the output. +// May also be an empty vector if no names are available. +// If non-empty, this vector must be the same length as "serialized". +// sparse_keys: A list of Nsparse string Tensors (scalars). +// The keys expected in the Examples' features associated with sparse values. +// dense_keys: A list of Ndense string Tensors (scalars). +// The keys expected in the Examples' features associated with dense values. +// dense_defaults: A list of Ndense Tensors (some may be empty). +// dense_defaults[j] provides default values +// when the example's feature_map lacks dense_key[j]. If an empty Tensor is +// provided for dense_defaults[j], then the Feature dense_keys[j] is required. +// The input type is inferred from dense_defaults[j], even when it's empty. +// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, +// then the shape of dense_defaults[j] must match that of dense_shapes[j]. +// If dense_shapes[j] has an undefined major dimension (variable strides dense +// feature), dense_defaults[j] must contain a single element: +// the padding element. +// sparse_types: A list of Nsparse types; the data types of data in each Feature +// given in sparse_keys. +// Currently the ParseExample supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// dense_shapes: A list of Ndense shapes; the shapes of data in each Feature +// given in dense_keys. +// The number of elements in the Feature corresponding to dense_key[j] +// must always equal dense_shapes[j].NumEntries(). +// If dense_shapes[j] == (D0, D1, ..., DN) then the shape of output +// Tensor dense_values[j] will be (|serialized|, D0, D1, ..., DN): +// The dense outputs are just the inputs row-stacked by batch. +// This works for dense_shapes[j] = (-1, D1, ..., DN). In this case +// the shape of the output Tensor dense_values[j] will be +// (|serialized|, M, D1, .., DN), where M is the maximum number of blocks +// of elements of length D1 * .... * DN, across all minibatch entries +// in the input. Any minibatch entry with less than M blocks of elements of +// length D1 * ... * DN will be padded with the corresponding default_value +// scalar element along the second dimension. +func ParseExample(scope *Scope, serialized tf.Output, names tf.Output, sparse_keys []tf.Output, dense_keys []tf.Output, dense_defaults []tf.Output, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"sparse_types": sparse_types, "dense_shapes": dense_shapes} opspec := tf.OpSpec{ - Type: "ReverseV2", + Type: "ParseExample", Input: []tf.Input{ - tensor, axis, + serialized, names, tf.OutputList(sparse_keys), tf.OutputList(dense_keys), tf.OutputList(dense_defaults), }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil { + scope.UpdateErr("ParseExample", err) + return + } + if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil { + scope.UpdateErr("ParseExample", err) + return + } + if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil { + scope.UpdateErr("ParseExample", err) + return + } + if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil { + scope.UpdateErr("ParseExample", err) + return + } + return sparse_indices, sparse_values, sparse_shapes, dense_values } -// RealAttr is an optional argument to Real. -type RealAttr func(optionalAttr) +// VariableShapeAttr is an optional argument to VariableShape. +type VariableShapeAttr func(optionalAttr) -// RealTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_FLOAT -func RealTout(value tf.DataType) RealAttr { +// VariableShapeOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_INT32 +func VariableShapeOutType(value tf.DataType) VariableShapeAttr { return func(m optionalAttr) { - m["Tout"] = value + m["out_type"] = value } } -// Returns the real part of a complex number. +// Returns the shape of the variable pointed to by `resource`. // -// Given a tensor `input` of complex numbers, this operation returns a tensor of -// type `float` that is the real part of each element in `input`. All elements in -// `input` must be complex numbers of the form \\(a + bj\\), where *a* is the real -// part returned by this operation and *b* is the imaginary part. +// This operation returns a 1-D integer tensor representing the shape of `input`. // // For example: // // ``` -// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] -// tf.real(input) ==> [-2.25, 3.25] +// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] +// shape(t) ==> [2, 2, 3] // ``` -func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output) { +func VariableShape(scope *Scope, input tf.Output, optional ...VariableShapeAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -10250,7 +9953,7 @@ func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output a(attrs) } opspec := tf.OpSpec{ - Type: "Real", + Type: "VariableShape", Input: []tf.Input{ input, }, @@ -10260,55 +9963,80 @@ func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output return op.Output(0) } -// AudioSummaryAttr is an optional argument to AudioSummary. -type AudioSummaryAttr func(optionalAttr) +// Computes softmax cross entropy cost and gradients to backpropagate. +// +// Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept +// a matrix of label probabilities, but rather a single label per row +// of features. This label is considered to have probability 1.0 for the +// given row. +// +// Inputs are the logits, not probabilities. +// +// Arguments: +// features: batch_size x num_classes matrix +// labels: batch_size vector with values in [0, num_classes). +// This is the label for the given minibatch entry. +// +// Returns Per example loss (batch_size vector).backpropagated gradients (batch_size x num_classes matrix). +func SparseSoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.Output) (loss tf.Output, backprop tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSoftmaxCrossEntropyWithLogits", + Input: []tf.Input{ + features, labels, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} -// AudioSummaryMaxOutputs sets the optional max_outputs attribute to value. +// Fast Fourier transform. // -// value: Max number of batch elements to generate audio for. -// If not specified, defaults to 3 +// Computes the 1-dimensional discrete Fourier transform over the inner-most +// dimension of `input`. // -// REQUIRES: value >= 1 -func AudioSummaryMaxOutputs(value int64) AudioSummaryAttr { - return func(m optionalAttr) { - m["max_outputs"] = value +// Arguments: +// input: A complex64 tensor. +// +// Returns A complex64 tensor of the same shape as `input`. The inner-most +// dimension of `input` is replaced with its 1D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.fft +// @end_compatibility +func FFT(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "FFT", + Input: []tf.Input{ + input, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Outputs a `Summary` protocol buffer with audio. -// -// DEPRECATED at GraphDef version 15: Use AudioSummaryV2. -// -// The summary has up to `max_outputs` summary values containing audio. The -// audio is built from `tensor` which must be 3-D with shape `[batch_size, -// frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are -// assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. -// -// The `tag` argument is a scalar `Tensor` of type `string`. It is used to -// build the `tag` of the summary values: -// -// * If `max_outputs` is 1, the summary value tag is '*tag*/audio'. -// * If `max_outputs` is greater than 1, the summary value tags are -// generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. +// Transforms a serialized tensorflow.TensorProto proto into a Tensor. // // Arguments: -// tag: Scalar. Used to build the `tag` attribute of the summary values. -// tensor: 2-D of shape `[batch_size, frames]`. -// sample_rate: The sample rate of the signal in hertz. +// serialized: A scalar string containing a serialized TensorProto proto. +// out_type: The type of the serialized tensor. The provided type must match the +// type of the serialized tensor and no implicit conversion will take place. // -// Returns Scalar. Serialized `Summary` protocol buffer. -func AudioSummary(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate float32, optional ...AudioSummaryAttr) (summary tf.Output) { +// Returns A Tensor of type `out_type`. +func ParseTensor(scope *Scope, serialized tf.Output, out_type tf.DataType) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"sample_rate": sample_rate} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"out_type": out_type} opspec := tf.OpSpec{ - Type: "AudioSummary", + Type: "ParseTensor", Input: []tf.Input{ - tag, tensor, + serialized, }, Attrs: attrs, } @@ -10316,51 +10044,46 @@ func AudioSummary(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate flo return op.Output(0) } -// QrAttr is an optional argument to Qr. -type QrAttr func(optionalAttr) +// MaxPoolWithArgmaxAttr is an optional argument to MaxPoolWithArgmax. +type MaxPoolWithArgmaxAttr func(optionalAttr) -// QrFullMatrices sets the optional full_matrices attribute to value. -// -// value: If true, compute full-sized `q` and `r`. If false -// (the default), compute only the leading `P` columns of `q`. -// If not specified, defaults to false -func QrFullMatrices(value bool) QrAttr { +// MaxPoolWithArgmaxTargmax sets the optional Targmax attribute to value. +// If not specified, defaults to DT_INT64 +func MaxPoolWithArgmaxTargmax(value tf.DataType) MaxPoolWithArgmaxAttr { return func(m optionalAttr) { - m["full_matrices"] = value + m["Targmax"] = value } } -// Computes the QR decompositions of one or more matrices. +// Performs max pooling on the input and outputs both max values and indices. // -// Computes the QR decomposition of each inner matrix in `tensor` such that -// `tensor[..., :, :] = q[..., :, :] * r[..., :,:])` +// The indices in `argmax` are flattened, so that a maximum value at position +// `[b, y, x, c]` becomes flattened index +// `((b * height + y) * width + x) * channels + c`. // -// ```python -// # a is a tensor. -// # q is a tensor of orthonormal matrices. -// # r is a tensor of upper triangular matrices. -// q, r = qr(a) -// q_full, r_full = qr(a, full_matrices=True) -// ``` +// The indices returned are always in `[0, height) x [0, width)` before flattening, +// even if padding is involved and the mathematically correct answer is outside +// (either negative or too large). This is a bug, but fixing it is difficult to do +// in a safe backwards compatible way, especially due to flattening. // // Arguments: -// input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions -// form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`. +// input: 4-D with shape `[batch, height, width, channels]`. Input to pool over. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. // -// Returns Orthonormal basis for range of `a`. If `full_matrices` is `False` then -// shape is `[..., M, P]`; if `full_matrices` is `True` then shape is -// `[..., M, M]`.Triangular factor. If `full_matrices` is `False` then shape is -// `[..., P, N]`. If `full_matrices` is `True` then shape is `[..., M, N]`. -func Qr(scope *Scope, input tf.Output, optional ...QrAttr) (q tf.Output, r tf.Output) { +// Returns The max pooled output tensor.4-D. The flattened indices of the max values chosen for each output. +func MaxPoolWithArgmax(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolWithArgmaxAttr) (output tf.Output, argmax tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Qr", + Type: "MaxPoolWithArgmax", Input: []tf.Input{ input, }, @@ -10370,53 +10093,35 @@ func Qr(scope *Scope, input tf.Output, optional ...QrAttr) (q tf.Output, r tf.Ou return op.Output(0), op.Output(1) } -// Records the bytes size of each element of `input_dataset` in a StatsAggregator. -func BytesProducedStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "BytesProducedStatsDataset", - Input: []tf.Input{ - input_dataset, tag, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceSparseApplyProximalGradientDescentAttr is an optional argument to ResourceSparseApplyProximalGradientDescent. -type ResourceSparseApplyProximalGradientDescentAttr func(optionalAttr) +// ResourceSparseApplyAdagradDAAttr is an optional argument to ResourceSparseApplyAdagradDA. +type ResourceSparseApplyAdagradDAAttr func(optionalAttr) -// ResourceSparseApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value. +// ResourceSparseApplyAdagradDAUseLocking sets the optional use_locking attribute to value. // -// value: If True, the subtraction will be protected by a lock; -// otherwise the behavior is undefined, but may exhibit less contention. +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. // If not specified, defaults to false -func ResourceSparseApplyProximalGradientDescentUseLocking(value bool) ResourceSparseApplyProximalGradientDescentAttr { +func ResourceSparseApplyAdagradDAUseLocking(value bool) ResourceSparseApplyAdagradDAAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Sparse update '*var' as FOBOS algorithm with fixed learning rate. -// -// That is for rows we have grad for, we update var as follows: -// prox_v = var - alpha * grad -// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} +// Update entries in '*var' and '*accum' according to the proximal adagrad scheme. // // Arguments: // var_: Should be from a Variable(). -// alpha: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. +// gradient_accumulator: Should be from a Variable(). +// gradient_squared_accumulator: Should be from a Variable(). // grad: The gradient. // indices: A vector of indices into the first dimension of var and accum. +// lr: Learning rate. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// global_step: Training step number. Must be a scalar. // // Returns the created operation. -func ResourceSparseApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalGradientDescentAttr) (o *tf.Operation) { +func ResourceSparseApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceSparseApplyAdagradDAAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -10425,42 +10130,133 @@ func ResourceSparseApplyProximalGradientDescent(scope *Scope, var_ tf.Output, al a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyProximalGradientDescent", + Type: "ResourceSparseApplyAdagradDA", Input: []tf.Input{ - var_, alpha, l1, l2, grad, indices, + var_, gradient_accumulator, gradient_squared_accumulator, grad, indices, lr, l1, l2, global_step, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// MeanAttr is an optional argument to Mean. -type MeanAttr func(optionalAttr) +// EncodeJpegAttr is an optional argument to EncodeJpeg. +type EncodeJpegAttr func(optionalAttr) -// MeanKeepDims sets the optional keep_dims attribute to value. +// EncodeJpegFormat sets the optional format attribute to value. // -// value: If true, retain reduced dimensions with length 1. +// value: Per pixel image format. +// If not specified, defaults to "" +func EncodeJpegFormat(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["format"] = value + } +} + +// EncodeJpegQuality sets the optional quality attribute to value. +// +// value: Quality of the compression from 0 to 100 (higher is better and slower). +// If not specified, defaults to 95 +func EncodeJpegQuality(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["quality"] = value + } +} + +// EncodeJpegProgressive sets the optional progressive attribute to value. +// +// value: If True, create a JPEG that loads progressively (coarse to fine). // If not specified, defaults to false -func MeanKeepDims(value bool) MeanAttr { +func EncodeJpegProgressive(value bool) EncodeJpegAttr { return func(m optionalAttr) { - m["keep_dims"] = value + m["progressive"] = value } } -// Computes the mean of elements across dimensions of a tensor. +// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value. // -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. +// value: If True, spend CPU/RAM to reduce size with no quality change. +// If not specified, defaults to false +func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["optimize_size"] = value + } +} + +// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value. +// +// value: See http://en.wikipedia.org/wiki/Chroma_subsampling. +// If not specified, defaults to true +func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["chroma_downsampling"] = value + } +} + +// EncodeJpegDensityUnit sets the optional density_unit attribute to value. +// +// value: Unit used to specify `x_density` and `y_density`: +// pixels per inch (`'in'`) or centimeter (`'cm'`). +// If not specified, defaults to "in" +func EncodeJpegDensityUnit(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["density_unit"] = value + } +} + +// EncodeJpegXDensity sets the optional x_density attribute to value. +// +// value: Horizontal pixels per density unit. +// If not specified, defaults to 300 +func EncodeJpegXDensity(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["x_density"] = value + } +} + +// EncodeJpegYDensity sets the optional y_density attribute to value. +// +// value: Vertical pixels per density unit. +// If not specified, defaults to 300 +func EncodeJpegYDensity(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["y_density"] = value + } +} + +// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value. +// +// value: If not empty, embed this XMP metadata in the image header. +// If not specified, defaults to "" +func EncodeJpegXmpMetadata(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["xmp_metadata"] = value + } +} + +// JPEG-encode an image. +// +// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`. +// +// The attr `format` can be used to override the color format of the encoded +// output. Values can be: +// +// * `''`: Use a default format based on the number of channels in the image. +// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension +// of `image` must be 1. +// * `rgb`: Output an RGB JPEG image. The `channels` dimension +// of `image` must be 3. +// +// If `format` is not specified or is the empty string, a default format is picked +// in function of the number of channels in `image`: +// +// * 1: Output a grayscale image. +// * 3: Output an RGB image. // // Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. +// image: 3-D with shape `[height, width, channels]`. // -// Returns The reduced tensor. -func Mean(scope *Scope, input tf.Output, axis tf.Output, optional ...MeanAttr) (output tf.Output) { +// Returns 0-D. JPEG-encoded image. +func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) { if scope.Err() != nil { return } @@ -10469,9 +10265,9 @@ func Mean(scope *Scope, input tf.Output, axis tf.Output, optional ...MeanAttr) ( a(attrs) } opspec := tf.OpSpec{ - Type: "Mean", + Type: "EncodeJpeg", Input: []tf.Input{ - input, axis, + image, }, Attrs: attrs, } @@ -10479,90 +10275,48 @@ func Mean(scope *Scope, input tf.Output, axis tf.Output, optional ...MeanAttr) ( return op.Output(0) } -// InitializeTableFromTextFileV2Attr is an optional argument to InitializeTableFromTextFileV2. -type InitializeTableFromTextFileV2Attr func(optionalAttr) +// MultinomialAttr is an optional argument to Multinomial. +type MultinomialAttr func(optionalAttr) -// InitializeTableFromTextFileV2VocabSize sets the optional vocab_size attribute to value. -// -// value: Number of elements of the file, use -1 if unknown. -// If not specified, defaults to -1 +// MultinomialSeed sets the optional seed attribute to value. // -// REQUIRES: value >= -1 -func InitializeTableFromTextFileV2VocabSize(value int64) InitializeTableFromTextFileV2Attr { +// value: If either seed or seed2 is set to be non-zero, the internal random number +// generator is seeded by the given seed. Otherwise, a random seed is used. +// If not specified, defaults to 0 +func MultinomialSeed(value int64) MultinomialAttr { return func(m optionalAttr) { - m["vocab_size"] = value + m["seed"] = value } } -// InitializeTableFromTextFileV2Delimiter sets the optional delimiter attribute to value. +// MultinomialSeed2 sets the optional seed2 attribute to value. // -// value: Delimiter to separate fields in a line. -// If not specified, defaults to "\t" -func InitializeTableFromTextFileV2Delimiter(value string) InitializeTableFromTextFileV2Attr { +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func MultinomialSeed2(value int64) MultinomialAttr { return func(m optionalAttr) { - m["delimiter"] = value - } -} - -// Initializes a table from a text file. -// -// It inserts one key-value pair into the table for each line of the file. -// The key and value is extracted from the whole line content, elements from the -// split line based on `delimiter` or the line number (starting from zero). -// Where to extract the key and value from a line is specified by `key_index` and -// `value_index`. -// -// - A value of -1 means use the line number(starting from zero), expects `int64`. -// - A value of -2 means use the whole line content, expects `string`. -// - A value >= 0 means use the index (starting at zero) of the split line based -// on `delimiter`. -// -// Arguments: -// table_handle: Handle to a table which will be initialized. -// filename: Filename of a vocabulary text file. -// key_index: Column index in a line to get the table `key` values from. -// value_index: Column index that represents information of a line to get the table -// `value` values from. -// -// Returns the created operation. -func InitializeTableFromTextFileV2(scope *Scope, table_handle tf.Output, filename tf.Output, key_index int64, value_index int64, optional ...InitializeTableFromTextFileV2Attr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"key_index": key_index, "value_index": value_index} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "InitializeTableFromTextFileV2", - Input: []tf.Input{ - table_handle, filename, - }, - Attrs: attrs, + m["seed2"] = value } - return scope.AddOperation(opspec) } -// QuantizedReluAttr is an optional argument to QuantizedRelu. -type QuantizedReluAttr func(optionalAttr) - -// QuantizedReluOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_QUINT8 -func QuantizedReluOutType(value tf.DataType) QuantizedReluAttr { +// MultinomialOutputDtype sets the optional output_dtype attribute to value. +// If not specified, defaults to DT_INT64 +func MultinomialOutputDtype(value tf.DataType) MultinomialAttr { return func(m optionalAttr) { - m["out_type"] = value + m["output_dtype"] = value } } -// Computes Quantized Rectified Linear: `max(features, 0)` +// Draws samples from a multinomial distribution. // // Arguments: +// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` +// represents the unnormalized log probabilities for all classes. +// num_samples: 0-D. Number of independent samples to draw for each row slice. // -// min_features: The float value that the lowest quantized value represents. -// max_features: The float value that the highest quantized value represents. -// -// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. -func QuantizedRelu(scope *Scope, features tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedReluAttr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { +// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]` +// contains the drawn class labels with range `[0, num_classes)`. +func Multinomial(scope *Scope, logits tf.Output, num_samples tf.Output, optional ...MultinomialAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -10571,214 +10325,270 @@ func QuantizedRelu(scope *Scope, features tf.Output, min_features tf.Output, max a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizedRelu", + Type: "Multinomial", Input: []tf.Input{ - features, min_features, max_features, + logits, num_samples, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Reshapes a SparseTensor to represent values in a new dense shape. -// -// This operation has the same semantics as reshape on the represented dense -// tensor. The `input_indices` are recomputed based on the requested `new_shape`. -// -// If one component of `new_shape` is the special value -1, the size of that -// dimension is computed so that the total dense size remains constant. At -// most one component of `new_shape` can be -1. The number of dense elements -// implied by `new_shape` must be the same as the number of dense elements -// originally implied by `input_shape`. -// -// Reshaping does not affect the order of values in the SparseTensor. -// -// If the input tensor has rank `R_in` and `N` non-empty values, and `new_shape` -// has length `R_out`, then `input_indices` has shape `[N, R_in]`, -// `input_shape` has length `R_in`, `output_indices` has shape `[N, R_out]`, and -// `output_shape` has length `R_out`. -// -// Arguments: -// input_indices: 2-D. `N x R_in` matrix with the indices of non-empty values in a -// SparseTensor. -// input_shape: 1-D. `R_in` vector with the input SparseTensor's dense shape. -// new_shape: 1-D. `R_out` vector with the requested new dense shape. -// -// Returns 2-D. `N x R_out` matrix with the updated indices of non-empty -// values in the output SparseTensor.1-D. `R_out` vector with the full dense shape of the output -// SparseTensor. This is the same as `new_shape` but with any -1 dimensions -// filled in. -func SparseReshape(scope *Scope, input_indices tf.Output, input_shape tf.Output, new_shape tf.Output) (output_indices tf.Output, output_shape tf.Output) { +// Returns the truth value of NOT x element-wise. +func LogicalNot(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseReshape", + Type: "LogicalNot", Input: []tf.Input{ - input_indices, input_shape, new_shape, + x, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Deprecated. Use TensorArraySplitV3 +// 3D real-valued fast Fourier transform. // -// DEPRECATED at GraphDef version 26: Use TensorArraySplitV3 -func TensorArraySplitV2(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// Computes the 3-dimensional discrete Fourier transform of a real-valued signal +// over the inner-most 3 dimensions of `input`. +// +// Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the +// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension +// of `output`: the zero-frequency term, followed by the `fft_length / 2` +// positive-frequency terms. +// +// Along each axis `RFFT3D` is computed on, if `fft_length` is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. +// +// Arguments: +// input: A float32 tensor. +// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. +// +// Returns A complex64 tensor of the same rank as `input`. The inner-most 3 +// dimensions of `input` are replaced with the their 3D Fourier transform. The +// inner-most dimension contains `fft_length / 2 + 1` unique frequency +// components. +// +// @compatibility(numpy) +// Equivalent to np.fft.rfftn with 3 dimensions. +// @end_compatibility +func RFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorArraySplitV2", + Type: "RFFT3D", Input: []tf.Input{ - handle, value, lengths, flow_in, + input, fft_length, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// PackAttr is an optional argument to Pack. -type PackAttr func(optionalAttr) +// TensorArrayV3Attr is an optional argument to TensorArrayV3. +type TensorArrayV3Attr func(optionalAttr) -// PackAxis sets the optional axis attribute to value. +// TensorArrayV3ElementShape sets the optional element_shape attribute to value. // -// value: Dimension along which to pack. Negative values wrap around, so the -// valid range is `[-(R+1), R+1)`. -// If not specified, defaults to 0 -func PackAxis(value int64) PackAttr { +// value: The expected shape of an element, if known. Used to +// validate the shapes of TensorArray elements. If this shape is not +// fully specified, gathering zero-size TensorArrays is an error. +// If not specified, defaults to +func TensorArrayV3ElementShape(value tf.Shape) TensorArrayV3Attr { return func(m optionalAttr) { - m["axis"] = value + m["element_shape"] = value } } -// Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor. +// TensorArrayV3DynamicSize sets the optional dynamic_size attribute to value. // -// Packs the `N` tensors in `values` into a tensor with rank one higher than each -// tensor in `values`, by packing them along the `axis` dimension. -// Given a list of tensors of shape `(A, B, C)`; +// value: A boolean that determines whether writes to the TensorArray +// are allowed to grow the size. By default, this is not allowed. +// If not specified, defaults to false +func TensorArrayV3DynamicSize(value bool) TensorArrayV3Attr { + return func(m optionalAttr) { + m["dynamic_size"] = value + } +} + +// TensorArrayV3ClearAfterRead sets the optional clear_after_read attribute to value. // -// if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`. -// if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`. -// Etc. +// value: If true (default), Tensors in the TensorArray are cleared +// after being read. This disables multiple read semantics but allows early +// release of memory. +// If not specified, defaults to true +func TensorArrayV3ClearAfterRead(value bool) TensorArrayV3Attr { + return func(m optionalAttr) { + m["clear_after_read"] = value + } +} + +// TensorArrayV3IdenticalElementShapes sets the optional identical_element_shapes attribute to value. // -// For example: +// value: If true (default is false), then all +// elements in the TensorArray will be expected to have have identical shapes. +// This allows certain behaviors, like dynamically checking for +// consistent shapes on write, and being able to fill in properly +// shaped zero tensors on stack -- even if the element_shape attribute +// is not fully defined. +// If not specified, defaults to false +func TensorArrayV3IdenticalElementShapes(value bool) TensorArrayV3Attr { + return func(m optionalAttr) { + m["identical_element_shapes"] = value + } +} + +// TensorArrayV3TensorArrayName sets the optional tensor_array_name attribute to value. // -// ``` -// # 'x' is [1, 4] -// # 'y' is [2, 5] -// # 'z' is [3, 6] -// pack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. -// pack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]] -// ``` +// value: Overrides the name used for the temporary tensor_array +// resource. Default value is the name of the 'TensorArray' op (which +// is guaranteed unique). +// If not specified, defaults to "" +func TensorArrayV3TensorArrayName(value string) TensorArrayV3Attr { + return func(m optionalAttr) { + m["tensor_array_name"] = value + } +} + +// An array of Tensors of given size. // -// This is the opposite of `unpack`. +// Write data via Write and read via Read or Pack. // // Arguments: -// values: Must be of same shape and type. +// size: The size of the array. +// dtype: The type of the elements on the tensor_array. // -// Returns The packed tensor. -func Pack(scope *Scope, values []tf.Output, optional ...PackAttr) (output tf.Output) { +// Returns The handle to the TensorArray.A scalar used to control gradient flow. +func TensorArrayV3(scope *Scope, size tf.Output, dtype tf.DataType, optional ...TensorArrayV3Attr) (handle tf.Output, flow tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtype": dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Pack", + Type: "TensorArrayV3", Input: []tf.Input{ - tf.OutputList(values), + size, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Reorders a SparseTensor into the canonical, row-major ordering. -// -// Note that by convention, all sparse ops preserve the canonical ordering along -// increasing dimension number. The only time ordering can be violated is during -// manual manipulation of the indices and values vectors to add entries. -// -// Reordering does not affect the shape of the SparseTensor. +// MaxPool3DAttr is an optional argument to MaxPool3D. +type MaxPool3DAttr func(optionalAttr) + +// MaxPool3DDataFormat sets the optional data_format attribute to value. // -// If the tensor has rank `R` and `N` non-empty values, `input_indices` has -// shape `[N, R]`, input_values has length `N`, and input_shape has length `R`. +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func MaxPool3DDataFormat(value string) MaxPool3DAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Performs 3D max pooling on the input. // // Arguments: -// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. -// input_shape: 1-D. Shape of the input SparseTensor. +// input: Shape `[batch, depth, rows, cols, channels]` tensor to pool over. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. // -// Returns 2-D. `N x R` matrix with the same indices as input_indices, but -// in canonical row-major ordering.1-D. `N` non-empty values corresponding to `output_indices`. -func SparseReorder(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { +// Returns The max pooled output tensor. +func MaxPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SparseReorder", + Type: "MaxPool3D", Input: []tf.Input{ - input_indices, input_values, input_shape, + input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Computes rectified linear: `max(features, 0)`. -func Relu(scope *Scope, features tf.Output) (activations tf.Output) { +// Computes the gradients of 3-D convolution with respect to the input. +// +// DEPRECATED at GraphDef version 10: Use Conv3DBackpropInputV2 +// +// Arguments: +// input: Shape `[batch, depth, rows, cols, in_channels]`. +// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. +// `in_channels` must match between `input` and `filter`. +// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, +// out_channels]`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func Conv3DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "Relu", + Type: "Conv3DBackpropInput", Input: []tf.Input{ - features, + input, filter, out_backprop, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceApplyAddSignAttr is an optional argument to ResourceApplyAddSign. -type ResourceApplyAddSignAttr func(optionalAttr) +// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad. +type ResourceApplyProximalAdagradAttr func(optionalAttr) -// ResourceApplyAddSignUseLocking sets the optional use_locking attribute to value. +// ResourceApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. // -// value: If `True`, updating of the var and m tensors is -// protected by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. // If not specified, defaults to false -func ResourceApplyAddSignUseLocking(value bool) ResourceApplyAddSignAttr { +func ResourceApplyProximalAdagradUseLocking(value bool) ResourceApplyProximalAdagradAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Update '*var' according to the AddSign update. +// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. // -// m_t <- beta1 * m_{t-1} + (1 - beta1) * g -// update <- (alpha + sign_decay * sign(g) *sign(m)) * g -// variable <- variable - lr_t * update +// accum += grad * grad +// prox_v = var - lr * grad * (1 / sqrt(accum)) +// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} // // Arguments: // var_: Should be from a Variable(). -// m: Should be from a Variable(). +// accum: Should be from a Variable(). // lr: Scaling factor. Must be a scalar. -// alpha: Must be a scalar. -// sign_decay: Must be a scalar. -// beta: Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. // grad: The gradient. // // Returns the created operation. -func ResourceApplyAddSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Output, alpha tf.Output, sign_decay tf.Output, beta tf.Output, grad tf.Output, optional ...ResourceApplyAddSignAttr) (o *tf.Operation) { +func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, optional ...ResourceApplyProximalAdagradAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -10787,96 +10597,201 @@ func ResourceApplyAddSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Outpu a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyAddSign", + Type: "ResourceApplyProximalAdagrad", Input: []tf.Input{ - var_, m, lr, alpha, sign_decay, beta, grad, + var_, accum, lr, l1, l2, grad, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// FractionalMaxPoolGradAttr is an optional argument to FractionalMaxPoolGrad. -type FractionalMaxPoolGradAttr func(optionalAttr) +// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2. +type MutableHashTableOfTensorsV2Attr func(optionalAttr) -// FractionalMaxPoolGradOverlapping sets the optional overlapping attribute to value. -// -// value: When set to True, it means when pooling, the values at the boundary -// of adjacent pooling cells are used by both cells. For example: -// -// `index 0 1 2 3 4` +// MutableHashTableOfTensorsV2Container sets the optional container attribute to value. // -// `value 20 5 16 3 7` +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value. // -// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. -// The result would be [20, 16] for fractional max pooling. +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. // If not specified, defaults to false -func FractionalMaxPoolGradOverlapping(value bool) FractionalMaxPoolGradAttr { +func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr { return func(m optionalAttr) { - m["overlapping"] = value + m["use_node_name_sharing"] = value + } +} + +// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value. +// If not specified, defaults to <> +func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["value_shape"] = value + } +} + +// Creates an empty hash table. +// +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a vector. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. +// +// Arguments: +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MutableHashTableOfTensorsV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Inverse 2D fast Fourier transform. +// +// Computes the inverse 2-dimensional discrete Fourier transform over the +// inner-most 2 dimensions of `input`. +// +// Arguments: +// input: A complex64 tensor. +// +// Returns A complex64 tensor of the same shape as `input`. The inner-most 2 +// dimensions of `input` are replaced with their inverse 2D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.ifft2 +// @end_compatibility +func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IFFT2D", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a tensor filled with a scalar value. +// +// This operation creates a tensor of shape `dims` and fills it with `value`. +// +// For example: +// +// ``` +// # Output tensor has shape [2, 3]. +// fill([2, 3], 9) ==> [[9, 9, 9] +// [9, 9, 9]] +// ``` +// +// Arguments: +// dims: 1-D. Represents the shape of the output tensor. +// value: 0-D (scalar). Value to fill the returned tensor. +// +// @compatibility(numpy) +// Equivalent to np.full +// @end_compatibility +func Fill(scope *Scope, dims tf.Output, value tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Fill", + Input: []tf.Input{ + dims, value, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Computes gradient of the FractionalMaxPool function. +// 2D fast Fourier transform. +// +// Computes the 2-dimensional discrete Fourier transform over the inner-most +// 2 dimensions of `input`. // // Arguments: -// orig_input: Original input for `fractional_max_pool` -// orig_output: Original output for `fractional_max_pool` -// out_backprop: 4-D with shape `[batch, height, width, channels]`. Gradients -// w.r.t. the output of `fractional_max_pool`. -// row_pooling_sequence: row pooling sequence, form pooling region with -// col_pooling_sequence. -// col_pooling_sequence: column pooling sequence, form pooling region with -// row_pooling sequence. +// input: A complex64 tensor. // -// Returns 4-D. Gradients w.r.t. the input of `fractional_max_pool`. -func FractionalMaxPoolGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, out_backprop tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output, optional ...FractionalMaxPoolGradAttr) (output tf.Output) { +// Returns A complex64 tensor of the same shape as `input`. The inner-most 2 +// dimensions of `input` are replaced with their 2D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.fft2 +// @end_compatibility +func FFT2D(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "FractionalMaxPoolGrad", + Type: "FFT2D", Input: []tf.Input{ - orig_input, orig_output, out_backprop, row_pooling_sequence, col_pooling_sequence, + input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceApplyAdagradDAAttr is an optional argument to ResourceApplyAdagradDA. -type ResourceApplyAdagradDAAttr func(optionalAttr) +// ResourceApplyProximalGradientDescentAttr is an optional argument to ResourceApplyProximalGradientDescent. +type ResourceApplyProximalGradientDescentAttr func(optionalAttr) -// ResourceApplyAdagradDAUseLocking sets the optional use_locking attribute to value. +// ResourceApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value. // -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// value: If True, the subtraction will be protected by a lock; +// otherwise the behavior is undefined, but may exhibit less contention. // If not specified, defaults to false -func ResourceApplyAdagradDAUseLocking(value bool) ResourceApplyAdagradDAAttr { +func ResourceApplyProximalGradientDescentUseLocking(value bool) ResourceApplyProximalGradientDescentAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Update '*var' according to the proximal adagrad scheme. +// Update '*var' as FOBOS algorithm with fixed learning rate. +// +// prox_v = var - alpha * delta +// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} // // Arguments: // var_: Should be from a Variable(). -// gradient_accumulator: Should be from a Variable(). -// gradient_squared_accumulator: Should be from a Variable(). -// grad: The gradient. -// lr: Scaling factor. Must be a scalar. +// alpha: Scaling factor. Must be a scalar. // l1: L1 regularization. Must be a scalar. // l2: L2 regularization. Must be a scalar. -// global_step: Training step number. Must be a scalar. +// delta: The change. // // Returns the created operation. -func ResourceApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceApplyAdagradDAAttr) (o *tf.Operation) { +func ResourceApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, delta tf.Output, optional ...ResourceApplyProximalGradientDescentAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -10885,85 +10800,49 @@ func ResourceApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator t a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyAdagradDA", + Type: "ResourceApplyProximalGradientDescent", Input: []tf.Input{ - var_, gradient_accumulator, gradient_squared_accumulator, grad, lr, l1, l2, global_step, + var_, alpha, l1, l2, delta, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// SparseReduceMaxSparseAttr is an optional argument to SparseReduceMaxSparse. -type SparseReduceMaxSparseAttr func(optionalAttr) - -// SparseReduceMaxSparseKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func SparseReduceMaxSparseKeepDims(value bool) SparseReduceMaxSparseAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the max of elements across dimensions of a SparseTensor. -// -// This Op takes a SparseTensor and is the sparse counterpart to -// `tf.reduce_max()`. In contrast to SparseReduceMax, this Op returns a -// SparseTensor. -// -// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained -// with length 1. -// -// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor -// with a single element is returned. Additionally, the axes can be negative, -// which are interpreted according to the indexing rules in Python. +// Computes the gradient for the sqrt of `x` wrt its input. // -// Arguments: -// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. -// input_shape: 1-D. Shape of the input SparseTensor. -// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. -func SparseReduceMaxSparse(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceMaxSparseAttr) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { +// Specifically, `grad = dy * 0.5 / y`, where `y = sqrt(x)`, and `dy` +// is the corresponding input gradient. +func SqrtGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "SparseReduceMaxSparse", + Type: "SqrtGrad", Input: []tf.Input{ - input_indices, input_values, input_shape, reduction_axes, + y, dy, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Creates a dataset that emits the outputs of `input_dataset` `count` times. +// Get the value of the tensor specified by its handle. // // Arguments: +// handle: The handle for a tensor stored in the session state. +// dtype: The type of the output value. // -// count: A scalar representing the number of times that `input_dataset` should -// be repeated. A value of `-1` indicates that it should be repeated infinitely. -// -// -func RepeatDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Returns The tensor for the given handle. +func GetSessionTensor(scope *Scope, handle tf.Output, dtype tf.DataType) (value tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "RepeatDataset", + Type: "GetSessionTensor", Input: []tf.Input{ - input_dataset, count, + handle, }, Attrs: attrs, } @@ -10971,110 +10850,104 @@ func RepeatDataset(scope *Scope, input_dataset tf.Output, count tf.Output, outpu return op.Output(0) } -// AddManySparseToTensorsMapAttr is an optional argument to AddManySparseToTensorsMap. -type AddManySparseToTensorsMapAttr func(optionalAttr) - -// AddManySparseToTensorsMapContainer sets the optional container attribute to value. +// Returns x - y element-wise. // -// value: The container name for the `SparseTensorsMap` created by this op. -// If not specified, defaults to "" -func AddManySparseToTensorsMapContainer(value string) AddManySparseToTensorsMapAttr { - return func(m optionalAttr) { - m["container"] = value +// *NOTE*: `Subtract` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Sub(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return } -} - -// AddManySparseToTensorsMapSharedName sets the optional shared_name attribute to value. -// -// value: The shared name for the `SparseTensorsMap` created by this op. -// If blank, the new Operation's unique name is used. -// If not specified, defaults to "" -func AddManySparseToTensorsMapSharedName(value string) AddManySparseToTensorsMapAttr { - return func(m optionalAttr) { - m["shared_name"] = value + opspec := tf.OpSpec{ + Type: "Sub", + Input: []tf.Input{ + x, y, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Add an `N`-minibatch `SparseTensor` to a `SparseTensorsMap`, return `N` handles. -// -// A `SparseTensor` of rank `R` is represented by three tensors: `sparse_indices`, -// `sparse_values`, and `sparse_shape`, where -// -// ```sparse_indices.shape[1] == sparse_shape.shape[0] == R``` -// -// An `N`-minibatch of `SparseTensor` objects is represented as a `SparseTensor` -// having a first `sparse_indices` column taking values between `[0, N)`, where -// the minibatch size `N == sparse_shape[0]`. -// -// The input `SparseTensor` must have rank `R` greater than 1, and the first -// dimension is treated as the minibatch dimension. Elements of the `SparseTensor` -// must be sorted in increasing order of this first dimension. The stored -// `SparseTensor` objects pointed to by each row of the output `sparse_handles` -// will have rank `R-1`. +// Computes softmax cross entropy cost and gradients to backpropagate. // -// The `SparseTensor` values can then be read out as part of a minibatch by passing -// the given keys as vector elements to `TakeManySparseFromTensorsMap`. To ensure -// the correct `SparseTensorsMap` is accessed, ensure that the same -// `container` and `shared_name` are passed to that Op. If no `shared_name` -// is provided here, instead use the *name* of the Operation created by calling -// `AddManySparseToTensorsMap` as the `shared_name` passed to -// `TakeManySparseFromTensorsMap`. Ensure the Operations are colocated. +// Inputs are the logits, not probabilities. // // Arguments: -// sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`. -// `sparse_indices[:, 0]` must be ordered values in `[0, N)`. -// sparse_values: 1-D. The `values` of the minibatch `SparseTensor`. -// sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`. -// The minibatch size `N == sparse_shape[0]`. +// features: batch_size x num_classes matrix +// labels: batch_size x num_classes matrix +// The caller must ensure that each batch of labels represents a valid +// probability distribution. // -// Returns 1-D. The handles of the `SparseTensor` now stored in the -// `SparseTensorsMap`. Shape: `[N]`. -func AddManySparseToTensorsMap(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...AddManySparseToTensorsMapAttr) (sparse_handles tf.Output) { +// Returns Per example loss (batch_size vector).backpropagated gradients (batch_size x num_classes matrix). +func SoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.Output) (loss tf.Output, backprop tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "AddManySparseToTensorsMap", + Type: "SoftmaxCrossEntropyWithLogits", Input: []tf.Input{ - sparse_indices, sparse_values, sparse_shape, + features, labels, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// MinAttr is an optional argument to Min. -type MinAttr func(optionalAttr) +// ReduceJoinAttr is an optional argument to ReduceJoin. +type ReduceJoinAttr func(optionalAttr) -// MinKeepDims sets the optional keep_dims attribute to value. +// ReduceJoinKeepDims sets the optional keep_dims attribute to value. // -// value: If true, retain reduced dimensions with length 1. +// value: If `True`, retain reduced dimensions with length `1`. // If not specified, defaults to false -func MinKeepDims(value bool) MinAttr { +func ReduceJoinKeepDims(value bool) ReduceJoinAttr { return func(m optionalAttr) { m["keep_dims"] = value } } -// Computes the minimum of elements across dimensions of a tensor. +// ReduceJoinSeparator sets the optional separator attribute to value. // -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. +// value: The separator to use when joining. +// If not specified, defaults to "" +func ReduceJoinSeparator(value string) ReduceJoinAttr { + return func(m optionalAttr) { + m["separator"] = value + } +} + +// Joins a string Tensor across the given dimensions. +// +// Computes the string join across dimensions in the given string Tensor of shape +// `[d_0, d_1, ..., d_n-1]`. Returns a new Tensor created by joining the input +// strings with the given separator (default: empty string). Negative indices are +// counted backwards from the end, with `-1` being equivalent to `n - 1`. +// +// For example: +// +// ```python +// # tensor `a` is [["a", "b"], ["c", "d"]] +// tf.reduce_join(a, 0) ==> ["ac", "bd"] +// tf.reduce_join(a, 1) ==> ["ab", "cd"] +// tf.reduce_join(a, -2) = tf.reduce_join(a, 0) ==> ["ac", "bd"] +// tf.reduce_join(a, -1) = tf.reduce_join(a, 1) ==> ["ab", "cd"] +// tf.reduce_join(a, 0, keep_dims=True) ==> [["ac", "bd"]] +// tf.reduce_join(a, 1, keep_dims=True) ==> [["ab"], ["cd"]] +// tf.reduce_join(a, 0, separator=".") ==> ["a.c", "b.d"] +// tf.reduce_join(a, [0, 1]) ==> ["acbd"] +// tf.reduce_join(a, [1, 0]) ==> ["abcd"] +// tf.reduce_join(a, []) ==> ["abcd"] +// ``` // // Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. +// inputs: The input to be joined. All reduced indices must have non-zero size. +// reduction_indices: The dimensions to reduce over. Dimensions are reduced in the +// order specified. Omitting `reduction_indices` is equivalent to passing +// `[n-1, n-2, ..., 0]`. Negative indices from `-n` to `-1` are supported. // -// Returns The reduced tensor. -func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (output tf.Output) { +// Returns Has shape equal to that of the input with reduced dimensions removed or +// set to `1` depending on `keep_dims`. +func ReduceJoin(scope *Scope, inputs tf.Output, reduction_indices tf.Output, optional ...ReduceJoinAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -11083,9 +10956,9 @@ func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (ou a(attrs) } opspec := tf.OpSpec{ - Type: "Min", + Type: "ReduceJoin", Input: []tf.Input{ - input, axis, + inputs, reduction_indices, }, Attrs: attrs, } @@ -11093,323 +10966,199 @@ func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (ou return op.Output(0) } -// Shuffle dimensions of x according to a permutation. -// -// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: -// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` -func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) { +// Computes cos of x element-wise. +func Cos(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Transpose", + Type: "Cos", Input: []tf.Input{ - x, perm, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter. -type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr) +// FusedBatchNormGradAttr is an optional argument to FusedBatchNormGrad. +type FusedBatchNormGradAttr func(optionalAttr) -// DepthwiseConv2dNativeBackpropFilterDataFormat sets the optional data_format attribute to value. +// FusedBatchNormGradEpsilon sets the optional epsilon attribute to value. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, height, width, channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, channels, height, width]. +// value: A small float number added to the variance of x. +// If not specified, defaults to 0.0001 +func FusedBatchNormGradEpsilon(value float32) FusedBatchNormGradAttr { + return func(m optionalAttr) { + m["epsilon"] = value + } +} + +// FusedBatchNormGradDataFormat sets the optional data_format attribute to value. +// +// value: The data format for y_backprop, x, x_backprop. +// Either "NHWC" (default) or "NCHW". // If not specified, defaults to "NHWC" -func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2dNativeBackpropFilterAttr { +func FusedBatchNormGradDataFormat(value string) FusedBatchNormGradAttr { return func(m optionalAttr) { m["data_format"] = value } } -// DepthwiseConv2dNativeBackpropFilterDilations sets the optional dilations attribute to value. +// FusedBatchNormGradIsTraining sets the optional is_training attribute to value. // -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each filter -// element on that dimension. The dimension order is determined by the value of -// `data_format`, see above for details. Dilations in the batch and depth -// dimensions must be 1. -// If not specified, defaults to -func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { +// value: A bool value to indicate the operation is for training (default) +// or inference. +// If not specified, defaults to true +func FusedBatchNormGradIsTraining(value bool) FusedBatchNormGradAttr { return func(m optionalAttr) { - m["dilations"] = value + m["is_training"] = value } } -// Computes the gradients of depthwise convolution with respect to the filter. +// Gradient for batch normalization. +// +// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +// The size of 1D Tensors matches the dimension C of the 4D Tensors. // // Arguments: -// input: 4-D with shape based on `data_format`. For example, if -// `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height, -// in_width, in_channels]` tensor. -// filter_sizes: An integer vector representing the tensor shape of `filter`, -// where `filter` is a 4-D -// `[filter_height, filter_width, in_channels, depthwise_multiplier]` tensor. -// out_backprop: 4-D with shape based on `data_format`. -// For example, if `data_format` is 'NHWC' then -// out_backprop shape is `[batch, out_height, out_width, out_channels]`. -// Gradients w.r.t. the output of the convolution. -// strides: The stride of the sliding window for each dimension of the input -// of the convolution. -// padding: The type of padding algorithm to use. +// y_backprop: A 4D Tensor for the gradient with respect to y. +// x: A 4D Tensor for input data. +// scale: A 1D Tensor for scaling factor, to scale the normalized x. +// reserve_space_1: When is_training is True, a 1D Tensor for the computed batch +// mean to be reused in gradient computation. When is_training is +// False, a 1D Tensor for the population mean to be reused in both +// 1st and 2nd order gradient computation. +// reserve_space_2: When is_training is True, a 1D Tensor for the computed batch +// variance (inverted variance in the cuDNN case) to be reused in +// gradient computation. When is_training is False, a 1D Tensor +// for the population variance to be reused in both 1st and 2nd +// order gradient computation. // -// Returns 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. -// the `filter` input of the convolution. -func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropFilterAttr) (output tf.Output) { +// Returns A 4D Tensor for the gradient with respect to x.A 1D Tensor for the gradient with respect to scale.A 1D Tensor for the gradient with respect to offset.Unused placeholder to match the mean input in FusedBatchNorm.Unused placeholder to match the variance input +// in FusedBatchNorm. +func FusedBatchNormGrad(scope *Scope, y_backprop tf.Output, x tf.Output, scale tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output, optional ...FusedBatchNormGradAttr) (x_backprop tf.Output, scale_backprop tf.Output, offset_backprop tf.Output, reserve_space_3 tf.Output, reserve_space_4 tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DepthwiseConv2dNativeBackpropFilter", + Type: "FusedBatchNormGrad", Input: []tf.Input{ - input, filter_sizes, out_backprop, + y_backprop, x, scale, reserve_space_1, reserve_space_2, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Flushes the writer's unwritten events. -// -// Arguments: -// writer: A handle to the summary writer resource. -// -// Returns the created operation. -func FlushSummaryWriter(scope *Scope, writer tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "FlushSummaryWriter", - Input: []tf.Input{ - writer, - }, - } - return scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// QuantizeV2Attr is an optional argument to QuantizeV2. -type QuantizeV2Attr func(optionalAttr) - -// QuantizeV2Mode sets the optional mode attribute to value. -// If not specified, defaults to "MIN_COMBINED" -func QuantizeV2Mode(value string) QuantizeV2Attr { - return func(m optionalAttr) { - m["mode"] = value - } -} +// TopKAttr is an optional argument to TopK. +type TopKAttr func(optionalAttr) -// QuantizeV2RoundMode sets the optional round_mode attribute to value. -// If not specified, defaults to "HALF_AWAY_FROM_ZERO" -func QuantizeV2RoundMode(value string) QuantizeV2Attr { +// TopKSorted sets the optional sorted attribute to value. +// +// value: If true the resulting `k` elements will be sorted by the values in +// descending order. +// If not specified, defaults to true +func TopKSorted(value bool) TopKAttr { return func(m optionalAttr) { - m["round_mode"] = value + m["sorted"] = value } } -// Quantize the 'input' tensor of type float to 'output' tensor of type 'T'. -// -// [min_range, max_range] are scalar floats that specify the range for -// the 'input' data. The 'mode' attribute controls exactly which calculations are -// used to convert the float values to their quantized equivalents. The -// 'round_mode' attribute controls which rounding tie-breaking algorithm is used -// when rounding float values to their quantized equivalents. -// -// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: -// -// ``` -// out[i] = (in[i] - min_range) * range(T) / (max_range - min_range) -// if T == qint8, out[i] -= (range(T) + 1) / 2.0 -// ``` -// here `range(T) = numeric_limits::max() - numeric_limits::min()` -// -// *MIN_COMBINED Mode Example* -// -// Assume the input is type float and has a possible range of [0.0, 6.0] and the -// output type is quint8 ([0, 255]). The min_range and max_range values should be -// specified as 0.0 and 6.0. Quantizing from float to quint8 will multiply each -// value of the input by 255/6 and cast to quint8. -// -// If the output type was qint8 ([-128, 127]), the operation will additionally -// subtract each value by 128 prior to casting, so that the range of values aligns -// with the range of qint8. -// -// If the mode is 'MIN_FIRST', then this approach is used: -// -// ``` -// num_discrete_values = 1 << (# of bits in T) -// range_adjust = num_discrete_values / (num_discrete_values - 1) -// range = (range_max - range_min) * range_adjust -// range_scale = num_discrete_values / range -// quantized = round(input * range_scale) - round(range_min * range_scale) + -// numeric_limits::min() -// quantized = max(quantized, numeric_limits::min()) -// quantized = min(quantized, numeric_limits::max()) -// ``` -// -// The biggest difference between this and MIN_COMBINED is that the minimum range -// is rounded first, before it's subtracted from the rounded value. With -// MIN_COMBINED, a small bias is introduced where repeated iterations of quantizing -// and dequantizing will introduce a larger and larger error. -// -// *SCALED mode Example* -// -// `SCALED` mode matches the quantization approach used in -// `QuantizeAndDequantize{V2|V3}`. -// -// If the mode is `SCALED`, we do not use the full range of the output type, -// choosing to elide the lowest possible value for symmetry (e.g., output range is -// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to -// 0. -// -// We first find the range of values in our tensor. The -// range we use is always centered on 0, so we find m such that -// ```c++ -// m = max(abs(input_min), abs(input_max)) -// ``` -// -// Our input tensor range is then `[-m, m]`. -// -// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. -// If T is signed, this is -// ``` -// num_bits = sizeof(T) * 8 -// [min_fixed, max_fixed] = -// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1] -// ``` +// Finds values and indices of the `k` largest elements for the last dimension. // -// Otherwise, if T is unsigned, the fixed-point range is -// ``` -// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] -// ``` +// DEPRECATED at GraphDef version 7: Use TopKV2 instead // -// From this we compute our scaling factor, s: -// ```c++ -// s = (max_fixed - min_fixed) / (2 * m) -// ``` +// If the input is a vector (rank-1), finds the `k` largest entries in the vector +// and outputs their values and indices as vectors. Thus `values[j]` is the +// `j`-th largest entry in `input`, and its index is `indices[j]`. // -// Now we can quantize the elements of our tensor: -// ```c++ -// result = round(input * s) -// ``` +// For matrices (resp. higher rank input), computes the top `k` entries in each +// row (resp. vector along the last dimension). Thus, // -// One thing to watch out for is that the operator may choose to adjust the -// requested minimum and maximum values slightly during the quantization process, -// so you should always use the output ports as the range for further calculations. -// For example, if the requested minimum and maximum values are close to equal, -// they will be separated by a small epsilon value to prevent ill-formed quantized -// buffers from being created. Otherwise, you can end up with buffers where all the -// quantized values map to the same float value, which causes problems for -// operations that have to perform further calculations on them. +// values.shape = indices.shape = input.shape[:-1] + [k] // -// Arguments: +// If two elements are equal, the lower-index element appears first. // -// min_range: The minimum scalar value possibly produced for the input. -// max_range: The maximum scalar value possibly produced for the input. +// If `k` varies dynamically, use `TopKV2` below. // +// Arguments: +// input: 1-D or higher with last dimension at least `k`. +// k: Number of top elements to look for along the last dimension (along each +// row for matrices). // -// Returns The quantized data produced from the float input.The actual minimum scalar value used for the output.The actual maximum scalar value used for the output. -func QuantizeV2(scope *Scope, input tf.Output, min_range tf.Output, max_range tf.Output, T tf.DataType, optional ...QuantizeV2Attr) (output tf.Output, output_min tf.Output, output_max tf.Output) { +// Returns The `k` largest elements along each last dimensional slice.The indices of `values` within the last dimension of `input`. +func TopK(scope *Scope, input tf.Output, k int64, optional ...TopKAttr) (values tf.Output, indices tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"T": T} + attrs := map[string]interface{}{"k": k} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizeV2", + Type: "TopK", Input: []tf.Input{ - input, min_range, max_range, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0), op.Output(1) } -// Component-wise divides a SparseTensor by a dense Tensor. +// Compute the Hurwitz zeta function \\(\zeta(x, q)\\). // -// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not -// the other direction. +// The Hurwitz zeta function is defined as: // -// Arguments: -// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. -// sp_shape: 1-D. Shape of the input SparseTensor. -// dense: `R`-D. The dense Tensor operand. // -// Returns 1-D. The `N` values that are operated on. -func SparseDenseCwiseDiv(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { +// \\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\\) +func Zeta(scope *Scope, x tf.Output, q tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseDenseCwiseDiv", + Type: "Zeta", Input: []tf.Input{ - sp_indices, sp_values, sp_shape, dense, + x, q, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceApplyMomentumAttr is an optional argument to ResourceApplyMomentum. -type ResourceApplyMomentumAttr func(optionalAttr) - -// ResourceApplyMomentumUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyMomentumUseLocking(value bool) ResourceApplyMomentumAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} +// ProdAttr is an optional argument to Prod. +type ProdAttr func(optionalAttr) -// ResourceApplyMomentumUseNesterov sets the optional use_nesterov attribute to value. +// ProdKeepDims sets the optional keep_dims attribute to value. // -// value: If `True`, the tensor passed to compute grad will be -// var - lr * momentum * accum, so in the end, the var you get is actually -// var - lr * momentum * accum. +// value: If true, retain reduced dimensions with length 1. // If not specified, defaults to false -func ResourceApplyMomentumUseNesterov(value bool) ResourceApplyMomentumAttr { +func ProdKeepDims(value bool) ProdAttr { return func(m optionalAttr) { - m["use_nesterov"] = value + m["keep_dims"] = value } } -// Update '*var' according to the momentum scheme. Set use_nesterov = True if you -// -// want to use Nesterov momentum. +// Computes the product of elements across dimensions of a tensor. // -// accum = accum * momentum + grad -// var -= lr * accum +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// grad: The gradient. -// momentum: Momentum. Must be a scalar. +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. // -// Returns the created operation. -func ResourceApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, momentum tf.Output, optional ...ResourceApplyMomentumAttr) (o *tf.Operation) { +// Returns The reduced tensor. +func Prod(scope *Scope, input tf.Output, axis tf.Output, optional ...ProdAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -11418,56 +11167,68 @@ func ResourceApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf. a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyMomentum", + Type: "Prod", Input: []tf.Input{ - var_, accum, lr, grad, momentum, + input, axis, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// MaxPoolGradGradAttr is an optional argument to MaxPoolGradGrad. -type MaxPoolGradGradAttr func(optionalAttr) +// FusedResizeAndPadConv2DAttr is an optional argument to FusedResizeAndPadConv2D. +type FusedResizeAndPadConv2DAttr func(optionalAttr) -// MaxPoolGradGradDataFormat sets the optional data_format attribute to value. +// FusedResizeAndPadConv2DResizeAlignCorners sets the optional resize_align_corners attribute to value. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolGradGradDataFormat(value string) MaxPoolGradGradAttr { +// value: If true, rescale input by (new_height - 1) / (height - 1), +// which exactly aligns the 4 corners of images and resized images. If false, rescale +// by new_height / height. Treat similarly the width dimension. +// If not specified, defaults to false +func FusedResizeAndPadConv2DResizeAlignCorners(value bool) FusedResizeAndPadConv2DAttr { return func(m optionalAttr) { - m["data_format"] = value + m["resize_align_corners"] = value } } -// Computes second-order gradients of the maxpooling function. +// Performs a resize and padding as a preprocess during a convolution. +// +// It's often possible to do spatial transformations more efficiently as part of +// the packing stage of a convolution, so this op allows for an optimized +// implementation where these stages are fused together. This prevents the need to +// write out the intermediate results as whole tensors, reducing memory pressure, +// and we can get some latency gains by merging the transformation calculations. +// The data_format attribute for Conv2D isn't supported by this op, and defaults to +// 'NHWC' order. +// Internally this op uses a single per-graph scratch buffer, which means that it +// will block if multiple versions are being run in parallel. This is because this +// operator is primarily an optimization to minimize memory usage. // // Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. +// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. +// size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. +// paddings: A two-column matrix specifying the padding sizes. The number of +// rows must be the same as the rank of `input`. +// filter: 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. // -// Returns Gradients of gradients w.r.t. the input to `max_pool`. -func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradAttr) (output tf.Output) { +// strides: 1-D of length 4. The stride of the sliding window for each dimension +// of `input`. Must be in the same order as the dimension specified with format. +// padding: The type of padding algorithm to use. +func FusedResizeAndPadConv2D(scope *Scope, input tf.Output, size tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string, optional ...FusedResizeAndPadConv2DAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MaxPoolGradGrad", + Type: "FusedResizeAndPadConv2D", Input: []tf.Input{ - orig_input, orig_output, grad, + input, size, paddings, filter, }, Attrs: attrs, } @@ -11475,140 +11236,165 @@ func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, return op.Output(0) } -// Returns the truth value of (x >= y) element-wise. +// Inverse 3D fast Fourier transform. // -// *NOTE*: `GreaterEqual` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func GreaterEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Computes the inverse 3-dimensional discrete Fourier transform over the +// inner-most 3 dimensions of `input`. +// +// Arguments: +// input: A complex64 tensor. +// +// Returns A complex64 tensor of the same shape as `input`. The inner-most 3 +// dimensions of `input` are replaced with their inverse 3D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.ifftn with 3 dimensions. +// @end_compatibility +func IFFT3D(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "GreaterEqual", + Type: "IFFT3D", Input: []tf.Input{ - x, y, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Conv3DAttr is an optional argument to Conv3D. -type Conv3DAttr func(optionalAttr) - -// Conv3DDataFormat sets the optional data_format attribute to value. -// -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func Conv3DDataFormat(value string) Conv3DAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Conv3DDilations sets the optional dilations attribute to value. -// -// value: 1-D tensor of length 5. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each -// filter element on that dimension. The dimension order is determined by the -// value of `data_format`, see above for details. Dilations in the batch and -// depth dimensions must be 1. -// If not specified, defaults to -func Conv3DDilations(value []int64) Conv3DAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes a 3-D convolution given 5-D `input` and `filter` tensors. +// Adds `bias` to `value`. // -// In signal processing, cross-correlation is a measure of similarity of -// two waveforms as a function of a time-lag applied to one of them. This -// is also known as a sliding dot product or sliding inner-product. +// This is a deprecated version of BiasAdd and will be soon removed. // -// Our Conv3D implements a form of cross-correlation. +// This is a special case of `tf.add` where `bias` is restricted to be 1-D. +// Broadcasting is supported, so `value` may have any number of dimensions. // // Arguments: -// input: Shape `[batch, in_depth, in_height, in_width, in_channels]`. -// filter: Shape `[filter_depth, filter_height, filter_width, in_channels, -// out_channels]`. `in_channels` must match between `input` and `filter`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func Conv3D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...Conv3DAttr) (output tf.Output) { +// value: Any number of dimensions. +// bias: 1-D with size the last dimension of `value`. +// +// Returns Broadcasted sum of `value` and `bias`. +func BiasAddV1(scope *Scope, value tf.Output, bias tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Conv3D", + Type: "BiasAddV1", Input: []tf.Input{ - input, filter, + value, bias, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Adds up a SparseTensor and a dense Tensor, using these special rules: +// Reverses specific dimensions of a tensor. +// +// NOTE `tf.reverse` has now changed behavior in preparation for 1.0. +// `tf.reverse_v2` is currently an alias that will be deprecated before TF 1.0. +// +// Given a `tensor`, and a `int32` tensor `axis` representing the set of +// dimensions of `tensor` to reverse. This operation reverses each dimension +// `i` for which there exists `j` s.t. `axis[j] == i`. +// +// `tensor` can have up to 8 dimensions. The number of dimensions specified +// in `axis` may be 0 or more entries. If an index is specified more than +// once, a InvalidArgument error is raised. +// +// For example: +// +// ``` +// # tensor 't' is [[[[ 0, 1, 2, 3], +// # [ 4, 5, 6, 7], +// # [ 8, 9, 10, 11]], +// # [[12, 13, 14, 15], +// # [16, 17, 18, 19], +// # [20, 21, 22, 23]]]] +// # tensor 't' shape is [1, 2, 3, 4] // -// (1) Broadcasts the dense side to have the same shape as the sparse side, if -// eligible; -// (2) Then, only the dense values pointed to by the indices of the SparseTensor -// participate in the cwise addition. +// # 'dims' is [3] or 'dims' is [-1] +// reverse(t, dims) ==> [[[[ 3, 2, 1, 0], +// [ 7, 6, 5, 4], +// [ 11, 10, 9, 8]], +// [[15, 14, 13, 12], +// [19, 18, 17, 16], +// [23, 22, 21, 20]]]] // -// By these rules, the result is a logical SparseTensor with exactly the same -// indices and shape, but possibly with different non-zero values. The output of -// this Op is the resultant non-zero values. +// # 'dims' is '[1]' (or 'dims' is '[-3]') +// reverse(t, dims) ==> [[[[12, 13, 14, 15], +// [16, 17, 18, 19], +// [20, 21, 22, 23] +// [[ 0, 1, 2, 3], +// [ 4, 5, 6, 7], +// [ 8, 9, 10, 11]]]] +// +// # 'dims' is '[2]' (or 'dims' is '[-2]') +// reverse(t, dims) ==> [[[[8, 9, 10, 11], +// [4, 5, 6, 7], +// [0, 1, 2, 3]] +// [[20, 21, 22, 23], +// [16, 17, 18, 19], +// [12, 13, 14, 15]]]] +// ``` // // Arguments: -// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. -// sp_shape: 1-D. Shape of the input SparseTensor. -// dense: `R`-D. The dense Tensor operand. +// tensor: Up to 8-D. +// axis: 1-D. The indices of the dimensions to reverse. Must be in the range +// `[-rank(tensor), rank(tensor))`. // -// Returns 1-D. The `N` values that are operated on. -func SparseDenseCwiseAdd(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { +// Returns The same shape as `tensor`. +func ReverseV2(scope *Scope, tensor tf.Output, axis tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseDenseCwiseAdd", + Type: "ReverseV2", Input: []tf.Input{ - sp_indices, sp_values, sp_shape, dense, + tensor, axis, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Read an element from the TensorArray into output `value`. +// RealAttr is an optional argument to Real. +type RealAttr func(optionalAttr) + +// RealTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_FLOAT +func RealTout(value tf.DataType) RealAttr { + return func(m optionalAttr) { + m["Tout"] = value + } +} + +// Returns the real part of a complex number. // -// Arguments: -// handle: The handle to a TensorArray. +// Given a tensor `input` of complex numbers, this operation returns a tensor of +// type `float` that is the real part of each element in `input`. All elements in +// `input` must be complex numbers of the form \\(a + bj\\), where *a* is the real +// part returned by this operation and *b* is the imaginary part. // -// flow_in: A float scalar that enforces proper chaining of operations. -// dtype: The type of the elem that is returned. +// For example: // -// Returns The tensor that is read from the TensorArray. -func TensorArrayReadV3(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) { +// ``` +// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +// tf.real(input) ==> [-2.25, 3.25] +// ``` +func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TensorArrayReadV3", + Type: "Real", Input: []tf.Input{ - handle, index, flow_in, + input, }, Attrs: attrs, } @@ -11616,49 +11402,55 @@ func TensorArrayReadV3(scope *Scope, handle tf.Output, index tf.Output, flow_in return op.Output(0) } -// EncodePngAttr is an optional argument to EncodePng. -type EncodePngAttr func(optionalAttr) +// AudioSummaryAttr is an optional argument to AudioSummary. +type AudioSummaryAttr func(optionalAttr) -// EncodePngCompression sets the optional compression attribute to value. +// AudioSummaryMaxOutputs sets the optional max_outputs attribute to value. // -// value: Compression level. -// If not specified, defaults to -1 -func EncodePngCompression(value int64) EncodePngAttr { +// value: Max number of batch elements to generate audio for. +// If not specified, defaults to 3 +// +// REQUIRES: value >= 1 +func AudioSummaryMaxOutputs(value int64) AudioSummaryAttr { return func(m optionalAttr) { - m["compression"] = value + m["max_outputs"] = value } } -// PNG-encode an image. +// Outputs a `Summary` protocol buffer with audio. // -// `image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]` -// where `channels` is: +// DEPRECATED at GraphDef version 15: Use AudioSummaryV2. // -// * 1: for grayscale. -// * 2: for grayscale + alpha. -// * 3: for RGB. -// * 4: for RGBA. +// The summary has up to `max_outputs` summary values containing audio. The +// audio is built from `tensor` which must be 3-D with shape `[batch_size, +// frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are +// assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. // -// The ZLIB compression level, `compression`, can be -1 for the PNG-encoder -// default or a value from 0 to 9. 9 is the highest compression level, generating -// the smallest output, but is slower. +// The `tag` argument is a scalar `Tensor` of type `string`. It is used to +// build the `tag` of the summary values: +// +// * If `max_outputs` is 1, the summary value tag is '*tag*/audio'. +// * If `max_outputs` is greater than 1, the summary value tags are +// generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. // // Arguments: -// image: 3-D with shape `[height, width, channels]`. +// tag: Scalar. Used to build the `tag` attribute of the summary values. +// tensor: 2-D of shape `[batch_size, frames]`. +// sample_rate: The sample rate of the signal in hertz. // -// Returns 0-D. PNG-encoded image. -func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (contents tf.Output) { +// Returns Scalar. Serialized `Summary` protocol buffer. +func AudioSummary(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate float32, optional ...AudioSummaryAttr) (summary tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"sample_rate": sample_rate} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "EncodePng", + Type: "AudioSummary", Input: []tf.Input{ - image, + tag, tensor, }, Attrs: attrs, } @@ -11666,38 +11458,42 @@ func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (conten return op.Output(0) } -// DataFormatVecPermuteAttr is an optional argument to DataFormatVecPermute. -type DataFormatVecPermuteAttr func(optionalAttr) +// QrAttr is an optional argument to Qr. +type QrAttr func(optionalAttr) -// DataFormatVecPermuteSrcFormat sets the optional src_format attribute to value. +// QrFullMatrices sets the optional full_matrices attribute to value. // -// value: source data format. -// If not specified, defaults to "NHWC" -func DataFormatVecPermuteSrcFormat(value string) DataFormatVecPermuteAttr { +// value: If true, compute full-sized `q` and `r`. If false +// (the default), compute only the leading `P` columns of `q`. +// If not specified, defaults to false +func QrFullMatrices(value bool) QrAttr { return func(m optionalAttr) { - m["src_format"] = value + m["full_matrices"] = value } } -// DataFormatVecPermuteDstFormat sets the optional dst_format attribute to value. +// Computes the QR decompositions of one or more matrices. // -// value: destination data format. -// If not specified, defaults to "NCHW" -func DataFormatVecPermuteDstFormat(value string) DataFormatVecPermuteAttr { - return func(m optionalAttr) { - m["dst_format"] = value - } -} - -// Returns the permuted vector/tensor in the destination data format given the +// Computes the QR decomposition of each inner matrix in `tensor` such that +// `tensor[..., :, :] = q[..., :, :] * r[..., :,:])` // -// one in the source data format. +// ```python +// # a is a tensor. +// # q is a tensor of orthonormal matrices. +// # r is a tensor of upper triangular matrices. +// q, r = qr(a) +// q_full, r_full = qr(a, full_matrices=True) +// ``` // // Arguments: -// x: Vector of size 4 or Tensor of shape (4, 2) in source data format. +// input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions +// form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`. // -// Returns Vector of size 4 or Tensor of shape (4, 2) in destination data format. -func DataFormatVecPermute(scope *Scope, x tf.Output, optional ...DataFormatVecPermuteAttr) (y tf.Output) { +// Returns Orthonormal basis for range of `a`. If `full_matrices` is `False` then +// shape is `[..., M, P]`; if `full_matrices` is `True` then shape is +// `[..., M, M]`.Triangular factor. If `full_matrices` is `False` then shape is +// `[..., P, N]`. If `full_matrices` is `True` then shape is `[..., M, N]`. +func Qr(scope *Scope, input tf.Output, optional ...QrAttr) (q tf.Output, r tf.Output) { if scope.Err() != nil { return } @@ -11706,155 +11502,118 @@ func DataFormatVecPermute(scope *Scope, x tf.Output, optional ...DataFormatVecPe a(attrs) } opspec := tf.OpSpec{ - Type: "DataFormatVecPermute", + Type: "Qr", Input: []tf.Input{ - x, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Returns element-wise integer closest to x. -// -// If the result is midway between two representable values, -// the even representable is chosen. -// For example: -// -// ``` -// rint(-1.5) ==> -2.0 -// rint(0.5000001) ==> 1.0 -// rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.] -// ``` -func Rint(scope *Scope, x tf.Output) (y tf.Output) { +// Records the bytes size of each element of `input_dataset` in a StatsAggregator. +func BytesProducedStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Rint", + Type: "BytesProducedStatsDataset", Input: []tf.Input{ - x, + input_dataset, tag, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// OrderedMapUnstageNoKeyAttr is an optional argument to OrderedMapUnstageNoKey. -type OrderedMapUnstageNoKeyAttr func(optionalAttr) +// ResourceSparseApplyProximalGradientDescentAttr is an optional argument to ResourceSparseApplyProximalGradientDescent. +type ResourceSparseApplyProximalGradientDescentAttr func(optionalAttr) -// OrderedMapUnstageNoKeyCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// ResourceSparseApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value. // -// REQUIRES: value >= 0 -func OrderedMapUnstageNoKeyCapacity(value int64) OrderedMapUnstageNoKeyAttr { +// value: If True, the subtraction will be protected by a lock; +// otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyProximalGradientDescentUseLocking(value bool) ResourceSparseApplyProximalGradientDescentAttr { return func(m optionalAttr) { - m["capacity"] = value + m["use_locking"] = value } } -// OrderedMapUnstageNoKeyMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// Sparse update '*var' as FOBOS algorithm with fixed learning rate. // -// REQUIRES: value >= 0 -func OrderedMapUnstageNoKeyMemoryLimit(value int64) OrderedMapUnstageNoKeyAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// OrderedMapUnstageNoKeyContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func OrderedMapUnstageNoKeyContainer(value string) OrderedMapUnstageNoKeyAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// OrderedMapUnstageNoKeySharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func OrderedMapUnstageNoKeySharedName(value string) OrderedMapUnstageNoKeyAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op removes and returns the (key, value) element with the smallest +// That is for rows we have grad for, we update var as follows: +// prox_v = var - alpha * grad +// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} +// +// Arguments: +// var_: Should be from a Variable(). +// alpha: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. // -// key from the underlying container. If the underlying container -// does not contain elements, the op will block until it does. -func OrderedMapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapUnstageNoKeyAttr) (key tf.Output, values []tf.Output) { +// Returns the created operation. +func ResourceSparseApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalGradientDescentAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "OrderedMapUnstageNoKey", + Type: "ResourceSparseApplyProximalGradientDescent", Input: []tf.Input{ - indices, + var_, alpha, l1, l2, grad, indices, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - key = op.Output(idx) - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("OrderedMapUnstageNoKey", err) - return - } - return key, values + return scope.AddOperation(opspec) } -// MaxPool3DGradGradAttr is an optional argument to MaxPool3DGradGrad. -type MaxPool3DGradGradAttr func(optionalAttr) +// MeanAttr is an optional argument to Mean. +type MeanAttr func(optionalAttr) -// MaxPool3DGradGradDataFormat sets the optional data_format attribute to value. +// MeanKeepDims sets the optional keep_dims attribute to value. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func MaxPool3DGradGradDataFormat(value string) MaxPool3DGradGradAttr { +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func MeanKeepDims(value bool) MeanAttr { return func(m optionalAttr) { - m["data_format"] = value + m["keep_dims"] = value } } -// Computes second-order gradients of the maxpooling function. +// Computes the mean of elements across dimensions of a tensor. +// +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. // // Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. // -// Returns Gradients of gradients w.r.t. the input to `max_pool`. -func MaxPool3DGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradGradAttr) (output tf.Output) { +// Returns The reduced tensor. +func Mean(scope *Scope, input tf.Output, axis tf.Output, optional ...MeanAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MaxPool3DGradGrad", + Type: "Mean", Input: []tf.Input{ - orig_input, orig_output, grad, + input, axis, }, Attrs: attrs, } @@ -11862,223 +11621,219 @@ func MaxPool3DGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output return op.Output(0) } -// Conv3DBackpropFilterV2Attr is an optional argument to Conv3DBackpropFilterV2. -type Conv3DBackpropFilterV2Attr func(optionalAttr) +// InitializeTableFromTextFileV2Attr is an optional argument to InitializeTableFromTextFileV2. +type InitializeTableFromTextFileV2Attr func(optionalAttr) -// Conv3DBackpropFilterV2DataFormat sets the optional data_format attribute to value. +// InitializeTableFromTextFileV2VocabSize sets the optional vocab_size attribute to value. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { +// value: Number of elements of the file, use -1 if unknown. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func InitializeTableFromTextFileV2VocabSize(value int64) InitializeTableFromTextFileV2Attr { return func(m optionalAttr) { - m["data_format"] = value + m["vocab_size"] = value } } -// Conv3DBackpropFilterV2Dilations sets the optional dilations attribute to value. +// InitializeTableFromTextFileV2Delimiter sets the optional delimiter attribute to value. // -// value: 1-D tensor of length 5. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each -// filter element on that dimension. The dimension order is determined by the -// value of `data_format`, see above for details. Dilations in the batch and -// depth dimensions must be 1. -// If not specified, defaults to -func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { +// value: Delimiter to separate fields in a line. +// If not specified, defaults to "\t" +func InitializeTableFromTextFileV2Delimiter(value string) InitializeTableFromTextFileV2Attr { return func(m optionalAttr) { - m["dilations"] = value + m["delimiter"] = value } } -// Computes the gradients of 3-D convolution with respect to the filter. +// Initializes a table from a text file. +// +// It inserts one key-value pair into the table for each line of the file. +// The key and value is extracted from the whole line content, elements from the +// split line based on `delimiter` or the line number (starting from zero). +// Where to extract the key and value from a line is specified by `key_index` and +// `value_index`. +// +// - A value of -1 means use the line number(starting from zero), expects `int64`. +// - A value of -2 means use the whole line content, expects `string`. +// - A value >= 0 means use the index (starting at zero) of the split line based +// on `delimiter`. // // Arguments: -// input: Shape `[batch, depth, rows, cols, in_channels]`. -// filter_sizes: An integer vector representing the tensor shape of `filter`, -// where `filter` is a 5-D -// `[filter_depth, filter_height, filter_width, in_channels, out_channels]` -// tensor. -// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, -// out_channels]`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func Conv3DBackpropFilterV2(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropFilterV2Attr) (output tf.Output) { +// table_handle: Handle to a table which will be initialized. +// filename: Filename of a vocabulary text file. +// key_index: Column index in a line to get the table `key` values from. +// value_index: Column index that represents information of a line to get the table +// `value` values from. +// +// Returns the created operation. +func InitializeTableFromTextFileV2(scope *Scope, table_handle tf.Output, filename tf.Output, key_index int64, value_index int64, optional ...InitializeTableFromTextFileV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{"key_index": key_index, "value_index": value_index} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Conv3DBackpropFilterV2", + Type: "InitializeTableFromTextFileV2", Input: []tf.Input{ - input, filter_sizes, out_backprop, + table_handle, filename, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Execute a sub graph on a remote processor. -// -// The graph specifications(such as graph itself, input tensors and output names) -// are stored as a serialized protocol buffer of RemoteFusedGraphExecuteInfo -// as serialized_remote_fused_graph_execute_info. -// The specifications will be passed to a dedicated registered -// remote fused graph executor. The executor will send the graph specifications -// to a remote processor and execute that graph. The execution results -// will be passed to consumer nodes as outputs of this node. +// QuantizedReluAttr is an optional argument to QuantizedRelu. +type QuantizedReluAttr func(optionalAttr) + +// QuantizedReluOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_QUINT8 +func QuantizedReluOutType(value tf.DataType) QuantizedReluAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Computes Quantized Rectified Linear: `max(features, 0)` // // Arguments: -// inputs: Arbitrary number of tensors with arbitrary data types // -// serialized_remote_fused_graph_execute_info: Serialized protocol buffer -// of RemoteFusedGraphExecuteInfo which contains graph specifications. +// min_features: The float value that the lowest quantized value represents. +// max_features: The float value that the highest quantized value represents. // -// Returns Arbitrary number of tensors with arbitrary data types -func RemoteFusedGraphExecute(scope *Scope, inputs []tf.Output, Toutputs []tf.DataType, serialized_remote_fused_graph_execute_info string) (outputs []tf.Output) { +// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. +func QuantizedRelu(scope *Scope, features tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedReluAttr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"Toutputs": Toutputs, "serialized_remote_fused_graph_execute_info": serialized_remote_fused_graph_execute_info} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "RemoteFusedGraphExecute", + Type: "QuantizedRelu", Input: []tf.Input{ - tf.OutputList(inputs), + features, min_features, max_features, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { - scope.UpdateErr("RemoteFusedGraphExecute", err) - return - } - return outputs + return op.Output(0), op.Output(1), op.Output(2) } -// ThreadUnsafeUnigramCandidateSamplerAttr is an optional argument to ThreadUnsafeUnigramCandidateSampler. -type ThreadUnsafeUnigramCandidateSamplerAttr func(optionalAttr) - -// ThreadUnsafeUnigramCandidateSamplerSeed sets the optional seed attribute to value. +// Reshapes a SparseTensor to represent values in a new dense shape. // -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func ThreadUnsafeUnigramCandidateSamplerSeed(value int64) ThreadUnsafeUnigramCandidateSamplerAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// ThreadUnsafeUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// This operation has the same semantics as reshape on the represented dense +// tensor. The `input_indices` are recomputed based on the requested `new_shape`. // -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func ThreadUnsafeUnigramCandidateSamplerSeed2(value int64) ThreadUnsafeUnigramCandidateSamplerAttr { - return func(m optionalAttr) { - m["seed2"] = value +// If one component of `new_shape` is the special value -1, the size of that +// dimension is computed so that the total dense size remains constant. At +// most one component of `new_shape` can be -1. The number of dense elements +// implied by `new_shape` must be the same as the number of dense elements +// originally implied by `input_shape`. +// +// Reshaping does not affect the order of values in the SparseTensor. +// +// If the input tensor has rank `R_in` and `N` non-empty values, and `new_shape` +// has length `R_out`, then `input_indices` has shape `[N, R_in]`, +// `input_shape` has length `R_in`, `output_indices` has shape `[N, R_out]`, and +// `output_shape` has length `R_out`. +// +// Arguments: +// input_indices: 2-D. `N x R_in` matrix with the indices of non-empty values in a +// SparseTensor. +// input_shape: 1-D. `R_in` vector with the input SparseTensor's dense shape. +// new_shape: 1-D. `R_out` vector with the requested new dense shape. +// +// Returns 2-D. `N x R_out` matrix with the updated indices of non-empty +// values in the output SparseTensor.1-D. `R_out` vector with the full dense shape of the output +// SparseTensor. This is the same as `new_shape` but with any -1 dimensions +// filled in. +func SparseReshape(scope *Scope, input_indices tf.Output, input_shape tf.Output, new_shape tf.Output) (output_indices tf.Output, output_shape tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseReshape", + Input: []tf.Input{ + input_indices, input_shape, new_shape, + }, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } -// Generates labels for candidate sampling with a learned unigram distribution. -// -// See explanations of candidate sampling and the data formats at -// go/candidate-sampling. -// -// For each batch, this op picks a single set of sampled candidate labels. -// -// The advantages of sampling candidates per-batch are simplicity and the -// possibility of efficient dense matrix multiplication. The disadvantage is that -// the sampled candidates must be chosen independently of the context and of the -// true labels. -// -// Arguments: -// true_classes: A batch_size * num_true matrix, in which each row contains the -// IDs of the num_true target_classes in the corresponding original label. -// num_true: Number of true labels per context. -// num_sampled: Number of candidates to randomly sample. -// unique: If unique is true, we sample with rejection, so that all sampled -// candidates in a batch are unique. This requires some approximation to -// estimate the post-rejection sampling probabilities. -// range_max: The sampler will sample integers from the interval [0, range_max). +// Deprecated. Use TensorArraySplitV3 // -// Returns A vector of length num_sampled, in which each element is -// the ID of a sampled candidate.A batch_size * num_true matrix, representing -// the number of times each candidate is expected to occur in a batch -// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled -// candidate representing the number of times the candidate is expected -// to occur in a batch of sampled candidates. If unique=true, then this is a -// probability. -func ThreadUnsafeUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...ThreadUnsafeUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { +// DEPRECATED at GraphDef version 26: Use TensorArraySplitV3 +func TensorArraySplitV2(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ThreadUnsafeUnigramCandidateSampler", + Type: "TensorArraySplitV2", Input: []tf.Input{ - true_classes, + handle, value, lengths, flow_in, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// MaxPoolV2Attr is an optional argument to MaxPoolV2. -type MaxPoolV2Attr func(optionalAttr) +// PackAttr is an optional argument to Pack. +type PackAttr func(optionalAttr) -// MaxPoolV2DataFormat sets the optional data_format attribute to value. +// PackAxis sets the optional axis attribute to value. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolV2DataFormat(value string) MaxPoolV2Attr { +// value: Dimension along which to pack. Negative values wrap around, so the +// valid range is `[-(R+1), R+1)`. +// If not specified, defaults to 0 +func PackAxis(value int64) PackAttr { return func(m optionalAttr) { - m["data_format"] = value + m["axis"] = value } } -// Performs max pooling on the input. +// Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor. +// +// Packs the `N` tensors in `values` into a tensor with rank one higher than each +// tensor in `values`, by packing them along the `axis` dimension. +// Given a list of tensors of shape `(A, B, C)`; +// +// if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`. +// if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`. +// Etc. +// +// For example: +// +// ``` +// # 'x' is [1, 4] +// # 'y' is [2, 5] +// # 'z' is [3, 6] +// pack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. +// pack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]] +// ``` +// +// This is the opposite of `unpack`. // // Arguments: -// input: 4-D input to pool over. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. +// values: Must be of same shape and type. // -// Returns The max pooled output tensor. -func MaxPoolV2(scope *Scope, input tf.Output, ksize tf.Output, strides tf.Output, padding string, optional ...MaxPoolV2Attr) (output tf.Output) { +// Returns The packed tensor. +func Pack(scope *Scope, values []tf.Output, optional ...PackAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MaxPoolV2", + Type: "Pack", Input: []tf.Input{ - input, ksize, strides, + tf.OutputList(values), }, Attrs: attrs, } @@ -12086,170 +11841,149 @@ func MaxPoolV2(scope *Scope, input tf.Output, ksize tf.Output, strides tf.Output return op.Output(0) } -// Deprecated. Use TensorArrayReadV3 +// Reorders a SparseTensor into the canonical, row-major ordering. // -// DEPRECATED at GraphDef version 26: Use TensorArrayReadV3 -func TensorArrayReadV2(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) { +// Note that by convention, all sparse ops preserve the canonical ordering along +// increasing dimension number. The only time ordering can be violated is during +// manual manipulation of the indices and values vectors to add entries. +// +// Reordering does not affect the shape of the SparseTensor. +// +// If the tensor has rank `R` and `N` non-empty values, `input_indices` has +// shape `[N, R]`, input_values has length `N`, and input_shape has length `R`. +// +// Arguments: +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. +// +// Returns 2-D. `N x R` matrix with the same indices as input_indices, but +// in canonical row-major ordering.1-D. `N` non-empty values corresponding to `output_indices`. +func SparseReorder(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "TensorArrayReadV2", + Type: "SparseReorder", Input: []tf.Input{ - handle, index, flow_in, + input_indices, input_values, input_shape, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Does nothing. Serves as a control trigger for scheduling. -// -// Only useful as a placeholder for control edges. -// -// Returns the created operation. -func ControlTrigger(scope *Scope) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ControlTrigger", - } - return scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } -// Batch normalization. -// -// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() -// -// This op is deprecated. Prefer `tf.nn.batch_normalization`. -// -// Arguments: -// t: A 4D input Tensor. -// m: A 1D mean Tensor with size matching the last dimension of t. -// This is the first output from tf.nn.moments, -// or a saved moving average thereof. -// v: A 1D variance Tensor with size matching the last dimension of t. -// This is the second output from tf.nn.moments, -// or a saved moving average thereof. -// beta: A 1D beta Tensor with size matching the last dimension of t. -// An offset to be added to the normalized tensor. -// gamma: A 1D gamma Tensor with size matching the last dimension of t. -// If "scale_after_normalization" is true, this tensor will be multiplied -// with the normalized tensor. -// variance_epsilon: A small float number to avoid dividing by 0. -// scale_after_normalization: A bool indicating whether the resulted tensor -// needs to be multiplied with gamma. -func BatchNormWithGlobalNormalization(scope *Scope, t tf.Output, m tf.Output, v tf.Output, beta tf.Output, gamma tf.Output, variance_epsilon float32, scale_after_normalization bool) (result tf.Output) { +// Computes rectified linear: `max(features, 0)`. +func Relu(scope *Scope, features tf.Output) (activations tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} opspec := tf.OpSpec{ - Type: "BatchNormWithGlobalNormalization", + Type: "Relu", Input: []tf.Input{ - t, m, v, beta, gamma, + features, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// MutableDenseHashTableV2Attr is an optional argument to MutableDenseHashTableV2. -type MutableDenseHashTableV2Attr func(optionalAttr) - -// MutableDenseHashTableV2Container sets the optional container attribute to value. -// -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func MutableDenseHashTableV2Container(value string) MutableDenseHashTableV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} +// ResourceApplyAddSignAttr is an optional argument to ResourceApplyAddSign. +type ResourceApplyAddSignAttr func(optionalAttr) -// MutableDenseHashTableV2SharedName sets the optional shared_name attribute to value. +// ResourceApplyAddSignUseLocking sets the optional use_locking attribute to value. // -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func MutableDenseHashTableV2SharedName(value string) MutableDenseHashTableV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// MutableDenseHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// value: If `True`, updating of the var and m tensors is +// protected by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. // If not specified, defaults to false -func MutableDenseHashTableV2UseNodeNameSharing(value bool) MutableDenseHashTableV2Attr { +func ResourceApplyAddSignUseLocking(value bool) ResourceApplyAddSignAttr { return func(m optionalAttr) { - m["use_node_name_sharing"] = value + m["use_locking"] = value } } -// MutableDenseHashTableV2ValueShape sets the optional value_shape attribute to value. +// Update '*var' according to the AddSign update. // -// value: The shape of each value. -// If not specified, defaults to <> -func MutableDenseHashTableV2ValueShape(value tf.Shape) MutableDenseHashTableV2Attr { - return func(m optionalAttr) { - m["value_shape"] = value - } -} - -// MutableDenseHashTableV2InitialNumBuckets sets the optional initial_num_buckets attribute to value. +// m_t <- beta1 * m_{t-1} + (1 - beta1) * g +// update <- (alpha + sign_decay * sign(g) *sign(m)) * g +// variable <- variable - lr_t * update // -// value: The initial number of hash table buckets. Must be a power -// to 2. -// If not specified, defaults to 131072 -func MutableDenseHashTableV2InitialNumBuckets(value int64) MutableDenseHashTableV2Attr { - return func(m optionalAttr) { - m["initial_num_buckets"] = value +// Arguments: +// var_: Should be from a Variable(). +// m: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// alpha: Must be a scalar. +// sign_decay: Must be a scalar. +// beta: Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyAddSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Output, alpha tf.Output, sign_decay tf.Output, beta tf.Output, grad tf.Output, optional ...ResourceApplyAddSignAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyAddSign", + Input: []tf.Input{ + var_, m, lr, alpha, sign_decay, beta, grad, + }, + Attrs: attrs, } + return scope.AddOperation(opspec) } -// MutableDenseHashTableV2MaxLoadFactor sets the optional max_load_factor attribute to value. +// FractionalMaxPoolGradAttr is an optional argument to FractionalMaxPoolGrad. +type FractionalMaxPoolGradAttr func(optionalAttr) + +// FractionalMaxPoolGradOverlapping sets the optional overlapping attribute to value. // -// value: The maximum ratio between number of entries and number of -// buckets before growing the table. Must be between 0 and 1. -// If not specified, defaults to 0.8 -func MutableDenseHashTableV2MaxLoadFactor(value float32) MutableDenseHashTableV2Attr { +// value: When set to True, it means when pooling, the values at the boundary +// of adjacent pooling cells are used by both cells. For example: +// +// `index 0 1 2 3 4` +// +// `value 20 5 16 3 7` +// +// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. +// The result would be [20, 16] for fractional max pooling. +// If not specified, defaults to false +func FractionalMaxPoolGradOverlapping(value bool) FractionalMaxPoolGradAttr { return func(m optionalAttr) { - m["max_load_factor"] = value + m["overlapping"] = value } } -// Creates an empty hash table that uses tensors as the backing store. -// -// It uses "open addressing" with quadratic reprobing to resolve -// collisions. -// -// This op creates a mutable hash table, specifying the type of its keys and -// values. Each value must be a scalar. Data can be inserted into the table using -// the insert operations. It does not support the initialization operation. +// Computes gradient of the FractionalMaxPool function. // // Arguments: -// empty_key: The key used to represent empty key buckets internally. Must not -// be used in insert or lookup operations. -// value_dtype: Type of the table values. +// orig_input: Original input for `fractional_max_pool` +// orig_output: Original output for `fractional_max_pool` +// out_backprop: 4-D with shape `[batch, height, width, channels]`. Gradients +// w.r.t. the output of `fractional_max_pool`. +// row_pooling_sequence: row pooling sequence, form pooling region with +// col_pooling_sequence. +// col_pooling_sequence: column pooling sequence, form pooling region with +// row_pooling sequence. // -// Returns Handle to a table. -func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, value_dtype tf.DataType, optional ...MutableDenseHashTableV2Attr) (table_handle tf.Output) { +// Returns 4-D. Gradients w.r.t. the input of `fractional_max_pool`. +func FractionalMaxPoolGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, out_backprop tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output, optional ...FractionalMaxPoolGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"value_dtype": value_dtype} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MutableDenseHashTableV2", + Type: "FractionalMaxPoolGrad", Input: []tf.Input{ - empty_key, + orig_input, orig_output, out_backprop, row_pooling_sequence, col_pooling_sequence, }, Attrs: attrs, } @@ -12257,248 +11991,272 @@ func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, value_dtype tf.D return op.Output(0) } -// Produces the max pool of the input tensor for quantized types. +// ResourceApplyAdagradDAAttr is an optional argument to ResourceApplyAdagradDA. +type ResourceApplyAdagradDAAttr func(optionalAttr) + +// ResourceApplyAdagradDAUseLocking sets the optional use_locking attribute to value. +// +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceApplyAdagradDAUseLocking(value bool) ResourceApplyAdagradDAAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the proximal adagrad scheme. // // Arguments: -// input: The 4D (batch x rows x cols x depth) Tensor to MaxReduce over. -// min_input: The float value that the lowest quantized input value represents. -// max_input: The float value that the highest quantized input value represents. -// ksize: The size of the window for each dimension of the input tensor. -// The length must be 4 to match the number of dimensions of the input. -// strides: The stride of the sliding window for each dimension of the input -// tensor. The length must be 4 to match the number of dimensions of the input. -// padding: The type of padding algorithm to use. +// var_: Should be from a Variable(). +// gradient_accumulator: Should be from a Variable(). +// gradient_squared_accumulator: Should be from a Variable(). +// grad: The gradient. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// global_step: Training step number. Must be a scalar. // -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -func QuantizedMaxPool(scope *Scope, input tf.Output, min_input tf.Output, max_input tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output, min_output tf.Output, max_output tf.Output) { +// Returns the created operation. +func ResourceApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceApplyAdagradDAAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "QuantizedMaxPool", + Type: "ResourceApplyAdagradDA", Input: []tf.Input{ - input, min_input, max_input, + var_, gradient_accumulator, gradient_squared_accumulator, grad, lr, l1, l2, global_step, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return scope.AddOperation(opspec) } -// Computes softplus: `log(exp(features) + 1)`. -func Softplus(scope *Scope, features tf.Output) (activations tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Softplus", - Input: []tf.Input{ - features, - }, +// SparseReduceMaxSparseAttr is an optional argument to SparseReduceMaxSparse. +type SparseReduceMaxSparseAttr func(optionalAttr) + +// SparseReduceMaxSparseKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func SparseReduceMaxSparseKeepDims(value bool) SparseReduceMaxSparseAttr { + return func(m optionalAttr) { + m["keep_dims"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Computes exponential of x - 1 element-wise. +// Computes the max of elements across dimensions of a SparseTensor. // -// I.e., \\(y = (\exp x) - 1\\). -func Expm1(scope *Scope, x tf.Output) (y tf.Output) { +// This Op takes a SparseTensor and is the sparse counterpart to +// `tf.reduce_max()`. In contrast to SparseReduceMax, this Op returns a +// SparseTensor. +// +// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained +// with length 1. +// +// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor +// with a single element is returned. Additionally, the axes can be negative, +// which are interpreted according to the indexing rules in Python. +// +// Arguments: +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. +// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. +func SparseReduceMaxSparse(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceMaxSparseAttr) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Expm1", + Type: "SparseReduceMaxSparse", Input: []tf.Input{ - x, + input_indices, input_values, input_shape, reduction_axes, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Returns the number of records this Reader has produced. -// -// This is the same as the number of ReaderRead executions that have -// succeeded. +// Creates a dataset that emits the outputs of `input_dataset` `count` times. // // Arguments: -// reader_handle: Handle to a Reader. -func ReaderNumRecordsProducedV2(scope *Scope, reader_handle tf.Output) (records_produced tf.Output) { +// +// count: A scalar representing the number of times that `input_dataset` should +// be repeated. A value of `-1` indicates that it should be repeated infinitely. +// +// +func RepeatDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ReaderNumRecordsProducedV2", + Type: "RepeatDataset", Input: []tf.Input{ - reader_handle, + input_dataset, count, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the sum along segments of a tensor. +// AddManySparseToTensorsMapAttr is an optional argument to AddManySparseToTensorsMap. +type AddManySparseToTensorsMapAttr func(optionalAttr) + +// AddManySparseToTensorsMapContainer sets the optional container attribute to value. // -// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of -// segments. +// value: The container name for the `SparseTensorsMap` created by this op. +// If not specified, defaults to "" +func AddManySparseToTensorsMapContainer(value string) AddManySparseToTensorsMapAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// AddManySparseToTensorsMapSharedName sets the optional shared_name attribute to value. // -// Computes a tensor such that -// \\(output_i = \sum_j data_j\\) where sum is over `j` such -// that `segment_ids[j] == i`. +// value: The shared name for the `SparseTensorsMap` created by this op. +// If blank, the new Operation's unique name is used. +// If not specified, defaults to "" +func AddManySparseToTensorsMapSharedName(value string) AddManySparseToTensorsMapAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Add an `N`-minibatch `SparseTensor` to a `SparseTensorsMap`, return `N` handles. // -// If the sum is empty for a given segment ID `i`, `output[i] = 0`. +// A `SparseTensor` of rank `R` is represented by three tensors: `sparse_indices`, +// `sparse_values`, and `sparse_shape`, where // -//
-// -//
+// ```sparse_indices.shape[1] == sparse_shape.shape[0] == R``` // -// Arguments: +// An `N`-minibatch of `SparseTensor` objects is represented as a `SparseTensor` +// having a first `sparse_indices` column taking values between `[0, N)`, where +// the minibatch size `N == sparse_shape[0]`. // -// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s -// first dimension. Values should be sorted and can be repeated. +// The input `SparseTensor` must have rank `R` greater than 1, and the first +// dimension is treated as the minibatch dimension. Elements of the `SparseTensor` +// must be sorted in increasing order of this first dimension. The stored +// `SparseTensor` objects pointed to by each row of the output `sparse_handles` +// will have rank `R-1`. // -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SegmentSum(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SegmentSum", - Input: []tf.Input{ - data, segment_ids, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that emits the lines of one or more text files. +// The `SparseTensor` values can then be read out as part of a minibatch by passing +// the given keys as vector elements to `TakeManySparseFromTensorsMap`. To ensure +// the correct `SparseTensorsMap` is accessed, ensure that the same +// `container` and `shared_name` are passed to that Op. If no `shared_name` +// is provided here, instead use the *name* of the Operation created by calling +// `AddManySparseToTensorsMap` as the `shared_name` passed to +// `TakeManySparseFromTensorsMap`. Ensure the Operations are colocated. // // Arguments: -// filenames: A scalar or a vector containing the name(s) of the file(s) to be -// read. -// compression_type: A scalar containing either (i) the empty string (no -// compression), (ii) "ZLIB", or (iii) "GZIP". -// buffer_size: A scalar containing the number of bytes to buffer. -func TextLineDataset(scope *Scope, filenames tf.Output, compression_type tf.Output, buffer_size tf.Output) (handle tf.Output) { +// sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`. +// `sparse_indices[:, 0]` must be ordered values in `[0, N)`. +// sparse_values: 1-D. The `values` of the minibatch `SparseTensor`. +// sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`. +// The minibatch size `N == sparse_shape[0]`. +// +// Returns 1-D. The handles of the `SparseTensor` now stored in the +// `SparseTensorsMap`. Shape: `[N]`. +func AddManySparseToTensorsMap(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...AddManySparseToTensorsMapAttr) (sparse_handles tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TextLineDataset", + Type: "AddManySparseToTensorsMap", Input: []tf.Input{ - filenames, compression_type, buffer_size, + sparse_indices, sparse_values, sparse_shape, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Checks whether a resource handle-based variable has been initialized. +// Concatenates tensors along one dimension. // // Arguments: -// resource: the input resource handle. +// values: List of `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. +// axis: 0-D. The dimension along which to concatenate. Must be in the +// range [-rank(values), rank(values)). // -// Returns a scalar boolean which is true if the variable has been -// initialized. -func VarIsInitializedOp(scope *Scope, resource tf.Output) (is_initialized tf.Output) { +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes. +func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "VarIsInitializedOp", + Type: "ConcatV2", Input: []tf.Input{ - resource, + tf.OutputList(values), axis, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Pads a tensor with zeros. -// -// This operation pads a `input` with zeros according to the `paddings` you -// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the -// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates -// how many zeros to add before the contents of `input` in that dimension, and -// `paddings[D, 1]` indicates how many zeros to add after the contents of `input` -// in that dimension. -// -// The padded size of each dimension D of the output is: -// -// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` -// -// For example: -// -// ``` -// # 't' is [[1, 1], [2, 2]] -// # 'paddings' is [[1, 1], [2, 2]] -// # rank of 't' is 2 -// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] -// [0, 0, 1, 1, 0, 0] -// [0, 0, 2, 2, 0, 0] -// [0, 0, 0, 0, 0, 0]] -// ``` -func Pad(scope *Scope, input tf.Output, paddings tf.Output) (output tf.Output) { +// Reads and outputs the entire contents of the input filename. +func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Pad", + Type: "ReadFile", Input: []tf.Input{ - input, paddings, + filename, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// SparseTensorDenseMatMulAttr is an optional argument to SparseTensorDenseMatMul. -type SparseTensorDenseMatMulAttr func(optionalAttr) - -// SparseTensorDenseMatMulAdjointA sets the optional adjoint_a attribute to value. -// -// value: Use the adjoint of A in the matrix multiply. If A is complex, this -// is transpose(conj(A)). Otherwise it's transpose(A). -// If not specified, defaults to false -func SparseTensorDenseMatMulAdjointA(value bool) SparseTensorDenseMatMulAttr { - return func(m optionalAttr) { - m["adjoint_a"] = value - } -} +// MinAttr is an optional argument to Min. +type MinAttr func(optionalAttr) -// SparseTensorDenseMatMulAdjointB sets the optional adjoint_b attribute to value. +// MinKeepDims sets the optional keep_dims attribute to value. // -// value: Use the adjoint of B in the matrix multiply. If B is complex, this -// is transpose(conj(B)). Otherwise it's transpose(B). +// value: If true, retain reduced dimensions with length 1. // If not specified, defaults to false -func SparseTensorDenseMatMulAdjointB(value bool) SparseTensorDenseMatMulAttr { +func MinKeepDims(value bool) MinAttr { return func(m optionalAttr) { - m["adjoint_b"] = value + m["keep_dims"] = value } } -// Multiply SparseTensor (of rank 2) "A" by dense matrix "B". -// -// No validity checking is performed on the indices of A. However, the following -// input format is recommended for optimal behavior: +// Computes the minimum of elements across dimensions of a tensor. // -// if adjoint_a == false: -// A should be sorted in lexicographically increasing order. Use SparseReorder -// if you're not sure. -// if adjoint_a == true: -// A should be sorted in order of increasing dimension 1 (i.e., "column major" -// order instead of "row major" order). +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. // // Arguments: -// a_indices: 2-D. The `indices` of the `SparseTensor`, size `[nnz, 2]` Matrix. -// a_values: 1-D. The `values` of the `SparseTensor`, size `[nnz]` Vector. -// a_shape: 1-D. The `shape` of the `SparseTensor`, size `[2]` Vector. -// b: 2-D. A dense Matrix. -func SparseTensorDenseMatMul(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b tf.Output, optional ...SparseTensorDenseMatMulAttr) (product tf.Output) { +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. +// +// Returns The reduced tensor. +func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -12507,9 +12265,9 @@ func SparseTensorDenseMatMul(scope *Scope, a_indices tf.Output, a_values tf.Outp a(attrs) } opspec := tf.OpSpec{ - Type: "SparseTensorDenseMatMul", + Type: "Min", Input: []tf.Input{ - a_indices, a_values, a_shape, b, + input, axis, }, Attrs: attrs, } @@ -12517,354 +12275,233 @@ func SparseTensorDenseMatMul(scope *Scope, a_indices tf.Output, a_values tf.Outp return op.Output(0) } -// Deserialize and concatenate `SparseTensors` from a serialized minibatch. -// -// The input `serialized_sparse` must be a string matrix of shape `[N x 3]` where -// `N` is the minibatch size and the rows correspond to packed outputs of -// `SerializeSparse`. The ranks of the original `SparseTensor` objects -// must all match. When the final `SparseTensor` is created, it has rank one -// higher than the ranks of the incoming `SparseTensor` objects -// (they have been concatenated along a new row dimension). -// -// The output `SparseTensor` object's shape values for all dimensions but the -// first are the max across the input `SparseTensor` objects' shape values -// for the corresponding dimensions. Its first shape value is `N`, the minibatch -// size. -// -// The input `SparseTensor` objects' indices are assumed ordered in -// standard lexicographic order. If this is not the case, after this -// step run `SparseReorder` to restore index ordering. -// -// For example, if the serialized input is a `[2 x 3]` matrix representing two -// original `SparseTensor` objects: -// -// index = [ 0] -// [10] -// [20] -// values = [1, 2, 3] -// shape = [50] -// -// and -// -// index = [ 2] -// [10] -// values = [4, 5] -// shape = [30] -// -// then the final deserialized `SparseTensor` will be: -// -// index = [0 0] -// [0 10] -// [0 20] -// [1 2] -// [1 10] -// values = [1, 2, 3, 4, 5] -// shape = [2 50] +// Shuffle dimensions of x according to a permutation. // -// Arguments: -// serialized_sparse: 2-D, The `N` serialized `SparseTensor` objects. -// Must have 3 columns. -// dtype: The `dtype` of the serialized `SparseTensor` objects. -func DeserializeManySparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { +// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: +// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` +func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "DeserializeManySparse", + Type: "Transpose", Input: []tf.Input{ - serialized_sparse, + x, perm, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// StringJoinAttr is an optional argument to StringJoin. -type StringJoinAttr func(optionalAttr) - -// StringJoinSeparator sets the optional separator attribute to value. -// -// value: string, an optional join separator. -// If not specified, defaults to "" -func StringJoinSeparator(value string) StringJoinAttr { - return func(m optionalAttr) { - m["separator"] = value - } + return op.Output(0) } -// Joins the strings in the given list of string tensors into one tensor; -// -// with the given separator (default is an empty separator). +// Computes sigmoid of `x` element-wise. // -// Arguments: -// inputs: A list of string tensors. The tensors must all have the same shape, -// or be scalars. Scalars may be mixed in; these will be broadcast to the shape -// of non-scalar inputs. -func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (output tf.Output) { +// Specifically, `y = 1 / (1 + exp(-x))`. +func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "StringJoin", + Type: "Sigmoid", Input: []tf.Input{ - tf.OutputList(inputs), + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns immutable tensor from memory region. -// -// The current implementation memmaps the tensor from a file. +// FusedBatchNormAttr is an optional argument to FusedBatchNorm. +type FusedBatchNormAttr func(optionalAttr) + +// FusedBatchNormEpsilon sets the optional epsilon attribute to value. // -// Arguments: -// dtype: Type of the returned tensor. -// shape: Shape of the returned tensor. -// memory_region_name: Name of readonly memory region used by the tensor, see -// NewReadOnlyMemoryRegionFromFile in tensorflow::Env. -func ImmutableConst(scope *Scope, dtype tf.DataType, shape tf.Shape, memory_region_name string) (tensor tf.Output) { - if scope.Err() != nil { - return +// value: A small float number added to the variance of x. +// If not specified, defaults to 0.0001 +func FusedBatchNormEpsilon(value float32) FusedBatchNormAttr { + return func(m optionalAttr) { + m["epsilon"] = value } - attrs := map[string]interface{}{"dtype": dtype, "shape": shape, "memory_region_name": memory_region_name} - opspec := tf.OpSpec{ - Type: "ImmutableConst", +} - Attrs: attrs, +// FusedBatchNormDataFormat sets the optional data_format attribute to value. +// +// value: The data format for x and y. Either "NHWC" (default) or "NCHW". +// If not specified, defaults to "NHWC" +func FusedBatchNormDataFormat(value string) FusedBatchNormAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// FusedBatchNormIsTraining sets the optional is_training attribute to value. +// +// value: A bool value to indicate the operation is for training (default) +// or inference. +// If not specified, defaults to true +func FusedBatchNormIsTraining(value bool) FusedBatchNormAttr { + return func(m optionalAttr) { + m["is_training"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Inverse real-valued fast Fourier transform. -// -// Computes the inverse 1-dimensional discrete Fourier transform of a real-valued -// signal over the inner-most dimension of `input`. -// -// The inner-most dimension of `input` is assumed to be the result of `RFFT`: the -// `fft_length / 2 + 1` unique components of the DFT of a real-valued signal. If -// `fft_length` is not provided, it is computed from the size of the inner-most -// dimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to -// compute `input` is odd, it should be provided since it cannot be inferred -// properly. +// Batch normalization. // -// Along the axis `IRFFT` is computed on, if `fft_length / 2 + 1` is smaller -// than the corresponding dimension of `input`, the dimension is cropped. If it is -// larger, the dimension is padded with zeros. +// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +// The size of 1D Tensors matches the dimension C of the 4D Tensors. // // Arguments: -// input: A complex64 tensor. -// fft_length: An int32 tensor of shape [1]. The FFT length. -// -// Returns A float32 tensor of the same rank as `input`. The inner-most -// dimension of `input` is replaced with the `fft_length` samples of its inverse -// 1D Fourier transform. +// x: A 4D Tensor for input data. +// scale: A 1D Tensor for scaling factor, to scale the normalized x. +// offset: A 1D Tensor for offset, to shift to the normalized x. +// mean: A 1D Tensor for population mean. Used for inference only; +// must be empty for training. +// variance: A 1D Tensor for population variance. Used for inference only; +// must be empty for training. // -// @compatibility(numpy) -// Equivalent to np.fft.irfft -// @end_compatibility -func IRFFT(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { +// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow +// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by +// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused +// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance +// in the cuDNN case), to be reused in the gradient computation. +func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormAttr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "IRFFT", + Type: "FusedBatchNorm", Input: []tf.Input{ - input, fft_length, + x, scale, offset, mean, variance, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// Concatenates a list of `SparseTensor` along the specified dimension. -// -// Concatenation is with respect to the dense versions of these sparse tensors. -// It is assumed that each input is a `SparseTensor` whose elements are ordered -// along increasing dimension number. -// -// All inputs' shapes must match, except for the concat dimension. The -// `indices`, `values`, and `shapes` lists must have the same length. -// -// The output shape is identical to the inputs', except along the concat -// dimension, where it is the sum of the inputs' sizes along that dimension. -// -// The output elements will be resorted to preserve the sort order along -// increasing dimension number. -// -// This op runs in `O(M log M)` time, where `M` is the total number of non-empty -// values across all inputs. This is due to the need for an internal sort in -// order to concatenate efficiently across an arbitrary dimension. -// -// For example, if `concat_dim = 1` and the inputs are -// -// sp_inputs[0]: shape = [2, 3] -// [0, 2]: "a" -// [1, 0]: "b" -// [1, 1]: "c" -// -// sp_inputs[1]: shape = [2, 4] -// [0, 1]: "d" -// [0, 2]: "e" -// -// then the output will be +// RandomStandardNormalAttr is an optional argument to RandomStandardNormal. +type RandomStandardNormalAttr func(optionalAttr) + +// RandomStandardNormalSeed sets the optional seed attribute to value. // -// shape = [2, 7] -// [0, 2]: "a" -// [0, 4]: "d" -// [0, 5]: "e" -// [1, 0]: "b" -// [1, 1]: "c" +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomStandardNormalSeed(value int64) RandomStandardNormalAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomStandardNormalSeed2 sets the optional seed2 attribute to value. // -// Graphically this is equivalent to doing +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomStandardNormalSeed2(value int64) RandomStandardNormalAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Outputs random values from a normal distribution. // -// [ a] concat [ d e ] = [ a d e ] -// [b c ] [ ] [b c ] +// The generated values will have mean 0 and standard deviation 1. // // Arguments: -// indices: 2-D. Indices of each input `SparseTensor`. -// values: 1-D. Non-empty values of each `SparseTensor`. -// shapes: 1-D. Shapes of each `SparseTensor`. -// concat_dim: Dimension to concatenate along. Must be in range [-rank, rank), -// where rank is the number of dimensions in each input `SparseTensor`. +// shape: The shape of the output tensor. +// dtype: The type of the output. // -// Returns 2-D. Indices of the concatenated `SparseTensor`.1-D. Non-empty values of the concatenated `SparseTensor`.1-D. Shape of the concatenated `SparseTensor`. -func SparseConcat(scope *Scope, indices []tf.Output, values []tf.Output, shapes []tf.Output, concat_dim int64) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { +// Returns A tensor of the specified shape filled with random normal values. +func RandomStandardNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomStandardNormalAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"concat_dim": concat_dim} + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SparseConcat", + Type: "RandomStandardNormal", Input: []tf.Input{ - tf.OutputList(indices), tf.OutputList(values), tf.OutputList(shapes), + shape, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Generates sparse cross from a list of sparse and dense tensors. -// -// The op takes two lists, one of 2D `SparseTensor` and one of 2D `Tensor`, each -// representing features of one feature column. It outputs a 2D `SparseTensor` with -// the batchwise crosses of these features. -// -// For example, if the inputs are -// -// inputs[0]: SparseTensor with shape = [2, 2] -// [0, 0]: "a" -// [1, 0]: "b" -// [1, 1]: "c" -// -// inputs[1]: SparseTensor with shape = [2, 1] -// [0, 0]: "d" -// [1, 0]: "e" -// -// inputs[2]: Tensor [["f"], ["g"]] -// -// then the output will be -// -// shape = [2, 2] -// [0, 0]: "a_X_d_X_f" -// [1, 0]: "b_X_e_X_g" -// [1, 1]: "c_X_e_X_g" -// -// if hashed_output=true then the output will be +// Component-wise divides a SparseTensor by a dense Tensor. // -// shape = [2, 2] -// [0, 0]: FingerprintCat64( -// Fingerprint64("f"), FingerprintCat64( -// Fingerprint64("d"), Fingerprint64("a"))) -// [1, 0]: FingerprintCat64( -// Fingerprint64("g"), FingerprintCat64( -// Fingerprint64("e"), Fingerprint64("b"))) -// [1, 1]: FingerprintCat64( -// Fingerprint64("g"), FingerprintCat64( -// Fingerprint64("e"), Fingerprint64("c"))) +// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not +// the other direction. // // Arguments: -// indices: 2-D. Indices of each input `SparseTensor`. -// values: 1-D. values of each `SparseTensor`. -// shapes: 1-D. Shapes of each `SparseTensor`. -// dense_inputs: 2-D. Columns represented by dense `Tensor`. -// hashed_output: If true, returns the hash of the cross instead of the string. -// This will allow us avoiding string manipulations. -// num_buckets: It is used if hashed_output is true. -// output = hashed_value%num_buckets if num_buckets > 0 else hashed_value. -// hash_key: Specify the hash_key that will be used by the `FingerprintCat64` -// function to combine the crosses fingerprints. -// -// +// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. +// sp_shape: 1-D. Shape of the input SparseTensor. +// dense: `R`-D. The dense Tensor operand. // -// Returns 2-D. Indices of the concatenated `SparseTensor`.1-D. Non-empty values of the concatenated or hashed -// `SparseTensor`.1-D. Shape of the concatenated `SparseTensor`. -func SparseCross(scope *Scope, indices []tf.Output, values []tf.Output, shapes []tf.Output, dense_inputs []tf.Output, hashed_output bool, num_buckets int64, hash_key int64, out_type tf.DataType, internal_type tf.DataType) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { +// Returns 1-D. The `N` values that are operated on. +func SparseDenseCwiseDiv(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"hashed_output": hashed_output, "num_buckets": num_buckets, "hash_key": hash_key, "out_type": out_type, "internal_type": internal_type} opspec := tf.OpSpec{ - Type: "SparseCross", + Type: "SparseDenseCwiseDiv", Input: []tf.Input{ - tf.OutputList(indices), tf.OutputList(values), tf.OutputList(shapes), tf.OutputList(dense_inputs), + sp_indices, sp_values, sp_shape, dense, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// ListDiffAttr is an optional argument to ListDiff. -type ListDiffAttr func(optionalAttr) - -// ListDiffOutIdx sets the optional out_idx attribute to value. -// If not specified, defaults to DT_INT32 -func ListDiffOutIdx(value tf.DataType) ListDiffAttr { - return func(m optionalAttr) { - m["out_idx"] = value - } -} +// FractionalAvgPoolGradAttr is an optional argument to FractionalAvgPoolGrad. +type FractionalAvgPoolGradAttr func(optionalAttr) -// Computes the difference between two lists of numbers or strings. -// -// Given a list `x` and a list `y`, this operation returns a list `out` that -// represents all values that are in `x` but not in `y`. The returned list `out` -// is sorted in the same order that the numbers appear in `x` (duplicates are -// preserved). This operation also returns a list `idx` that represents the -// position of each `out` element in `x`. In other words: +// FractionalAvgPoolGradOverlapping sets the optional overlapping attribute to value. // -// `out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1]` +// value: When set to True, it means when pooling, the values at the boundary +// of adjacent pooling cells are used by both cells. For example: // -// For example, given this input: +// `index 0 1 2 3 4` // -// ``` -// x = [1, 2, 3, 4, 5, 6] -// y = [1, 3, 5] -// ``` +// `value 20 5 16 3 7` // -// This operation would return: +// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. +// The result would be [41/3, 26/3] for fractional avg pooling. +// If not specified, defaults to false +func FractionalAvgPoolGradOverlapping(value bool) FractionalAvgPoolGradAttr { + return func(m optionalAttr) { + m["overlapping"] = value + } +} + +// Computes gradient of the FractionalAvgPool function. // -// ``` -// out ==> [2, 4, 6] -// idx ==> [1, 3, 5] -// ``` +// Unlike FractionalMaxPoolGrad, we don't need to find arg_max for +// FractionalAvgPoolGrad, we just need to evenly back-propagate each element of +// out_backprop to those indices that form the same pooling cell. Therefore, we +// just need to know the shape of original input tensor, instead of the whole +// tensor. // // Arguments: -// x: 1-D. Values to keep. -// y: 1-D. Values to remove. +// orig_input_tensor_shape: Original input tensor shape for `fractional_avg_pool` +// out_backprop: 4-D with shape `[batch, height, width, channels]`. Gradients +// w.r.t. the output of `fractional_avg_pool`. +// row_pooling_sequence: row pooling sequence, form pooling region with +// col_pooling_sequence. +// col_pooling_sequence: column pooling sequence, form pooling region with +// row_pooling sequence. // -// Returns 1-D. Values present in `x` but not in `y`.1-D. Positions of `x` values preserved in `out`. -func ListDiff(scope *Scope, x tf.Output, y tf.Output, optional ...ListDiffAttr) (out tf.Output, idx tf.Output) { +// Returns 4-D. Gradients w.r.t. the input of `fractional_avg_pool`. +func FractionalAvgPoolGrad(scope *Scope, orig_input_tensor_shape tf.Output, out_backprop tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output, optional ...FractionalAvgPoolGradAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -12873,244 +12510,218 @@ func ListDiff(scope *Scope, x tf.Output, y tf.Output, optional ...ListDiffAttr) a(attrs) } opspec := tf.OpSpec{ - Type: "ListDiff", + Type: "FractionalAvgPoolGrad", Input: []tf.Input{ - x, y, + orig_input_tensor_shape, out_backprop, row_pooling_sequence, col_pooling_sequence, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Adds up a `SparseTensor` and a dense `Tensor`, producing a dense `Tensor`. -// -// This Op does not require `a_indices` be sorted in standard lexicographic order. +// Concatenates tensors along one dimension. // // Arguments: -// a_indices: 2-D. The `indices` of the `SparseTensor`, with shape `[nnz, ndims]`. -// a_values: 1-D. The `values` of the `SparseTensor`, with shape `[nnz]`. -// a_shape: 1-D. The `shape` of the `SparseTensor`, with shape `[ndims]`. -// b: `ndims`-D Tensor. With shape `a_shape`. -func SparseTensorDenseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b tf.Output) (output tf.Output) { +// concat_dim: 0-D. The dimension along which to concatenate. Must be in the +// range [0, rank(values)). +// values: The `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. +// +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes. +func Concat(scope *Scope, concat_dim tf.Output, values []tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseTensorDenseAdd", + Type: "Concat", Input: []tf.Input{ - a_indices, a_values, a_shape, b, + concat_dim, tf.OutputList(values), }, } op := scope.AddOperation(opspec) return op.Output(0) } -// SparseToSparseSetOperationAttr is an optional argument to SparseToSparseSetOperation. -type SparseToSparseSetOperationAttr func(optionalAttr) +// ResourceApplyMomentumAttr is an optional argument to ResourceApplyMomentum. +type ResourceApplyMomentumAttr func(optionalAttr) -// SparseToSparseSetOperationValidateIndices sets the optional validate_indices attribute to value. -// If not specified, defaults to true -func SparseToSparseSetOperationValidateIndices(value bool) SparseToSparseSetOperationAttr { +// ResourceApplyMomentumUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyMomentumUseLocking(value bool) ResourceApplyMomentumAttr { return func(m optionalAttr) { - m["validate_indices"] = value + m["use_locking"] = value } } -// Applies set operation along last dimension of 2 `SparseTensor` inputs. -// -// See SetOperationOp::SetOperationFromContext for values of `set_operation`. -// -// If `validate_indices` is `True`, `SparseToSparseSetOperation` validates the -// order and range of `set1` and `set2` indices. -// -// Input `set1` is a `SparseTensor` represented by `set1_indices`, `set1_values`, -// and `set1_shape`. For `set1` ranked `n`, 1st `n-1` dimensions must be the same -// as `set2`. Dimension `n` contains values in a set, duplicates are allowed but -// ignored. +// ResourceApplyMomentumUseNesterov sets the optional use_nesterov attribute to value. // -// Input `set2` is a `SparseTensor` represented by `set2_indices`, `set2_values`, -// and `set2_shape`. For `set2` ranked `n`, 1st `n-1` dimensions must be the same -// as `set1`. Dimension `n` contains values in a set, duplicates are allowed but -// ignored. +// value: If `True`, the tensor passed to compute grad will be +// var - lr * momentum * accum, so in the end, the var you get is actually +// var - lr * momentum * accum. +// If not specified, defaults to false +func ResourceApplyMomentumUseNesterov(value bool) ResourceApplyMomentumAttr { + return func(m optionalAttr) { + m["use_nesterov"] = value + } +} + +// Update '*var' according to the momentum scheme. Set use_nesterov = True if you // -// If `validate_indices` is `True`, this op validates the order and range of `set1` -// and `set2` indices. +// want to use Nesterov momentum. // -// Output `result` is a `SparseTensor` represented by `result_indices`, -// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this -// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` -// dimension contains the result of `set_operation` applied to the corresponding -// `[0...n-1]` dimension of `set`. +// accum = accum * momentum + grad +// var -= lr * accum // // Arguments: -// set1_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major -// order. -// set1_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major -// order. -// set1_shape: 1D `Tensor`, shape of a `SparseTensor`. `set1_shape[0...n-1]` must -// be the same as `set2_shape[0...n-1]`, `set1_shape[n]` is the -// max set size across `0...n-1` dimensions. -// set2_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major -// order. -// set2_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major -// order. -// set2_shape: 1D `Tensor`, shape of a `SparseTensor`. `set2_shape[0...n-1]` must -// be the same as `set1_shape[0...n-1]`, `set2_shape[n]` is the -// max set size across `0...n-1` dimensions. -// +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// grad: The gradient. +// momentum: Momentum. Must be a scalar. // -// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is -// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` -// is the max result set size across all `0...n-1` dimensions. -func SparseToSparseSetOperation(scope *Scope, set1_indices tf.Output, set1_values tf.Output, set1_shape tf.Output, set2_indices tf.Output, set2_values tf.Output, set2_shape tf.Output, set_operation string, optional ...SparseToSparseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { +// Returns the created operation. +func ResourceApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, momentum tf.Output, optional ...ResourceApplyMomentumAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"set_operation": set_operation} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "SparseToSparseSetOperation", + Type: "ResourceApplyMomentum", Input: []tf.Input{ - set1_indices, set1_values, set1_shape, set2_indices, set2_values, set2_shape, + var_, accum, lr, grad, momentum, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return scope.AddOperation(opspec) } -// Computes numerical negative value element-wise. +// MaxPoolGradGradAttr is an optional argument to MaxPoolGradGrad. +type MaxPoolGradGradAttr func(optionalAttr) + +// MaxPoolGradGradDataFormat sets the optional data_format attribute to value. // -// I.e., \\(y = -x\\). -func Neg(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Neg", - Input: []tf.Input{ - x, - }, +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolGradGradDataFormat(value string) MaxPoolGradGradAttr { + return func(m optionalAttr) { + m["data_format"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Writes a `Summary` protocol buffer with a histogram. -// -// The generated -// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) -// has one summary value containing a histogram for `values`. -// -// This op reports an `InvalidArgument` error if any value is not finite. +// Computes second-order gradients of the maxpooling function. // // Arguments: -// writer: A handle to a summary writer. -// step: The step to write the summary for. -// tag: Scalar. Tag to use for the `Summary.Value`. -// values: Any shape. Values to use to build the histogram. +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. // -// Returns the created operation. -func WriteHistogramSummary(scope *Scope, writer tf.Output, step tf.Output, tag tf.Output, values tf.Output) (o *tf.Operation) { +// Returns Gradients of gradients w.r.t. the input to `max_pool`. +func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "WriteHistogramSummary", + Type: "MaxPoolGradGrad", Input: []tf.Input{ - writer, step, tag, values, + orig_input, orig_output, grad, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Adds two `SparseTensor` objects to produce another `SparseTensor`. -// -// The input `SparseTensor` objects' indices are assumed ordered in standard -// lexicographic order. If this is not the case, before this step run -// `SparseReorder` to restore index ordering. -// -// By default, if two values sum to zero at some index, the output `SparseTensor` -// would still include that particular location in its index, storing a zero in the -// corresponding value slot. To override this, callers can specify `thresh`, -// indicating that if the sum has a magnitude strictly smaller than `thresh`, its -// corresponding value and index would then not be included. In particular, -// `thresh == 0` (default) means everything is kept and actual thresholding happens -// only for a positive value. +// Returns element-wise integer closest to x. // -// In the following shapes, `nnz` is the count after taking `thresh` into account. +// If the result is midway between two representable values, +// the even representable is chosen. +// For example: // -// Arguments: -// a_indices: 2-D. The `indices` of the first `SparseTensor`, size `[nnz, ndims]` Matrix. -// a_values: 1-D. The `values` of the first `SparseTensor`, size `[nnz]` Vector. -// a_shape: 1-D. The `shape` of the first `SparseTensor`, size `[ndims]` Vector. -// b_indices: 2-D. The `indices` of the second `SparseTensor`, size `[nnz, ndims]` Matrix. -// b_values: 1-D. The `values` of the second `SparseTensor`, size `[nnz]` Vector. -// b_shape: 1-D. The `shape` of the second `SparseTensor`, size `[ndims]` Vector. -// thresh: 0-D. The magnitude threshold that determines if an output value/index -// pair takes space. -func SparseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output, thresh tf.Output) (sum_indices tf.Output, sum_values tf.Output, sum_shape tf.Output) { +// ``` +// rint(-1.5) ==> -2.0 +// rint(0.5000001) ==> 1.0 +// rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.] +// ``` +func Rint(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseAdd", + Type: "Rint", Input: []tf.Input{ - a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh, + x, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// OrderedMapPeekAttr is an optional argument to OrderedMapPeek. -type OrderedMapPeekAttr func(optionalAttr) +// OrderedMapUnstageNoKeyAttr is an optional argument to OrderedMapUnstageNoKey. +type OrderedMapUnstageNoKeyAttr func(optionalAttr) -// OrderedMapPeekCapacity sets the optional capacity attribute to value. +// OrderedMapUnstageNoKeyCapacity sets the optional capacity attribute to value. // If not specified, defaults to 0 // // REQUIRES: value >= 0 -func OrderedMapPeekCapacity(value int64) OrderedMapPeekAttr { +func OrderedMapUnstageNoKeyCapacity(value int64) OrderedMapUnstageNoKeyAttr { return func(m optionalAttr) { m["capacity"] = value } } -// OrderedMapPeekMemoryLimit sets the optional memory_limit attribute to value. +// OrderedMapUnstageNoKeyMemoryLimit sets the optional memory_limit attribute to value. // If not specified, defaults to 0 // // REQUIRES: value >= 0 -func OrderedMapPeekMemoryLimit(value int64) OrderedMapPeekAttr { +func OrderedMapUnstageNoKeyMemoryLimit(value int64) OrderedMapUnstageNoKeyAttr { return func(m optionalAttr) { m["memory_limit"] = value } } -// OrderedMapPeekContainer sets the optional container attribute to value. +// OrderedMapUnstageNoKeyContainer sets the optional container attribute to value. // If not specified, defaults to "" -func OrderedMapPeekContainer(value string) OrderedMapPeekAttr { +func OrderedMapUnstageNoKeyContainer(value string) OrderedMapUnstageNoKeyAttr { return func(m optionalAttr) { m["container"] = value } } -// OrderedMapPeekSharedName sets the optional shared_name attribute to value. +// OrderedMapUnstageNoKeySharedName sets the optional shared_name attribute to value. // If not specified, defaults to "" -func OrderedMapPeekSharedName(value string) OrderedMapPeekAttr { +func OrderedMapUnstageNoKeySharedName(value string) OrderedMapUnstageNoKeyAttr { return func(m optionalAttr) { m["shared_name"] = value } } -// Op peeks at the values at the specified key. If the +// Op removes and returns the (key, value) element with the smallest // -// underlying container does not contain this key -// this op will block until it does. This Op is optimized for -// performance. -func OrderedMapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapPeekAttr) (values []tf.Output) { +// key from the underlying container. If the underlying container +// does not contain elements, the op will block until it does. +func OrderedMapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapUnstageNoKeyAttr) (key tf.Output, values []tf.Output) { if scope.Err() != nil { return } @@ -13119,9 +12730,9 @@ func OrderedMapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf. a(attrs) } opspec := tf.OpSpec{ - Type: "OrderedMapPeek", + Type: "OrderedMapUnstageNoKey", Input: []tf.Input{ - key, indices, + indices, }, Attrs: attrs, } @@ -13131,122 +12742,56 @@ func OrderedMapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf. } var idx int var err error + key = op.Output(idx) if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("OrderedMapPeek", err) + scope.UpdateErr("OrderedMapUnstageNoKey", err) return } - return values -} - -// DecodeAndCropJpegAttr is an optional argument to DecodeAndCropJpeg. -type DecodeAndCropJpegAttr func(optionalAttr) - -// DecodeAndCropJpegChannels sets the optional channels attribute to value. -// -// value: Number of color channels for the decoded image. -// If not specified, defaults to 0 -func DecodeAndCropJpegChannels(value int64) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["channels"] = value - } -} - -// DecodeAndCropJpegRatio sets the optional ratio attribute to value. -// -// value: Downscaling ratio. -// If not specified, defaults to 1 -func DecodeAndCropJpegRatio(value int64) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["ratio"] = value - } -} - -// DecodeAndCropJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. -// -// value: If true use a slower but nicer upscaling of the -// chroma planes (yuv420/422 only). -// If not specified, defaults to true -func DecodeAndCropJpegFancyUpscaling(value bool) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["fancy_upscaling"] = value - } -} - -// DecodeAndCropJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. -// -// value: If true try to recover an image from truncated input. -// If not specified, defaults to false -func DecodeAndCropJpegTryRecoverTruncated(value bool) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["try_recover_truncated"] = value - } + return key, values } -// DecodeAndCropJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. -// -// value: The minimum required fraction of lines before a truncated -// input is accepted. -// If not specified, defaults to 1 -func DecodeAndCropJpegAcceptableFraction(value float32) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["acceptable_fraction"] = value - } -} +// MaxPool3DGradGradAttr is an optional argument to MaxPool3DGradGrad. +type MaxPool3DGradGradAttr func(optionalAttr) -// DecodeAndCropJpegDctMethod sets the optional dct_method attribute to value. +// MaxPool3DGradGradDataFormat sets the optional data_format attribute to value. // -// value: string specifying a hint about the algorithm used for -// decompression. Defaults to "" which maps to a system-specific -// default. Currently valid values are ["INTEGER_FAST", -// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal -// jpeg library changes to a version that does not have that specific -// option.) -// If not specified, defaults to "" -func DecodeAndCropJpegDctMethod(value string) DecodeAndCropJpegAttr { +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func MaxPool3DGradGradDataFormat(value string) MaxPool3DGradGradAttr { return func(m optionalAttr) { - m["dct_method"] = value + m["data_format"] = value } } -// Decode and Crop a JPEG-encoded image to a uint8 tensor. -// -// The attr `channels` indicates the desired number of color channels for the -// decoded image. -// -// Accepted values are: -// -// * 0: Use the number of channels in the JPEG-encoded image. -// * 1: output a grayscale image. -// * 3: output an RGB image. -// -// If needed, the JPEG-encoded image is transformed to match the requested number -// of color channels. -// -// The attr `ratio` allows downscaling the image by an integer factor during -// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than -// downscaling the image later. -// -// -// It is equivalent to a combination of decode and crop, but much faster by only -// decoding partial jpeg image. +// Computes second-order gradients of the maxpooling function. // // Arguments: -// contents: 0-D. The JPEG-encoded image. -// crop_window: 1-D. The crop window: [crop_y, crop_x, crop_height, crop_width]. +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. // -// Returns 3-D with shape `[height, width, channels]`.. -func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output, optional ...DecodeAndCropJpegAttr) (image tf.Output) { +// Returns Gradients of gradients w.r.t. the input to `max_pool`. +func MaxPool3DGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DecodeAndCropJpeg", + Type: "MaxPool3DGradGrad", Input: []tf.Input{ - contents, crop_window, + orig_input, orig_output, grad, }, Attrs: attrs, } @@ -13254,263 +12799,212 @@ func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output, return op.Output(0) } -// AllCandidateSamplerAttr is an optional argument to AllCandidateSampler. -type AllCandidateSamplerAttr func(optionalAttr) +// Conv3DBackpropFilterV2Attr is an optional argument to Conv3DBackpropFilterV2. +type Conv3DBackpropFilterV2Attr func(optionalAttr) -// AllCandidateSamplerSeed sets the optional seed attribute to value. +// Conv3DBackpropFilterV2DataFormat sets the optional data_format attribute to value. // -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func AllCandidateSamplerSeed(value int64) AllCandidateSamplerAttr { +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { - m["seed"] = value + m["data_format"] = value } } -// AllCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// Conv3DBackpropFilterV2Dilations sets the optional dilations attribute to value. // -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func AllCandidateSamplerSeed2(value int64) AllCandidateSamplerAttr { +// value: 1-D tensor of length 5. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each +// filter element on that dimension. The dimension order is determined by the +// value of `data_format`, see above for details. Dilations in the batch and +// depth dimensions must be 1. +// If not specified, defaults to +func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { return func(m optionalAttr) { - m["seed2"] = value + m["dilations"] = value } } -// Generates labels for candidate sampling with a learned unigram distribution. -// -// See explanations of candidate sampling and the data formats at -// go/candidate-sampling. -// -// For each batch, this op picks a single set of sampled candidate labels. -// -// The advantages of sampling candidates per-batch are simplicity and the -// possibility of efficient dense matrix multiplication. The disadvantage is that -// the sampled candidates must be chosen independently of the context and of the -// true labels. +// Computes the gradients of 3-D convolution with respect to the filter. // // Arguments: -// true_classes: A batch_size * num_true matrix, in which each row contains the -// IDs of the num_true target_classes in the corresponding original label. -// num_true: Number of true labels per context. -// num_sampled: Number of candidates to produce. -// unique: If unique is true, we sample with rejection, so that all sampled -// candidates in a batch are unique. This requires some approximation to -// estimate the post-rejection sampling probabilities. -// -// Returns A vector of length num_sampled, in which each element is -// the ID of a sampled candidate.A batch_size * num_true matrix, representing -// the number of times each candidate is expected to occur in a batch -// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled -// candidate representing the number of times the candidate is expected -// to occur in a batch of sampled candidates. If unique=true, then this is a -// probability. -func AllCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, optional ...AllCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { +// input: Shape `[batch, depth, rows, cols, in_channels]`. +// filter_sizes: An integer vector representing the tensor shape of `filter`, +// where `filter` is a 5-D +// `[filter_depth, filter_height, filter_width, in_channels, out_channels]` +// tensor. +// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, +// out_channels]`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func Conv3DBackpropFilterV2(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropFilterV2Attr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique} + attrs := map[string]interface{}{"strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "AllCandidateSampler", + Type: "Conv3DBackpropFilterV2", Input: []tf.Input{ - true_classes, + input, filter_sizes, out_backprop, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Returns the element-wise min of two SparseTensors. +// Execute a sub graph on a remote processor. // -// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. +// The graph specifications(such as graph itself, input tensors and output names) +// are stored as a serialized protocol buffer of RemoteFusedGraphExecuteInfo +// as serialized_remote_fused_graph_execute_info. +// The specifications will be passed to a dedicated registered +// remote fused graph executor. The executor will send the graph specifications +// to a remote processor and execute that graph. The execution results +// will be passed to consumer nodes as outputs of this node. // // Arguments: -// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, in the canonical lexicographic ordering. -// a_values: 1-D. `N` non-empty values corresponding to `a_indices`. -// a_shape: 1-D. Shape of the input SparseTensor. -// b_indices: counterpart to `a_indices` for the other operand. -// b_values: counterpart to `a_values` for the other operand; must be of the same dtype. -// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal. -// -// Returns 2-D. The indices of the output SparseTensor.1-D. The values of the output SparseTensor. -func SparseSparseMinimum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSparseMinimum", - Input: []tf.Input{ - a_indices, a_values, a_shape, b_indices, b_values, b_shape, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Constructs a tensor by tiling a given tensor. +// inputs: Arbitrary number of tensors with arbitrary data types // -// This operation creates a new tensor by replicating `input` `multiples` times. -// The output tensor's i'th dimension has `input.dims(i) * multiples[i]` elements, -// and the values of `input` are replicated `multiples[i]` times along the 'i'th -// dimension. For example, tiling `[a b c d]` by `[2]` produces -// `[a b c d a b c d]`. +// serialized_remote_fused_graph_execute_info: Serialized protocol buffer +// of RemoteFusedGraphExecuteInfo which contains graph specifications. // -// Arguments: -// input: 1-D or higher. -// multiples: 1-D. Length must be the same as the number of dimensions in `input` -func Tile(scope *Scope, input tf.Output, multiples tf.Output) (output tf.Output) { +// Returns Arbitrary number of tensors with arbitrary data types +func RemoteFusedGraphExecute(scope *Scope, inputs []tf.Output, Toutputs []tf.DataType, serialized_remote_fused_graph_execute_info string) (outputs []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"Toutputs": Toutputs, "serialized_remote_fused_graph_execute_info": serialized_remote_fused_graph_execute_info} opspec := tf.OpSpec{ - Type: "Tile", + Type: "RemoteFusedGraphExecute", Input: []tf.Input{ - input, multiples, + tf.OutputList(inputs), }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { + scope.UpdateErr("RemoteFusedGraphExecute", err) + return + } + return outputs } -// Saves the input tensors to disk. +// SerializeManySparseAttr is an optional argument to SerializeManySparse. +type SerializeManySparseAttr func(optionalAttr) + +// SerializeManySparseOutType sets the optional out_type attribute to value. // -// The size of `tensor_names` must match the number of tensors in `data`. `data[i]` -// is written to `filename` with name `tensor_names[i]`. +// value: The `dtype` to use for serialization; the supported types are `string` +// (default) and `variant`. +// If not specified, defaults to DT_STRING +func SerializeManySparseOutType(value tf.DataType) SerializeManySparseAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor` object. // -// See also `SaveSlices`. +// The `SparseTensor` must have rank `R` greater than 1, and the first dimension +// is treated as the minibatch dimension. Elements of the `SparseTensor` +// must be sorted in increasing order of this first dimension. The serialized +// `SparseTensor` objects going into each row of `serialized_sparse` will have +// rank `R-1`. // -// Arguments: -// filename: Must have a single element. The name of the file to which we write -// the tensor. -// tensor_names: Shape `[N]`. The names of the tensors to be saved. -// data: `N` tensors to save. +// The minibatch size `N` is extracted from `sparse_shape[0]`. // -// Returns the created operation. -func Save(scope *Scope, filename tf.Output, tensor_names tf.Output, data []tf.Output) (o *tf.Operation) { +// Arguments: +// sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`. +// sparse_values: 1-D. The `values` of the minibatch `SparseTensor`. +// sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`. +func SerializeManySparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeManySparseAttr) (serialized_sparse tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Save", + Type: "SerializeManySparse", Input: []tf.Input{ - filename, tensor_names, tf.OutputList(data), + sparse_indices, sparse_values, sparse_shape, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Returns element-wise remainder of division. When `x < 0` xor `y < 0` is -// -// true, this follows Python semantics in that the result here is consistent -// with a flooring divide. E.g. `floor(x / y) * y + mod(x, y) = x`. -// -// *NOTE*: `FloorMod` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func FloorMod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Computes inverse hyperbolic cosine of x element-wise. +func Acosh(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "FloorMod", + Type: "Acosh", Input: []tf.Input{ - x, y, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// TakeManySparseFromTensorsMapAttr is an optional argument to TakeManySparseFromTensorsMap. -type TakeManySparseFromTensorsMapAttr func(optionalAttr) +// TensorArrayV2Attr is an optional argument to TensorArrayV2. +type TensorArrayV2Attr func(optionalAttr) -// TakeManySparseFromTensorsMapContainer sets the optional container attribute to value. -// -// value: The container name for the `SparseTensorsMap` read by this op. -// If not specified, defaults to "" -func TakeManySparseFromTensorsMapContainer(value string) TakeManySparseFromTensorsMapAttr { +// TensorArrayV2ElementShape sets the optional element_shape attribute to value. +// If not specified, defaults to +func TensorArrayV2ElementShape(value tf.Shape) TensorArrayV2Attr { return func(m optionalAttr) { - m["container"] = value + m["element_shape"] = value } } -// TakeManySparseFromTensorsMapSharedName sets the optional shared_name attribute to value. -// -// value: The shared name for the `SparseTensorsMap` read by this op. -// It should not be blank; rather the `shared_name` or unique Operation name -// of the Op that created the original `SparseTensorsMap` should be used. +// TensorArrayV2DynamicSize sets the optional dynamic_size attribute to value. +// If not specified, defaults to false +func TensorArrayV2DynamicSize(value bool) TensorArrayV2Attr { + return func(m optionalAttr) { + m["dynamic_size"] = value + } +} + +// TensorArrayV2ClearAfterRead sets the optional clear_after_read attribute to value. +// If not specified, defaults to true +func TensorArrayV2ClearAfterRead(value bool) TensorArrayV2Attr { + return func(m optionalAttr) { + m["clear_after_read"] = value + } +} + +// TensorArrayV2TensorArrayName sets the optional tensor_array_name attribute to value. // If not specified, defaults to "" -func TakeManySparseFromTensorsMapSharedName(value string) TakeManySparseFromTensorsMapAttr { +func TensorArrayV2TensorArrayName(value string) TensorArrayV2Attr { return func(m optionalAttr) { - m["shared_name"] = value + m["tensor_array_name"] = value } } -// Read `SparseTensors` from a `SparseTensorsMap` and concatenate them. -// -// The input `sparse_handles` must be an `int64` matrix of shape `[N, 1]` where -// `N` is the minibatch size and the rows correspond to the output handles of -// `AddSparseToTensorsMap` or `AddManySparseToTensorsMap`. The ranks of the -// original `SparseTensor` objects that went into the given input ops must all -// match. When the final `SparseTensor` is created, it has rank one -// higher than the ranks of the incoming `SparseTensor` objects -// (they have been concatenated along a new row dimension on the left). -// -// The output `SparseTensor` object's shape values for all dimensions but the -// first are the max across the input `SparseTensor` objects' shape values -// for the corresponding dimensions. Its first shape value is `N`, the minibatch -// size. -// -// The input `SparseTensor` objects' indices are assumed ordered in -// standard lexicographic order. If this is not the case, after this -// step run `SparseReorder` to restore index ordering. -// -// For example, if the handles represent an input, which is a `[2, 3]` matrix -// representing two original `SparseTensor` objects: -// -// ``` -// index = [ 0] -// [10] -// [20] -// values = [1, 2, 3] -// shape = [50] -// ``` -// -// and -// -// ``` -// index = [ 2] -// [10] -// values = [4, 5] -// shape = [30] -// ``` -// -// then the final `SparseTensor` will be: -// -// ``` -// index = [0 0] -// [0 10] -// [0 20] -// [1 2] -// [1 10] -// values = [1, 2, 3, 4, 5] -// shape = [2 50] -// ``` -// -// Arguments: -// sparse_handles: 1-D, The `N` serialized `SparseTensor` objects. -// Shape: `[N]`. -// dtype: The `dtype` of the `SparseTensor` objects stored in the -// `SparseTensorsMap`. +// Deprecated. Use TensorArrayV3 // -// Returns 2-D. The `indices` of the minibatch `SparseTensor`.1-D. The `values` of the minibatch `SparseTensor`.1-D. The `shape` of the minibatch `SparseTensor`. -func TakeManySparseFromTensorsMap(scope *Scope, sparse_handles tf.Output, dtype tf.DataType, optional ...TakeManySparseFromTensorsMapAttr) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { +// DEPRECATED at GraphDef version 26: Use TensorArrayV3 +func TensorArrayV2(scope *Scope, size tf.Output, dtype tf.DataType, optional ...TensorArrayV2Attr) (handle tf.Output) { if scope.Err() != nil { return } @@ -13519,261 +13013,262 @@ func TakeManySparseFromTensorsMap(scope *Scope, sparse_handles tf.Output, dtype a(attrs) } opspec := tf.OpSpec{ - Type: "TakeManySparseFromTensorsMap", + Type: "TensorArrayV2", Input: []tf.Input{ - sparse_handles, + size, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Says whether the targets are in the top `K` predictions. -// -// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the -// prediction for the target class is among the top `k` predictions among -// all predictions for example `i`. Note that the behavior of `InTopK` differs -// from the `TopK` op in its handling of ties; if multiple classes have the -// same prediction value and straddle the top-`k` boundary, all of those -// classes are considered to be in the top `k`. -// -// More formally, let -// -// \\(predictions_i\\) be the predictions for all classes for example `i`, -// \\(targets_i\\) be the target class for example `i`, -// \\(out_i\\) be the output for example `i`, -// -// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ -// -// Arguments: -// predictions: A `batch_size` x `classes` tensor. -// targets: A `batch_size` vector of class ids. -// k: Number of top elements to look at for computing precision. +// DecodeCSVAttr is an optional argument to DecodeCSV. +type DecodeCSVAttr func(optionalAttr) + +// DecodeCSVFieldDelim sets the optional field_delim attribute to value. // -// Returns Computed precision at `k` as a `bool Tensor`. -func InTopKV2(scope *Scope, predictions tf.Output, targets tf.Output, k tf.Output) (precision tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "InTopKV2", - Input: []tf.Input{ - predictions, targets, k, - }, +// value: char delimiter to separate fields in a record. +// If not specified, defaults to "," +func DecodeCSVFieldDelim(value string) DecodeCSVAttr { + return func(m optionalAttr) { + m["field_delim"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Assigns a new value to a variable. -// -// Any ReadVariableOp with a control dependency on this op is guaranteed to return -// this value or a subsequent newer value of the variable. -// -// Arguments: -// resource: handle to the resource in which to store the variable. -// value: the value to set the new tensor to use. +// DecodeCSVUseQuoteDelim sets the optional use_quote_delim attribute to value. // -// Returns the created operation. -func AssignVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return +// value: If false, treats double quotation marks as regular +// characters inside of the string fields (ignoring RFC 4180, Section 2, +// Bullet 5). +// If not specified, defaults to true +func DecodeCSVUseQuoteDelim(value bool) DecodeCSVAttr { + return func(m optionalAttr) { + m["use_quote_delim"] = value } - opspec := tf.OpSpec{ - Type: "AssignVariableOp", - Input: []tf.Input{ - resource, value, - }, +} + +// DecodeCSVNaValue sets the optional na_value attribute to value. +// +// value: Additional string to recognize as NA/NaN. +// If not specified, defaults to "" +func DecodeCSVNaValue(value string) DecodeCSVAttr { + return func(m optionalAttr) { + m["na_value"] = value } - return scope.AddOperation(opspec) } -// Returns a tensor of ones with the same shape and type as x. +// Convert CSV records to tensors. Each column maps to one tensor. +// +// RFC 4180 format is expected for the CSV records. +// (https://tools.ietf.org/html/rfc4180) +// Note that we allow leading and trailing spaces with int or float field. // // Arguments: -// x: a tensor of type T. +// records: Each string is a record/row in the csv and all records should have +// the same format. +// record_defaults: One tensor per column of the input record, with either a +// scalar default value for that column or empty if the column is required. // -// Returns a tensor of the same shape and type as x but filled with ones. -func OnesLike(scope *Scope, x tf.Output) (y tf.Output) { +// Returns Each tensor will have the same shape as records. +func DecodeCSV(scope *Scope, records tf.Output, record_defaults []tf.Output, optional ...DecodeCSVAttr) (output []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "OnesLike", + Type: "DecodeCSV", Input: []tf.Input{ - x, + records, tf.OutputList(record_defaults), }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// The gradient of SparseFillEmptyRows. -// -// Takes vectors reverse_index_map, shaped `[N]`, and grad_values, -// shaped `[N_full]`, where `N_full >= N` and copies data into either -// `d_values` or `d_default_value`. Here `d_values` is shaped `[N]` and -// `d_default_value` is a scalar. -// -// d_values[j] = grad_values[reverse_index_map[j]] -// d_default_value = sum_{k : 0 .. N_full - 1} ( -// grad_values[k] * 1{k not in reverse_index_map}) -// -// Arguments: -// reverse_index_map: 1-D. The reverse index map from SparseFillEmptyRows. -// grad_values: 1-D. The gradients from backprop. -// -// Returns 1-D. The backprop into values.0-D. The backprop into default_value. -func SparseFillEmptyRowsGrad(scope *Scope, reverse_index_map tf.Output, grad_values tf.Output) (d_values tf.Output, d_default_value tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "SparseFillEmptyRowsGrad", - Input: []tf.Input{ - reverse_index_map, grad_values, - }, + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("DecodeCSV", err) + return } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return output } -// Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` -// -// if < 0, `scale * features` otherwise. +// MapClearAttr is an optional argument to MapClear. +type MapClearAttr func(optionalAttr) + +// MapClearCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) -func Selu(scope *Scope, features tf.Output) (activations tf.Output) { - if scope.Err() != nil { - return +// REQUIRES: value >= 0 +func MapClearCapacity(value int64) MapClearAttr { + return func(m optionalAttr) { + m["capacity"] = value } - opspec := tf.OpSpec{ - Type: "Selu", - Input: []tf.Input{ - features, - }, +} + +// MapClearMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapClearMemoryLimit(value int64) MapClearAttr { + return func(m optionalAttr) { + m["memory_limit"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// SetSizeAttr is an optional argument to SetSize. -type SetSizeAttr func(optionalAttr) +// MapClearContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapClearContainer(value string) MapClearAttr { + return func(m optionalAttr) { + m["container"] = value + } +} -// SetSizeValidateIndices sets the optional validate_indices attribute to value. -// If not specified, defaults to true -func SetSizeValidateIndices(value bool) SetSizeAttr { +// MapClearSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapClearSharedName(value string) MapClearAttr { return func(m optionalAttr) { - m["validate_indices"] = value + m["shared_name"] = value } } -// Number of unique elements along last dimension of input `set`. -// -// Input `set` is a `SparseTensor` represented by `set_indices`, `set_values`, -// and `set_shape`. The last dimension contains values in a set, duplicates are -// allowed but ignored. -// -// If `validate_indices` is `True`, this op validates the order and range of `set` -// indices. -// -// Arguments: -// set_indices: 2D `Tensor`, indices of a `SparseTensor`. -// set_values: 1D `Tensor`, values of a `SparseTensor`. -// set_shape: 1D `Tensor`, shape of a `SparseTensor`. +// Op removes all elements in the underlying container. // -// Returns For `set` ranked `n`, this is a `Tensor` with rank `n-1`, and the same 1st -// `n-1` dimensions as `set`. Each value is the number of unique elements in -// the corresponding `[0...n-1]` dimension of `set`. -func SetSize(scope *Scope, set_indices tf.Output, set_values tf.Output, set_shape tf.Output, optional ...SetSizeAttr) (size tf.Output) { +// Returns the created operation. +func MapClear(scope *Scope, dtypes []tf.DataType, optional ...MapClearAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "SetSize", - Input: []tf.Input{ - set_indices, set_values, set_shape, - }, + Type: "MapClear", + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Computes the sign and the log of the absolute value of the determinant of +// ThreadUnsafeUnigramCandidateSamplerAttr is an optional argument to ThreadUnsafeUnigramCandidateSampler. +type ThreadUnsafeUnigramCandidateSamplerAttr func(optionalAttr) + +// ThreadUnsafeUnigramCandidateSamplerSeed sets the optional seed attribute to value. // -// one or more square matrices. +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func ThreadUnsafeUnigramCandidateSamplerSeed(value int64) ThreadUnsafeUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// ThreadUnsafeUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. // -// The input is a tensor of shape `[N, M, M]` whose inner-most 2 dimensions -// form square matrices. The outputs are two tensors containing the signs and -// absolute values of the log determinants for all N input submatrices -// `[..., :, :]` such that the determinant = sign*exp(log_abs_determinant). -// The log_abs_determinant is computed as det(P)*sum(log(diag(LU))) where LU -// is the LU decomposition of the input and P is the corresponding -// permutation matrix. +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func ThreadUnsafeUnigramCandidateSamplerSeed2(value int64) ThreadUnsafeUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Generates labels for candidate sampling with a learned unigram distribution. +// +// See explanations of candidate sampling and the data formats at +// go/candidate-sampling. +// +// For each batch, this op picks a single set of sampled candidate labels. +// +// The advantages of sampling candidates per-batch are simplicity and the +// possibility of efficient dense matrix multiplication. The disadvantage is that +// the sampled candidates must be chosen independently of the context and of the +// true labels. // // Arguments: -// input: Shape is `[N, M, M]`. +// true_classes: A batch_size * num_true matrix, in which each row contains the +// IDs of the num_true target_classes in the corresponding original label. +// num_true: Number of true labels per context. +// num_sampled: Number of candidates to randomly sample. +// unique: If unique is true, we sample with rejection, so that all sampled +// candidates in a batch are unique. This requires some approximation to +// estimate the post-rejection sampling probabilities. +// range_max: The sampler will sample integers from the interval [0, range_max). // -// Returns The signs of the log determinants of the inputs. Shape is `[N]`.The logs of the absolute values of the determinants -// of the N input matrices. Shape is `[N]`. -func LogMatrixDeterminant(scope *Scope, input tf.Output) (sign tf.Output, log_abs_determinant tf.Output) { +// Returns A vector of length num_sampled, in which each element is +// the ID of a sampled candidate.A batch_size * num_true matrix, representing +// the number of times each candidate is expected to occur in a batch +// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled +// candidate representing the number of times the candidate is expected +// to occur in a batch of sampled candidates. If unique=true, then this is a +// probability. +func ThreadUnsafeUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...ThreadUnsafeUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "LogMatrixDeterminant", + Type: "ThreadUnsafeUnigramCandidateSampler", Input: []tf.Input{ - input, + true_classes, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0), op.Output(1), op.Output(2) } -// SumAttr is an optional argument to Sum. -type SumAttr func(optionalAttr) +// MaxPoolV2Attr is an optional argument to MaxPoolV2. +type MaxPoolV2Attr func(optionalAttr) -// SumKeepDims sets the optional keep_dims attribute to value. +// MaxPoolV2DataFormat sets the optional data_format attribute to value. // -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func SumKeepDims(value bool) SumAttr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolV2DataFormat(value string) MaxPoolV2Attr { return func(m optionalAttr) { - m["keep_dims"] = value + m["data_format"] = value } } -// Computes the sum of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. +// Performs max pooling on the input. // // Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. +// input: 4-D input to pool over. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. // -// Returns The reduced tensor. -func Sum(scope *Scope, input tf.Output, axis tf.Output, optional ...SumAttr) (output tf.Output) { +// Returns The max pooled output tensor. +func MaxPoolV2(scope *Scope, input tf.Output, ksize tf.Output, strides tf.Output, padding string, optional ...MaxPoolV2Attr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Sum", + Type: "MaxPoolV2", Input: []tf.Input{ - input, axis, + input, ksize, strides, }, Attrs: attrs, } @@ -13781,317 +13276,257 @@ func Sum(scope *Scope, input tf.Output, axis tf.Output, optional ...SumAttr) (ou return op.Output(0) } -// Delete the tensor specified by its handle in the session. -// -// Arguments: -// handle: The handle for a tensor stored in the session state. +// MutableDenseHashTableV2Attr is an optional argument to MutableDenseHashTableV2. +type MutableDenseHashTableV2Attr func(optionalAttr) + +// MutableDenseHashTableV2Container sets the optional container attribute to value. // -// Returns the created operation. -func DeleteSessionTensor(scope *Scope, handle tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DeleteSessionTensor", - Input: []tf.Input{ - handle, - }, +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func MutableDenseHashTableV2Container(value string) MutableDenseHashTableV2Attr { + return func(m optionalAttr) { + m["container"] = value } - return scope.AddOperation(opspec) } -// L2 Loss. -// -// Computes half the L2 norm of a tensor without the `sqrt`: -// -// output = sum(t ** 2) / 2 -// -// Arguments: -// t: Typically 2-D, but may have any dimensions. +// MutableDenseHashTableV2SharedName sets the optional shared_name attribute to value. // -// Returns 0-D. -func L2Loss(scope *Scope, t tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "L2Loss", - Input: []tf.Input{ - t, - }, +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func MutableDenseHashTableV2SharedName(value string) MutableDenseHashTableV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// DenseToSparseSetOperationAttr is an optional argument to DenseToSparseSetOperation. -type DenseToSparseSetOperationAttr func(optionalAttr) +// MutableDenseHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// If not specified, defaults to false +func MutableDenseHashTableV2UseNodeNameSharing(value bool) MutableDenseHashTableV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} -// DenseToSparseSetOperationValidateIndices sets the optional validate_indices attribute to value. -// If not specified, defaults to true -func DenseToSparseSetOperationValidateIndices(value bool) DenseToSparseSetOperationAttr { +// MutableDenseHashTableV2ValueShape sets the optional value_shape attribute to value. +// +// value: The shape of each value. +// If not specified, defaults to <> +func MutableDenseHashTableV2ValueShape(value tf.Shape) MutableDenseHashTableV2Attr { return func(m optionalAttr) { - m["validate_indices"] = value + m["value_shape"] = value } } -// Applies set operation along last dimension of `Tensor` and `SparseTensor`. +// MutableDenseHashTableV2InitialNumBuckets sets the optional initial_num_buckets attribute to value. // -// See SetOperationOp::SetOperationFromContext for values of `set_operation`. +// value: The initial number of hash table buckets. Must be a power +// to 2. +// If not specified, defaults to 131072 +func MutableDenseHashTableV2InitialNumBuckets(value int64) MutableDenseHashTableV2Attr { + return func(m optionalAttr) { + m["initial_num_buckets"] = value + } +} + +// MutableDenseHashTableV2MaxLoadFactor sets the optional max_load_factor attribute to value. // -// Input `set2` is a `SparseTensor` represented by `set2_indices`, `set2_values`, -// and `set2_shape`. For `set2` ranked `n`, 1st `n-1` dimensions must be the same -// as `set1`. Dimension `n` contains values in a set, duplicates are allowed but -// ignored. +// value: The maximum ratio between number of entries and number of +// buckets before growing the table. Must be between 0 and 1. +// If not specified, defaults to 0.8 +func MutableDenseHashTableV2MaxLoadFactor(value float32) MutableDenseHashTableV2Attr { + return func(m optionalAttr) { + m["max_load_factor"] = value + } +} + +// Creates an empty hash table that uses tensors as the backing store. // -// If `validate_indices` is `True`, this op validates the order and range of `set2` -// indices. +// It uses "open addressing" with quadratic reprobing to resolve +// collisions. // -// Output `result` is a `SparseTensor` represented by `result_indices`, -// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this -// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` -// dimension contains the result of `set_operation` applied to the corresponding -// `[0...n-1]` dimension of `set`. +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a scalar. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. // // Arguments: -// set1: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set2`. -// Dimension `n` contains values in a set, duplicates are allowed but ignored. -// set2_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major -// order. -// set2_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major -// order. -// set2_shape: 1D `Tensor`, shape of a `SparseTensor`. `set2_shape[0...n-1]` must -// be the same as the 1st `n-1` dimensions of `set1`, `result_shape[n]` is the -// max set size across `n-1` dimensions. -// +// empty_key: The key used to represent empty key buckets internally. Must not +// be used in insert or lookup operations. +// value_dtype: Type of the table values. // -// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is -// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` -// is the max result set size across all `0...n-1` dimensions. -func DenseToSparseSetOperation(scope *Scope, set1 tf.Output, set2_indices tf.Output, set2_values tf.Output, set2_shape tf.Output, set_operation string, optional ...DenseToSparseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { +// Returns Handle to a table. +func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, value_dtype tf.DataType, optional ...MutableDenseHashTableV2Attr) (table_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"set_operation": set_operation} + attrs := map[string]interface{}{"value_dtype": value_dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DenseToSparseSetOperation", + Type: "MutableDenseHashTableV2", Input: []tf.Input{ - set1, set2_indices, set2_values, set2_shape, + empty_key, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// FusedResizeAndPadConv2DAttr is an optional argument to FusedResizeAndPadConv2D. -type FusedResizeAndPadConv2DAttr func(optionalAttr) +// StageSizeAttr is an optional argument to StageSize. +type StageSizeAttr func(optionalAttr) -// FusedResizeAndPadConv2DResizeAlignCorners sets the optional resize_align_corners attribute to value. +// StageSizeCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// value: If true, rescale input by (new_height - 1) / (height - 1), -// which exactly aligns the 4 corners of images and resized images. If false, rescale -// by new_height / height. Treat similarly the width dimension. -// If not specified, defaults to false -func FusedResizeAndPadConv2DResizeAlignCorners(value bool) FusedResizeAndPadConv2DAttr { +// REQUIRES: value >= 0 +func StageSizeCapacity(value int64) StageSizeAttr { return func(m optionalAttr) { - m["resize_align_corners"] = value + m["capacity"] = value } } -// Performs a resize and padding as a preprocess during a convolution. -// -// It's often possible to do spatial transformations more efficiently as part of -// the packing stage of a convolution, so this op allows for an optimized -// implementation where these stages are fused together. This prevents the need to -// write out the intermediate results as whole tensors, reducing memory pressure, -// and we can get some latency gains by merging the transformation calculations. -// The data_format attribute for Conv2D isn't supported by this op, and defaults to -// 'NHWC' order. -// Internally this op uses a single per-graph scratch buffer, which means that it -// will block if multiple versions are being run in parallel. This is because this -// operator is primarily an optimization to minimize memory usage. -// -// Arguments: -// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. -// size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. -// paddings: A two-column matrix specifying the padding sizes. The number of -// rows must be the same as the rank of `input`. -// filter: 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. +// StageSizeMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// strides: 1-D of length 4. The stride of the sliding window for each dimension -// of `input`. Must be in the same order as the dimension specified with format. -// padding: The type of padding algorithm to use. -func FusedResizeAndPadConv2D(scope *Scope, input tf.Output, size tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string, optional ...FusedResizeAndPadConv2DAttr) (output tf.Output) { +// REQUIRES: value >= 0 +func StageSizeMemoryLimit(value int64) StageSizeAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// StageSizeContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func StageSizeContainer(value string) StageSizeAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// StageSizeSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func StageSizeSharedName(value string) StageSizeAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op returns the number of elements in the underlying container. +func StageSize(scope *Scope, dtypes []tf.DataType, optional ...StageSizeAttr) (size tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "FusedResizeAndPadConv2D", - Input: []tf.Input{ - input, size, paddings, filter, - }, + Type: "StageSize", + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Subtracts a value from the current value of a variable. -// -// Any ReadVariableOp which depends directly or indirectly on this assign is -// guaranteed to see the incremented value or a subsequent newer one. -// -// Outputs the incremented value, which can be used to totally order the -// increments to this variable. +// Produces the max pool of the input tensor for quantized types. // // Arguments: -// resource: handle to the resource in which to store the variable. -// value: the value by which the variable will be incremented. +// input: The 4D (batch x rows x cols x depth) Tensor to MaxReduce over. +// min_input: The float value that the lowest quantized input value represents. +// max_input: The float value that the highest quantized input value represents. +// ksize: The size of the window for each dimension of the input tensor. +// The length must be 4 to match the number of dimensions of the input. +// strides: The stride of the sliding window for each dimension of the input +// tensor. The length must be 4 to match the number of dimensions of the input. +// padding: The type of padding algorithm to use. // -// Returns the created operation. -func AssignSubVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +func QuantizedMaxPool(scope *Scope, input tf.Output, min_input tf.Output, max_input tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output, min_output tf.Output, max_output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + opspec := tf.OpSpec{ + Type: "QuantizedMaxPool", + Input: []tf.Input{ + input, min_input, max_input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Computes softplus: `log(exp(features) + 1)`. +func Softplus(scope *Scope, features tf.Output) (activations tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "AssignSubVariableOp", + Type: "Softplus", Input: []tf.Input{ - resource, value, + features, }, } - return scope.AddOperation(opspec) -} - -// RestoreAttr is an optional argument to Restore. -type RestoreAttr func(optionalAttr) - -// RestorePreferredShard sets the optional preferred_shard attribute to value. -// -// value: Index of file to open first if multiple files match -// `file_pattern`. -// If not specified, defaults to -1 -func RestorePreferredShard(value int64) RestoreAttr { - return func(m optionalAttr) { - m["preferred_shard"] = value - } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Restores a tensor from checkpoint files. -// -// Reads a tensor stored in one or several files. If there are several files (for -// instance because a tensor was saved as slices), `file_pattern` may contain -// wildcard symbols (`*` and `?`) in the filename portion only, not in the -// directory portion. -// -// If a `file_pattern` matches several files, `preferred_shard` can be used to hint -// in which file the requested tensor is likely to be found. This op will first -// open the file at index `preferred_shard` in the list of matching files and try -// to restore tensors from that file. Only if some tensors or tensor slices are -// not found in that first file, then the Op opens all the files. Setting -// `preferred_shard` to match the value passed as the `shard` input -// of a matching `Save` Op may speed up Restore. This attribute only affects -// performance, not correctness. The default value -1 means files are processed in -// order. -// -// See also `RestoreSlice`. -// -// Arguments: -// file_pattern: Must have a single element. The pattern of the files from -// which we read the tensor. -// tensor_name: Must have a single element. The name of the tensor to be -// restored. -// dt: The type of the tensor to be restored. +// Computes exponential of x - 1 element-wise. // -// Returns The restored tensor. -func Restore(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, dt tf.DataType, optional ...RestoreAttr) (tensor tf.Output) { +// I.e., \\(y = (\exp x) - 1\\). +func Expm1(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dt": dt} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Restore", + Type: "Expm1", Input: []tf.Input{ - file_pattern, tensor_name, + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// QuantizedResizeBilinearAttr is an optional argument to QuantizedResizeBilinear. -type QuantizedResizeBilinearAttr func(optionalAttr) - -// QuantizedResizeBilinearAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, rescale input by (new_height - 1) / (height - 1), which -// exactly aligns the 4 corners of images and resized images. If false, rescale -// by new_height / height. Treat similarly the width dimension. -// If not specified, defaults to false -func QuantizedResizeBilinearAlignCorners(value bool) QuantizedResizeBilinearAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// Resize quantized `images` to `size` using quantized bilinear interpolation. +// Returns the number of records this Reader has produced. // -// Input images and output images must be quantized types. +// This is the same as the number of ReaderRead executions that have +// succeeded. // // Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. -// -// -// -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func QuantizedResizeBilinear(scope *Scope, images tf.Output, size tf.Output, min tf.Output, max tf.Output, optional ...QuantizedResizeBilinearAttr) (resized_images tf.Output, out_min tf.Output, out_max tf.Output) { +// reader_handle: Handle to a Reader. +func ReaderNumRecordsProducedV2(scope *Scope, reader_handle tf.Output) (records_produced tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "QuantizedResizeBilinear", + Type: "ReaderNumRecordsProducedV2", Input: []tf.Input{ - images, size, min, max, + reader_handle, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Computes the minimum along segments of a tensor. +// Computes the sum along segments of a tensor. // // Read @{$math_ops#segmentation$the section on segmentation} for an explanation of // segments. // // Computes a tensor such that -// \\(output_i = \min_j(data_j)\\) where `min` is over `j` such +// \\(output_i = \sum_j data_j\\) where sum is over `j` such // that `segment_ids[j] == i`. // -// If the min is empty for a given segment ID `i`, `output[i] = 0`. +// If the sum is empty for a given segment ID `i`, `output[i] = 0`. // //
-// +// //
// // Arguments: @@ -14101,12 +13536,12 @@ func QuantizedResizeBilinear(scope *Scope, images tf.Output, size tf.Output, min // // Returns Has same shape as data, except for dimension 0 which // has size `k`, the number of segments. -func SegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { +func SegmentSum(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SegmentMin", + Type: "SegmentSum", Input: []tf.Input{ data, segment_ids, }, @@ -14115,180 +13550,62 @@ func SegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf. return op.Output(0) } -// SdcaOptimizerAttr is an optional argument to SdcaOptimizer. -type SdcaOptimizerAttr func(optionalAttr) - -// SdcaOptimizerAdaptative sets the optional adaptative attribute to value. -// -// value: Whether to use Adapative SDCA for the inner loop. -// If not specified, defaults to false -func SdcaOptimizerAdaptative(value bool) SdcaOptimizerAttr { - return func(m optionalAttr) { - m["adaptative"] = value - } -} - -// Distributed version of Stochastic Dual Coordinate Ascent (SDCA) optimizer for -// -// linear models with L1 + L2 regularization. As global optimization objective is -// strongly-convex, the optimizer optimizes the dual objective at each step. The -// optimizer applies each update one example at a time. Examples are sampled -// uniformly, and the optimizer is learning rate free and enjoys linear convergence -// rate. -// -// [Proximal Stochastic Dual Coordinate Ascent](http://arxiv.org/pdf/1211.2717v1.pdf).
-// Shai Shalev-Shwartz, Tong Zhang. 2012 -// -// $$Loss Objective = \sum f_{i} (wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|$$ -// -// [Adding vs. Averaging in Distributed Primal-Dual Optimization](http://arxiv.org/abs/1502.03508).
-// Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, -// Peter Richtarik, Martin Takac. 2015 -// -// [Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).
-// Dominik Csiba, Zheng Qu, Peter Richtarik. 2015 +// Creates a dataset that emits the lines of one or more text files. // // Arguments: -// sparse_example_indices: a list of vectors which contain example indices. -// sparse_feature_indices: a list of vectors which contain feature indices. -// sparse_feature_values: a list of vectors which contains feature value -// associated with each feature group. -// dense_features: a list of matrices which contains the dense feature values. -// example_weights: a vector which contains the weight associated with each -// example. -// example_labels: a vector which contains the label/target associated with each -// example. -// sparse_indices: a list of vectors where each value is the indices which has -// corresponding weights in sparse_weights. This field maybe omitted for the -// dense approach. -// sparse_weights: a list of vectors where each value is the weight associated with -// a sparse feature group. -// dense_weights: a list of vectors where the values are the weights associated -// with a dense feature group. -// example_state_data: a list of vectors containing the example state data. -// loss_type: Type of the primal loss. Currently SdcaSolver supports logistic, -// squared and hinge losses. -// l1: Symmetric l1 regularization strength. -// l2: Symmetric l2 regularization strength. -// num_loss_partitions: Number of partitions of the global loss function. -// num_inner_iterations: Number of iterations per mini-batch. -// -// Returns a list of vectors containing the updated example state -// data.a list of vectors where each value is the delta -// weights associated with a sparse feature group.a list of vectors where the values are the delta -// weights associated with a dense feature group. -func SdcaOptimizer(scope *Scope, sparse_example_indices []tf.Output, sparse_feature_indices []tf.Output, sparse_feature_values []tf.Output, dense_features []tf.Output, example_weights tf.Output, example_labels tf.Output, sparse_indices []tf.Output, sparse_weights []tf.Output, dense_weights []tf.Output, example_state_data tf.Output, loss_type string, l1 float32, l2 float32, num_loss_partitions int64, num_inner_iterations int64, optional ...SdcaOptimizerAttr) (out_example_state_data tf.Output, out_delta_sparse_weights []tf.Output, out_delta_dense_weights []tf.Output) { +// filenames: A scalar or a vector containing the name(s) of the file(s) to be +// read. +// compression_type: A scalar containing either (i) the empty string (no +// compression), (ii) "ZLIB", or (iii) "GZIP". +// buffer_size: A scalar containing the number of bytes to buffer. +func TextLineDataset(scope *Scope, filenames tf.Output, compression_type tf.Output, buffer_size tf.Output) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"loss_type": loss_type, "l1": l1, "l2": l2, "num_loss_partitions": num_loss_partitions, "num_inner_iterations": num_inner_iterations} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "SdcaOptimizer", + Type: "TextLineDataset", Input: []tf.Input{ - tf.OutputList(sparse_example_indices), tf.OutputList(sparse_feature_indices), tf.OutputList(sparse_feature_values), tf.OutputList(dense_features), example_weights, example_labels, tf.OutputList(sparse_indices), tf.OutputList(sparse_weights), tf.OutputList(dense_weights), example_state_data, + filenames, compression_type, buffer_size, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - out_example_state_data = op.Output(idx) - if out_delta_sparse_weights, idx, err = makeOutputList(op, idx, "out_delta_sparse_weights"); err != nil { - scope.UpdateErr("SdcaOptimizer", err) - return - } - if out_delta_dense_weights, idx, err = makeOutputList(op, idx, "out_delta_dense_weights"); err != nil { - scope.UpdateErr("SdcaOptimizer", err) - return - } - return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights -} - -// SparseMatMulAttr is an optional argument to SparseMatMul. -type SparseMatMulAttr func(optionalAttr) - -// SparseMatMulTransposeA sets the optional transpose_a attribute to value. -// If not specified, defaults to false -func SparseMatMulTransposeA(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["transpose_a"] = value - } -} - -// SparseMatMulTransposeB sets the optional transpose_b attribute to value. -// If not specified, defaults to false -func SparseMatMulTransposeB(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["transpose_b"] = value - } -} - -// SparseMatMulAIsSparse sets the optional a_is_sparse attribute to value. -// If not specified, defaults to false -func SparseMatMulAIsSparse(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["a_is_sparse"] = value - } -} - -// SparseMatMulBIsSparse sets the optional b_is_sparse attribute to value. -// If not specified, defaults to false -func SparseMatMulBIsSparse(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["b_is_sparse"] = value - } + return op.Output(0) } -// Multiply matrix "a" by matrix "b". +// Computes gradients for SparseSegmentMean. // -// The inputs must be two-dimensional matrices and the inner dimension of "a" must -// match the outer dimension of "b". This op is optimized for the case where at -// least one of "a" or "b" is sparse. The breakeven for using this versus a dense -// matrix multiply on one platform was 30% zero values in the sparse matrix. +// Returns tensor "output" with same shape as grad, except for dimension 0 whose +// value is output_dim0. // -// The gradient computation of this operation will only take advantage of sparsity -// in the input gradient when that gradient comes from a Relu. -func SparseMatMul(scope *Scope, a tf.Output, b tf.Output, optional ...SparseMatMulAttr) (product tf.Output) { +// Arguments: +// grad: gradient propagated to the SparseSegmentMean op. +// indices: indices passed to the corresponding SparseSegmentMean op. +// segment_ids: segment_ids passed to the corresponding SparseSegmentMean op. +// output_dim0: dimension 0 of "data" passed to SparseSegmentMean op. +func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "SparseMatMul", + Type: "SparseSegmentMeanGrad", Input: []tf.Input{ - a, b, + grad, indices, segment_ids, output_dim0, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the power of one value to another. -// -// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for -// corresponding elements in `x` and `y`. For example: +// Returns the truth value of (x >= y) element-wise. // -// ``` -// # tensor 'x' is [[2, 2]], [3, 3]] -// # tensor 'y' is [[8, 16], [2, 3]] -// tf.pow(x, y) ==> [[256, 65536], [9, 27]] -// ``` -func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// *NOTE*: `GreaterEqual` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func GreaterEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Pow", + Type: "GreaterEqual", Input: []tf.Input{ x, y, }, @@ -14297,39 +13614,64 @@ func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// ShapeAttr is an optional argument to Shape. -type ShapeAttr func(optionalAttr) +// Conv3DAttr is an optional argument to Conv3D. +type Conv3DAttr func(optionalAttr) -// ShapeOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_INT32 -func ShapeOutType(value tf.DataType) ShapeAttr { +// Conv3DDataFormat sets the optional data_format attribute to value. +// +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func Conv3DDataFormat(value string) Conv3DAttr { return func(m optionalAttr) { - m["out_type"] = value + m["data_format"] = value } } -// Returns the shape of a tensor. +// Conv3DDilations sets the optional dilations attribute to value. // -// This operation returns a 1-D integer tensor representing the shape of `input`. +// value: 1-D tensor of length 5. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each +// filter element on that dimension. The dimension order is determined by the +// value of `data_format`, see above for details. Dilations in the batch and +// depth dimensions must be 1. +// If not specified, defaults to +func Conv3DDilations(value []int64) Conv3DAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes a 3-D convolution given 5-D `input` and `filter` tensors. // -// For example: +// In signal processing, cross-correlation is a measure of similarity of +// two waveforms as a function of a time-lag applied to one of them. This +// is also known as a sliding dot product or sliding inner-product. // -// ``` -// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] -// shape(t) ==> [2, 2, 3] -// ``` -func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Output) { +// Our Conv3D implements a form of cross-correlation. +// +// Arguments: +// input: Shape `[batch, in_depth, in_height, in_width, in_channels]`. +// filter: Shape `[filter_depth, filter_height, filter_width, in_channels, +// out_channels]`. `in_channels` must match between `input` and `filter`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func Conv3D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...Conv3DAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Shape", + Type: "Conv3D", Input: []tf.Input{ - input, + input, filter, }, Attrs: attrs, } @@ -14337,93 +13679,57 @@ func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Outp return op.Output(0) } -// Computes fingerprints of the input strings. +// Adds up a SparseTensor and a dense Tensor, using these special rules: +// +// (1) Broadcasts the dense side to have the same shape as the sparse side, if +// eligible; +// (2) Then, only the dense values pointed to by the indices of the SparseTensor +// participate in the cwise addition. +// +// By these rules, the result is a logical SparseTensor with exactly the same +// indices and shape, but possibly with different non-zero values. The output of +// this Op is the resultant non-zero values. // // Arguments: -// input: vector of strings to compute fingerprints on. +// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. +// sp_shape: 1-D. Shape of the input SparseTensor. +// dense: `R`-D. The dense Tensor operand. // -// Returns a (N,2) shaped matrix where N is the number of elements in the input -// vector. Each row contains the low and high parts of the fingerprint. -func SdcaFprint(scope *Scope, input tf.Output) (output tf.Output) { +// Returns 1-D. The `N` values that are operated on. +func SparseDenseCwiseAdd(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SdcaFprint", + Type: "SparseDenseCwiseAdd", Input: []tf.Input{ - input, + sp_indices, sp_values, sp_shape, dense, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// RandomPoissonV2Attr is an optional argument to RandomPoissonV2. -type RandomPoissonV2Attr func(optionalAttr) - -// RandomPoissonV2Seed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomPoissonV2Seed(value int64) RandomPoissonV2Attr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomPoissonV2Seed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomPoissonV2Seed2(value int64) RandomPoissonV2Attr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// RandomPoissonV2Dtype sets the optional dtype attribute to value. -// If not specified, defaults to DT_INT64 -func RandomPoissonV2Dtype(value tf.DataType) RandomPoissonV2Attr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Outputs random values from the Poisson distribution(s) described by rate. -// -// This op uses two algorithms, depending on rate. If rate >= 10, then -// the algorithm by Hormann is used to acquire samples via -// transformation-rejection. -// See http://www.sciencedirect.com/science/article/pii/0167668793909974. -// -// Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform -// random variables. -// See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer -// Programming, Volume 2. Addison Wesley +// Read an element from the TensorArray into output `value`. // // Arguments: -// shape: 1-D integer tensor. Shape of independent samples to draw from each -// distribution described by the shape parameters given in rate. -// rate: A tensor in which each scalar is a "rate" parameter describing the -// associated poisson distribution. +// handle: The handle to a TensorArray. // -// Returns A tensor with shape `shape + shape(rate)`. Each slice -// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for -// `rate[i0, i1, ...iN]`. -func RandomPoissonV2(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonV2Attr) (output tf.Output) { +// flow_in: A float scalar that enforces proper chaining of operations. +// dtype: The type of the elem that is returned. +// +// Returns The tensor that is read from the TensorArray. +func TensorArrayReadV3(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "RandomPoissonV2", + Type: "TensorArrayReadV3", Input: []tf.Input{ - shape, rate, + handle, index, flow_in, }, Attrs: attrs, } @@ -14431,333 +13737,422 @@ func RandomPoissonV2(scope *Scope, shape tf.Output, rate tf.Output, optional ... return op.Output(0) } -// MatrixTriangularSolveAttr is an optional argument to MatrixTriangularSolve. -type MatrixTriangularSolveAttr func(optionalAttr) +// QuantizeV2Attr is an optional argument to QuantizeV2. +type QuantizeV2Attr func(optionalAttr) -// MatrixTriangularSolveLower sets the optional lower attribute to value. +// QuantizeV2Mode sets the optional mode attribute to value. +// If not specified, defaults to "MIN_COMBINED" +func QuantizeV2Mode(value string) QuantizeV2Attr { + return func(m optionalAttr) { + m["mode"] = value + } +} + +// QuantizeV2RoundMode sets the optional round_mode attribute to value. +// If not specified, defaults to "HALF_AWAY_FROM_ZERO" +func QuantizeV2RoundMode(value string) QuantizeV2Attr { + return func(m optionalAttr) { + m["round_mode"] = value + } +} + +// Quantize the 'input' tensor of type float to 'output' tensor of type 'T'. +// +// [min_range, max_range] are scalar floats that specify the range for +// the 'input' data. The 'mode' attribute controls exactly which calculations are +// used to convert the float values to their quantized equivalents. The +// 'round_mode' attribute controls which rounding tie-breaking algorithm is used +// when rounding float values to their quantized equivalents. +// +// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: +// +// ``` +// out[i] = (in[i] - min_range) * range(T) / (max_range - min_range) +// if T == qint8, out[i] -= (range(T) + 1) / 2.0 +// ``` +// here `range(T) = numeric_limits::max() - numeric_limits::min()` +// +// *MIN_COMBINED Mode Example* +// +// Assume the input is type float and has a possible range of [0.0, 6.0] and the +// output type is quint8 ([0, 255]). The min_range and max_range values should be +// specified as 0.0 and 6.0. Quantizing from float to quint8 will multiply each +// value of the input by 255/6 and cast to quint8. +// +// If the output type was qint8 ([-128, 127]), the operation will additionally +// subtract each value by 128 prior to casting, so that the range of values aligns +// with the range of qint8. +// +// If the mode is 'MIN_FIRST', then this approach is used: +// +// ``` +// num_discrete_values = 1 << (# of bits in T) +// range_adjust = num_discrete_values / (num_discrete_values - 1) +// range = (range_max - range_min) * range_adjust +// range_scale = num_discrete_values / range +// quantized = round(input * range_scale) - round(range_min * range_scale) + +// numeric_limits::min() +// quantized = max(quantized, numeric_limits::min()) +// quantized = min(quantized, numeric_limits::max()) +// ``` +// +// The biggest difference between this and MIN_COMBINED is that the minimum range +// is rounded first, before it's subtracted from the rounded value. With +// MIN_COMBINED, a small bias is introduced where repeated iterations of quantizing +// and dequantizing will introduce a larger and larger error. +// +// *SCALED mode Example* +// +// `SCALED` mode matches the quantization approach used in +// `QuantizeAndDequantize{V2|V3}`. +// +// If the mode is `SCALED`, we do not use the full range of the output type, +// choosing to elide the lowest possible value for symmetry (e.g., output range is +// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to +// 0. // -// value: Boolean indicating whether the innermost matrices in `matrix` are -// lower or upper triangular. -// If not specified, defaults to true -func MatrixTriangularSolveLower(value bool) MatrixTriangularSolveAttr { - return func(m optionalAttr) { - m["lower"] = value - } -} - -// MatrixTriangularSolveAdjoint sets the optional adjoint attribute to value. +// We first find the range of values in our tensor. The +// range we use is always centered on 0, so we find m such that +// ```c++ +// m = max(abs(input_min), abs(input_max)) +// ``` // -// value: Boolean indicating whether to solve with `matrix` or its (block-wise) -// adjoint. +// Our input tensor range is then `[-m, m]`. // -// @compatibility(numpy) -// Equivalent to np.linalg.triangular_solve -// @end_compatibility -// If not specified, defaults to false -func MatrixTriangularSolveAdjoint(value bool) MatrixTriangularSolveAttr { - return func(m optionalAttr) { - m["adjoint"] = value - } -} - -// Solves systems of linear equations with upper or lower triangular matrices by +// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. +// If T is signed, this is +// ``` +// num_bits = sizeof(T) * 8 +// [min_fixed, max_fixed] = +// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1] +// ``` // -// backsubstitution. +// Otherwise, if T is unsigned, the fixed-point range is +// ``` +// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] +// ``` // -// `matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form -// square matrices. If `lower` is `True` then the strictly upper triangular part -// of each inner-most matrix is assumed to be zero and not accessed. -// If `lower` is False then the strictly lower triangular part of each inner-most -// matrix is assumed to be zero and not accessed. -// `rhs` is a tensor of shape `[..., M, K]`. +// From this we compute our scaling factor, s: +// ```c++ +// s = (max_fixed - min_fixed) / (2 * m) +// ``` // -// The output is a tensor of shape `[..., M, K]`. If `adjoint` is -// `True` then the innermost matrices in `output` satisfy matrix equations -// `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. -// If `adjoint` is `False` then the strictly then the innermost matrices in -// `output` satisfy matrix equations -// `adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`. +// Now we can quantize the elements of our tensor: +// ```c++ +// result = round(input * s) +// ``` +// +// One thing to watch out for is that the operator may choose to adjust the +// requested minimum and maximum values slightly during the quantization process, +// so you should always use the output ports as the range for further calculations. +// For example, if the requested minimum and maximum values are close to equal, +// they will be separated by a small epsilon value to prevent ill-formed quantized +// buffers from being created. Otherwise, you can end up with buffers where all the +// quantized values map to the same float value, which causes problems for +// operations that have to perform further calculations on them. // // Arguments: -// matrix: Shape is `[..., M, M]`. -// rhs: Shape is `[..., M, K]`. // -// Returns Shape is `[..., M, K]`. -func MatrixTriangularSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixTriangularSolveAttr) (output tf.Output) { +// min_range: The minimum scalar value possibly produced for the input. +// max_range: The maximum scalar value possibly produced for the input. +// +// +// Returns The quantized data produced from the float input.The actual minimum scalar value used for the output.The actual maximum scalar value used for the output. +func QuantizeV2(scope *Scope, input tf.Output, min_range tf.Output, max_range tf.Output, T tf.DataType, optional ...QuantizeV2Attr) (output tf.Output, output_min tf.Output, output_max tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"T": T} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MatrixTriangularSolve", + Type: "QuantizeV2", Input: []tf.Input{ - matrix, rhs, + input, min_range, max_range, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Computes inverse hyperbolic sine of x element-wise. -func Asinh(scope *Scope, x tf.Output) (y tf.Output) { +// Returns the truth value of (x < y) element-wise. +// +// *NOTE*: `Less` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Less(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Asinh", + Type: "Less", Input: []tf.Input{ - x, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a dataset with a range of values. Corresponds to python's xrange. +// QuantizedReluXAttr is an optional argument to QuantizedReluX. +type QuantizedReluXAttr func(optionalAttr) + +// QuantizedReluXOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_QUINT8 +func QuantizedReluXOutType(value tf.DataType) QuantizedReluXAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Computes Quantized Rectified Linear X: `min(max(features, 0), max_value)` // // Arguments: -// start: corresponds to start in python's xrange(). -// stop: corresponds to stop in python's xrange(). -// step: corresponds to step in python's xrange(). // // -func RangeDataset(scope *Scope, start tf.Output, stop tf.Output, step tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// min_features: The float value that the lowest quantized value represents. +// max_features: The float value that the highest quantized value represents. +// +// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. +func QuantizedReluX(scope *Scope, features tf.Output, max_value tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedReluXAttr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "RangeDataset", + Type: "QuantizedReluX", Input: []tf.Input{ - start, stop, step, + features, max_value, min_features, max_features, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// DepthwiseConv2dNativeBackpropInputAttr is an optional argument to DepthwiseConv2dNativeBackpropInput. -type DepthwiseConv2dNativeBackpropInputAttr func(optionalAttr) +// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2. +type WholeFileReaderV2Attr func(optionalAttr) -// DepthwiseConv2dNativeBackpropInputDataFormat sets the optional data_format attribute to value. +// WholeFileReaderV2Container sets the optional container attribute to value. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, height, width, channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, channels, height, width]. -// If not specified, defaults to "NHWC" -func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dNativeBackpropInputAttr { +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr { return func(m optionalAttr) { - m["data_format"] = value + m["container"] = value } } -// DepthwiseConv2dNativeBackpropInputDilations sets the optional dilations attribute to value. +// WholeFileReaderV2SharedName sets the optional shared_name attribute to value. // -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each filter -// element on that dimension. The dimension order is determined by the value of -// `data_format`, see above for details. Dilations in the batch and depth -// dimensions must be 1. -// If not specified, defaults to -func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr { return func(m optionalAttr) { - m["dilations"] = value + m["shared_name"] = value } } -// Computes the gradients of depthwise convolution with respect to the input. +// A Reader that outputs the entire contents of a file as a value. // -// Arguments: -// input_sizes: An integer vector representing the shape of `input`, based -// on `data_format`. For example, if `data_format` is 'NHWC' then -// `input` is a 4-D `[batch, height, width, channels]` tensor. -// filter: 4-D with shape -// `[filter_height, filter_width, in_channels, depthwise_multiplier]`. -// out_backprop: 4-D with shape based on `data_format`. -// For example, if `data_format` is 'NHWC' then -// out_backprop shape is `[batch, out_height, out_width, out_channels]`. -// Gradients w.r.t. the output of the convolution. -// strides: The stride of the sliding window for each dimension of the input -// of the convolution. -// padding: The type of padding algorithm to use. +// To use, enqueue filenames in a Queue. The output of ReaderRead will +// be a filename (key) and the contents of that file (value). // -// Returns 4-D with shape according to `data_format`. For example, if -// `data_format` is 'NHWC', output shape is `[batch, in_height, -// in_width, in_channels]`. Gradient w.r.t. the input of the -// convolution. -func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropInputAttr) (output tf.Output) { +// Returns The handle to reference the Reader. +func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DepthwiseConv2dNativeBackpropInput", - Input: []tf.Input{ - input_sizes, filter, out_backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Adds sparse updates to the variable referenced by `resource`. -// -// This operation computes -// -// # Scalar indices -// ref[indices, ...] += updates[...] -// -// # Vector indices (for each i) -// ref[indices[i], ...] += updates[i, ...] -// -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] -// -// Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions add. -// -// Requires `updates.shape = indices.shape + ref.shape[1:]`. -// -//
-// -//
-// -// Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. -// -// Returns the created operation. -func ResourceScatterAdd(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ResourceScatterAdd", - Input: []tf.Input{ - resource, indices, updates, - }, - } - return scope.AddOperation(opspec) -} + Type: "WholeFileReaderV2", -// Computes the gradient for the inverse of `x` wrt its input. -// -// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` -// is the corresponding input gradient. -func ReciprocalGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReciprocalGrad", - Input: []tf.Input{ - y, dy, - }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the min of x and y (i.e. x < y ? x : y) element-wise. +// Transforms a tf.Example proto (as a string) into typed tensors. // -// *NOTE*: `Minimum` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Arguments: +// serialized: A vector containing a batch of binary serialized Example protos. +// dense_defaults: A list of Tensors (some may be empty), whose length matches +// the length of `dense_keys`. dense_defaults[j] provides default values +// when the example's feature_map lacks dense_key[j]. If an empty Tensor is +// provided for dense_defaults[j], then the Feature dense_keys[j] is required. +// The input type is inferred from dense_defaults[j], even when it's empty. +// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, +// then the shape of dense_defaults[j] must match that of dense_shapes[j]. +// If dense_shapes[j] has an undefined major dimension (variable strides dense +// feature), dense_defaults[j] must contain a single element: +// the padding element. +// num_sparse: The number of sparse features to be parsed from the example. This +// must match the lengths of `sparse_keys` and `sparse_types`. +// sparse_keys: A list of `num_sparse` strings. +// The keys expected in the Examples' features associated with sparse values. +// dense_keys: The keys expected in the Examples' features associated with dense +// values. +// sparse_types: A list of `num_sparse` types; the data types of data in each +// Feature given in sparse_keys. +// Currently the ParseSingleExample op supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// dense_shapes: The shapes of data in each Feature given in dense_keys. +// The length of this list must match the length of `dense_keys`. The +// number of elements in the Feature corresponding to dense_key[j] must +// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] == +// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j] +// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1, +// ..., DN), the shape of the output Tensor dense_values[j] will be (M, +// D1, .., DN), where M is the number of blocks of elements of length +// D1 * .... * DN, in the input. +func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.Output, num_sparse int64, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_sparse": num_sparse, "sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes} opspec := tf.OpSpec{ - Type: "Minimum", + Type: "ParseSingleExample", Input: []tf.Input{ - x, y, + serialized, tf.OutputList(dense_defaults), }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return + } + if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return + } + if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return + } + if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return + } + return sparse_indices, sparse_values, sparse_shapes, dense_values } -// MfccAttr is an optional argument to Mfcc. -type MfccAttr func(optionalAttr) +// QuantizedConv2DAttr is an optional argument to QuantizedConv2D. +type QuantizedConv2DAttr func(optionalAttr) -// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value. -// -// value: The highest frequency to use when calculating the -// ceptstrum. -// If not specified, defaults to 4000 -func MfccUpperFrequencyLimit(value float32) MfccAttr { +// QuantizedConv2DOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { return func(m optionalAttr) { - m["upper_frequency_limit"] = value + m["out_type"] = value } } -// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value. +// QuantizedConv2DDilations sets the optional dilations attribute to value. // -// value: The lowest frequency to use when calculating the -// ceptstrum. -// If not specified, defaults to 20 -func MfccLowerFrequencyLimit(value float32) MfccAttr { +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each +// filter element on that dimension. The dimension order is determined by the +// value of `data_format`, see above for details. Dilations in the batch and +// depth dimensions must be 1. +// If not specified, defaults to +func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { return func(m optionalAttr) { - m["lower_frequency_limit"] = value + m["dilations"] = value } } -// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value. +// Computes a 2D convolution given quantized 4D input and filter tensors. // -// value: Resolution of the Mel bank used internally. -// If not specified, defaults to 40 -func MfccFilterbankChannelCount(value int64) MfccAttr { - return func(m optionalAttr) { - m["filterbank_channel_count"] = value +// The inputs are quantized tensors where the lowest value represents the real +// number of the associated minimum, and the highest represents the maximum. +// This means that you can only interpret the quantized output in the same way, by +// taking the returned minimum and maximum values into account. +// +// Arguments: +// +// filter: filter's input_depth dimension must match input's depth dimensions. +// min_input: The float value that the lowest quantized input value represents. +// max_input: The float value that the highest quantized input value represents. +// min_filter: The float value that the lowest quantized filter value represents. +// max_filter: The float value that the highest quantized filter value represents. +// strides: The stride of the sliding window for each dimension of the input +// tensor. +// padding: The type of padding algorithm to use. +// +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +func QuantizedConv2D(scope *Scope, input tf.Output, filter tf.Output, min_input tf.Output, max_input tf.Output, min_filter tf.Output, max_filter tf.Output, strides []int64, padding string, optional ...QuantizedConv2DAttr) (output tf.Output, min_output tf.Output, max_output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) } + opspec := tf.OpSpec{ + Type: "QuantizedConv2D", + Input: []tf.Input{ + input, filter, min_input, max_input, min_filter, max_filter, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value. -// -// value: How many output channels to produce per time slice. -// If not specified, defaults to 13 -func MfccDctCoefficientCount(value int64) MfccAttr { +// ResourceGatherAttr is an optional argument to ResourceGather. +type ResourceGatherAttr func(optionalAttr) + +// ResourceGatherValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func ResourceGatherValidateIndices(value bool) ResourceGatherAttr { return func(m optionalAttr) { - m["dct_coefficient_count"] = value + m["validate_indices"] = value } } -// Transforms a spectrogram into a form that's useful for speech recognition. +// Gather slices from the variable pointed to by `resource` according to `indices`. // -// Mel Frequency Cepstral Coefficients are a way of representing audio data that's -// been effective as an input feature for machine learning. They are created by -// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the -// higher frequencies that are less significant to the human ear. They have a long -// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum -// is a good resource to learn more. +// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). +// Produces an output tensor with shape `indices.shape + params.shape[1:]` where: // -// Arguments: -// spectrogram: Typically produced by the Spectrogram op, with magnitude_squared -// set to true. -// sample_rate: How many samples per second the source audio used. -func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) { +// ```python +// # Scalar indices +// output[:, ..., :] = params[indices, :, ... :] +// +// # Vector indices +// output[i, :, ..., :] = params[indices[i], :, ... :] +// +// # Higher rank indices +// output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :] +// ``` +func ResourceGather(scope *Scope, resource tf.Output, indices tf.Output, dtype tf.DataType, optional ...ResourceGatherAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtype": dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Mfcc", + Type: "ResourceGather", Input: []tf.Input{ - spectrogram, sample_rate, + resource, indices, }, Attrs: attrs, } @@ -14765,470 +14160,522 @@ func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional . return op.Output(0) } -// Returns the element-wise sum of a list of tensors. -// -// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not -// wait for all of its inputs to be ready before beginning to sum. This can -// save memory if inputs are ready at different times, since minimum temporary -// storage is proportional to the output size rather than the inputs size. -// -// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. +// Delete the TensorArray from its resource container. // -// Returns a `Tensor` of same shape and type as the elements of `inputs`. +// This enables the user to close and release the resource in the middle +// of a step/run. // // Arguments: -// inputs: A list of `Tensor` objects, each with same shape and type. -// shape: Shape of elements of `inputs`. -func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) { +// handle: The handle to a TensorArray (output of TensorArray or TensorArrayGrad). +// +// Returns the created operation. +func TensorArrayCloseV3(scope *Scope, handle tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"shape": shape} opspec := tf.OpSpec{ - Type: "AccumulateNV2", + Type: "TensorArrayCloseV3", Input: []tf.Input{ - tf.OutputList(inputs), + handle, }, - Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Convert the quantized 'input' tensor into a lower-precision 'output', using the -// -// actual distribution of the values to maximize the usage of the lower bit depth -// and adjusting the output min and max ranges accordingly. +// Adds two `SparseTensor` objects to produce another `SparseTensor`. // -// [input_min, input_max] are scalar floats that specify the range for the float -// interpretation of the 'input' data. For example, if input_min is -1.0f and -// input_max is 1.0f, and we are dealing with quint16 quantized data, then a 0 -// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f. +// The input `SparseTensor` objects' indices are assumed ordered in standard +// lexicographic order. If this is not the case, before this step run +// `SparseReorder` to restore index ordering. // -// This operator tries to squeeze as much precision as possible into an output with -// a lower bit depth by calculating the actual min and max values found in the -// data. For example, maybe that quint16 input has no values lower than 16,384 and -// none higher than 49,152. That means only half the range is actually needed, all -// the float interpretations are between -0.5f and 0.5f, so if we want to compress -// the data into a quint8 output, we can use that range rather than the theoretical -// -1.0f to 1.0f that is suggested by the input min and max. +// By default, if two values sum to zero at some index, the output `SparseTensor` +// would still include that particular location in its index, storing a zero in the +// corresponding value slot. To override this, callers can specify `thresh`, +// indicating that if the sum has a magnitude strictly smaller than `thresh`, its +// corresponding value and index would then not be included. In particular, +// `thresh == 0` (default) means everything is kept and actual thresholding happens +// only for a positive value. // -// In practice, this is most useful for taking output from operations like -// QuantizedMatMul that can produce higher bit-depth outputs than their inputs and -// may have large potential output ranges, but in practice have a distribution of -// input values that only uses a small fraction of the possible range. By feeding -// that output into this operator, we can reduce it from 32 bits down to 8 with -// minimal loss of accuracy. +// In the following shapes, `nnz` is the count after taking `thresh` into account. // // Arguments: -// -// input_min: The float value that the minimum quantized input value represents. -// input_max: The float value that the maximum quantized input value represents. -// out_type: The type of the output. Should be a lower bit depth than Tinput. -// -// Returns The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. -func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, out_type tf.DataType) (output tf.Output, output_min tf.Output, output_max tf.Output) { +// a_indices: 2-D. The `indices` of the first `SparseTensor`, size `[nnz, ndims]` Matrix. +// a_values: 1-D. The `values` of the first `SparseTensor`, size `[nnz]` Vector. +// a_shape: 1-D. The `shape` of the first `SparseTensor`, size `[ndims]` Vector. +// b_indices: 2-D. The `indices` of the second `SparseTensor`, size `[nnz, ndims]` Matrix. +// b_values: 1-D. The `values` of the second `SparseTensor`, size `[nnz]` Vector. +// b_shape: 1-D. The `shape` of the second `SparseTensor`, size `[ndims]` Vector. +// thresh: 0-D. The magnitude threshold that determines if an output value/index +// pair takes space. +func SparseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output, thresh tf.Output) (sum_indices tf.Output, sum_values tf.Output, sum_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"out_type": out_type} opspec := tf.OpSpec{ - Type: "QuantizeDownAndShrinkRange", + Type: "SparseAdd", Input: []tf.Input{ - input, input_min, input_max, + a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0), op.Output(1), op.Output(2) } -// RandomGammaAttr is an optional argument to RandomGamma. -type RandomGammaAttr func(optionalAttr) +// OrderedMapPeekAttr is an optional argument to OrderedMapPeek. +type OrderedMapPeekAttr func(optionalAttr) -// RandomGammaSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. +// OrderedMapPeekCapacity sets the optional capacity attribute to value. // If not specified, defaults to 0 -func RandomGammaSeed(value int64) RandomGammaAttr { +// +// REQUIRES: value >= 0 +func OrderedMapPeekCapacity(value int64) OrderedMapPeekAttr { return func(m optionalAttr) { - m["seed"] = value + m["capacity"] = value } } -// RandomGammaSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. +// OrderedMapPeekMemoryLimit sets the optional memory_limit attribute to value. // If not specified, defaults to 0 -func RandomGammaSeed2(value int64) RandomGammaAttr { +// +// REQUIRES: value >= 0 +func OrderedMapPeekMemoryLimit(value int64) OrderedMapPeekAttr { return func(m optionalAttr) { - m["seed2"] = value + m["memory_limit"] = value } } -// Outputs random values from the Gamma distribution(s) described by alpha. -// -// This op uses the algorithm by Marsaglia et al. to acquire samples via -// transformation-rejection from pairs of uniform and normal random variables. -// See http://dl.acm.org/citation.cfm?id=358414 -// -// Arguments: -// shape: 1-D integer tensor. Shape of independent samples to draw from each -// distribution described by the shape parameters given in alpha. -// alpha: A tensor in which each scalar is a "shape" parameter describing the -// associated gamma distribution. +// OrderedMapPeekContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func OrderedMapPeekContainer(value string) OrderedMapPeekAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// OrderedMapPeekSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func OrderedMapPeekSharedName(value string) OrderedMapPeekAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op peeks at the values at the specified key. If the // -// Returns A tensor with shape `shape + shape(alpha)`. Each slice -// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for -// `alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha. -func RandomGamma(scope *Scope, shape tf.Output, alpha tf.Output, optional ...RandomGammaAttr) (output tf.Output) { +// underlying container does not contain this key +// this op will block until it does. This Op is optimized for +// performance. +func OrderedMapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapPeekAttr) (values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RandomGamma", + Type: "OrderedMapPeek", Input: []tf.Input{ - shape, alpha, + key, indices, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("OrderedMapPeek", err) + return + } + return values } -// QuantizedConv2DAttr is an optional argument to QuantizedConv2D. -type QuantizedConv2DAttr func(optionalAttr) +// DecodeAndCropJpegAttr is an optional argument to DecodeAndCropJpeg. +type DecodeAndCropJpegAttr func(optionalAttr) -// QuantizedConv2DOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { +// DecodeAndCropJpegChannels sets the optional channels attribute to value. +// +// value: Number of color channels for the decoded image. +// If not specified, defaults to 0 +func DecodeAndCropJpegChannels(value int64) DecodeAndCropJpegAttr { return func(m optionalAttr) { - m["out_type"] = value + m["channels"] = value } } -// QuantizedConv2DDilations sets the optional dilations attribute to value. +// DecodeAndCropJpegRatio sets the optional ratio attribute to value. // -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each -// filter element on that dimension. The dimension order is determined by the -// value of `data_format`, see above for details. Dilations in the batch and -// depth dimensions must be 1. -// If not specified, defaults to -func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { +// value: Downscaling ratio. +// If not specified, defaults to 1 +func DecodeAndCropJpegRatio(value int64) DecodeAndCropJpegAttr { return func(m optionalAttr) { - m["dilations"] = value + m["ratio"] = value } } -// Computes a 2D convolution given quantized 4D input and filter tensors. +// DecodeAndCropJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. // -// The inputs are quantized tensors where the lowest value represents the real -// number of the associated minimum, and the highest represents the maximum. -// This means that you can only interpret the quantized output in the same way, by -// taking the returned minimum and maximum values into account. +// value: If true use a slower but nicer upscaling of the +// chroma planes (yuv420/422 only). +// If not specified, defaults to true +func DecodeAndCropJpegFancyUpscaling(value bool) DecodeAndCropJpegAttr { + return func(m optionalAttr) { + m["fancy_upscaling"] = value + } +} + +// DecodeAndCropJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. // -// Arguments: +// value: If true try to recover an image from truncated input. +// If not specified, defaults to false +func DecodeAndCropJpegTryRecoverTruncated(value bool) DecodeAndCropJpegAttr { + return func(m optionalAttr) { + m["try_recover_truncated"] = value + } +} + +// DecodeAndCropJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. // -// filter: filter's input_depth dimension must match input's depth dimensions. -// min_input: The float value that the lowest quantized input value represents. -// max_input: The float value that the highest quantized input value represents. -// min_filter: The float value that the lowest quantized filter value represents. -// max_filter: The float value that the highest quantized filter value represents. -// strides: The stride of the sliding window for each dimension of the input -// tensor. -// padding: The type of padding algorithm to use. +// value: The minimum required fraction of lines before a truncated +// input is accepted. +// If not specified, defaults to 1 +func DecodeAndCropJpegAcceptableFraction(value float32) DecodeAndCropJpegAttr { + return func(m optionalAttr) { + m["acceptable_fraction"] = value + } +} + +// DecodeAndCropJpegDctMethod sets the optional dct_method attribute to value. // -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -func QuantizedConv2D(scope *Scope, input tf.Output, filter tf.Output, min_input tf.Output, max_input tf.Output, min_filter tf.Output, max_filter tf.Output, strides []int64, padding string, optional ...QuantizedConv2DAttr) (output tf.Output, min_output tf.Output, max_output tf.Output) { +// value: string specifying a hint about the algorithm used for +// decompression. Defaults to "" which maps to a system-specific +// default. Currently valid values are ["INTEGER_FAST", +// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal +// jpeg library changes to a version that does not have that specific +// option.) +// If not specified, defaults to "" +func DecodeAndCropJpegDctMethod(value string) DecodeAndCropJpegAttr { + return func(m optionalAttr) { + m["dct_method"] = value + } +} + +// Decode and Crop a JPEG-encoded image to a uint8 tensor. +// +// The attr `channels` indicates the desired number of color channels for the +// decoded image. +// +// Accepted values are: +// +// * 0: Use the number of channels in the JPEG-encoded image. +// * 1: output a grayscale image. +// * 3: output an RGB image. +// +// If needed, the JPEG-encoded image is transformed to match the requested number +// of color channels. +// +// The attr `ratio` allows downscaling the image by an integer factor during +// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than +// downscaling the image later. +// +// +// It is equivalent to a combination of decode and crop, but much faster by only +// decoding partial jpeg image. +// +// Arguments: +// contents: 0-D. The JPEG-encoded image. +// crop_window: 1-D. The crop window: [crop_y, crop_x, crop_height, crop_width]. +// +// Returns 3-D with shape `[height, width, channels]`.. +func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output, optional ...DecodeAndCropJpegAttr) (image tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizedConv2D", + Type: "DecodeAndCropJpeg", Input: []tf.Input{ - input, filter, min_input, max_input, min_filter, max_filter, + contents, crop_window, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// ResourceGatherAttr is an optional argument to ResourceGather. -type ResourceGatherAttr func(optionalAttr) +// AllCandidateSamplerAttr is an optional argument to AllCandidateSampler. +type AllCandidateSamplerAttr func(optionalAttr) -// ResourceGatherValidateIndices sets the optional validate_indices attribute to value. -// If not specified, defaults to true -func ResourceGatherValidateIndices(value bool) ResourceGatherAttr { +// AllCandidateSamplerSeed sets the optional seed attribute to value. +// +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func AllCandidateSamplerSeed(value int64) AllCandidateSamplerAttr { return func(m optionalAttr) { - m["validate_indices"] = value + m["seed"] = value } } -// Gather slices from the variable pointed to by `resource` according to `indices`. +// AllCandidateSamplerSeed2 sets the optional seed2 attribute to value. // -// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). -// Produces an output tensor with shape `indices.shape + params.shape[1:]` where: +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func AllCandidateSamplerSeed2(value int64) AllCandidateSamplerAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Generates labels for candidate sampling with a learned unigram distribution. // -// ```python -// # Scalar indices -// output[:, ..., :] = params[indices, :, ... :] +// See explanations of candidate sampling and the data formats at +// go/candidate-sampling. // -// # Vector indices -// output[i, :, ..., :] = params[indices[i], :, ... :] +// For each batch, this op picks a single set of sampled candidate labels. +// +// The advantages of sampling candidates per-batch are simplicity and the +// possibility of efficient dense matrix multiplication. The disadvantage is that +// the sampled candidates must be chosen independently of the context and of the +// true labels. +// +// Arguments: +// true_classes: A batch_size * num_true matrix, in which each row contains the +// IDs of the num_true target_classes in the corresponding original label. +// num_true: Number of true labels per context. +// num_sampled: Number of candidates to produce. +// unique: If unique is true, we sample with rejection, so that all sampled +// candidates in a batch are unique. This requires some approximation to +// estimate the post-rejection sampling probabilities. // -// # Higher rank indices -// output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :] -// ``` -func ResourceGather(scope *Scope, resource tf.Output, indices tf.Output, dtype tf.DataType, optional ...ResourceGatherAttr) (output tf.Output) { +// Returns A vector of length num_sampled, in which each element is +// the ID of a sampled candidate.A batch_size * num_true matrix, representing +// the number of times each candidate is expected to occur in a batch +// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled +// candidate representing the number of times the candidate is expected +// to occur in a batch of sampled candidates. If unique=true, then this is a +// probability. +func AllCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, optional ...AllCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceGather", + Type: "AllCandidateSampler", Input: []tf.Input{ - resource, indices, + true_classes, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Delete the TensorArray from its resource container. +// Saves the input tensors to disk. // -// This enables the user to close and release the resource in the middle -// of a step/run. +// The size of `tensor_names` must match the number of tensors in `data`. `data[i]` +// is written to `filename` with name `tensor_names[i]`. +// +// See also `SaveSlices`. // // Arguments: -// handle: The handle to a TensorArray (output of TensorArray or TensorArrayGrad). +// filename: Must have a single element. The name of the file to which we write +// the tensor. +// tensor_names: Shape `[N]`. The names of the tensors to be saved. +// data: `N` tensors to save. // // Returns the created operation. -func TensorArrayCloseV3(scope *Scope, handle tf.Output) (o *tf.Operation) { +func Save(scope *Scope, filename tf.Output, tensor_names tf.Output, data []tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorArrayCloseV3", + Type: "Save", Input: []tf.Input{ - handle, + filename, tensor_names, tf.OutputList(data), }, } return scope.AddOperation(opspec) } -// RandomUniformIntAttr is an optional argument to RandomUniformInt. -type RandomUniformIntAttr func(optionalAttr) - -// RandomUniformIntSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomUniformIntSeed(value int64) RandomUniformIntAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomUniformIntSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomUniformIntSeed2(value int64) RandomUniformIntAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Outputs random integers from a uniform distribution. -// -// The generated values are uniform integers in the range `[minval, maxval)`. -// The lower bound `minval` is included in the range, while the upper bound -// `maxval` is excluded. -// -// The random integers are slightly biased unless `maxval - minval` is an exact -// power of two. The bias is small for values of `maxval - minval` significantly -// smaller than the range of the output (either `2^32` or `2^64`). +// Returns element-wise remainder of division. When `x < 0` xor `y < 0` is // -// Arguments: -// shape: The shape of the output tensor. -// minval: 0-D. Inclusive lower bound on the generated integers. -// maxval: 0-D. Exclusive upper bound on the generated integers. +// true, this follows Python semantics in that the result here is consistent +// with a flooring divide. E.g. `floor(x / y) * y + mod(x, y) = x`. // -// Returns A tensor of the specified shape filled with uniform random integers. -func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf.Output, optional ...RandomUniformIntAttr) (output tf.Output) { +// *NOTE*: `FloorMod` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func FloorMod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "RandomUniformInt", + Type: "FloorMod", Input: []tf.Input{ - shape, minval, maxval, + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// SkipgramAttr is an optional argument to Skipgram. -type SkipgramAttr func(optionalAttr) +// SparseTensorDenseMatMulAttr is an optional argument to SparseTensorDenseMatMul. +type SparseTensorDenseMatMulAttr func(optionalAttr) -// SkipgramWindowSize sets the optional window_size attribute to value. +// SparseTensorDenseMatMulAdjointA sets the optional adjoint_a attribute to value. // -// value: The number of words to predict to the left and right of the target. -// If not specified, defaults to 5 -func SkipgramWindowSize(value int64) SkipgramAttr { +// value: Use the adjoint of A in the matrix multiply. If A is complex, this +// is transpose(conj(A)). Otherwise it's transpose(A). +// If not specified, defaults to false +func SparseTensorDenseMatMulAdjointA(value bool) SparseTensorDenseMatMulAttr { return func(m optionalAttr) { - m["window_size"] = value + m["adjoint_a"] = value } } -// SkipgramMinCount sets the optional min_count attribute to value. +// SparseTensorDenseMatMulAdjointB sets the optional adjoint_b attribute to value. // -// value: The minimum number of word occurrences for it to be included in the -// vocabulary. -// If not specified, defaults to 5 -func SkipgramMinCount(value int64) SkipgramAttr { +// value: Use the adjoint of B in the matrix multiply. If B is complex, this +// is transpose(conj(B)). Otherwise it's transpose(B). +// If not specified, defaults to false +func SparseTensorDenseMatMulAdjointB(value bool) SparseTensorDenseMatMulAttr { return func(m optionalAttr) { - m["min_count"] = value + m["adjoint_b"] = value } } -// SkipgramSubsample sets the optional subsample attribute to value. +// Multiply SparseTensor (of rank 2) "A" by dense matrix "B". // -// value: Threshold for word occurrence. Words that appear with higher -// frequency will be randomly down-sampled. Set to 0 to disable. -// If not specified, defaults to 0.001 -func SkipgramSubsample(value float32) SkipgramAttr { - return func(m optionalAttr) { - m["subsample"] = value - } -} - -// Parses a text file and creates a batch of examples. +// No validity checking is performed on the indices of A. However, the following +// input format is recommended for optimal behavior: // -// DEPRECATED at GraphDef version 19: Moving word2vec into tensorflow_models/tutorials and deprecating its ops here as a result +// if adjoint_a == false: +// A should be sorted in lexicographically increasing order. Use SparseReorder +// if you're not sure. +// if adjoint_a == true: +// A should be sorted in order of increasing dimension 1 (i.e., "column major" +// order instead of "row major" order). // // Arguments: -// filename: The corpus's text file name. -// batch_size: The size of produced batch. -// -// Returns A vector of words in the corpus.Frequencies of words. Sorted in the non-ascending order.Number of words per epoch in the data file.The current epoch number.The total number of words processed so far.A vector of word ids.A vector of word ids. -func Skipgram(scope *Scope, filename string, batch_size int64, optional ...SkipgramAttr) (vocab_word tf.Output, vocab_freq tf.Output, words_per_epoch tf.Output, current_epoch tf.Output, total_words_processed tf.Output, examples tf.Output, labels tf.Output) { +// a_indices: 2-D. The `indices` of the `SparseTensor`, size `[nnz, 2]` Matrix. +// a_values: 1-D. The `values` of the `SparseTensor`, size `[nnz]` Vector. +// a_shape: 1-D. The `shape` of the `SparseTensor`, size `[2]` Vector. +// b: 2-D. A dense Matrix. +func SparseTensorDenseMatMul(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b tf.Output, optional ...SparseTensorDenseMatMulAttr) (product tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"filename": filename, "batch_size": batch_size} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Skipgram", - + Type: "SparseTensorDenseMatMul", + Input: []tf.Input{ + a_indices, a_values, a_shape, b, + }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4), op.Output(5), op.Output(6) + return op.Output(0) } -// StringToNumberAttr is an optional argument to StringToNumber. -type StringToNumberAttr func(optionalAttr) - -// StringToNumberOutType sets the optional out_type attribute to value. +// Deserialize and concatenate `SparseTensors` from a serialized minibatch. // -// value: The numeric type to interpret each string in `string_tensor` as. -// If not specified, defaults to DT_FLOAT -func StringToNumberOutType(value tf.DataType) StringToNumberAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Converts each string in the input Tensor to the specified numeric type. +// The input `serialized_sparse` must be a string matrix of shape `[N x 3]` where +// `N` is the minibatch size and the rows correspond to packed outputs of +// `SerializeSparse`. The ranks of the original `SparseTensor` objects +// must all match. When the final `SparseTensor` is created, it has rank one +// higher than the ranks of the incoming `SparseTensor` objects +// (they have been concatenated along a new row dimension). // -// (Note that int32 overflow results in an error while float overflow -// results in a rounded value.) +// The output `SparseTensor` object's shape values for all dimensions but the +// first are the max across the input `SparseTensor` objects' shape values +// for the corresponding dimensions. Its first shape value is `N`, the minibatch +// size. // -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToNumber(scope *Scope, string_tensor tf.Output, optional ...StringToNumberAttr) (output tf.Output) { +// The input `SparseTensor` objects' indices are assumed ordered in +// standard lexicographic order. If this is not the case, after this +// step run `SparseReorder` to restore index ordering. +// +// For example, if the serialized input is a `[2 x 3]` matrix representing two +// original `SparseTensor` objects: +// +// index = [ 0] +// [10] +// [20] +// values = [1, 2, 3] +// shape = [50] +// +// and +// +// index = [ 2] +// [10] +// values = [4, 5] +// shape = [30] +// +// then the final deserialized `SparseTensor` will be: +// +// index = [0 0] +// [0 10] +// [0 20] +// [1 2] +// [1 10] +// values = [1, 2, 3, 4, 5] +// shape = [2 50] +// +// Arguments: +// serialized_sparse: 2-D, The `N` serialized `SparseTensor` objects. +// Must have 3 columns. +// dtype: The `dtype` of the serialized `SparseTensor` objects. +func DeserializeManySparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "StringToNumber", + Type: "DeserializeManySparse", Input: []tf.Input{ - string_tensor, + serialized_sparse, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// ResourceApplyFtrlV2Attr is an optional argument to ResourceApplyFtrlV2. -type ResourceApplyFtrlV2Attr func(optionalAttr) +// StringJoinAttr is an optional argument to StringJoin. +type StringJoinAttr func(optionalAttr) -// ResourceApplyFtrlV2UseLocking sets the optional use_locking attribute to value. +// StringJoinSeparator sets the optional separator attribute to value. // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyFtrlV2UseLocking(value bool) ResourceApplyFtrlV2Attr { +// value: string, an optional join separator. +// If not specified, defaults to "" +func StringJoinSeparator(value string) StringJoinAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["separator"] = value } } -// Update '*var' according to the Ftrl-proximal scheme. +// Joins the strings in the given list of string tensors into one tensor; // -// grad_with_shrinkage = grad + 2 * l2_shrinkage * var -// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage -// linear += grad_with_shrinkage + -// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var -// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 -// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 -// accum = accum_new +// with the given separator (default is an empty separator). // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// linear: Should be from a Variable(). -// grad: The gradient. -// lr: Scaling factor. Must be a scalar. -// l1: L1 regulariation. Must be a scalar. -// l2: L2 shrinkage regulariation. Must be a scalar. -// -// lr_power: Scaling factor. Must be a scalar. -// -// Returns the created operation. -func ResourceApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, l2_shrinkage tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlV2Attr) (o *tf.Operation) { +// inputs: A list of string tensors. The tensors must all have the same shape, +// or be scalars. Scalars may be mixed in; these will be broadcast to the shape +// of non-scalar inputs. +func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -15237,325 +14684,395 @@ func ResourceApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear t a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyFtrlV2", + Type: "StringJoin", Input: []tf.Input{ - var_, accum, linear, grad, lr, l1, l2, l2_shrinkage, lr_power, + tf.OutputList(inputs), }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// TruncatedNormalAttr is an optional argument to TruncatedNormal. -type TruncatedNormalAttr func(optionalAttr) - -// TruncatedNormalSeed sets the optional seed attribute to value. +// Returns immutable tensor from memory region. // -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func TruncatedNormalSeed(value int64) TruncatedNormalAttr { - return func(m optionalAttr) { - m["seed"] = value +// The current implementation memmaps the tensor from a file. +// +// Arguments: +// dtype: Type of the returned tensor. +// shape: Shape of the returned tensor. +// memory_region_name: Name of readonly memory region used by the tensor, see +// NewReadOnlyMemoryRegionFromFile in tensorflow::Env. +func ImmutableConst(scope *Scope, dtype tf.DataType, shape tf.Shape, memory_region_name string) (tensor tf.Output) { + if scope.Err() != nil { + return } -} + attrs := map[string]interface{}{"dtype": dtype, "shape": shape, "memory_region_name": memory_region_name} + opspec := tf.OpSpec{ + Type: "ImmutableConst", -// TruncatedNormalSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func TruncatedNormalSeed2(value int64) TruncatedNormalAttr { - return func(m optionalAttr) { - m["seed2"] = value + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Outputs random values from a truncated normal distribution. +// Inverse real-valued fast Fourier transform. // -// The generated values follow a normal distribution with mean 0 and standard -// deviation 1, except that values whose magnitude is more than 2 standard -// deviations from the mean are dropped and re-picked. +// Computes the inverse 1-dimensional discrete Fourier transform of a real-valued +// signal over the inner-most dimension of `input`. +// +// The inner-most dimension of `input` is assumed to be the result of `RFFT`: the +// `fft_length / 2 + 1` unique components of the DFT of a real-valued signal. If +// `fft_length` is not provided, it is computed from the size of the inner-most +// dimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to +// compute `input` is odd, it should be provided since it cannot be inferred +// properly. +// +// Along the axis `IRFFT` is computed on, if `fft_length / 2 + 1` is smaller +// than the corresponding dimension of `input`, the dimension is cropped. If it is +// larger, the dimension is padded with zeros. // // Arguments: -// shape: The shape of the output tensor. -// dtype: The type of the output. +// input: A complex64 tensor. +// fft_length: An int32 tensor of shape [1]. The FFT length. // -// Returns A tensor of the specified shape filled with random truncated normal -// values. -func TruncatedNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...TruncatedNormalAttr) (output tf.Output) { +// Returns A float32 tensor of the same rank as `input`. The inner-most +// dimension of `input` is replaced with the `fft_length` samples of its inverse +// 1D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.irfft +// @end_compatibility +func IRFFT(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "TruncatedNormal", + Type: "IRFFT", Input: []tf.Input{ - shape, + input, fft_length, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// RandomShuffleAttr is an optional argument to RandomShuffle. -type RandomShuffleAttr func(optionalAttr) - -// RandomShuffleSeed sets the optional seed attribute to value. +// Concatenates a list of `SparseTensor` along the specified dimension. // -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomShuffleSeed(value int64) RandomShuffleAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomShuffleSeed2 sets the optional seed2 attribute to value. +// Concatenation is with respect to the dense versions of these sparse tensors. +// It is assumed that each input is a `SparseTensor` whose elements are ordered +// along increasing dimension number. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomShuffleSeed2(value int64) RandomShuffleAttr { - return func(m optionalAttr) { - m["seed2"] = value +// All inputs' shapes must match, except for the concat dimension. The +// `indices`, `values`, and `shapes` lists must have the same length. +// +// The output shape is identical to the inputs', except along the concat +// dimension, where it is the sum of the inputs' sizes along that dimension. +// +// The output elements will be resorted to preserve the sort order along +// increasing dimension number. +// +// This op runs in `O(M log M)` time, where `M` is the total number of non-empty +// values across all inputs. This is due to the need for an internal sort in +// order to concatenate efficiently across an arbitrary dimension. +// +// For example, if `concat_dim = 1` and the inputs are +// +// sp_inputs[0]: shape = [2, 3] +// [0, 2]: "a" +// [1, 0]: "b" +// [1, 1]: "c" +// +// sp_inputs[1]: shape = [2, 4] +// [0, 1]: "d" +// [0, 2]: "e" +// +// then the output will be +// +// shape = [2, 7] +// [0, 2]: "a" +// [0, 4]: "d" +// [0, 5]: "e" +// [1, 0]: "b" +// [1, 1]: "c" +// +// Graphically this is equivalent to doing +// +// [ a] concat [ d e ] = [ a d e ] +// [b c ] [ ] [b c ] +// +// Arguments: +// indices: 2-D. Indices of each input `SparseTensor`. +// values: 1-D. Non-empty values of each `SparseTensor`. +// shapes: 1-D. Shapes of each `SparseTensor`. +// concat_dim: Dimension to concatenate along. Must be in range [-rank, rank), +// where rank is the number of dimensions in each input `SparseTensor`. +// +// Returns 2-D. Indices of the concatenated `SparseTensor`.1-D. Non-empty values of the concatenated `SparseTensor`.1-D. Shape of the concatenated `SparseTensor`. +func SparseConcat(scope *Scope, indices []tf.Output, values []tf.Output, shapes []tf.Output, concat_dim int64) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"concat_dim": concat_dim} + opspec := tf.OpSpec{ + Type: "SparseConcat", + Input: []tf.Input{ + tf.OutputList(indices), tf.OutputList(values), tf.OutputList(shapes), + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// Randomly shuffles a tensor along its first dimension. +// Generates sparse cross from a list of sparse and dense tensors. // -// The tensor is shuffled along dimension 0, such that each `value[j]` is mapped -// to one and only one `output[i]`. For example, a mapping that might occur for a -// 3x2 tensor is: +// The op takes two lists, one of 2D `SparseTensor` and one of 2D `Tensor`, each +// representing features of one feature column. It outputs a 2D `SparseTensor` with +// the batchwise crosses of these features. // -// ``` -// [[1, 2], [[5, 6], -// [3, 4], ==> [1, 2], -// [5, 6]] [3, 4]] -// ``` +// For example, if the inputs are +// +// inputs[0]: SparseTensor with shape = [2, 2] +// [0, 0]: "a" +// [1, 0]: "b" +// [1, 1]: "c" +// +// inputs[1]: SparseTensor with shape = [2, 1] +// [0, 0]: "d" +// [1, 0]: "e" +// +// inputs[2]: Tensor [["f"], ["g"]] +// +// then the output will be +// +// shape = [2, 2] +// [0, 0]: "a_X_d_X_f" +// [1, 0]: "b_X_e_X_g" +// [1, 1]: "c_X_e_X_g" +// +// if hashed_output=true then the output will be +// +// shape = [2, 2] +// [0, 0]: FingerprintCat64( +// Fingerprint64("f"), FingerprintCat64( +// Fingerprint64("d"), Fingerprint64("a"))) +// [1, 0]: FingerprintCat64( +// Fingerprint64("g"), FingerprintCat64( +// Fingerprint64("e"), Fingerprint64("b"))) +// [1, 1]: FingerprintCat64( +// Fingerprint64("g"), FingerprintCat64( +// Fingerprint64("e"), Fingerprint64("c"))) // // Arguments: -// value: The tensor to be shuffled. +// indices: 2-D. Indices of each input `SparseTensor`. +// values: 1-D. values of each `SparseTensor`. +// shapes: 1-D. Shapes of each `SparseTensor`. +// dense_inputs: 2-D. Columns represented by dense `Tensor`. +// hashed_output: If true, returns the hash of the cross instead of the string. +// This will allow us avoiding string manipulations. +// num_buckets: It is used if hashed_output is true. +// output = hashed_value%num_buckets if num_buckets > 0 else hashed_value. +// hash_key: Specify the hash_key that will be used by the `FingerprintCat64` +// function to combine the crosses fingerprints. // -// Returns A tensor of same shape and type as `value`, shuffled along its first -// dimension. -func RandomShuffle(scope *Scope, value tf.Output, optional ...RandomShuffleAttr) (output tf.Output) { +// +// +// Returns 2-D. Indices of the concatenated `SparseTensor`.1-D. Non-empty values of the concatenated or hashed +// `SparseTensor`.1-D. Shape of the concatenated `SparseTensor`. +func SparseCross(scope *Scope, indices []tf.Output, values []tf.Output, shapes []tf.Output, dense_inputs []tf.Output, hashed_output bool, num_buckets int64, hash_key int64, out_type tf.DataType, internal_type tf.DataType) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"hashed_output": hashed_output, "num_buckets": num_buckets, "hash_key": hash_key, "out_type": out_type, "internal_type": internal_type} opspec := tf.OpSpec{ - Type: "RandomShuffle", + Type: "SparseCross", Input: []tf.Input{ - value, + tf.OutputList(indices), tf.OutputList(values), tf.OutputList(shapes), tf.OutputList(dense_inputs), }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// OrderedMapIncompleteSizeAttr is an optional argument to OrderedMapIncompleteSize. -type OrderedMapIncompleteSizeAttr func(optionalAttr) - -// OrderedMapIncompleteSizeCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// Concatenates quantized tensors along one dimension. // -// REQUIRES: value >= 0 -func OrderedMapIncompleteSizeCapacity(value int64) OrderedMapIncompleteSizeAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// OrderedMapIncompleteSizeMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// Arguments: +// concat_dim: 0-D. The dimension along which to concatenate. Must be in the +// range [0, rank(values)). +// values: The `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. +// input_mins: The minimum scalar values for each of the input tensors. +// input_maxes: The maximum scalar values for each of the input tensors. // -// REQUIRES: value >= 0 -func OrderedMapIncompleteSizeMemoryLimit(value int64) OrderedMapIncompleteSizeAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// OrderedMapIncompleteSizeContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func OrderedMapIncompleteSizeContainer(value string) OrderedMapIncompleteSizeAttr { - return func(m optionalAttr) { - m["container"] = value +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes.The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. +func QuantizedConcat(scope *Scope, concat_dim tf.Output, values []tf.Output, input_mins []tf.Output, input_maxes []tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { + if scope.Err() != nil { + return } -} - -// OrderedMapIncompleteSizeSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func OrderedMapIncompleteSizeSharedName(value string) OrderedMapIncompleteSizeAttr { - return func(m optionalAttr) { - m["shared_name"] = value + opspec := tf.OpSpec{ + Type: "QuantizedConcat", + Input: []tf.Input{ + concat_dim, tf.OutputList(values), tf.OutputList(input_mins), tf.OutputList(input_maxes), + }, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// Op returns the number of incomplete elements in the underlying container. -func OrderedMapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...OrderedMapIncompleteSizeAttr) (size tf.Output) { +// Slice a `SparseTensor` based on the `start` and `size`. +// +// For example, if the input is +// +// input_tensor = shape = [2, 7] +// [ a d e ] +// [b c ] +// +// Graphically the output tensors are: +// +// sparse_slice([0, 0], [2, 4]) = shape = [2, 4] +// [ a ] +// [b c ] +// +// sparse_slice([0, 4], [2, 3]) = shape = [2, 3] +// [ d e ] +// [ ] +// +// Arguments: +// indices: 2-D tensor represents the indices of the sparse tensor. +// values: 1-D tensor represents the values of the sparse tensor. +// shape: 1-D. tensor represents the shape of the sparse tensor. +// start: 1-D. tensor represents the start of the slice. +// size: 1-D. tensor represents the size of the slice. +// output indices: A list of 1-D tensors represents the indices of the output +// sparse tensors. +// +// Returns A list of 1-D tensors represents the values of the output sparse +// tensors.A list of 1-D tensors represents the shape of the output sparse +// tensors. +func SparseSlice(scope *Scope, indices tf.Output, values tf.Output, shape tf.Output, start tf.Output, size tf.Output) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "OrderedMapIncompleteSize", - - Attrs: attrs, + Type: "SparseSlice", + Input: []tf.Input{ + indices, values, shape, start, size, + }, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// DecodeRawAttr is an optional argument to DecodeRaw. -type DecodeRawAttr func(optionalAttr) - -// DecodeRawLittleEndian sets the optional little_endian attribute to value. +// Adds up a `SparseTensor` and a dense `Tensor`, producing a dense `Tensor`. // -// value: Whether the input `bytes` are in little-endian order. -// Ignored for `out_type` values that are stored in a single byte like -// `uint8`. -// If not specified, defaults to true -func DecodeRawLittleEndian(value bool) DecodeRawAttr { - return func(m optionalAttr) { - m["little_endian"] = value - } -} - -// Reinterpret the bytes of a string as a vector of numbers. +// This Op does not require `a_indices` be sorted in standard lexicographic order. // // Arguments: -// bytes: All the elements must have the same length. -// -// -// Returns A Tensor with one more dimension than the input `bytes`. The -// added dimension will have size equal to the length of the elements -// of `bytes` divided by the number of bytes to represent `out_type`. -func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ...DecodeRawAttr) (output tf.Output) { +// a_indices: 2-D. The `indices` of the `SparseTensor`, with shape `[nnz, ndims]`. +// a_values: 1-D. The `values` of the `SparseTensor`, with shape `[nnz]`. +// a_shape: 1-D. The `shape` of the `SparseTensor`, with shape `[ndims]`. +// b: `ndims`-D Tensor. With shape `a_shape`. +func SparseTensorDenseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"out_type": out_type} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "DecodeRaw", + Type: "SparseTensorDenseAdd", Input: []tf.Input{ - bytes, + a_indices, a_values, a_shape, b, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Copy a tensor setting everything outside a central band in each innermost matrix -// -// to zero. -// -// The `band` part is computed as follows: -// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a -// tensor with the same shape where -// -// `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`. -// -// The indicator function -// -// `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && -// (num_upper < 0 || (n-m) <= num_upper)`. -// -// For example: -// -// ``` -// # if 'input' is [[ 0, 1, 2, 3] -// [-1, 0, 1, 2] -// [-2, -1, 0, 1] -// [-3, -2, -1, 0]], -// -// tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3] -// [-1, 0, 1, 2] -// [ 0, -1, 0, 1] -// [ 0, 0, -1, 0]], -// -// tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0] -// [-1, 0, 1, 0] -// [-2, -1, 0, 1] -// [ 0, -2, -1, 0]] -// ``` -// -// Useful special cases: +// Returns the set of files matching one or more glob patterns. // -// ``` -// tf.matrix_band_part(input, 0, -1) ==> Upper triangular part. -// tf.matrix_band_part(input, -1, 0) ==> Lower triangular part. -// tf.matrix_band_part(input, 0, 0) ==> Diagonal. -// ``` +// Note that this routine only supports wildcard characters in the +// basename portion of the pattern, not in the directory portion. // // Arguments: -// input: Rank `k` tensor. -// num_lower: 0-D tensor. Number of subdiagonals to keep. If negative, keep entire -// lower triangle. -// num_upper: 0-D tensor. Number of superdiagonals to keep. If negative, keep -// entire upper triangle. +// pattern: Shell wildcard pattern(s). Scalar or vector of type string. // -// Returns Rank `k` tensor of the same shape as input. The extracted banded tensor. -func MatrixBandPart(scope *Scope, input tf.Output, num_lower tf.Output, num_upper tf.Output) (band tf.Output) { +// Returns A vector of matching filenames. +func MatchingFiles(scope *Scope, pattern tf.Output) (filenames tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MatrixBandPart", + Type: "MatchingFiles", Input: []tf.Input{ - input, num_lower, num_upper, + pattern, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// DecodeCompressedAttr is an optional argument to DecodeCompressed. -type DecodeCompressedAttr func(optionalAttr) +// MatrixSolveLsAttr is an optional argument to MatrixSolveLs. +type MatrixSolveLsAttr func(optionalAttr) -// DecodeCompressedCompressionType sets the optional compression_type attribute to value. -// -// value: A scalar containing either (i) the empty string (no -// compression), (ii) "ZLIB", or (iii) "GZIP". -// If not specified, defaults to "" -func DecodeCompressedCompressionType(value string) DecodeCompressedAttr { +// MatrixSolveLsFast sets the optional fast attribute to value. +// If not specified, defaults to true +func MatrixSolveLsFast(value bool) MatrixSolveLsAttr { return func(m optionalAttr) { - m["compression_type"] = value + m["fast"] = value } } -// Decompress strings. +// Solves one or more linear least-squares problems. // -// This op decompresses each element of the `bytes` input `Tensor`, which -// is assumed to be compressed using the given `compression_type`. +// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions +// form real or complex matrices of size `[M, N]`. `Rhs` is a tensor of the same +// type as `matrix` and shape `[..., M, K]`. +// The output is a tensor shape `[..., N, K]` where each output matrix solves +// each of the equations +// `matrix[..., :, :]` * `output[..., :, :]` = `rhs[..., :, :]` +// in the least squares sense. +// +// We use the following notation for (complex) matrix and right-hand sides +// in the batch: +// +// `matrix`=\\(A \in \mathbb{C}^{m \times n}\\), +// `rhs`=\\(B \in \mathbb{C}^{m \times k}\\), +// `output`=\\(X \in \mathbb{C}^{n \times k}\\), +// `l2_regularizer`=\\(\lambda \in \mathbb{R}\\). +// +// If `fast` is `True`, then the solution is computed by solving the normal +// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then +// \\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares +// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + +// \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as +// \\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the +// minimum-norm solution to the under-determined linear system, i.e. +// \\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\), +// subject to \\(A Z = B\\). Notice that the fast path is only numerically stable +// when \\(A\\) is numerically full rank and has a condition number +// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or\\(\lambda\\) is +// sufficiently large. // -// The `output` is a string `Tensor` of the same shape as `bytes`, -// each element containing the decompressed data from the corresponding -// element in `bytes`. +// If `fast` is `False` an algorithm based on the numerically robust complete +// orthogonal decomposition is used. This computes the minimum-norm +// least-squares solution, even when \\(A\\) is rank deficient. This path is +// typically 6-7 times slower than the fast path. If `fast` is `False` then +// `l2_regularizer` is ignored. // // Arguments: -// bytes: A Tensor of string which is compressed. +// matrix: Shape is `[..., M, N]`. +// rhs: Shape is `[..., M, K]`. +// l2_regularizer: Scalar tensor. // -// Returns A Tensor with the same shape as input `bytes`, uncompressed -// from bytes. -func DecodeCompressed(scope *Scope, bytes tf.Output, optional ...DecodeCompressedAttr) (output tf.Output) { +// @compatibility(numpy) +// Equivalent to np.linalg.lstsq +// @end_compatibility +// +// Returns Shape is `[..., N, K]`. +func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -15564,9 +15081,9 @@ func DecodeCompressed(scope *Scope, bytes tf.Output, optional ...DecodeCompresse a(attrs) } opspec := tf.OpSpec{ - Type: "DecodeCompressed", + Type: "MatrixSolveLs", Input: []tf.Input{ - bytes, + matrix, rhs, l2_regularizer, }, Attrs: attrs, } @@ -15574,522 +15091,537 @@ func DecodeCompressed(scope *Scope, bytes tf.Output, optional ...DecodeCompresse return op.Output(0) } -// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2. -type WholeFileReaderV2Attr func(optionalAttr) - -// WholeFileReaderV2Container sets the optional container attribute to value. +// Elementwise computes the bitwise OR of `x` and `y`. // -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr { - return func(m optionalAttr) { - m["container"] = value +// The result will have those bits set, that are set in `x`, `y` or both. The +// computation is performed on the underlying representations of `x` and `y`. +func BitwiseOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BitwiseOr", + Input: []tf.Input{ + x, y, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// WholeFileReaderV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr { +// SparseToSparseSetOperationAttr is an optional argument to SparseToSparseSetOperation. +type SparseToSparseSetOperationAttr func(optionalAttr) + +// SparseToSparseSetOperationValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func SparseToSparseSetOperationValidateIndices(value bool) SparseToSparseSetOperationAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["validate_indices"] = value } } -// A Reader that outputs the entire contents of a file as a value. +// Applies set operation along last dimension of 2 `SparseTensor` inputs. // -// To use, enqueue filenames in a Queue. The output of ReaderRead will -// be a filename (key) and the contents of that file (value). +// See SetOperationOp::SetOperationFromContext for values of `set_operation`. // -// Returns The handle to reference the Reader. -func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) { +// If `validate_indices` is `True`, `SparseToSparseSetOperation` validates the +// order and range of `set1` and `set2` indices. +// +// Input `set1` is a `SparseTensor` represented by `set1_indices`, `set1_values`, +// and `set1_shape`. For `set1` ranked `n`, 1st `n-1` dimensions must be the same +// as `set2`. Dimension `n` contains values in a set, duplicates are allowed but +// ignored. +// +// Input `set2` is a `SparseTensor` represented by `set2_indices`, `set2_values`, +// and `set2_shape`. For `set2` ranked `n`, 1st `n-1` dimensions must be the same +// as `set1`. Dimension `n` contains values in a set, duplicates are allowed but +// ignored. +// +// If `validate_indices` is `True`, this op validates the order and range of `set1` +// and `set2` indices. +// +// Output `result` is a `SparseTensor` represented by `result_indices`, +// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this +// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` +// dimension contains the result of `set_operation` applied to the corresponding +// `[0...n-1]` dimension of `set`. +// +// Arguments: +// set1_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major +// order. +// set1_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major +// order. +// set1_shape: 1D `Tensor`, shape of a `SparseTensor`. `set1_shape[0...n-1]` must +// be the same as `set2_shape[0...n-1]`, `set1_shape[n]` is the +// max set size across `0...n-1` dimensions. +// set2_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major +// order. +// set2_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major +// order. +// set2_shape: 1D `Tensor`, shape of a `SparseTensor`. `set2_shape[0...n-1]` must +// be the same as `set1_shape[0...n-1]`, `set2_shape[n]` is the +// max set size across `0...n-1` dimensions. +// +// +// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is +// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` +// is the max result set size across all `0...n-1` dimensions. +func SparseToSparseSetOperation(scope *Scope, set1_indices tf.Output, set1_values tf.Output, set1_shape tf.Output, set2_indices tf.Output, set2_values tf.Output, set2_shape tf.Output, set_operation string, optional ...SparseToSparseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"set_operation": set_operation} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "WholeFileReaderV2", - + Type: "SparseToSparseSetOperation", + Input: []tf.Input{ + set1_indices, set1_values, set1_shape, set2_indices, set2_values, set2_shape, + }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Transforms a tf.Example proto (as a string) into typed tensors. +// Computes numerical negative value element-wise. // -// Arguments: -// serialized: A vector containing a batch of binary serialized Example protos. -// dense_defaults: A list of Tensors (some may be empty), whose length matches -// the length of `dense_keys`. dense_defaults[j] provides default values -// when the example's feature_map lacks dense_key[j]. If an empty Tensor is -// provided for dense_defaults[j], then the Feature dense_keys[j] is required. -// The input type is inferred from dense_defaults[j], even when it's empty. -// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, -// then the shape of dense_defaults[j] must match that of dense_shapes[j]. -// If dense_shapes[j] has an undefined major dimension (variable strides dense -// feature), dense_defaults[j] must contain a single element: -// the padding element. -// num_sparse: The number of sparse features to be parsed from the example. This -// must match the lengths of `sparse_keys` and `sparse_types`. -// sparse_keys: A list of `num_sparse` strings. -// The keys expected in the Examples' features associated with sparse values. -// dense_keys: The keys expected in the Examples' features associated with dense -// values. -// sparse_types: A list of `num_sparse` types; the data types of data in each -// Feature given in sparse_keys. -// Currently the ParseSingleExample op supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// dense_shapes: The shapes of data in each Feature given in dense_keys. -// The length of this list must match the length of `dense_keys`. The -// number of elements in the Feature corresponding to dense_key[j] must -// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] == -// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j] -// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1, -// ..., DN), the shape of the output Tensor dense_values[j] will be (M, -// D1, .., DN), where M is the number of blocks of elements of length -// D1 * .... * DN, in the input. -func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.Output, num_sparse int64, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { +// I.e., \\(y = -x\\). +func Neg(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_sparse": num_sparse, "sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes} opspec := tf.OpSpec{ - Type: "ParseSingleExample", + Type: "Neg", Input: []tf.Input{ - serialized, tf.OutputList(dense_defaults), + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil { - scope.UpdateErr("ParseSingleExample", err) - return - } - if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil { - scope.UpdateErr("ParseSingleExample", err) - return - } - if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil { - scope.UpdateErr("ParseSingleExample", err) - return + return op.Output(0) +} + +// FakeQuantWithMinMaxVarsAttr is an optional argument to FakeQuantWithMinMaxVars. +type FakeQuantWithMinMaxVarsAttr func(optionalAttr) + +// FakeQuantWithMinMaxVarsNumBits sets the optional num_bits attribute to value. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxVarsNumBits(value int64) FakeQuantWithMinMaxVarsAttr { + return func(m optionalAttr) { + m["num_bits"] = value } - if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil { - scope.UpdateErr("ParseSingleExample", err) - return +} + +// FakeQuantWithMinMaxVarsNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsNarrowRange(value bool) FakeQuantWithMinMaxVarsAttr { + return func(m optionalAttr) { + m["narrow_range"] = value } - return sparse_indices, sparse_values, sparse_shapes, dense_values } -// Computes acos of x element-wise. -func Acos(scope *Scope, x tf.Output) (y tf.Output) { +// Fake-quantize the 'inputs' tensor of type float via global float scalars `min` +// +// and `max` to 'outputs' tensor of same shape as `inputs`. +// +// `[min; max]` define the clamping range for the `inputs` data. +// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` +// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and +// then de-quantized and output as floats in `[min; max]` interval. +// `num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive. +// +// This operation has a gradient and thus allows for training `min` and `max` +// values. +func FakeQuantWithMinMaxVars(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsAttr) (outputs tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Acos", + Type: "FakeQuantWithMinMaxVars", Input: []tf.Input{ - x, + inputs, min, max, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// MaxPoolWithArgmaxAttr is an optional argument to MaxPoolWithArgmax. -type MaxPoolWithArgmaxAttr func(optionalAttr) - -// MaxPoolWithArgmaxTargmax sets the optional Targmax attribute to value. -// If not specified, defaults to DT_INT64 -func MaxPoolWithArgmaxTargmax(value tf.DataType) MaxPoolWithArgmaxAttr { - return func(m optionalAttr) { - m["Targmax"] = value - } -} - -// Performs max pooling on the input and outputs both max values and indices. -// -// The indices in `argmax` are flattened, so that a maximum value at position -// `[b, y, x, c]` becomes flattened index -// `((b * height + y) * width + x) * channels + c`. +// Returns the element-wise min of two SparseTensors. // -// The indices returned are always in `[0, height) x [0, width)` before flattening, -// even if padding is involved and the mathematically correct answer is outside -// (either negative or too large). This is a bug, but fixing it is difficult to do -// in a safe backwards compatible way, especially due to flattening. +// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. // // Arguments: -// input: 4-D with shape `[batch, height, width, channels]`. Input to pool over. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. +// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, in the canonical lexicographic ordering. +// a_values: 1-D. `N` non-empty values corresponding to `a_indices`. +// a_shape: 1-D. Shape of the input SparseTensor. +// b_indices: counterpart to `a_indices` for the other operand. +// b_values: counterpart to `a_values` for the other operand; must be of the same dtype. +// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal. // -// Returns The max pooled output tensor.4-D. The flattened indices of the max values chosen for each output. -func MaxPoolWithArgmax(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolWithArgmaxAttr) (output tf.Output, argmax tf.Output) { +// Returns 2-D. The indices of the output SparseTensor.1-D. The values of the output SparseTensor. +func SparseSparseMinimum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "MaxPoolWithArgmax", + Type: "SparseSparseMinimum", Input: []tf.Input{ - input, + a_indices, a_values, a_shape, b_indices, b_values, b_shape, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0), op.Output(1) } -// Transforms a serialized tensorflow.TensorProto proto into a Tensor. +// Constructs a tensor by tiling a given tensor. // -// Arguments: -// serialized: A scalar string containing a serialized TensorProto proto. -// out_type: The type of the serialized tensor. The provided type must match the -// type of the serialized tensor and no implicit conversion will take place. +// This operation creates a new tensor by replicating `input` `multiples` times. +// The output tensor's i'th dimension has `input.dims(i) * multiples[i]` elements, +// and the values of `input` are replicated `multiples[i]` times along the 'i'th +// dimension. For example, tiling `[a b c d]` by `[2]` produces +// `[a b c d a b c d]`. // -// Returns A Tensor of type `out_type`. -func ParseTensor(scope *Scope, serialized tf.Output, out_type tf.DataType) (output tf.Output) { +// Arguments: +// input: 1-D or higher. +// multiples: 1-D. Length must be the same as the number of dimensions in `input` +func Tile(scope *Scope, input tf.Output, multiples tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"out_type": out_type} opspec := tf.OpSpec{ - Type: "ParseTensor", + Type: "Tile", Input: []tf.Input{ - serialized, + input, multiples, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// MapClearAttr is an optional argument to MapClear. -type MapClearAttr func(optionalAttr) - -// MapClearCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapClearCapacity(value int64) MapClearAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} +// TakeManySparseFromTensorsMapAttr is an optional argument to TakeManySparseFromTensorsMap. +type TakeManySparseFromTensorsMapAttr func(optionalAttr) -// MapClearMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// TakeManySparseFromTensorsMapContainer sets the optional container attribute to value. // -// REQUIRES: value >= 0 -func MapClearMemoryLimit(value int64) MapClearAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// MapClearContainer sets the optional container attribute to value. +// value: The container name for the `SparseTensorsMap` read by this op. // If not specified, defaults to "" -func MapClearContainer(value string) MapClearAttr { +func TakeManySparseFromTensorsMapContainer(value string) TakeManySparseFromTensorsMapAttr { return func(m optionalAttr) { m["container"] = value } } -// MapClearSharedName sets the optional shared_name attribute to value. +// TakeManySparseFromTensorsMapSharedName sets the optional shared_name attribute to value. +// +// value: The shared name for the `SparseTensorsMap` read by this op. +// It should not be blank; rather the `shared_name` or unique Operation name +// of the Op that created the original `SparseTensorsMap` should be used. // If not specified, defaults to "" -func MapClearSharedName(value string) MapClearAttr { +func TakeManySparseFromTensorsMapSharedName(value string) TakeManySparseFromTensorsMapAttr { return func(m optionalAttr) { m["shared_name"] = value } } -// Op removes all elements in the underlying container. +// Read `SparseTensors` from a `SparseTensorsMap` and concatenate them. // -// Returns the created operation. -func MapClear(scope *Scope, dtypes []tf.DataType, optional ...MapClearAttr) (o *tf.Operation) { +// The input `sparse_handles` must be an `int64` matrix of shape `[N, 1]` where +// `N` is the minibatch size and the rows correspond to the output handles of +// `AddSparseToTensorsMap` or `AddManySparseToTensorsMap`. The ranks of the +// original `SparseTensor` objects that went into the given input ops must all +// match. When the final `SparseTensor` is created, it has rank one +// higher than the ranks of the incoming `SparseTensor` objects +// (they have been concatenated along a new row dimension on the left). +// +// The output `SparseTensor` object's shape values for all dimensions but the +// first are the max across the input `SparseTensor` objects' shape values +// for the corresponding dimensions. Its first shape value is `N`, the minibatch +// size. +// +// The input `SparseTensor` objects' indices are assumed ordered in +// standard lexicographic order. If this is not the case, after this +// step run `SparseReorder` to restore index ordering. +// +// For example, if the handles represent an input, which is a `[2, 3]` matrix +// representing two original `SparseTensor` objects: +// +// ``` +// index = [ 0] +// [10] +// [20] +// values = [1, 2, 3] +// shape = [50] +// ``` +// +// and +// +// ``` +// index = [ 2] +// [10] +// values = [4, 5] +// shape = [30] +// ``` +// +// then the final `SparseTensor` will be: +// +// ``` +// index = [0 0] +// [0 10] +// [0 20] +// [1 2] +// [1 10] +// values = [1, 2, 3, 4, 5] +// shape = [2 50] +// ``` +// +// Arguments: +// sparse_handles: 1-D, The `N` serialized `SparseTensor` objects. +// Shape: `[N]`. +// dtype: The `dtype` of the `SparseTensor` objects stored in the +// `SparseTensorsMap`. +// +// Returns 2-D. The `indices` of the minibatch `SparseTensor`.1-D. The `values` of the minibatch `SparseTensor`.1-D. The `shape` of the minibatch `SparseTensor`. +func TakeManySparseFromTensorsMap(scope *Scope, sparse_handles tf.Output, dtype tf.DataType, optional ...TakeManySparseFromTensorsMapAttr) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{"dtype": dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MapClear", - + Type: "TakeManySparseFromTensorsMap", + Input: []tf.Input{ + sparse_handles, + }, Attrs: attrs, } - return scope.AddOperation(opspec) -} - -// DecodeCSVAttr is an optional argument to DecodeCSV. -type DecodeCSVAttr func(optionalAttr) - -// DecodeCSVFieldDelim sets the optional field_delim attribute to value. -// -// value: char delimiter to separate fields in a record. -// If not specified, defaults to "," -func DecodeCSVFieldDelim(value string) DecodeCSVAttr { - return func(m optionalAttr) { - m["field_delim"] = value - } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// DecodeCSVUseQuoteDelim sets the optional use_quote_delim attribute to value. -// -// value: If false, treats double quotation marks as regular -// characters inside of the string fields (ignoring RFC 4180, Section 2, -// Bullet 5). -// If not specified, defaults to true -func DecodeCSVUseQuoteDelim(value bool) DecodeCSVAttr { - return func(m optionalAttr) { - m["use_quote_delim"] = value - } -} +// MaxPoolAttr is an optional argument to MaxPool. +type MaxPoolAttr func(optionalAttr) -// DecodeCSVNaValue sets the optional na_value attribute to value. +// MaxPoolDataFormat sets the optional data_format attribute to value. // -// value: Additional string to recognize as NA/NaN. -// If not specified, defaults to "" -func DecodeCSVNaValue(value string) DecodeCSVAttr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolDataFormat(value string) MaxPoolAttr { return func(m optionalAttr) { - m["na_value"] = value + m["data_format"] = value } } -// Convert CSV records to tensors. Each column maps to one tensor. -// -// RFC 4180 format is expected for the CSV records. -// (https://tools.ietf.org/html/rfc4180) -// Note that we allow leading and trailing spaces with int or float field. +// Performs max pooling on the input. // // Arguments: -// records: Each string is a record/row in the csv and all records should have -// the same format. -// record_defaults: One tensor per column of the input record, with either a -// scalar default value for that column or empty if the column is required. +// input: 4-D input to pool over. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. // -// Returns Each tensor will have the same shape as records. -func DecodeCSV(scope *Scope, records tf.Output, record_defaults []tf.Output, optional ...DecodeCSVAttr) (output []tf.Output) { +// Returns The max pooled output tensor. +func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DecodeCSV", + Type: "MaxPool", Input: []tf.Input{ - records, tf.OutputList(record_defaults), + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("DecodeCSV", err) - return - } - return output + return op.Output(0) } -// Returns the rank of a tensor. +// Says whether the targets are in the top `K` predictions. // -// This operation returns an integer representing the rank of `input`. +// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the +// prediction for the target class is among the top `k` predictions among +// all predictions for example `i`. Note that the behavior of `InTopK` differs +// from the `TopK` op in its handling of ties; if multiple classes have the +// same prediction value and straddle the top-`k` boundary, all of those +// classes are considered to be in the top `k`. // -// For example: +// More formally, let // -// ``` -// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] -// # shape of tensor 't' is [2, 2, 3] -// rank(t) ==> 3 -// ``` +// \\(predictions_i\\) be the predictions for all classes for example `i`, +// \\(targets_i\\) be the target class for example `i`, +// \\(out_i\\) be the output for example `i`, // -// **Note**: The rank of a tensor is not the same as the rank of a matrix. The rank -// of a tensor is the number of indices required to uniquely select each element -// of the tensor. Rank is also known as "order", "degree", or "ndims." -func Rank(scope *Scope, input tf.Output) (output tf.Output) { +// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ +// +// Arguments: +// predictions: A `batch_size` x `classes` tensor. +// targets: A `batch_size` vector of class ids. +// k: Number of top elements to look at for computing precision. +// +// Returns Computed precision at `k` as a `bool Tensor`. +func InTopKV2(scope *Scope, predictions tf.Output, targets tf.Output, k tf.Output) (precision tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Rank", + Type: "InTopKV2", Input: []tf.Input{ - input, + predictions, targets, k, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Output a fact about factorials. -func Fact(scope *Scope) (fact tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Fact", - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Makes its input available to the next iteration. +// Assigns a new value to a variable. +// +// Any ReadVariableOp with a control dependency on this op is guaranteed to return +// this value or a subsequent newer value of the variable. // // Arguments: -// data: The tensor to be made available to the next iteration. +// resource: handle to the resource in which to store the variable. +// value: the value to set the new tensor to use. // -// Returns The same tensor as `data`. -func NextIteration(scope *Scope, data tf.Output) (output tf.Output) { +// Returns the created operation. +func AssignVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "NextIteration", + Type: "AssignVariableOp", Input: []tf.Input{ - data, + resource, value, }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Creates a dataset that skips `count` elements from the `input_dataset`. +// Returns a tensor of ones with the same shape and type as x. // // Arguments: +// x: a tensor of type T. // -// count: A scalar representing the number of elements from the `input_dataset` -// that should be skipped. If count is -1, skips everything. -// -// -func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Returns a tensor of the same shape and type as x but filled with ones. +func OnesLike(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "SkipDataset", + Type: "OnesLike", Input: []tf.Input{ - input_dataset, count, + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes hyperbolic tangent of `x` element-wise. -func Tanh(scope *Scope, x tf.Output) (y tf.Output) { +// The gradient of SparseFillEmptyRows. +// +// Takes vectors reverse_index_map, shaped `[N]`, and grad_values, +// shaped `[N_full]`, where `N_full >= N` and copies data into either +// `d_values` or `d_default_value`. Here `d_values` is shaped `[N]` and +// `d_default_value` is a scalar. +// +// d_values[j] = grad_values[reverse_index_map[j]] +// d_default_value = sum_{k : 0 .. N_full - 1} ( +// grad_values[k] * 1{k not in reverse_index_map}) +// +// Arguments: +// reverse_index_map: 1-D. The reverse index map from SparseFillEmptyRows. +// grad_values: 1-D. The gradients from backprop. +// +// Returns 1-D. The backprop into values.0-D. The backprop into default_value. +func SparseFillEmptyRowsGrad(scope *Scope, reverse_index_map tf.Output, grad_values tf.Output) (d_values tf.Output, d_default_value tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Tanh", + Type: "SparseFillEmptyRowsGrad", Input: []tf.Input{ - x, + reverse_index_map, grad_values, }, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Computes the maximum along segments of a tensor. -// -// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of -// segments. -// -// Computes a tensor such that -// \\(output_i = \max_j(data_j)\\) where `max` is over `j` such -// that `segment_ids[j] == i`. -// -// If the max is empty for a given segment ID `i`, `output[i] = 0`. -// -//
-// -//
-// -// Arguments: +// Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` // -// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s -// first dimension. Values should be sorted and can be repeated. +// if < 0, `scale * features` otherwise. // -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { +// See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) +func Selu(scope *Scope, features tf.Output) (activations tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SegmentMax", + Type: "Selu", Input: []tf.Input{ - data, segment_ids, + features, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// AvgPoolGradAttr is an optional argument to AvgPoolGrad. -type AvgPoolGradAttr func(optionalAttr) +// SetSizeAttr is an optional argument to SetSize. +type SetSizeAttr func(optionalAttr) -// AvgPoolGradDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func AvgPoolGradDataFormat(value string) AvgPoolGradAttr { +// SetSizeValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func SetSizeValidateIndices(value bool) SetSizeAttr { return func(m optionalAttr) { - m["data_format"] = value + m["validate_indices"] = value } } -// Computes gradients of the average pooling function. +// Number of unique elements along last dimension of input `set`. +// +// Input `set` is a `SparseTensor` represented by `set_indices`, `set_values`, +// and `set_shape`. The last dimension contains values in a set, duplicates are +// allowed but ignored. +// +// If `validate_indices` is `True`, this op validates the order and range of `set` +// indices. // // Arguments: -// orig_input_shape: 1-D. Shape of the original input to `avg_pool`. -// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. -// the output of `avg_pool`. -// ksize: The size of the sliding window for each dimension of the input. -// strides: The stride of the sliding window for each dimension of the input. -// padding: The type of padding algorithm to use. +// set_indices: 2D `Tensor`, indices of a `SparseTensor`. +// set_values: 1D `Tensor`, values of a `SparseTensor`. +// set_shape: 1D `Tensor`, shape of a `SparseTensor`. // -// Returns 4-D. Gradients w.r.t. the input of `avg_pool`. -func AvgPoolGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolGradAttr) (output tf.Output) { +// Returns For `set` ranked `n`, this is a `Tensor` with rank `n-1`, and the same 1st +// `n-1` dimensions as `set`. Each value is the number of unique elements in +// the corresponding `[0...n-1]` dimension of `set`. +func SetSize(scope *Scope, set_indices tf.Output, set_values tf.Output, set_shape tf.Output, optional ...SetSizeAttr) (size tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "AvgPoolGrad", + Type: "SetSize", Input: []tf.Input{ - orig_input_shape, grad, + set_indices, set_values, set_shape, }, Attrs: attrs, } @@ -16097,334 +15629,303 @@ func AvgPoolGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize return op.Output(0) } -// StageClearAttr is an optional argument to StageClear. -type StageClearAttr func(optionalAttr) - -// StageClearCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// Computes the sign and the log of the absolute value of the determinant of // -// REQUIRES: value >= 0 -func StageClearCapacity(value int64) StageClearAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// StageClearMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// one or more square matrices. // -// REQUIRES: value >= 0 -func StageClearMemoryLimit(value int64) StageClearAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// StageClearContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func StageClearContainer(value string) StageClearAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// StageClearSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func StageClearSharedName(value string) StageClearAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op removes all elements in the underlying container. +// The input is a tensor of shape `[N, M, M]` whose inner-most 2 dimensions +// form square matrices. The outputs are two tensors containing the signs and +// absolute values of the log determinants for all N input submatrices +// `[..., :, :]` such that the determinant = sign*exp(log_abs_determinant). +// The log_abs_determinant is computed as det(P)*sum(log(diag(LU))) where LU +// is the LU decomposition of the input and P is the corresponding +// permutation matrix. // -// Returns the created operation. -func StageClear(scope *Scope, dtypes []tf.DataType, optional ...StageClearAttr) (o *tf.Operation) { +// Arguments: +// input: Shape is `[N, M, M]`. +// +// Returns The signs of the log determinants of the inputs. Shape is `[N]`.The logs of the absolute values of the determinants +// of the N input matrices. Shape is `[N]`. +func LogMatrixDeterminant(scope *Scope, input tf.Output) (sign tf.Output, log_abs_determinant tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "StageClear", - - Attrs: attrs, + Type: "LogMatrixDeterminant", + Input: []tf.Input{ + input, + }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } -// ComputeAccidentalHitsAttr is an optional argument to ComputeAccidentalHits. -type ComputeAccidentalHitsAttr func(optionalAttr) - -// ComputeAccidentalHitsSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func ComputeAccidentalHitsSeed(value int64) ComputeAccidentalHitsAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} +// SumAttr is an optional argument to Sum. +type SumAttr func(optionalAttr) -// ComputeAccidentalHitsSeed2 sets the optional seed2 attribute to value. +// SumKeepDims sets the optional keep_dims attribute to value. // -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func ComputeAccidentalHitsSeed2(value int64) ComputeAccidentalHitsAttr { +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func SumKeepDims(value bool) SumAttr { return func(m optionalAttr) { - m["seed2"] = value + m["keep_dims"] = value } } -// Computes the ids of the positions in sampled_candidates that match true_labels. -// -// When doing log-odds NCE, the result of this op should be passed through a -// SparseToDense op, then added to the logits of the sampled candidates. This has -// the effect of 'removing' the sampled labels that match the true labels by -// making the classifier sure that they are sampled labels. +// Computes the sum of elements across dimensions of a tensor. +// +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. // // Arguments: -// true_classes: The true_classes output of UnpackSparseLabels. -// sampled_candidates: The sampled_candidates output of CandidateSampler. -// num_true: Number of true labels per context. +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. // -// Returns A vector of indices corresponding to rows of true_candidates.A vector of IDs of positions in sampled_candidates that match a true_label -// for the row with the corresponding index in indices.A vector of the same length as indices and ids, in which each element -// is -FLOAT_MAX. -func ComputeAccidentalHits(scope *Scope, true_classes tf.Output, sampled_candidates tf.Output, num_true int64, optional ...ComputeAccidentalHitsAttr) (indices tf.Output, ids tf.Output, weights tf.Output) { +// Returns The reduced tensor. +func Sum(scope *Scope, input tf.Output, axis tf.Output, optional ...SumAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_true": num_true} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ComputeAccidentalHits", + Type: "Sum", Input: []tf.Input{ - true_classes, sampled_candidates, + input, axis, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Computes sigmoid of `x` element-wise. +// Delete the tensor specified by its handle in the session. // -// Specifically, `y = 1 / (1 + exp(-x))`. -func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) { +// Arguments: +// handle: The handle for a tensor stored in the session state. +// +// Returns the created operation. +func DeleteSessionTensor(scope *Scope, handle tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Sigmoid", + Type: "DeleteSessionTensor", Input: []tf.Input{ - x, + handle, }, } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RandomStandardNormalAttr is an optional argument to RandomStandardNormal. -type RandomStandardNormalAttr func(optionalAttr) - -// RandomStandardNormalSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomStandardNormalSeed(value int64) RandomStandardNormalAttr { - return func(m optionalAttr) { - m["seed"] = value - } + return scope.AddOperation(opspec) } -// RandomStandardNormalSeed2 sets the optional seed2 attribute to value. +// L2 Loss. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomStandardNormalSeed2(value int64) RandomStandardNormalAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Outputs random values from a normal distribution. +// Computes half the L2 norm of a tensor without the `sqrt`: // -// The generated values will have mean 0 and standard deviation 1. +// output = sum(t ** 2) / 2 // // Arguments: -// shape: The shape of the output tensor. -// dtype: The type of the output. +// t: Typically 2-D, but may have any dimensions. // -// Returns A tensor of the specified shape filled with random normal values. -func RandomStandardNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomStandardNormalAttr) (output tf.Output) { +// Returns 0-D. +func L2Loss(scope *Scope, t tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "RandomStandardNormal", + Type: "L2Loss", Input: []tf.Input{ - shape, + t, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// FusedBatchNormAttr is an optional argument to FusedBatchNorm. -type FusedBatchNormAttr func(optionalAttr) +// DenseToSparseSetOperationAttr is an optional argument to DenseToSparseSetOperation. +type DenseToSparseSetOperationAttr func(optionalAttr) -// FusedBatchNormEpsilon sets the optional epsilon attribute to value. -// -// value: A small float number added to the variance of x. -// If not specified, defaults to 0.0001 -func FusedBatchNormEpsilon(value float32) FusedBatchNormAttr { +// DenseToSparseSetOperationValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func DenseToSparseSetOperationValidateIndices(value bool) DenseToSparseSetOperationAttr { return func(m optionalAttr) { - m["epsilon"] = value + m["validate_indices"] = value } } -// FusedBatchNormDataFormat sets the optional data_format attribute to value. +// Applies set operation along last dimension of `Tensor` and `SparseTensor`. // -// value: The data format for x and y. Either "NHWC" (default) or "NCHW". -// If not specified, defaults to "NHWC" -func FusedBatchNormDataFormat(value string) FusedBatchNormAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// FusedBatchNormIsTraining sets the optional is_training attribute to value. +// See SetOperationOp::SetOperationFromContext for values of `set_operation`. // -// value: A bool value to indicate the operation is for training (default) -// or inference. -// If not specified, defaults to true -func FusedBatchNormIsTraining(value bool) FusedBatchNormAttr { - return func(m optionalAttr) { - m["is_training"] = value - } -} - -// Batch normalization. +// Input `set2` is a `SparseTensor` represented by `set2_indices`, `set2_values`, +// and `set2_shape`. For `set2` ranked `n`, 1st `n-1` dimensions must be the same +// as `set1`. Dimension `n` contains values in a set, duplicates are allowed but +// ignored. // -// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -// The size of 1D Tensors matches the dimension C of the 4D Tensors. +// If `validate_indices` is `True`, this op validates the order and range of `set2` +// indices. +// +// Output `result` is a `SparseTensor` represented by `result_indices`, +// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this +// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` +// dimension contains the result of `set_operation` applied to the corresponding +// `[0...n-1]` dimension of `set`. // // Arguments: -// x: A 4D Tensor for input data. -// scale: A 1D Tensor for scaling factor, to scale the normalized x. -// offset: A 1D Tensor for offset, to shift to the normalized x. -// mean: A 1D Tensor for population mean. Used for inference only; -// must be empty for training. -// variance: A 1D Tensor for population variance. Used for inference only; -// must be empty for training. +// set1: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set2`. +// Dimension `n` contains values in a set, duplicates are allowed but ignored. +// set2_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major +// order. +// set2_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major +// order. +// set2_shape: 1D `Tensor`, shape of a `SparseTensor`. `set2_shape[0...n-1]` must +// be the same as the 1st `n-1` dimensions of `set1`, `result_shape[n]` is the +// max set size across `n-1` dimensions. // -// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow -// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by -// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused -// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance -// in the cuDNN case), to be reused in the gradient computation. -func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormAttr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) { +// +// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is +// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` +// is the max result set size across all `0...n-1` dimensions. +func DenseToSparseSetOperation(scope *Scope, set1 tf.Output, set2_indices tf.Output, set2_values tf.Output, set2_shape tf.Output, set_operation string, optional ...DenseToSparseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"set_operation": set_operation} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "FusedBatchNorm", + Type: "DenseToSparseSetOperation", Input: []tf.Input{ - x, scale, offset, mean, variance, + set1, set2_indices, set2_values, set2_shape, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) + return op.Output(0), op.Output(1), op.Output(2) } -// Computes tan of x element-wise. -func Tan(scope *Scope, x tf.Output) (y tf.Output) { +// Subtracts a value from the current value of a variable. +// +// Any ReadVariableOp which depends directly or indirectly on this assign is +// guaranteed to see the incremented value or a subsequent newer one. +// +// Outputs the incremented value, which can be used to totally order the +// increments to this variable. +// +// Arguments: +// resource: handle to the resource in which to store the variable. +// value: the value by which the variable will be incremented. +// +// Returns the created operation. +func AssignSubVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Tan", + Type: "AssignSubVariableOp", Input: []tf.Input{ - x, + resource, value, }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// FusedBatchNormV2Attr is an optional argument to FusedBatchNormV2. -type FusedBatchNormV2Attr func(optionalAttr) +// RestoreAttr is an optional argument to Restore. +type RestoreAttr func(optionalAttr) -// FusedBatchNormV2Epsilon sets the optional epsilon attribute to value. +// RestorePreferredShard sets the optional preferred_shard attribute to value. // -// value: A small float number added to the variance of x. -// If not specified, defaults to 0.0001 -func FusedBatchNormV2Epsilon(value float32) FusedBatchNormV2Attr { +// value: Index of file to open first if multiple files match +// `file_pattern`. +// If not specified, defaults to -1 +func RestorePreferredShard(value int64) RestoreAttr { return func(m optionalAttr) { - m["epsilon"] = value + m["preferred_shard"] = value } } -// FusedBatchNormV2DataFormat sets the optional data_format attribute to value. +// Restores a tensor from checkpoint files. // -// value: The data format for x and y. Either "NHWC" (default) or "NCHW". -// If not specified, defaults to "NHWC" -func FusedBatchNormV2DataFormat(value string) FusedBatchNormV2Attr { - return func(m optionalAttr) { - m["data_format"] = value +// Reads a tensor stored in one or several files. If there are several files (for +// instance because a tensor was saved as slices), `file_pattern` may contain +// wildcard symbols (`*` and `?`) in the filename portion only, not in the +// directory portion. +// +// If a `file_pattern` matches several files, `preferred_shard` can be used to hint +// in which file the requested tensor is likely to be found. This op will first +// open the file at index `preferred_shard` in the list of matching files and try +// to restore tensors from that file. Only if some tensors or tensor slices are +// not found in that first file, then the Op opens all the files. Setting +// `preferred_shard` to match the value passed as the `shard` input +// of a matching `Save` Op may speed up Restore. This attribute only affects +// performance, not correctness. The default value -1 means files are processed in +// order. +// +// See also `RestoreSlice`. +// +// Arguments: +// file_pattern: Must have a single element. The pattern of the files from +// which we read the tensor. +// tensor_name: Must have a single element. The name of the tensor to be +// restored. +// dt: The type of the tensor to be restored. +// +// Returns The restored tensor. +func Restore(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, dt tf.DataType, optional ...RestoreAttr) (tensor tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dt": dt} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Restore", + Input: []tf.Input{ + file_pattern, tensor_name, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// FusedBatchNormV2IsTraining sets the optional is_training attribute to value. +// QuantizedResizeBilinearAttr is an optional argument to QuantizedResizeBilinear. +type QuantizedResizeBilinearAttr func(optionalAttr) + +// QuantizedResizeBilinearAlignCorners sets the optional align_corners attribute to value. // -// value: A bool value to indicate the operation is for training (default) -// or inference. -// If not specified, defaults to true -func FusedBatchNormV2IsTraining(value bool) FusedBatchNormV2Attr { +// value: If true, rescale input by (new_height - 1) / (height - 1), which +// exactly aligns the 4 corners of images and resized images. If false, rescale +// by new_height / height. Treat similarly the width dimension. +// If not specified, defaults to false +func QuantizedResizeBilinearAlignCorners(value bool) QuantizedResizeBilinearAttr { return func(m optionalAttr) { - m["is_training"] = value + m["align_corners"] = value } } -// Batch normalization. +// Resize quantized `images` to `size` using quantized bilinear interpolation. // -// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -// The size of 1D Tensors matches the dimension C of the 4D Tensors. +// Input images and output images must be quantized types. // // Arguments: -// x: A 4D Tensor for input data. -// scale: A 1D Tensor for scaling factor, to scale the normalized x. -// offset: A 1D Tensor for offset, to shift to the normalized x. -// mean: A 1D Tensor for population mean. Used for inference only; -// must be empty for training. -// variance: A 1D Tensor for population variance. Used for inference only; -// must be empty for training. +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. // -// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow -// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by -// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused -// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance -// in the cuDNN case), to be reused in the gradient computation. -func FusedBatchNormV2(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormV2Attr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) { +// +// +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func QuantizedResizeBilinear(scope *Scope, images tf.Output, size tf.Output, min tf.Output, max tf.Output, optional ...QuantizedResizeBilinearAttr) (resized_images tf.Output, out_min tf.Output, out_max tf.Output) { if scope.Err() != nil { return } @@ -16433,194 +15934,192 @@ func FusedBatchNormV2(scope *Scope, x tf.Output, scale tf.Output, offset tf.Outp a(attrs) } opspec := tf.OpSpec{ - Type: "FusedBatchNormV2", + Type: "QuantizedResizeBilinear", Input: []tf.Input{ - x, scale, offset, mean, variance, + images, size, min, max, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) + return op.Output(0), op.Output(1), op.Output(2) } -// MultinomialAttr is an optional argument to Multinomial. -type MultinomialAttr func(optionalAttr) - -// MultinomialSeed sets the optional seed attribute to value. +// Computes the minimum along segments of a tensor. // -// value: If either seed or seed2 is set to be non-zero, the internal random number -// generator is seeded by the given seed. Otherwise, a random seed is used. -// If not specified, defaults to 0 -func MultinomialSeed(value int64) MultinomialAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// MultinomialSeed2 sets the optional seed2 attribute to value. +// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of +// segments. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func MultinomialSeed2(value int64) MultinomialAttr { - return func(m optionalAttr) { - m["seed2"] = value +// Computes a tensor such that +// \\(output_i = \min_j(data_j)\\) where `min` is over `j` such +// that `segment_ids[j] == i`. +// +// If the min is empty for a given segment ID `i`, `output[i] = 0`. +// +//
+// +//
+// +// Arguments: +// +// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s +// first dimension. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SegmentMin", + Input: []tf.Input{ + data, segment_ids, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// MultinomialOutputDtype sets the optional output_dtype attribute to value. -// If not specified, defaults to DT_INT64 -func MultinomialOutputDtype(value tf.DataType) MultinomialAttr { +// SdcaOptimizerAttr is an optional argument to SdcaOptimizer. +type SdcaOptimizerAttr func(optionalAttr) + +// SdcaOptimizerAdaptative sets the optional adaptative attribute to value. +// +// value: Whether to use Adapative SDCA for the inner loop. +// If not specified, defaults to false +func SdcaOptimizerAdaptative(value bool) SdcaOptimizerAttr { return func(m optionalAttr) { - m["output_dtype"] = value + m["adaptative"] = value } } -// Draws samples from a multinomial distribution. +// Distributed version of Stochastic Dual Coordinate Ascent (SDCA) optimizer for +// +// linear models with L1 + L2 regularization. As global optimization objective is +// strongly-convex, the optimizer optimizes the dual objective at each step. The +// optimizer applies each update one example at a time. Examples are sampled +// uniformly, and the optimizer is learning rate free and enjoys linear convergence +// rate. +// +// [Proximal Stochastic Dual Coordinate Ascent](http://arxiv.org/pdf/1211.2717v1.pdf).
+// Shai Shalev-Shwartz, Tong Zhang. 2012 +// +// $$Loss Objective = \sum f_{i} (wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|$$ +// +// [Adding vs. Averaging in Distributed Primal-Dual Optimization](http://arxiv.org/abs/1502.03508).
+// Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, +// Peter Richtarik, Martin Takac. 2015 +// +// [Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).
+// Dominik Csiba, Zheng Qu, Peter Richtarik. 2015 // // Arguments: -// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` -// represents the unnormalized log probabilities for all classes. -// num_samples: 0-D. Number of independent samples to draw for each row slice. +// sparse_example_indices: a list of vectors which contain example indices. +// sparse_feature_indices: a list of vectors which contain feature indices. +// sparse_feature_values: a list of vectors which contains feature value +// associated with each feature group. +// dense_features: a list of matrices which contains the dense feature values. +// example_weights: a vector which contains the weight associated with each +// example. +// example_labels: a vector which contains the label/target associated with each +// example. +// sparse_indices: a list of vectors where each value is the indices which has +// corresponding weights in sparse_weights. This field maybe omitted for the +// dense approach. +// sparse_weights: a list of vectors where each value is the weight associated with +// a sparse feature group. +// dense_weights: a list of vectors where the values are the weights associated +// with a dense feature group. +// example_state_data: a list of vectors containing the example state data. +// loss_type: Type of the primal loss. Currently SdcaSolver supports logistic, +// squared and hinge losses. +// l1: Symmetric l1 regularization strength. +// l2: Symmetric l2 regularization strength. +// num_loss_partitions: Number of partitions of the global loss function. +// num_inner_iterations: Number of iterations per mini-batch. // -// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]` -// contains the drawn class labels with range `[0, num_classes)`. -func Multinomial(scope *Scope, logits tf.Output, num_samples tf.Output, optional ...MultinomialAttr) (output tf.Output) { +// Returns a list of vectors containing the updated example state +// data.a list of vectors where each value is the delta +// weights associated with a sparse feature group.a list of vectors where the values are the delta +// weights associated with a dense feature group. +func SdcaOptimizer(scope *Scope, sparse_example_indices []tf.Output, sparse_feature_indices []tf.Output, sparse_feature_values []tf.Output, dense_features []tf.Output, example_weights tf.Output, example_labels tf.Output, sparse_indices []tf.Output, sparse_weights []tf.Output, dense_weights []tf.Output, example_state_data tf.Output, loss_type string, l1 float32, l2 float32, num_loss_partitions int64, num_inner_iterations int64, optional ...SdcaOptimizerAttr) (out_example_state_data tf.Output, out_delta_sparse_weights []tf.Output, out_delta_dense_weights []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"loss_type": loss_type, "l1": l1, "l2": l2, "num_loss_partitions": num_loss_partitions, "num_inner_iterations": num_inner_iterations} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Multinomial", + Type: "SdcaOptimizer", Input: []tf.Input{ - logits, num_samples, + tf.OutputList(sparse_example_indices), tf.OutputList(sparse_feature_indices), tf.OutputList(sparse_feature_values), tf.OutputList(dense_features), example_weights, example_labels, tf.OutputList(sparse_indices), tf.OutputList(sparse_weights), tf.OutputList(dense_weights), example_state_data, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// EncodeJpegAttr is an optional argument to EncodeJpeg. -type EncodeJpegAttr func(optionalAttr) - -// EncodeJpegFormat sets the optional format attribute to value. -// -// value: Per pixel image format. -// If not specified, defaults to "" -func EncodeJpegFormat(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["format"] = value - } -} - -// EncodeJpegQuality sets the optional quality attribute to value. -// -// value: Quality of the compression from 0 to 100 (higher is better and slower). -// If not specified, defaults to 95 -func EncodeJpegQuality(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["quality"] = value + if scope.Err() != nil { + return } -} - -// EncodeJpegProgressive sets the optional progressive attribute to value. -// -// value: If True, create a JPEG that loads progressively (coarse to fine). -// If not specified, defaults to false -func EncodeJpegProgressive(value bool) EncodeJpegAttr { - return func(m optionalAttr) { - m["progressive"] = value + var idx int + var err error + out_example_state_data = op.Output(idx) + if out_delta_sparse_weights, idx, err = makeOutputList(op, idx, "out_delta_sparse_weights"); err != nil { + scope.UpdateErr("SdcaOptimizer", err) + return } -} - -// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value. -// -// value: If True, spend CPU/RAM to reduce size with no quality change. -// If not specified, defaults to false -func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr { - return func(m optionalAttr) { - m["optimize_size"] = value + if out_delta_dense_weights, idx, err = makeOutputList(op, idx, "out_delta_dense_weights"); err != nil { + scope.UpdateErr("SdcaOptimizer", err) + return } + return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights } -// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value. -// -// value: See http://en.wikipedia.org/wiki/Chroma_subsampling. -// If not specified, defaults to true -func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr { - return func(m optionalAttr) { - m["chroma_downsampling"] = value - } -} +// SparseMatMulAttr is an optional argument to SparseMatMul. +type SparseMatMulAttr func(optionalAttr) -// EncodeJpegDensityUnit sets the optional density_unit attribute to value. -// -// value: Unit used to specify `x_density` and `y_density`: -// pixels per inch (`'in'`) or centimeter (`'cm'`). -// If not specified, defaults to "in" -func EncodeJpegDensityUnit(value string) EncodeJpegAttr { +// SparseMatMulTransposeA sets the optional transpose_a attribute to value. +// If not specified, defaults to false +func SparseMatMulTransposeA(value bool) SparseMatMulAttr { return func(m optionalAttr) { - m["density_unit"] = value + m["transpose_a"] = value } } -// EncodeJpegXDensity sets the optional x_density attribute to value. -// -// value: Horizontal pixels per density unit. -// If not specified, defaults to 300 -func EncodeJpegXDensity(value int64) EncodeJpegAttr { +// SparseMatMulTransposeB sets the optional transpose_b attribute to value. +// If not specified, defaults to false +func SparseMatMulTransposeB(value bool) SparseMatMulAttr { return func(m optionalAttr) { - m["x_density"] = value + m["transpose_b"] = value } } -// EncodeJpegYDensity sets the optional y_density attribute to value. -// -// value: Vertical pixels per density unit. -// If not specified, defaults to 300 -func EncodeJpegYDensity(value int64) EncodeJpegAttr { +// SparseMatMulAIsSparse sets the optional a_is_sparse attribute to value. +// If not specified, defaults to false +func SparseMatMulAIsSparse(value bool) SparseMatMulAttr { return func(m optionalAttr) { - m["y_density"] = value + m["a_is_sparse"] = value } } -// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value. -// -// value: If not empty, embed this XMP metadata in the image header. -// If not specified, defaults to "" -func EncodeJpegXmpMetadata(value string) EncodeJpegAttr { +// SparseMatMulBIsSparse sets the optional b_is_sparse attribute to value. +// If not specified, defaults to false +func SparseMatMulBIsSparse(value bool) SparseMatMulAttr { return func(m optionalAttr) { - m["xmp_metadata"] = value + m["b_is_sparse"] = value } } -// JPEG-encode an image. -// -// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`. -// -// The attr `format` can be used to override the color format of the encoded -// output. Values can be: -// -// * `''`: Use a default format based on the number of channels in the image. -// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension -// of `image` must be 1. -// * `rgb`: Output an RGB JPEG image. The `channels` dimension -// of `image` must be 3. -// -// If `format` is not specified or is the empty string, a default format is picked -// in function of the number of channels in `image`: -// -// * 1: Output a grayscale image. -// * 3: Output an RGB image. +// Multiply matrix "a" by matrix "b". // -// Arguments: -// image: 3-D with shape `[height, width, channels]`. +// The inputs must be two-dimensional matrices and the inner dimension of "a" must +// match the outer dimension of "b". This op is optimized for the case where at +// least one of "a" or "b" is sparse. The breakeven for using this versus a dense +// matrix multiply on one platform was 30% zero values in the sparse matrix. // -// Returns 0-D. JPEG-encoded image. -func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) { +// The gradient computation of this operation will only take advantage of sparsity +// in the input gradient when that gradient comes from a Relu. +func SparseMatMul(scope *Scope, a tf.Output, b tf.Output, optional ...SparseMatMulAttr) (product tf.Output) { if scope.Err() != nil { return } @@ -16629,9 +16128,9 @@ func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (cont a(attrs) } opspec := tf.OpSpec{ - Type: "EncodeJpeg", + Type: "SparseMatMul", Input: []tf.Input{ - image, + a, b, }, Attrs: attrs, } @@ -16639,114 +16138,52 @@ func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (cont return op.Output(0) } -// MaxPoolGradAttr is an optional argument to MaxPoolGrad. -type MaxPoolGradAttr func(optionalAttr) - -// MaxPoolGradDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolGradDataFormat(value string) MaxPoolGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes gradients of the maxpooling function. +// Computes the power of one value to another. // -// Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: 4-D. Gradients w.r.t. the output of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. +// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for +// corresponding elements in `x` and `y`. For example: // -// Returns Gradients w.r.t. the input to `max_pool`. -func MaxPoolGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradAttr) (output tf.Output) { +// ``` +// # tensor 'x' is [[2, 2]], [3, 3]] +// # tensor 'y' is [[8, 16], [2, 3]] +// tf.pow(x, y) ==> [[256, 65536], [9, 27]] +// ``` +func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "MaxPoolGrad", + Type: "Pow", Input: []tf.Input{ - orig_input, orig_output, grad, + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// CropAndResizeAttr is an optional argument to CropAndResize. -type CropAndResizeAttr func(optionalAttr) - -// CropAndResizeMethod sets the optional method attribute to value. -// -// value: A string specifying the interpolation method. Only 'bilinear' is -// supported for now. -// If not specified, defaults to "bilinear" -func CropAndResizeMethod(value string) CropAndResizeAttr { - return func(m optionalAttr) { - m["method"] = value - } -} +// ShapeAttr is an optional argument to Shape. +type ShapeAttr func(optionalAttr) -// CropAndResizeExtrapolationValue sets the optional extrapolation_value attribute to value. -// -// value: Value used for extrapolation, when applicable. -// If not specified, defaults to 0 -func CropAndResizeExtrapolationValue(value float32) CropAndResizeAttr { +// ShapeOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_INT32 +func ShapeOutType(value tf.DataType) ShapeAttr { return func(m optionalAttr) { - m["extrapolation_value"] = value + m["out_type"] = value } } -// Extracts crops from the input image tensor and bilinearly resizes them (possibly -// -// with aspect ratio change) to a common output size specified by `crop_size`. This -// is more general than the `crop_to_bounding_box` op which extracts a fixed size -// slice from the input image and does not allow resizing or aspect ratio change. +// Returns the shape of a tensor. // -// Returns a tensor with `crops` from the input `image` at positions defined at the -// bounding box locations in `boxes`. The cropped boxes are all resized (with -// bilinear interpolation) to a fixed `size = [crop_height, crop_width]`. The -// result is a 4-D tensor `[num_boxes, crop_height, crop_width, depth]`. The -// resizing is corner aligned. In particular, if `boxes = [[0, 0, 1, 1]]`, the -// method will give identical results to using `tf.image.resize_bilinear()` -// with `align_corners=True`. +// This operation returns a 1-D integer tensor representing the shape of `input`. // -// Arguments: -// image: A 4-D tensor of shape `[batch, image_height, image_width, depth]`. -// Both `image_height` and `image_width` need to be positive. -// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor -// specifies the coordinates of a box in the `box_ind[i]` image and is specified -// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of -// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the -// `[0, 1]` interval of normalized image height is mapped to -// `[0, image_height - 1]` in image height coordinates. We do allow `y1` > `y2`, in -// which case the sampled crop is an up-down flipped version of the original -// image. The width dimension is treated similarly. Normalized coordinates -// outside the `[0, 1]` range are allowed, in which case we use -// `extrapolation_value` to extrapolate the input image values. -// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. -// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. -// crop_size: A 1-D tensor of 2 elements, `size = [crop_height, crop_width]`. All -// cropped image patches are resized to this size. The aspect ratio of the image -// content is not preserved. Both `crop_height` and `crop_width` need to be -// positive. +// For example: // -// Returns A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. -func CropAndResize(scope *Scope, image tf.Output, boxes tf.Output, box_ind tf.Output, crop_size tf.Output, optional ...CropAndResizeAttr) (crops tf.Output) { +// ``` +// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] +// shape(t) ==> [2, 2, 3] +// ``` +func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -16755,9 +16192,9 @@ func CropAndResize(scope *Scope, image tf.Output, boxes tf.Output, box_ind tf.Ou a(attrs) } opspec := tf.OpSpec{ - Type: "CropAndResize", + Type: "Shape", Input: []tf.Input{ - image, boxes, box_ind, crop_size, + input, }, Attrs: attrs, } @@ -16765,136 +16202,153 @@ func CropAndResize(scope *Scope, image tf.Output, boxes tf.Output, box_ind tf.Ou return op.Output(0) } -// ResourceApplyPowerSignAttr is an optional argument to ResourceApplyPowerSign. -type ResourceApplyPowerSignAttr func(optionalAttr) - -// ResourceApplyPowerSignUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and m tensors is -// protected by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyPowerSignUseLocking(value bool) ResourceApplyPowerSignAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' according to the AddSign update. -// -// m_t <- beta1 * m_{t-1} + (1 - beta1) * g -// update <- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g -// variable <- variable - lr_t * update +// Computes fingerprints of the input strings. // // Arguments: -// var_: Should be from a Variable(). -// m: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// logbase: Must be a scalar. -// sign_decay: Must be a scalar. -// beta: Must be a scalar. -// grad: The gradient. +// input: vector of strings to compute fingerprints on. // -// Returns the created operation. -func ResourceApplyPowerSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Output, logbase tf.Output, sign_decay tf.Output, beta tf.Output, grad tf.Output, optional ...ResourceApplyPowerSignAttr) (o *tf.Operation) { +// Returns a (N,2) shaped matrix where N is the number of elements in the input +// vector. Each row contains the low and high parts of the fingerprint. +func SdcaFprint(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ResourceApplyPowerSign", + Type: "SdcaFprint", Input: []tf.Input{ - var_, m, lr, logbase, sign_decay, beta, grad, + input, }, - Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Deprecated. Disallowed in GraphDef version >= 2. +// RandomPoissonV2Attr is an optional argument to RandomPoissonV2. +type RandomPoissonV2Attr func(optionalAttr) + +// RandomPoissonV2Seed sets the optional seed attribute to value. // -// DEPRECATED at GraphDef version 2: Use AdjustContrastv2 instead -func AdjustContrast(scope *Scope, images tf.Output, contrast_factor tf.Output, min_value tf.Output, max_value tf.Output) (output tf.Output) { - if scope.Err() != nil { - return +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomPoissonV2Seed(value int64) RandomPoissonV2Attr { + return func(m optionalAttr) { + m["seed"] = value } - opspec := tf.OpSpec{ - Type: "AdjustContrast", - Input: []tf.Input{ - images, contrast_factor, min_value, max_value, - }, +} + +// RandomPoissonV2Seed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomPoissonV2Seed2(value int64) RandomPoissonV2Attr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// RandomPoissonV2Dtype sets the optional dtype attribute to value. +// If not specified, defaults to DT_INT64 +func RandomPoissonV2Dtype(value tf.DataType) RandomPoissonV2Attr { + return func(m optionalAttr) { + m["dtype"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Table initializer that takes two tensors for keys and values respectively. +// Outputs random values from the Poisson distribution(s) described by rate. +// +// This op uses two algorithms, depending on rate. If rate >= 10, then +// the algorithm by Hormann is used to acquire samples via +// transformation-rejection. +// See http://www.sciencedirect.com/science/article/pii/0167668793909974. +// +// Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform +// random variables. +// See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer +// Programming, Volume 2. Addison Wesley // // Arguments: -// table_handle: Handle to a table which will be initialized. -// keys: Keys of type Tkey. -// values: Values of type Tval. +// shape: 1-D integer tensor. Shape of independent samples to draw from each +// distribution described by the shape parameters given in rate. +// rate: A tensor in which each scalar is a "rate" parameter describing the +// associated poisson distribution. // -// Returns the created operation. -func InitializeTableV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { +// Returns A tensor with shape `shape + shape(rate)`. Each slice +// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for +// `rate[i0, i1, ...iN]`. +func RandomPoissonV2(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonV2Attr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "InitializeTableV2", + Type: "RandomPoissonV2", Input: []tf.Input{ - table_handle, keys, values, + shape, rate, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// PrintAttr is an optional argument to Print. -type PrintAttr func(optionalAttr) +// MatrixTriangularSolveAttr is an optional argument to MatrixTriangularSolve. +type MatrixTriangularSolveAttr func(optionalAttr) -// PrintMessage sets the optional message attribute to value. +// MatrixTriangularSolveLower sets the optional lower attribute to value. // -// value: A string, prefix of the error message. -// If not specified, defaults to "" -func PrintMessage(value string) PrintAttr { +// value: Boolean indicating whether the innermost matrices in `matrix` are +// lower or upper triangular. +// If not specified, defaults to true +func MatrixTriangularSolveLower(value bool) MatrixTriangularSolveAttr { return func(m optionalAttr) { - m["message"] = value + m["lower"] = value } } -// PrintFirstN sets the optional first_n attribute to value. +// MatrixTriangularSolveAdjoint sets the optional adjoint attribute to value. // -// value: Only log `first_n` number of times. -1 disables logging. -// If not specified, defaults to -1 -func PrintFirstN(value int64) PrintAttr { - return func(m optionalAttr) { - m["first_n"] = value - } -} - -// PrintSummarize sets the optional summarize attribute to value. +// value: Boolean indicating whether to solve with `matrix` or its (block-wise) +// adjoint. // -// value: Only print this many entries of each tensor. -// If not specified, defaults to 3 -func PrintSummarize(value int64) PrintAttr { +// @compatibility(numpy) +// Equivalent to np.linalg.triangular_solve +// @end_compatibility +// If not specified, defaults to false +func MatrixTriangularSolveAdjoint(value bool) MatrixTriangularSolveAttr { return func(m optionalAttr) { - m["summarize"] = value + m["adjoint"] = value } } -// Prints a list of tensors. +// Solves systems of linear equations with upper or lower triangular matrices by // -// Passes `input` through to `output` and prints `data` when evaluating. +// backsubstitution. +// +// `matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form +// square matrices. If `lower` is `True` then the strictly upper triangular part +// of each inner-most matrix is assumed to be zero and not accessed. +// If `lower` is False then the strictly lower triangular part of each inner-most +// matrix is assumed to be zero and not accessed. +// `rhs` is a tensor of shape `[..., M, K]`. +// +// The output is a tensor of shape `[..., M, K]`. If `adjoint` is +// `True` then the innermost matrices in `output` satisfy matrix equations +// `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. +// If `adjoint` is `False` then the strictly then the innermost matrices in +// `output` satisfy matrix equations +// `adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`. // // Arguments: -// input: The tensor passed to `output` -// data: A list of tensors to print out when op is evaluated. +// matrix: Shape is `[..., M, M]`. +// rhs: Shape is `[..., M, K]`. // -// Returns = The unmodified `input` tensor -func Print(scope *Scope, input tf.Output, data []tf.Output, optional ...PrintAttr) (output tf.Output) { +// Returns Shape is `[..., M, K]`. +func MatrixTriangularSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixTriangularSolveAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -16903,9 +16357,9 @@ func Print(scope *Scope, input tf.Output, data []tf.Output, optional ...PrintAtt a(attrs) } opspec := tf.OpSpec{ - Type: "Print", + Type: "MatrixTriangularSolve", Input: []tf.Input{ - input, tf.OutputList(data), + matrix, rhs, }, Attrs: attrs, } @@ -16913,44 +16367,38 @@ func Print(scope *Scope, input tf.Output, data []tf.Output, optional ...PrintAtt return op.Output(0) } -// Outputs a `Summary` protocol buffer with a tensor and per-plugin data. -// -// Arguments: -// tag: A string attached to this summary. Used for organization in TensorBoard. -// tensor: A tensor to serialize. -// serialized_summary_metadata: A serialized SummaryMetadata proto. Contains plugin -// data. -func TensorSummaryV2(scope *Scope, tag tf.Output, tensor tf.Output, serialized_summary_metadata tf.Output) (summary tf.Output) { +// Computes inverse hyperbolic sine of x element-wise. +func Asinh(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorSummaryV2", + Type: "Asinh", Input: []tf.Input{ - tag, tensor, serialized_summary_metadata, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a dataset that asynchronously prefetches elements from `input_dataset`. +// Creates a dataset with a range of values. Corresponds to python's xrange. // // Arguments: -// -// buffer_size: The maximum number of elements to buffer in an iterator over -// this dataset. +// start: corresponds to start in python's xrange(). +// stop: corresponds to stop in python's xrange(). +// step: corresponds to step in python's xrange(). // // -func PrefetchDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +func RangeDataset(scope *Scope, start tf.Output, stop tf.Output, step tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "PrefetchDataset", + Type: "RangeDataset", Input: []tf.Input{ - input_dataset, buffer_size, + start, stop, step, }, Attrs: attrs, } @@ -16958,59 +16406,69 @@ func PrefetchDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Outpu return op.Output(0) } -// TensorSummaryAttr is an optional argument to TensorSummary. -type TensorSummaryAttr func(optionalAttr) - -// TensorSummaryDescription sets the optional description attribute to value. -// -// value: A json-encoded SummaryDescription proto. -// If not specified, defaults to "" -func TensorSummaryDescription(value string) TensorSummaryAttr { - return func(m optionalAttr) { - m["description"] = value - } -} +// DepthwiseConv2dNativeBackpropInputAttr is an optional argument to DepthwiseConv2dNativeBackpropInput. +type DepthwiseConv2dNativeBackpropInputAttr func(optionalAttr) -// TensorSummaryLabels sets the optional labels attribute to value. +// DepthwiseConv2dNativeBackpropInputDataFormat sets the optional data_format attribute to value. // -// value: An unused list of strings. -// If not specified, defaults to <> -func TensorSummaryLabels(value []string) TensorSummaryAttr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, height, width, channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, channels, height, width]. +// If not specified, defaults to "NHWC" +func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { - m["labels"] = value + m["data_format"] = value } } -// TensorSummaryDisplayName sets the optional display_name attribute to value. +// DepthwiseConv2dNativeBackpropInputDilations sets the optional dilations attribute to value. // -// value: An unused string. -// If not specified, defaults to "" -func TensorSummaryDisplayName(value string) TensorSummaryAttr { +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to +func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { - m["display_name"] = value + m["dilations"] = value } } -// Outputs a `Summary` protocol buffer with a tensor. -// -// This op is being phased out in favor of TensorSummaryV2, which lets callers pass -// a tag as well as a serialized SummaryMetadata proto string that contains -// plugin-specific data. We will keep this op to maintain backwards compatibility. +// Computes the gradients of depthwise convolution with respect to the input. // // Arguments: -// tensor: A tensor to serialize. -func TensorSummary(scope *Scope, tensor tf.Output, optional ...TensorSummaryAttr) (summary tf.Output) { +// input_sizes: An integer vector representing the shape of `input`, based +// on `data_format`. For example, if `data_format` is 'NHWC' then +// `input` is a 4-D `[batch, height, width, channels]` tensor. +// filter: 4-D with shape +// `[filter_height, filter_width, in_channels, depthwise_multiplier]`. +// out_backprop: 4-D with shape based on `data_format`. +// For example, if `data_format` is 'NHWC' then +// out_backprop shape is `[batch, out_height, out_width, out_channels]`. +// Gradients w.r.t. the output of the convolution. +// strides: The stride of the sliding window for each dimension of the input +// of the convolution. +// padding: The type of padding algorithm to use. +// +// Returns 4-D with shape according to `data_format`. For example, if +// `data_format` is 'NHWC', output shape is `[batch, in_height, +// in_width, in_channels]`. Gradient w.r.t. the input of the +// convolution. +func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropInputAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TensorSummary", + Type: "DepthwiseConv2dNativeBackpropInput", Input: []tf.Input{ - tensor, + input_sizes, filter, out_backprop, }, Attrs: attrs, } @@ -17018,174 +16476,145 @@ func TensorSummary(scope *Scope, tensor tf.Output, optional ...TensorSummaryAttr return op.Output(0) } -// Computes the gradient for the tanh of `x` wrt its input. +// Adds sparse updates to the variable referenced by `resource`. // -// Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy` -// is the corresponding input gradient. -func TanhGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { +// This operation computes +// +// # Scalar indices +// ref[indices, ...] += updates[...] +// +// # Vector indices (for each i) +// ref[indices[i], ...] += updates[i, ...] +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions add. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]`. +// +//
+// +//
+// +// Arguments: +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. +// +// Returns the created operation. +func ResourceScatterAdd(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TanhGrad", + Type: "ResourceScatterAdd", Input: []tf.Input{ - y, dy, + resource, indices, updates, }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Outputs a `Summary` protocol buffer with scalar values. +// Says whether the targets are in the top `K` predictions. +// +// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the +// prediction for the target class is among the top `k` predictions among +// all predictions for example `i`. Note that the behavior of `InTopK` differs +// from the `TopK` op in its handling of ties; if multiple classes have the +// same prediction value and straddle the top-`k` boundary, all of those +// classes are considered to be in the top `k`. // -// The input `tags` and `values` must have the same shape. The generated summary -// has a summary value for each tag-value pair in `tags` and `values`. +// More formally, let +// +// \\(predictions_i\\) be the predictions for all classes for example `i`, +// \\(targets_i\\) be the target class for example `i`, +// \\(out_i\\) be the output for example `i`, +// +// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ // // Arguments: -// tags: Tags for the summary. -// values: Same shape as `tags. Values for the summary. +// predictions: A `batch_size` x `classes` tensor. +// targets: A `batch_size` vector of class ids. +// k: Number of top elements to look at for computing precision. // -// Returns Scalar. Serialized `Summary` protocol buffer. -func ScalarSummary(scope *Scope, tags tf.Output, values tf.Output) (summary tf.Output) { +// Returns Computed Precision at `k` as a `bool Tensor`. +func InTopK(scope *Scope, predictions tf.Output, targets tf.Output, k int64) (precision tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"k": k} opspec := tf.OpSpec{ - Type: "ScalarSummary", + Type: "InTopK", Input: []tf.Input{ - tags, values, + predictions, targets, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Outputs a `Summary` protocol buffer with a histogram. -// -// The generated -// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) -// has one summary value containing a histogram for `values`. -// -// This op reports an `InvalidArgument` error if any value is not finite. -// -// Arguments: -// tag: Scalar. Tag to use for the `Summary.Value`. -// values: Any shape. Values to use to build the histogram. +// Computes the gradient for the inverse of `x` wrt its input. // -// Returns Scalar. Serialized `Summary` protocol buffer. -func HistogramSummary(scope *Scope, tag tf.Output, values tf.Output) (summary tf.Output) { +// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` +// is the corresponding input gradient. +func ReciprocalGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "HistogramSummary", + Type: "ReciprocalGrad", Input: []tf.Input{ - tag, values, + y, dy, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the number of elements in the given queue. -// -// Arguments: -// handle: The handle to a queue. +// Returns the min of x and y (i.e. x < y ? x : y) element-wise. // -// Returns The number of elements in the given queue. -func QueueSizeV2(scope *Scope, handle tf.Output) (size tf.Output) { +// *NOTE*: `Minimum` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "QueueSizeV2", + Type: "Minimum", Input: []tf.Input{ - handle, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ImageSummaryAttr is an optional argument to ImageSummary. -type ImageSummaryAttr func(optionalAttr) - -// ImageSummaryMaxImages sets the optional max_images attribute to value. -// -// value: Max number of batch elements to generate images for. -// If not specified, defaults to 3 -// -// REQUIRES: value >= 1 -func ImageSummaryMaxImages(value int64) ImageSummaryAttr { - return func(m optionalAttr) { - m["max_images"] = value - } -} - -// ImageSummaryBadColor sets the optional bad_color attribute to value. -// -// value: Color to use for pixels with non-finite values. -// If not specified, defaults to > int_val:255 int_val:0 int_val:0 int_val:255 > -func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { - return func(m optionalAttr) { - m["bad_color"] = value - } -} - -// Outputs a `Summary` protocol buffer with images. -// -// The summary has up to `max_images` summary values containing images. The -// images are built from `tensor` which must be 4-D with shape `[batch_size, -// height, width, channels]` and where `channels` can be: -// -// * 1: `tensor` is interpreted as Grayscale. -// * 3: `tensor` is interpreted as RGB. -// * 4: `tensor` is interpreted as RGBA. -// -// The images have the same number of channels as the input tensor. For float -// input, the values are normalized one image at a time to fit in the range -// `[0, 255]`. `uint8` values are unchanged. The op uses two different -// normalization algorithms: -// -// * If the input values are all positive, they are rescaled so the largest one -// is 255. -// -// * If any input value is negative, the values are shifted so input value 0.0 -// is at 127. They are then rescaled so that either the smallest value is 0, -// or the largest one is 255. +// Returns the element-wise sum of a list of tensors. // -// The `tag` argument is a scalar `Tensor` of type `string`. It is used to -// build the `tag` of the summary values: +// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not +// wait for all of its inputs to be ready before beginning to sum. This can +// save memory if inputs are ready at different times, since minimum temporary +// storage is proportional to the output size rather than the inputs size. // -// * If `max_images` is 1, the summary value tag is '*tag*/image'. -// * If `max_images` is greater than 1, the summary value tags are -// generated sequentially as '*tag*/image/0', '*tag*/image/1', etc. +// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. // -// The `bad_color` argument is the color to use in the generated images for -// non-finite input values. It is a `unit8` 1-D tensor of length `channels`. -// Each element must be in the range `[0, 255]` (It represents the value of a -// pixel in the output image). Non-finite values in the input tensor are -// replaced by this tensor in the output image. The default value is the color -// red. +// Returns a `Tensor` of same shape and type as the elements of `inputs`. // // Arguments: -// tag: Scalar. Used to build the `tag` attribute of the summary values. -// tensor: 4-D of shape `[batch_size, height, width, channels]` where -// `channels` is 1, 3, or 4. -// -// Returns Scalar. Serialized `Summary` protocol buffer. -func ImageSummary(scope *Scope, tag tf.Output, tensor tf.Output, optional ...ImageSummaryAttr) (summary tf.Output) { +// inputs: A list of `Tensor` objects, each with same shape and type. +// shape: Shape of elements of `inputs`. +func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"shape": shape} opspec := tf.OpSpec{ - Type: "ImageSummary", + Type: "AccumulateNV2", Input: []tf.Input{ - tag, tensor, + tf.OutputList(inputs), }, Attrs: attrs, } @@ -17193,101 +16622,106 @@ func ImageSummary(scope *Scope, tag tf.Output, tensor tf.Output, optional ...Ima return op.Output(0) } -// AudioSummaryV2Attr is an optional argument to AudioSummaryV2. -type AudioSummaryV2Attr func(optionalAttr) - -// AudioSummaryV2MaxOutputs sets the optional max_outputs attribute to value. -// -// value: Max number of batch elements to generate audio for. -// If not specified, defaults to 3 +// Convert the quantized 'input' tensor into a lower-precision 'output', using the // -// REQUIRES: value >= 1 -func AudioSummaryV2MaxOutputs(value int64) AudioSummaryV2Attr { - return func(m optionalAttr) { - m["max_outputs"] = value - } -} - -// Outputs a `Summary` protocol buffer with audio. +// actual distribution of the values to maximize the usage of the lower bit depth +// and adjusting the output min and max ranges accordingly. // -// The summary has up to `max_outputs` summary values containing audio. The -// audio is built from `tensor` which must be 3-D with shape `[batch_size, -// frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are -// assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. +// [input_min, input_max] are scalar floats that specify the range for the float +// interpretation of the 'input' data. For example, if input_min is -1.0f and +// input_max is 1.0f, and we are dealing with quint16 quantized data, then a 0 +// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f. // -// The `tag` argument is a scalar `Tensor` of type `string`. It is used to -// build the `tag` of the summary values: +// This operator tries to squeeze as much precision as possible into an output with +// a lower bit depth by calculating the actual min and max values found in the +// data. For example, maybe that quint16 input has no values lower than 16,384 and +// none higher than 49,152. That means only half the range is actually needed, all +// the float interpretations are between -0.5f and 0.5f, so if we want to compress +// the data into a quint8 output, we can use that range rather than the theoretical +// -1.0f to 1.0f that is suggested by the input min and max. // -// * If `max_outputs` is 1, the summary value tag is '*tag*/audio'. -// * If `max_outputs` is greater than 1, the summary value tags are -// generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. +// In practice, this is most useful for taking output from operations like +// QuantizedMatMul that can produce higher bit-depth outputs than their inputs and +// may have large potential output ranges, but in practice have a distribution of +// input values that only uses a small fraction of the possible range. By feeding +// that output into this operator, we can reduce it from 32 bits down to 8 with +// minimal loss of accuracy. // // Arguments: -// tag: Scalar. Used to build the `tag` attribute of the summary values. -// tensor: 2-D of shape `[batch_size, frames]`. -// sample_rate: The sample rate of the signal in hertz. // -// Returns Scalar. Serialized `Summary` protocol buffer. -func AudioSummaryV2(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate tf.Output, optional ...AudioSummaryV2Attr) (summary tf.Output) { +// input_min: The float value that the minimum quantized input value represents. +// input_max: The float value that the maximum quantized input value represents. +// out_type: The type of the output. Should be a lower bit depth than Tinput. +// +// Returns The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. +func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, out_type tf.DataType) (output tf.Output, output_min tf.Output, output_max tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"out_type": out_type} opspec := tf.OpSpec{ - Type: "AudioSummaryV2", + Type: "QuantizeDownAndShrinkRange", Input: []tf.Input{ - tag, tensor, sample_rate, + input, input_min, input_max, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// AvgPoolAttr is an optional argument to AvgPool. -type AvgPoolAttr func(optionalAttr) +// RandomGammaAttr is an optional argument to RandomGamma. +type RandomGammaAttr func(optionalAttr) -// AvgPoolDataFormat sets the optional data_format attribute to value. +// RandomGammaSeed sets the optional seed attribute to value. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func AvgPoolDataFormat(value string) AvgPoolAttr { +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomGammaSeed(value int64) RandomGammaAttr { return func(m optionalAttr) { - m["data_format"] = value + m["seed"] = value } } -// Performs average pooling on the input. +// RandomGammaSeed2 sets the optional seed2 attribute to value. // -// Each entry in `output` is the mean of the corresponding size `ksize` -// window in `value`. +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomGammaSeed2(value int64) RandomGammaAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Outputs random values from the Gamma distribution(s) described by alpha. +// +// This op uses the algorithm by Marsaglia et al. to acquire samples via +// transformation-rejection from pairs of uniform and normal random variables. +// See http://dl.acm.org/citation.cfm?id=358414 // // Arguments: -// value: 4-D with shape `[batch, height, width, channels]`. -// ksize: The size of the sliding window for each dimension of `value`. -// strides: The stride of the sliding window for each dimension of `value`. -// padding: The type of padding algorithm to use. +// shape: 1-D integer tensor. Shape of independent samples to draw from each +// distribution described by the shape parameters given in alpha. +// alpha: A tensor in which each scalar is a "shape" parameter describing the +// associated gamma distribution. // -// Returns The average pooled output tensor. -func AvgPool(scope *Scope, value tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolAttr) (output tf.Output) { +// Returns A tensor with shape `shape + shape(alpha)`. Each slice +// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for +// `alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha. +func RandomGamma(scope *Scope, shape tf.Output, alpha tf.Output, optional ...RandomGammaAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "AvgPool", + Type: "RandomGamma", Input: []tf.Input{ - value, + shape, alpha, }, Attrs: attrs, } @@ -17295,57 +16729,59 @@ func AvgPool(scope *Scope, value tf.Output, ksize []int64, strides []int64, padd return op.Output(0) } -// Merges summaries. -// -// This op creates a -// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) -// protocol buffer that contains the union of all the values in the input -// summaries. -// -// When the Op is run, it reports an `InvalidArgument` error if multiple values -// in the summaries to merge use the same tag. -// -// Arguments: -// inputs: Can be of any shape. Each must contain serialized `Summary` protocol -// buffers. +// RandomUniformIntAttr is an optional argument to RandomUniformInt. +type RandomUniformIntAttr func(optionalAttr) + +// RandomUniformIntSeed sets the optional seed attribute to value. // -// Returns Scalar. Serialized `Summary` protocol buffer. -func MergeSummary(scope *Scope, inputs []tf.Output) (summary tf.Output) { - if scope.Err() != nil { - return +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomUniformIntSeed(value int64) RandomUniformIntAttr { + return func(m optionalAttr) { + m["seed"] = value } - opspec := tf.OpSpec{ - Type: "MergeSummary", - Input: []tf.Input{ - tf.OutputList(inputs), - }, +} + +// RandomUniformIntSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomUniformIntSeed2(value int64) RandomUniformIntAttr { + return func(m optionalAttr) { + m["seed2"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Computes the gradient of morphological 2-D dilation with respect to the filter. +// Outputs random integers from a uniform distribution. +// +// The generated values are uniform integers in the range `[minval, maxval)`. +// The lower bound `minval` is included in the range, while the upper bound +// `maxval` is excluded. +// +// The random integers are slightly biased unless `maxval - minval` is an exact +// power of two. The bias is small for values of `maxval - minval` significantly +// smaller than the range of the output (either `2^32` or `2^64`). // // Arguments: -// input: 4-D with shape `[batch, in_height, in_width, depth]`. -// filter: 3-D with shape `[filter_height, filter_width, depth]`. -// out_backprop: 4-D with shape `[batch, out_height, out_width, depth]`. -// strides: 1-D of length 4. The stride of the sliding window for each dimension of -// the input tensor. Must be: `[1, stride_height, stride_width, 1]`. -// rates: 1-D of length 4. The input stride for atrous morphological dilation. -// Must be: `[1, rate_height, rate_width, 1]`. -// padding: The type of padding algorithm to use. +// shape: The shape of the output tensor. +// minval: 0-D. Inclusive lower bound on the generated integers. +// maxval: 0-D. Exclusive upper bound on the generated integers. // -// Returns 3-D with shape `[filter_height, filter_width, depth]`. -func Dilation2DBackpropFilter(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, rates []int64, padding string) (filter_backprop tf.Output) { +// Returns A tensor of the specified shape filled with uniform random integers. +func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf.Output, optional ...RandomUniformIntAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Dilation2DBackpropFilter", + Type: "RandomUniformInt", Input: []tf.Input{ - input, filter, out_backprop, + shape, minval, maxval, }, Attrs: attrs, } @@ -17353,295 +16789,370 @@ func Dilation2DBackpropFilter(scope *Scope, input tf.Output, filter tf.Output, o return op.Output(0) } -// AddSparseToTensorsMapAttr is an optional argument to AddSparseToTensorsMap. -type AddSparseToTensorsMapAttr func(optionalAttr) +// SkipgramAttr is an optional argument to Skipgram. +type SkipgramAttr func(optionalAttr) -// AddSparseToTensorsMapContainer sets the optional container attribute to value. +// SkipgramWindowSize sets the optional window_size attribute to value. // -// value: The container name for the `SparseTensorsMap` created by this op. -// If not specified, defaults to "" -func AddSparseToTensorsMapContainer(value string) AddSparseToTensorsMapAttr { +// value: The number of words to predict to the left and right of the target. +// If not specified, defaults to 5 +func SkipgramWindowSize(value int64) SkipgramAttr { return func(m optionalAttr) { - m["container"] = value + m["window_size"] = value } } -// AddSparseToTensorsMapSharedName sets the optional shared_name attribute to value. +// SkipgramMinCount sets the optional min_count attribute to value. // -// value: The shared name for the `SparseTensorsMap` created by this op. -// If blank, the new Operation's unique name is used. -// If not specified, defaults to "" -func AddSparseToTensorsMapSharedName(value string) AddSparseToTensorsMapAttr { +// value: The minimum number of word occurrences for it to be included in the +// vocabulary. +// If not specified, defaults to 5 +func SkipgramMinCount(value int64) SkipgramAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["min_count"] = value } } -// Add a `SparseTensor` to a `SparseTensorsMap` return its handle. -// -// A `SparseTensor` is represented by three tensors: `sparse_indices`, -// `sparse_values`, and `sparse_shape`. +// SkipgramSubsample sets the optional subsample attribute to value. // -// This operator takes the given `SparseTensor` and adds it to a container -// object (a `SparseTensorsMap`). A unique key within this container is generated -// in the form of an `int64`, and this is the value that is returned. +// value: Threshold for word occurrence. Words that appear with higher +// frequency will be randomly down-sampled. Set to 0 to disable. +// If not specified, defaults to 0.001 +func SkipgramSubsample(value float32) SkipgramAttr { + return func(m optionalAttr) { + m["subsample"] = value + } +} + +// Parses a text file and creates a batch of examples. // -// The `SparseTensor` can then be read out as part of a minibatch by passing -// the key as a vector element to `TakeManySparseFromTensorsMap`. To ensure -// the correct `SparseTensorsMap` is accessed, ensure that the same -// `container` and `shared_name` are passed to that Op. If no `shared_name` -// is provided here, instead use the *name* of the Operation created by calling -// `AddSparseToTensorsMap` as the `shared_name` passed to -// `TakeManySparseFromTensorsMap`. Ensure the Operations are colocated. +// DEPRECATED at GraphDef version 19: Moving word2vec into tensorflow_models/tutorials and deprecating its ops here as a result // // Arguments: -// sparse_indices: 2-D. The `indices` of the `SparseTensor`. -// sparse_values: 1-D. The `values` of the `SparseTensor`. -// sparse_shape: 1-D. The `shape` of the `SparseTensor`. +// filename: The corpus's text file name. +// batch_size: The size of produced batch. // -// Returns 0-D. The handle of the `SparseTensor` now stored in the -// `SparseTensorsMap`. -func AddSparseToTensorsMap(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...AddSparseToTensorsMapAttr) (sparse_handle tf.Output) { +// Returns A vector of words in the corpus.Frequencies of words. Sorted in the non-ascending order.Number of words per epoch in the data file.The current epoch number.The total number of words processed so far.A vector of word ids.A vector of word ids. +func Skipgram(scope *Scope, filename string, batch_size int64, optional ...SkipgramAttr) (vocab_word tf.Output, vocab_freq tf.Output, words_per_epoch tf.Output, current_epoch tf.Output, total_words_processed tf.Output, examples tf.Output, labels tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"filename": filename, "batch_size": batch_size} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "AddSparseToTensorsMap", - Input: []tf.Input{ - sparse_indices, sparse_values, sparse_shape, - }, + Type: "Skipgram", + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4), op.Output(5), op.Output(6) } -// Writes a `Summary` protocol buffer with scalar values. +// StringToNumberAttr is an optional argument to StringToNumber. +type StringToNumberAttr func(optionalAttr) + +// StringToNumberOutType sets the optional out_type attribute to value. // -// The input `tag` and `value` must have the scalars. +// value: The numeric type to interpret each string in `string_tensor` as. +// If not specified, defaults to DT_FLOAT +func StringToNumberOutType(value tf.DataType) StringToNumberAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Converts each string in the input Tensor to the specified numeric type. // -// Arguments: -// writer: A handle to a summary writer. -// step: The step to write the summary for. -// tag: Tag for the summary. -// value: Value for the summary. +// (Note that int32 overflow results in an error while float overflow +// results in a rounded value.) // -// Returns the created operation. -func WriteScalarSummary(scope *Scope, writer tf.Output, step tf.Output, tag tf.Output, value tf.Output) (o *tf.Operation) { +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToNumber(scope *Scope, string_tensor tf.Output, optional ...StringToNumberAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "WriteScalarSummary", + Type: "StringToNumber", Input: []tf.Input{ - writer, step, tag, value, + string_tensor, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Computes the matrix exponential of one or more square matrices: -// -// exp(A) = \sum_{n=0}^\infty A^n/n! +// ResourceApplyFtrlV2Attr is an optional argument to ResourceApplyFtrlV2. +type ResourceApplyFtrlV2Attr func(optionalAttr) + +// ResourceApplyFtrlV2UseLocking sets the optional use_locking attribute to value. // -// The exponential is computed using a combination of the scaling and squaring -// method and the Pade approximation. Details can be founds in: -// Nicholas J. Higham, "The scaling and squaring method for the matrix exponential -// revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005. +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyFtrlV2UseLocking(value bool) ResourceApplyFtrlV2Attr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the Ftrl-proximal scheme. // -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. The output is a tensor of the same shape as the input -// containing the exponential for all input submatrices `[..., :, :]`. +// grad_with_shrinkage = grad + 2 * l2_shrinkage * var +// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage +// linear += grad_with_shrinkage + +// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +// accum = accum_new // // Arguments: -// input: Shape is `[..., M, M]`. +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// linear: Should be from a Variable(). +// grad: The gradient. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regulariation. Must be a scalar. +// l2: L2 shrinkage regulariation. Must be a scalar. // -// Returns Shape is `[..., M, M]`. +// lr_power: Scaling factor. Must be a scalar. // -// @compatibility(scipy) -// Equivalent to scipy.linalg.expm -// @end_compatibility -func MatrixExponential(scope *Scope, input tf.Output) (output tf.Output) { +// Returns the created operation. +func ResourceApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, l2_shrinkage tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "MatrixExponential", + Type: "ResourceApplyFtrlV2", Input: []tf.Input{ - input, + var_, accum, linear, grad, lr, l1, l2, l2_shrinkage, lr_power, }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// TruncatedNormalAttr is an optional argument to TruncatedNormal. +type TruncatedNormalAttr func(optionalAttr) + +// TruncatedNormalSeed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func TruncatedNormalSeed(value int64) TruncatedNormalAttr { + return func(m optionalAttr) { + m["seed"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// QueueDequeueUpToV2Attr is an optional argument to QueueDequeueUpToV2. -type QueueDequeueUpToV2Attr func(optionalAttr) - -// QueueDequeueUpToV2TimeoutMs sets the optional timeout_ms attribute to value. +// TruncatedNormalSeed2 sets the optional seed2 attribute to value. // -// value: If the queue has fewer than n elements, this operation -// will block for up to timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueDequeueUpToV2TimeoutMs(value int64) QueueDequeueUpToV2Attr { +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func TruncatedNormalSeed2(value int64) TruncatedNormalAttr { return func(m optionalAttr) { - m["timeout_ms"] = value + m["seed2"] = value } } -// Dequeues `n` tuples of one or more tensors from the given queue. -// -// This operation is not supported by all queues. If a queue does not support -// DequeueUpTo, then an Unimplemented error is returned. -// -// If the queue is closed and there are more than 0 but less than `n` -// elements remaining, then instead of returning an OutOfRange error like -// QueueDequeueMany, less than `n` elements are returned immediately. If -// the queue is closed and there are 0 elements left in the queue, then -// an OutOfRange error is returned just like in QueueDequeueMany. -// Otherwise the behavior is identical to QueueDequeueMany: -// -// This operation concatenates queue-element component tensors along the -// 0th dimension to make a single component tensor. All of the components -// in the dequeued tuple will have size n in the 0th dimension. +// Outputs random values from a truncated normal distribution. // -// This operation has `k` outputs, where `k` is the number of components in -// the tuples stored in the given queue, and output `i` is the ith -// component of the dequeued tuple. +// The generated values follow a normal distribution with mean 0 and standard +// deviation 1, except that values whose magnitude is more than 2 standard +// deviations from the mean are dropped and re-picked. // // Arguments: -// handle: The handle to a queue. -// n: The number of tuples to dequeue. -// component_types: The type of each component in a tuple. +// shape: The shape of the output tensor. +// dtype: The type of the output. // -// Returns One or more tensors that were dequeued as a tuple. -func QueueDequeueUpToV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueUpToV2Attr) (components []tf.Output) { +// Returns A tensor of the specified shape filled with random truncated normal +// values. +func TruncatedNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...TruncatedNormalAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"component_types": component_types} + attrs := map[string]interface{}{"dtype": dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QueueDequeueUpToV2", + Type: "TruncatedNormal", Input: []tf.Input{ - handle, n, + shape, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("QueueDequeueUpToV2", err) - return - } - return components + return op.Output(0) } -// Computes the Cholesky decomposition of one or more square matrices. +// RandomShuffleAttr is an optional argument to RandomShuffle. +type RandomShuffleAttr func(optionalAttr) + +// RandomShuffleSeed sets the optional seed attribute to value. // -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomShuffleSeed(value int64) RandomShuffleAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomShuffleSeed2 sets the optional seed2 attribute to value. // -// The input has to be symmetric and positive definite. Only the lower-triangular -// part of the input will be used for this operation. The upper-triangular part -// will not be read. +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomShuffleSeed2(value int64) RandomShuffleAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Randomly shuffles a tensor along its first dimension. // -// The output is a tensor of the same shape as the input -// containing the Cholesky decompositions for all input submatrices `[..., :, :]`. +// The tensor is shuffled along dimension 0, such that each `value[j]` is mapped +// to one and only one `output[i]`. For example, a mapping that might occur for a +// 3x2 tensor is: // -// **Note**: The gradient computation on GPU is faster for large matrices but -// not for large batch dimensions when the submatrices are small. In this -// case it might be faster to use the CPU. +// ``` +// [[1, 2], [[5, 6], +// [3, 4], ==> [1, 2], +// [5, 6]] [3, 4]] +// ``` // // Arguments: -// input: Shape is `[..., M, M]`. +// value: The tensor to be shuffled. // -// Returns Shape is `[..., M, M]`. -func Cholesky(scope *Scope, input tf.Output) (output tf.Output) { +// Returns A tensor of same shape and type as `value`, shuffled along its first +// dimension. +func RandomShuffle(scope *Scope, value tf.Output, optional ...RandomShuffleAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Cholesky", + Type: "RandomShuffle", Input: []tf.Input{ - input, + value, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Writes contents to the file at input filename. Creates file and recursively -// -// creates directory if not existing. +// OrderedMapIncompleteSizeAttr is an optional argument to OrderedMapIncompleteSize. +type OrderedMapIncompleteSizeAttr func(optionalAttr) + +// OrderedMapIncompleteSizeCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// Arguments: -// filename: scalar. The name of the file to which we write the contents. -// contents: scalar. The content to be written to the output file. +// REQUIRES: value >= 0 +func OrderedMapIncompleteSizeCapacity(value int64) OrderedMapIncompleteSizeAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// OrderedMapIncompleteSizeMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// Returns the created operation. -func WriteFile(scope *Scope, filename tf.Output, contents tf.Output) (o *tf.Operation) { +// REQUIRES: value >= 0 +func OrderedMapIncompleteSizeMemoryLimit(value int64) OrderedMapIncompleteSizeAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// OrderedMapIncompleteSizeContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func OrderedMapIncompleteSizeContainer(value string) OrderedMapIncompleteSizeAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// OrderedMapIncompleteSizeSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func OrderedMapIncompleteSizeSharedName(value string) OrderedMapIncompleteSizeAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op returns the number of incomplete elements in the underlying container. +func OrderedMapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...OrderedMapIncompleteSizeAttr) (size tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "WriteFile", - Input: []tf.Input{ - filename, contents, - }, + Type: "OrderedMapIncompleteSize", + + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// AllAttr is an optional argument to All. -type AllAttr func(optionalAttr) +// DecodeRawAttr is an optional argument to DecodeRaw. +type DecodeRawAttr func(optionalAttr) -// AllKeepDims sets the optional keep_dims attribute to value. +// DecodeRawLittleEndian sets the optional little_endian attribute to value. // -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func AllKeepDims(value bool) AllAttr { +// value: Whether the input `bytes` are in little-endian order. +// Ignored for `out_type` values that are stored in a single byte like +// `uint8`. +// If not specified, defaults to true +func DecodeRawLittleEndian(value bool) DecodeRawAttr { return func(m optionalAttr) { - m["keep_dims"] = value + m["little_endian"] = value } } -// Computes the "logical and" of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. +// Reinterpret the bytes of a string as a vector of numbers. // // Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. +// bytes: All the elements must have the same length. // -// Returns The reduced tensor. -func All(scope *Scope, input tf.Output, axis tf.Output, optional ...AllAttr) (output tf.Output) { +// +// Returns A Tensor with one more dimension than the input `bytes`. The +// added dimension will have size equal to the length of the elements +// of `bytes` divided by the number of bytes to represent `out_type`. +func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ...DecodeRawAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"out_type": out_type} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "All", + Type: "DecodeRaw", Input: []tf.Input{ - input, axis, + bytes, }, Attrs: attrs, } @@ -17649,88 +17160,129 @@ func All(scope *Scope, input tf.Output, axis tf.Output, optional ...AllAttr) (ou return op.Output(0) } -// Computes the Eigen Decomposition of a batch of square self-adjoint matrices. +// Copy a tensor setting everything outside a central band in each innermost matrix +// +// to zero. +// +// The `band` part is computed as follows: +// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a +// tensor with the same shape where +// +// `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`. +// +// The indicator function +// +// `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && +// (num_upper < 0 || (n-m) <= num_upper)`. +// +// For example: +// +// ``` +// # if 'input' is [[ 0, 1, 2, 3] +// [-1, 0, 1, 2] +// [-2, -1, 0, 1] +// [-3, -2, -1, 0]], +// +// tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3] +// [-1, 0, 1, 2] +// [ 0, -1, 0, 1] +// [ 0, 0, -1, 0]], // -// DEPRECATED at GraphDef version 11: Use SelfAdjointEigV2 instead. +// tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0] +// [-1, 0, 1, 0] +// [-2, -1, 0, 1] +// [ 0, -2, -1, 0]] +// ``` // -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices, with the same constraints as the single matrix -// SelfAdjointEig. +// Useful special cases: // -// The result is a [..., M+1, M] matrix with [..., 0,:] containing the -// eigenvalues, and subsequent [...,1:, :] containing the eigenvectors. +// ``` +// tf.matrix_band_part(input, 0, -1) ==> Upper triangular part. +// tf.matrix_band_part(input, -1, 0) ==> Lower triangular part. +// tf.matrix_band_part(input, 0, 0) ==> Diagonal. +// ``` // // Arguments: -// input: Shape is `[..., M, M]`. +// input: Rank `k` tensor. +// num_lower: 0-D tensor. Number of subdiagonals to keep. If negative, keep entire +// lower triangle. +// num_upper: 0-D tensor. Number of superdiagonals to keep. If negative, keep +// entire upper triangle. // -// Returns Shape is `[..., M+1, M]`. -func SelfAdjointEig(scope *Scope, input tf.Output) (output tf.Output) { +// Returns Rank `k` tensor of the same shape as input. The extracted banded tensor. +func MatrixBandPart(scope *Scope, input tf.Output, num_lower tf.Output, num_upper tf.Output) (band tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SelfAdjointEig", + Type: "MatrixBandPart", Input: []tf.Input{ - input, + input, num_lower, num_upper, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes softplus gradients for a softplus operation. -// -// Arguments: -// gradients: The backpropagated gradients to the corresponding softplus operation. -// features: The features passed as input to the corresponding softplus operation. -// -// Returns The gradients: `gradients / (1 + exp(-features))`. -func SoftplusGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SoftplusGrad", - Input: []tf.Input{ - gradients, features, - }, +// QuantizedMatMulAttr is an optional argument to QuantizedMatMul. +type QuantizedMatMulAttr func(optionalAttr) + +// QuantizedMatMulToutput sets the optional Toutput attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedMatMulToutput(value tf.DataType) QuantizedMatMulAttr { + return func(m optionalAttr) { + m["Toutput"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// SelfAdjointEigV2Attr is an optional argument to SelfAdjointEigV2. -type SelfAdjointEigV2Attr func(optionalAttr) +// QuantizedMatMulTransposeA sets the optional transpose_a attribute to value. +// +// value: If true, `a` is transposed before multiplication. +// If not specified, defaults to false +func QuantizedMatMulTransposeA(value bool) QuantizedMatMulAttr { + return func(m optionalAttr) { + m["transpose_a"] = value + } +} -// SelfAdjointEigV2ComputeV sets the optional compute_v attribute to value. +// QuantizedMatMulTransposeB sets the optional transpose_b attribute to value. // -// value: If `True` then eigenvectors will be computed and returned in `v`. -// Otherwise, only the eigenvalues will be computed. -// If not specified, defaults to true -func SelfAdjointEigV2ComputeV(value bool) SelfAdjointEigV2Attr { +// value: If true, `b` is transposed before multiplication. +// If not specified, defaults to false +func QuantizedMatMulTransposeB(value bool) QuantizedMatMulAttr { return func(m optionalAttr) { - m["compute_v"] = value + m["transpose_b"] = value } } -// Computes the eigen decomposition of one or more square self-adjoint matrices. +// QuantizedMatMulTactivation sets the optional Tactivation attribute to value. // -// Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in -// `input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`. +// value: The type of output produced by activation function +// following this operation. +// If not specified, defaults to DT_QUINT8 +func QuantizedMatMulTactivation(value tf.DataType) QuantizedMatMulAttr { + return func(m optionalAttr) { + m["Tactivation"] = value + } +} + +// Perform a quantized matrix multiplication of `a` by the matrix `b`. // -// ```python -// # a is a tensor. -// # e is a tensor of eigenvalues. -// # v is a tensor of eigenvectors. -// e, v = self_adjoint_eig(a) -// e = self_adjoint_eig(a, compute_v=False) -// ``` +// The inputs must be two-dimensional matrices and the inner dimension of +// `a` (after being transposed if `transpose_a` is non-zero) must match the +// outer dimension of `b` (after being transposed if `transposed_b` is +// non-zero). // // Arguments: -// input: `Tensor` input of shape `[N, N]`. +// a: Must be a two-dimensional tensor. +// b: Must be a two-dimensional tensor. +// min_a: The float value that the lowest quantized `a` value represents. +// max_a: The float value that the highest quantized `a` value represents. +// min_b: The float value that the lowest quantized `b` value represents. +// max_b: The float value that the highest quantized `b` value represents. // -// Returns Eigenvalues. Shape is `[N]`.Eigenvectors. Shape is `[N, N]`. -func SelfAdjointEigV2(scope *Scope, input tf.Output, optional ...SelfAdjointEigV2Attr) (e tf.Output, v tf.Output) { +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +func QuantizedMatMul(scope *Scope, a tf.Output, b tf.Output, min_a tf.Output, max_a tf.Output, min_b tf.Output, max_b tf.Output, optional ...QuantizedMatMulAttr) (out tf.Output, min_out tf.Output, max_out tf.Output) { if scope.Err() != nil { return } @@ -17739,121 +17291,114 @@ func SelfAdjointEigV2(scope *Scope, input tf.Output, optional ...SelfAdjointEigV a(attrs) } opspec := tf.OpSpec{ - Type: "SelfAdjointEigV2", + Type: "QuantizedMatMul", Input: []tf.Input{ - input, + a, b, min_a, max_a, min_b, max_b, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0), op.Output(1), op.Output(2) } -// Adjust the saturation of one or more images. +// Does nothing. Serves as a control trigger for scheduling. // -// `images` is a tensor of at least 3 dimensions. The last dimension is -// interpretted as channels, and must be three. +// Only useful as a placeholder for control edges. // -// The input image is considered in the RGB colorspace. Conceptually, the RGB -// colors are first mapped into HSV. A scale is then applied all the saturation -// values, and then remapped back to RGB colorspace. +// Returns the created operation. +func ControlTrigger(scope *Scope) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ControlTrigger", + } + return scope.AddOperation(opspec) +} + +// Batch normalization. // -// Arguments: -// images: Images to adjust. At least 3-D. -// scale: A float scale to add to the saturation. +// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() // -// Returns The hue-adjusted image or images. -func AdjustSaturation(scope *Scope, images tf.Output, scale tf.Output) (output tf.Output) { +// This op is deprecated. Prefer `tf.nn.batch_normalization`. +// +// Arguments: +// t: A 4D input Tensor. +// m: A 1D mean Tensor with size matching the last dimension of t. +// This is the first output from tf.nn.moments, +// or a saved moving average thereof. +// v: A 1D variance Tensor with size matching the last dimension of t. +// This is the second output from tf.nn.moments, +// or a saved moving average thereof. +// beta: A 1D beta Tensor with size matching the last dimension of t. +// An offset to be added to the normalized tensor. +// gamma: A 1D gamma Tensor with size matching the last dimension of t. +// If "scale_after_normalization" is true, this tensor will be multiplied +// with the normalized tensor. +// variance_epsilon: A small float number to avoid dividing by 0. +// scale_after_normalization: A bool indicating whether the resulted tensor +// needs to be multiplied with gamma. +func BatchNormWithGlobalNormalization(scope *Scope, t tf.Output, m tf.Output, v tf.Output, beta tf.Output, gamma tf.Output, variance_epsilon float32, scale_after_normalization bool) (result tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} opspec := tf.OpSpec{ - Type: "AdjustSaturation", + Type: "BatchNormWithGlobalNormalization", Input: []tf.Input{ - images, scale, + t, m, v, beta, gamma, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Elementwise computes the bitwise OR of `x` and `y`. +// Deprecated. Use TensorArrayReadV3 // -// The result will have those bits set, that are set in `x`, `y` or both. The -// computation is performed on the underlying representations of `x` and `y`. -func BitwiseOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// DEPRECATED at GraphDef version 26: Use TensorArrayReadV3 +func TensorArrayReadV2(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "BitwiseOr", + Type: "TensorArrayReadV2", Input: []tf.Input{ - x, y, + handle, index, flow_in, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// MatrixSolveLsAttr is an optional argument to MatrixSolveLs. -type MatrixSolveLsAttr func(optionalAttr) +// QuantizedMulAttr is an optional argument to QuantizedMul. +type QuantizedMulAttr func(optionalAttr) -// MatrixSolveLsFast sets the optional fast attribute to value. -// If not specified, defaults to true -func MatrixSolveLsFast(value bool) MatrixSolveLsAttr { +// QuantizedMulToutput sets the optional Toutput attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedMulToutput(value tf.DataType) QuantizedMulAttr { return func(m optionalAttr) { - m["fast"] = value + m["Toutput"] = value } } -// Solves one or more linear least-squares problems. -// -// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions -// form real or complex matrices of size `[M, N]`. `Rhs` is a tensor of the same -// type as `matrix` and shape `[..., M, K]`. -// The output is a tensor shape `[..., N, K]` where each output matrix solves -// each of the equations -// `matrix[..., :, :]` * `output[..., :, :]` = `rhs[..., :, :]` -// in the least squares sense. -// -// We use the following notation for (complex) matrix and right-hand sides -// in the batch: -// -// `matrix`=\\(A \in \mathbb{C}^{m \times n}\\), -// `rhs`=\\(B \in \mathbb{C}^{m \times k}\\), -// `output`=\\(X \in \mathbb{C}^{n \times k}\\), -// `l2_regularizer`=\\(\lambda \in \mathbb{R}\\). +// Returns x * y element-wise, working on quantized buffers. // -// If `fast` is `True`, then the solution is computed by solving the normal -// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then -// \\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares -// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + -// \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as -// \\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the -// minimum-norm solution to the under-determined linear system, i.e. -// \\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\), -// subject to \\(A Z = B\\). Notice that the fast path is only numerically stable -// when \\(A\\) is numerically full rank and has a condition number -// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or\\(\lambda\\) is -// sufficiently large. +// Arguments: // -// If `fast` is `False` an algorithm based on the numerically robust complete -// orthogonal decomposition is used. This computes the minimum-norm -// least-squares solution, even when \\(A\\) is rank deficient. This path is -// typically 6-7 times slower than the fast path. If `fast` is `False` then -// `l2_regularizer` is ignored. // -// Arguments: -// matrix: Shape is `[..., M, N]`. -// rhs: Shape is `[..., M, K]`. -// l2_regularizer: Scalar tensor. +// min_x: The float value that the lowest quantized `x` value represents. +// max_x: The float value that the highest quantized `x` value represents. +// min_y: The float value that the lowest quantized `y` value represents. +// max_y: The float value that the highest quantized `y` value represents. // -// @compatibility(numpy) -// Equivalent to np.linalg.lstsq -// @end_compatibility +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. // -// Returns Shape is `[..., N, K]`. -func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) { +// *NOTE*: `QuantizedMul` supports limited forms of broadcasting. More about +// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func QuantizedMul(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedMulAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { if scope.Err() != nil { return } @@ -17862,67 +17407,42 @@ func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer a(attrs) } opspec := tf.OpSpec{ - Type: "MatrixSolveLs", + Type: "QuantizedMul", Input: []tf.Input{ - matrix, rhs, l2_regularizer, + x, y, min_x, max_x, min_y, max_y, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// SvdAttr is an optional argument to Svd. -type SvdAttr func(optionalAttr) +// QuantizedAddAttr is an optional argument to QuantizedAdd. +type QuantizedAddAttr func(optionalAttr) -// SvdComputeUv sets the optional compute_uv attribute to value. -// -// value: If true, left and right singular vectors will be -// computed and returned in `u` and `v`, respectively. -// If false, `u` and `v` are not set and should never referenced. -// If not specified, defaults to true -func SvdComputeUv(value bool) SvdAttr { +// QuantizedAddToutput sets the optional Toutput attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedAddToutput(value tf.DataType) QuantizedAddAttr { return func(m optionalAttr) { - m["compute_uv"] = value + m["Toutput"] = value } } -// SvdFullMatrices sets the optional full_matrices attribute to value. +// Returns x + y element-wise, working on quantized buffers. // -// value: If true, compute full-sized `u` and `v`. If false -// (the default), compute only the leading `P` singular vectors. -// Ignored if `compute_uv` is `False`. -// If not specified, defaults to false -func SvdFullMatrices(value bool) SvdAttr { - return func(m optionalAttr) { - m["full_matrices"] = value - } -} - -// Computes the singular value decompositions of one or more matrices. +// Arguments: // -// Computes the SVD of each inner matrix in `input` such that -// `input[..., :, :] = u[..., :, :] * diag(s[..., :, :]) * transpose(v[..., :, :])` // -// ```python -// # a is a tensor containing a batch of matrices. -// # s is a tensor of singular values for each matrix. -// # u is the tensor containing of left singular vectors for each matrix. -// # v is the tensor containing of right singular vectors for each matrix. -// s, u, v = svd(a) -// s, _, _ = svd(a, compute_uv=False) -// ``` +// min_x: The float value that the lowest quantized `x` value represents. +// max_x: The float value that the highest quantized `x` value represents. +// min_y: The float value that the lowest quantized `y` value represents. +// max_y: The float value that the highest quantized `y` value represents. // -// Arguments: -// input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions -// form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`. +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. // -// Returns Singular values. Shape is `[..., P]`.Left singular vectors. If `full_matrices` is `False` then shape is -// `[..., M, P]`; if `full_matrices` is `True` then shape is -// `[..., M, M]`. Undefined if `compute_uv` is `False`.Left singular vectors. If `full_matrices` is `False` then shape is -// `[..., N, P]`. If `full_matrices` is `True` then shape is `[..., N, N]`. -// Undefined if `compute_uv` is false. -func Svd(scope *Scope, input tf.Output, optional ...SvdAttr) (s tf.Output, u tf.Output, v tf.Output) { +// *NOTE*: `QuantizedAdd` supports limited forms of broadcasting. More about +// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedAddAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { if scope.Err() != nil { return } @@ -17931,9 +17451,9 @@ func Svd(scope *Scope, input tf.Output, optional ...SvdAttr) (s tf.Output, u tf. a(attrs) } opspec := tf.OpSpec{ - Type: "Svd", + Type: "QuantizedAdd", Input: []tf.Input{ - input, + x, y, min_x, max_x, min_y, max_y, }, Attrs: attrs, } @@ -17941,40 +17461,65 @@ func Svd(scope *Scope, input tf.Output, optional ...SvdAttr) (s tf.Output, u tf. return op.Output(0), op.Output(1), op.Output(2) } -// QueueEnqueueManyV2Attr is an optional argument to QueueEnqueueManyV2. -type QueueEnqueueManyV2Attr func(optionalAttr) +// MfccAttr is an optional argument to Mfcc. +type MfccAttr func(optionalAttr) -// QueueEnqueueManyV2TimeoutMs sets the optional timeout_ms attribute to value. +// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value. // -// value: If the queue is too full, this operation will block for up -// to timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueEnqueueManyV2TimeoutMs(value int64) QueueEnqueueManyV2Attr { +// value: The highest frequency to use when calculating the +// ceptstrum. +// If not specified, defaults to 4000 +func MfccUpperFrequencyLimit(value float32) MfccAttr { return func(m optionalAttr) { - m["timeout_ms"] = value + m["upper_frequency_limit"] = value } } -// Enqueues zero or more tuples of one or more tensors in the given queue. +// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value. // -// This operation slices each component tensor along the 0th dimension to -// make multiple queue elements. All of the tuple components must have the -// same size in the 0th dimension. +// value: The lowest frequency to use when calculating the +// ceptstrum. +// If not specified, defaults to 20 +func MfccLowerFrequencyLimit(value float32) MfccAttr { + return func(m optionalAttr) { + m["lower_frequency_limit"] = value + } +} + +// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value. // -// The components input has k elements, which correspond to the components of -// tuples stored in the given queue. +// value: Resolution of the Mel bank used internally. +// If not specified, defaults to 40 +func MfccFilterbankChannelCount(value int64) MfccAttr { + return func(m optionalAttr) { + m["filterbank_channel_count"] = value + } +} + +// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value. // -// N.B. If the queue is full, this operation will block until the given -// elements have been enqueued (or 'timeout_ms' elapses, if specified). +// value: How many output channels to produce per time slice. +// If not specified, defaults to 13 +func MfccDctCoefficientCount(value int64) MfccAttr { + return func(m optionalAttr) { + m["dct_coefficient_count"] = value + } +} + +// Transforms a spectrogram into a form that's useful for speech recognition. // -// Arguments: -// handle: The handle to a queue. -// components: One or more tensors from which the enqueued tensors should -// be taken. +// Mel Frequency Cepstral Coefficients are a way of representing audio data that's +// been effective as an input feature for machine learning. They are created by +// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the +// higher frequencies that are less significant to the human ear. They have a long +// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum +// is a good resource to learn more. // -// Returns the created operation. -func QueueEnqueueManyV2(scope *Scope, handle tf.Output, components []tf.Output, optional ...QueueEnqueueManyV2Attr) (o *tf.Operation) { +// Arguments: +// spectrogram: Typically produced by the Spectrogram op, with magnitude_squared +// set to true. +// sample_rate: How many samples per second the source audio used. +func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -17983,123 +17528,212 @@ func QueueEnqueueManyV2(scope *Scope, handle tf.Output, components []tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "QueueEnqueueManyV2", + Type: "Mfcc", Input: []tf.Input{ - handle, tf.OutputList(components), + spectrogram, sample_rate, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Computes the product along segments of a tensor. +// Given a quantized tensor described by (input, input_min, input_max), outputs a // -// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of -// segments. +// range that covers the actual values present in that tensor. This op is +// typically used to produce the requested_output_min and requested_output_max for +// Requantize. // -// Computes a tensor such that -// \\(output_i = \prod_j data_j\\) where the product is over `j` such -// that `segment_ids[j] == i`. +// Arguments: // -// If the product is empty for a given segment ID `i`, `output[i] = 1`. +// input_min: The float value that the minimum quantized input value represents. +// input_max: The float value that the maximum quantized input value represents. // -//
-// -//
+// Returns The computed min output.the computed max output. +func RequantizationRange(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output) (output_min tf.Output, output_max tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RequantizationRange", + Input: []tf.Input{ + input, input_min, input_max, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// MapPeekAttr is an optional argument to MapPeek. +type MapPeekAttr func(optionalAttr) + +// MapPeekCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// Arguments: +// REQUIRES: value >= 0 +func MapPeekCapacity(value int64) MapPeekAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// MapPeekMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s -// first dimension. Values should be sorted and can be repeated. +// REQUIRES: value >= 0 +func MapPeekMemoryLimit(value int64) MapPeekAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// MapPeekContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapPeekContainer(value string) MapPeekAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MapPeekSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapPeekSharedName(value string) MapPeekAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op peeks at the values at the specified key. If the // -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { +// underlying container does not contain this key +// this op will block until it does. +func MapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapPeekAttr) (values []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SegmentProd", + Type: "MapPeek", Input: []tf.Input{ - data, segment_ids, + key, indices, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("MapPeek", err) + return + } + return values } -// Converts one or more images from RGB to HSV. +// Looks up keys in a table, outputs the corresponding values. // -// Outputs a tensor of the same shape as the `images` tensor, containing the HSV -// value of the pixels. The output is only well defined if the value in `images` -// are in `[0,1]`. +// The tensor `keys` must of the same type as the keys of the table. +// The output `values` is of the type of the table values. // -// `output[..., 0]` contains hue, `output[..., 1]` contains saturation, and -// `output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0 -// corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue. +// The scalar `default_value` is the value output for keys not present in the +// table. It must also be of the same type as the table values. // // Arguments: -// images: 1-D or higher rank. RGB data to convert. Last dimension must be size 3. +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. // -// Returns `images` converted to HSV. -func RGBToHSV(scope *Scope, images tf.Output) (output tf.Output) { +// +// Returns Same shape as `keys`. Values found in the table, or `default_values` +// for missing keys. +func LookupTableFindV2(scope *Scope, table_handle tf.Output, keys tf.Output, default_value tf.Output) (values tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "RGBToHSV", + Type: "LookupTableFindV2", Input: []tf.Input{ - images, + table_handle, keys, default_value, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Does nothing. Only useful as a placeholder for control edges. +// Bucketizes 'input' based on 'boundaries'. // -// Returns the created operation. -func NoOp(scope *Scope) (o *tf.Operation) { +// For example, if the inputs are +// boundaries = [0, 10, 100] +// input = [[-5, 10000] +// [150, 10] +// [5, 100]] +// +// then the output will be +// output = [[0, 3] +// [3, 2] +// [1, 3]] +// +// Arguments: +// input: Any shape of Tensor contains with int or float type. +// boundaries: A sorted list of floats gives the boundary of the buckets. +// +// Returns Same shape with 'input', each value of input replaced with bucket index. +// +// @compatibility(numpy) +// Equivalent to np.digitize. +// @end_compatibility +func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"boundaries": boundaries} opspec := tf.OpSpec{ - Type: "NoOp", + Type: "Bucketize", + Input: []tf.Input{ + input, + }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// MergeV2CheckpointsAttr is an optional argument to MergeV2Checkpoints. -type MergeV2CheckpointsAttr func(optionalAttr) +// EncodePngAttr is an optional argument to EncodePng. +type EncodePngAttr func(optionalAttr) -// MergeV2CheckpointsDeleteOldDirs sets the optional delete_old_dirs attribute to value. +// EncodePngCompression sets the optional compression attribute to value. // -// value: see above. -// If not specified, defaults to true -func MergeV2CheckpointsDeleteOldDirs(value bool) MergeV2CheckpointsAttr { +// value: Compression level. +// If not specified, defaults to -1 +func EncodePngCompression(value int64) EncodePngAttr { return func(m optionalAttr) { - m["delete_old_dirs"] = value + m["compression"] = value } } -// V2 format specific: merges the metadata files of sharded checkpoints. The +// PNG-encode an image. // -// result is one logical checkpoint, with one physical metadata file and renamed -// data files. +// `image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]` +// where `channels` is: // -// Intended for "grouping" multiple checkpoints in a sharded checkpoint setup. +// * 1: for grayscale. +// * 2: for grayscale + alpha. +// * 3: for RGB. +// * 4: for RGBA. // -// If delete_old_dirs is true, attempts to delete recursively the dirname of each -// path in the input checkpoint_prefixes. This is useful when those paths are non -// user-facing temporary locations. +// The ZLIB compression level, `compression`, can be -1 for the PNG-encoder +// default or a value from 0 to 9. 9 is the highest compression level, generating +// the smallest output, but is slower. // // Arguments: -// checkpoint_prefixes: prefixes of V2 checkpoints to merge. -// destination_prefix: scalar. The desired final prefix. Allowed to be the same -// as one of the checkpoint_prefixes. +// image: 3-D with shape `[height, width, channels]`. // -// Returns the created operation. -func MergeV2Checkpoints(scope *Scope, checkpoint_prefixes tf.Output, destination_prefix tf.Output, optional ...MergeV2CheckpointsAttr) (o *tf.Operation) { +// Returns 0-D. PNG-encoded image. +func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (contents tf.Output) { if scope.Err() != nil { return } @@ -18108,182 +17742,101 @@ func MergeV2Checkpoints(scope *Scope, checkpoint_prefixes tf.Output, destination a(attrs) } opspec := tf.OpSpec{ - Type: "MergeV2Checkpoints", + Type: "EncodePng", Input: []tf.Input{ - checkpoint_prefixes, destination_prefix, + image, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Saves input tensors slices to disk. -// -// This is like `Save` except that tensors can be listed in the saved file as being -// a slice of a larger tensor. `shapes_and_slices` specifies the shape of the -// larger tensor and the slice that this tensor covers. `shapes_and_slices` must -// have as many elements as `tensor_names`. -// -// Elements of the `shapes_and_slices` input must either be: -// -// * The empty string, in which case the corresponding tensor is -// saved normally. -// * A string of the form `dim0 dim1 ... dimN-1 slice-spec` where the -// `dimI` are the dimensions of the larger tensor and `slice-spec` -// specifies what part is covered by the tensor to save. -// -// `slice-spec` itself is a `:`-separated list: `slice0:slice1:...:sliceN-1` -// where each `sliceI` is either: -// -// * The string `-` meaning that the slice covers all indices of this dimension -// * `start,length` where `start` and `length` are integers. In that -// case the slice covers `length` indices starting at `start`. +// Updates the table to associates keys with values. // -// See also `Save`. +// The tensor `keys` must be of the same type as the keys of the table. +// The tensor `values` must be of the type of the table values. // // Arguments: -// filename: Must have a single element. The name of the file to which we write the -// tensor. -// tensor_names: Shape `[N]`. The names of the tensors to be saved. -// shapes_and_slices: Shape `[N]`. The shapes and slice specifications to use when -// saving the tensors. -// data: `N` tensors to save. +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// values: Values to associate with keys. // // Returns the created operation. -func SaveSlices(scope *Scope, filename tf.Output, tensor_names tf.Output, shapes_and_slices tf.Output, data []tf.Output) (o *tf.Operation) { +func LookupTableInsertV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SaveSlices", + Type: "LookupTableInsertV2", Input: []tf.Input{ - filename, tensor_names, shapes_and_slices, tf.OutputList(data), + table_handle, keys, values, }, } return scope.AddOperation(opspec) } -// DenseToDenseSetOperationAttr is an optional argument to DenseToDenseSetOperation. -type DenseToDenseSetOperationAttr func(optionalAttr) - -// DenseToDenseSetOperationValidateIndices sets the optional validate_indices attribute to value. -// If not specified, defaults to true -func DenseToDenseSetOperationValidateIndices(value bool) DenseToDenseSetOperationAttr { - return func(m optionalAttr) { - m["validate_indices"] = value - } -} - -// Applies set operation along last dimension of 2 `Tensor` inputs. -// -// See SetOperationOp::SetOperationFromContext for values of `set_operation`. -// -// Output `result` is a `SparseTensor` represented by `result_indices`, -// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this -// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` -// dimension contains the result of `set_operation` applied to the corresponding -// `[0...n-1]` dimension of `set`. -// -// Arguments: -// set1: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set2`. -// Dimension `n` contains values in a set, duplicates are allowed but ignored. -// set2: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set1`. -// Dimension `n` contains values in a set, duplicates are allowed but ignored. -// -// -// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is -// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` -// is the max result set size across all `0...n-1` dimensions. -func DenseToDenseSetOperation(scope *Scope, set1 tf.Output, set2 tf.Output, set_operation string, optional ...DenseToDenseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"set_operation": set_operation} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DenseToDenseSetOperation", - Input: []tf.Input{ - set1, set2, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Generate a sharded filename. The filename is printf formatted as -// -// %s-%05d-of-%05d, basename, shard, num_shards. -func ShardedFilename(scope *Scope, basename tf.Output, shard tf.Output, num_shards tf.Output) (filename tf.Output) { +// Returns element-wise smallest integer in not less than x. +func Ceil(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ShardedFilename", + Type: "Ceil", Input: []tf.Input{ - basename, shard, num_shards, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Generate a glob pattern matching all sharded file names. -func ShardedFilespec(scope *Scope, basename tf.Output, num_shards tf.Output) (filename tf.Output) { +// Computes the number of elements in the given table. +// +// Arguments: +// table_handle: Handle to the table. +// +// Returns Scalar that contains number of elements in the table. +func LookupTableSizeV2(scope *Scope, table_handle tf.Output) (size tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ShardedFilespec", + Type: "LookupTableSizeV2", Input: []tf.Input{ - basename, num_shards, + table_handle, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// TextLineReaderV2Attr is an optional argument to TextLineReaderV2. -type TextLineReaderV2Attr func(optionalAttr) - -// TextLineReaderV2SkipHeaderLines sets the optional skip_header_lines attribute to value. -// -// value: Number of lines to skip from the beginning of every file. -// If not specified, defaults to 0 -func TextLineReaderV2SkipHeaderLines(value int64) TextLineReaderV2Attr { - return func(m optionalAttr) { - m["skip_header_lines"] = value - } -} +// ResizeBilinearGradAttr is an optional argument to ResizeBilinearGrad. +type ResizeBilinearGradAttr func(optionalAttr) -// TextLineReaderV2Container sets the optional container attribute to value. +// ResizeBilinearGradAlignCorners sets the optional align_corners attribute to value. // -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func TextLineReaderV2Container(value string) TextLineReaderV2Attr { +// value: If true, rescale grads by (orig_height - 1) / (height - 1), which +// exactly aligns the 4 corners of grads and original_image. If false, rescale by +// orig_height / height. Treat similarly the width dimension. +// If not specified, defaults to false +func ResizeBilinearGradAlignCorners(value bool) ResizeBilinearGradAttr { return func(m optionalAttr) { - m["container"] = value + m["align_corners"] = value } } -// TextLineReaderV2SharedName sets the optional shared_name attribute to value. +// Computes the gradient of bilinear interpolation. // -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func TextLineReaderV2SharedName(value string) TextLineReaderV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A Reader that outputs the lines of a file delimited by '\n'. +// Arguments: +// grads: 4-D with shape `[batch, height, width, channels]`. +// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, +// The image tensor that was resized. // -// Returns The handle to reference the Reader. -func TextLineReaderV2(scope *Scope, optional ...TextLineReaderV2Attr) (reader_handle tf.Output) { +// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. +// Gradients with respect to the input image. Input image must have been +// float or double. +func ResizeBilinearGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBilinearGradAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -18291,241 +17844,258 @@ func TextLineReaderV2(scope *Scope, optional ...TextLineReaderV2Attr) (reader_ha for _, a := range optional { a(attrs) } - opspec := tf.OpSpec{ - Type: "TextLineReaderV2", - + opspec := tf.OpSpec{ + Type: "ResizeBilinearGrad", + Input: []tf.Input{ + grads, original_image, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// LoadAndRemapMatrixAttr is an optional argument to LoadAndRemapMatrix. -type LoadAndRemapMatrixAttr func(optionalAttr) - -// LoadAndRemapMatrixMaxRowsInMemory sets the optional max_rows_in_memory attribute to value. -// -// value: The maximum number of rows to load from the checkpoint at -// once. If less than or equal to 0, the entire matrix will be loaded into -// memory. Setting this arg trades increased disk reads for lower memory usage. -// If not specified, defaults to -1 -func LoadAndRemapMatrixMaxRowsInMemory(value int64) LoadAndRemapMatrixAttr { - return func(m optionalAttr) { - m["max_rows_in_memory"] = value - } -} - -// Loads a 2-D (matrix) `Tensor` with name `old_tensor_name` from the checkpoint -// -// at `ckpt_path` and potentially reorders its rows and columns using the -// specified remappings. -// -// Most users should use one of the wrapper initializers (such as -// `tf.contrib.framework.load_and_remap_matrix_initializer`) instead of this -// function directly. -// -// The remappings are 1-D tensors with the following properties: +// Outputs all keys and values in the table. // -// * `row_remapping` must have exactly `num_rows` entries. Row `i` of the output -// matrix will be initialized from the row corresponding to index -// `row_remapping[i]` in the old `Tensor` from the checkpoint. -// * `col_remapping` must have either 0 entries (indicating that no column -// reordering is needed) or `num_cols` entries. If specified, column `j` of the -// output matrix will be initialized from the column corresponding to index -// `col_remapping[j]` in the old `Tensor` from the checkpoint. -// * A value of -1 in either of the remappings signifies a "missing" entry. In that -// case, values from the `initializing_values` tensor will be used to fill that -// missing row or column. If `row_remapping` has `r` missing entries and -// `col_remapping` has `c` missing entries, then the following condition must be -// true: +// Arguments: +// table_handle: Handle to the table. // -// `(r * num_cols) + (c * num_rows) - (r * c) == len(initializing_values)` // -// The remapping tensors can be generated using the GenerateVocabRemapping op. // -// As an example, with row_remapping = [1, 0, -1], col_remapping = [0, 2, -1], -// initializing_values = [0.5, -0.5, 0.25, -0.25, 42], and w(i, j) representing -// the value from row i, column j of the old tensor in the checkpoint, the output -// matrix will look like the following: +// Returns Vector of all keys present in the table.Tensor of all values in the table. Indexed in parallel with `keys`. +func LookupTableExportV2(scope *Scope, table_handle tf.Output, Tkeys tf.DataType, Tvalues tf.DataType) (keys tf.Output, values tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"Tkeys": Tkeys, "Tvalues": Tvalues} + opspec := tf.OpSpec{ + Type: "LookupTableExportV2", + Input: []tf.Input{ + table_handle, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Replaces the contents of the table with the specified keys and values. // -// [[w(1, 0), w(1, 2), 0.5], -// [w(0, 0), w(0, 2), -0.5], -// [0.25, -0.25, 42]] +// The tensor `keys` must be of the same type as the keys of the table. +// The tensor `values` must be of the type of the table values. // // Arguments: -// ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`) from -// which the old matrix `Tensor` will be loaded. -// old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint. -// row_remapping: An int `Tensor` of row remappings (generally created by -// `generate_vocab_remapping`). Even if no row remapping is needed, this must -// still be an index-valued Tensor (e.g. [0, 1, 2, ...]), or a shifted -// index-valued `Tensor` (e.g. [8, 9, 10, ...], for partitioned `Variables`). -// col_remapping: An int `Tensor` of column remappings (generally created by -// `generate_vocab_remapping`). May be a size-0 `Tensor` if only row remapping -// is to be done (e.g. column ordering is the same). -// initializing_values: A float `Tensor` containing values to fill in for cells -// in the output matrix that are not loaded from the checkpoint. Length must be -// exactly the same as the number of missing / new cells. -// num_rows: Number of rows (length of the 1st dimension) in the output matrix. -// num_cols: Number of columns (length of the 2nd dimension) in the output matrix. +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// values: Values to associate with keys. // -// Returns Output matrix containing existing values loaded from the -// checkpoint, and with any missing values filled in from initializing_values. -func LoadAndRemapMatrix(scope *Scope, ckpt_path tf.Output, old_tensor_name tf.Output, row_remapping tf.Output, col_remapping tf.Output, initializing_values tf.Output, num_rows int64, num_cols int64, optional ...LoadAndRemapMatrixAttr) (output_matrix tf.Output) { +// Returns the created operation. +func LookupTableImportV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_rows": num_rows, "num_cols": num_cols} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "LoadAndRemapMatrix", + Type: "LookupTableImportV2", Input: []tf.Input{ - ckpt_path, old_tensor_name, row_remapping, col_remapping, initializing_values, + table_handle, keys, values, }, - Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// TFRecordReaderV2Attr is an optional argument to TFRecordReaderV2. -type TFRecordReaderV2Attr func(optionalAttr) +// MapUnstageNoKeyAttr is an optional argument to MapUnstageNoKey. +type MapUnstageNoKeyAttr func(optionalAttr) -// TFRecordReaderV2Container sets the optional container attribute to value. +// MapUnstageNoKeyCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func TFRecordReaderV2Container(value string) TFRecordReaderV2Attr { +// REQUIRES: value >= 0 +func MapUnstageNoKeyCapacity(value int64) MapUnstageNoKeyAttr { return func(m optionalAttr) { - m["container"] = value + m["capacity"] = value } } -// TFRecordReaderV2SharedName sets the optional shared_name attribute to value. +// MapUnstageNoKeyMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. +// REQUIRES: value >= 0 +func MapUnstageNoKeyMemoryLimit(value int64) MapUnstageNoKeyAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// MapUnstageNoKeyContainer sets the optional container attribute to value. // If not specified, defaults to "" -func TFRecordReaderV2SharedName(value string) TFRecordReaderV2Attr { +func MapUnstageNoKeyContainer(value string) MapUnstageNoKeyAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["container"] = value } } -// TFRecordReaderV2CompressionType sets the optional compression_type attribute to value. +// MapUnstageNoKeySharedName sets the optional shared_name attribute to value. // If not specified, defaults to "" -func TFRecordReaderV2CompressionType(value string) TFRecordReaderV2Attr { +func MapUnstageNoKeySharedName(value string) MapUnstageNoKeyAttr { return func(m optionalAttr) { - m["compression_type"] = value + m["shared_name"] = value } } -// A Reader that outputs the records from a TensorFlow Records file. +// Op removes and returns a random (key, value) // -// Returns The handle to reference the Reader. -func TFRecordReaderV2(scope *Scope, optional ...TFRecordReaderV2Attr) (reader_handle tf.Output) { +// from the underlying container. If the underlying container +// does not contain elements, the op will block until it does. +func MapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageNoKeyAttr) (key tf.Output, values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TFRecordReaderV2", - + Type: "MapUnstageNoKey", + Input: []tf.Input{ + indices, + }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + key = op.Output(idx) + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("MapUnstageNoKey", err) + return + } + return key, values } -// QuantizeAndDequantizeV3Attr is an optional argument to QuantizeAndDequantizeV3. -type QuantizeAndDequantizeV3Attr func(optionalAttr) +// HashTableV2Attr is an optional argument to HashTableV2. +type HashTableV2Attr func(optionalAttr) -// QuantizeAndDequantizeV3SignedInput sets the optional signed_input attribute to value. -// If not specified, defaults to true -func QuantizeAndDequantizeV3SignedInput(value bool) QuantizeAndDequantizeV3Attr { +// HashTableV2Container sets the optional container attribute to value. +// +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func HashTableV2Container(value string) HashTableV2Attr { return func(m optionalAttr) { - m["signed_input"] = value + m["container"] = value } } -// QuantizeAndDequantizeV3RangeGiven sets the optional range_given attribute to value. -// If not specified, defaults to true -func QuantizeAndDequantizeV3RangeGiven(value bool) QuantizeAndDequantizeV3Attr { +// HashTableV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func HashTableV2SharedName(value string) HashTableV2Attr { return func(m optionalAttr) { - m["range_given"] = value + m["shared_name"] = value } } -// Quantizes then dequantizes a tensor. +// HashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. // -// This is almost identical to QuantizeAndDequantizeV2, except that num_bits is a -// tensor, so its value can change during training. -func QuantizeAndDequantizeV3(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, num_bits tf.Output, optional ...QuantizeAndDequantizeV3Attr) (output tf.Output) { +// value: If true and shared_name is empty, the table is shared +// using the node name. +// If not specified, defaults to false +func HashTableV2UseNodeNameSharing(value bool) HashTableV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// Creates a non-initialized hash table. +// +// This op creates a hash table, specifying the type of its keys and values. +// Before using the table you will have to initialize it. After initialization the +// table will be immutable. +// +// Arguments: +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func HashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...HashTableV2Attr) (table_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizeAndDequantizeV3", - Input: []tf.Input{ - input, input_min, input_max, num_bits, - }, + Type: "HashTableV2", + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// IdentityReaderV2Attr is an optional argument to IdentityReaderV2. -type IdentityReaderV2Attr func(optionalAttr) +// MutableHashTableV2Attr is an optional argument to MutableHashTableV2. +type MutableHashTableV2Attr func(optionalAttr) -// IdentityReaderV2Container sets the optional container attribute to value. +// MutableHashTableV2Container sets the optional container attribute to value. // -// value: If non-empty, this reader is placed in the given container. +// value: If non-empty, this table is placed in the given container. // Otherwise, a default container is used. // If not specified, defaults to "" -func IdentityReaderV2Container(value string) IdentityReaderV2Attr { +func MutableHashTableV2Container(value string) MutableHashTableV2Attr { return func(m optionalAttr) { m["container"] = value } } -// IdentityReaderV2SharedName sets the optional shared_name attribute to value. +// MutableHashTableV2SharedName sets the optional shared_name attribute to value. // -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. +// value: If non-empty, this table is shared under the given name across +// multiple sessions. // If not specified, defaults to "" -func IdentityReaderV2SharedName(value string) IdentityReaderV2Attr { +func MutableHashTableV2SharedName(value string) MutableHashTableV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// MutableHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// +// value: If true and shared_name is empty, the table is shared +// using the node name. +// If not specified, defaults to false +func MutableHashTableV2UseNodeNameSharing(value bool) MutableHashTableV2Attr { return func(m optionalAttr) { - m["shared_name"] = value + m["use_node_name_sharing"] = value } } -// A Reader that outputs the queued work as both the key and value. +// Creates an empty hash table. // -// To use, enqueue strings in a Queue. ReaderRead will take the front -// work string and output (work, work). +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a scalar. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. // -// Returns The handle to reference the Reader. -func IdentityReaderV2(scope *Scope, optional ...IdentityReaderV2Attr) (reader_handle tf.Output) { +// Arguments: +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func MutableHashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableV2Attr) (table_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "IdentityReaderV2", + Type: "MutableHashTableV2", Attrs: attrs, } @@ -18533,29 +18103,98 @@ func IdentityReaderV2(scope *Scope, optional ...IdentityReaderV2Attr) (reader_ha return op.Output(0) } -// ResourceApplyGradientDescentAttr is an optional argument to ResourceApplyGradientDescent. -type ResourceApplyGradientDescentAttr func(optionalAttr) +// DequantizeAttr is an optional argument to Dequantize. +type DequantizeAttr func(optionalAttr) -// ResourceApplyGradientDescentUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, the subtraction will be protected by a lock; -// otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceApplyGradientDescentUseLocking(value bool) ResourceApplyGradientDescentAttr { +// DequantizeMode sets the optional mode attribute to value. +// If not specified, defaults to "MIN_COMBINED" +func DequantizeMode(value string) DequantizeAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["mode"] = value } } -// Update '*var' by subtracting 'alpha' * 'delta' from it. +// Dequantize the 'input' tensor into a float Tensor. +// +// [min_range, max_range] are scalar floats that specify the range for +// the 'input' data. The 'mode' attribute controls exactly which calculations are +// used to convert the float values to their quantized equivalents. +// +// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: +// +// ``` +// if T == qint8, in[i] += (range(T) + 1)/ 2.0 +// out[i] = min_range + (in[i]* (max_range - min_range) / range(T)) +// ``` +// here `range(T) = numeric_limits::max() - numeric_limits::min()` +// +// *MIN_COMBINED Mode Example* +// +// If the input comes from a QuantizedRelu6, the output type is +// quint8 (range of 0-255) but the possible range of QuantizedRelu6 is +// 0-6. The min_range and max_range values are therefore 0.0 and 6.0. +// Dequantize on quint8 will take each value, cast to float, and multiply +// by 6 / 255. +// Note that if quantizedtype is qint8, the operation will additionally add +// each value by 128 prior to casting. +// +// If the mode is 'MIN_FIRST', then this approach is used: +// +// ```c++ +// num_discrete_values = 1 << (# of bits in T) +// range_adjust = num_discrete_values / (num_discrete_values - 1) +// range = (range_max - range_min) * range_adjust +// range_scale = range / num_discrete_values +// const double offset_input = static_cast(input) - lowest_quantized; +// result = range_min + ((input - numeric_limits::min()) * range_scale) +// ``` +// +// *SCALED mode Example* +// +// `SCALED` mode matches the quantization approach used in +// `QuantizeAndDequantize{V2|V3}`. +// +// If the mode is `SCALED`, we do not use the full range of the output type, +// choosing to elide the lowest possible value for symmetry (e.g., output range is +// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to +// 0. +// +// We first find the range of values in our tensor. The +// range we use is always centered on 0, so we find m such that +// ```c++ +// m = max(abs(input_min), abs(input_max)) +// ``` +// +// Our input tensor range is then `[-m, m]`. +// +// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. +// If T is signed, this is +// ``` +// num_bits = sizeof(T) * 8 +// [min_fixed, max_fixed] = +// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1] +// ``` +// +// Otherwise, if T is unsigned, the fixed-point range is +// ``` +// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] +// ``` +// +// From this we compute our scaling factor, s: +// ```c++ +// s = (2 * m) / (max_fixed - min_fixed) +// ``` +// +// Now we can dequantize the elements of our tensor: +// ```c++ +// result = input * s +// ``` // // Arguments: -// var_: Should be from a Variable(). -// alpha: Scaling factor. Must be a scalar. -// delta: The change. // -// Returns the created operation. -func ResourceApplyGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, delta tf.Output, optional ...ResourceApplyGradientDescentAttr) (o *tf.Operation) { +// min_range: The minimum scalar value possibly produced for the input. +// max_range: The maximum scalar value possibly produced for the input. +func Dequantize(scope *Scope, input tf.Output, min_range tf.Output, max_range tf.Output, optional ...DequantizeAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -18564,132 +18203,115 @@ func ResourceApplyGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyGradientDescent", + Type: "Dequantize", Input: []tf.Input{ - var_, alpha, delta, + input, min_range, max_range, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Returns the next record (key, value pair) produced by a Reader. -// -// Will dequeue from the input queue if necessary (e.g. when the -// Reader needs to start reading from a new file since it has finished -// with the previous file). -// -// Arguments: -// reader_handle: Handle to a Reader. -// queue_handle: Handle to a Queue, with string work items. +// Flips all bits elementwise. // -// Returns A scalar.A scalar. -func ReaderReadV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output) (key tf.Output, value tf.Output) { +// The result will have exactly those bits set, that are not set in `x`. The +// computation is performed on the underlying representation of x. +func Invert(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ReaderReadV2", + Type: "Invert", Input: []tf.Input{ - reader_handle, queue_handle, + x, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Returns up to `num_records` (key, value) pairs produced by a Reader. -// -// Will dequeue from the input queue if necessary (e.g. when the -// Reader needs to start reading from a new file since it has finished -// with the previous file). -// It may return less than `num_records` even before the last batch. -// -// Arguments: -// reader_handle: Handle to a `Reader`. -// queue_handle: Handle to a `Queue`, with string work items. -// num_records: number of records to read from `Reader`. +// Deprecated. Disallowed in GraphDef version >= 2. // -// Returns A 1-D tensor.A 1-D tensor. -func ReaderReadUpToV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output, num_records tf.Output) (keys tf.Output, values tf.Output) { +// DEPRECATED at GraphDef version 2: Use AdjustContrastv2 instead +func AdjustContrast(scope *Scope, images tf.Output, contrast_factor tf.Output, min_value tf.Output, max_value tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ReaderReadUpToV2", + Type: "AdjustContrast", Input: []tf.Input{ - reader_handle, queue_handle, num_records, + images, contrast_factor, min_value, max_value, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Restore a Reader to its initial clean state. +// Table initializer that takes two tensors for keys and values respectively. // // Arguments: -// reader_handle: Handle to a Reader. +// table_handle: Handle to a table which will be initialized. +// keys: Keys of type Tkey. +// values: Values of type Tval. // // Returns the created operation. -func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) { +func InitializeTableV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ReaderResetV2", + Type: "InitializeTableV2", Input: []tf.Input{ - reader_handle, + table_handle, keys, values, }, } return scope.AddOperation(opspec) } -// ResourceApplyAdamAttr is an optional argument to ResourceApplyAdam. -type ResourceApplyAdamAttr func(optionalAttr) +// PrintAttr is an optional argument to Print. +type PrintAttr func(optionalAttr) -// ResourceApplyAdamUseLocking sets the optional use_locking attribute to value. +// PrintMessage sets the optional message attribute to value. // -// value: If `True`, updating of the var, m, and v tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyAdamUseLocking(value bool) ResourceApplyAdamAttr { +// value: A string, prefix of the error message. +// If not specified, defaults to "" +func PrintMessage(value string) PrintAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["message"] = value } } -// ResourceApplyAdamUseNesterov sets the optional use_nesterov attribute to value. +// PrintFirstN sets the optional first_n attribute to value. // -// value: If `True`, uses the nesterov update. -// If not specified, defaults to false -func ResourceApplyAdamUseNesterov(value bool) ResourceApplyAdamAttr { +// value: Only log `first_n` number of times. -1 disables logging. +// If not specified, defaults to -1 +func PrintFirstN(value int64) PrintAttr { return func(m optionalAttr) { - m["use_nesterov"] = value + m["first_n"] = value } } -// Update '*var' according to the Adam algorithm. +// PrintSummarize sets the optional summarize attribute to value. // -// lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) -// m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t -// v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t -// variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon) +// value: Only print this many entries of each tensor. +// If not specified, defaults to 3 +func PrintSummarize(value int64) PrintAttr { + return func(m optionalAttr) { + m["summarize"] = value + } +} + +// Prints a list of tensors. +// +// Passes `input` through to `output` and prints `data` when evaluating. // // Arguments: -// var_: Should be from a Variable(). -// m: Should be from a Variable(). -// v: Should be from a Variable(). -// beta1_power: Must be a scalar. -// beta2_power: Must be a scalar. -// lr: Scaling factor. Must be a scalar. -// beta1: Momentum factor. Must be a scalar. -// beta2: Momentum factor. Must be a scalar. -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. +// input: The tensor passed to `output` +// data: A list of tensors to print out when op is evaluated. // -// Returns the created operation. -func ResourceApplyAdam(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, beta2_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdamAttr) (o *tf.Operation) { +// Returns = The unmodified `input` tensor +func Print(scope *Scope, input tf.Output, data []tf.Output, optional ...PrintAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -18698,85 +18320,103 @@ func ResourceApplyAdam(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, b a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyAdam", + Type: "Print", Input: []tf.Input{ - var_, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, + input, tf.OutputList(data), }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Store the input tensor in the state of the current session. +// Outputs a `Summary` protocol buffer with a tensor and per-plugin data. // // Arguments: -// value: The tensor to be stored. -// -// Returns The handle for the tensor stored in the session state, represented -// as a ResourceHandle object. -func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) { +// tag: A string attached to this summary. Used for organization in TensorBoard. +// tensor: A tensor to serialize. +// serialized_summary_metadata: A serialized SummaryMetadata proto. Contains plugin +// data. +func TensorSummaryV2(scope *Scope, tag tf.Output, tensor tf.Output, serialized_summary_metadata tf.Output) (summary tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "GetSessionHandleV2", + Type: "TensorSummaryV2", Input: []tf.Input{ - value, + tag, tensor, serialized_summary_metadata, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the set of files matching one or more glob patterns. -// -// Note that this routine only supports wildcard characters in the -// basename portion of the pattern, not in the directory portion. +// Creates a dataset that asynchronously prefetches elements from `input_dataset`. // // Arguments: -// pattern: Shell wildcard pattern(s). Scalar or vector of type string. // -// Returns A vector of matching filenames. -func MatchingFiles(scope *Scope, pattern tf.Output) (filenames tf.Output) { +// buffer_size: The maximum number of elements to buffer in an iterator over +// this dataset. +// +// +func PrefetchDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "MatchingFiles", + Type: "PrefetchDataset", Input: []tf.Input{ - pattern, + input_dataset, buffer_size, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResizeBicubicGradAttr is an optional argument to ResizeBicubicGrad. -type ResizeBicubicGradAttr func(optionalAttr) +// TensorSummaryAttr is an optional argument to TensorSummary. +type TensorSummaryAttr func(optionalAttr) -// ResizeBicubicGradAlignCorners sets the optional align_corners attribute to value. +// TensorSummaryDescription sets the optional description attribute to value. // -// value: If true, rescale grads by (orig_height - 1) / (height - 1), which -// exactly aligns the 4 corners of grads and original_image. If false, rescale by -// orig_height / height. Treat similarly the width dimension. -// If not specified, defaults to false -func ResizeBicubicGradAlignCorners(value bool) ResizeBicubicGradAttr { +// value: A json-encoded SummaryDescription proto. +// If not specified, defaults to "" +func TensorSummaryDescription(value string) TensorSummaryAttr { return func(m optionalAttr) { - m["align_corners"] = value + m["description"] = value } } -// Computes the gradient of bicubic interpolation. +// TensorSummaryLabels sets the optional labels attribute to value. // -// Arguments: -// grads: 4-D with shape `[batch, height, width, channels]`. -// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, -// The image tensor that was resized. +// value: An unused list of strings. +// If not specified, defaults to <> +func TensorSummaryLabels(value []string) TensorSummaryAttr { + return func(m optionalAttr) { + m["labels"] = value + } +} + +// TensorSummaryDisplayName sets the optional display_name attribute to value. // -// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. -// Gradients with respect to the input image. Input image must have been -// float or double. -func ResizeBicubicGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBicubicGradAttr) (output tf.Output) { +// value: An unused string. +// If not specified, defaults to "" +func TensorSummaryDisplayName(value string) TensorSummaryAttr { + return func(m optionalAttr) { + m["display_name"] = value + } +} + +// Outputs a `Summary` protocol buffer with a tensor. +// +// This op is being phased out in favor of TensorSummaryV2, which lets callers pass +// a tag as well as a serialized SummaryMetadata proto string that contains +// plugin-specific data. We will keep this op to maintain backwards compatibility. +// +// Arguments: +// tensor: A tensor to serialize. +func TensorSummary(scope *Scope, tensor tf.Output, optional ...TensorSummaryAttr) (summary tf.Output) { if scope.Err() != nil { return } @@ -18785,9 +18425,9 @@ func ResizeBicubicGrad(scope *Scope, grads tf.Output, original_image tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "ResizeBicubicGrad", + Type: "TensorSummary", Input: []tf.Input{ - grads, original_image, + tensor, }, Attrs: attrs, } @@ -18795,189 +18435,163 @@ func ResizeBicubicGrad(scope *Scope, grads tf.Output, original_image tf.Output, return op.Output(0) } -// ResizeNearestNeighborAttr is an optional argument to ResizeNearestNeighbor. -type ResizeNearestNeighborAttr func(optionalAttr) - -// ResizeNearestNeighborAlignCorners sets the optional align_corners attribute to value. +// Computes the gradient for the tanh of `x` wrt its input. // -// value: If true, rescale input by (new_height - 1) / (height - 1), which -// exactly aligns the 4 corners of images and resized images. If false, rescale -// by new_height / height. Treat similarly the width dimension. -// If not specified, defaults to false -func ResizeNearestNeighborAlignCorners(value bool) ResizeNearestNeighborAttr { - return func(m optionalAttr) { - m["align_corners"] = value +// Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy` +// is the corresponding input gradient. +func TanhGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TanhGrad", + Input: []tf.Input{ + y, dy, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Resize `images` to `size` using nearest neighbor interpolation. +// Outputs a `Summary` protocol buffer with scalar values. +// +// The input `tags` and `values` must have the same shape. The generated summary +// has a summary value for each tag-value pair in `tags` and `values`. // // Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. +// tags: Tags for the summary. +// values: Same shape as `tags. Values for the summary. // -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeNearestNeighbor(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeNearestNeighborAttr) (resized_images tf.Output) { +// Returns Scalar. Serialized `Summary` protocol buffer. +func ScalarSummary(scope *Scope, tags tf.Output, values tf.Output) (summary tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ResizeNearestNeighbor", + Type: "ScalarSummary", Input: []tf.Input{ - images, size, + tags, values, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResizeNearestNeighborGradAttr is an optional argument to ResizeNearestNeighborGrad. -type ResizeNearestNeighborGradAttr func(optionalAttr) - -// ResizeNearestNeighborGradAlignCorners sets the optional align_corners attribute to value. +// Outputs a `Summary` protocol buffer with a histogram. // -// value: If true, rescale grads by (orig_height - 1) / (height - 1), which -// exactly aligns the 4 corners of grads and original_image. If false, rescale by -// orig_height / height. Treat similarly the width dimension. -// If not specified, defaults to false -func ResizeNearestNeighborGradAlignCorners(value bool) ResizeNearestNeighborGradAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// Computes the gradient of nearest neighbor interpolation. +// The generated +// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) +// has one summary value containing a histogram for `values`. +// +// This op reports an `InvalidArgument` error if any value is not finite. // // Arguments: -// grads: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The -// original input size. +// tag: Scalar. Tag to use for the `Summary.Value`. +// values: Any shape. Values to use to build the histogram. // -// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients -// with respect to the input image. -func ResizeNearestNeighborGrad(scope *Scope, grads tf.Output, size tf.Output, optional ...ResizeNearestNeighborGradAttr) (output tf.Output) { +// Returns Scalar. Serialized `Summary` protocol buffer. +func HistogramSummary(scope *Scope, tag tf.Output, values tf.Output) (summary tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ResizeNearestNeighborGrad", + Type: "HistogramSummary", Input: []tf.Input{ - grads, size, + tag, values, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// DecodeJpegAttr is an optional argument to DecodeJpeg. -type DecodeJpegAttr func(optionalAttr) - -// DecodeJpegChannels sets the optional channels attribute to value. +// Computes the number of elements in the given queue. // -// value: Number of color channels for the decoded image. -// If not specified, defaults to 0 -func DecodeJpegChannels(value int64) DecodeJpegAttr { - return func(m optionalAttr) { - m["channels"] = value - } -} - -// DecodeJpegRatio sets the optional ratio attribute to value. +// Arguments: +// handle: The handle to a queue. // -// value: Downscaling ratio. -// If not specified, defaults to 1 -func DecodeJpegRatio(value int64) DecodeJpegAttr { - return func(m optionalAttr) { - m["ratio"] = value +// Returns The number of elements in the given queue. +func QueueSizeV2(scope *Scope, handle tf.Output) (size tf.Output) { + if scope.Err() != nil { + return } -} - -// DecodeJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. -// -// value: If true use a slower but nicer upscaling of the -// chroma planes (yuv420/422 only). -// If not specified, defaults to true -func DecodeJpegFancyUpscaling(value bool) DecodeJpegAttr { - return func(m optionalAttr) { - m["fancy_upscaling"] = value + opspec := tf.OpSpec{ + Type: "QueueSizeV2", + Input: []tf.Input{ + handle, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// DecodeJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. -// -// value: If true try to recover an image from truncated input. -// If not specified, defaults to false -func DecodeJpegTryRecoverTruncated(value bool) DecodeJpegAttr { - return func(m optionalAttr) { - m["try_recover_truncated"] = value - } -} +// ImageSummaryAttr is an optional argument to ImageSummary. +type ImageSummaryAttr func(optionalAttr) -// DecodeJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. +// ImageSummaryMaxImages sets the optional max_images attribute to value. // -// value: The minimum required fraction of lines before a truncated -// input is accepted. -// If not specified, defaults to 1 -func DecodeJpegAcceptableFraction(value float32) DecodeJpegAttr { +// value: Max number of batch elements to generate images for. +// If not specified, defaults to 3 +// +// REQUIRES: value >= 1 +func ImageSummaryMaxImages(value int64) ImageSummaryAttr { return func(m optionalAttr) { - m["acceptable_fraction"] = value + m["max_images"] = value } } -// DecodeJpegDctMethod sets the optional dct_method attribute to value. -// -// value: string specifying a hint about the algorithm used for -// decompression. Defaults to "" which maps to a system-specific -// default. Currently valid values are ["INTEGER_FAST", -// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal -// jpeg library changes to a version that does not have that specific -// option.) -// If not specified, defaults to "" -func DecodeJpegDctMethod(value string) DecodeJpegAttr { +// ImageSummaryBadColor sets the optional bad_color attribute to value. +// +// value: Color to use for pixels with non-finite values. +// If not specified, defaults to > int_val:255 int_val:0 int_val:0 int_val:255 > +func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr { return func(m optionalAttr) { - m["dct_method"] = value + m["bad_color"] = value } } -// Decode a JPEG-encoded image to a uint8 tensor. +// Outputs a `Summary` protocol buffer with images. // -// The attr `channels` indicates the desired number of color channels for the -// decoded image. +// The summary has up to `max_images` summary values containing images. The +// images are built from `tensor` which must be 4-D with shape `[batch_size, +// height, width, channels]` and where `channels` can be: // -// Accepted values are: +// * 1: `tensor` is interpreted as Grayscale. +// * 3: `tensor` is interpreted as RGB. +// * 4: `tensor` is interpreted as RGBA. // -// * 0: Use the number of channels in the JPEG-encoded image. -// * 1: output a grayscale image. -// * 3: output an RGB image. +// The images have the same number of channels as the input tensor. For float +// input, the values are normalized one image at a time to fit in the range +// `[0, 255]`. `uint8` values are unchanged. The op uses two different +// normalization algorithms: // -// If needed, the JPEG-encoded image is transformed to match the requested number -// of color channels. +// * If the input values are all positive, they are rescaled so the largest one +// is 255. // -// The attr `ratio` allows downscaling the image by an integer factor during -// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than -// downscaling the image later. +// * If any input value is negative, the values are shifted so input value 0.0 +// is at 127. They are then rescaled so that either the smallest value is 0, +// or the largest one is 255. // +// The `tag` argument is a scalar `Tensor` of type `string`. It is used to +// build the `tag` of the summary values: // -// This op also supports decoding PNGs and non-animated GIFs since the interface is -// the same, though it is cleaner to use `tf.image.decode_image`. +// * If `max_images` is 1, the summary value tag is '*tag*/image'. +// * If `max_images` is greater than 1, the summary value tags are +// generated sequentially as '*tag*/image/0', '*tag*/image/1', etc. +// +// The `bad_color` argument is the color to use in the generated images for +// non-finite input values. It is a `unit8` 1-D tensor of length `channels`. +// Each element must be in the range `[0, 255]` (It represents the value of a +// pixel in the output image). Non-finite values in the input tensor are +// replaced by this tensor in the output image. The default value is the color +// red. // // Arguments: -// contents: 0-D. The JPEG-encoded image. +// tag: Scalar. Used to build the `tag` attribute of the summary values. +// tensor: 4-D of shape `[batch_size, height, width, channels]` where +// `channels` is 1, 3, or 4. // -// Returns 3-D with shape `[height, width, channels]`.. -func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (image tf.Output) { +// Returns Scalar. Serialized `Summary` protocol buffer. +func ImageSummary(scope *Scope, tag tf.Output, tensor tf.Output, optional ...ImageSummaryAttr) (summary tf.Output) { if scope.Err() != nil { return } @@ -18986,9 +18600,9 @@ func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (i a(attrs) } opspec := tf.OpSpec{ - Type: "DecodeJpeg", + Type: "ImageSummary", Input: []tf.Input{ - contents, + tag, tensor, }, Attrs: attrs, } @@ -18996,29 +18610,42 @@ func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (i return op.Output(0) } -// ExtractJpegShapeAttr is an optional argument to ExtractJpegShape. -type ExtractJpegShapeAttr func(optionalAttr) +// AudioSummaryV2Attr is an optional argument to AudioSummaryV2. +type AudioSummaryV2Attr func(optionalAttr) -// ExtractJpegShapeOutputType sets the optional output_type attribute to value. +// AudioSummaryV2MaxOutputs sets the optional max_outputs attribute to value. // -// value: (Optional) The output type of the operation (int32 or int64). -// Defaults to int32. -// If not specified, defaults to DT_INT32 -func ExtractJpegShapeOutputType(value tf.DataType) ExtractJpegShapeAttr { +// value: Max number of batch elements to generate audio for. +// If not specified, defaults to 3 +// +// REQUIRES: value >= 1 +func AudioSummaryV2MaxOutputs(value int64) AudioSummaryV2Attr { return func(m optionalAttr) { - m["output_type"] = value + m["max_outputs"] = value } } -// Extract the shape information of a JPEG-encoded image. +// Outputs a `Summary` protocol buffer with audio. // -// This op only parses the image header, so it is much faster than DecodeJpeg. +// The summary has up to `max_outputs` summary values containing audio. The +// audio is built from `tensor` which must be 3-D with shape `[batch_size, +// frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are +// assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. +// +// The `tag` argument is a scalar `Tensor` of type `string`. It is used to +// build the `tag` of the summary values: +// +// * If `max_outputs` is 1, the summary value tag is '*tag*/audio'. +// * If `max_outputs` is greater than 1, the summary value tags are +// generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. // // Arguments: -// contents: 0-D. The JPEG-encoded image. +// tag: Scalar. Used to build the `tag` attribute of the summary values. +// tensor: 2-D of shape `[batch_size, frames]`. +// sample_rate: The sample rate of the signal in hertz. // -// Returns 1-D. The image shape with format [height, width, channels]. -func ExtractJpegShape(scope *Scope, contents tf.Output, optional ...ExtractJpegShapeAttr) (image_shape tf.Output) { +// Returns Scalar. Serialized `Summary` protocol buffer. +func AudioSummaryV2(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate tf.Output, optional ...AudioSummaryV2Attr) (summary tf.Output) { if scope.Err() != nil { return } @@ -19027,9 +18654,9 @@ func ExtractJpegShape(scope *Scope, contents tf.Output, optional ...ExtractJpegS a(attrs) } opspec := tf.OpSpec{ - Type: "ExtractJpegShape", + Type: "AudioSummaryV2", Input: []tf.Input{ - contents, + tag, tensor, sample_rate, }, Attrs: attrs, } @@ -19037,132 +18664,161 @@ func ExtractJpegShape(scope *Scope, contents tf.Output, optional ...ExtractJpegS return op.Output(0) } -// PaddingFIFOQueueV2Attr is an optional argument to PaddingFIFOQueueV2. -type PaddingFIFOQueueV2Attr func(optionalAttr) +// AvgPoolAttr is an optional argument to AvgPool. +type AvgPoolAttr func(optionalAttr) -// PaddingFIFOQueueV2Shapes sets the optional shapes attribute to value. -// -// value: The shape of each component in a value. The length of this attr must -// be either 0 or the same as the length of component_types. -// Shapes of fixed rank but variable size are allowed by setting -// any shape dimension to -1. In this case, the inputs' shape may vary along -// the given dimension, and DequeueMany will pad the given dimension with -// zeros up to the maximum shape of all elements in the given batch. -// If the length of this attr is 0, different queue elements may have -// different ranks and shapes, but only one element may be dequeued at a time. -// If not specified, defaults to <> +// AvgPoolDataFormat sets the optional data_format attribute to value. // -// REQUIRES: len(value) >= 0 -func PaddingFIFOQueueV2Shapes(value []tf.Shape) PaddingFIFOQueueV2Attr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func AvgPoolDataFormat(value string) AvgPoolAttr { return func(m optionalAttr) { - m["shapes"] = value + m["data_format"] = value } } -// PaddingFIFOQueueV2Capacity sets the optional capacity attribute to value. +// Performs average pooling on the input. // -// value: The upper bound on the number of elements in this queue. -// Negative numbers mean no limit. -// If not specified, defaults to -1 -func PaddingFIFOQueueV2Capacity(value int64) PaddingFIFOQueueV2Attr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// PaddingFIFOQueueV2Container sets the optional container attribute to value. +// Each entry in `output` is the mean of the corresponding size `ksize` +// window in `value`. // -// value: If non-empty, this queue is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func PaddingFIFOQueueV2Container(value string) PaddingFIFOQueueV2Attr { - return func(m optionalAttr) { - m["container"] = value +// Arguments: +// value: 4-D with shape `[batch, height, width, channels]`. +// ksize: The size of the sliding window for each dimension of `value`. +// strides: The stride of the sliding window for each dimension of `value`. +// padding: The type of padding algorithm to use. +// +// Returns The average pooled output tensor. +func AvgPool(scope *Scope, value tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolAttr) (output tf.Output) { + if scope.Err() != nil { + return } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "AvgPool", + Input: []tf.Input{ + value, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) } -// PaddingFIFOQueueV2SharedName sets the optional shared_name attribute to value. +// Merges summaries. // -// value: If non-empty, this queue will be shared under the given name -// across multiple sessions. -// If not specified, defaults to "" -func PaddingFIFOQueueV2SharedName(value string) PaddingFIFOQueueV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value +// This op creates a +// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) +// protocol buffer that contains the union of all the values in the input +// summaries. +// +// When the Op is run, it reports an `InvalidArgument` error if multiple values +// in the summaries to merge use the same tag. +// +// Arguments: +// inputs: Can be of any shape. Each must contain serialized `Summary` protocol +// buffers. +// +// Returns Scalar. Serialized `Summary` protocol buffer. +func MergeSummary(scope *Scope, inputs []tf.Output) (summary tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MergeSummary", + Input: []tf.Input{ + tf.OutputList(inputs), + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// A queue that produces elements in first-in first-out order. -// -// Variable-size shapes are allowed by setting the corresponding shape dimensions -// to 0 in the shape attr. In this case DequeueMany will pad up to the maximum -// size of any given element in the minibatch. See below for details. +// Computes the gradient of morphological 2-D dilation with respect to the filter. // // Arguments: -// component_types: The type of each component in a value. +// input: 4-D with shape `[batch, in_height, in_width, depth]`. +// filter: 3-D with shape `[filter_height, filter_width, depth]`. +// out_backprop: 4-D with shape `[batch, out_height, out_width, depth]`. +// strides: 1-D of length 4. The stride of the sliding window for each dimension of +// the input tensor. Must be: `[1, stride_height, stride_width, 1]`. +// rates: 1-D of length 4. The input stride for atrous morphological dilation. +// Must be: `[1, rate_height, rate_width, 1]`. +// padding: The type of padding algorithm to use. // -// Returns The handle to the queue. -func PaddingFIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...PaddingFIFOQueueV2Attr) (handle tf.Output) { +// Returns 3-D with shape `[filter_height, filter_width, depth]`. +func Dilation2DBackpropFilter(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, rates []int64, padding string) (filter_backprop tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"component_types": component_types} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} opspec := tf.OpSpec{ - Type: "PaddingFIFOQueueV2", - + Type: "Dilation2DBackpropFilter", + Input: []tf.Input{ + input, filter, out_backprop, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// DecodePngAttr is an optional argument to DecodePng. -type DecodePngAttr func(optionalAttr) +// AddSparseToTensorsMapAttr is an optional argument to AddSparseToTensorsMap. +type AddSparseToTensorsMapAttr func(optionalAttr) -// DecodePngChannels sets the optional channels attribute to value. +// AddSparseToTensorsMapContainer sets the optional container attribute to value. // -// value: Number of color channels for the decoded image. -// If not specified, defaults to 0 -func DecodePngChannels(value int64) DecodePngAttr { +// value: The container name for the `SparseTensorsMap` created by this op. +// If not specified, defaults to "" +func AddSparseToTensorsMapContainer(value string) AddSparseToTensorsMapAttr { return func(m optionalAttr) { - m["channels"] = value + m["container"] = value } } -// DecodePngDtype sets the optional dtype attribute to value. -// If not specified, defaults to DT_UINT8 -func DecodePngDtype(value tf.DataType) DecodePngAttr { +// AddSparseToTensorsMapSharedName sets the optional shared_name attribute to value. +// +// value: The shared name for the `SparseTensorsMap` created by this op. +// If blank, the new Operation's unique name is used. +// If not specified, defaults to "" +func AddSparseToTensorsMapSharedName(value string) AddSparseToTensorsMapAttr { return func(m optionalAttr) { - m["dtype"] = value + m["shared_name"] = value } } -// Decode a PNG-encoded image to a uint8 or uint16 tensor. -// -// The attr `channels` indicates the desired number of color channels for the -// decoded image. -// -// Accepted values are: +// Add a `SparseTensor` to a `SparseTensorsMap` return its handle. // -// * 0: Use the number of channels in the PNG-encoded image. -// * 1: output a grayscale image. -// * 3: output an RGB image. -// * 4: output an RGBA image. +// A `SparseTensor` is represented by three tensors: `sparse_indices`, +// `sparse_values`, and `sparse_shape`. // -// If needed, the PNG-encoded image is transformed to match the requested number -// of color channels. +// This operator takes the given `SparseTensor` and adds it to a container +// object (a `SparseTensorsMap`). A unique key within this container is generated +// in the form of an `int64`, and this is the value that is returned. // -// This op also supports decoding JPEGs and non-animated GIFs since the interface -// is the same, though it is cleaner to use `tf.image.decode_image`. +// The `SparseTensor` can then be read out as part of a minibatch by passing +// the key as a vector element to `TakeManySparseFromTensorsMap`. To ensure +// the correct `SparseTensorsMap` is accessed, ensure that the same +// `container` and `shared_name` are passed to that Op. If no `shared_name` +// is provided here, instead use the *name* of the Operation created by calling +// `AddSparseToTensorsMap` as the `shared_name` passed to +// `TakeManySparseFromTensorsMap`. Ensure the Operations are colocated. // // Arguments: -// contents: 0-D. The PNG-encoded image. +// sparse_indices: 2-D. The `indices` of the `SparseTensor`. +// sparse_values: 1-D. The `values` of the `SparseTensor`. +// sparse_shape: 1-D. The `shape` of the `SparseTensor`. // -// Returns 3-D with shape `[height, width, channels]`. -func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (image tf.Output) { +// Returns 0-D. The handle of the `SparseTensor` now stored in the +// `SparseTensorsMap`. +func AddSparseToTensorsMap(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...AddSparseToTensorsMapAttr) (sparse_handle tf.Output) { if scope.Err() != nil { return } @@ -19171,9 +18827,9 @@ func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (ima a(attrs) } opspec := tf.OpSpec{ - Type: "DecodePng", + Type: "AddSparseToTensorsMap", Input: []tf.Input{ - contents, + sparse_indices, sparse_values, sparse_shape, }, Attrs: attrs, } @@ -19181,310 +18837,390 @@ func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (ima return op.Output(0) } -// Decode the first frame of a GIF-encoded image to a uint8 tensor. +// Computes the matrix exponential of one or more square matrices: // -// GIF with frame or transparency compression are not supported -// convert animated GIF from compressed to uncompressed by: +// exp(A) = \sum_{n=0}^\infty A^n/n! // -// convert $src.gif -coalesce $dst.gif +// The exponential is computed using a combination of the scaling and squaring +// method and the Pade approximation. Details can be founds in: +// Nicholas J. Higham, "The scaling and squaring method for the matrix exponential +// revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005. // -// This op also supports decoding JPEGs and PNGs, though it is cleaner to use -// `tf.image.decode_image`. +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. The output is a tensor of the same shape as the input +// containing the exponential for all input submatrices `[..., :, :]`. // // Arguments: -// contents: 0-D. The GIF-encoded image. +// input: Shape is `[..., M, M]`. // -// Returns 4-D with shape `[num_frames, height, width, 3]`. RGB order -func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) { +// Returns Shape is `[..., M, M]`. +// +// @compatibility(scipy) +// Equivalent to scipy.linalg.expm +// @end_compatibility +func MatrixExponential(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "DecodeGif", + Type: "MatrixExponential", Input: []tf.Input{ - contents, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceApplyCenteredRMSPropAttr is an optional argument to ResourceApplyCenteredRMSProp. -type ResourceApplyCenteredRMSPropAttr func(optionalAttr) +// QueueDequeueUpToV2Attr is an optional argument to QueueDequeueUpToV2. +type QueueDequeueUpToV2Attr func(optionalAttr) -// ResourceApplyCenteredRMSPropUseLocking sets the optional use_locking attribute to value. +// QueueDequeueUpToV2TimeoutMs sets the optional timeout_ms attribute to value. // -// value: If `True`, updating of the var, mg, ms, and mom tensors is -// protected by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyCenteredRMSPropUseLocking(value bool) ResourceApplyCenteredRMSPropAttr { +// value: If the queue has fewer than n elements, this operation +// will block for up to timeout_ms milliseconds. +// Note: This option is not supported yet. +// If not specified, defaults to -1 +func QueueDequeueUpToV2TimeoutMs(value int64) QueueDequeueUpToV2Attr { return func(m optionalAttr) { - m["use_locking"] = value + m["timeout_ms"] = value } } -// Update '*var' according to the centered RMSProp algorithm. -// -// The centered RMSProp algorithm uses an estimate of the centered second moment -// (i.e., the variance) for normalization, as opposed to regular RMSProp, which -// uses the (uncentered) second moment. This often helps with training, but is -// slightly more expensive in terms of computation and memory. +// Dequeues `n` tuples of one or more tensors from the given queue. // -// Note that in dense implementation of this algorithm, mg, ms, and mom will -// update even if the grad is zero, but in this sparse implementation, mg, ms, -// and mom will not update in iterations during which the grad is zero. +// This operation is not supported by all queues. If a queue does not support +// DequeueUpTo, then an Unimplemented error is returned. // -// mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// mean_grad = decay * mean_grad + (1-decay) * gradient +// If the queue is closed and there are more than 0 but less than `n` +// elements remaining, then instead of returning an OutOfRange error like +// QueueDequeueMany, less than `n` elements are returned immediately. If +// the queue is closed and there are 0 elements left in the queue, then +// an OutOfRange error is returned just like in QueueDequeueMany. +// Otherwise the behavior is identical to QueueDequeueMany: // -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) +// This operation concatenates queue-element component tensors along the +// 0th dimension to make a single component tensor. All of the components +// in the dequeued tuple will have size n in the 0th dimension. // -// mg <- rho * mg_{t-1} + (1-rho) * grad -// ms <- rho * ms_{t-1} + (1-rho) * grad * grad -// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon) -// var <- var - mom +// This operation has `k` outputs, where `k` is the number of components in +// the tuples stored in the given queue, and output `i` is the ith +// component of the dequeued tuple. // // Arguments: -// var_: Should be from a Variable(). -// mg: Should be from a Variable(). -// ms: Should be from a Variable(). -// mom: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// rho: Decay rate. Must be a scalar. -// -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. +// handle: The handle to a queue. +// n: The number of tuples to dequeue. +// component_types: The type of each component in a tuple. // -// Returns the created operation. -func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyCenteredRMSPropAttr) (o *tf.Operation) { +// Returns One or more tensors that were dequeued as a tuple. +func QueueDequeueUpToV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueUpToV2Attr) (components []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"component_types": component_types} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyCenteredRMSProp", + Type: "QueueDequeueUpToV2", Input: []tf.Input{ - var_, mg, ms, mom, lr, rho, momentum, epsilon, grad, + handle, n, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("QueueDequeueUpToV2", err) + return + } + return components } -// Returns a list of tensors with the same shapes and contents as the input +// Computes the Cholesky decomposition of one or more square matrices. // -// tensors. +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. // -// This op can be used to override the gradient for complicated functions. For -// example, suppose y = f(x) and we wish to apply a custom function g for backprop -// such that dx = g(dy). In Python, +// The input has to be symmetric and positive definite. Only the lower-triangular +// part of the input will be used for this operation. The upper-triangular part +// will not be read. // -// ```python -// with tf.get_default_graph().gradient_override_map( -// {'IdentityN': 'OverrideGradientWithG'}): -// y, _ = identity_n([f(x), x]) +// The output is a tensor of the same shape as the input +// containing the Cholesky decompositions for all input submatrices `[..., :, :]`. // -// @tf.RegisterGradient('OverrideGradientWithG') -// def ApplyG(op, dy, _): -// return [None, g(dy)] # Do not backprop to f(x). -// ``` -func IdentityN(scope *Scope, input []tf.Output) (output []tf.Output) { +// **Note**: The gradient computation on GPU is faster for large matrices but +// not for large batch dimensions when the submatrices are small. In this +// case it might be faster to use the CPU. +// +// Arguments: +// input: Shape is `[..., M, M]`. +// +// Returns Shape is `[..., M, M]`. +func Cholesky(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "IdentityN", + Type: "Cholesky", Input: []tf.Input{ - tf.OutputList(input), + input, }, } op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Writes contents to the file at input filename. Creates file and recursively +// +// creates directory if not existing. +// +// Arguments: +// filename: scalar. The name of the file to which we write the contents. +// contents: scalar. The content to be written to the output file. +// +// Returns the created operation. +func WriteFile(scope *Scope, filename tf.Output, contents tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("IdentityN", err) + opspec := tf.OpSpec{ + Type: "WriteFile", + Input: []tf.Input{ + filename, contents, + }, + } + return scope.AddOperation(opspec) +} + +// AllAttr is an optional argument to All. +type AllAttr func(optionalAttr) + +// AllKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func AllKeepDims(value bool) AllAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the "logical and" of elements across dimensions of a tensor. +// +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. +// +// Arguments: +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. +// +// Returns The reduced tensor. +func All(scope *Scope, input tf.Output, axis tf.Output, optional ...AllAttr) (output tf.Output) { + if scope.Err() != nil { return } - return output + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "All", + Input: []tf.Input{ + input, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Computes the gradient of the sigmoid of `x` wrt its input. +// Computes the Eigen Decomposition of a batch of square self-adjoint matrices. // -// Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and -// `dy` is the corresponding input gradient. -func SigmoidGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { +// DEPRECATED at GraphDef version 11: Use SelfAdjointEigV2 instead. +// +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices, with the same constraints as the single matrix +// SelfAdjointEig. +// +// The result is a [..., M+1, M] matrix with [..., 0,:] containing the +// eigenvalues, and subsequent [...,1:, :] containing the eigenvectors. +// +// Arguments: +// input: Shape is `[..., M, M]`. +// +// Returns Shape is `[..., M+1, M]`. +func SelfAdjointEig(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SigmoidGrad", + Type: "SelfAdjointEig", Input: []tf.Input{ - y, dy, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Convert one or more images from HSV to RGB. -// -// Outputs a tensor of the same shape as the `images` tensor, containing the RGB -// value of the pixels. The output is only well defined if the value in `images` -// are in `[0,1]`. -// -// See `rgb_to_hsv` for a description of the HSV encoding. +// Computes softplus gradients for a softplus operation. // // Arguments: -// images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3. +// gradients: The backpropagated gradients to the corresponding softplus operation. +// features: The features passed as input to the corresponding softplus operation. // -// Returns `images` converted to RGB. -func HSVToRGB(scope *Scope, images tf.Output) (output tf.Output) { +// Returns The gradients: `gradients / (1 + exp(-features))`. +func SoftplusGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "HSVToRGB", + Type: "SoftplusGrad", Input: []tf.Input{ - images, + gradients, features, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// SampleDistortedBoundingBoxV2Attr is an optional argument to SampleDistortedBoundingBoxV2. -type SampleDistortedBoundingBoxV2Attr func(optionalAttr) +// SelfAdjointEigV2Attr is an optional argument to SelfAdjointEigV2. +type SelfAdjointEigV2Attr func(optionalAttr) -// SampleDistortedBoundingBoxV2Seed sets the optional seed attribute to value. +// SelfAdjointEigV2ComputeV sets the optional compute_v attribute to value. // -// value: If either `seed` or `seed2` are set to non-zero, the random number -// generator is seeded by the given `seed`. Otherwise, it is seeded by a random -// seed. -// If not specified, defaults to 0 -func SampleDistortedBoundingBoxV2Seed(value int64) SampleDistortedBoundingBoxV2Attr { +// value: If `True` then eigenvectors will be computed and returned in `v`. +// Otherwise, only the eigenvalues will be computed. +// If not specified, defaults to true +func SelfAdjointEigV2ComputeV(value bool) SelfAdjointEigV2Attr { return func(m optionalAttr) { - m["seed"] = value + m["compute_v"] = value } } -// SampleDistortedBoundingBoxV2Seed2 sets the optional seed2 attribute to value. +// Computes the eigen decomposition of one or more square self-adjoint matrices. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2Attr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// SampleDistortedBoundingBoxV2AspectRatioRange sets the optional aspect_ratio_range attribute to value. +// Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in +// `input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`. // -// value: The cropped area of the image must have an aspect ratio = -// width / height within this range. -// If not specified, defaults to -func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { - return func(m optionalAttr) { - m["aspect_ratio_range"] = value +// ```python +// # a is a tensor. +// # e is a tensor of eigenvalues. +// # v is a tensor of eigenvectors. +// e, v = self_adjoint_eig(a) +// e = self_adjoint_eig(a, compute_v=False) +// ``` +// +// Arguments: +// input: `Tensor` input of shape `[N, N]`. +// +// Returns Eigenvalues. Shape is `[N]`.Eigenvectors. Shape is `[N, N]`. +func SelfAdjointEigV2(scope *Scope, input tf.Output, optional ...SelfAdjointEigV2Attr) (e tf.Output, v tf.Output) { + if scope.Err() != nil { + return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SelfAdjointEigV2", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } -// SampleDistortedBoundingBoxV2AreaRange sets the optional area_range attribute to value. +// Adjust the saturation of one or more images. // -// value: The cropped area of the image must contain a fraction of the -// supplied image within in this range. -// If not specified, defaults to -func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { - return func(m optionalAttr) { - m["area_range"] = value +// `images` is a tensor of at least 3 dimensions. The last dimension is +// interpretted as channels, and must be three. +// +// The input image is considered in the RGB colorspace. Conceptually, the RGB +// colors are first mapped into HSV. A scale is then applied all the saturation +// values, and then remapped back to RGB colorspace. +// +// Arguments: +// images: Images to adjust. At least 3-D. +// scale: A float scale to add to the saturation. +// +// Returns The hue-adjusted image or images. +func AdjustSaturation(scope *Scope, images tf.Output, scale tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "AdjustSaturation", + Input: []tf.Input{ + images, scale, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// SampleDistortedBoundingBoxV2MaxAttempts sets the optional max_attempts attribute to value. +// SvdAttr is an optional argument to Svd. +type SvdAttr func(optionalAttr) + +// SvdComputeUv sets the optional compute_uv attribute to value. // -// value: Number of attempts at generating a cropped region of the image -// of the specified constraints. After `max_attempts` failures, return the entire -// image. -// If not specified, defaults to 100 -func SampleDistortedBoundingBoxV2MaxAttempts(value int64) SampleDistortedBoundingBoxV2Attr { +// value: If true, left and right singular vectors will be +// computed and returned in `u` and `v`, respectively. +// If false, `u` and `v` are not set and should never referenced. +// If not specified, defaults to true +func SvdComputeUv(value bool) SvdAttr { return func(m optionalAttr) { - m["max_attempts"] = value + m["compute_uv"] = value } } -// SampleDistortedBoundingBoxV2UseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value. +// SvdFullMatrices sets the optional full_matrices attribute to value. // -// value: Controls behavior if no bounding boxes supplied. -// If true, assume an implicit bounding box covering the whole input. If false, -// raise an error. +// value: If true, compute full-sized `u` and `v`. If false +// (the default), compute only the leading `P` singular vectors. +// Ignored if `compute_uv` is `False`. // If not specified, defaults to false -func SampleDistortedBoundingBoxV2UseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxV2Attr { +func SvdFullMatrices(value bool) SvdAttr { return func(m optionalAttr) { - m["use_image_if_no_bounding_boxes"] = value + m["full_matrices"] = value } } -// Generate a single randomly distorted bounding box for an image. -// -// Bounding box annotations are often supplied in addition to ground-truth labels -// in image recognition or object localization tasks. A common technique for -// training such a system is to randomly distort an image while preserving -// its content, i.e. *data augmentation*. This Op outputs a randomly distorted -// localization of an object, i.e. bounding box, given an `image_size`, -// `bounding_boxes` and a series of constraints. -// -// The output of this Op is a single bounding box that may be used to crop the -// original image. The output is returned as 3 tensors: `begin`, `size` and -// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the -// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize -// what the bounding box looks like. -// -// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The -// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and -// height of the underlying image. +// Computes the singular value decompositions of one or more matrices. // -// For example, +// Computes the SVD of each inner matrix in `input` such that +// `input[..., :, :] = u[..., :, :] * diag(s[..., :, :]) * transpose(v[..., :, :])` // // ```python -// # Generate a single distorted bounding box. -// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( -// tf.shape(image), -// bounding_boxes=bounding_boxes) -// -// # Draw the bounding box in an image summary. -// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), -// bbox_for_draw) -// tf.summary.image('images_with_box', image_with_box) -// -// # Employ the bounding box to distort the image. -// distorted_image = tf.slice(image, begin, size) +// # a is a tensor containing a batch of matrices. +// # s is a tensor of singular values for each matrix. +// # u is the tensor containing of left singular vectors for each matrix. +// # v is the tensor containing of right singular vectors for each matrix. +// s, u, v = svd(a) +// s, _, _ = svd(a, compute_uv=False) // ``` // -// Note that if no bounding box information is available, setting -// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit -// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is -// false and no bounding boxes are supplied, an error is raised. -// // Arguments: -// image_size: 1-D, containing `[height, width, channels]`. -// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes -// associated with the image. -// min_object_covered: The cropped area of the image must contain at least this -// fraction of any bounding box supplied. The value of this parameter should be -// non-negative. In the case of 0, the cropped area does not need to overlap -// any of the bounding boxes supplied. +// input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions +// form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`. // -// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to -// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to -// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box. -// Provide as input to `tf.image.draw_bounding_boxes`. -func SampleDistortedBoundingBoxV2(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, min_object_covered tf.Output, optional ...SampleDistortedBoundingBoxV2Attr) (begin tf.Output, size tf.Output, bboxes tf.Output) { +// Returns Singular values. Shape is `[..., P]`.Left singular vectors. If `full_matrices` is `False` then shape is +// `[..., M, P]`; if `full_matrices` is `True` then shape is +// `[..., M, M]`. Undefined if `compute_uv` is `False`.Left singular vectors. If `full_matrices` is `False` then shape is +// `[..., N, P]`. If `full_matrices` is `True` then shape is `[..., N, N]`. +// Undefined if `compute_uv` is false. +func Svd(scope *Scope, input tf.Output, optional ...SvdAttr) (s tf.Output, u tf.Output, v tf.Output) { if scope.Err() != nil { return } @@ -19493,88 +19229,50 @@ func SampleDistortedBoundingBoxV2(scope *Scope, image_size tf.Output, bounding_b a(attrs) } opspec := tf.OpSpec{ - Type: "SampleDistortedBoundingBoxV2", + Type: "Svd", Input: []tf.Input{ - image_size, bounding_boxes, min_object_covered, + input, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// ExtractGlimpseAttr is an optional argument to ExtractGlimpse. -type ExtractGlimpseAttr func(optionalAttr) - -// ExtractGlimpseCentered sets the optional centered attribute to value. -// -// value: indicates if the offset coordinates are centered relative to -// the image, in which case the (0, 0) offset is relative to the center -// of the input images. If false, the (0,0) offset corresponds to the -// upper left corner of the input images. -// If not specified, defaults to true -func ExtractGlimpseCentered(value bool) ExtractGlimpseAttr { - return func(m optionalAttr) { - m["centered"] = value - } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// ExtractGlimpseNormalized sets the optional normalized attribute to value. -// -// value: indicates if the offset coordinates are normalized. -// If not specified, defaults to true -func ExtractGlimpseNormalized(value bool) ExtractGlimpseAttr { - return func(m optionalAttr) { - m["normalized"] = value - } -} +// QueueEnqueueManyV2Attr is an optional argument to QueueEnqueueManyV2. +type QueueEnqueueManyV2Attr func(optionalAttr) -// ExtractGlimpseUniformNoise sets the optional uniform_noise attribute to value. +// QueueEnqueueManyV2TimeoutMs sets the optional timeout_ms attribute to value. // -// value: indicates if the noise should be generated using a -// uniform distribution or a Gaussian distribution. -// If not specified, defaults to true -func ExtractGlimpseUniformNoise(value bool) ExtractGlimpseAttr { +// value: If the queue is too full, this operation will block for up +// to timeout_ms milliseconds. +// Note: This option is not supported yet. +// If not specified, defaults to -1 +func QueueEnqueueManyV2TimeoutMs(value int64) QueueEnqueueManyV2Attr { return func(m optionalAttr) { - m["uniform_noise"] = value + m["timeout_ms"] = value } } -// Extracts a glimpse from the input tensor. -// -// Returns a set of windows called glimpses extracted at location -// `offsets` from the input tensor. If the windows only partially -// overlaps the inputs, the non overlapping areas will be filled with -// random noise. +// Enqueues zero or more tuples of one or more tensors in the given queue. // -// The result is a 4-D tensor of shape `[batch_size, glimpse_height, -// glimpse_width, channels]`. The channels and batch dimensions are the -// same as that of the input tensor. The height and width of the output -// windows are specified in the `size` parameter. +// This operation slices each component tensor along the 0th dimension to +// make multiple queue elements. All of the tuple components must have the +// same size in the 0th dimension. // -// The argument `normalized` and `centered` controls how the windows are built: +// The components input has k elements, which correspond to the components of +// tuples stored in the given queue. // -// * If the coordinates are normalized but not centered, 0.0 and 1.0 -// correspond to the minimum and maximum of each height and width -// dimension. -// * If the coordinates are both normalized and centered, they range from -// -1.0 to 1.0. The coordinates (-1.0, -1.0) correspond to the upper -// left corner, the lower right corner is located at (1.0, 1.0) and the -// center is at (0, 0). -// * If the coordinates are not normalized they are interpreted as -// numbers of pixels. +// N.B. If the queue is full, this operation will block until the given +// elements have been enqueued (or 'timeout_ms' elapses, if specified). // // Arguments: -// input: A 4-D float tensor of shape `[batch_size, height, width, channels]`. -// size: A 1-D tensor of 2 elements containing the size of the glimpses -// to extract. The glimpse height must be specified first, following -// by the glimpse width. -// offsets: A 2-D integer tensor of shape `[batch_size, 2]` containing -// the y, x locations of the center of each window. +// handle: The handle to a queue. +// components: One or more tensors from which the enqueued tensors should +// be taken. // -// Returns A tensor representing the glimpses `[batch_size, -// glimpse_height, glimpse_width, channels]`. -func ExtractGlimpse(scope *Scope, input tf.Output, size tf.Output, offsets tf.Output, optional ...ExtractGlimpseAttr) (glimpse tf.Output) { +// Returns the created operation. +func QueueEnqueueManyV2(scope *Scope, handle tf.Output, components []tf.Output, optional ...QueueEnqueueManyV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -19583,146 +19281,123 @@ func ExtractGlimpse(scope *Scope, input tf.Output, size tf.Output, offsets tf.Ou a(attrs) } opspec := tf.OpSpec{ - Type: "ExtractGlimpse", + Type: "QueueEnqueueManyV2", Input: []tf.Input{ - input, size, offsets, + handle, tf.OutputList(components), }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// A container for an iterator resource. +// Computes the product along segments of a tensor. // -// Returns A handle to the iterator that can be passed to a "MakeIterator" -// or "IteratorGetNext" op. -func Iterator(scope *Scope, shared_name string, container string, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of +// segments. +// +// Computes a tensor such that +// \\(output_i = \prod_j data_j\\) where the product is over `j` such +// that `segment_ids[j] == i`. +// +// If the product is empty for a given segment ID `i`, `output[i] = 1`. +// +//
+// +//
+// +// Arguments: +// +// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s +// first dimension. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"shared_name": shared_name, "container": container, "output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Iterator", - - Attrs: attrs, + Type: "SegmentProd", + Input: []tf.Input{ + data, segment_ids, + }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ShuffleDatasetAttr is an optional argument to ShuffleDataset. -type ShuffleDatasetAttr func(optionalAttr) - -// ShuffleDatasetReshuffleEachIteration sets the optional reshuffle_each_iteration attribute to value. -// -// value: If true, each iterator over this dataset will be given -// a different pseudorandomly generated seed, based on a sequence seeded by the -// `seed` and `seed2` inputs. If false, each iterator will be given the same -// seed, and repeated iteration over this dataset will yield the exact same -// sequence of results. -// If not specified, defaults to true -func ShuffleDatasetReshuffleEachIteration(value bool) ShuffleDatasetAttr { - return func(m optionalAttr) { - m["reshuffle_each_iteration"] = value - } -} - -// Creates a dataset that shuffles elements from `input_dataset` pseudorandomly. +// Converts one or more images from RGB to HSV. // -// Arguments: +// Outputs a tensor of the same shape as the `images` tensor, containing the HSV +// value of the pixels. The output is only well defined if the value in `images` +// are in `[0,1]`. // -// buffer_size: The number of output elements to buffer in an iterator over -// this dataset. Compare with the `min_after_dequeue` attr when creating a -// `RandomShuffleQueue`. -// seed: A scalar seed for the random number generator. If either `seed` or -// `seed2` is set to be non-zero, the random number generator is seeded -// by the given seed. Otherwise, a random seed is used. -// seed2: A second scalar seed to avoid seed collision. +// `output[..., 0]` contains hue, `output[..., 1]` contains saturation, and +// `output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0 +// corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue. // +// Arguments: +// images: 1-D or higher rank. RGB data to convert. Last dimension must be size 3. // -func ShuffleDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, seed tf.Output, seed2 tf.Output, output_types []tf.DataType, output_shapes []tf.Shape, optional ...ShuffleDatasetAttr) (handle tf.Output) { +// Returns `images` converted to HSV. +func RGBToHSV(scope *Scope, images tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ShuffleDataset", + Type: "RGBToHSV", Input: []tf.Input{ - input_dataset, buffer_size, seed, seed2, + images, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// 3D fast Fourier transform. -// -// Computes the 3-dimensional discrete Fourier transform over the inner-most 3 -// dimensions of `input`. -// -// Arguments: -// input: A complex64 tensor. -// -// Returns A complex64 tensor of the same shape as `input`. The inner-most 3 -// dimensions of `input` are replaced with their 3D Fourier transform. +// Does nothing. Only useful as a placeholder for control edges. // -// @compatibility(numpy) -// Equivalent to np.fft.fftn with 3 dimensions. -// @end_compatibility -func FFT3D(scope *Scope, input tf.Output) (output tf.Output) { +// Returns the created operation. +func NoOp(scope *Scope) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "FFT3D", - Input: []tf.Input{ - input, - }, + Type: "NoOp", } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// CropAndResizeGradBoxesAttr is an optional argument to CropAndResizeGradBoxes. -type CropAndResizeGradBoxesAttr func(optionalAttr) +// MergeV2CheckpointsAttr is an optional argument to MergeV2Checkpoints. +type MergeV2CheckpointsAttr func(optionalAttr) -// CropAndResizeGradBoxesMethod sets the optional method attribute to value. +// MergeV2CheckpointsDeleteOldDirs sets the optional delete_old_dirs attribute to value. // -// value: A string specifying the interpolation method. Only 'bilinear' is -// supported for now. -// If not specified, defaults to "bilinear" -func CropAndResizeGradBoxesMethod(value string) CropAndResizeGradBoxesAttr { +// value: see above. +// If not specified, defaults to true +func MergeV2CheckpointsDeleteOldDirs(value bool) MergeV2CheckpointsAttr { return func(m optionalAttr) { - m["method"] = value + m["delete_old_dirs"] = value } } -// Computes the gradient of the crop_and_resize op wrt the input boxes tensor. +// V2 format specific: merges the metadata files of sharded checkpoints. The +// +// result is one logical checkpoint, with one physical metadata file and renamed +// data files. +// +// Intended for "grouping" multiple checkpoints in a sharded checkpoint setup. +// +// If delete_old_dirs is true, attempts to delete recursively the dirname of each +// path in the input checkpoint_prefixes. This is useful when those paths are non +// user-facing temporary locations. // // Arguments: -// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. -// image: A 4-D tensor of shape `[batch, image_height, image_width, depth]`. -// Both `image_height` and `image_width` need to be positive. -// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor -// specifies the coordinates of a box in the `box_ind[i]` image and is specified -// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of -// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the -// `[0, 1]` interval of normalized image height is mapped to -// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in -// which case the sampled crop is an up-down flipped version of the original -// image. The width dimension is treated similarly. Normalized coordinates -// outside the `[0, 1]` range are allowed, in which case we use -// `extrapolation_value` to extrapolate the input image values. -// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. -// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. +// checkpoint_prefixes: prefixes of V2 checkpoints to merge. +// destination_prefix: scalar. The desired final prefix. Allowed to be the same +// as one of the checkpoint_prefixes. // -// Returns A 2-D tensor of shape `[num_boxes, 4]`. -func CropAndResizeGradBoxes(scope *Scope, grads tf.Output, image tf.Output, boxes tf.Output, box_ind tf.Output, optional ...CropAndResizeGradBoxesAttr) (output tf.Output) { +// Returns the created operation. +func MergeV2Checkpoints(scope *Scope, checkpoint_prefixes tf.Output, destination_prefix tf.Output, optional ...MergeV2CheckpointsAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -19731,291 +19406,524 @@ func CropAndResizeGradBoxes(scope *Scope, grads tf.Output, image tf.Output, boxe a(attrs) } opspec := tf.OpSpec{ - Type: "CropAndResizeGradBoxes", + Type: "MergeV2Checkpoints", Input: []tf.Input{ - grads, image, boxes, box_ind, + checkpoint_prefixes, destination_prefix, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Saves tensors in V2 checkpoint format. +// Saves input tensors slices to disk. +// +// This is like `Save` except that tensors can be listed in the saved file as being +// a slice of a larger tensor. `shapes_and_slices` specifies the shape of the +// larger tensor and the slice that this tensor covers. `shapes_and_slices` must +// have as many elements as `tensor_names`. +// +// Elements of the `shapes_and_slices` input must either be: +// +// * The empty string, in which case the corresponding tensor is +// saved normally. +// * A string of the form `dim0 dim1 ... dimN-1 slice-spec` where the +// `dimI` are the dimensions of the larger tensor and `slice-spec` +// specifies what part is covered by the tensor to save. +// +// `slice-spec` itself is a `:`-separated list: `slice0:slice1:...:sliceN-1` +// where each `sliceI` is either: +// +// * The string `-` meaning that the slice covers all indices of this dimension +// * `start,length` where `start` and `length` are integers. In that +// case the slice covers `length` indices starting at `start`. // -// By default, saves the named tensors in full. If the caller wishes to save -// specific slices of full tensors, "shape_and_slices" should be non-empty strings -// and correspondingly well-formed. +// See also `Save`. // // Arguments: -// prefix: Must have a single element. The prefix of the V2 checkpoint to which we -// write the tensors. -// tensor_names: shape {N}. The names of the tensors to be saved. -// shape_and_slices: shape {N}. The slice specs of the tensors to be saved. -// Empty strings indicate that they are non-partitioned tensors. -// tensors: `N` tensors to save. +// filename: Must have a single element. The name of the file to which we write the +// tensor. +// tensor_names: Shape `[N]`. The names of the tensors to be saved. +// shapes_and_slices: Shape `[N]`. The shapes and slice specifications to use when +// saving the tensors. +// data: `N` tensors to save. // // Returns the created operation. -func SaveV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and_slices tf.Output, tensors []tf.Output) (o *tf.Operation) { +func SaveSlices(scope *Scope, filename tf.Output, tensor_names tf.Output, shapes_and_slices tf.Output, data []tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SaveV2", + Type: "SaveSlices", Input: []tf.Input{ - prefix, tensor_names, shape_and_slices, tf.OutputList(tensors), + filename, tensor_names, shapes_and_slices, tf.OutputList(data), }, } return scope.AddOperation(opspec) } -// StatsAggregatorHandleAttr is an optional argument to StatsAggregatorHandle. -type StatsAggregatorHandleAttr func(optionalAttr) - -// StatsAggregatorHandleContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func StatsAggregatorHandleContainer(value string) StatsAggregatorHandleAttr { - return func(m optionalAttr) { - m["container"] = value - } -} +// DenseToDenseSetOperationAttr is an optional argument to DenseToDenseSetOperation. +type DenseToDenseSetOperationAttr func(optionalAttr) -// StatsAggregatorHandleSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func StatsAggregatorHandleSharedName(value string) StatsAggregatorHandleAttr { +// DenseToDenseSetOperationValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func DenseToDenseSetOperationValidateIndices(value bool) DenseToDenseSetOperationAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["validate_indices"] = value } } -// Creates a statistics manager resource. -func StatsAggregatorHandle(scope *Scope, optional ...StatsAggregatorHandleAttr) (handle tf.Output) { +// Applies set operation along last dimension of 2 `Tensor` inputs. +// +// See SetOperationOp::SetOperationFromContext for values of `set_operation`. +// +// Output `result` is a `SparseTensor` represented by `result_indices`, +// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this +// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` +// dimension contains the result of `set_operation` applied to the corresponding +// `[0...n-1]` dimension of `set`. +// +// Arguments: +// set1: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set2`. +// Dimension `n` contains values in a set, duplicates are allowed but ignored. +// set2: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set1`. +// Dimension `n` contains values in a set, duplicates are allowed but ignored. +// +// +// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is +// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` +// is the max result set size across all `0...n-1` dimensions. +func DenseToDenseSetOperation(scope *Scope, set1 tf.Output, set2 tf.Output, set_operation string, optional ...DenseToDenseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"set_operation": set_operation} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "StatsAggregatorHandle", - + Type: "DenseToDenseSetOperation", + Input: []tf.Input{ + set1, set2, + }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Greedily selects a subset of bounding boxes in descending order of score, -// -// pruning away boxes that have high intersection-over-union (IOU) overlap -// with previously selected boxes. Bounding boxes are supplied as -// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any -// diagonal pair of box corners and the coordinates can be provided as normalized -// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm -// is agnostic to where the origin is in the coordinate system. Note that this -// algorithm is invariant to orthogonal transformations and translations -// of the coordinate system; thus translating or reflections of the coordinate -// system result in the same boxes being selected by the algorithm. -// -// The output of this operation is a set of integers indexing into the input -// collection of bounding boxes representing the selected boxes. The bounding -// box coordinates corresponding to the selected indices can then be obtained -// using the `tf.gather operation`. For example: -// -// selected_indices = tf.image.non_max_suppression_v2( -// boxes, scores, max_output_size, iou_threshold) -// selected_boxes = tf.gather(boxes, selected_indices) -// -// Arguments: -// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. -// scores: A 1-D float tensor of shape `[num_boxes]` representing a single -// score corresponding to each box (each row of boxes). -// max_output_size: A scalar integer tensor representing the maximum number of -// boxes to be selected by non max suppression. -// iou_threshold: A 0-D float tensor representing the threshold for deciding whether -// boxes overlap too much with respect to IOU. +// Generate a sharded filename. The filename is printf formatted as // -// Returns A 1-D integer tensor of shape `[M]` representing the selected -// indices from the boxes tensor, where `M <= max_output_size`. -func NonMaxSuppressionV2(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output) (selected_indices tf.Output) { +// %s-%05d-of-%05d, basename, shard, num_shards. +func ShardedFilename(scope *Scope, basename tf.Output, shard tf.Output, num_shards tf.Output) (filename tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "NonMaxSuppressionV2", + Type: "ShardedFilename", Input: []tf.Input{ - boxes, scores, max_output_size, iou_threshold, + basename, shard, num_shards, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Reshapes a tensor. +// BatchToSpace for N-D tensors of type T. // -// Given `tensor`, this operation returns a tensor that has the same values -// as `tensor` with shape `shape`. +// This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of shape +// `block_shape + [batch]`, interleaves these blocks back into the grid defined by +// the spatial dimensions `[1, ..., M]`, to obtain a result with the same rank as +// the input. The spatial dimensions of this intermediate result are then +// optionally cropped according to `crops` to produce the output. This is the +// reverse of SpaceToBatch. See below for a precise description. // -// If one component of `shape` is the special value -1, the size of that dimension -// is computed so that the total size remains constant. In particular, a `shape` -// of `[-1]` flattens into 1-D. At most one component of `shape` can be -1. +// Arguments: +// input: N-D with shape `input_shape = [batch] + spatial_shape + remaining_shape`, +// where spatial_shape has M dimensions. +// block_shape: 1-D with shape `[M]`, all values must be >= 1. +// crops: 2-D with shape `[M, 2]`, all values must be >= 0. +// `crops[i] = [crop_start, crop_end]` specifies the amount to crop from input +// dimension `i + 1`, which corresponds to spatial dimension `i`. It is +// required that +// `crop_start[i] + crop_end[i] <= block_shape[i] * input_shape[i + 1]`. // -// If `shape` is 1-D or higher, then the operation returns a tensor with shape -// `shape` filled with the values of `tensor`. In this case, the number of elements -// implied by `shape` must be the same as the number of elements in `tensor`. +// This operation is equivalent to the following steps: // -// For example: +// 1. Reshape `input` to `reshaped` of shape: +// [block_shape[0], ..., block_shape[M-1], +// batch / prod(block_shape), +// input_shape[1], ..., input_shape[N-1]] +// +// 2. Permute dimensions of `reshaped` to produce `permuted` of shape +// [batch / prod(block_shape), +// +// input_shape[1], block_shape[0], +// ..., +// input_shape[M], block_shape[M-1], +// +// input_shape[M+1], ..., input_shape[N-1]] +// +// 3. Reshape `permuted` to produce `reshaped_permuted` of shape +// [batch / prod(block_shape), +// +// input_shape[1] * block_shape[0], +// ..., +// input_shape[M] * block_shape[M-1], +// +// input_shape[M+1], +// ..., +// input_shape[N-1]] +// +// 4. Crop the start and end of dimensions `[1, ..., M]` of +// `reshaped_permuted` according to `crops` to produce the output of shape: +// [batch / prod(block_shape), +// +// input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], +// ..., +// input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], +// +// input_shape[M+1], ..., input_shape[N-1]] +// +// Some examples: +// +// (1) For the following input of shape `[4, 1, 1, 1]`, `block_shape = [2, 2]`, and +// `crops = [[0, 0], [0, 0]]`: // // ``` -// # tensor 't' is [1, 2, 3, 4, 5, 6, 7, 8, 9] -// # tensor 't' has shape [9] -// reshape(t, [3, 3]) ==> [[1, 2, 3], -// [4, 5, 6], -// [7, 8, 9]] +// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] +// ``` +// +// The output tensor has shape `[1, 2, 2, 1]` and value: +// +// ``` +// x = [[[[1], [2]], [[3], [4]]]] +// ``` +// +// (2) For the following input of shape `[4, 1, 1, 3]`, `block_shape = [2, 2]`, and +// `crops = [[0, 0], [0, 0]]`: +// +// ``` +// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] +// ``` +// +// The output tensor has shape `[1, 2, 2, 3]` and value: +// +// ``` +// x = [[[[1, 2, 3], [4, 5, 6]], +// [[7, 8, 9], [10, 11, 12]]]] +// ``` +// +// (3) For the following input of shape `[4, 2, 2, 1]`, `block_shape = [2, 2]`, and +// `crops = [[0, 0], [0, 0]]`: +// +// ``` +// x = [[[[1], [3]], [[9], [11]]], +// [[[2], [4]], [[10], [12]]], +// [[[5], [7]], [[13], [15]]], +// [[[6], [8]], [[14], [16]]]] +// ``` +// +// The output tensor has shape `[1, 4, 4, 1]` and value: +// +// ``` +// x = [[[1], [2], [3], [4]], +// [[5], [6], [7], [8]], +// [[9], [10], [11], [12]], +// [[13], [14], [15], [16]]] +// ``` +// +// (4) For the following input of shape `[8, 1, 3, 1]`, `block_shape = [2, 2]`, and +// `crops = [[0, 0], [2, 0]]`: +// +// ``` +// x = [[[[0], [1], [3]]], [[[0], [9], [11]]], +// [[[0], [2], [4]]], [[[0], [10], [12]]], +// [[[0], [5], [7]]], [[[0], [13], [15]]], +// [[[0], [6], [8]]], [[[0], [14], [16]]]] +// ``` +// +// The output tensor has shape `[2, 2, 4, 1]` and value: +// +// ``` +// x = [[[[1], [2], [3], [4]], +// [[5], [6], [7], [8]]], +// [[[9], [10], [11], [12]], +// [[13], [14], [15], [16]]]] +// ``` +func BatchToSpaceND(scope *Scope, input tf.Output, block_shape tf.Output, crops tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BatchToSpaceND", + Input: []tf.Input{ + input, block_shape, crops, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// UnpackAttr is an optional argument to Unpack. +type UnpackAttr func(optionalAttr) + +// UnpackAxis sets the optional axis attribute to value. // -// # tensor 't' is [[[1, 1], [2, 2]], -// # [[3, 3], [4, 4]]] -// # tensor 't' has shape [2, 2, 2] -// reshape(t, [2, 4]) ==> [[1, 1, 2, 2], -// [3, 3, 4, 4]] +// value: Dimension along which to unpack. Negative values wrap around, so the +// valid range is `[-R, R)`. +// If not specified, defaults to 0 +func UnpackAxis(value int64) UnpackAttr { + return func(m optionalAttr) { + m["axis"] = value + } +} + +// Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors. // -// # tensor 't' is [[[1, 1, 1], -// # [2, 2, 2]], -// # [[3, 3, 3], -// # [4, 4, 4]], -// # [[5, 5, 5], -// # [6, 6, 6]]] -// # tensor 't' has shape [3, 2, 3] -// # pass '[-1]' to flatten 't' -// reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6] +// Unpacks `num` tensors from `value` by chipping it along the `axis` dimension. +// For example, given a tensor of shape `(A, B, C, D)`; // -// # -1 can also be used to infer the shape +// If `axis == 0` then the i'th tensor in `output` is the slice `value[i, :, :, :]` +// and each tensor in `output` will have shape `(B, C, D)`. (Note that the +// dimension unpacked along is gone, unlike `split`). // -// # -1 is inferred to be 9: -// reshape(t, [2, -1]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3], -// [4, 4, 4, 5, 5, 5, 6, 6, 6]] -// # -1 is inferred to be 2: -// reshape(t, [-1, 9]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3], -// [4, 4, 4, 5, 5, 5, 6, 6, 6]] -// # -1 is inferred to be 3: -// reshape(t, [ 2, -1, 3]) ==> [[[1, 1, 1], -// [2, 2, 2], -// [3, 3, 3]], -// [[4, 4, 4], -// [5, 5, 5], -// [6, 6, 6]]] +// If `axis == 1` then the i'th tensor in `output` is the slice `value[:, i, :, :]` +// and each tensor in `output` will have shape `(A, C, D)`. +// Etc. // -// # tensor 't' is [7] -// # shape `[]` reshapes to a scalar -// reshape(t, []) ==> 7 -// ``` +// This is the opposite of `pack`. // // Arguments: +// value: 1-D or higher, with `axis` dimension size equal to `num`. // -// shape: Defines the shape of the output tensor. -func Reshape(scope *Scope, tensor tf.Output, shape tf.Output) (output tf.Output) { +// +// Returns The list of tensors unpacked from `value`. +func Unpack(scope *Scope, value tf.Output, num int64, optional ...UnpackAttr) (output []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num": num} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Reshape", + Type: "Unpack", Input: []tf.Input{ - tensor, shape, + value, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("Unpack", err) + return + } + return output } -// Creates a dataset that splits a SparseTensor into elements row-wise. -func SparseTensorSliceDataset(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output) (handle tf.Output) { +// Increments variable pointed to by 'resource' until it reaches 'limit'. +// +// Arguments: +// resource: Should be from a scalar `Variable` node. +// limit: If incrementing ref would bring it above limit, instead generates an +// 'OutOfRange' error. +// +// +// Returns A copy of the input before increment. If nothing else modifies the +// input, the values produced will all be distinct. +func ResourceCountUpTo(scope *Scope, resource tf.Output, limit int64, T tf.DataType) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"limit": limit, "T": T} opspec := tf.OpSpec{ - Type: "SparseTensorSliceDataset", + Type: "ResourceCountUpTo", Input: []tf.Input{ - indices, values, dense_shape, + resource, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns x / y element-wise for real types. +// Delete the stack from its resource container. // -// If `x` and `y` are reals, this will return the floating-point division. +// Arguments: +// handle: The handle to a stack. // -// *NOTE*: `Div` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func RealDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Returns the created operation. +func StackCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "RealDiv", + Type: "StackCloseV2", Input: []tf.Input{ - x, y, + handle, }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Creates a dataset that concatenates `input_dataset` with `another_dataset`. -func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Generate a glob pattern matching all sharded file names. +func ShardedFilespec(scope *Scope, basename tf.Output, num_shards tf.Output) (filename tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ConcatenateDataset", + Type: "ShardedFilespec", Input: []tf.Input{ - input_dataset, another_dataset, + basename, num_shards, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Adds a value to the current value of a variable. +// TextLineReaderV2Attr is an optional argument to TextLineReaderV2. +type TextLineReaderV2Attr func(optionalAttr) + +// TextLineReaderV2SkipHeaderLines sets the optional skip_header_lines attribute to value. // -// Any ReadVariableOp which depends directly or indirectly on this assign is -// guaranteed to see the incremented value or a subsequent newer one. +// value: Number of lines to skip from the beginning of every file. +// If not specified, defaults to 0 +func TextLineReaderV2SkipHeaderLines(value int64) TextLineReaderV2Attr { + return func(m optionalAttr) { + m["skip_header_lines"] = value + } +} + +// TextLineReaderV2Container sets the optional container attribute to value. // -// Outputs the incremented value, which can be used to totally order the -// increments to this variable. +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func TextLineReaderV2Container(value string) TextLineReaderV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// TextLineReaderV2SharedName sets the optional shared_name attribute to value. // -// Arguments: -// resource: handle to the resource in which to store the variable. -// value: the value by which the variable will be incremented. +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func TextLineReaderV2SharedName(value string) TextLineReaderV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// A Reader that outputs the lines of a file delimited by '\n'. // -// Returns the created operation. -func AssignAddVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { +// Returns The handle to reference the Reader. +func TextLineReaderV2(scope *Scope, optional ...TextLineReaderV2Attr) (reader_handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "AssignAddVariableOp", - Input: []tf.Input{ - resource, value, - }, + Type: "TextLineReaderV2", + + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Records the latency of producing `input_dataset` elements in a StatsAggregator. -func LatencyStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// LoadAndRemapMatrixAttr is an optional argument to LoadAndRemapMatrix. +type LoadAndRemapMatrixAttr func(optionalAttr) + +// LoadAndRemapMatrixMaxRowsInMemory sets the optional max_rows_in_memory attribute to value. +// +// value: The maximum number of rows to load from the checkpoint at +// once. If less than or equal to 0, the entire matrix will be loaded into +// memory. Setting this arg trades increased disk reads for lower memory usage. +// If not specified, defaults to -1 +func LoadAndRemapMatrixMaxRowsInMemory(value int64) LoadAndRemapMatrixAttr { + return func(m optionalAttr) { + m["max_rows_in_memory"] = value + } +} + +// Loads a 2-D (matrix) `Tensor` with name `old_tensor_name` from the checkpoint +// +// at `ckpt_path` and potentially reorders its rows and columns using the +// specified remappings. +// +// Most users should use one of the wrapper initializers (such as +// `tf.contrib.framework.load_and_remap_matrix_initializer`) instead of this +// function directly. +// +// The remappings are 1-D tensors with the following properties: +// +// * `row_remapping` must have exactly `num_rows` entries. Row `i` of the output +// matrix will be initialized from the row corresponding to index +// `row_remapping[i]` in the old `Tensor` from the checkpoint. +// * `col_remapping` must have either 0 entries (indicating that no column +// reordering is needed) or `num_cols` entries. If specified, column `j` of the +// output matrix will be initialized from the column corresponding to index +// `col_remapping[j]` in the old `Tensor` from the checkpoint. +// * A value of -1 in either of the remappings signifies a "missing" entry. In that +// case, values from the `initializing_values` tensor will be used to fill that +// missing row or column. If `row_remapping` has `r` missing entries and +// `col_remapping` has `c` missing entries, then the following condition must be +// true: +// +// `(r * num_cols) + (c * num_rows) - (r * c) == len(initializing_values)` +// +// The remapping tensors can be generated using the GenerateVocabRemapping op. +// +// As an example, with row_remapping = [1, 0, -1], col_remapping = [0, 2, -1], +// initializing_values = [0.5, -0.5, 0.25, -0.25, 42], and w(i, j) representing +// the value from row i, column j of the old tensor in the checkpoint, the output +// matrix will look like the following: +// +// [[w(1, 0), w(1, 2), 0.5], +// [w(0, 0), w(0, 2), -0.5], +// [0.25, -0.25, 42]] +// +// Arguments: +// ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`) from +// which the old matrix `Tensor` will be loaded. +// old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint. +// row_remapping: An int `Tensor` of row remappings (generally created by +// `generate_vocab_remapping`). Even if no row remapping is needed, this must +// still be an index-valued Tensor (e.g. [0, 1, 2, ...]), or a shifted +// index-valued `Tensor` (e.g. [8, 9, 10, ...], for partitioned `Variables`). +// col_remapping: An int `Tensor` of column remappings (generally created by +// `generate_vocab_remapping`). May be a size-0 `Tensor` if only row remapping +// is to be done (e.g. column ordering is the same). +// initializing_values: A float `Tensor` containing values to fill in for cells +// in the output matrix that are not loaded from the checkpoint. Length must be +// exactly the same as the number of missing / new cells. +// num_rows: Number of rows (length of the 1st dimension) in the output matrix. +// num_cols: Number of columns (length of the 2nd dimension) in the output matrix. +// +// Returns Output matrix containing existing values loaded from the +// checkpoint, and with any missing values filled in from initializing_values. +func LoadAndRemapMatrix(scope *Scope, ckpt_path tf.Output, old_tensor_name tf.Output, row_remapping tf.Output, col_remapping tf.Output, initializing_values tf.Output, num_rows int64, num_cols int64, optional ...LoadAndRemapMatrixAttr) (output_matrix tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{"num_rows": num_rows, "num_cols": num_cols} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "LatencyStatsDataset", + Type: "LoadAndRemapMatrix", Input: []tf.Input{ - input_dataset, tag, + ckpt_path, old_tensor_name, row_remapping, col_remapping, initializing_values, }, Attrs: attrs, } @@ -20023,80 +19931,94 @@ func LatencyStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, o return op.Output(0) } -// Convert JSON-encoded Example records to binary protocol buffer strings. +// TFRecordReaderV2Attr is an optional argument to TFRecordReaderV2. +type TFRecordReaderV2Attr func(optionalAttr) + +// TFRecordReaderV2Container sets the optional container attribute to value. // -// This op translates a tensor containing Example records, encoded using -// the [standard JSON -// mapping](https://developers.google.com/protocol-buffers/docs/proto3#json), -// into a tensor containing the same records encoded as binary protocol -// buffers. The resulting tensor can then be fed to any of the other -// Example-parsing ops. +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func TFRecordReaderV2Container(value string) TFRecordReaderV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// TFRecordReaderV2SharedName sets the optional shared_name attribute to value. // -// Arguments: -// json_examples: Each string is a JSON object serialized according to the JSON -// mapping of the Example proto. +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func TFRecordReaderV2SharedName(value string) TFRecordReaderV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// TFRecordReaderV2CompressionType sets the optional compression_type attribute to value. +// If not specified, defaults to "" +func TFRecordReaderV2CompressionType(value string) TFRecordReaderV2Attr { + return func(m optionalAttr) { + m["compression_type"] = value + } +} + +// A Reader that outputs the records from a TensorFlow Records file. // -// Returns Each string is a binary Example protocol buffer corresponding -// to the respective element of `json_examples`. -func DecodeJSONExample(scope *Scope, json_examples tf.Output) (binary_examples tf.Output) { +// Returns The handle to reference the Reader. +func TFRecordReaderV2(scope *Scope, optional ...TFRecordReaderV2Attr) (reader_handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "DecodeJSONExample", - Input: []tf.Input{ - json_examples, - }, + Type: "TFRecordReaderV2", + + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the grayscale dilation of 4-D `input` and 3-D `filter` tensors. -// -// The `input` tensor has shape `[batch, in_height, in_width, depth]` and the -// `filter` tensor has shape `[filter_height, filter_width, depth]`, i.e., each -// input channel is processed independently of the others with its own structuring -// function. The `output` tensor has shape -// `[batch, out_height, out_width, depth]`. The spatial dimensions of the output -// tensor depend on the `padding` algorithm. We currently only support the default -// "NHWC" `data_format`. -// -// In detail, the grayscale morphological 2-D dilation is the max-sum correlation -// (for consistency with `conv2d`, we use unmirrored filters): -// -// output[b, y, x, c] = -// max_{dy, dx} input[b, -// strides[1] * y + rates[1] * dy, -// strides[2] * x + rates[2] * dx, -// c] + -// filter[dy, dx, c] -// -// Max-pooling is a special case when the filter has size equal to the pooling -// kernel size and contains all zeros. -// -// Note on duality: The dilation of `input` by the `filter` is equal to the -// negation of the erosion of `-input` by the reflected `filter`. -// -// Arguments: -// input: 4-D with shape `[batch, in_height, in_width, depth]`. -// filter: 3-D with shape `[filter_height, filter_width, depth]`. -// strides: The stride of the sliding window for each dimension of the input -// tensor. Must be: `[1, stride_height, stride_width, 1]`. -// rates: The input stride for atrous morphological dilation. Must be: -// `[1, rate_height, rate_width, 1]`. -// padding: The type of padding algorithm to use. +// QuantizeAndDequantizeV3Attr is an optional argument to QuantizeAndDequantizeV3. +type QuantizeAndDequantizeV3Attr func(optionalAttr) + +// QuantizeAndDequantizeV3SignedInput sets the optional signed_input attribute to value. +// If not specified, defaults to true +func QuantizeAndDequantizeV3SignedInput(value bool) QuantizeAndDequantizeV3Attr { + return func(m optionalAttr) { + m["signed_input"] = value + } +} + +// QuantizeAndDequantizeV3RangeGiven sets the optional range_given attribute to value. +// If not specified, defaults to true +func QuantizeAndDequantizeV3RangeGiven(value bool) QuantizeAndDequantizeV3Attr { + return func(m optionalAttr) { + m["range_given"] = value + } +} + +// Quantizes then dequantizes a tensor. // -// Returns 4-D with shape `[batch, out_height, out_width, depth]`. -func Dilation2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, rates []int64, padding string) (output tf.Output) { +// This is almost identical to QuantizeAndDequantizeV2, except that num_bits is a +// tensor, so its value can change during training. +func QuantizeAndDequantizeV3(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, num_bits tf.Output, optional ...QuantizeAndDequantizeV3Attr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Dilation2D", + Type: "QuantizeAndDequantizeV3", Input: []tf.Input{ - input, filter, + input, input_min, input_max, num_bits, }, Attrs: attrs, } @@ -20104,161 +20026,211 @@ func Dilation2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64 return op.Output(0) } -// Converts the given variant tensor to an iterator and stores it in the given resource. +// IdentityReaderV2Attr is an optional argument to IdentityReaderV2. +type IdentityReaderV2Attr func(optionalAttr) + +// IdentityReaderV2Container sets the optional container attribute to value. // -// Arguments: -// resource_handle: A handle to an iterator resource. -// serialized: A variant tensor storing the state of the iterator contained in the -// resource. +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func IdentityReaderV2Container(value string) IdentityReaderV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// IdentityReaderV2SharedName sets the optional shared_name attribute to value. // -// Returns the created operation. -func DeserializeIterator(scope *Scope, resource_handle tf.Output, serialized tf.Output) (o *tf.Operation) { +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func IdentityReaderV2SharedName(value string) IdentityReaderV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// A Reader that outputs the queued work as both the key and value. +// +// To use, enqueue strings in a Queue. ReaderRead will take the front +// work string and output (work, work). +// +// Returns The handle to reference the Reader. +func IdentityReaderV2(scope *Scope, optional ...IdentityReaderV2Attr) (reader_handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "DeserializeIterator", - Input: []tf.Input{ - resource_handle, serialized, - }, + Type: "IdentityReaderV2", + + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// TensorArrayConcatV2Attr is an optional argument to TensorArrayConcatV2. -type TensorArrayConcatV2Attr func(optionalAttr) +// ResourceApplyGradientDescentAttr is an optional argument to ResourceApplyGradientDescent. +type ResourceApplyGradientDescentAttr func(optionalAttr) -// TensorArrayConcatV2ElementShapeExcept0 sets the optional element_shape_except0 attribute to value. -// If not specified, defaults to -func TensorArrayConcatV2ElementShapeExcept0(value tf.Shape) TensorArrayConcatV2Attr { +// ResourceApplyGradientDescentUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, the subtraction will be protected by a lock; +// otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceApplyGradientDescentUseLocking(value bool) ResourceApplyGradientDescentAttr { return func(m optionalAttr) { - m["element_shape_except0"] = value + m["use_locking"] = value } } -// Deprecated. Use TensorArrayConcatV3 -func TensorArrayConcatV2(scope *Scope, handle tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayConcatV2Attr) (value tf.Output, lengths tf.Output) { +// Update '*var' by subtracting 'alpha' * 'delta' from it. +// +// Arguments: +// var_: Should be from a Variable(). +// alpha: Scaling factor. Must be a scalar. +// delta: The change. +// +// Returns the created operation. +func ResourceApplyGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, delta tf.Output, optional ...ResourceApplyGradientDescentAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TensorArrayConcatV2", + Type: "ResourceApplyGradientDescent", Input: []tf.Input{ - handle, flow_in, + var_, alpha, delta, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return scope.AddOperation(opspec) } -// Creates a dataset that batches and pads `batch_size` elements from the input. +// Returns the next record (key, value pair) produced by a Reader. // -// Arguments: +// Will dequeue from the input queue if necessary (e.g. when the +// Reader needs to start reading from a new file since it has finished +// with the previous file). // -// batch_size: A scalar representing the number of elements to accumulate in a -// batch. -// padded_shapes: A list of int64 tensors representing the desired padded shapes -// of the corresponding output components. These shapes may be partially -// specified, using `-1` to indicate that a particular dimension should be -// padded to the maximum size of all batch elements. -// padding_values: A list of scalars containing the padding value to use for -// each of the outputs. +// Arguments: +// reader_handle: Handle to a Reader. +// queue_handle: Handle to a Queue, with string work items. // -func PaddedBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, padded_shapes []tf.Output, padding_values []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { +// Returns A scalar.A scalar. +func ReaderReadV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output) (key tf.Output, value tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "PaddedBatchDataset", + Type: "ReaderReadV2", Input: []tf.Input{ - input_dataset, batch_size, tf.OutputList(padded_shapes), tf.OutputList(padding_values), + reader_handle, queue_handle, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Creates a dataset that batches input elements into a SparseTensor. +// Returns up to `num_records` (key, value) pairs produced by a Reader. // -// Arguments: -// input_dataset: A handle to an input dataset. Must have a single component. -// batch_size: A scalar representing the number of elements to accumulate in a -// batch. -// row_shape: A vector representing the dense shape of each row in the produced -// SparseTensor. The shape may be partially specified, using `-1` to indicate -// that a particular dimension should use the maximum size of all batch elements. +// Will dequeue from the input queue if necessary (e.g. when the +// Reader needs to start reading from a new file since it has finished +// with the previous file). +// It may return less than `num_records` even before the last batch. // +// Arguments: +// reader_handle: Handle to a `Reader`. +// queue_handle: Handle to a `Queue`, with string work items. +// num_records: number of records to read from `Reader`. // -func DenseToSparseBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, row_shape tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Returns A 1-D tensor.A 1-D tensor. +func ReaderReadUpToV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output, num_records tf.Output) (keys tf.Output, values tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "DenseToSparseBatchDataset", + Type: "ReaderReadUpToV2", Input: []tf.Input{ - input_dataset, batch_size, row_shape, + reader_handle, queue_handle, num_records, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Deprecated. Use TensorArrayGradV3 +// Restore a Reader to its initial clean state. // -// DEPRECATED at GraphDef version 26: Use TensorArrayGradV3 -func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output) { +// Arguments: +// reader_handle: Handle to a Reader. +// +// Returns the created operation. +func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"source": source} opspec := tf.OpSpec{ - Type: "TensorArrayGradV2", + Type: "ReaderResetV2", Input: []tf.Input{ - handle, flow_in, + reader_handle, }, - Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// ResourceSparseApplyAdadeltaAttr is an optional argument to ResourceSparseApplyAdadelta. -type ResourceSparseApplyAdadeltaAttr func(optionalAttr) +// ResourceApplyAdamAttr is an optional argument to ResourceApplyAdam. +type ResourceApplyAdamAttr func(optionalAttr) -// ResourceSparseApplyAdadeltaUseLocking sets the optional use_locking attribute to value. +// ResourceApplyAdamUseLocking sets the optional use_locking attribute to value. // -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// value: If `True`, updating of the var, m, and v tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. // If not specified, defaults to false -func ResourceSparseApplyAdadeltaUseLocking(value bool) ResourceSparseApplyAdadeltaAttr { +func ResourceApplyAdamUseLocking(value bool) ResourceApplyAdamAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// var: Should be from a Variable(). +// ResourceApplyAdamUseNesterov sets the optional use_nesterov attribute to value. // -// Arguments: +// value: If `True`, uses the nesterov update. +// If not specified, defaults to false +func ResourceApplyAdamUseNesterov(value bool) ResourceApplyAdamAttr { + return func(m optionalAttr) { + m["use_nesterov"] = value + } +} + +// Update '*var' according to the Adam algorithm. // -// accum: Should be from a Variable(). -// accum_update: : Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// rho: Decay factor. Must be a scalar. -// epsilon: Constant factor. Must be a scalar. +// lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) +// m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t +// v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t +// variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon) +// +// Arguments: +// var_: Should be from a Variable(). +// m: Should be from a Variable(). +// v: Should be from a Variable(). +// beta1_power: Must be a scalar. +// beta2_power: Must be a scalar. +// lr: Scaling factor. Must be a scalar. +// beta1: Momentum factor. Must be a scalar. +// beta2: Momentum factor. Must be a scalar. +// epsilon: Ridge term. Must be a scalar. // grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. // // Returns the created operation. -func ResourceSparseApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdadeltaAttr) (o *tf.Operation) { +func ResourceApplyAdam(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, beta2_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdamAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -20267,181 +20239,116 @@ func ResourceSparseApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyAdadelta", + Type: "ResourceApplyAdam", Input: []tf.Input{ - var_, accum, accum_update, lr, rho, epsilon, grad, indices, + var_, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// Identity op for gradient debugging. +// Store the input tensor in the state of the current session. // -// This op is hidden from public in Python. It is used by TensorFlow Debugger to -// register gradient tensors for gradient debugging. -// This op operates on non-reference-type tensors. -func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) { +// Arguments: +// value: The tensor to be stored. +// +// Returns The handle for the tensor stored in the session state, represented +// as a ResourceHandle object. +func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "DebugGradientIdentity", + Type: "GetSessionHandleV2", Input: []tf.Input{ - input, + value, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Return substrings from `Tensor` of strings. -// -// For each string in the input `Tensor`, creates a substring starting at index -// `pos` with a total length of `len`. -// -// If `len` defines a substring that would extend beyond the length of the input -// string, then as many characters as possible are used. -// -// If `pos` is negative or specifies a character index larger than any of the input -// strings, then an `InvalidArgumentError` is thrown. -// -// `pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on -// Op creation. -// -// *NOTE*: `Substr` supports broadcasting up to two dimensions. More about -// broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -// -// --- -// -// Examples -// -// Using scalar `pos` and `len`: -// -// ```python -// input = [b'Hello', b'World'] -// position = 1 -// length = 3 -// -// output = [b'ell', b'orl'] -// ``` -// -// Using `pos` and `len` with same shape as `input`: -// -// ```python -// input = [[b'ten', b'eleven', b'twelve'], -// [b'thirteen', b'fourteen', b'fifteen'], -// [b'sixteen', b'seventeen', b'eighteen']] -// position = [[1, 2, 3], -// [1, 2, 3], -// [1, 2, 3]] -// length = [[2, 3, 4], -// [4, 3, 2], -// [5, 5, 5]] -// -// output = [[b'en', b'eve', b'lve'], -// [b'hirt', b'urt', b'te'], -// [b'ixtee', b'vente', b'hteen']] -// ``` -// -// Broadcasting `pos` and `len` onto `input`: -// -// ``` -// input = [[b'ten', b'eleven', b'twelve'], -// [b'thirteen', b'fourteen', b'fifteen'], -// [b'sixteen', b'seventeen', b'eighteen'], -// [b'nineteen', b'twenty', b'twentyone']] -// position = [1, 2, 3] -// length = [1, 2, 3] -// -// output = [[b'e', b'ev', b'lve'], -// [b'h', b'ur', b'tee'], -// [b'i', b've', b'hte'], -// [b'i', b'en', b'nty']] -// ``` -// -// Broadcasting `input` onto `pos` and `len`: -// -// ``` -// input = b'thirteen' -// position = [1, 5, 7] -// length = [3, 2, 1] +// ResizeBicubicGradAttr is an optional argument to ResizeBicubicGrad. +type ResizeBicubicGradAttr func(optionalAttr) + +// ResizeBicubicGradAlignCorners sets the optional align_corners attribute to value. // -// output = [b'hir', b'ee', b'n'] -// ``` +// value: If true, rescale grads by (orig_height - 1) / (height - 1), which +// exactly aligns the 4 corners of grads and original_image. If false, rescale by +// orig_height / height. Treat similarly the width dimension. +// If not specified, defaults to false +func ResizeBicubicGradAlignCorners(value bool) ResizeBicubicGradAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// Computes the gradient of bicubic interpolation. // // Arguments: -// input: Tensor of strings -// pos: Scalar defining the position of first character in each substring -// len: Scalar defining the number of characters to include in each substring +// grads: 4-D with shape `[batch, height, width, channels]`. +// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, +// The image tensor that was resized. // -// Returns Tensor of substrings -func Substr(scope *Scope, input tf.Output, pos tf.Output, len tf.Output) (output tf.Output) { +// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. +// Gradients with respect to the input image. Input image must have been +// float or double. +func ResizeBicubicGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBicubicGradAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Substr", + Type: "ResizeBicubicGrad", Input: []tf.Input{ - input, pos, len, + grads, original_image, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a Dataset that returns pseudorandom numbers. -// -// Arguments: -// seed: A scalar seed for the random number generator. If either seed or -// seed2 is set to be non-zero, the random number generator is seeded -// by the given seed. Otherwise, a random seed is used. -// seed2: A second scalar seed to avoid seed collision. -// +// ResizeNearestNeighborAttr is an optional argument to ResizeNearestNeighbor. +type ResizeNearestNeighborAttr func(optionalAttr) + +// ResizeNearestNeighborAlignCorners sets the optional align_corners attribute to value. // -func RandomDataset(scope *Scope, seed tf.Output, seed2 tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "RandomDataset", - Input: []tf.Input{ - seed, seed2, - }, - Attrs: attrs, +// value: If true, rescale input by (new_height - 1) / (height - 1), which +// exactly aligns the 4 corners of images and resized images. If false, rescale +// by new_height / height. Treat similarly the width dimension. +// If not specified, defaults to false +func ResizeNearestNeighborAlignCorners(value bool) ResizeNearestNeighborAttr { + return func(m optionalAttr) { + m["align_corners"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Creates a dataset that shuffles and repeats elements from `input_dataset` -// -// pseudorandomly. +// Resize `images` to `size` using nearest neighbor interpolation. // // Arguments: +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. // -// buffer_size: The number of output elements to buffer in an iterator over -// this dataset. Compare with the `min_after_dequeue` attr when creating a -// `RandomShuffleQueue`. -// seed: A scalar seed for the random number generator. If either `seed` or -// `seed2` is set to be non-zero, the random number generator is seeded -// by the given seed. Otherwise, a random seed is used. -// seed2: A second scalar seed to avoid seed collision. -// count: A scalar representing the number of times the underlying dataset -// should be repeated. The default is `-1`, which results in infinite repetition. -// -// -func ShuffleAndRepeatDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, seed tf.Output, seed2 tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeNearestNeighbor(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeNearestNeighborAttr) (resized_images tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ShuffleAndRepeatDataset", + Type: "ResizeNearestNeighbor", Input: []tf.Input{ - input_dataset, buffer_size, seed, seed2, count, + images, size, }, Attrs: attrs, } @@ -20449,28 +20356,42 @@ func ShuffleAndRepeatDataset(scope *Scope, input_dataset tf.Output, buffer_size return op.Output(0) } -// Creates a dataset that caches elements from `input_dataset`. +// ResizeNearestNeighborGradAttr is an optional argument to ResizeNearestNeighborGrad. +type ResizeNearestNeighborGradAttr func(optionalAttr) + +// ResizeNearestNeighborGradAlignCorners sets the optional align_corners attribute to value. // -// A CacheDataset will iterate over the input_dataset, and store tensors. If the -// cache already exists, the cache will be used. If the cache is inappropriate -// (e.g. cannot be opened, contains tensors of the wrong shape / size), an error -// will the returned when used. +// value: If true, rescale grads by (orig_height - 1) / (height - 1), which +// exactly aligns the 4 corners of grads and original_image. If false, rescale by +// orig_height / height. Treat similarly the width dimension. +// If not specified, defaults to false +func ResizeNearestNeighborGradAlignCorners(value bool) ResizeNearestNeighborGradAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// Computes the gradient of nearest neighbor interpolation. // // Arguments: +// grads: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The +// original input size. // -// filename: A path on the filesystem where we should cache the dataset. Note: this -// will be a directory. -// -// -func CacheDataset(scope *Scope, input_dataset tf.Output, filename tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients +// with respect to the input image. +func ResizeNearestNeighborGrad(scope *Scope, grads tf.Output, size tf.Output, optional ...ResizeNearestNeighborGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "CacheDataset", + Type: "ResizeNearestNeighborGrad", Input: []tf.Input{ - input_dataset, filename, + grads, size, }, Attrs: attrs, } @@ -20478,552 +20399,718 @@ func CacheDataset(scope *Scope, input_dataset tf.Output, filename tf.Output, out return op.Output(0) } -// PlaceholderAttr is an optional argument to Placeholder. -type PlaceholderAttr func(optionalAttr) +// ExtractJpegShapeAttr is an optional argument to ExtractJpegShape. +type ExtractJpegShapeAttr func(optionalAttr) -// PlaceholderShape sets the optional shape attribute to value. +// ExtractJpegShapeOutputType sets the optional output_type attribute to value. // -// value: (Optional) The shape of the tensor. If the shape has 0 dimensions, the -// shape is unconstrained. -// If not specified, defaults to -func PlaceholderShape(value tf.Shape) PlaceholderAttr { +// value: (Optional) The output type of the operation (int32 or int64). +// Defaults to int32. +// If not specified, defaults to DT_INT32 +func ExtractJpegShapeOutputType(value tf.DataType) ExtractJpegShapeAttr { return func(m optionalAttr) { - m["shape"] = value + m["output_type"] = value } } -// A placeholder op for a value that will be fed into the computation. +// Extract the shape information of a JPEG-encoded image. // -// N.B. This operation will fail with an error if it is executed. It is -// intended as a way to represent a value that will always be fed, and to -// provide attrs that enable the fed value to be checked at runtime. +// This op only parses the image header, so it is much faster than DecodeJpeg. // // Arguments: -// dtype: The type of elements in the tensor. +// contents: 0-D. The JPEG-encoded image. // -// Returns A placeholder tensor that must be replaced using the feed mechanism. -func Placeholder(scope *Scope, dtype tf.DataType, optional ...PlaceholderAttr) (output tf.Output) { +// Returns 1-D. The image shape with format [height, width, channels]. +func ExtractJpegShape(scope *Scope, contents tf.Output, optional ...ExtractJpegShapeAttr) (image_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Placeholder", - + Type: "ExtractJpegShape", + Input: []tf.Input{ + contents, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a dataset that executes a SQL query and emits rows of the result set. +// PaddingFIFOQueueV2Attr is an optional argument to PaddingFIFOQueueV2. +type PaddingFIFOQueueV2Attr func(optionalAttr) + +// PaddingFIFOQueueV2Shapes sets the optional shapes attribute to value. // -// Arguments: -// driver_name: The database type. Currently, the only supported type is 'sqlite'. -// data_source_name: A connection string to connect to the database. -// query: A SQL query to execute. +// value: The shape of each component in a value. The length of this attr must +// be either 0 or the same as the length of component_types. +// Shapes of fixed rank but variable size are allowed by setting +// any shape dimension to -1. In this case, the inputs' shape may vary along +// the given dimension, and DequeueMany will pad the given dimension with +// zeros up to the maximum shape of all elements in the given batch. +// If the length of this attr is 0, different queue elements may have +// different ranks and shapes, but only one element may be dequeued at a time. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func PaddingFIFOQueueV2Shapes(value []tf.Shape) PaddingFIFOQueueV2Attr { + return func(m optionalAttr) { + m["shapes"] = value + } +} + +// PaddingFIFOQueueV2Capacity sets the optional capacity attribute to value. // +// value: The upper bound on the number of elements in this queue. +// Negative numbers mean no limit. +// If not specified, defaults to -1 +func PaddingFIFOQueueV2Capacity(value int64) PaddingFIFOQueueV2Attr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// PaddingFIFOQueueV2Container sets the optional container attribute to value. // -func SqlDataset(scope *Scope, driver_name tf.Output, data_source_name tf.Output, query tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return +// value: If non-empty, this queue is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func PaddingFIFOQueueV2Container(value string) PaddingFIFOQueueV2Attr { + return func(m optionalAttr) { + m["container"] = value } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "SqlDataset", - Input: []tf.Input{ - driver_name, data_source_name, query, - }, - Attrs: attrs, +} + +// PaddingFIFOQueueV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this queue will be shared under the given name +// across multiple sessions. +// If not specified, defaults to "" +func PaddingFIFOQueueV2SharedName(value string) PaddingFIFOQueueV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Creates a dataset that emits the records from one or more binary files. +// A queue that produces elements in first-in first-out order. +// +// Variable-size shapes are allowed by setting the corresponding shape dimensions +// to 0 in the shape attr. In this case DequeueMany will pad up to the maximum +// size of any given element in the minibatch. See below for details. // // Arguments: -// filenames: A scalar or a vector containing the name(s) of the file(s) to be -// read. -// header_bytes: A scalar representing the number of bytes to skip at the -// beginning of a file. -// record_bytes: A scalar representing the number of bytes in each record. -// footer_bytes: A scalar representing the number of bytes to skip at the end -// of a file. -// buffer_size: A scalar representing the number of bytes to buffer. Must be > 0. -func FixedLengthRecordDataset(scope *Scope, filenames tf.Output, header_bytes tf.Output, record_bytes tf.Output, footer_bytes tf.Output, buffer_size tf.Output) (handle tf.Output) { +// component_types: The type of each component in a value. +// +// Returns The handle to the queue. +func PaddingFIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...PaddingFIFOQueueV2Attr) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"component_types": component_types} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "FixedLengthRecordDataset", - Input: []tf.Input{ - filenames, header_bytes, record_bytes, footer_bytes, buffer_size, - }, + Type: "PaddingFIFOQueueV2", + + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Slice a `SparseTensor` based on the `start` and `size`. +// DecodePngAttr is an optional argument to DecodePng. +type DecodePngAttr func(optionalAttr) + +// DecodePngChannels sets the optional channels attribute to value. // -// For example, if the input is +// value: Number of color channels for the decoded image. +// If not specified, defaults to 0 +func DecodePngChannels(value int64) DecodePngAttr { + return func(m optionalAttr) { + m["channels"] = value + } +} + +// DecodePngDtype sets the optional dtype attribute to value. +// If not specified, defaults to DT_UINT8 +func DecodePngDtype(value tf.DataType) DecodePngAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Decode a PNG-encoded image to a uint8 or uint16 tensor. // -// input_tensor = shape = [2, 7] -// [ a d e ] -// [b c ] +// The attr `channels` indicates the desired number of color channels for the +// decoded image. // -// Graphically the output tensors are: +// Accepted values are: // -// sparse_slice([0, 0], [2, 4]) = shape = [2, 4] -// [ a ] -// [b c ] +// * 0: Use the number of channels in the PNG-encoded image. +// * 1: output a grayscale image. +// * 3: output an RGB image. +// * 4: output an RGBA image. // -// sparse_slice([0, 4], [2, 3]) = shape = [2, 3] -// [ d e ] -// [ ] +// If needed, the PNG-encoded image is transformed to match the requested number +// of color channels. // -// Arguments: -// indices: 2-D tensor represents the indices of the sparse tensor. -// values: 1-D tensor represents the values of the sparse tensor. -// shape: 1-D. tensor represents the shape of the sparse tensor. -// start: 1-D. tensor represents the start of the slice. -// size: 1-D. tensor represents the size of the slice. -// output indices: A list of 1-D tensors represents the indices of the output -// sparse tensors. +// This op also supports decoding JPEGs and non-animated GIFs since the interface +// is the same, though it is cleaner to use `tf.image.decode_image`. // -// Returns A list of 1-D tensors represents the values of the output sparse -// tensors.A list of 1-D tensors represents the shape of the output sparse -// tensors. -func SparseSlice(scope *Scope, indices tf.Output, values tf.Output, shape tf.Output, start tf.Output, size tf.Output) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { +// Arguments: +// contents: 0-D. The PNG-encoded image. +// +// Returns 3-D with shape `[height, width, channels]`. +func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (image tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SparseSlice", + Type: "DecodePng", Input: []tf.Input{ - indices, values, shape, start, size, + contents, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Concatenates quantized tensors along one dimension. +// Decode the first frame of a GIF-encoded image to a uint8 tensor. +// +// GIF with frame or transparency compression are not supported +// convert animated GIF from compressed to uncompressed by: +// +// convert $src.gif -coalesce $dst.gif +// +// This op also supports decoding JPEGs and PNGs, though it is cleaner to use +// `tf.image.decode_image`. // // Arguments: -// concat_dim: 0-D. The dimension along which to concatenate. Must be in the -// range [0, rank(values)). -// values: The `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. -// input_mins: The minimum scalar values for each of the input tensors. -// input_maxes: The maximum scalar values for each of the input tensors. +// contents: 0-D. The GIF-encoded image. // -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes.The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. -func QuantizedConcat(scope *Scope, concat_dim tf.Output, values []tf.Output, input_mins []tf.Output, input_maxes []tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { +// Returns 4-D with shape `[num_frames, height, width, 3]`. RGB order +func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "QuantizedConcat", + Type: "DecodeGif", Input: []tf.Input{ - concat_dim, tf.OutputList(values), tf.OutputList(input_mins), tf.OutputList(input_maxes), + contents, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Gradients for batch normalization. +// ResourceApplyCenteredRMSPropAttr is an optional argument to ResourceApplyCenteredRMSProp. +type ResourceApplyCenteredRMSPropAttr func(optionalAttr) + +// ResourceApplyCenteredRMSPropUseLocking sets the optional use_locking attribute to value. // -// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() +// value: If `True`, updating of the var, mg, ms, and mom tensors is +// protected by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyCenteredRMSPropUseLocking(value bool) ResourceApplyCenteredRMSPropAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the centered RMSProp algorithm. // -// This op is deprecated. See `tf.nn.batch_normalization`. +// The centered RMSProp algorithm uses an estimate of the centered second moment +// (i.e., the variance) for normalization, as opposed to regular RMSProp, which +// uses the (uncentered) second moment. This often helps with training, but is +// slightly more expensive in terms of computation and memory. +// +// Note that in dense implementation of this algorithm, mg, ms, and mom will +// update even if the grad is zero, but in this sparse implementation, mg, ms, +// and mom will not update in iterations during which the grad is zero. +// +// mean_square = decay * mean_square + (1-decay) * gradient ** 2 +// mean_grad = decay * mean_grad + (1-decay) * gradient +// +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) +// +// mg <- rho * mg_{t-1} + (1-rho) * grad +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon) +// var <- var - mom // // Arguments: -// t: A 4D input Tensor. -// m: A 1D mean Tensor with size matching the last dimension of t. -// This is the first output from tf.nn.moments, -// or a saved moving average thereof. -// v: A 1D variance Tensor with size matching the last dimension of t. -// This is the second output from tf.nn.moments, -// or a saved moving average thereof. -// gamma: A 1D gamma Tensor with size matching the last dimension of t. -// If "scale_after_normalization" is true, this Tensor will be multiplied -// with the normalized Tensor. -// backprop: 4D backprop Tensor. -// variance_epsilon: A small float number to avoid dividing by 0. -// scale_after_normalization: A bool indicating whether the resulted tensor -// needs to be multiplied with gamma. +// var_: Should be from a Variable(). +// mg: Should be from a Variable(). +// ms: Should be from a Variable(). +// mom: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// rho: Decay rate. Must be a scalar. // -// Returns 4D backprop tensor for input.1D backprop tensor for mean.1D backprop tensor for variance.1D backprop tensor for beta.1D backprop tensor for gamma. -func BatchNormWithGlobalNormalizationGrad(scope *Scope, t tf.Output, m tf.Output, v tf.Output, gamma tf.Output, backprop tf.Output, variance_epsilon float32, scale_after_normalization bool) (dx tf.Output, dm tf.Output, dv tf.Output, db tf.Output, dg tf.Output) { +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyCenteredRMSPropAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "BatchNormWithGlobalNormalizationGrad", + Type: "ResourceApplyCenteredRMSProp", Input: []tf.Input{ - t, m, v, gamma, backprop, + var_, mg, ms, mom, lr, rho, momentum, epsilon, grad, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) + return scope.AddOperation(opspec) } -// Creates a dataset that emits the records from one or more TFRecord files. +// Returns a list of tensors with the same shapes and contents as the input // -// Arguments: -// filenames: A scalar or vector containing the name(s) of the file(s) to be -// read. -// compression_type: A scalar containing either (i) the empty string (no -// compression), (ii) "ZLIB", or (iii) "GZIP". -// buffer_size: A scalar representing the number of bytes to buffer. A value of -// 0 means no buffering will be performed. -func TFRecordDataset(scope *Scope, filenames tf.Output, compression_type tf.Output, buffer_size tf.Output) (handle tf.Output) { +// tensors. +// +// This op can be used to override the gradient for complicated functions. For +// example, suppose y = f(x) and we wish to apply a custom function g for backprop +// such that dx = g(dy). In Python, +// +// ```python +// with tf.get_default_graph().gradient_override_map( +// {'IdentityN': 'OverrideGradientWithG'}): +// y, _ = identity_n([f(x), x]) +// +// @tf.RegisterGradient('OverrideGradientWithG') +// def ApplyG(op, dy, _): +// return [None, g(dy)] # Do not backprop to f(x). +// ``` +func IdentityN(scope *Scope, input []tf.Output) (output []tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TFRecordDataset", + Type: "IdentityN", Input: []tf.Input{ - filenames, compression_type, buffer_size, + tf.OutputList(input), }, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// BatchToSpace for 4-D tensors of type T. -// -// This is a legacy version of the more general BatchToSpaceND. -// -// Rearranges (permutes) data from batch into blocks of spatial data, followed by -// cropping. This is the reverse transformation of SpaceToBatch. More specifically, -// this op outputs a copy of the input tensor where values from the `batch` -// dimension are moved in spatial blocks to the `height` and `width` dimensions, -// followed by cropping along the `height` and `width` dimensions. -// -// Arguments: -// input: 4-D tensor with shape -// `[batch*block_size*block_size, height_pad/block_size, width_pad/block_size, -// depth]`. Note that the batch size of the input tensor must be divisible by -// `block_size * block_size`. -// crops: 2-D tensor of non-negative integers with shape `[2, 2]`. It specifies -// how many elements to crop from the intermediate result across the spatial -// dimensions as follows: -// -// crops = [[crop_top, crop_bottom], [crop_left, crop_right]] -// -// -// Returns 4-D with shape `[batch, height, width, depth]`, where: -// -// height = height_pad - crop_top - crop_bottom -// width = width_pad - crop_left - crop_right -// -// The attr `block_size` must be greater than one. It indicates the block size. -// -// Some examples: -// -// (1) For the following input of shape `[4, 1, 1, 1]` and block_size of 2: -// -// ``` -// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] -// ``` -// -// The output tensor has shape `[1, 2, 2, 1]` and value: -// -// ``` -// x = [[[[1], [2]], [[3], [4]]]] -// ``` -// -// (2) For the following input of shape `[4, 1, 1, 3]` and block_size of 2: -// -// ``` -// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] -// ``` -// -// The output tensor has shape `[1, 2, 2, 3]` and value: -// -// ``` -// x = [[[[1, 2, 3], [4, 5, 6]], -// [[7, 8, 9], [10, 11, 12]]]] -// ``` -// -// (3) For the following input of shape `[4, 2, 2, 1]` and block_size of 2: -// -// ``` -// x = [[[[1], [3]], [[9], [11]]], -// [[[2], [4]], [[10], [12]]], -// [[[5], [7]], [[13], [15]]], -// [[[6], [8]], [[14], [16]]]] -// ``` -// -// The output tensor has shape `[1, 4, 4, 1]` and value: -// -// ``` -// x = [[[1], [2], [3], [4]], -// [[5], [6], [7], [8]], -// [[9], [10], [11], [12]], -// [[13], [14], [15], [16]]] -// ``` -// -// (4) For the following input of shape `[8, 1, 2, 1]` and block_size of 2: -// -// ``` -// x = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], -// [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] -// ``` -// -// The output tensor has shape `[2, 2, 4, 1]` and value: -// -// ``` -// x = [[[[1], [3]], [[5], [7]]], -// [[[2], [4]], [[10], [12]]], -// [[[5], [7]], [[13], [15]]], -// [[[6], [8]], [[14], [16]]]] -// ``` -func BatchToSpace(scope *Scope, input tf.Output, crops tf.Output, block_size int64) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"block_size": block_size} + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("IdentityN", err) + return + } + return output +} + +// Computes the gradient of the sigmoid of `x` wrt its input. +// +// Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and +// `dy` is the corresponding input gradient. +func SigmoidGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } opspec := tf.OpSpec{ - Type: "BatchToSpace", + Type: "SigmoidGrad", Input: []tf.Input{ - input, crops, + y, dy, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Makes a new iterator from the given `dataset` and stores it in `iterator`. +// Convert one or more images from HSV to RGB. // -// This operation may be executed multiple times. Each execution will reset the -// iterator in `iterator` to the first element of `dataset`. +// Outputs a tensor of the same shape as the `images` tensor, containing the RGB +// value of the pixels. The output is only well defined if the value in `images` +// are in `[0,1]`. // -// Returns the created operation. -func MakeIterator(scope *Scope, dataset tf.Output, iterator tf.Output) (o *tf.Operation) { +// See `rgb_to_hsv` for a description of the HSV encoding. +// +// Arguments: +// images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3. +// +// Returns `images` converted to RGB. +func HSVToRGB(scope *Scope, images tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MakeIterator", + Type: "HSVToRGB", Input: []tf.Input{ - dataset, iterator, + images, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Adjust the contrast of one or more images. +// SampleDistortedBoundingBoxV2Attr is an optional argument to SampleDistortedBoundingBoxV2. +type SampleDistortedBoundingBoxV2Attr func(optionalAttr) + +// SampleDistortedBoundingBoxV2Seed sets the optional seed attribute to value. // -// `images` is a tensor of at least 3 dimensions. The last 3 dimensions are -// interpreted as `[height, width, channels]`. The other dimensions only -// represent a collection of images, such as `[batch, height, width, channels].` +// value: If either `seed` or `seed2` are set to non-zero, the random number +// generator is seeded by the given `seed`. Otherwise, it is seeded by a random +// seed. +// If not specified, defaults to 0 +func SampleDistortedBoundingBoxV2Seed(value int64) SampleDistortedBoundingBoxV2Attr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// SampleDistortedBoundingBoxV2Seed2 sets the optional seed2 attribute to value. // -// Contrast is adjusted independently for each channel of each image. +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2Attr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// SampleDistortedBoundingBoxV2AspectRatioRange sets the optional aspect_ratio_range attribute to value. // -// For each channel, the Op first computes the mean of the image pixels in the -// channel and then adjusts each component of each pixel to -// `(x - mean) * contrast_factor + mean`. +// value: The cropped area of the image must have an aspect ratio = +// width / height within this range. +// If not specified, defaults to +func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr { + return func(m optionalAttr) { + m["aspect_ratio_range"] = value + } +} + +// SampleDistortedBoundingBoxV2AreaRange sets the optional area_range attribute to value. // -// Arguments: -// images: Images to adjust. At least 3-D. -// contrast_factor: A float multiplier for adjusting contrast. +// value: The cropped area of the image must contain a fraction of the +// supplied image within in this range. +// If not specified, defaults to +func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { + return func(m optionalAttr) { + m["area_range"] = value + } +} + +// SampleDistortedBoundingBoxV2MaxAttempts sets the optional max_attempts attribute to value. // -// Returns The contrast-adjusted image or images. -func AdjustContrastv2(scope *Scope, images tf.Output, contrast_factor tf.Output) (output tf.Output) { - if scope.Err() != nil { - return +// value: Number of attempts at generating a cropped region of the image +// of the specified constraints. After `max_attempts` failures, return the entire +// image. +// If not specified, defaults to 100 +func SampleDistortedBoundingBoxV2MaxAttempts(value int64) SampleDistortedBoundingBoxV2Attr { + return func(m optionalAttr) { + m["max_attempts"] = value } - opspec := tf.OpSpec{ - Type: "AdjustContrastv2", - Input: []tf.Input{ - images, contrast_factor, - }, +} + +// SampleDistortedBoundingBoxV2UseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value. +// +// value: Controls behavior if no bounding boxes supplied. +// If true, assume an implicit bounding box covering the whole input. If false, +// raise an error. +// If not specified, defaults to false +func SampleDistortedBoundingBoxV2UseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxV2Attr { + return func(m optionalAttr) { + m["use_image_if_no_bounding_boxes"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Gets the next output from the given iterator. -func IteratorGetNext(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { +// Generate a single randomly distorted bounding box for an image. +// +// Bounding box annotations are often supplied in addition to ground-truth labels +// in image recognition or object localization tasks. A common technique for +// training such a system is to randomly distort an image while preserving +// its content, i.e. *data augmentation*. This Op outputs a randomly distorted +// localization of an object, i.e. bounding box, given an `image_size`, +// `bounding_boxes` and a series of constraints. +// +// The output of this Op is a single bounding box that may be used to crop the +// original image. The output is returned as 3 tensors: `begin`, `size` and +// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the +// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize +// what the bounding box looks like. +// +// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The +// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and +// height of the underlying image. +// +// For example, +// +// ```python +// # Generate a single distorted bounding box. +// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( +// tf.shape(image), +// bounding_boxes=bounding_boxes) +// +// # Draw the bounding box in an image summary. +// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), +// bbox_for_draw) +// tf.summary.image('images_with_box', image_with_box) +// +// # Employ the bounding box to distort the image. +// distorted_image = tf.slice(image, begin, size) +// ``` +// +// Note that if no bounding box information is available, setting +// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit +// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is +// false and no bounding boxes are supplied, an error is raised. +// +// Arguments: +// image_size: 1-D, containing `[height, width, channels]`. +// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes +// associated with the image. +// min_object_covered: The cropped area of the image must contain at least this +// fraction of any bounding box supplied. The value of this parameter should be +// non-negative. In the case of 0, the cropped area does not need to overlap +// any of the bounding boxes supplied. +// +// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to +// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to +// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box. +// Provide as input to `tf.image.draw_bounding_boxes`. +func SampleDistortedBoundingBoxV2(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, min_object_covered tf.Output, optional ...SampleDistortedBoundingBoxV2Attr) (begin tf.Output, size tf.Output, bboxes tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "IteratorGetNext", + Type: "SampleDistortedBoundingBoxV2", Input: []tf.Input{ - iterator, + image_size, bounding_boxes, min_object_covered, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return + return op.Output(0), op.Output(1), op.Output(2) +} + +// ExtractGlimpseAttr is an optional argument to ExtractGlimpse. +type ExtractGlimpseAttr func(optionalAttr) + +// ExtractGlimpseCentered sets the optional centered attribute to value. +// +// value: indicates if the offset coordinates are centered relative to +// the image, in which case the (0, 0) offset is relative to the center +// of the input images. If false, the (0,0) offset corresponds to the +// upper left corner of the input images. +// If not specified, defaults to true +func ExtractGlimpseCentered(value bool) ExtractGlimpseAttr { + return func(m optionalAttr) { + m["centered"] = value } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("IteratorGetNext", err) - return +} + +// ExtractGlimpseNormalized sets the optional normalized attribute to value. +// +// value: indicates if the offset coordinates are normalized. +// If not specified, defaults to true +func ExtractGlimpseNormalized(value bool) ExtractGlimpseAttr { + return func(m optionalAttr) { + m["normalized"] = value } - return components } -// Outputs the single element from the given dataset. +// ExtractGlimpseUniformNoise sets the optional uniform_noise attribute to value. // -// Arguments: -// dataset: A handle to a dataset that contains a single element. +// value: indicates if the noise should be generated using a +// uniform distribution or a Gaussian distribution. +// If not specified, defaults to true +func ExtractGlimpseUniformNoise(value bool) ExtractGlimpseAttr { + return func(m optionalAttr) { + m["uniform_noise"] = value + } +} + +// Extracts a glimpse from the input tensor. +// +// Returns a set of windows called glimpses extracted at location +// `offsets` from the input tensor. If the windows only partially +// overlaps the inputs, the non overlapping areas will be filled with +// random noise. +// +// The result is a 4-D tensor of shape `[batch_size, glimpse_height, +// glimpse_width, channels]`. The channels and batch dimensions are the +// same as that of the input tensor. The height and width of the output +// windows are specified in the `size` parameter. // +// The argument `normalized` and `centered` controls how the windows are built: +// +// * If the coordinates are normalized but not centered, 0.0 and 1.0 +// correspond to the minimum and maximum of each height and width +// dimension. +// * If the coordinates are both normalized and centered, they range from +// -1.0 to 1.0. The coordinates (-1.0, -1.0) correspond to the upper +// left corner, the lower right corner is located at (1.0, 1.0) and the +// center is at (0, 0). +// * If the coordinates are not normalized they are interpreted as +// numbers of pixels. // +// Arguments: +// input: A 4-D float tensor of shape `[batch_size, height, width, channels]`. +// size: A 1-D tensor of 2 elements containing the size of the glimpses +// to extract. The glimpse height must be specified first, following +// by the glimpse width. +// offsets: A 2-D integer tensor of shape `[batch_size, 2]` containing +// the y, x locations of the center of each window. // -// Returns The components of the single element of `input`. -func DatasetToSingleElement(scope *Scope, dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { +// Returns A tensor representing the glimpses `[batch_size, +// glimpse_height, glimpse_width, channels]`. +func ExtractGlimpse(scope *Scope, input tf.Output, size tf.Output, offsets tf.Output, optional ...ExtractGlimpseAttr) (glimpse tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "DatasetToSingleElement", + Type: "ExtractGlimpse", Input: []tf.Input{ - dataset, + input, size, offsets, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("DatasetToSingleElement", err) - return - } - return components + return op.Output(0) } -// Converts the given `resource_handle` representing an iterator to a string. -// -// Arguments: -// resource_handle: A handle to an iterator resource. +// A container for an iterator resource. // -// Returns A string representation of the given handle. -func IteratorToStringHandle(scope *Scope, resource_handle tf.Output) (string_handle tf.Output) { +// Returns A handle to the iterator that can be passed to a "MakeIterator" +// or "IteratorGetNext" op. +func Iterator(scope *Scope, shared_name string, container string, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"shared_name": shared_name, "container": container, "output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "IteratorToStringHandle", - Input: []tf.Input{ - resource_handle, - }, + Type: "Iterator", + + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ShapeNAttr is an optional argument to ShapeN. -type ShapeNAttr func(optionalAttr) +// CropAndResizeGradImageAttr is an optional argument to CropAndResizeGradImage. +type CropAndResizeGradImageAttr func(optionalAttr) -// ShapeNOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_INT32 -func ShapeNOutType(value tf.DataType) ShapeNAttr { +// CropAndResizeGradImageMethod sets the optional method attribute to value. +// +// value: A string specifying the interpolation method. Only 'bilinear' is +// supported for now. +// If not specified, defaults to "bilinear" +func CropAndResizeGradImageMethod(value string) CropAndResizeGradImageAttr { return func(m optionalAttr) { - m["out_type"] = value + m["method"] = value } } -// Returns shape of tensors. +// Computes the gradient of the crop_and_resize op wrt the input image tensor. // -// This operation returns N 1-D integer tensors representing shape of `input[i]s`. -func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []tf.Output) { +// Arguments: +// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. +// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor +// specifies the coordinates of a box in the `box_ind[i]` image and is specified +// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of +// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the +// `[0, 1]` interval of normalized image height is mapped to +// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in +// which case the sampled crop is an up-down flipped version of the original +// image. The width dimension is treated similarly. Normalized coordinates +// outside the `[0, 1]` range are allowed, in which case we use +// `extrapolation_value` to extrapolate the input image values. +// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. +// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. +// image_size: A 1-D tensor with value `[batch, image_height, image_width, depth]` +// containing the original image size. Both `image_height` and `image_width` need +// to be positive. +// +// +// Returns A 4-D tensor of shape `[batch, image_height, image_width, depth]`. +func CropAndResizeGradImage(scope *Scope, grads tf.Output, boxes tf.Output, box_ind tf.Output, image_size tf.Output, T tf.DataType, optional ...CropAndResizeGradImageAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"T": T} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ShapeN", + Type: "CropAndResizeGradImage", Input: []tf.Input{ - tf.OutputList(input), + grads, boxes, box_ind, image_size, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("ShapeN", err) - return - } - return output + return op.Output(0) } -// IteratorFromStringHandleAttr is an optional argument to IteratorFromStringHandle. -type IteratorFromStringHandleAttr func(optionalAttr) +// ShuffleDatasetAttr is an optional argument to ShuffleDataset. +type ShuffleDatasetAttr func(optionalAttr) -// IteratorFromStringHandleOutputTypes sets the optional output_types attribute to value. -// -// value: If specified, defines the type of each tuple component in an -// element produced by the resulting iterator. -// If not specified, defaults to <> +// ShuffleDatasetReshuffleEachIteration sets the optional reshuffle_each_iteration attribute to value. // -// REQUIRES: len(value) >= 0 -func IteratorFromStringHandleOutputTypes(value []tf.DataType) IteratorFromStringHandleAttr { +// value: If true, each iterator over this dataset will be given +// a different pseudorandomly generated seed, based on a sequence seeded by the +// `seed` and `seed2` inputs. If false, each iterator will be given the same +// seed, and repeated iteration over this dataset will yield the exact same +// sequence of results. +// If not specified, defaults to true +func ShuffleDatasetReshuffleEachIteration(value bool) ShuffleDatasetAttr { return func(m optionalAttr) { - m["output_types"] = value + m["reshuffle_each_iteration"] = value } } -// IteratorFromStringHandleOutputShapes sets the optional output_shapes attribute to value. +// Creates a dataset that shuffles elements from `input_dataset` pseudorandomly. // -// value: If specified, defines the shape of each tuple component in an -// element produced by the resulting iterator. -// If not specified, defaults to <> +// Arguments: // -// REQUIRES: len(value) >= 0 -func IteratorFromStringHandleOutputShapes(value []tf.Shape) IteratorFromStringHandleAttr { - return func(m optionalAttr) { - m["output_shapes"] = value - } -} - -// Converts the given string representing a handle to an iterator to a resource. +// buffer_size: The number of output elements to buffer in an iterator over +// this dataset. Compare with the `min_after_dequeue` attr when creating a +// `RandomShuffleQueue`. +// seed: A scalar seed for the random number generator. If either `seed` or +// `seed2` is set to be non-zero, the random number generator is seeded +// by the given seed. Otherwise, a random seed is used. +// seed2: A second scalar seed to avoid seed collision. // -// Arguments: -// string_handle: A string representation of the given handle. // -// Returns A handle to an iterator resource. -func IteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional ...IteratorFromStringHandleAttr) (resource_handle tf.Output) { +func ShuffleDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, seed tf.Output, seed2 tf.Output, output_types []tf.DataType, output_shapes []tf.Shape, optional ...ShuffleDatasetAttr) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "IteratorFromStringHandle", + Type: "ShuffleDataset", Input: []tf.Input{ - string_handle, + input_dataset, buffer_size, seed, seed2, }, Attrs: attrs, } @@ -21031,177 +21118,145 @@ func IteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional .. return op.Output(0) } -// Computes arctangent of `y/x` element-wise, respecting signs of the arguments. +// 3D fast Fourier transform. // -// This is the angle \( \theta \in [-\pi, \pi] \) such that -// \[ x = r \cos(\theta) \] -// and -// \[ y = r \sin(\theta) \] -// where \(r = \sqrt(x^2 + y^2) \). -func Atan2(scope *Scope, y tf.Output, x tf.Output) (z tf.Output) { +// Computes the 3-dimensional discrete Fourier transform over the inner-most 3 +// dimensions of `input`. +// +// Arguments: +// input: A complex64 tensor. +// +// Returns A complex64 tensor of the same shape as `input`. The inner-most 3 +// dimensions of `input` are replaced with their 3D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.fftn with 3 dimensions. +// @end_compatibility +func FFT3D(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Atan2", + Type: "FFT3D", Input: []tf.Input{ - y, x, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Return a tensor with the same shape and contents as the input tensor or value. -func Identity(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Identity", - Input: []tf.Input{ - input, - }, +// CropAndResizeGradBoxesAttr is an optional argument to CropAndResizeGradBoxes. +type CropAndResizeGradBoxesAttr func(optionalAttr) + +// CropAndResizeGradBoxesMethod sets the optional method attribute to value. +// +// value: A string specifying the interpolation method. Only 'bilinear' is +// supported for now. +// If not specified, defaults to "bilinear" +func CropAndResizeGradBoxesMethod(value string) CropAndResizeGradBoxesAttr { + return func(m optionalAttr) { + m["method"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Gather slices from `params` axis `axis` according to `indices`. -// -// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). -// Produces an output tensor with shape `params.shape[:axis] + indices.shape + -// params.shape[axis + 1:]` where: -// -// ```python -// # Scalar indices (output is rank(params) - 1). -// output[a_0, ..., a_n, b_0, ..., b_n] = -// params[a_0, ..., a_n, indices, b_0, ..., b_n] -// -// # Vector indices (output is rank(params)). -// output[a_0, ..., a_n, i, b_0, ..., b_n] = -// params[a_0, ..., a_n, indices[i], b_0, ..., b_n] -// -// # Higher rank indices (output is rank(params) + rank(indices) - 1). -// output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] = -// params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n] -// ``` -// -//
-// -//
+// Computes the gradient of the crop_and_resize op wrt the input boxes tensor. // // Arguments: -// params: The tensor from which to gather values. Must be at least rank -// `axis + 1`. -// indices: Index tensor. Must be in range `[0, params.shape[axis])`. -// axis: The axis in `params` to gather `indices` from. Defaults to the first -// dimension. Supports negative indexes. +// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. +// image: A 4-D tensor of shape `[batch, image_height, image_width, depth]`. +// Both `image_height` and `image_width` need to be positive. +// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor +// specifies the coordinates of a box in the `box_ind[i]` image and is specified +// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of +// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the +// `[0, 1]` interval of normalized image height is mapped to +// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in +// which case the sampled crop is an up-down flipped version of the original +// image. The width dimension is treated similarly. Normalized coordinates +// outside the `[0, 1]` range are allowed, in which case we use +// `extrapolation_value` to extrapolate the input image values. +// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. +// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. // -// Returns Values from `params` gathered from indices given by `indices`, with -// shape `params.shape[:axis] + indices.shape + params.shape[axis + 1:]`. -func GatherV2(scope *Scope, params tf.Output, indices tf.Output, axis tf.Output) (output tf.Output) { +// Returns A 2-D tensor of shape `[num_boxes, 4]`. +func CropAndResizeGradBoxes(scope *Scope, grads tf.Output, image tf.Output, boxes tf.Output, box_ind tf.Output, optional ...CropAndResizeGradBoxesAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "GatherV2", + Type: "CropAndResizeGradBoxes", Input: []tf.Input{ - params, indices, axis, + grads, image, boxes, box_ind, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Converts the given `resource_handle` representing an iterator to a variant tensor. +// Saves tensors in V2 checkpoint format. +// +// By default, saves the named tensors in full. If the caller wishes to save +// specific slices of full tensors, "shape_and_slices" should be non-empty strings +// and correspondingly well-formed. // // Arguments: -// resource_handle: A handle to an iterator resource. +// prefix: Must have a single element. The prefix of the V2 checkpoint to which we +// write the tensors. +// tensor_names: shape {N}. The names of the tensors to be saved. +// shape_and_slices: shape {N}. The slice specs of the tensors to be saved. +// Empty strings indicate that they are non-partitioned tensors. +// tensors: `N` tensors to save. // -// Returns A variant tensor storing the state of the iterator contained in the -// resource. -func SerializeIterator(scope *Scope, resource_handle tf.Output) (serialized tf.Output) { +// Returns the created operation. +func SaveV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and_slices tf.Output, tensors []tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SerializeIterator", + Type: "SaveV2", Input: []tf.Input{ - resource_handle, + prefix, tensor_names, shape_and_slices, tf.OutputList(tensors), }, } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// FIFOQueueV2Attr is an optional argument to FIFOQueueV2. -type FIFOQueueV2Attr func(optionalAttr) - -// FIFOQueueV2Shapes sets the optional shapes attribute to value. -// -// value: The shape of each component in a value. The length of this attr must -// be either 0 or the same as the length of component_types. If the length of -// this attr is 0, the shapes of queue elements are not constrained, and -// only one element may be dequeued at a time. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func FIFOQueueV2Shapes(value []tf.Shape) FIFOQueueV2Attr { - return func(m optionalAttr) { - m["shapes"] = value - } + return scope.AddOperation(opspec) } -// FIFOQueueV2Capacity sets the optional capacity attribute to value. -// -// value: The upper bound on the number of elements in this queue. -// Negative numbers mean no limit. -// If not specified, defaults to -1 -func FIFOQueueV2Capacity(value int64) FIFOQueueV2Attr { - return func(m optionalAttr) { - m["capacity"] = value - } -} +// StatsAggregatorHandleAttr is an optional argument to StatsAggregatorHandle. +type StatsAggregatorHandleAttr func(optionalAttr) -// FIFOQueueV2Container sets the optional container attribute to value. -// -// value: If non-empty, this queue is placed in the given container. -// Otherwise, a default container is used. +// StatsAggregatorHandleContainer sets the optional container attribute to value. // If not specified, defaults to "" -func FIFOQueueV2Container(value string) FIFOQueueV2Attr { +func StatsAggregatorHandleContainer(value string) StatsAggregatorHandleAttr { return func(m optionalAttr) { m["container"] = value } } -// FIFOQueueV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this queue will be shared under the given name -// across multiple sessions. +// StatsAggregatorHandleSharedName sets the optional shared_name attribute to value. // If not specified, defaults to "" -func FIFOQueueV2SharedName(value string) FIFOQueueV2Attr { +func StatsAggregatorHandleSharedName(value string) StatsAggregatorHandleAttr { return func(m optionalAttr) { m["shared_name"] = value } } -// A queue that produces elements in first-in first-out order. -// -// Arguments: -// component_types: The type of each component in a value. -// -// Returns The handle to the queue. -func FIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...FIFOQueueV2Attr) (handle tf.Output) { +// Creates a statistics manager resource. +func StatsAggregatorHandle(scope *Scope, optional ...StatsAggregatorHandleAttr) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"component_types": component_types} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "FIFOQueueV2", + Type: "StatsAggregatorHandle", Attrs: attrs, } @@ -21209,248 +21264,155 @@ func FIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...FIFOQu return op.Output(0) } -// Produces a summary of any statistics recorded by the given statistics manager. -func StatsAggregatorSummary(scope *Scope, iterator tf.Output) (summary tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "StatsAggregatorSummary", - Input: []tf.Input{ - iterator, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Compute the pairwise cross product. +// Greedily selects a subset of bounding boxes in descending order of score, // -// `a` and `b` must be the same shape; they can either be simple 3-element vectors, -// or any shape where the innermost dimension is 3. In the latter case, each pair -// of corresponding 3-element vectors is cross-multiplied independently. +// pruning away boxes that have high intersection-over-union (IOU) overlap +// with previously selected boxes. Bounding boxes are supplied as +// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +// diagonal pair of box corners and the coordinates can be provided as normalized +// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +// is agnostic to where the origin is in the coordinate system. Note that this +// algorithm is invariant to orthogonal transformations and translations +// of the coordinate system; thus translating or reflections of the coordinate +// system result in the same boxes being selected by the algorithm. +// +// The output of this operation is a set of integers indexing into the input +// collection of bounding boxes representing the selected boxes. The bounding +// box coordinates corresponding to the selected indices can then be obtained +// using the `tf.gather operation`. For example: +// +// selected_indices = tf.image.non_max_suppression_v2( +// boxes, scores, max_output_size, iou_threshold) +// selected_boxes = tf.gather(boxes, selected_indices) // // Arguments: -// a: A tensor containing 3-element vectors. -// b: Another tensor, of same type and shape as `a`. +// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. +// scores: A 1-D float tensor of shape `[num_boxes]` representing a single +// score corresponding to each box (each row of boxes). +// max_output_size: A scalar integer tensor representing the maximum number of +// boxes to be selected by non max suppression. +// iou_threshold: A 0-D float tensor representing the threshold for deciding whether +// boxes overlap too much with respect to IOU. // -// Returns Pairwise cross product of the vectors in `a` and `b`. -func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) { +// Returns A 1-D integer tensor of shape `[M]` representing the selected +// indices from the boxes tensor, where `M <= max_output_size`. +func NonMaxSuppressionV2(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output) (selected_indices tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Cross", + Type: "NonMaxSuppressionV2", Input: []tf.Input{ - a, b, + boxes, scores, max_output_size, iou_threshold, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Performs a padding as a preprocess during a convolution. +// Reshapes a tensor. // -// Similar to FusedResizeAndPadConv2d, this op allows for an optimized -// implementation where the spatial padding transformation stage is fused with the -// im2col lookup, but in this case without the bilinear filtering required for -// resizing. Fusing the padding prevents the need to write out the intermediate -// results as whole tensors, reducing memory pressure, and we can get some latency -// gains by merging the transformation calculations. -// The data_format attribute for Conv2D isn't supported by this op, and 'NHWC' -// order is used instead. -// Internally this op uses a single per-graph scratch buffer, which means that it -// will block if multiple versions are being run in parallel. This is because this -// operator is primarily an optimization to minimize memory usage. +// Given `tensor`, this operation returns a tensor that has the same values +// as `tensor` with shape `shape`. // -// Arguments: -// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. -// paddings: A two-column matrix specifying the padding sizes. The number of -// rows must be the same as the rank of `input`. -// filter: 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. +// If one component of `shape` is the special value -1, the size of that dimension +// is computed so that the total size remains constant. In particular, a `shape` +// of `[-1]` flattens into 1-D. At most one component of `shape` can be -1. // -// strides: 1-D of length 4. The stride of the sliding window for each dimension -// of `input`. Must be in the same order as the dimension specified with format. -// padding: The type of padding algorithm to use. -func FusedPadConv2D(scope *Scope, input tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding} - opspec := tf.OpSpec{ - Type: "FusedPadConv2D", - Input: []tf.Input{ - input, paddings, filter, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Conv2DBackpropInputAttr is an optional argument to Conv2DBackpropInput. -type Conv2DBackpropInputAttr func(optionalAttr) - -// Conv2DBackpropInputUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. -// If not specified, defaults to true -func Conv2DBackpropInputUseCudnnOnGpu(value bool) Conv2DBackpropInputAttr { - return func(m optionalAttr) { - m["use_cudnn_on_gpu"] = value - } -} - -// Conv2DBackpropInputDataFormat sets the optional data_format attribute to value. +// If `shape` is 1-D or higher, then the operation returns a tensor with shape +// `shape` filled with the values of `tensor`. In this case, the number of elements +// implied by `shape` must be the same as the number of elements in `tensor`. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Conv2DBackpropInputDilations sets the optional dilations attribute to value. +// For example: +// +// ``` +// # tensor 't' is [1, 2, 3, 4, 5, 6, 7, 8, 9] +// # tensor 't' has shape [9] +// reshape(t, [3, 3]) ==> [[1, 2, 3], +// [4, 5, 6], +// [7, 8, 9]] +// +// # tensor 't' is [[[1, 1], [2, 2]], +// # [[3, 3], [4, 4]]] +// # tensor 't' has shape [2, 2, 2] +// reshape(t, [2, 4]) ==> [[1, 1, 2, 2], +// [3, 3, 4, 4]] +// +// # tensor 't' is [[[1, 1, 1], +// # [2, 2, 2]], +// # [[3, 3, 3], +// # [4, 4, 4]], +// # [[5, 5, 5], +// # [6, 6, 6]]] +// # tensor 't' has shape [3, 2, 3] +// # pass '[-1]' to flatten 't' +// reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6] +// +// # -1 can also be used to infer the shape +// +// # -1 is inferred to be 9: +// reshape(t, [2, -1]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3], +// [4, 4, 4, 5, 5, 5, 6, 6, 6]] +// # -1 is inferred to be 2: +// reshape(t, [-1, 9]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3], +// [4, 4, 4, 5, 5, 5, 6, 6, 6]] +// # -1 is inferred to be 3: +// reshape(t, [ 2, -1, 3]) ==> [[[1, 1, 1], +// [2, 2, 2], +// [3, 3, 3]], +// [[4, 4, 4], +// [5, 5, 5], +// [6, 6, 6]]] // -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each filter -// element on that dimension. The dimension order is determined by the value of -// `data_format`, see above for details. Dilations in the batch and depth -// dimensions must be 1. -// If not specified, defaults to -func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes the gradients of convolution with respect to the input. +// # tensor 't' is [7] +// # shape `[]` reshapes to a scalar +// reshape(t, []) ==> 7 +// ``` // // Arguments: -// input_sizes: An integer vector representing the shape of `input`, -// where `input` is a 4-D `[batch, height, width, channels]` tensor. -// filter: 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. -// out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. -// Gradients w.r.t. the output of the convolution. -// strides: The stride of the sliding window for each dimension of the input -// of the convolution. Must be in the same order as the dimension specified with -// format. -// padding: The type of padding algorithm to use. // -// Returns 4-D with shape `[batch, in_height, in_width, in_channels]`. Gradient -// w.r.t. the input of the convolution. -func Conv2DBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv2DBackpropInputAttr) (output tf.Output) { +// shape: Defines the shape of the output tensor. +func Reshape(scope *Scope, tensor tf.Output, shape tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Conv2DBackpropInput", + Type: "Reshape", Input: []tf.Input{ - input_sizes, filter, out_backprop, + tensor, shape, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Interleave the values from the `data` tensors into a single tensor. -// -// Builds a merged tensor such that -// -// ```python -// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] -// ``` -// -// For example, if each `indices[m]` is scalar or vector, we have -// -// ```python -// # Scalar indices: -// merged[indices[m], ...] = data[m][...] -// -// # Vector indices: -// merged[indices[m][i], ...] = data[m][i, ...] -// ``` -// -// Each `data[i].shape` must start with the corresponding `indices[i].shape`, -// and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we -// must have `data[i].shape = indices[i].shape + constant`. In terms of this -// `constant`, the output shape is -// -// merged.shape = [max(indices)] + constant -// -// Values are merged in order, so if an index appears in both `indices[m][i]` and -// `indices[n][j]` for `(m,i) < (n,j)` the slice `data[n][j]` will appear in the -// merged result. If you do not need this guarantee, ParallelDynamicStitch might -// perform better on some devices. -// -// For example: -// -// ```python -// indices[0] = 6 -// indices[1] = [4, 1] -// indices[2] = [[5, 2], [0, 3]] -// data[0] = [61, 62] -// data[1] = [[41, 42], [11, 12]] -// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] -// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], -// [51, 52], [61, 62]] -// ``` -// -// This method can be used to merge partitions created by `dynamic_partition` -// as illustrated on the following example: -// -// ```python -// # Apply function (increments x_i) on elements for which a certain condition -// # apply (x_i != -1 in this example). -// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) -// condition_mask=tf.not_equal(x,tf.constant(-1.)) -// partitioned_data = tf.dynamic_partition( -// x, tf.cast(condition_mask, tf.int32) , 2) -// partitioned_data[1] = partitioned_data[1] + 1.0 -// condition_indices = tf.dynamic_partition( -// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) -// x = tf.dynamic_stitch(condition_indices, partitioned_data) -// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain -// # unchanged. -// ``` -// -//
-// -//
-func DynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { +// Creates a dataset that splits a SparseTensor into elements row-wise. +func SparseTensorSliceDataset(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output) (handle tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "DynamicStitch", + Type: "SparseTensorSliceDataset", Input: []tf.Input{ - tf.OutputList(indices), tf.OutputList(data), + indices, values, dense_shape, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the truth value of (x == y) element-wise. +// Returns x / y element-wise for real types. // -// *NOTE*: `Equal` supports broadcasting. More about broadcasting +// If `x` and `y` are reals, this will return the floating-point division. +// +// *NOTE*: `Div` supports broadcasting. More about broadcasting // [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Equal(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +func RealDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Equal", + Type: "RealDiv", Input: []tf.Input{ x, y, }, @@ -21459,32 +21421,59 @@ func Equal(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// TensorArrayGatherV2Attr is an optional argument to TensorArrayGatherV2. -type TensorArrayGatherV2Attr func(optionalAttr) - -// TensorArrayGatherV2ElementShape sets the optional element_shape attribute to value. -// If not specified, defaults to -func TensorArrayGatherV2ElementShape(value tf.Shape) TensorArrayGatherV2Attr { - return func(m optionalAttr) { - m["element_shape"] = value +// Creates a dataset that concatenates `input_dataset` with `another_dataset`. +func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "ConcatenateDataset", + Input: []tf.Input{ + input_dataset, another_dataset, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Deprecated. Use TensorArrayGatherV3 +// Adds a value to the current value of a variable. // -// DEPRECATED at GraphDef version 26: Use TensorArrayGatherV3 -func TensorArrayGatherV2(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV2Attr) (value tf.Output) { +// Any ReadVariableOp which depends directly or indirectly on this assign is +// guaranteed to see the incremented value or a subsequent newer one. +// +// Outputs the incremented value, which can be used to totally order the +// increments to this variable. +// +// Arguments: +// resource: handle to the resource in which to store the variable. +// value: the value by which the variable will be incremented. +// +// Returns the created operation. +func AssignAddVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) + opspec := tf.OpSpec{ + Type: "AssignAddVariableOp", + Input: []tf.Input{ + resource, value, + }, + } + return scope.AddOperation(opspec) +} + +// Records the latency of producing `input_dataset` elements in a StatsAggregator. +func LatencyStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "TensorArrayGatherV2", + Type: "LatencyStatsDataset", Input: []tf.Input{ - handle, indices, flow_in, + input_dataset, tag, }, Attrs: attrs, } @@ -21492,287 +21481,161 @@ func TensorArrayGatherV2(scope *Scope, handle tf.Output, indices tf.Output, flow return op.Output(0) } -// Interleave the values from the `data` tensors into a single tensor. -// -// Builds a merged tensor such that -// -// ```python -// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] -// ``` -// -// For example, if each `indices[m]` is scalar or vector, we have -// -// ```python -// # Scalar indices: -// merged[indices[m], ...] = data[m][...] +// Convert JSON-encoded Example records to binary protocol buffer strings. // -// # Vector indices: -// merged[indices[m][i], ...] = data[m][i, ...] -// ``` +// This op translates a tensor containing Example records, encoded using +// the [standard JSON +// mapping](https://developers.google.com/protocol-buffers/docs/proto3#json), +// into a tensor containing the same records encoded as binary protocol +// buffers. The resulting tensor can then be fed to any of the other +// Example-parsing ops. // -// Each `data[i].shape` must start with the corresponding `indices[i].shape`, -// and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we -// must have `data[i].shape = indices[i].shape + constant`. In terms of this -// `constant`, the output shape is +// Arguments: +// json_examples: Each string is a JSON object serialized according to the JSON +// mapping of the Example proto. // -// merged.shape = [max(indices)] + constant +// Returns Each string is a binary Example protocol buffer corresponding +// to the respective element of `json_examples`. +func DecodeJSONExample(scope *Scope, json_examples tf.Output) (binary_examples tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DecodeJSONExample", + Input: []tf.Input{ + json_examples, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the grayscale dilation of 4-D `input` and 3-D `filter` tensors. // -// Values may be merged in parallel, so if an index appears in both `indices[m][i]` -// and `indices[n][j]`, the result may be invalid. This differs from the normal -// DynamicStitch operator that defines the behavior in that case. +// The `input` tensor has shape `[batch, in_height, in_width, depth]` and the +// `filter` tensor has shape `[filter_height, filter_width, depth]`, i.e., each +// input channel is processed independently of the others with its own structuring +// function. The `output` tensor has shape +// `[batch, out_height, out_width, depth]`. The spatial dimensions of the output +// tensor depend on the `padding` algorithm. We currently only support the default +// "NHWC" `data_format`. // -// For example: +// In detail, the grayscale morphological 2-D dilation is the max-sum correlation +// (for consistency with `conv2d`, we use unmirrored filters): // -// ```python -// indices[0] = 6 -// indices[1] = [4, 1] -// indices[2] = [[5, 2], [0, 3]] -// data[0] = [61, 62] -// data[1] = [[41, 42], [11, 12]] -// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] -// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], -// [51, 52], [61, 62]] -// ``` +// output[b, y, x, c] = +// max_{dy, dx} input[b, +// strides[1] * y + rates[1] * dy, +// strides[2] * x + rates[2] * dx, +// c] + +// filter[dy, dx, c] // -// This method can be used to merge partitions created by `dynamic_partition` -// as illustrated on the following example: +// Max-pooling is a special case when the filter has size equal to the pooling +// kernel size and contains all zeros. // -// ```python -// # Apply function (increments x_i) on elements for which a certain condition -// # apply (x_i != -1 in this example). -// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) -// condition_mask=tf.not_equal(x,tf.constant(-1.)) -// partitioned_data = tf.dynamic_partition( -// x, tf.cast(condition_mask, tf.int32) , 2) -// partitioned_data[1] = partitioned_data[1] + 1.0 -// condition_indices = tf.dynamic_partition( -// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) -// x = tf.dynamic_stitch(condition_indices, partitioned_data) -// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain -// # unchanged. -// ``` +// Note on duality: The dilation of `input` by the `filter` is equal to the +// negation of the erosion of `-input` by the reflected `filter`. // -//
-// -//
-func ParallelDynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { +// Arguments: +// input: 4-D with shape `[batch, in_height, in_width, depth]`. +// filter: 3-D with shape `[filter_height, filter_width, depth]`. +// strides: The stride of the sliding window for each dimension of the input +// tensor. Must be: `[1, stride_height, stride_width, 1]`. +// rates: The input stride for atrous morphological dilation. Must be: +// `[1, rate_height, rate_width, 1]`. +// padding: The type of padding algorithm to use. +// +// Returns 4-D with shape `[batch, out_height, out_width, depth]`. +func Dilation2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, rates []int64, padding string) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} opspec := tf.OpSpec{ - Type: "ParallelDynamicStitch", + Type: "Dilation2D", Input: []tf.Input{ - tf.OutputList(indices), tf.OutputList(data), + input, filter, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the gradient for the inverse of `x` wrt its input. +// Converts the given variant tensor to an iterator and stores it in the given resource. // -// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` -// is the corresponding input gradient. -func InvGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { +// Arguments: +// resource_handle: A handle to an iterator resource. +// serialized: A variant tensor storing the state of the iterator contained in the +// resource. +// +// Returns the created operation. +func DeserializeIterator(scope *Scope, resource_handle tf.Output, serialized tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "InvGrad", + Type: "DeserializeIterator", Input: []tf.Input{ - y, dy, + resource_handle, serialized, }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// StridedSliceAttr is an optional argument to StridedSlice. -type StridedSliceAttr func(optionalAttr) - -// StridedSliceBeginMask sets the optional begin_mask attribute to value. -// -// value: a bitmask where a bit i being 1 means to ignore the begin -// value and instead use the largest interval possible. At runtime -// begin[i] will be replaced with `[0, n-1) if `stride[i] > 0` or -// `[-1, n-1]` if `stride[i] < 0` -// If not specified, defaults to 0 -func StridedSliceBeginMask(value int64) StridedSliceAttr { - return func(m optionalAttr) { - m["begin_mask"] = value - } -} +// TensorArrayConcatV2Attr is an optional argument to TensorArrayConcatV2. +type TensorArrayConcatV2Attr func(optionalAttr) -// StridedSliceEndMask sets the optional end_mask attribute to value. -// -// value: analogous to `begin_mask` -// If not specified, defaults to 0 -func StridedSliceEndMask(value int64) StridedSliceAttr { +// TensorArrayConcatV2ElementShapeExcept0 sets the optional element_shape_except0 attribute to value. +// If not specified, defaults to +func TensorArrayConcatV2ElementShapeExcept0(value tf.Shape) TensorArrayConcatV2Attr { return func(m optionalAttr) { - m["end_mask"] = value + m["element_shape_except0"] = value } } -// StridedSliceEllipsisMask sets the optional ellipsis_mask attribute to value. -// -// value: a bitmask where bit `i` being 1 means the `i`th -// position is actually an ellipsis. One bit at most can be 1. -// If `ellipsis_mask == 0`, then an implicit ellipsis mask of `1 << (m+1)` -// is provided. This means that `foo[3:5] == foo[3:5, ...]`. An ellipsis -// implicitly creates as many range specifications as necessary to fully -// specify the sliced range for every dimension. For example for a 4-dimensional -// tensor `foo` the slice `foo[2, ..., 5:8]` implies `foo[2, :, :, 5:8]`. -// If not specified, defaults to 0 -func StridedSliceEllipsisMask(value int64) StridedSliceAttr { - return func(m optionalAttr) { - m["ellipsis_mask"] = value +// Deprecated. Use TensorArrayConcatV3 +func TensorArrayConcatV2(scope *Scope, handle tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayConcatV2Attr) (value tf.Output, lengths tf.Output) { + if scope.Err() != nil { + return } -} - -// StridedSliceNewAxisMask sets the optional new_axis_mask attribute to value. -// -// value: a bitmask where bit `i` being 1 means the `i`th -// specification creates a new shape 1 dimension. For example -// `foo[:4, tf.newaxis, :2]` would produce a shape `(4, 1, 2)` tensor. -// If not specified, defaults to 0 -func StridedSliceNewAxisMask(value int64) StridedSliceAttr { - return func(m optionalAttr) { - m["new_axis_mask"] = value + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) } -} - -// StridedSliceShrinkAxisMask sets the optional shrink_axis_mask attribute to value. -// -// value: a bitmask where bit `i` implies that the `i`th -// specification should shrink the dimensionality. begin and end -// must imply a slice of size 1 in the dimension. For example in -// python one might do `foo[:, 3, :]` which would result in -// `shrink_axis_mask` being 2. -// If not specified, defaults to 0 -func StridedSliceShrinkAxisMask(value int64) StridedSliceAttr { - return func(m optionalAttr) { - m["shrink_axis_mask"] = value + opspec := tf.OpSpec{ + Type: "TensorArrayConcatV2", + Input: []tf.Input{ + handle, flow_in, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } -// Return a strided slice from `input`. -// -// Note, most python users will want to use the Python `Tensor.__getitem__` -// or `Variable.__getitem__` rather than this op directly. -// -// The goal of this op is to produce a new tensor with a subset of -// the elements from the `n` dimensional `input` tensor. The subset is chosen using -// a sequence of `m` sparse range specifications encoded into the arguments -// of this function. Note, in some cases -// `m` could be equal to `n`, but this need not be the case. Each -// range specification entry can be one of the following: -// -// - An ellipsis (...). Ellipses are used to imply zero or more -// dimensions of full-dimension selection and are produced using -// `ellipsis_mask`. For example, `foo[...]` is the identity slice. -// -// - A new axis. This is used to insert a new shape=1 dimension and is -// produced using `new_axis_mask`. For example, `foo[:, ...]` where -// `foo` is shape `(3, 4)` produces a `(1, 3, 4)` tensor. -// -// -// - A range `begin:end:stride`. This is used to specify how much to choose from -// a given dimension. `stride` can be any integer but 0. `begin` is an integer -// which represents the index of the first value to select while `end` represents -// the index of the last value to select. The number of values selected in each -// dimension is `end - begin` if `stride > 0` and `begin - end` if `stride < 0`. -// `begin` and `end` can be negative where `-1` is the last element, `-2` is -// the second to last. `begin_mask` controls whether to replace the explicitly -// given `begin` with an implicit effective value of `0` if `stride > 0` and -// `-1` if `stride < 0`. `end_mask` is analogous but produces the number -// required to create the largest open interval. For example, given a shape -// `(3,)` tensor `foo[:]`, the effective `begin` and `end` are `0` and `3`. Do -// not assume this is equivalent to `foo[0:-1]` which has an effective `begin` -// and `end` of `0` and `2`. Another example is `foo[-2::-1]` which reverses the -// first dimension of a tensor while dropping the last two (in the original -// order elements). For example `foo = [1,2,3,4]; foo[-2::-1]` is `[4,3]`. -// -// - A single index. This is used to keep only elements that have a given -// index. For example (`foo[2, :]` on a shape `(5,6)` tensor produces a -// shape `(6,)` tensor. This is encoded in `begin` and `end` and -// `shrink_axis_mask`. -// -// Each conceptual range specification is encoded in the op's argument. This -// encoding is best understand by considering a non-trivial example. In -// particular, -// `foo[1, 2:4, None, ..., :-3:-1, :]` will be encoded as -// -// ``` -// begin = [1, 2, x, x, 0, x] # x denotes don't care (usually 0) -// end = [2, 4, x, x, -3, x] -// strides = [1, 1, x, x, -1, 1] -// begin_mask = 1<<4 | 1 << 5 = 48 -// end_mask = 1<<5 = 32 -// ellipsis_mask = 1<<3 = 8 -// new_axis_mask = 1<<2 4 -// shrink_axis_mask = 1<<0 -// ``` -// -// In this case if `foo.shape` is (5, 5, 5, 5, 5, 5) the final shape of -// the slice becomes (2, 1, 5, 5, 2, 5). -// Let us walk step by step through each argument specification. -// -// 1. The first argument in the example slice is turned into `begin = 1` and -// `end = begin + 1 = 2`. To disambiguate from the original spec `2:4` we -// also set the appropriate bit in `shrink_axis_mask`. -// -// 2. `2:4` is contributes 2, 4, 1 to begin, end, and stride. All masks have -// zero bits contributed. -// -// 3. None is a synonym for `tf.newaxis`. This means insert a dimension of size 1 -// dimension in the final shape. Dummy values are contributed to begin, -// end and stride, while the new_axis_mask bit is set. -// -// 4. `...` grab the full ranges from as many dimensions as needed to -// fully specify a slice for every dimension of the input shape. -// -// 5. `:-3:-1` shows the use of negative indices. A negative index `i` associated -// with a dimension that has shape `s` is converted to a positive index -// `s + i`. So `-1` becomes `s-1` (i.e. the last element). This conversion -// is done internally so begin, end and strides receive x, -3, and -1. -// The appropriate begin_mask bit is set to indicate the start range is the -// full range (ignoring the x). -// -// 6. `:` indicates that the entire contents of the corresponding dimension -// is selected. This is equivalent to `::` or `0::1`. begin, end, and strides -// receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and -// `end_mask` are also set. -// -// *Requirements*: -// `0 != strides[i] for i in [0, m)` -// `ellipsis_mask must be a power of two (only one ellipsis)` +// Creates a dataset that batches and pads `batch_size` elements from the input. // // Arguments: // -// begin: `begin[k]` specifies the offset into the `k`th range specification. -// The exact dimension this corresponds to will be determined by context. -// Out-of-bounds values will be silently clamped. If the `k`th bit of -// `begin_mask` then `begin[k]` is ignored and the full range of the -// appropriate dimension is used instead. Negative values causes indexing -// to start from the highest element e.g. If `foo==[1,2,3]` then `foo[-1]==3`. -// end: `end[i]` is like `begin` with the exception that `end_mask` is -// used to determine full ranges. -// strides: `strides[i]` specifies the increment in the `i`th specification -// after extracting a given element. Negative indices will reverse -// the original order. Out or range values are -// clamped to `[0,dim[i]) if slice[i]>0` or `[-1,dim[i]-1] if slice[i] < 0` -func StridedSlice(scope *Scope, input tf.Output, begin tf.Output, end tf.Output, strides tf.Output, optional ...StridedSliceAttr) (output tf.Output) { +// batch_size: A scalar representing the number of elements to accumulate in a +// batch. +// padded_shapes: A list of int64 tensors representing the desired padded shapes +// of the corresponding output components. These shapes may be partially +// specified, using `-1` to indicate that a particular dimension should be +// padded to the maximum size of all batch elements. +// padding_values: A list of scalars containing the padding value to use for +// each of the outputs. +// +func PaddedBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, padded_shapes []tf.Output, padding_values []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "StridedSlice", + Type: "PaddedBatchDataset", Input: []tf.Input{ - input, begin, end, strides, + input_dataset, batch_size, tf.OutputList(padded_shapes), tf.OutputList(padding_values), }, Attrs: attrs, } @@ -21780,187 +21643,263 @@ func StridedSlice(scope *Scope, input tf.Output, begin tf.Output, end tf.Output, return op.Output(0) } -// PriorityQueueV2Attr is an optional argument to PriorityQueueV2. -type PriorityQueueV2Attr func(optionalAttr) - -// PriorityQueueV2ComponentTypes sets the optional component_types attribute to value. +// Creates a dataset that batches input elements into a SparseTensor. // -// value: The type of each component in a value. -// If not specified, defaults to <> +// Arguments: +// input_dataset: A handle to an input dataset. Must have a single component. +// batch_size: A scalar representing the number of elements to accumulate in a +// batch. +// row_shape: A vector representing the dense shape of each row in the produced +// SparseTensor. The shape may be partially specified, using `-1` to indicate +// that a particular dimension should use the maximum size of all batch elements. // -// REQUIRES: len(value) >= 0 -func PriorityQueueV2ComponentTypes(value []tf.DataType) PriorityQueueV2Attr { - return func(m optionalAttr) { - m["component_types"] = value - } -} - -// PriorityQueueV2Capacity sets the optional capacity attribute to value. // -// value: The upper bound on the number of elements in this queue. -// Negative numbers mean no limit. -// If not specified, defaults to -1 -func PriorityQueueV2Capacity(value int64) PriorityQueueV2Attr { - return func(m optionalAttr) { - m["capacity"] = value +func DenseToSparseBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, row_shape tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "DenseToSparseBatchDataset", + Input: []tf.Input{ + input_dataset, batch_size, row_shape, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// PriorityQueueV2Container sets the optional container attribute to value. +// Deprecated. Use TensorArrayGradV3 // -// value: If non-empty, this queue is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func PriorityQueueV2Container(value string) PriorityQueueV2Attr { - return func(m optionalAttr) { - m["container"] = value +// DEPRECATED at GraphDef version 26: Use TensorArrayGradV3 +func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output) { + if scope.Err() != nil { + return } + attrs := map[string]interface{}{"source": source} + opspec := tf.OpSpec{ + Type: "TensorArrayGradV2", + Input: []tf.Input{ + handle, flow_in, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) } -// PriorityQueueV2SharedName sets the optional shared_name attribute to value. +// ResourceSparseApplyAdadeltaAttr is an optional argument to ResourceSparseApplyAdadelta. +type ResourceSparseApplyAdadeltaAttr func(optionalAttr) + +// ResourceSparseApplyAdadeltaUseLocking sets the optional use_locking attribute to value. // -// value: If non-empty, this queue will be shared under the given name -// across multiple sessions. -// If not specified, defaults to "" -func PriorityQueueV2SharedName(value string) PriorityQueueV2Attr { +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyAdadeltaUseLocking(value bool) ResourceSparseApplyAdadeltaAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["use_locking"] = value } } -// A queue that produces elements sorted by the first component value. -// -// Note that the PriorityQueue requires the first component of any element -// to be a scalar int64, in addition to the other elements declared by -// component_types. Therefore calls to Enqueue and EnqueueMany (resp. Dequeue -// and DequeueMany) on a PriorityQueue will all require (resp. output) one extra -// entry in their input (resp. output) lists. +// var: Should be from a Variable(). // // Arguments: -// shapes: The shape of each component in a value. The length of this attr must -// be either 0 or the same as the length of component_types. If the length of -// this attr is 0, the shapes of queue elements are not constrained, and -// only one element may be dequeued at a time. // -// Returns The handle to the queue. -func PriorityQueueV2(scope *Scope, shapes []tf.Shape, optional ...PriorityQueueV2Attr) (handle tf.Output) { +// accum: Should be from a Variable(). +// accum_update: : Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// rho: Decay factor. Must be a scalar. +// epsilon: Constant factor. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// +// Returns the created operation. +func ResourceSparseApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdadeltaAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"shapes": shapes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "PriorityQueueV2", - + Type: "ResourceSparseApplyAdadelta", + Input: []tf.Input{ + var_, accum, accum_update, lr, rho, epsilon, grad, indices, + }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// UnstageAttr is an optional argument to Unstage. -type UnstageAttr func(optionalAttr) - -// UnstageCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// Identity op for gradient debugging. // -// REQUIRES: value >= 0 -func UnstageCapacity(value int64) UnstageAttr { - return func(m optionalAttr) { - m["capacity"] = value +// This op is hidden from public in Python. It is used by TensorFlow Debugger to +// register gradient tensors for gradient debugging. +// This op operates on non-reference-type tensors. +func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return } -} - -// UnstageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func UnstageMemoryLimit(value int64) UnstageAttr { - return func(m optionalAttr) { - m["memory_limit"] = value + opspec := tf.OpSpec{ + Type: "DebugGradientIdentity", + Input: []tf.Input{ + input, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// UnstageContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func UnstageContainer(value string) UnstageAttr { - return func(m optionalAttr) { - m["container"] = value +// Return substrings from `Tensor` of strings. +// +// For each string in the input `Tensor`, creates a substring starting at index +// `pos` with a total length of `len`. +// +// If `len` defines a substring that would extend beyond the length of the input +// string, then as many characters as possible are used. +// +// If `pos` is negative or specifies a character index larger than any of the input +// strings, then an `InvalidArgumentError` is thrown. +// +// `pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on +// Op creation. +// +// *NOTE*: `Substr` supports broadcasting up to two dimensions. More about +// broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +// +// --- +// +// Examples +// +// Using scalar `pos` and `len`: +// +// ```python +// input = [b'Hello', b'World'] +// position = 1 +// length = 3 +// +// output = [b'ell', b'orl'] +// ``` +// +// Using `pos` and `len` with same shape as `input`: +// +// ```python +// input = [[b'ten', b'eleven', b'twelve'], +// [b'thirteen', b'fourteen', b'fifteen'], +// [b'sixteen', b'seventeen', b'eighteen']] +// position = [[1, 2, 3], +// [1, 2, 3], +// [1, 2, 3]] +// length = [[2, 3, 4], +// [4, 3, 2], +// [5, 5, 5]] +// +// output = [[b'en', b'eve', b'lve'], +// [b'hirt', b'urt', b'te'], +// [b'ixtee', b'vente', b'hteen']] +// ``` +// +// Broadcasting `pos` and `len` onto `input`: +// +// ``` +// input = [[b'ten', b'eleven', b'twelve'], +// [b'thirteen', b'fourteen', b'fifteen'], +// [b'sixteen', b'seventeen', b'eighteen'], +// [b'nineteen', b'twenty', b'twentyone']] +// position = [1, 2, 3] +// length = [1, 2, 3] +// +// output = [[b'e', b'ev', b'lve'], +// [b'h', b'ur', b'tee'], +// [b'i', b've', b'hte'], +// [b'i', b'en', b'nty']] +// ``` +// +// Broadcasting `input` onto `pos` and `len`: +// +// ``` +// input = b'thirteen' +// position = [1, 5, 7] +// length = [3, 2, 1] +// +// output = [b'hir', b'ee', b'n'] +// ``` +// +// Arguments: +// input: Tensor of strings +// pos: Scalar defining the position of first character in each substring +// len: Scalar defining the number of characters to include in each substring +// +// Returns Tensor of substrings +func Substr(scope *Scope, input tf.Output, pos tf.Output, len tf.Output) (output tf.Output) { + if scope.Err() != nil { + return } -} - -// UnstageSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func UnstageSharedName(value string) UnstageAttr { - return func(m optionalAttr) { - m["shared_name"] = value + opspec := tf.OpSpec{ + Type: "Substr", + Input: []tf.Input{ + input, pos, len, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Op is similar to a lightweight Dequeue. +// Creates a Dataset that returns pseudorandom numbers. // -// The basic functionality is similar to dequeue with many fewer -// capabilities and options. This Op is optimized for performance. -func Unstage(scope *Scope, dtypes []tf.DataType, optional ...UnstageAttr) (values []tf.Output) { +// Arguments: +// seed: A scalar seed for the random number generator. If either seed or +// seed2 is set to be non-zero, the random number generator is seeded +// by the given seed. Otherwise, a random seed is used. +// seed2: A second scalar seed to avoid seed collision. +// +// +func RandomDataset(scope *Scope, seed tf.Output, seed2 tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Unstage", - + Type: "RandomDataset", + Input: []tf.Input{ + seed, seed2, + }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("Unstage", err) - return - } - return values -} - -// ArgMaxAttr is an optional argument to ArgMax. -type ArgMaxAttr func(optionalAttr) - -// ArgMaxOutputType sets the optional output_type attribute to value. -// If not specified, defaults to DT_INT64 -func ArgMaxOutputType(value tf.DataType) ArgMaxAttr { - return func(m optionalAttr) { - m["output_type"] = value - } + return op.Output(0) } -// Returns the index with the largest value across dimensions of a tensor. +// Creates a dataset that shuffles and repeats elements from `input_dataset` // -// Note that in case of ties the identity of the return value is not guaranteed. +// pseudorandomly. // // Arguments: // -// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`. -// Describes which dimension of the input Tensor to reduce across. For vectors, -// use dimension = 0. -func ArgMax(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMaxAttr) (output tf.Output) { +// buffer_size: The number of output elements to buffer in an iterator over +// this dataset. Compare with the `min_after_dequeue` attr when creating a +// `RandomShuffleQueue`. +// seed: A scalar seed for the random number generator. If either `seed` or +// `seed2` is set to be non-zero, the random number generator is seeded +// by the given seed. Otherwise, a random seed is used. +// seed2: A second scalar seed to avoid seed collision. +// count: A scalar representing the number of times the underlying dataset +// should be repeated. The default is `-1`, which results in infinite repetition. +// +// +func ShuffleAndRepeatDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, seed tf.Output, seed2 tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ArgMax", + Type: "ShuffleAndRepeatDataset", Input: []tf.Input{ - input, dimension, + input_dataset, buffer_size, seed, seed2, count, }, Attrs: attrs, } @@ -21968,226 +21907,244 @@ func ArgMax(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgM return op.Output(0) } -// ResourceStridedSliceAssignAttr is an optional argument to ResourceStridedSliceAssign. -type ResourceStridedSliceAssignAttr func(optionalAttr) - -// ResourceStridedSliceAssignBeginMask sets the optional begin_mask attribute to value. -// If not specified, defaults to 0 -func ResourceStridedSliceAssignBeginMask(value int64) ResourceStridedSliceAssignAttr { - return func(m optionalAttr) { - m["begin_mask"] = value - } -} - -// ResourceStridedSliceAssignEndMask sets the optional end_mask attribute to value. -// If not specified, defaults to 0 -func ResourceStridedSliceAssignEndMask(value int64) ResourceStridedSliceAssignAttr { - return func(m optionalAttr) { - m["end_mask"] = value - } -} - -// ResourceStridedSliceAssignEllipsisMask sets the optional ellipsis_mask attribute to value. -// If not specified, defaults to 0 -func ResourceStridedSliceAssignEllipsisMask(value int64) ResourceStridedSliceAssignAttr { - return func(m optionalAttr) { - m["ellipsis_mask"] = value - } -} - -// ResourceStridedSliceAssignNewAxisMask sets the optional new_axis_mask attribute to value. -// If not specified, defaults to 0 -func ResourceStridedSliceAssignNewAxisMask(value int64) ResourceStridedSliceAssignAttr { - return func(m optionalAttr) { - m["new_axis_mask"] = value - } -} - -// ResourceStridedSliceAssignShrinkAxisMask sets the optional shrink_axis_mask attribute to value. -// If not specified, defaults to 0 -func ResourceStridedSliceAssignShrinkAxisMask(value int64) ResourceStridedSliceAssignAttr { - return func(m optionalAttr) { - m["shrink_axis_mask"] = value - } -} - -// Assign `value` to the sliced l-value reference of `ref`. +// Creates a dataset that caches elements from `input_dataset`. // -// The values of `value` are assigned to the positions in the variable -// `ref` that are selected by the slice parameters. The slice parameters -// `begin, `end`, `strides`, etc. work exactly as in `StridedSlice`. +// A CacheDataset will iterate over the input_dataset, and store tensors. If the +// cache already exists, the cache will be used. If the cache is inappropriate +// (e.g. cannot be opened, contains tensors of the wrong shape / size), an error +// will the returned when used. // -// NOTE this op currently does not support broadcasting and so `value`'s -// shape must be exactly the shape produced by the slice of `ref`. +// Arguments: // -// Returns the created operation. -func ResourceStridedSliceAssign(scope *Scope, ref tf.Output, begin tf.Output, end tf.Output, strides tf.Output, value tf.Output, optional ...ResourceStridedSliceAssignAttr) (o *tf.Operation) { +// filename: A path on the filesystem where we should cache the dataset. Note: this +// will be a directory. +// +// +func CacheDataset(scope *Scope, input_dataset tf.Output, filename tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ResourceStridedSliceAssign", + Type: "CacheDataset", Input: []tf.Input{ - ref, begin, end, strides, value, + input_dataset, filename, }, Attrs: attrs, } - return scope.AddOperation(opspec) -} - -// QueueEnqueueV2Attr is an optional argument to QueueEnqueueV2. -type QueueEnqueueV2Attr func(optionalAttr) - -// QueueEnqueueV2TimeoutMs sets the optional timeout_ms attribute to value. -// -// value: If the queue is full, this operation will block for up to -// timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueEnqueueV2TimeoutMs(value int64) QueueEnqueueV2Attr { - return func(m optionalAttr) { - m["timeout_ms"] = value - } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Enqueues a tuple of one or more tensors in the given queue. -// -// The components input has k elements, which correspond to the components of -// tuples stored in the given queue. -// -// N.B. If the queue is full, this operation will block until the given -// element has been enqueued (or 'timeout_ms' elapses, if specified). +// Creates a dataset that executes a SQL query and emits rows of the result set. // // Arguments: -// handle: The handle to a queue. -// components: One or more tensors from which the enqueued tensors should be taken. +// driver_name: The database type. Currently, the only supported type is 'sqlite'. +// data_source_name: A connection string to connect to the database. +// query: A SQL query to execute. // -// Returns the created operation. -func QueueEnqueueV2(scope *Scope, handle tf.Output, components []tf.Output, optional ...QueueEnqueueV2Attr) (o *tf.Operation) { +// +func SqlDataset(scope *Scope, driver_name tf.Output, data_source_name tf.Output, query tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "QueueEnqueueV2", + Type: "SqlDataset", Input: []tf.Input{ - handle, tf.OutputList(components), + driver_name, data_source_name, query, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// QueueDequeueManyV2Attr is an optional argument to QueueDequeueManyV2. -type QueueDequeueManyV2Attr func(optionalAttr) - -// QueueDequeueManyV2TimeoutMs sets the optional timeout_ms attribute to value. +// Creates a dataset that emits the records from one or more binary files. // -// value: If the queue has fewer than n elements, this operation -// will block for up to timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueDequeueManyV2TimeoutMs(value int64) QueueDequeueManyV2Attr { - return func(m optionalAttr) { - m["timeout_ms"] = value +// Arguments: +// filenames: A scalar or a vector containing the name(s) of the file(s) to be +// read. +// header_bytes: A scalar representing the number of bytes to skip at the +// beginning of a file. +// record_bytes: A scalar representing the number of bytes in each record. +// footer_bytes: A scalar representing the number of bytes to skip at the end +// of a file. +// buffer_size: A scalar representing the number of bytes to buffer. Must be > 0. +func FixedLengthRecordDataset(scope *Scope, filenames tf.Output, header_bytes tf.Output, record_bytes tf.Output, footer_bytes tf.Output, buffer_size tf.Output) (handle tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "FixedLengthRecordDataset", + Input: []tf.Input{ + filenames, header_bytes, record_bytes, footer_bytes, buffer_size, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Dequeues `n` tuples of one or more tensors from the given queue. -// -// If the queue is closed and there are fewer than `n` elements, then an -// OutOfRange error is returned. -// -// This operation concatenates queue-element component tensors along the -// 0th dimension to make a single component tensor. All of the components -// in the dequeued tuple will have size `n` in the 0th dimension. +// Gradients for batch normalization. // -// This operation has `k` outputs, where `k` is the number of components in -// the tuples stored in the given queue, and output `i` is the ith -// component of the dequeued tuple. +// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() // -// N.B. If the queue is empty, this operation will block until `n` elements -// have been dequeued (or 'timeout_ms' elapses, if specified). +// This op is deprecated. See `tf.nn.batch_normalization`. // // Arguments: -// handle: The handle to a queue. -// n: The number of tuples to dequeue. -// component_types: The type of each component in a tuple. +// t: A 4D input Tensor. +// m: A 1D mean Tensor with size matching the last dimension of t. +// This is the first output from tf.nn.moments, +// or a saved moving average thereof. +// v: A 1D variance Tensor with size matching the last dimension of t. +// This is the second output from tf.nn.moments, +// or a saved moving average thereof. +// gamma: A 1D gamma Tensor with size matching the last dimension of t. +// If "scale_after_normalization" is true, this Tensor will be multiplied +// with the normalized Tensor. +// backprop: 4D backprop Tensor. +// variance_epsilon: A small float number to avoid dividing by 0. +// scale_after_normalization: A bool indicating whether the resulted tensor +// needs to be multiplied with gamma. // -// Returns One or more tensors that were dequeued as a tuple. -func QueueDequeueManyV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueManyV2Attr) (components []tf.Output) { +// Returns 4D backprop tensor for input.1D backprop tensor for mean.1D backprop tensor for variance.1D backprop tensor for beta.1D backprop tensor for gamma. +func BatchNormWithGlobalNormalizationGrad(scope *Scope, t tf.Output, m tf.Output, v tf.Output, gamma tf.Output, backprop tf.Output, variance_epsilon float32, scale_after_normalization bool) (dx tf.Output, dm tf.Output, dv tf.Output, db tf.Output, dg tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"component_types": component_types} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} opspec := tf.OpSpec{ - Type: "QueueDequeueManyV2", + Type: "BatchNormWithGlobalNormalizationGrad", Input: []tf.Input{ - handle, n, + t, m, v, gamma, backprop, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("QueueDequeueManyV2", err) - return - } - return components + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// EncodeBase64Attr is an optional argument to EncodeBase64. -type EncodeBase64Attr func(optionalAttr) - -// EncodeBase64Pad sets the optional pad attribute to value. +// Creates a dataset that emits the records from one or more TFRecord files. // -// value: Bool whether padding is applied at the ends. -// If not specified, defaults to false -func EncodeBase64Pad(value bool) EncodeBase64Attr { - return func(m optionalAttr) { - m["pad"] = value +// Arguments: +// filenames: A scalar or vector containing the name(s) of the file(s) to be +// read. +// compression_type: A scalar containing either (i) the empty string (no +// compression), (ii) "ZLIB", or (iii) "GZIP". +// buffer_size: A scalar representing the number of bytes to buffer. A value of +// 0 means no buffering will be performed. +func TFRecordDataset(scope *Scope, filenames tf.Output, compression_type tf.Output, buffer_size tf.Output) (handle tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TFRecordDataset", + Input: []tf.Input{ + filenames, compression_type, buffer_size, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Encode strings into web-safe base64 format. +// BatchToSpace for 4-D tensors of type T. // -// Refer to the following article for more information on base64 format: -// en.wikipedia.org/wiki/Base64. Base64 strings may have padding with '=' at the -// end so that the encoded has length multiple of 4. See Padding section of the -// link above. +// This is a legacy version of the more general BatchToSpaceND. // -// Web-safe means that the encoder uses - and _ instead of + and /. +// Rearranges (permutes) data from batch into blocks of spatial data, followed by +// cropping. This is the reverse transformation of SpaceToBatch. More specifically, +// this op outputs a copy of the input tensor where values from the `batch` +// dimension are moved in spatial blocks to the `height` and `width` dimensions, +// followed by cropping along the `height` and `width` dimensions. // // Arguments: -// input: Strings to be encoded. +// input: 4-D tensor with shape +// `[batch*block_size*block_size, height_pad/block_size, width_pad/block_size, +// depth]`. Note that the batch size of the input tensor must be divisible by +// `block_size * block_size`. +// crops: 2-D tensor of non-negative integers with shape `[2, 2]`. It specifies +// how many elements to crop from the intermediate result across the spatial +// dimensions as follows: // -// Returns Input strings encoded in base64. -func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) (output tf.Output) { +// crops = [[crop_top, crop_bottom], [crop_left, crop_right]] +// +// +// Returns 4-D with shape `[batch, height, width, depth]`, where: +// +// height = height_pad - crop_top - crop_bottom +// width = width_pad - crop_left - crop_right +// +// The attr `block_size` must be greater than one. It indicates the block size. +// +// Some examples: +// +// (1) For the following input of shape `[4, 1, 1, 1]` and block_size of 2: +// +// ``` +// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] +// ``` +// +// The output tensor has shape `[1, 2, 2, 1]` and value: +// +// ``` +// x = [[[[1], [2]], [[3], [4]]]] +// ``` +// +// (2) For the following input of shape `[4, 1, 1, 3]` and block_size of 2: +// +// ``` +// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] +// ``` +// +// The output tensor has shape `[1, 2, 2, 3]` and value: +// +// ``` +// x = [[[[1, 2, 3], [4, 5, 6]], +// [[7, 8, 9], [10, 11, 12]]]] +// ``` +// +// (3) For the following input of shape `[4, 2, 2, 1]` and block_size of 2: +// +// ``` +// x = [[[[1], [3]], [[9], [11]]], +// [[[2], [4]], [[10], [12]]], +// [[[5], [7]], [[13], [15]]], +// [[[6], [8]], [[14], [16]]]] +// ``` +// +// The output tensor has shape `[1, 4, 4, 1]` and value: +// +// ``` +// x = [[[1], [2], [3], [4]], +// [[5], [6], [7], [8]], +// [[9], [10], [11], [12]], +// [[13], [14], [15], [16]]] +// ``` +// +// (4) For the following input of shape `[8, 1, 2, 1]` and block_size of 2: +// +// ``` +// x = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], +// [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] +// ``` +// +// The output tensor has shape `[2, 2, 4, 1]` and value: +// +// ``` +// x = [[[[1], [3]], [[5], [7]]], +// [[[2], [4]], [[10], [12]]], +// [[[5], [7]], [[13], [15]]], +// [[[6], [8]], [[14], [16]]]] +// ``` +func BatchToSpace(scope *Scope, input tf.Output, crops tf.Output, block_size int64) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"block_size": block_size} opspec := tf.OpSpec{ - Type: "EncodeBase64", + Type: "BatchToSpace", Input: []tf.Input{ - input, + input, crops, }, Attrs: attrs, } @@ -22195,172 +22152,171 @@ func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) ( return op.Output(0) } -// Deprecated. Use TensorArrayCloseV3 +// Makes a new iterator from the given `dataset` and stores it in `iterator`. // -// DEPRECATED at GraphDef version 26: Use TensorArrayCloseV3 +// This operation may be executed multiple times. Each execution will reset the +// iterator in `iterator` to the first element of `dataset`. // // Returns the created operation. -func TensorArrayCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { +func MakeIterator(scope *Scope, dataset tf.Output, iterator tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorArrayCloseV2", + Type: "MakeIterator", Input: []tf.Input{ - handle, + dataset, iterator, }, } return scope.AddOperation(opspec) } -// CropAndResizeGradImageAttr is an optional argument to CropAndResizeGradImage. -type CropAndResizeGradImageAttr func(optionalAttr) - -// CropAndResizeGradImageMethod sets the optional method attribute to value. +// Adjust the contrast of one or more images. // -// value: A string specifying the interpolation method. Only 'bilinear' is -// supported for now. -// If not specified, defaults to "bilinear" -func CropAndResizeGradImageMethod(value string) CropAndResizeGradImageAttr { - return func(m optionalAttr) { - m["method"] = value - } -} - -// Computes the gradient of the crop_and_resize op wrt the input image tensor. +// `images` is a tensor of at least 3 dimensions. The last 3 dimensions are +// interpreted as `[height, width, channels]`. The other dimensions only +// represent a collection of images, such as `[batch, height, width, channels].` // -// Arguments: -// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. -// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor -// specifies the coordinates of a box in the `box_ind[i]` image and is specified -// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of -// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the -// `[0, 1]` interval of normalized image height is mapped to -// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in -// which case the sampled crop is an up-down flipped version of the original -// image. The width dimension is treated similarly. Normalized coordinates -// outside the `[0, 1]` range are allowed, in which case we use -// `extrapolation_value` to extrapolate the input image values. -// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. -// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. -// image_size: A 1-D tensor with value `[batch, image_height, image_width, depth]` -// containing the original image size. Both `image_height` and `image_width` need -// to be positive. +// Contrast is adjusted independently for each channel of each image. // +// For each channel, the Op first computes the mean of the image pixels in the +// channel and then adjusts each component of each pixel to +// `(x - mean) * contrast_factor + mean`. // -// Returns A 4-D tensor of shape `[batch, image_height, image_width, depth]`. -func CropAndResizeGradImage(scope *Scope, grads tf.Output, boxes tf.Output, box_ind tf.Output, image_size tf.Output, T tf.DataType, optional ...CropAndResizeGradImageAttr) (output tf.Output) { +// Arguments: +// images: Images to adjust. At least 3-D. +// contrast_factor: A float multiplier for adjusting contrast. +// +// Returns The contrast-adjusted image or images. +func AdjustContrastv2(scope *Scope, images tf.Output, contrast_factor tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"T": T} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "CropAndResizeGradImage", + Type: "AdjustContrastv2", Input: []tf.Input{ - grads, boxes, box_ind, image_size, + images, contrast_factor, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Reads and outputs the entire contents of the input filename. -func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) { +// Gets the next output from the given iterator. +func IteratorGetNext(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ReadFile", + Type: "IteratorGetNext", Input: []tf.Input{ - filename, + iterator, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("IteratorGetNext", err) + return + } + return components } -// Concatenates tensors along one dimension. +// Outputs the single element from the given dataset. // -// Arguments: -// values: List of `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. -// axis: 0-D. The dimension along which to concatenate. Must be in the -// range [-rank(values), rank(values)). +// Arguments: +// dataset: A handle to a dataset that contains a single element. // -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes. -func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) { +// +// +// Returns The components of the single element of `input`. +func DatasetToSingleElement(scope *Scope, dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ConcatV2", + Type: "DatasetToSingleElement", Input: []tf.Input{ - tf.OutputList(values), axis, + dataset, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("DatasetToSingleElement", err) + return + } + return components } -// Forwards the value of an available tensor from `inputs` to `output`. -// -// `Merge` waits for at least one of the tensors in `inputs` to become available. -// It is usually combined with `Switch` to implement branching. -// -// `Merge` forwards the first tensor to become available to `output`, and sets -// `value_index` to its index in `inputs`. +// Converts the given `resource_handle` representing an iterator to a string. // // Arguments: -// inputs: The input tensors, exactly one of which will become available. +// resource_handle: A handle to an iterator resource. // -// Returns Will be set to the available input tensor.The index of the chosen input tensor in `inputs`. -func Merge(scope *Scope, inputs []tf.Output) (output tf.Output, value_index tf.Output) { +// Returns A string representation of the given handle. +func IteratorToStringHandle(scope *Scope, resource_handle tf.Output) (string_handle tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Merge", + Type: "IteratorToStringHandle", Input: []tf.Input{ - tf.OutputList(inputs), + resource_handle, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// QueueCloseV2Attr is an optional argument to QueueCloseV2. -type QueueCloseV2Attr func(optionalAttr) +// IteratorFromStringHandleAttr is an optional argument to IteratorFromStringHandle. +type IteratorFromStringHandleAttr func(optionalAttr) -// QueueCloseV2CancelPendingEnqueues sets the optional cancel_pending_enqueues attribute to value. +// IteratorFromStringHandleOutputTypes sets the optional output_types attribute to value. // -// value: If true, all pending enqueue requests that are -// blocked on the given queue will be canceled. -// If not specified, defaults to false -func QueueCloseV2CancelPendingEnqueues(value bool) QueueCloseV2Attr { +// value: If specified, defines the type of each tuple component in an +// element produced by the resulting iterator. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func IteratorFromStringHandleOutputTypes(value []tf.DataType) IteratorFromStringHandleAttr { return func(m optionalAttr) { - m["cancel_pending_enqueues"] = value + m["output_types"] = value } } -// Closes the given queue. +// IteratorFromStringHandleOutputShapes sets the optional output_shapes attribute to value. // -// This operation signals that no more elements will be enqueued in the -// given queue. Subsequent Enqueue(Many) operations will fail. -// Subsequent Dequeue(Many) operations will continue to succeed if -// sufficient elements remain in the queue. Subsequent Dequeue(Many) -// operations that would block will fail immediately. +// value: If specified, defines the shape of each tuple component in an +// element produced by the resulting iterator. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func IteratorFromStringHandleOutputShapes(value []tf.Shape) IteratorFromStringHandleAttr { + return func(m optionalAttr) { + m["output_shapes"] = value + } +} + +// Converts the given string representing a handle to an iterator to a resource. // // Arguments: -// handle: The handle to a queue. +// string_handle: A string representation of the given handle. // -// Returns the created operation. -func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr) (o *tf.Operation) { +// Returns A handle to an iterator resource. +func IteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional ...IteratorFromStringHandleAttr) (resource_handle tf.Output) { if scope.Err() != nil { return } @@ -22369,743 +22325,747 @@ func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr) a(attrs) } opspec := tf.OpSpec{ - Type: "QueueCloseV2", + Type: "IteratorFromStringHandle", Input: []tf.Input{ - handle, + string_handle, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Computes inverse hyperbolic tangent of x element-wise. -func Atanh(scope *Scope, x tf.Output) (y tf.Output) { +// Computes arctangent of `y/x` element-wise, respecting signs of the arguments. +// +// This is the angle \( \theta \in [-\pi, \pi] \) such that +// \[ x = r \cos(\theta) \] +// and +// \[ y = r \sin(\theta) \] +// where \(r = \sqrt(x^2 + y^2) \). +func Atan2(scope *Scope, y tf.Output, x tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Atanh", + Type: "Atan2", Input: []tf.Input{ - x, + y, x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns true if queue is closed. -// -// This operation returns true if the queue is closed and false if the queue -// is open. -// -// Arguments: -// handle: The handle to a queue. -func QueueIsClosedV2(scope *Scope, handle tf.Output) (is_closed tf.Output) { +// Return a tensor with the same shape and contents as the input tensor or value. +func Identity(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "QueueIsClosedV2", + Type: "Identity", Input: []tf.Input{ - handle, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the batched diagonal part of a batched tensor. -// -// This operation returns a tensor with the `diagonal` part -// of the batched `input`. The `diagonal` part is computed as follows: -// -// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a -// tensor of rank `k - 1` with dimensions `[I, J, K, ..., min(M, N)]` where: +// Gather slices from `params` axis `axis` according to `indices`. // -// `diagonal[i, j, k, ..., n] = input[i, j, k, ..., n, n]`. +// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). +// Produces an output tensor with shape `params.shape[:axis] + indices.shape + +// params.shape[axis + 1:]` where: // -// The input must be at least a matrix. +// ```python +// # Scalar indices (output is rank(params) - 1). +// output[a_0, ..., a_n, b_0, ..., b_n] = +// params[a_0, ..., a_n, indices, b_0, ..., b_n] // -// For example: +// # Vector indices (output is rank(params)). +// output[a_0, ..., a_n, i, b_0, ..., b_n] = +// params[a_0, ..., a_n, indices[i], b_0, ..., b_n] // +// # Higher rank indices (output is rank(params) + rank(indices) - 1). +// output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] = +// params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n] // ``` -// # 'input' is [[[1, 0, 0, 0] -// [0, 2, 0, 0] -// [0, 0, 3, 0] -// [0, 0, 0, 4]], -// [[5, 0, 0, 0] -// [0, 6, 0, 0] -// [0, 0, 7, 0] -// [0, 0, 0, 8]]] -// -// and input.shape = (2, 4, 4) -// -// tf.matrix_diag_part(input) ==> [[1, 2, 3, 4], [5, 6, 7, 8]] // -// which has shape (2, 4) -// ``` +//
+// +//
// // Arguments: -// input: Rank `k` tensor where `k >= 2`. -// -// Returns The extracted diagonal(s) having shape -// `diagonal.shape = input.shape[:-2] + [min(input.shape[-2:])]`. -func MatrixDiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MatrixDiagPart", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the absolute value of a tensor. +// params: The tensor from which to gather values. Must be at least rank +// `axis + 1`. +// indices: Index tensor. Must be in range `[0, params.shape[axis])`. +// axis: The axis in `params` to gather `indices` from. Defaults to the first +// dimension. Supports negative indexes. // -// Given a tensor `x`, this operation returns a tensor containing the absolute -// value of each element in `x`. For example, if x is an input element and y is -// an output element, this operation computes \\(y = |x|\\). -func Abs(scope *Scope, x tf.Output) (y tf.Output) { +// Returns Values from `params` gathered from indices given by `indices`, with +// shape `params.shape[:axis] + indices.shape + params.shape[axis + 1:]`. +func GatherV2(scope *Scope, params tf.Output, indices tf.Output, axis tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Abs", + Type: "GatherV2", Input: []tf.Input{ - x, + params, indices, axis, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Flushes and closes the summary writer. -// -// Also removes it from the resource manager. To reopen, use another -// CreateSummaryFileWriter op. -// -// Arguments: -// writer: A handle to the summary writer resource. -// -// Returns the created operation. -func CloseSummaryWriter(scope *Scope, writer tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "CloseSummaryWriter", - Input: []tf.Input{ - writer, - }, - } - return scope.AddOperation(opspec) -} - -// StackV2Attr is an optional argument to StackV2. -type StackV2Attr func(optionalAttr) - -// StackV2StackName sets the optional stack_name attribute to value. -// -// value: Overrides the name used for the temporary stack resource. Default -// value is the name of the 'Stack' op (which is guaranteed unique). -// If not specified, defaults to "" -func StackV2StackName(value string) StackV2Attr { - return func(m optionalAttr) { - m["stack_name"] = value - } -} - -// A stack that produces elements in first-in last-out order. +// Converts the given `resource_handle` representing an iterator to a variant tensor. // // Arguments: -// max_size: The maximum size of the stack if non-negative. If negative, the stack -// size is unlimited. -// elem_type: The type of the elements on the stack. +// resource_handle: A handle to an iterator resource. // -// Returns The handle to the stack. -func StackV2(scope *Scope, max_size tf.Output, elem_type tf.DataType, optional ...StackV2Attr) (handle tf.Output) { +// Returns A variant tensor storing the state of the iterator contained in the +// resource. +func SerializeIterator(scope *Scope, resource_handle tf.Output) (serialized tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"elem_type": elem_type} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "StackV2", + Type: "SerializeIterator", Input: []tf.Input{ - max_size, + resource_handle, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// OrderedMapStageAttr is an optional argument to OrderedMapStage. -type OrderedMapStageAttr func(optionalAttr) +// FIFOQueueV2Attr is an optional argument to FIFOQueueV2. +type FIFOQueueV2Attr func(optionalAttr) -// OrderedMapStageCapacity sets the optional capacity attribute to value. +// FIFOQueueV2Shapes sets the optional shapes attribute to value. // -// value: Maximum number of elements in the Staging Area. If > 0, inserts -// on the container will block when the capacity is reached. -// If not specified, defaults to 0 +// value: The shape of each component in a value. The length of this attr must +// be either 0 or the same as the length of component_types. If the length of +// this attr is 0, the shapes of queue elements are not constrained, and +// only one element may be dequeued at a time. +// If not specified, defaults to <> // -// REQUIRES: value >= 0 -func OrderedMapStageCapacity(value int64) OrderedMapStageAttr { +// REQUIRES: len(value) >= 0 +func FIFOQueueV2Shapes(value []tf.Shape) FIFOQueueV2Attr { return func(m optionalAttr) { - m["capacity"] = value + m["shapes"] = value } } -// OrderedMapStageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// FIFOQueueV2Capacity sets the optional capacity attribute to value. // -// REQUIRES: value >= 0 -func OrderedMapStageMemoryLimit(value int64) OrderedMapStageAttr { +// value: The upper bound on the number of elements in this queue. +// Negative numbers mean no limit. +// If not specified, defaults to -1 +func FIFOQueueV2Capacity(value int64) FIFOQueueV2Attr { return func(m optionalAttr) { - m["memory_limit"] = value + m["capacity"] = value } } -// OrderedMapStageContainer sets the optional container attribute to value. +// FIFOQueueV2Container sets the optional container attribute to value. // -// value: If non-empty, this queue is placed in the given container. Otherwise, -// a default container is used. +// value: If non-empty, this queue is placed in the given container. +// Otherwise, a default container is used. // If not specified, defaults to "" -func OrderedMapStageContainer(value string) OrderedMapStageAttr { +func FIFOQueueV2Container(value string) FIFOQueueV2Attr { return func(m optionalAttr) { m["container"] = value } } -// OrderedMapStageSharedName sets the optional shared_name attribute to value. +// FIFOQueueV2SharedName sets the optional shared_name attribute to value. // -// value: It is necessary to match this name to the matching Unstage Op. +// value: If non-empty, this queue will be shared under the given name +// across multiple sessions. // If not specified, defaults to "" -func OrderedMapStageSharedName(value string) OrderedMapStageAttr { +func FIFOQueueV2SharedName(value string) FIFOQueueV2Attr { return func(m optionalAttr) { m["shared_name"] = value } } -// Stage (key, values) in the underlying container which behaves like a ordered -// -// associative container. Elements are ordered by key. +// A queue that produces elements in first-in first-out order. // // Arguments: -// key: int64 -// -// values: a list of tensors -// dtypes A list of data types that inserted values should adhere to. -// +// component_types: The type of each component in a value. // -// Returns the created operation. -func OrderedMapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf.Output, dtypes []tf.DataType, optional ...OrderedMapStageAttr) (o *tf.Operation) { +// Returns The handle to the queue. +func FIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...FIFOQueueV2Attr) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{"component_types": component_types} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "OrderedMapStage", - Input: []tf.Input{ - key, indices, tf.OutputList(values), - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// StackPushV2Attr is an optional argument to StackPushV2. -type StackPushV2Attr func(optionalAttr) + Type: "FIFOQueueV2", -// StackPushV2SwapMemory sets the optional swap_memory attribute to value. -// -// value: Swap `elem` to CPU. Default to false. -// If not specified, defaults to false -func StackPushV2SwapMemory(value bool) StackPushV2Attr { - return func(m optionalAttr) { - m["swap_memory"] = value + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Push an element onto the stack. -// -// Arguments: -// handle: The handle to a stack. -// elem: The tensor to be pushed onto the stack. -// -// Returns The same tensor as the input 'elem'. -func StackPushV2(scope *Scope, handle tf.Output, elem tf.Output, optional ...StackPushV2Attr) (output tf.Output) { +// Produces a summary of any statistics recorded by the given statistics manager. +func StatsAggregatorSummary(scope *Scope, iterator tf.Output) (summary tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "StackPushV2", + Type: "StatsAggregatorSummary", Input: []tf.Input{ - handle, elem, + iterator, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// FusedBatchNormGradV2Attr is an optional argument to FusedBatchNormGradV2. -type FusedBatchNormGradV2Attr func(optionalAttr) - -// FusedBatchNormGradV2Epsilon sets the optional epsilon attribute to value. +// Compute the pairwise cross product. // -// value: A small float number added to the variance of x. -// If not specified, defaults to 0.0001 -func FusedBatchNormGradV2Epsilon(value float32) FusedBatchNormGradV2Attr { - return func(m optionalAttr) { - m["epsilon"] = value - } -} - -// FusedBatchNormGradV2DataFormat sets the optional data_format attribute to value. +// `a` and `b` must be the same shape; they can either be simple 3-element vectors, +// or any shape where the innermost dimension is 3. In the latter case, each pair +// of corresponding 3-element vectors is cross-multiplied independently. // -// value: The data format for y_backprop, x, x_backprop. -// Either "NHWC" (default) or "NCHW". -// If not specified, defaults to "NHWC" -func FusedBatchNormGradV2DataFormat(value string) FusedBatchNormGradV2Attr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// FusedBatchNormGradV2IsTraining sets the optional is_training attribute to value. +// Arguments: +// a: A tensor containing 3-element vectors. +// b: Another tensor, of same type and shape as `a`. // -// value: A bool value to indicate the operation is for training (default) -// or inference. -// If not specified, defaults to true -func FusedBatchNormGradV2IsTraining(value bool) FusedBatchNormGradV2Attr { - return func(m optionalAttr) { - m["is_training"] = value +// Returns Pairwise cross product of the vectors in `a` and `b`. +func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Cross", + Input: []tf.Input{ + a, b, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Gradient for batch normalization. +// Performs a padding as a preprocess during a convolution. // -// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -// The size of 1D Tensors matches the dimension C of the 4D Tensors. +// Similar to FusedResizeAndPadConv2d, this op allows for an optimized +// implementation where the spatial padding transformation stage is fused with the +// im2col lookup, but in this case without the bilinear filtering required for +// resizing. Fusing the padding prevents the need to write out the intermediate +// results as whole tensors, reducing memory pressure, and we can get some latency +// gains by merging the transformation calculations. +// The data_format attribute for Conv2D isn't supported by this op, and 'NHWC' +// order is used instead. +// Internally this op uses a single per-graph scratch buffer, which means that it +// will block if multiple versions are being run in parallel. This is because this +// operator is primarily an optimization to minimize memory usage. // // Arguments: -// y_backprop: A 4D Tensor for the gradient with respect to y. -// x: A 4D Tensor for input data. -// scale: A 1D Tensor for scaling factor, to scale the normalized x. -// reserve_space_1: When is_training is True, a 1D Tensor for the computed batch -// mean to be reused in gradient computation. When is_training is -// False, a 1D Tensor for the population mean to be reused in both -// 1st and 2nd order gradient computation. -// reserve_space_2: When is_training is True, a 1D Tensor for the computed batch -// variance (inverted variance in the cuDNN case) to be reused in -// gradient computation. When is_training is False, a 1D Tensor -// for the population variance to be reused in both 1st and 2nd -// order gradient computation. +// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. +// paddings: A two-column matrix specifying the padding sizes. The number of +// rows must be the same as the rank of `input`. +// filter: 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. // -// Returns A 4D Tensor for the gradient with respect to x.A 1D Tensor for the gradient with respect to scale.A 1D Tensor for the gradient with respect to offset.Unused placeholder to match the mean input in FusedBatchNorm.Unused placeholder to match the variance input -// in FusedBatchNorm. -func FusedBatchNormGradV2(scope *Scope, y_backprop tf.Output, x tf.Output, scale tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output, optional ...FusedBatchNormGradV2Attr) (x_backprop tf.Output, scale_backprop tf.Output, offset_backprop tf.Output, reserve_space_3 tf.Output, reserve_space_4 tf.Output) { +// strides: 1-D of length 4. The stride of the sliding window for each dimension +// of `input`. Must be in the same order as the dimension specified with format. +// padding: The type of padding algorithm to use. +func FusedPadConv2D(scope *Scope, input tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "FusedBatchNormGradV2", + Type: "FusedPadConv2D", Input: []tf.Input{ - y_backprop, x, scale, reserve_space_1, reserve_space_2, + input, paddings, filter, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) + return op.Output(0) } -// Creates a TensorArray for storing the gradients of values in the given handle. -// -// If the given TensorArray gradient already exists, returns a reference to it. -// -// Locks the size of the original TensorArray by disabling its dynamic size flag. -// -// **A note about the input flow_in:** -// -// The handle flow_in forces the execution of the gradient lookup to occur -// only after certain other operations have occurred. For example, when -// the forward TensorArray is dynamically sized, writes to this TensorArray -// may resize the object. The gradient TensorArray is statically sized based -// on the size of the forward TensorArray when this operation executes. -// Furthermore, the size of the forward TensorArray is frozen by this call. -// As a result, the flow is used to ensure that the call to generate the gradient -// TensorArray only happens after all writes are executed. -// -// In the case of dynamically sized TensorArrays, gradient computation should -// only be performed on read operations that have themselves been chained via -// flow to occur only after all writes have executed. That way the final size -// of the forward TensorArray is known when this operation is called. -// -// **A note about the source attribute:** -// -// TensorArray gradient calls use an accumulator TensorArray object. If -// multiple gradients are calculated and run in the same session, the multiple -// gradient nodes may accidentally flow through the same accumulator TensorArray. -// This double counts and generally breaks the TensorArray gradient flow. +// Conv2DBackpropInputAttr is an optional argument to Conv2DBackpropInput. +type Conv2DBackpropInputAttr func(optionalAttr) + +// Conv2DBackpropInputUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. +// If not specified, defaults to true +func Conv2DBackpropInputUseCudnnOnGpu(value bool) Conv2DBackpropInputAttr { + return func(m optionalAttr) { + m["use_cudnn_on_gpu"] = value + } +} + +// Conv2DBackpropInputDataFormat sets the optional data_format attribute to value. // -// The solution is to identify which gradient call this particular -// TensorArray gradient is being called in. This is performed by identifying -// a unique string (e.g. "gradients", "gradients_1", ...) from the input -// gradient Tensor's name. This string is used as a suffix when creating -// the TensorArray gradient object here (the attribute `source`). +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Conv2DBackpropInputDilations sets the optional dilations attribute to value. // -// The attribute `source` is added as a suffix to the forward TensorArray's -// name when performing the creation / lookup, so that each separate gradient -// calculation gets its own TensorArray accumulator. +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to +func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of convolution with respect to the input. // // Arguments: -// handle: The handle to the forward TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. -// source: The gradient source string, used to decide which gradient TensorArray -// to return. -func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) { +// input_sizes: An integer vector representing the shape of `input`, +// where `input` is a 4-D `[batch, height, width, channels]` tensor. +// filter: 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. +// out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. +// Gradients w.r.t. the output of the convolution. +// strides: The stride of the sliding window for each dimension of the input +// of the convolution. Must be in the same order as the dimension specified with +// format. +// padding: The type of padding algorithm to use. +// +// Returns 4-D with shape `[batch, in_height, in_width, in_channels]`. Gradient +// w.r.t. the input of the convolution. +func Conv2DBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv2DBackpropInputAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"source": source} + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TensorArrayGradV3", + Type: "Conv2DBackpropInput", Input: []tf.Input{ - handle, flow_in, + input_sizes, filter, out_backprop, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Compare values of `input` to `threshold` and pack resulting bits into a `uint8`. +// Interleave the values from the `data` tensors into a single tensor. // -// Each comparison returns a boolean `true` (if `input_value > threshold`) -// or and `false` otherwise. +// Builds a merged tensor such that // -// This operation is useful for Locality-Sensitive-Hashing (LSH) and other -// algorithms that use hashing approximations of cosine and `L2` distances; -// codes can be generated from an input via: +// ```python +// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] +// ``` +// +// For example, if each `indices[m]` is scalar or vector, we have // // ```python -// codebook_size = 50 -// codebook_bits = codebook_size * 32 -// codebook = tf.get_variable('codebook', [x.shape[-1].value, codebook_bits], -// dtype=x.dtype, -// initializer=tf.orthogonal_initializer()) -// codes = compare_and_threshold(tf.matmul(x, codebook), threshold=0.) -// codes = tf.bitcast(codes, tf.int32) # go from uint8 to int32 -// # now codes has shape x.shape[:-1] + [codebook_size] +// # Scalar indices: +// merged[indices[m], ...] = data[m][...] +// +// # Vector indices: +// merged[indices[m][i], ...] = data[m][i, ...] // ``` // -// **NOTE**: Currently, the innermost dimension of the tensor must be divisible -// by 8. +// Each `data[i].shape` must start with the corresponding `indices[i].shape`, +// and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we +// must have `data[i].shape = indices[i].shape + constant`. In terms of this +// `constant`, the output shape is // -// Given an `input` shaped `[s0, s1, ..., s_n]`, the output is -// a `uint8` tensor shaped `[s0, s1, ..., s_n / 8]`. +// merged.shape = [max(indices)] + constant // -// Arguments: -// input: Values to compare against `threshold` and bitpack. -// threshold: Threshold to compare against. +// Values are merged in order, so if an index appears in both `indices[m][i]` and +// `indices[n][j]` for `(m,i) < (n,j)` the slice `data[n][j]` will appear in the +// merged result. If you do not need this guarantee, ParallelDynamicStitch might +// perform better on some devices. // -// Returns The bitpacked comparisons. -func CompareAndBitpack(scope *Scope, input tf.Output, threshold tf.Output) (output tf.Output) { +// For example: +// +// ```python +// indices[0] = 6 +// indices[1] = [4, 1] +// indices[2] = [[5, 2], [0, 3]] +// data[0] = [61, 62] +// data[1] = [[41, 42], [11, 12]] +// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] +// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], +// [51, 52], [61, 62]] +// ``` +// +// This method can be used to merge partitions created by `dynamic_partition` +// as illustrated on the following example: +// +// ```python +// # Apply function (increments x_i) on elements for which a certain condition +// # apply (x_i != -1 in this example). +// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) +// condition_mask=tf.not_equal(x,tf.constant(-1.)) +// partitioned_data = tf.dynamic_partition( +// x, tf.cast(condition_mask, tf.int32) , 2) +// partitioned_data[1] = partitioned_data[1] + 1.0 +// condition_indices = tf.dynamic_partition( +// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) +// x = tf.dynamic_stitch(condition_indices, partitioned_data) +// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain +// # unchanged. +// ``` +// +//
+// +//
+func DynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "CompareAndBitpack", + Type: "DynamicStitch", Input: []tf.Input{ - input, threshold, + tf.OutputList(indices), tf.OutputList(data), }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Push an element onto the tensor_array. -// -// Arguments: -// handle: The handle to a TensorArray. -// index: The position to write to inside the TensorArray. -// value: The tensor to write to the TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. +// Returns the truth value of (x == y) element-wise. // -// Returns A float scalar that enforces proper chaining of operations. -func TensorArrayWriteV3(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// *NOTE*: `Equal` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Equal(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorArrayWriteV3", + Type: "Equal", Input: []tf.Input{ - handle, index, value, flow_in, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Scatter the data from the input value into specific TensorArray elements. -// -// `indices` must be a vector, its length must match the first dim of `value`. -// -// Arguments: -// handle: The handle to a TensorArray. -// indices: The locations at which to write the tensor elements. -// value: The concatenated tensor to write to the TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. +// TensorArrayGatherV2Attr is an optional argument to TensorArrayGatherV2. +type TensorArrayGatherV2Attr func(optionalAttr) + +// TensorArrayGatherV2ElementShape sets the optional element_shape attribute to value. +// If not specified, defaults to +func TensorArrayGatherV2ElementShape(value tf.Shape) TensorArrayGatherV2Attr { + return func(m optionalAttr) { + m["element_shape"] = value + } +} + +// Deprecated. Use TensorArrayGatherV3 // -// Returns A float scalar that enforces proper chaining of operations. -func TensorArrayScatterV3(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// DEPRECATED at GraphDef version 26: Use TensorArrayGatherV3 +func TensorArrayGatherV2(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV2Attr) (value tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TensorArrayScatterV3", + Type: "TensorArrayGatherV2", Input: []tf.Input{ - handle, indices, value, flow_in, + handle, indices, flow_in, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// TensorArrayConcatV3Attr is an optional argument to TensorArrayConcatV3. -type TensorArrayConcatV3Attr func(optionalAttr) - -// TensorArrayConcatV3ElementShapeExcept0 sets the optional element_shape_except0 attribute to value. +// Interleave the values from the `data` tensors into a single tensor. // -// value: The expected shape of an element, if known, -// excluding the first dimension. Used to validate the shapes of -// TensorArray elements. If this shape is not fully specified, concatenating -// zero-size TensorArrays is an error. -// If not specified, defaults to -func TensorArrayConcatV3ElementShapeExcept0(value tf.Shape) TensorArrayConcatV3Attr { - return func(m optionalAttr) { - m["element_shape_except0"] = value - } -} - -// Concat the elements from the TensorArray into value `value`. +// Builds a merged tensor such that // -// Takes `T` elements of shapes +// ```python +// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] +// ``` // -// ``` -// (n0 x d0 x d1 x ...), (n1 x d0 x d1 x ...), ..., (n(T-1) x d0 x d1 x ...) -// ``` +// For example, if each `indices[m]` is scalar or vector, we have // -// and concatenates them into a Tensor of shape: +// ```python +// # Scalar indices: +// merged[indices[m], ...] = data[m][...] // -// ```(n0 + n1 + ... + n(T-1) x d0 x d1 x ...)``` +// # Vector indices: +// merged[indices[m][i], ...] = data[m][i, ...] +// ``` // -// All elements must have the same shape (excepting the first dimension). +// Each `data[i].shape` must start with the corresponding `indices[i].shape`, +// and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we +// must have `data[i].shape = indices[i].shape + constant`. In terms of this +// `constant`, the output shape is // -// Arguments: -// handle: The handle to a TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. -// dtype: The type of the elem that is returned. +// merged.shape = [max(indices)] + constant +// +// Values may be merged in parallel, so if an index appears in both `indices[m][i]` +// and `indices[n][j]`, the result may be invalid. This differs from the normal +// DynamicStitch operator that defines the behavior in that case. +// +// For example: +// +// ```python +// indices[0] = 6 +// indices[1] = [4, 1] +// indices[2] = [[5, 2], [0, 3]] +// data[0] = [61, 62] +// data[1] = [[41, 42], [11, 12]] +// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] +// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], +// [51, 52], [61, 62]] +// ``` +// +// This method can be used to merge partitions created by `dynamic_partition` +// as illustrated on the following example: +// +// ```python +// # Apply function (increments x_i) on elements for which a certain condition +// # apply (x_i != -1 in this example). +// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) +// condition_mask=tf.not_equal(x,tf.constant(-1.)) +// partitioned_data = tf.dynamic_partition( +// x, tf.cast(condition_mask, tf.int32) , 2) +// partitioned_data[1] = partitioned_data[1] + 1.0 +// condition_indices = tf.dynamic_partition( +// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) +// x = tf.dynamic_stitch(condition_indices, partitioned_data) +// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain +// # unchanged. +// ``` // -// Returns All of the elements in the TensorArray, concatenated along the first -// axis.A vector of the row sizes of the original T elements in the -// value output. In the example above, this would be the values: -// `(n1, n2, ..., n(T-1))`. -func TensorArrayConcatV3(scope *Scope, handle tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayConcatV3Attr) (value tf.Output, lengths tf.Output) { +//
+// +//
+func ParallelDynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) + opspec := tf.OpSpec{ + Type: "ParallelDynamicStitch", + Input: []tf.Input{ + tf.OutputList(indices), tf.OutputList(data), + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the gradient for the inverse of `x` wrt its input. +// +// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` +// is the corresponding input gradient. +func InvGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { + if scope.Err() != nil { + return } opspec := tf.OpSpec{ - Type: "TensorArrayConcatV3", + Type: "InvGrad", Input: []tf.Input{ - handle, flow_in, + y, dy, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// ParameterizedTruncatedNormalAttr is an optional argument to ParameterizedTruncatedNormal. -type ParameterizedTruncatedNormalAttr func(optionalAttr) +// StridedSliceAttr is an optional argument to StridedSlice. +type StridedSliceAttr func(optionalAttr) -// ParameterizedTruncatedNormalSeed sets the optional seed attribute to value. +// StridedSliceBeginMask sets the optional begin_mask attribute to value. // -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. +// value: a bitmask where a bit i being 1 means to ignore the begin +// value and instead use the largest interval possible. At runtime +// begin[i] will be replaced with `[0, n-1) if `stride[i] > 0` or +// `[-1, n-1]` if `stride[i] < 0` // If not specified, defaults to 0 -func ParameterizedTruncatedNormalSeed(value int64) ParameterizedTruncatedNormalAttr { +func StridedSliceBeginMask(value int64) StridedSliceAttr { return func(m optionalAttr) { - m["seed"] = value + m["begin_mask"] = value } } -// ParameterizedTruncatedNormalSeed2 sets the optional seed2 attribute to value. +// StridedSliceEndMask sets the optional end_mask attribute to value. // -// value: A second seed to avoid seed collision. +// value: analogous to `begin_mask` // If not specified, defaults to 0 -func ParameterizedTruncatedNormalSeed2(value int64) ParameterizedTruncatedNormalAttr { +func StridedSliceEndMask(value int64) StridedSliceAttr { return func(m optionalAttr) { - m["seed2"] = value + m["end_mask"] = value } } -// Outputs random values from a normal distribution. The parameters may each be a -// -// scalar which applies to the entire output, or a vector of length shape[0] which -// stores the parameters for each batch. -// -// Arguments: -// shape: The shape of the output tensor. Batches are indexed by the 0th dimension. -// means: The mean parameter of each batch. -// stdevs: The standard deviation parameter of each batch. Must be greater than 0. -// minvals: The minimum cutoff. May be -infinity. -// maxvals: The maximum cutoff. May be +infinity, and must be more than the minval -// for each batch. +// StridedSliceEllipsisMask sets the optional ellipsis_mask attribute to value. // -// Returns A matrix of shape num_batches x samples_per_batch, filled with random -// truncated normal values using the parameters for each row. -func ParameterizedTruncatedNormal(scope *Scope, shape tf.Output, means tf.Output, stdevs tf.Output, minvals tf.Output, maxvals tf.Output, optional ...ParameterizedTruncatedNormalAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ParameterizedTruncatedNormal", - Input: []tf.Input{ - shape, means, stdevs, minvals, maxvals, - }, - Attrs: attrs, +// value: a bitmask where bit `i` being 1 means the `i`th +// position is actually an ellipsis. One bit at most can be 1. +// If `ellipsis_mask == 0`, then an implicit ellipsis mask of `1 << (m+1)` +// is provided. This means that `foo[3:5] == foo[3:5, ...]`. An ellipsis +// implicitly creates as many range specifications as necessary to fully +// specify the sliced range for every dimension. For example for a 4-dimensional +// tensor `foo` the slice `foo[2, ..., 5:8]` implies `foo[2, :, :, 5:8]`. +// If not specified, defaults to 0 +func StridedSliceEllipsisMask(value int64) StridedSliceAttr { + return func(m optionalAttr) { + m["ellipsis_mask"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Returns a diagonal tensor with a given diagonal values. +// StridedSliceNewAxisMask sets the optional new_axis_mask attribute to value. // -// Given a `diagonal`, this operation returns a tensor with the `diagonal` and -// everything else padded with zeros. The diagonal is computed as follows: +// value: a bitmask where bit `i` being 1 means the `i`th +// specification creates a new shape 1 dimension. For example +// `foo[:4, tf.newaxis, :2]` would produce a shape `(4, 1, 2)` tensor. +// If not specified, defaults to 0 +func StridedSliceNewAxisMask(value int64) StridedSliceAttr { + return func(m optionalAttr) { + m["new_axis_mask"] = value + } +} + +// StridedSliceShrinkAxisMask sets the optional shrink_axis_mask attribute to value. // -// Assume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of -// rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where: +// value: a bitmask where bit `i` implies that the `i`th +// specification should shrink the dimensionality. begin and end +// must imply a slice of size 1 in the dimension. For example in +// python one might do `foo[:, 3, :]` which would result in +// `shrink_axis_mask` being 2. +// If not specified, defaults to 0 +func StridedSliceShrinkAxisMask(value int64) StridedSliceAttr { + return func(m optionalAttr) { + m["shrink_axis_mask"] = value + } +} + +// Return a strided slice from `input`. // -// `output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else. +// Note, most python users will want to use the Python `Tensor.__getitem__` +// or `Variable.__getitem__` rather than this op directly. // -// For example: +// The goal of this op is to produce a new tensor with a subset of +// the elements from the `n` dimensional `input` tensor. The subset is chosen using +// a sequence of `m` sparse range specifications encoded into the arguments +// of this function. Note, in some cases +// `m` could be equal to `n`, but this need not be the case. Each +// range specification entry can be one of the following: // -// ``` -// # 'diagonal' is [1, 2, 3, 4] -// tf.diag(diagonal) ==> [[1, 0, 0, 0] -// [0, 2, 0, 0] -// [0, 0, 3, 0] -// [0, 0, 0, 4]] -// ``` +// - An ellipsis (...). Ellipses are used to imply zero or more +// dimensions of full-dimension selection and are produced using +// `ellipsis_mask`. For example, `foo[...]` is the identity slice. // -// Arguments: -// diagonal: Rank k tensor where k is at most 1. -func Diag(scope *Scope, diagonal tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Diag", - Input: []tf.Input{ - diagonal, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Split the data from the input value into TensorArray elements. +// - A new axis. This is used to insert a new shape=1 dimension and is +// produced using `new_axis_mask`. For example, `foo[:, ...]` where +// `foo` is shape `(3, 4)` produces a `(1, 3, 4)` tensor. // -// Assuming that `lengths` takes on values // -// ```(n0, n1, ..., n(T-1))``` +// - A range `begin:end:stride`. This is used to specify how much to choose from +// a given dimension. `stride` can be any integer but 0. `begin` is an integer +// which represents the index of the first value to select while `end` represents +// the index of the last value to select. The number of values selected in each +// dimension is `end - begin` if `stride > 0` and `begin - end` if `stride < 0`. +// `begin` and `end` can be negative where `-1` is the last element, `-2` is +// the second to last. `begin_mask` controls whether to replace the explicitly +// given `begin` with an implicit effective value of `0` if `stride > 0` and +// `-1` if `stride < 0`. `end_mask` is analogous but produces the number +// required to create the largest open interval. For example, given a shape +// `(3,)` tensor `foo[:]`, the effective `begin` and `end` are `0` and `3`. Do +// not assume this is equivalent to `foo[0:-1]` which has an effective `begin` +// and `end` of `0` and `2`. Another example is `foo[-2::-1]` which reverses the +// first dimension of a tensor while dropping the last two (in the original +// order elements). For example `foo = [1,2,3,4]; foo[-2::-1]` is `[4,3]`. // -// and that `value` has shape +// - A single index. This is used to keep only elements that have a given +// index. For example (`foo[2, :]` on a shape `(5,6)` tensor produces a +// shape `(6,)` tensor. This is encoded in `begin` and `end` and +// `shrink_axis_mask`. // -// ```(n0 + n1 + ... + n(T-1) x d0 x d1 x ...)```, +// Each conceptual range specification is encoded in the op's argument. This +// encoding is best understand by considering a non-trivial example. In +// particular, +// `foo[1, 2:4, None, ..., :-3:-1, :]` will be encoded as // -// this splits values into a TensorArray with T tensors. +// ``` +// begin = [1, 2, x, x, 0, x] # x denotes don't care (usually 0) +// end = [2, 4, x, x, -3, x] +// strides = [1, 1, x, x, -1, 1] +// begin_mask = 1<<4 | 1 << 5 = 48 +// end_mask = 1<<5 = 32 +// ellipsis_mask = 1<<3 = 8 +// new_axis_mask = 1<<2 4 +// shrink_axis_mask = 1<<0 +// ``` // -// TensorArray index t will be the subtensor of values with starting position +// In this case if `foo.shape` is (5, 5, 5, 5, 5, 5) the final shape of +// the slice becomes (2, 1, 5, 5, 2, 5). +// Let us walk step by step through each argument specification. // -// ```(n0 + n1 + ... + n(t-1), 0, 0, ...)``` +// 1. The first argument in the example slice is turned into `begin = 1` and +// `end = begin + 1 = 2`. To disambiguate from the original spec `2:4` we +// also set the appropriate bit in `shrink_axis_mask`. // -// and having size +// 2. `2:4` is contributes 2, 4, 1 to begin, end, and stride. All masks have +// zero bits contributed. // -// ```nt x d0 x d1 x ...``` +// 3. None is a synonym for `tf.newaxis`. This means insert a dimension of size 1 +// dimension in the final shape. Dummy values are contributed to begin, +// end and stride, while the new_axis_mask bit is set. // -// Arguments: -// handle: The handle to a TensorArray. -// value: The concatenated tensor to write to the TensorArray. -// lengths: The vector of lengths, how to split the rows of value into the -// TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. +// 4. `...` grab the full ranges from as many dimensions as needed to +// fully specify a slice for every dimension of the input shape. // -// Returns A float scalar that enforces proper chaining of operations. -func TensorArraySplitV3(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorArraySplitV3", - Input: []tf.Input{ - handle, value, lengths, flow_in, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// SerializeSparseAttr is an optional argument to SerializeSparse. -type SerializeSparseAttr func(optionalAttr) - -// SerializeSparseOutType sets the optional out_type attribute to value. +// 5. `:-3:-1` shows the use of negative indices. A negative index `i` associated +// with a dimension that has shape `s` is converted to a positive index +// `s + i`. So `-1` becomes `s-1` (i.e. the last element). This conversion +// is done internally so begin, end and strides receive x, -3, and -1. +// The appropriate begin_mask bit is set to indicate the start range is the +// full range (ignoring the x). // -// value: The `dtype` to use for serialization; the supported types are `string` -// (default) and `variant`. -// If not specified, defaults to DT_STRING -func SerializeSparseOutType(value tf.DataType) SerializeSparseAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Serialize a `SparseTensor` into a `[3]` `Tensor` object. +// 6. `:` indicates that the entire contents of the corresponding dimension +// is selected. This is equivalent to `::` or `0::1`. begin, end, and strides +// receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and +// `end_mask` are also set. +// +// *Requirements*: +// `0 != strides[i] for i in [0, m)` +// `ellipsis_mask must be a power of two (only one ellipsis)` // // Arguments: -// sparse_indices: 2-D. The `indices` of the `SparseTensor`. -// sparse_values: 1-D. The `values` of the `SparseTensor`. -// sparse_shape: 1-D. The `shape` of the `SparseTensor`. -func SerializeSparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeSparseAttr) (serialized_sparse tf.Output) { +// +// begin: `begin[k]` specifies the offset into the `k`th range specification. +// The exact dimension this corresponds to will be determined by context. +// Out-of-bounds values will be silently clamped. If the `k`th bit of +// `begin_mask` then `begin[k]` is ignored and the full range of the +// appropriate dimension is used instead. Negative values causes indexing +// to start from the highest element e.g. If `foo==[1,2,3]` then `foo[-1]==3`. +// end: `end[i]` is like `begin` with the exception that `end_mask` is +// used to determine full ranges. +// strides: `strides[i]` specifies the increment in the `i`th specification +// after extracting a given element. Negative indices will reverse +// the original order. Out or range values are +// clamped to `[0,dim[i]) if slice[i]>0` or `[-1,dim[i]-1] if slice[i] < 0` +func StridedSlice(scope *Scope, input tf.Output, begin tf.Output, end tf.Output, strides tf.Output, optional ...StridedSliceAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -23114,9 +23074,9 @@ func SerializeSparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Ou a(attrs) } opspec := tf.OpSpec{ - Type: "SerializeSparse", + Type: "StridedSlice", Input: []tf.Input{ - sparse_indices, sparse_values, sparse_shape, + input, begin, end, strides, }, Attrs: attrs, } @@ -23124,318 +23084,403 @@ func SerializeSparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Ou return op.Output(0) } -// RandomShuffleQueueV2Attr is an optional argument to RandomShuffleQueueV2. -type RandomShuffleQueueV2Attr func(optionalAttr) +// PriorityQueueV2Attr is an optional argument to PriorityQueueV2. +type PriorityQueueV2Attr func(optionalAttr) -// RandomShuffleQueueV2Shapes sets the optional shapes attribute to value. +// PriorityQueueV2ComponentTypes sets the optional component_types attribute to value. // -// value: The shape of each component in a value. The length of this attr must -// be either 0 or the same as the length of component_types. If the length of -// this attr is 0, the shapes of queue elements are not constrained, and -// only one element may be dequeued at a time. +// value: The type of each component in a value. // If not specified, defaults to <> // // REQUIRES: len(value) >= 0 -func RandomShuffleQueueV2Shapes(value []tf.Shape) RandomShuffleQueueV2Attr { +func PriorityQueueV2ComponentTypes(value []tf.DataType) PriorityQueueV2Attr { return func(m optionalAttr) { - m["shapes"] = value + m["component_types"] = value } } -// RandomShuffleQueueV2Capacity sets the optional capacity attribute to value. +// PriorityQueueV2Capacity sets the optional capacity attribute to value. // // value: The upper bound on the number of elements in this queue. // Negative numbers mean no limit. // If not specified, defaults to -1 -func RandomShuffleQueueV2Capacity(value int64) RandomShuffleQueueV2Attr { +func PriorityQueueV2Capacity(value int64) PriorityQueueV2Attr { return func(m optionalAttr) { m["capacity"] = value } } -// RandomShuffleQueueV2MinAfterDequeue sets the optional min_after_dequeue attribute to value. +// PriorityQueueV2Container sets the optional container attribute to value. // -// value: Dequeue will block unless there would be this -// many elements after the dequeue or the queue is closed. This -// ensures a minimum level of mixing of elements. -// If not specified, defaults to 0 -func RandomShuffleQueueV2MinAfterDequeue(value int64) RandomShuffleQueueV2Attr { +// value: If non-empty, this queue is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func PriorityQueueV2Container(value string) PriorityQueueV2Attr { return func(m optionalAttr) { - m["min_after_dequeue"] = value + m["container"] = value } } -// RandomShuffleQueueV2Seed sets the optional seed attribute to value. +// PriorityQueueV2SharedName sets the optional shared_name attribute to value. // -// value: If either seed or seed2 is set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, a random seed is used. -// If not specified, defaults to 0 -func RandomShuffleQueueV2Seed(value int64) RandomShuffleQueueV2Attr { +// value: If non-empty, this queue will be shared under the given name +// across multiple sessions. +// If not specified, defaults to "" +func PriorityQueueV2SharedName(value string) PriorityQueueV2Attr { return func(m optionalAttr) { - m["seed"] = value + m["shared_name"] = value } } -// RandomShuffleQueueV2Seed2 sets the optional seed2 attribute to value. +// A queue that produces elements sorted by the first component value. // -// value: A second seed to avoid seed collision. +// Note that the PriorityQueue requires the first component of any element +// to be a scalar int64, in addition to the other elements declared by +// component_types. Therefore calls to Enqueue and EnqueueMany (resp. Dequeue +// and DequeueMany) on a PriorityQueue will all require (resp. output) one extra +// entry in their input (resp. output) lists. +// +// Arguments: +// shapes: The shape of each component in a value. The length of this attr must +// be either 0 or the same as the length of component_types. If the length of +// this attr is 0, the shapes of queue elements are not constrained, and +// only one element may be dequeued at a time. +// +// Returns The handle to the queue. +func PriorityQueueV2(scope *Scope, shapes []tf.Shape, optional ...PriorityQueueV2Attr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"shapes": shapes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "PriorityQueueV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// UnstageAttr is an optional argument to Unstage. +type UnstageAttr func(optionalAttr) + +// UnstageCapacity sets the optional capacity attribute to value. // If not specified, defaults to 0 -func RandomShuffleQueueV2Seed2(value int64) RandomShuffleQueueV2Attr { +// +// REQUIRES: value >= 0 +func UnstageCapacity(value int64) UnstageAttr { return func(m optionalAttr) { - m["seed2"] = value + m["capacity"] = value } } -// RandomShuffleQueueV2Container sets the optional container attribute to value. +// UnstageMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// value: If non-empty, this queue is placed in the given container. -// Otherwise, a default container is used. +// REQUIRES: value >= 0 +func UnstageMemoryLimit(value int64) UnstageAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// UnstageContainer sets the optional container attribute to value. // If not specified, defaults to "" -func RandomShuffleQueueV2Container(value string) RandomShuffleQueueV2Attr { +func UnstageContainer(value string) UnstageAttr { return func(m optionalAttr) { m["container"] = value } } -// RandomShuffleQueueV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this queue will be shared under the given name -// across multiple sessions. +// UnstageSharedName sets the optional shared_name attribute to value. // If not specified, defaults to "" -func RandomShuffleQueueV2SharedName(value string) RandomShuffleQueueV2Attr { +func UnstageSharedName(value string) UnstageAttr { return func(m optionalAttr) { m["shared_name"] = value } } -// A queue that randomizes the order of elements. -// -// Arguments: -// component_types: The type of each component in a value. +// Op is similar to a lightweight Dequeue. // -// Returns The handle to the queue. -func RandomShuffleQueueV2(scope *Scope, component_types []tf.DataType, optional ...RandomShuffleQueueV2Attr) (handle tf.Output) { +// The basic functionality is similar to dequeue with many fewer +// capabilities and options. This Op is optimized for performance. +func Unstage(scope *Scope, dtypes []tf.DataType, optional ...UnstageAttr) (values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"component_types": component_types} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RandomShuffleQueueV2", + Type: "Unstage", Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("Unstage", err) + return + } + return values } -// Draw bounding boxes on a batch of images. -// -// Outputs a copy of `images` but draws on top of the pixels zero or more bounding -// boxes specified by the locations in `boxes`. The coordinates of the each -// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The -// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and -// height of the underlying image. -// -// For example, if an image is 100 x 200 pixels (height x width) and the bounding -// box is `[0.1, 0.2, 0.5, 0.9]`, the upper-left and bottom-right coordinates of -// the bounding box will be `(40, 10)` to `(100, 50)` (in (x,y) coordinates). +// ArgMaxAttr is an optional argument to ArgMax. +type ArgMaxAttr func(optionalAttr) + +// ArgMaxOutputType sets the optional output_type attribute to value. +// If not specified, defaults to DT_INT64 +func ArgMaxOutputType(value tf.DataType) ArgMaxAttr { + return func(m optionalAttr) { + m["output_type"] = value + } +} + +// Returns the index with the largest value across dimensions of a tensor. // -// Parts of the bounding box may fall outside the image. +// Note that in case of ties the identity of the return value is not guaranteed. // // Arguments: -// images: 4-D with shape `[batch, height, width, depth]`. A batch of images. -// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding -// boxes. // -// Returns 4-D with the same shape as `images`. The batch of input images with -// bounding boxes drawn on the images. -func DrawBoundingBoxes(scope *Scope, images tf.Output, boxes tf.Output) (output tf.Output) { +// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`. +// Describes which dimension of the input Tensor to reduce across. For vectors, +// use dimension = 0. +func ArgMax(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMaxAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "DrawBoundingBoxes", + Type: "ArgMax", Input: []tf.Input{ - images, boxes, + input, dimension, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// LearnedUnigramCandidateSamplerAttr is an optional argument to LearnedUnigramCandidateSampler. -type LearnedUnigramCandidateSamplerAttr func(optionalAttr) +// ResourceStridedSliceAssignAttr is an optional argument to ResourceStridedSliceAssign. +type ResourceStridedSliceAssignAttr func(optionalAttr) -// LearnedUnigramCandidateSamplerSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. +// ResourceStridedSliceAssignBeginMask sets the optional begin_mask attribute to value. // If not specified, defaults to 0 -func LearnedUnigramCandidateSamplerSeed(value int64) LearnedUnigramCandidateSamplerAttr { +func ResourceStridedSliceAssignBeginMask(value int64) ResourceStridedSliceAssignAttr { + return func(m optionalAttr) { + m["begin_mask"] = value + } +} + +// ResourceStridedSliceAssignEndMask sets the optional end_mask attribute to value. +// If not specified, defaults to 0 +func ResourceStridedSliceAssignEndMask(value int64) ResourceStridedSliceAssignAttr { + return func(m optionalAttr) { + m["end_mask"] = value + } +} + +// ResourceStridedSliceAssignEllipsisMask sets the optional ellipsis_mask attribute to value. +// If not specified, defaults to 0 +func ResourceStridedSliceAssignEllipsisMask(value int64) ResourceStridedSliceAssignAttr { + return func(m optionalAttr) { + m["ellipsis_mask"] = value + } +} + +// ResourceStridedSliceAssignNewAxisMask sets the optional new_axis_mask attribute to value. +// If not specified, defaults to 0 +func ResourceStridedSliceAssignNewAxisMask(value int64) ResourceStridedSliceAssignAttr { return func(m optionalAttr) { - m["seed"] = value + m["new_axis_mask"] = value } } -// LearnedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. -// -// value: An second seed to avoid seed collision. +// ResourceStridedSliceAssignShrinkAxisMask sets the optional shrink_axis_mask attribute to value. // If not specified, defaults to 0 -func LearnedUnigramCandidateSamplerSeed2(value int64) LearnedUnigramCandidateSamplerAttr { +func ResourceStridedSliceAssignShrinkAxisMask(value int64) ResourceStridedSliceAssignAttr { return func(m optionalAttr) { - m["seed2"] = value + m["shrink_axis_mask"] = value } } -// Generates labels for candidate sampling with a learned unigram distribution. -// -// See explanations of candidate sampling and the data formats at -// go/candidate-sampling. -// -// For each batch, this op picks a single set of sampled candidate labels. +// Assign `value` to the sliced l-value reference of `ref`. // -// The advantages of sampling candidates per-batch are simplicity and the -// possibility of efficient dense matrix multiplication. The disadvantage is that -// the sampled candidates must be chosen independently of the context and of the -// true labels. +// The values of `value` are assigned to the positions in the variable +// `ref` that are selected by the slice parameters. The slice parameters +// `begin, `end`, `strides`, etc. work exactly as in `StridedSlice`. // -// Arguments: -// true_classes: A batch_size * num_true matrix, in which each row contains the -// IDs of the num_true target_classes in the corresponding original label. -// num_true: Number of true labels per context. -// num_sampled: Number of candidates to randomly sample. -// unique: If unique is true, we sample with rejection, so that all sampled -// candidates in a batch are unique. This requires some approximation to -// estimate the post-rejection sampling probabilities. -// range_max: The sampler will sample integers from the interval [0, range_max). +// NOTE this op currently does not support broadcasting and so `value`'s +// shape must be exactly the shape produced by the slice of `ref`. // -// Returns A vector of length num_sampled, in which each element is -// the ID of a sampled candidate.A batch_size * num_true matrix, representing -// the number of times each candidate is expected to occur in a batch -// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled -// candidate representing the number of times the candidate is expected -// to occur in a batch of sampled candidates. If unique=true, then this is a -// probability. -func LearnedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LearnedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { +// Returns the created operation. +func ResourceStridedSliceAssign(scope *Scope, ref tf.Output, begin tf.Output, end tf.Output, strides tf.Output, value tf.Output, optional ...ResourceStridedSliceAssignAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "LearnedUnigramCandidateSampler", + Type: "ResourceStridedSliceAssign", Input: []tf.Input{ - true_classes, + ref, begin, end, strides, value, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return scope.AddOperation(opspec) } -// Computes gradients for the scaled exponential linear (Selu) operation. +// QueueEnqueueV2Attr is an optional argument to QueueEnqueueV2. +type QueueEnqueueV2Attr func(optionalAttr) + +// QueueEnqueueV2TimeoutMs sets the optional timeout_ms attribute to value. +// +// value: If the queue is full, this operation will block for up to +// timeout_ms milliseconds. +// Note: This option is not supported yet. +// If not specified, defaults to -1 +func QueueEnqueueV2TimeoutMs(value int64) QueueEnqueueV2Attr { + return func(m optionalAttr) { + m["timeout_ms"] = value + } +} + +// Enqueues a tuple of one or more tensors in the given queue. +// +// The components input has k elements, which correspond to the components of +// tuples stored in the given queue. +// +// N.B. If the queue is full, this operation will block until the given +// element has been enqueued (or 'timeout_ms' elapses, if specified). // // Arguments: -// gradients: The backpropagated gradients to the corresponding Selu operation. -// outputs: The outputs of the corresponding Selu operation. +// handle: The handle to a queue. +// components: One or more tensors from which the enqueued tensors should be taken. // -// Returns The gradients: `gradients * (outputs + scale * alpha)` -// if outputs < 0, `scale * gradients` otherwise. -func SeluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { +// Returns the created operation. +func QueueEnqueueV2(scope *Scope, handle tf.Output, components []tf.Output, optional ...QueueEnqueueV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SeluGrad", + Type: "QueueEnqueueV2", Input: []tf.Input{ - gradients, outputs, + handle, tf.OutputList(components), }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Get the current size of the TensorArray. +// QueueDequeueManyV2Attr is an optional argument to QueueDequeueManyV2. +type QueueDequeueManyV2Attr func(optionalAttr) + +// QueueDequeueManyV2TimeoutMs sets the optional timeout_ms attribute to value. +// +// value: If the queue has fewer than n elements, this operation +// will block for up to timeout_ms milliseconds. +// Note: This option is not supported yet. +// If not specified, defaults to -1 +func QueueDequeueManyV2TimeoutMs(value int64) QueueDequeueManyV2Attr { + return func(m optionalAttr) { + m["timeout_ms"] = value + } +} + +// Dequeues `n` tuples of one or more tensors from the given queue. +// +// If the queue is closed and there are fewer than `n` elements, then an +// OutOfRange error is returned. +// +// This operation concatenates queue-element component tensors along the +// 0th dimension to make a single component tensor. All of the components +// in the dequeued tuple will have size `n` in the 0th dimension. +// +// This operation has `k` outputs, where `k` is the number of components in +// the tuples stored in the given queue, and output `i` is the ith +// component of the dequeued tuple. +// +// N.B. If the queue is empty, this operation will block until `n` elements +// have been dequeued (or 'timeout_ms' elapses, if specified). // // Arguments: -// handle: The handle to a TensorArray (output of TensorArray or TensorArrayGrad). -// flow_in: A float scalar that enforces proper chaining of operations. +// handle: The handle to a queue. +// n: The number of tuples to dequeue. +// component_types: The type of each component in a tuple. // -// Returns The current size of the TensorArray. -func TensorArraySizeV3(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) { +// Returns One or more tensors that were dequeued as a tuple. +func QueueDequeueManyV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueManyV2Attr) (components []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"component_types": component_types} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TensorArraySizeV3", + Type: "QueueDequeueManyV2", Input: []tf.Input{ - handle, flow_in, + handle, n, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Deprecated. Use TensorArrayGradV3 -// -// DEPRECATED at GraphDef version 26: Use TensorArrayWriteV3 -func TensorArrayWriteV2(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "TensorArrayWriteV2", - Input: []tf.Input{ - handle, index, value, flow_in, - }, + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("QueueDequeueManyV2", err) + return } - op := scope.AddOperation(opspec) - return op.Output(0) + return components } -// SparseReduceMaxAttr is an optional argument to SparseReduceMax. -type SparseReduceMaxAttr func(optionalAttr) +// EncodeBase64Attr is an optional argument to EncodeBase64. +type EncodeBase64Attr func(optionalAttr) -// SparseReduceMaxKeepDims sets the optional keep_dims attribute to value. +// EncodeBase64Pad sets the optional pad attribute to value. // -// value: If true, retain reduced dimensions with length 1. +// value: Bool whether padding is applied at the ends. // If not specified, defaults to false -func SparseReduceMaxKeepDims(value bool) SparseReduceMaxAttr { +func EncodeBase64Pad(value bool) EncodeBase64Attr { return func(m optionalAttr) { - m["keep_dims"] = value + m["pad"] = value } } -// Computes the max of elements across dimensions of a SparseTensor. -// -// This Op takes a SparseTensor and is the sparse counterpart to -// `tf.reduce_max()`. In particular, this Op also returns a dense `Tensor` -// instead of a sparse one. +// Encode strings into web-safe base64 format. // -// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained -// with length 1. +// Refer to the following article for more information on base64 format: +// en.wikipedia.org/wiki/Base64. Base64 strings may have padding with '=' at the +// end so that the encoded has length multiple of 4. See Padding section of the +// link above. // -// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor -// with a single element is returned. Additionally, the axes can be negative, -// which are interpreted according to the indexing rules in Python. +// Web-safe means that the encoder uses - and _ instead of + and /. // // Arguments: -// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. -// input_shape: 1-D. Shape of the input SparseTensor. -// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. +// input: Strings to be encoded. // -// Returns `R-K`-D. The reduced Tensor. -func SparseReduceMax(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceMaxAttr) (output tf.Output) { +// Returns Input strings encoded in base64. +func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) (output tf.Output) { if scope.Err() != nil { return } @@ -23444,9 +23489,9 @@ func SparseReduceMax(scope *Scope, input_indices tf.Output, input_values tf.Outp a(attrs) } opspec := tf.OpSpec{ - Type: "SparseReduceMax", + Type: "EncodeBase64", Input: []tf.Input{ - input_indices, input_values, input_shape, reduction_axes, + input, }, Attrs: attrs, } @@ -23454,68 +23499,77 @@ func SparseReduceMax(scope *Scope, input_indices tf.Output, input_values tf.Outp return op.Output(0) } -// AsStringAttr is an optional argument to AsString. -type AsStringAttr func(optionalAttr) +// Deprecated. Use TensorArrayCloseV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayCloseV3 +// +// Returns the created operation. +func TensorArrayCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorArrayCloseV2", + Input: []tf.Input{ + handle, + }, + } + return scope.AddOperation(opspec) +} -// AsStringPrecision sets the optional precision attribute to value. +// Forwards the value of an available tensor from `inputs` to `output`. // -// value: The post-decimal precision to use for floating point numbers. -// Only used if precision > -1. -// If not specified, defaults to -1 -func AsStringPrecision(value int64) AsStringAttr { - return func(m optionalAttr) { - m["precision"] = value +// `Merge` waits for at least one of the tensors in `inputs` to become available. +// It is usually combined with `Switch` to implement branching. +// +// `Merge` forwards the first tensor to become available to `output`, and sets +// `value_index` to its index in `inputs`. +// +// Arguments: +// inputs: The input tensors, exactly one of which will become available. +// +// Returns Will be set to the available input tensor.The index of the chosen input tensor in `inputs`. +func Merge(scope *Scope, inputs []tf.Output) (output tf.Output, value_index tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Merge", + Input: []tf.Input{ + tf.OutputList(inputs), + }, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } -// AsStringScientific sets the optional scientific attribute to value. -// -// value: Use scientific notation for floating point numbers. -// If not specified, defaults to false -func AsStringScientific(value bool) AsStringAttr { - return func(m optionalAttr) { - m["scientific"] = value - } -} +// QueueCloseV2Attr is an optional argument to QueueCloseV2. +type QueueCloseV2Attr func(optionalAttr) -// AsStringShortest sets the optional shortest attribute to value. +// QueueCloseV2CancelPendingEnqueues sets the optional cancel_pending_enqueues attribute to value. // -// value: Use shortest representation (either scientific or standard) for -// floating point numbers. +// value: If true, all pending enqueue requests that are +// blocked on the given queue will be canceled. // If not specified, defaults to false -func AsStringShortest(value bool) AsStringAttr { +func QueueCloseV2CancelPendingEnqueues(value bool) QueueCloseV2Attr { return func(m optionalAttr) { - m["shortest"] = value + m["cancel_pending_enqueues"] = value } } -// AsStringWidth sets the optional width attribute to value. +// Closes the given queue. // -// value: Pad pre-decimal numbers to this width. -// Applies to both floating point and integer numbers. -// Only used if width > -1. -// If not specified, defaults to -1 -func AsStringWidth(value int64) AsStringAttr { - return func(m optionalAttr) { - m["width"] = value - } -} - -// AsStringFill sets the optional fill attribute to value. +// This operation signals that no more elements will be enqueued in the +// given queue. Subsequent Enqueue(Many) operations will fail. +// Subsequent Dequeue(Many) operations will continue to succeed if +// sufficient elements remain in the queue. Subsequent Dequeue(Many) +// operations that would block will fail immediately. // -// value: The value to pad if width > -1. If empty, pads with spaces. -// Another typical value is '0'. String cannot be longer than 1 character. -// If not specified, defaults to "" -func AsStringFill(value string) AsStringAttr { - return func(m optionalAttr) { - m["fill"] = value - } -} - -// Converts each entry in the given tensor to strings. Supports many numeric +// Arguments: +// handle: The handle to a queue. // -// types and boolean. -func AsString(scope *Scope, input tf.Output, optional ...AsStringAttr) (output tf.Output) { +// Returns the created operation. +func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -23524,378 +23578,381 @@ func AsString(scope *Scope, input tf.Output, optional ...AsStringAttr) (output t a(attrs) } opspec := tf.OpSpec{ - Type: "AsString", + Type: "QueueCloseV2", Input: []tf.Input{ - input, + handle, }, Attrs: attrs, } + return scope.AddOperation(opspec) +} + +// Computes inverse hyperbolic tangent of x element-wise. +func Atanh(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Atanh", + Input: []tf.Input{ + x, + }, + } op := scope.AddOperation(opspec) return op.Output(0) } -// Deprecated. Use TensorArrayScatterV3 +// Returns true if queue is closed. // -// DEPRECATED at GraphDef version 26: Use TensorArrayScatterV3 -func TensorArrayScatterV2(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// This operation returns true if the queue is closed and false if the queue +// is open. +// +// Arguments: +// handle: The handle to a queue. +func QueueIsClosedV2(scope *Scope, handle tf.Output) (is_closed tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorArrayScatterV2", + Type: "QueueIsClosedV2", Input: []tf.Input{ - handle, indices, value, flow_in, + handle, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Applies sparse addition to `input` using individual values or slices +// Returns the batched diagonal part of a batched tensor. // -// from `updates` according to indices `indices`. The updates are non-aliasing: -// `input` is only modified in-place if no other operations will use it. -// Otherwise, a copy of `input` is made. This operation has a gradient with -// respect to both `input` and `updates`. +// This operation returns a tensor with the `diagonal` part +// of the batched `input`. The `diagonal` part is computed as follows: // -// `input` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. +// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a +// tensor of rank `k - 1` with dimensions `[I, J, K, ..., min(M, N)]` where: // -// `indices` must be integer tensor, containing indices into `input`. -// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. +// `diagonal[i, j, k, ..., n] = input[i, j, k, ..., n, n]`. // -// The innermost dimension of `indices` (with length `K`) corresponds to -// indices into elements (if `K = P`) or `(P-K)`-dimensional slices -// (if `K < P`) along the `K`th dimension of `input`. +// The input must be at least a matrix. // -// `updates` is `Tensor` of rank `Q-1+P-K` with shape: +// For example: // // ``` -// [d_0, ..., d_{Q-2}, input.shape[K], ..., input.shape[P-1]]. -// ``` -// -// For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 -// elements. In Python, that addition would look like this: -// -// input = tf.constant([1, 2, 3, 4, 5, 6, 7, 8]) -// indices = tf.constant([[4], [3], [1], [7]]) -// updates = tf.constant([9, 10, 11, 12]) -// output = tf.scatter_nd_non_aliasing_add(input, indices, updates) -// with tf.Session() as sess: -// print(sess.run(output)) +// # 'input' is [[[1, 0, 0, 0] +// [0, 2, 0, 0] +// [0, 0, 3, 0] +// [0, 0, 0, 4]], +// [[5, 0, 0, 0] +// [0, 6, 0, 0] +// [0, 0, 7, 0] +// [0, 0, 0, 8]]] // -// The resulting value `output` would look like this: +// and input.shape = (2, 4, 4) // -// [1, 13, 3, 14, 14, 6, 7, 20] +// tf.matrix_diag_part(input) ==> [[1, 2, 3, 4], [5, 6, 7, 8]] // -// See @{tf.scatter_nd} for more details about how to make updates to slices. +// which has shape (2, 4) +// ``` // // Arguments: -// input: A Tensor. -// indices: A Tensor. Must be one of the following types: `int32`, `int64`. -// A tensor of indices into `input`. -// updates: A Tensor. Must have the same type as ref. A tensor of updated values -// to add to `input`. +// input: Rank `k` tensor where `k >= 2`. // -// Returns A `Tensor` with the same shape as `input`, containing values of `input` -// updated with `updates`. -func ScatterNdNonAliasingAdd(scope *Scope, input tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) { +// Returns The extracted diagonal(s) having shape +// `diagonal.shape = input.shape[:-2] + [min(input.shape[-2:])]`. +func MatrixDiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ScatterNdNonAliasingAdd", + Type: "MatrixDiagPart", Input: []tf.Input{ - input, indices, updates, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// FractionalMaxPoolAttr is an optional argument to FractionalMaxPool. -type FractionalMaxPoolAttr func(optionalAttr) +// Computes the absolute value of a tensor. +// +// Given a tensor `x`, this operation returns a tensor containing the absolute +// value of each element in `x`. For example, if x is an input element and y is +// an output element, this operation computes \\(y = |x|\\). +func Abs(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Abs", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} -// FractionalMaxPoolPseudoRandom sets the optional pseudo_random attribute to value. +// StackV2Attr is an optional argument to StackV2. +type StackV2Attr func(optionalAttr) + +// StackV2StackName sets the optional stack_name attribute to value. // -// value: When set to True, generates the pooling sequence in a -// pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin -// Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) for -// difference between pseudorandom and random. -// If not specified, defaults to false -func FractionalMaxPoolPseudoRandom(value bool) FractionalMaxPoolAttr { +// value: Overrides the name used for the temporary stack resource. Default +// value is the name of the 'Stack' op (which is guaranteed unique). +// If not specified, defaults to "" +func StackV2StackName(value string) StackV2Attr { return func(m optionalAttr) { - m["pseudo_random"] = value + m["stack_name"] = value } } -// FractionalMaxPoolOverlapping sets the optional overlapping attribute to value. -// -// value: When set to True, it means when pooling, the values at the boundary -// of adjacent pooling cells are used by both cells. For example: -// -// `index 0 1 2 3 4` +// A stack that produces elements in first-in last-out order. // -// `value 20 5 16 3 7` +// Arguments: +// max_size: The maximum size of the stack if non-negative. If negative, the stack +// size is unlimited. +// elem_type: The type of the elements on the stack. // -// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. -// The result would be [20, 16] for fractional max pooling. -// If not specified, defaults to false -func FractionalMaxPoolOverlapping(value bool) FractionalMaxPoolAttr { - return func(m optionalAttr) { - m["overlapping"] = value +// Returns The handle to the stack. +func StackV2(scope *Scope, max_size tf.Output, elem_type tf.DataType, optional ...StackV2Attr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"elem_type": elem_type} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StackV2", + Input: []tf.Input{ + max_size, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// FractionalMaxPoolDeterministic sets the optional deterministic attribute to value. +// OrderedMapStageAttr is an optional argument to OrderedMapStage. +type OrderedMapStageAttr func(optionalAttr) + +// OrderedMapStageCapacity sets the optional capacity attribute to value. // -// value: When set to True, a fixed pooling region will be used when -// iterating over a FractionalMaxPool node in the computation graph. Mainly used -// in unit test to make FractionalMaxPool deterministic. -// If not specified, defaults to false -func FractionalMaxPoolDeterministic(value bool) FractionalMaxPoolAttr { +// value: Maximum number of elements in the Staging Area. If > 0, inserts +// on the container will block when the capacity is reached. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func OrderedMapStageCapacity(value int64) OrderedMapStageAttr { return func(m optionalAttr) { - m["deterministic"] = value + m["capacity"] = value } } -// FractionalMaxPoolSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. +// OrderedMapStageMemoryLimit sets the optional memory_limit attribute to value. // If not specified, defaults to 0 -func FractionalMaxPoolSeed(value int64) FractionalMaxPoolAttr { +// +// REQUIRES: value >= 0 +func OrderedMapStageMemoryLimit(value int64) OrderedMapStageAttr { return func(m optionalAttr) { - m["seed"] = value + m["memory_limit"] = value } } -// FractionalMaxPoolSeed2 sets the optional seed2 attribute to value. +// OrderedMapStageContainer sets the optional container attribute to value. // -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func FractionalMaxPoolSeed2(value int64) FractionalMaxPoolAttr { +// value: If non-empty, this queue is placed in the given container. Otherwise, +// a default container is used. +// If not specified, defaults to "" +func OrderedMapStageContainer(value string) OrderedMapStageAttr { return func(m optionalAttr) { - m["seed2"] = value + m["container"] = value } } -// Performs fractional max pooling on the input. -// -// Fractional max pooling is slightly different than regular max pooling. In -// regular max pooling, you downsize an input set by taking the maximum value of -// smaller N x N subsections of the set (often 2x2), and try to reduce the set by -// a factor of N, where N is an integer. Fractional max pooling, as you might -// expect from the word "fractional", means that the overall reduction ratio N -// does not have to be an integer. -// -// The sizes of the pooling regions are generated randomly but are fairly uniform. -// For example, let's look at the height dimension, and the constraints on the -// list of rows that will be pool boundaries. -// -// First we define the following: +// OrderedMapStageSharedName sets the optional shared_name attribute to value. // -// 1. input_row_length : the number of rows from the input set -// 2. output_row_length : which will be smaller than the input -// 3. alpha = input_row_length / output_row_length : our reduction ratio -// 4. K = floor(alpha) -// 5. row_pooling_sequence : this is the result list of pool boundary rows +// value: It is necessary to match this name to the matching Unstage Op. +// If not specified, defaults to "" +func OrderedMapStageSharedName(value string) OrderedMapStageAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Stage (key, values) in the underlying container which behaves like a ordered // -// Then, row_pooling_sequence should satisfy: +// associative container. Elements are ordered by key. // -// 1. a[0] = 0 : the first value of the sequence is 0 -// 2. a[end] = input_row_length : the last value of the sequence is the size -// 3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size -// 4. length(row_pooling_sequence) = output_row_length+1 +// Arguments: +// key: int64 // -// For more details on fractional max pooling, see this paper: -// [Benjamin Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) +// values: a list of tensors +// dtypes A list of data types that inserted values should adhere to. // -// Arguments: -// value: 4-D with shape `[batch, height, width, channels]`. -// pooling_ratio: Pooling ratio for each dimension of `value`, currently only -// supports row and col dimension and should be >= 1.0. For example, a valid -// pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements -// must be 1.0 because we don't allow pooling on batch and channels -// dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions -// respectively. // -// Returns output tensor after fractional max pooling.row pooling sequence, needed to calculate gradient.column pooling sequence, needed to calculate gradient. -func FractionalMaxPool(scope *Scope, value tf.Output, pooling_ratio []float32, optional ...FractionalMaxPoolAttr) (output tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output) { +// Returns the created operation. +func OrderedMapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf.Output, dtypes []tf.DataType, optional ...OrderedMapStageAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"pooling_ratio": pooling_ratio} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "FractionalMaxPool", + Type: "OrderedMapStage", Input: []tf.Input{ - value, + key, indices, tf.OutputList(values), }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return scope.AddOperation(opspec) } -// Deprecated. Use TensorArraySizeV3 +// StackPushV2Attr is an optional argument to StackPushV2. +type StackPushV2Attr func(optionalAttr) + +// StackPushV2SwapMemory sets the optional swap_memory attribute to value. // -// DEPRECATED at GraphDef version 26: Use TensorArraySizeV3 -func TensorArraySizeV2(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) { +// value: Swap `elem` to CPU. Default to false. +// If not specified, defaults to false +func StackPushV2SwapMemory(value bool) StackPushV2Attr { + return func(m optionalAttr) { + m["swap_memory"] = value + } +} + +// Push an element onto the stack. +// +// Arguments: +// handle: The handle to a stack. +// elem: The tensor to be pushed onto the stack. +// +// Returns The same tensor as the input 'elem'. +func StackPushV2(scope *Scope, handle tf.Output, elem tf.Output, optional ...StackPushV2Attr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TensorArraySizeV2", + Type: "StackPushV2", Input: []tf.Input{ - handle, flow_in, + handle, elem, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Conv2DAttr is an optional argument to Conv2D. -type Conv2DAttr func(optionalAttr) +// FusedBatchNormGradV2Attr is an optional argument to FusedBatchNormGradV2. +type FusedBatchNormGradV2Attr func(optionalAttr) -// Conv2DUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. -// If not specified, defaults to true -func Conv2DUseCudnnOnGpu(value bool) Conv2DAttr { +// FusedBatchNormGradV2Epsilon sets the optional epsilon attribute to value. +// +// value: A small float number added to the variance of x. +// If not specified, defaults to 0.0001 +func FusedBatchNormGradV2Epsilon(value float32) FusedBatchNormGradV2Attr { return func(m optionalAttr) { - m["use_cudnn_on_gpu"] = value + m["epsilon"] = value } } -// Conv2DDataFormat sets the optional data_format attribute to value. +// FusedBatchNormGradV2DataFormat sets the optional data_format attribute to value. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, height, width, channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, channels, height, width]. +// value: The data format for y_backprop, x, x_backprop. +// Either "NHWC" (default) or "NCHW". // If not specified, defaults to "NHWC" -func Conv2DDataFormat(value string) Conv2DAttr { +func FusedBatchNormGradV2DataFormat(value string) FusedBatchNormGradV2Attr { return func(m optionalAttr) { m["data_format"] = value } } -// Conv2DDilations sets the optional dilations attribute to value. +// FusedBatchNormGradV2IsTraining sets the optional is_training attribute to value. // -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each -// filter element on that dimension. The dimension order is determined by the -// value of `data_format`, see above for details. Dilations in the batch and -// depth dimensions must be 1. -// If not specified, defaults to -func Conv2DDilations(value []int64) Conv2DAttr { +// value: A bool value to indicate the operation is for training (default) +// or inference. +// If not specified, defaults to true +func FusedBatchNormGradV2IsTraining(value bool) FusedBatchNormGradV2Attr { return func(m optionalAttr) { - m["dilations"] = value + m["is_training"] = value } } -// Computes a 2-D convolution given 4-D `input` and `filter` tensors. -// -// Given an input tensor of shape `[batch, in_height, in_width, in_channels]` -// and a filter / kernel tensor of shape -// `[filter_height, filter_width, in_channels, out_channels]`, this op -// performs the following: -// -// 1. Flattens the filter to a 2-D matrix with shape -// `[filter_height * filter_width * in_channels, output_channels]`. -// 2. Extracts image patches from the input tensor to form a *virtual* -// tensor of shape `[batch, out_height, out_width, -// filter_height * filter_width * in_channels]`. -// 3. For each patch, right-multiplies the filter matrix and the image patch -// vector. -// -// In detail, with the default NHWC format, -// -// output[b, i, j, k] = -// sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] * -// filter[di, dj, q, k] +// Gradient for batch normalization. // -// Must have `strides[0] = strides[3] = 1`. For the most common case of the same -// horizontal and vertices strides, `strides = [1, stride, stride, 1]`. +// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +// The size of 1D Tensors matches the dimension C of the 4D Tensors. // // Arguments: -// input: A 4-D tensor. The dimension order is interpreted according to the value -// of `data_format`, see below for details. -// filter: A 4-D tensor of shape -// `[filter_height, filter_width, in_channels, out_channels]` -// strides: 1-D tensor of length 4. The stride of the sliding window for each -// dimension of `input`. The dimension order is determined by the value of -// `data_format`, see below for details. -// padding: The type of padding algorithm to use. +// y_backprop: A 4D Tensor for the gradient with respect to y. +// x: A 4D Tensor for input data. +// scale: A 1D Tensor for scaling factor, to scale the normalized x. +// reserve_space_1: When is_training is True, a 1D Tensor for the computed batch +// mean to be reused in gradient computation. When is_training is +// False, a 1D Tensor for the population mean to be reused in both +// 1st and 2nd order gradient computation. +// reserve_space_2: When is_training is True, a 1D Tensor for the computed batch +// variance (inverted variance in the cuDNN case) to be reused in +// gradient computation. When is_training is False, a 1D Tensor +// for the population variance to be reused in both 1st and 2nd +// order gradient computation. // -// Returns A 4-D tensor. The dimension order is determined by the value of -// `data_format`, see below for details. -func Conv2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...Conv2DAttr) (output tf.Output) { +// Returns A 4D Tensor for the gradient with respect to x.A 1D Tensor for the gradient with respect to scale.A 1D Tensor for the gradient with respect to offset.Unused placeholder to match the mean input in FusedBatchNorm.Unused placeholder to match the variance input +// in FusedBatchNorm. +func FusedBatchNormGradV2(scope *Scope, y_backprop tf.Output, x tf.Output, scale tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output, optional ...FusedBatchNormGradV2Attr) (x_backprop tf.Output, scale_backprop tf.Output, offset_backprop tf.Output, reserve_space_3 tf.Output, reserve_space_4 tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Conv2D", + Type: "FusedBatchNormGradV2", Input: []tf.Input{ - input, filter, + y_backprop, x, scale, reserve_space_1, reserve_space_2, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// FakeQuantWithMinMaxArgsAttr is an optional argument to FakeQuantWithMinMaxArgs. -type FakeQuantWithMinMaxArgsAttr func(optionalAttr) - -// FakeQuantWithMinMaxArgsMin sets the optional min attribute to value. -// If not specified, defaults to -6 -func FakeQuantWithMinMaxArgsMin(value float32) FakeQuantWithMinMaxArgsAttr { - return func(m optionalAttr) { - m["min"] = value - } -} - -// FakeQuantWithMinMaxArgsMax sets the optional max attribute to value. -// If not specified, defaults to 6 -func FakeQuantWithMinMaxArgsMax(value float32) FakeQuantWithMinMaxArgsAttr { - return func(m optionalAttr) { - m["max"] = value - } + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// FakeQuantWithMinMaxArgsNumBits sets the optional num_bits attribute to value. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxArgsNumBits(value int64) FakeQuantWithMinMaxArgsAttr { - return func(m optionalAttr) { - m["num_bits"] = value - } -} +// DecodeCompressedAttr is an optional argument to DecodeCompressed. +type DecodeCompressedAttr func(optionalAttr) -// FakeQuantWithMinMaxArgsNarrowRange sets the optional narrow_range attribute to value. -// If not specified, defaults to false -func FakeQuantWithMinMaxArgsNarrowRange(value bool) FakeQuantWithMinMaxArgsAttr { +// DecodeCompressedCompressionType sets the optional compression_type attribute to value. +// +// value: A scalar containing either (i) the empty string (no +// compression), (ii) "ZLIB", or (iii) "GZIP". +// If not specified, defaults to "" +func DecodeCompressedCompressionType(value string) DecodeCompressedAttr { return func(m optionalAttr) { - m["narrow_range"] = value + m["compression_type"] = value } } -// Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type. +// Decompress strings. // -// Attributes `[min; max]` define the clamping range for the `inputs` data. -// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` -// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and -// then de-quantized and output as floats in `[min; max]` interval. -// `num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive. +// This op decompresses each element of the `bytes` input `Tensor`, which +// is assumed to be compressed using the given `compression_type`. +// +// The `output` is a string `Tensor` of the same shape as `bytes`, +// each element containing the decompressed data from the corresponding +// element in `bytes`. // -// Quantization is called fake since the output is still in floating point. -func FakeQuantWithMinMaxArgs(scope *Scope, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsAttr) (outputs tf.Output) { +// Arguments: +// bytes: A Tensor of string which is compressed. +// +// Returns A Tensor with the same shape as input `bytes`, uncompressed +// from bytes. +func DecodeCompressed(scope *Scope, bytes tf.Output, optional ...DecodeCompressedAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -23904,9 +23961,9 @@ func FakeQuantWithMinMaxArgs(scope *Scope, inputs tf.Output, optional ...FakeQua a(attrs) } opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxArgs", + Type: "DecodeCompressed", Input: []tf.Input{ - inputs, + bytes, }, Attrs: attrs, } @@ -23914,212 +23971,269 @@ func FakeQuantWithMinMaxArgs(scope *Scope, inputs tf.Output, optional ...FakeQua return op.Output(0) } -// StageAttr is an optional argument to Stage. -type StageAttr func(optionalAttr) - -// StageCapacity sets the optional capacity attribute to value. +// Creates a TensorArray for storing the gradients of values in the given handle. // -// value: Maximum number of elements in the Staging Area. If > 0, inserts -// on the container will block when the capacity is reached. -// If not specified, defaults to 0 +// If the given TensorArray gradient already exists, returns a reference to it. // -// REQUIRES: value >= 0 -func StageCapacity(value int64) StageAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// StageMemoryLimit sets the optional memory_limit attribute to value. +// Locks the size of the original TensorArray by disabling its dynamic size flag. // -// value: The maximum number of bytes allowed for Tensors in the Staging Area. -// If > 0, inserts will block until sufficient space is available. -// If not specified, defaults to 0 +// **A note about the input flow_in:** // -// REQUIRES: value >= 0 -func StageMemoryLimit(value int64) StageAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// StageContainer sets the optional container attribute to value. +// The handle flow_in forces the execution of the gradient lookup to occur +// only after certain other operations have occurred. For example, when +// the forward TensorArray is dynamically sized, writes to this TensorArray +// may resize the object. The gradient TensorArray is statically sized based +// on the size of the forward TensorArray when this operation executes. +// Furthermore, the size of the forward TensorArray is frozen by this call. +// As a result, the flow is used to ensure that the call to generate the gradient +// TensorArray only happens after all writes are executed. // -// value: If non-empty, this queue is placed in the given container. Otherwise, -// a default container is used. -// If not specified, defaults to "" -func StageContainer(value string) StageAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// StageSharedName sets the optional shared_name attribute to value. +// In the case of dynamically sized TensorArrays, gradient computation should +// only be performed on read operations that have themselves been chained via +// flow to occur only after all writes have executed. That way the final size +// of the forward TensorArray is known when this operation is called. // -// value: It is necessary to match this name to the matching Unstage Op. -// If not specified, defaults to "" -func StageSharedName(value string) StageAttr { - return func(m optionalAttr) { - m["shared_name"] = value +// **A note about the source attribute:** +// +// TensorArray gradient calls use an accumulator TensorArray object. If +// multiple gradients are calculated and run in the same session, the multiple +// gradient nodes may accidentally flow through the same accumulator TensorArray. +// This double counts and generally breaks the TensorArray gradient flow. +// +// The solution is to identify which gradient call this particular +// TensorArray gradient is being called in. This is performed by identifying +// a unique string (e.g. "gradients", "gradients_1", ...) from the input +// gradient Tensor's name. This string is used as a suffix when creating +// the TensorArray gradient object here (the attribute `source`). +// +// The attribute `source` is added as a suffix to the forward TensorArray's +// name when performing the creation / lookup, so that each separate gradient +// calculation gets its own TensorArray accumulator. +// +// Arguments: +// handle: The handle to the forward TensorArray. +// flow_in: A float scalar that enforces proper chaining of operations. +// source: The gradient source string, used to decide which gradient TensorArray +// to return. +func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"source": source} + opspec := tf.OpSpec{ + Type: "TensorArrayGradV3", + Input: []tf.Input{ + handle, flow_in, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } -// Stage values similar to a lightweight Enqueue. +// Compare values of `input` to `threshold` and pack resulting bits into a `uint8`. // -// The basic functionality of this Op is similar to a queue with many -// fewer capabilities and options. This Op is optimized for performance. +// Each comparison returns a boolean `true` (if `input_value > threshold`) +// or and `false` otherwise. +// +// This operation is useful for Locality-Sensitive-Hashing (LSH) and other +// algorithms that use hashing approximations of cosine and `L2` distances; +// codes can be generated from an input via: +// +// ```python +// codebook_size = 50 +// codebook_bits = codebook_size * 32 +// codebook = tf.get_variable('codebook', [x.shape[-1].value, codebook_bits], +// dtype=x.dtype, +// initializer=tf.orthogonal_initializer()) +// codes = compare_and_threshold(tf.matmul(x, codebook), threshold=0.) +// codes = tf.bitcast(codes, tf.int32) # go from uint8 to int32 +// # now codes has shape x.shape[:-1] + [codebook_size] +// ``` +// +// **NOTE**: Currently, the innermost dimension of the tensor must be divisible +// by 8. +// +// Given an `input` shaped `[s0, s1, ..., s_n]`, the output is +// a `uint8` tensor shaped `[s0, s1, ..., s_n / 8]`. // // Arguments: -// values: a list of tensors -// dtypes A list of data types that inserted values should adhere to. +// input: Values to compare against `threshold` and bitpack. +// threshold: Threshold to compare against. // -// Returns the created operation. -func Stage(scope *Scope, values []tf.Output, optional ...StageAttr) (o *tf.Operation) { +// Returns The bitpacked comparisons. +func CompareAndBitpack(scope *Scope, input tf.Output, threshold tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Stage", + Type: "CompareAndBitpack", Input: []tf.Input{ - tf.OutputList(values), + input, threshold, }, - Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// StagePeekAttr is an optional argument to StagePeek. -type StagePeekAttr func(optionalAttr) - -// StagePeekCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// Push an element onto the tensor_array. // -// REQUIRES: value >= 0 -func StagePeekCapacity(value int64) StagePeekAttr { - return func(m optionalAttr) { - m["capacity"] = value +// Arguments: +// handle: The handle to a TensorArray. +// index: The position to write to inside the TensorArray. +// value: The tensor to write to the TensorArray. +// flow_in: A float scalar that enforces proper chaining of operations. +// +// Returns A float scalar that enforces proper chaining of operations. +func TensorArrayWriteV3(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorArrayWriteV3", + Input: []tf.Input{ + handle, index, value, flow_in, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// StagePeekMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// Scatter the data from the input value into specific TensorArray elements. // -// REQUIRES: value >= 0 -func StagePeekMemoryLimit(value int64) StagePeekAttr { - return func(m optionalAttr) { - m["memory_limit"] = value +// `indices` must be a vector, its length must match the first dim of `value`. +// +// Arguments: +// handle: The handle to a TensorArray. +// indices: The locations at which to write the tensor elements. +// value: The concatenated tensor to write to the TensorArray. +// flow_in: A float scalar that enforces proper chaining of operations. +// +// Returns A float scalar that enforces proper chaining of operations. +func TensorArrayScatterV3(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { + if scope.Err() != nil { + return } -} - -// StagePeekContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func StagePeekContainer(value string) StagePeekAttr { - return func(m optionalAttr) { - m["container"] = value + opspec := tf.OpSpec{ + Type: "TensorArrayScatterV3", + Input: []tf.Input{ + handle, indices, value, flow_in, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// StagePeekSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func StagePeekSharedName(value string) StagePeekAttr { +// TensorArrayConcatV3Attr is an optional argument to TensorArrayConcatV3. +type TensorArrayConcatV3Attr func(optionalAttr) + +// TensorArrayConcatV3ElementShapeExcept0 sets the optional element_shape_except0 attribute to value. +// +// value: The expected shape of an element, if known, +// excluding the first dimension. Used to validate the shapes of +// TensorArray elements. If this shape is not fully specified, concatenating +// zero-size TensorArrays is an error. +// If not specified, defaults to +func TensorArrayConcatV3ElementShapeExcept0(value tf.Shape) TensorArrayConcatV3Attr { return func(m optionalAttr) { - m["shared_name"] = value + m["element_shape_except0"] = value } } -// Op peeks at the values at the specified index. If the +// Concat the elements from the TensorArray into value `value`. // -// underlying container does not contain sufficient elements -// this op will block until it does. This Op is optimized for -// performance. -func StagePeek(scope *Scope, index tf.Output, dtypes []tf.DataType, optional ...StagePeekAttr) (values []tf.Output) { +// Takes `T` elements of shapes +// +// ``` +// (n0 x d0 x d1 x ...), (n1 x d0 x d1 x ...), ..., (n(T-1) x d0 x d1 x ...) +// ``` +// +// and concatenates them into a Tensor of shape: +// +// ```(n0 + n1 + ... + n(T-1) x d0 x d1 x ...)``` +// +// All elements must have the same shape (excepting the first dimension). +// +// Arguments: +// handle: The handle to a TensorArray. +// flow_in: A float scalar that enforces proper chaining of operations. +// dtype: The type of the elem that is returned. +// +// Returns All of the elements in the TensorArray, concatenated along the first +// axis.A vector of the row sizes of the original T elements in the +// value output. In the example above, this would be the values: +// `(n1, n2, ..., n(T-1))`. +func TensorArrayConcatV3(scope *Scope, handle tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayConcatV3Attr) (value tf.Output, lengths tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{"dtype": dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "StagePeek", + Type: "TensorArrayConcatV3", Input: []tf.Input{ - index, + handle, flow_in, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("StagePeek", err) - return - } - return values + return op.Output(0), op.Output(1) } -// Conv3DBackpropInputV2Attr is an optional argument to Conv3DBackpropInputV2. -type Conv3DBackpropInputV2Attr func(optionalAttr) +// ParameterizedTruncatedNormalAttr is an optional argument to ParameterizedTruncatedNormal. +type ParameterizedTruncatedNormalAttr func(optionalAttr) -// Conv3DBackpropInputV2DataFormat sets the optional data_format attribute to value. +// ParameterizedTruncatedNormalSeed sets the optional seed attribute to value. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func ParameterizedTruncatedNormalSeed(value int64) ParameterizedTruncatedNormalAttr { return func(m optionalAttr) { - m["data_format"] = value + m["seed"] = value } } -// Conv3DBackpropInputV2Dilations sets the optional dilations attribute to value. +// ParameterizedTruncatedNormalSeed2 sets the optional seed2 attribute to value. // -// value: 1-D tensor of length 5. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each -// filter element on that dimension. The dimension order is determined by the -// value of `data_format`, see above for details. Dilations in the batch and -// depth dimensions must be 1. -// If not specified, defaults to -func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func ParameterizedTruncatedNormalSeed2(value int64) ParameterizedTruncatedNormalAttr { return func(m optionalAttr) { - m["dilations"] = value + m["seed2"] = value } } -// Computes the gradients of 3-D convolution with respect to the input. +// Outputs random values from a normal distribution. The parameters may each be a +// +// scalar which applies to the entire output, or a vector of length shape[0] which +// stores the parameters for each batch. // // Arguments: -// input_sizes: An integer vector representing the tensor shape of `input`, -// where `input` is a 5-D -// `[batch, depth, rows, cols, in_channels]` tensor. -// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. -// `in_channels` must match between `input` and `filter`. -// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, -// out_channels]`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func Conv3DBackpropInputV2(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropInputV2Attr) (output tf.Output) { +// shape: The shape of the output tensor. Batches are indexed by the 0th dimension. +// means: The mean parameter of each batch. +// stdevs: The standard deviation parameter of each batch. Must be greater than 0. +// minvals: The minimum cutoff. May be -infinity. +// maxvals: The maximum cutoff. May be +infinity, and must be more than the minval +// for each batch. +// +// Returns A matrix of shape num_batches x samples_per_batch, filled with random +// truncated normal values using the parameters for each row. +func ParameterizedTruncatedNormal(scope *Scope, shape tf.Output, means tf.Output, stdevs tf.Output, minvals tf.Output, maxvals tf.Output, optional ...ParameterizedTruncatedNormalAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Conv3DBackpropInputV2", + Type: "ParameterizedTruncatedNormal", Input: []tf.Input{ - input_sizes, filter, out_backprop, + shape, means, stdevs, minvals, maxvals, }, Attrs: attrs, } @@ -24127,384 +24241,223 @@ func Conv3DBackpropInputV2(scope *Scope, input_sizes tf.Output, filter tf.Output return op.Output(0) } -// DepthToSpaceAttr is an optional argument to DepthToSpace. -type DepthToSpaceAttr func(optionalAttr) - -// DepthToSpaceDataFormat sets the optional data_format attribute to value. -// If not specified, defaults to "NHWC" -func DepthToSpaceDataFormat(value string) DepthToSpaceAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// DepthToSpace for tensors of type T. -// -// Rearranges data from depth into blocks of spatial data. -// This is the reverse transformation of SpaceToDepth. More specifically, -// this op outputs a copy of the input tensor where values from the `depth` -// dimension are moved in spatial blocks to the `height` and `width` dimensions. -// The attr `block_size` indicates the input block size and how the data is moved. -// -// * Chunks of data of size `block_size * block_size` from depth are rearranged -// into non-overlapping blocks of size `block_size x block_size` -// * The width the output tensor is `input_depth * block_size`, whereas the -// height is `input_height * block_size`. -// * The Y, X coordinates within each block of the output image are determined -// by the high order component of the input channel index. -// * The depth of the input tensor must be divisible by -// `block_size * block_size`. -// -// The `data_format` attr specifies the layout of the input and output tensors -// with the following options: -// "NHWC": `[ batch, height, width, channels ]` -// "NCHW": `[ batch, channels, height, width ]` -// "NCHW_VECT_C": -// `qint8 [ batch, channels / 4, height, width, 4 ]` -// -// It is useful to consider the operation as transforming a 6-D Tensor. -// e.g. for data_format = NHWC, -// Each element in the input tensor can be specified via 6 coordinates, -// ordered by decreasing memory layout significance as: -// n,iY,iX,bY,bX,oC (where n=batch index, iX, iY means X or Y coordinates -// within the input image, bX, bY means coordinates -// within the output block, oC means output channels). -// The output would be the input transposed to the following layout: -// n,iY,bY,iX,bX,oC -// -// This operation is useful for resizing the activations between convolutions -// (but keeping all data), e.g. instead of pooling. It is also useful for training -// purely convolutional models. -// -// For example, given an input of shape `[1, 1, 1, 4]`, data_format = "NHWC" and -// block_size = 2: -// -// ``` -// x = [[[[1, 2, 3, 4]]]] -// -// ``` -// -// This operation will output a tensor of shape `[1, 2, 2, 1]`: -// -// ``` -// [[[[1], [2]], -// [[3], [4]]]] -// ``` -// -// Here, the input has a batch of 1 and each batch element has shape `[1, 1, 4]`, -// the corresponding output will have 2x2 elements and will have a depth of -// 1 channel (1 = `4 / (block_size * block_size)`). -// The output element shape is `[2, 2, 1]`. -// -// For an input tensor with larger depth, here of shape `[1, 1, 1, 12]`, e.g. -// -// ``` -// x = [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]] -// ``` -// -// This operation, for block size of 2, will return the following tensor of shape -// `[1, 2, 2, 3]` -// -// ``` -// [[[[1, 2, 3], [4, 5, 6]], -// [[7, 8, 9], [10, 11, 12]]]] +// Returns a diagonal tensor with a given diagonal values. // -// ``` +// Given a `diagonal`, this operation returns a tensor with the `diagonal` and +// everything else padded with zeros. The diagonal is computed as follows: // -// Similarly, for the following input of shape `[1 2 2 4]`, and a block size of 2: +// Assume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of +// rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where: // -// ``` -// x = [[[[1, 2, 3, 4], -// [5, 6, 7, 8]], -// [[9, 10, 11, 12], -// [13, 14, 15, 16]]]] -// ``` +// `output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else. // -// the operator will return the following tensor of shape `[1 4 4 1]`: +// For example: // // ``` -// x = [[[ [1], [2], [5], [6]], -// [ [3], [4], [7], [8]], -// [ [9], [10], [13], [14]], -// [ [11], [12], [15], [16]]]] -// +// # 'diagonal' is [1, 2, 3, 4] +// tf.diag(diagonal) ==> [[1, 0, 0, 0] +// [0, 2, 0, 0] +// [0, 0, 3, 0] +// [0, 0, 0, 4]] // ``` // // Arguments: -// -// block_size: The size of the spatial block, same as in Space2Depth. -func DepthToSpace(scope *Scope, input tf.Output, block_size int64, optional ...DepthToSpaceAttr) (output tf.Output) { +// diagonal: Rank k tensor where k is at most 1. +func Diag(scope *Scope, diagonal tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"block_size": block_size} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "DepthToSpace", + Type: "Diag", Input: []tf.Input{ - input, + diagonal, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// MapStageAttr is an optional argument to MapStage. -type MapStageAttr func(optionalAttr) - -// MapStageCapacity sets the optional capacity attribute to value. +// Split the data from the input value into TensorArray elements. // -// value: Maximum number of elements in the Staging Area. If > 0, inserts -// on the container will block when the capacity is reached. -// If not specified, defaults to 0 +// Assuming that `lengths` takes on values // -// REQUIRES: value >= 0 -func MapStageCapacity(value int64) MapStageAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// MapStageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// ```(n0, n1, ..., n(T-1))``` // -// REQUIRES: value >= 0 -func MapStageMemoryLimit(value int64) MapStageAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// MapStageContainer sets the optional container attribute to value. +// and that `value` has shape // -// value: If non-empty, this queue is placed in the given container. Otherwise, -// a default container is used. -// If not specified, defaults to "" -func MapStageContainer(value string) MapStageAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MapStageSharedName sets the optional shared_name attribute to value. +// ```(n0 + n1 + ... + n(T-1) x d0 x d1 x ...)```, // -// value: It is necessary to match this name to the matching Unstage Op. -// If not specified, defaults to "" -func MapStageSharedName(value string) MapStageAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Stage (key, values) in the underlying container which behaves like a hashtable. +// this splits values into a TensorArray with T tensors. +// +// TensorArray index t will be the subtensor of values with starting position +// +// ```(n0 + n1 + ... + n(t-1), 0, 0, ...)``` // -// Arguments: -// key: int64 +// and having size // -// values: a list of tensors -// dtypes A list of data types that inserted values should adhere to. +// ```nt x d0 x d1 x ...``` // +// Arguments: +// handle: The handle to a TensorArray. +// value: The concatenated tensor to write to the TensorArray. +// lengths: The vector of lengths, how to split the rows of value into the +// TensorArray. +// flow_in: A float scalar that enforces proper chaining of operations. // -// Returns the created operation. -func MapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf.Output, dtypes []tf.DataType, optional ...MapStageAttr) (o *tf.Operation) { +// Returns A float scalar that enforces proper chaining of operations. +func TensorArraySplitV3(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "MapStage", + Type: "TensorArraySplitV3", Input: []tf.Input{ - key, indices, tf.OutputList(values), + handle, value, lengths, flow_in, }, - Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// MapUnstageAttr is an optional argument to MapUnstage. -type MapUnstageAttr func(optionalAttr) - -// MapUnstageCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapUnstageCapacity(value int64) MapUnstageAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} +// SerializeSparseAttr is an optional argument to SerializeSparse. +type SerializeSparseAttr func(optionalAttr) -// MapUnstageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// SerializeSparseOutType sets the optional out_type attribute to value. // -// REQUIRES: value >= 0 -func MapUnstageMemoryLimit(value int64) MapUnstageAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// MapUnstageContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func MapUnstageContainer(value string) MapUnstageAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MapUnstageSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func MapUnstageSharedName(value string) MapUnstageAttr { +// value: The `dtype` to use for serialization; the supported types are `string` +// (default) and `variant`. +// If not specified, defaults to DT_STRING +func SerializeSparseOutType(value tf.DataType) SerializeSparseAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["out_type"] = value } } -// Op removes and returns the values associated with the key +// Serialize a `SparseTensor` into a `[3]` `Tensor` object. // -// from the underlying container. If the underlying container -// does not contain this key, the op will block until it does. -func MapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageAttr) (values []tf.Output) { +// Arguments: +// sparse_indices: 2-D. The `indices` of the `SparseTensor`. +// sparse_values: 1-D. The `values` of the `SparseTensor`. +// sparse_shape: 1-D. The `shape` of the `SparseTensor`. +func SerializeSparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeSparseAttr) (serialized_sparse tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MapUnstage", + Type: "SerializeSparse", Input: []tf.Input{ - key, indices, + sparse_indices, sparse_values, sparse_shape, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("MapUnstage", err) - return - } - return values + return op.Output(0) } -// MapSizeAttr is an optional argument to MapSize. -type MapSizeAttr func(optionalAttr) +// RandomShuffleQueueV2Attr is an optional argument to RandomShuffleQueueV2. +type RandomShuffleQueueV2Attr func(optionalAttr) -// MapSizeCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// RandomShuffleQueueV2Shapes sets the optional shapes attribute to value. // -// REQUIRES: value >= 0 -func MapSizeCapacity(value int64) MapSizeAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// MapSizeMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// value: The shape of each component in a value. The length of this attr must +// be either 0 or the same as the length of component_types. If the length of +// this attr is 0, the shapes of queue elements are not constrained, and +// only one element may be dequeued at a time. +// If not specified, defaults to <> // -// REQUIRES: value >= 0 -func MapSizeMemoryLimit(value int64) MapSizeAttr { +// REQUIRES: len(value) >= 0 +func RandomShuffleQueueV2Shapes(value []tf.Shape) RandomShuffleQueueV2Attr { return func(m optionalAttr) { - m["memory_limit"] = value + m["shapes"] = value } } -// MapSizeContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func MapSizeContainer(value string) MapSizeAttr { +// RandomShuffleQueueV2Capacity sets the optional capacity attribute to value. +// +// value: The upper bound on the number of elements in this queue. +// Negative numbers mean no limit. +// If not specified, defaults to -1 +func RandomShuffleQueueV2Capacity(value int64) RandomShuffleQueueV2Attr { return func(m optionalAttr) { - m["container"] = value + m["capacity"] = value } } -// MapSizeSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func MapSizeSharedName(value string) MapSizeAttr { +// RandomShuffleQueueV2MinAfterDequeue sets the optional min_after_dequeue attribute to value. +// +// value: Dequeue will block unless there would be this +// many elements after the dequeue or the queue is closed. This +// ensures a minimum level of mixing of elements. +// If not specified, defaults to 0 +func RandomShuffleQueueV2MinAfterDequeue(value int64) RandomShuffleQueueV2Attr { return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op returns the number of elements in the underlying container. -func MapSize(scope *Scope, dtypes []tf.DataType, optional ...MapSizeAttr) (size tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MapSize", - - Attrs: attrs, + m["min_after_dequeue"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// MapIncompleteSizeAttr is an optional argument to MapIncompleteSize. -type MapIncompleteSizeAttr func(optionalAttr) - -// MapIncompleteSizeCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// RandomShuffleQueueV2Seed sets the optional seed attribute to value. // -// REQUIRES: value >= 0 -func MapIncompleteSizeCapacity(value int64) MapIncompleteSizeAttr { +// value: If either seed or seed2 is set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, a random seed is used. +// If not specified, defaults to 0 +func RandomShuffleQueueV2Seed(value int64) RandomShuffleQueueV2Attr { return func(m optionalAttr) { - m["capacity"] = value + m["seed"] = value } } -// MapIncompleteSizeMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// RandomShuffleQueueV2Seed2 sets the optional seed2 attribute to value. // -// REQUIRES: value >= 0 -func MapIncompleteSizeMemoryLimit(value int64) MapIncompleteSizeAttr { +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomShuffleQueueV2Seed2(value int64) RandomShuffleQueueV2Attr { return func(m optionalAttr) { - m["memory_limit"] = value + m["seed2"] = value } } -// MapIncompleteSizeContainer sets the optional container attribute to value. +// RandomShuffleQueueV2Container sets the optional container attribute to value. +// +// value: If non-empty, this queue is placed in the given container. +// Otherwise, a default container is used. // If not specified, defaults to "" -func MapIncompleteSizeContainer(value string) MapIncompleteSizeAttr { +func RandomShuffleQueueV2Container(value string) RandomShuffleQueueV2Attr { return func(m optionalAttr) { m["container"] = value } } -// MapIncompleteSizeSharedName sets the optional shared_name attribute to value. +// RandomShuffleQueueV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this queue will be shared under the given name +// across multiple sessions. // If not specified, defaults to "" -func MapIncompleteSizeSharedName(value string) MapIncompleteSizeAttr { +func RandomShuffleQueueV2SharedName(value string) RandomShuffleQueueV2Attr { return func(m optionalAttr) { m["shared_name"] = value } } -// Op returns the number of incomplete elements in the underlying container. -func MapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...MapIncompleteSizeAttr) (size tf.Output) { +// A queue that randomizes the order of elements. +// +// Arguments: +// component_types: The type of each component in a value. +// +// Returns The handle to the queue. +func RandomShuffleQueueV2(scope *Scope, component_types []tf.DataType, optional ...RandomShuffleQueueV2Attr) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{"component_types": component_types} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MapIncompleteSize", + Type: "RandomShuffleQueueV2", Attrs: attrs, } @@ -24512,188 +24465,211 @@ func MapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...MapIncomp return op.Output(0) } -// OrderedMapUnstageAttr is an optional argument to OrderedMapUnstage. -type OrderedMapUnstageAttr func(optionalAttr) - -// OrderedMapUnstageCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// Draw bounding boxes on a batch of images. // -// REQUIRES: value >= 0 -func OrderedMapUnstageCapacity(value int64) OrderedMapUnstageAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// OrderedMapUnstageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// Outputs a copy of `images` but draws on top of the pixels zero or more bounding +// boxes specified by the locations in `boxes`. The coordinates of the each +// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The +// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and +// height of the underlying image. // -// REQUIRES: value >= 0 -func OrderedMapUnstageMemoryLimit(value int64) OrderedMapUnstageAttr { - return func(m optionalAttr) { - m["memory_limit"] = value +// For example, if an image is 100 x 200 pixels (height x width) and the bounding +// box is `[0.1, 0.2, 0.5, 0.9]`, the upper-left and bottom-right coordinates of +// the bounding box will be `(40, 10)` to `(100, 50)` (in (x,y) coordinates). +// +// Parts of the bounding box may fall outside the image. +// +// Arguments: +// images: 4-D with shape `[batch, height, width, depth]`. A batch of images. +// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding +// boxes. +// +// Returns 4-D with the same shape as `images`. The batch of input images with +// bounding boxes drawn on the images. +func DrawBoundingBoxes(scope *Scope, images tf.Output, boxes tf.Output) (output tf.Output) { + if scope.Err() != nil { + return } + opspec := tf.OpSpec{ + Type: "DrawBoundingBoxes", + Input: []tf.Input{ + images, boxes, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) } -// OrderedMapUnstageContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func OrderedMapUnstageContainer(value string) OrderedMapUnstageAttr { +// LearnedUnigramCandidateSamplerAttr is an optional argument to LearnedUnigramCandidateSampler. +type LearnedUnigramCandidateSamplerAttr func(optionalAttr) + +// LearnedUnigramCandidateSamplerSeed sets the optional seed attribute to value. +// +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func LearnedUnigramCandidateSamplerSeed(value int64) LearnedUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["container"] = value + m["seed"] = value } } -// OrderedMapUnstageSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func OrderedMapUnstageSharedName(value string) OrderedMapUnstageAttr { +// LearnedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func LearnedUnigramCandidateSamplerSeed2(value int64) LearnedUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["seed2"] = value } } -// Op removes and returns the values associated with the key +// Generates labels for candidate sampling with a learned unigram distribution. // -// from the underlying container. If the underlying container -// does not contain this key, the op will block until it does. -func OrderedMapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapUnstageAttr) (values []tf.Output) { +// See explanations of candidate sampling and the data formats at +// go/candidate-sampling. +// +// For each batch, this op picks a single set of sampled candidate labels. +// +// The advantages of sampling candidates per-batch are simplicity and the +// possibility of efficient dense matrix multiplication. The disadvantage is that +// the sampled candidates must be chosen independently of the context and of the +// true labels. +// +// Arguments: +// true_classes: A batch_size * num_true matrix, in which each row contains the +// IDs of the num_true target_classes in the corresponding original label. +// num_true: Number of true labels per context. +// num_sampled: Number of candidates to randomly sample. +// unique: If unique is true, we sample with rejection, so that all sampled +// candidates in a batch are unique. This requires some approximation to +// estimate the post-rejection sampling probabilities. +// range_max: The sampler will sample integers from the interval [0, range_max). +// +// Returns A vector of length num_sampled, in which each element is +// the ID of a sampled candidate.A batch_size * num_true matrix, representing +// the number of times each candidate is expected to occur in a batch +// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled +// candidate representing the number of times the candidate is expected +// to occur in a batch of sampled candidates. If unique=true, then this is a +// probability. +func LearnedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LearnedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "OrderedMapUnstage", + Type: "LearnedUnigramCandidateSampler", Input: []tf.Input{ - key, indices, + true_classes, }, Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Computes gradients for the scaled exponential linear (Selu) operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding Selu operation. +// outputs: The outputs of the corresponding Selu operation. +// +// Returns The gradients: `gradients * (outputs + scale * alpha)` +// if outputs < 0, `scale * gradients` otherwise. +func SeluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { if scope.Err() != nil { return } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("OrderedMapUnstage", err) - return + opspec := tf.OpSpec{ + Type: "SeluGrad", + Input: []tf.Input{ + gradients, outputs, + }, } - return values + op := scope.AddOperation(opspec) + return op.Output(0) } -// OrderedMapSizeAttr is an optional argument to OrderedMapSize. -type OrderedMapSizeAttr func(optionalAttr) - -// OrderedMapSizeCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// Get the current size of the TensorArray. // -// REQUIRES: value >= 0 -func OrderedMapSizeCapacity(value int64) OrderedMapSizeAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// OrderedMapSizeMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// Arguments: +// handle: The handle to a TensorArray (output of TensorArray or TensorArrayGrad). +// flow_in: A float scalar that enforces proper chaining of operations. // -// REQUIRES: value >= 0 -func OrderedMapSizeMemoryLimit(value int64) OrderedMapSizeAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// OrderedMapSizeContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func OrderedMapSizeContainer(value string) OrderedMapSizeAttr { - return func(m optionalAttr) { - m["container"] = value +// Returns The current size of the TensorArray. +func TensorArraySizeV3(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) { + if scope.Err() != nil { + return } -} - -// OrderedMapSizeSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func OrderedMapSizeSharedName(value string) OrderedMapSizeAttr { - return func(m optionalAttr) { - m["shared_name"] = value + opspec := tf.OpSpec{ + Type: "TensorArraySizeV3", + Input: []tf.Input{ + handle, flow_in, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Op returns the number of elements in the underlying container. -func OrderedMapSize(scope *Scope, dtypes []tf.DataType, optional ...OrderedMapSizeAttr) (size tf.Output) { +// Deprecated. Use TensorArrayGradV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayWriteV3 +func TensorArrayWriteV2(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "OrderedMapSize", - - Attrs: attrs, + Type: "TensorArrayWriteV2", + Input: []tf.Input{ + handle, index, value, flow_in, + }, } op := scope.AddOperation(opspec) return op.Output(0) } -// CTCLossAttr is an optional argument to CTCLoss. -type CTCLossAttr func(optionalAttr) +// SparseReduceMaxAttr is an optional argument to SparseReduceMax. +type SparseReduceMaxAttr func(optionalAttr) -// CTCLossPreprocessCollapseRepeated sets the optional preprocess_collapse_repeated attribute to value. +// SparseReduceMaxKeepDims sets the optional keep_dims attribute to value. // -// value: Scalar, if true then repeated labels are -// collapsed prior to the CTC calculation. +// value: If true, retain reduced dimensions with length 1. // If not specified, defaults to false -func CTCLossPreprocessCollapseRepeated(value bool) CTCLossAttr { +func SparseReduceMaxKeepDims(value bool) SparseReduceMaxAttr { return func(m optionalAttr) { - m["preprocess_collapse_repeated"] = value + m["keep_dims"] = value } } -// CTCLossCtcMergeRepeated sets the optional ctc_merge_repeated attribute to value. +// Computes the max of elements across dimensions of a SparseTensor. // -// value: Scalar. If set to false, *during* CTC calculation -// repeated non-blank labels will not be merged and are interpreted as -// individual labels. This is a simplified version of CTC. -// If not specified, defaults to true -func CTCLossCtcMergeRepeated(value bool) CTCLossAttr { - return func(m optionalAttr) { - m["ctc_merge_repeated"] = value - } -} - -// CTCLossIgnoreLongerOutputsThanInputs sets the optional ignore_longer_outputs_than_inputs attribute to value. +// This Op takes a SparseTensor and is the sparse counterpart to +// `tf.reduce_max()`. In particular, this Op also returns a dense `Tensor` +// instead of a sparse one. // -// value: Scalar. If set to true, during CTC -// calculation, items that have longer output sequences than input sequences -// are skipped: they don't contribute to the loss term and have zero-gradient. -// If not specified, defaults to false -func CTCLossIgnoreLongerOutputsThanInputs(value bool) CTCLossAttr { - return func(m optionalAttr) { - m["ignore_longer_outputs_than_inputs"] = value - } -} - -// Calculates the CTC Loss (log probability) for each batch entry. Also calculates +// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained +// with length 1. // -// the gradient. This class performs the softmax operation for you, so inputs -// should be e.g. linear projections of outputs by an LSTM. +// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor +// with a single element is returned. Additionally, the axes can be negative, +// which are interpreted according to the indexing rules in Python. // // Arguments: -// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. -// labels_indices: The indices of a `SparseTensor`. -// `labels_indices(i, :) == [b, t]` means `labels_values(i)` stores the id for -// `(batch b, time t)`. -// labels_values: The values (labels) associated with the given batch and time. -// sequence_length: A vector containing sequence lengths (batch). +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. +// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. // -// Returns A vector (batch) containing log-probabilities.The gradient of `loss`. 3-D, shape: -// `(max_time x batch_size x num_classes)`. -func CTCLoss(scope *Scope, inputs tf.Output, labels_indices tf.Output, labels_values tf.Output, sequence_length tf.Output, optional ...CTCLossAttr) (loss tf.Output, gradient tf.Output) { +// Returns `R-K`-D. The reduced Tensor. +func SparseReduceMax(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceMaxAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -24702,51 +24678,78 @@ func CTCLoss(scope *Scope, inputs tf.Output, labels_indices tf.Output, labels_va a(attrs) } opspec := tf.OpSpec{ - Type: "CTCLoss", + Type: "SparseReduceMax", Input: []tf.Input{ - inputs, labels_indices, labels_values, sequence_length, + input_indices, input_values, input_shape, reduction_axes, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) +} + +// AsStringAttr is an optional argument to AsString. +type AsStringAttr func(optionalAttr) + +// AsStringPrecision sets the optional precision attribute to value. +// +// value: The post-decimal precision to use for floating point numbers. +// Only used if precision > -1. +// If not specified, defaults to -1 +func AsStringPrecision(value int64) AsStringAttr { + return func(m optionalAttr) { + m["precision"] = value + } +} + +// AsStringScientific sets the optional scientific attribute to value. +// +// value: Use scientific notation for floating point numbers. +// If not specified, defaults to false +func AsStringScientific(value bool) AsStringAttr { + return func(m optionalAttr) { + m["scientific"] = value + } } -// CTCGreedyDecoderAttr is an optional argument to CTCGreedyDecoder. -type CTCGreedyDecoderAttr func(optionalAttr) - -// CTCGreedyDecoderMergeRepeated sets the optional merge_repeated attribute to value. +// AsStringShortest sets the optional shortest attribute to value. // -// value: If True, merge repeated classes in output. +// value: Use shortest representation (either scientific or standard) for +// floating point numbers. // If not specified, defaults to false -func CTCGreedyDecoderMergeRepeated(value bool) CTCGreedyDecoderAttr { +func AsStringShortest(value bool) AsStringAttr { return func(m optionalAttr) { - m["merge_repeated"] = value + m["shortest"] = value } } -// Performs greedy decoding on the logits given in inputs. -// -// A note about the attribute merge_repeated: if enabled, when -// consecutive logits' maximum indices are the same, only the first of -// these is emitted. Labeling the blank '*', the sequence "A B B * B B" -// becomes "A B B" if merge_repeated = True and "A B B B B" if -// merge_repeated = False. +// AsStringWidth sets the optional width attribute to value. // -// Regardless of the value of merge_repeated, if the maximum index of a given -// time and batch corresponds to the blank, index `(num_classes - 1)`, no new -// element is emitted. +// value: Pad pre-decimal numbers to this width. +// Applies to both floating point and integer numbers. +// Only used if width > -1. +// If not specified, defaults to -1 +func AsStringWidth(value int64) AsStringAttr { + return func(m optionalAttr) { + m["width"] = value + } +} + +// AsStringFill sets the optional fill attribute to value. // -// Arguments: -// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. -// sequence_length: A vector containing sequence lengths, size `(batch_size)`. +// value: The value to pad if width > -1. If empty, pads with spaces. +// Another typical value is '0'. String cannot be longer than 1 character. +// If not specified, defaults to "" +func AsStringFill(value string) AsStringAttr { + return func(m optionalAttr) { + m["fill"] = value + } +} + +// Converts each entry in the given tensor to strings. Supports many numeric // -// Returns Indices matrix, size `(total_decoded_outputs x 2)`, -// of a `SparseTensor`. The rows store: [batch, time].Values vector, size: `(total_decoded_outputs)`, -// of a `SparseTensor`. The vector stores the decoded classes.Shape vector, size `(2)`, of the decoded SparseTensor. -// Values are: `[batch_size, max_decoded_length]`.Matrix, size `(batch_size x 1)`, containing sequence -// log-probabilities. -func CTCGreedyDecoder(scope *Scope, inputs tf.Output, sequence_length tf.Output, optional ...CTCGreedyDecoderAttr) (decoded_indices tf.Output, decoded_values tf.Output, decoded_shape tf.Output, log_probability tf.Output) { +// types and boolean. +func AsString(scope *Scope, input tf.Output, optional ...AsStringAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -24755,380 +24758,326 @@ func CTCGreedyDecoder(scope *Scope, inputs tf.Output, sequence_length tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "CTCGreedyDecoder", + Type: "AsString", Input: []tf.Input{ - inputs, sequence_length, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) + return op.Output(0) } -// Forwards `data` to the output port determined by `pred`. -// -// If `pred` is true, the `data` input is forwarded to `output_true`. Otherwise, -// the data goes to `output_false`. -// -// See also `RefSwitch` and `Merge`. -// -// Arguments: -// data: The tensor to be forwarded to the appropriate output. -// pred: A scalar that specifies which output port will receive data. +// Deprecated. Use TensorArrayScatterV3 // -// Returns If `pred` is false, data will be forwarded to this output.If `pred` is true, data will be forwarded to this output. -func Switch(scope *Scope, data tf.Output, pred tf.Output) (output_false tf.Output, output_true tf.Output) { +// DEPRECATED at GraphDef version 26: Use TensorArrayScatterV3 +func TensorArrayScatterV2(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Switch", + Type: "TensorArrayScatterV2", Input: []tf.Input{ - data, pred, + handle, indices, value, flow_in, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Add all input tensors element wise. +// Applies sparse addition to `input` using individual values or slices +// +// from `updates` according to indices `indices`. The updates are non-aliasing: +// `input` is only modified in-place if no other operations will use it. +// Otherwise, a copy of `input` is made. This operation has a gradient with +// respect to both `input` and `updates`. +// +// `input` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. +// +// `indices` must be integer tensor, containing indices into `input`. +// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. +// +// The innermost dimension of `indices` (with length `K`) corresponds to +// indices into elements (if `K = P`) or `(P-K)`-dimensional slices +// (if `K < P`) along the `K`th dimension of `input`. +// +// `updates` is `Tensor` of rank `Q-1+P-K` with shape: +// +// ``` +// [d_0, ..., d_{Q-2}, input.shape[K], ..., input.shape[P-1]]. +// ``` +// +// For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 +// elements. In Python, that addition would look like this: +// +// input = tf.constant([1, 2, 3, 4, 5, 6, 7, 8]) +// indices = tf.constant([[4], [3], [1], [7]]) +// updates = tf.constant([9, 10, 11, 12]) +// output = tf.scatter_nd_non_aliasing_add(input, indices, updates) +// with tf.Session() as sess: +// print(sess.run(output)) +// +// The resulting value `output` would look like this: +// +// [1, 13, 3, 14, 14, 6, 7, 20] +// +// See @{tf.scatter_nd} for more details about how to make updates to slices. // // Arguments: -// inputs: Must all be the same size and shape. -func AddN(scope *Scope, inputs []tf.Output) (sum tf.Output) { +// input: A Tensor. +// indices: A Tensor. Must be one of the following types: `int32`, `int64`. +// A tensor of indices into `input`. +// updates: A Tensor. Must have the same type as ref. A tensor of updated values +// to add to `input`. +// +// Returns A `Tensor` with the same shape as `input`, containing values of `input` +// updated with `updates`. +func ScatterNdNonAliasingAdd(scope *Scope, input tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "AddN", + Type: "ScatterNdNonAliasingAdd", Input: []tf.Input{ - tf.OutputList(inputs), + input, indices, updates, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// EnterAttr is an optional argument to Enter. -type EnterAttr func(optionalAttr) +// FractionalMaxPoolAttr is an optional argument to FractionalMaxPool. +type FractionalMaxPoolAttr func(optionalAttr) -// EnterIsConstant sets the optional is_constant attribute to value. +// FractionalMaxPoolPseudoRandom sets the optional pseudo_random attribute to value. // -// value: If true, the output is constant within the child frame. +// value: When set to True, generates the pooling sequence in a +// pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin +// Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) for +// difference between pseudorandom and random. // If not specified, defaults to false -func EnterIsConstant(value bool) EnterAttr { +func FractionalMaxPoolPseudoRandom(value bool) FractionalMaxPoolAttr { return func(m optionalAttr) { - m["is_constant"] = value + m["pseudo_random"] = value } } -// EnterParallelIterations sets the optional parallel_iterations attribute to value. +// FractionalMaxPoolOverlapping sets the optional overlapping attribute to value. // -// value: The number of iterations allowed to run in parallel. -// If not specified, defaults to 10 -func EnterParallelIterations(value int64) EnterAttr { +// value: When set to True, it means when pooling, the values at the boundary +// of adjacent pooling cells are used by both cells. For example: +// +// `index 0 1 2 3 4` +// +// `value 20 5 16 3 7` +// +// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. +// The result would be [20, 16] for fractional max pooling. +// If not specified, defaults to false +func FractionalMaxPoolOverlapping(value bool) FractionalMaxPoolAttr { return func(m optionalAttr) { - m["parallel_iterations"] = value + m["overlapping"] = value } } -// Creates or finds a child frame, and makes `data` available to the child frame. -// -// This op is used together with `Exit` to create loops in the graph. -// The unique `frame_name` is used by the `Executor` to identify frames. If -// `is_constant` is true, `output` is a constant in the child frame; otherwise -// it may be changed in the child frame. At most `parallel_iterations` iterations -// are run in parallel in the child frame. -// -// Arguments: -// data: The tensor to be made available to the child frame. -// frame_name: The name of the child frame. +// FractionalMaxPoolDeterministic sets the optional deterministic attribute to value. // -// Returns The same tensor as `data`. -func Enter(scope *Scope, data tf.Output, frame_name string, optional ...EnterAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"frame_name": frame_name} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Enter", - Input: []tf.Input{ - data, - }, - Attrs: attrs, +// value: When set to True, a fixed pooling region will be used when +// iterating over a FractionalMaxPool node in the computation graph. Mainly used +// in unit test to make FractionalMaxPool deterministic. +// If not specified, defaults to false +func FractionalMaxPoolDeterministic(value bool) FractionalMaxPoolAttr { + return func(m optionalAttr) { + m["deterministic"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Produce a string tensor that encodes the state of a Reader. -// -// Not all Readers support being serialized, so this can produce an -// Unimplemented error. +// FractionalMaxPoolSeed sets the optional seed attribute to value. // -// Arguments: -// reader_handle: Handle to a Reader. -func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Output) { - if scope.Err() != nil { - return +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func FractionalMaxPoolSeed(value int64) FractionalMaxPoolAttr { + return func(m optionalAttr) { + m["seed"] = value } - opspec := tf.OpSpec{ - Type: "ReaderSerializeStateV2", - Input: []tf.Input{ - reader_handle, - }, +} + +// FractionalMaxPoolSeed2 sets the optional seed2 attribute to value. +// +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func FractionalMaxPoolSeed2(value int64) FractionalMaxPoolAttr { + return func(m optionalAttr) { + m["seed2"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Exits the current frame to its parent frame. +// Performs fractional max pooling on the input. +// +// Fractional max pooling is slightly different than regular max pooling. In +// regular max pooling, you downsize an input set by taking the maximum value of +// smaller N x N subsections of the set (often 2x2), and try to reduce the set by +// a factor of N, where N is an integer. Fractional max pooling, as you might +// expect from the word "fractional", means that the overall reduction ratio N +// does not have to be an integer. +// +// The sizes of the pooling regions are generated randomly but are fairly uniform. +// For example, let's look at the height dimension, and the constraints on the +// list of rows that will be pool boundaries. +// +// First we define the following: +// +// 1. input_row_length : the number of rows from the input set +// 2. output_row_length : which will be smaller than the input +// 3. alpha = input_row_length / output_row_length : our reduction ratio +// 4. K = floor(alpha) +// 5. row_pooling_sequence : this is the result list of pool boundary rows // -// Exit makes its input `data` available to the parent frame. +// Then, row_pooling_sequence should satisfy: +// +// 1. a[0] = 0 : the first value of the sequence is 0 +// 2. a[end] = input_row_length : the last value of the sequence is the size +// 3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size +// 4. length(row_pooling_sequence) = output_row_length+1 +// +// For more details on fractional max pooling, see this paper: +// [Benjamin Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) // // Arguments: -// data: The tensor to be made available to the parent frame. +// value: 4-D with shape `[batch, height, width, channels]`. +// pooling_ratio: Pooling ratio for each dimension of `value`, currently only +// supports row and col dimension and should be >= 1.0. For example, a valid +// pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements +// must be 1.0 because we don't allow pooling on batch and channels +// dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions +// respectively. // -// Returns The same tensor as `data`. -func Exit(scope *Scope, data tf.Output) (output tf.Output) { +// Returns output tensor after fractional max pooling.row pooling sequence, needed to calculate gradient.column pooling sequence, needed to calculate gradient. +func FractionalMaxPool(scope *Scope, value tf.Output, pooling_ratio []float32, optional ...FractionalMaxPoolAttr) (output tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "Exit", - Input: []tf.Input{ - data, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns a copy of the input tensor. -func Snapshot(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return + attrs := map[string]interface{}{"pooling_ratio": pooling_ratio} + for _, a := range optional { + a(attrs) } opspec := tf.OpSpec{ - Type: "Snapshot", + Type: "FractionalMaxPool", Input: []tf.Input{ - input, + value, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Scatter `updates` into a new (initially zero) tensor according to `indices`. -// -// Creates a new tensor by applying sparse `updates` to individual -// values or slices within a zero tensor of the given `shape` according to -// indices. This operator is the inverse of the @{tf.gather_nd} operator which -// extracts values or slices from a given tensor. -// -// **WARNING**: The order in which updates are applied is nondeterministic, so the -// output will be nondeterministic if `indices` contains duplicates. -// -// `indices` is an integer tensor containing indices into a new tensor of shape -// `shape`. The last dimension of `indices` can be at most the rank of `shape`: -// -// indices.shape[-1] <= shape.rank -// -// The last dimension of `indices` corresponds to indices into elements -// (if `indices.shape[-1] = shape.rank`) or slices -// (if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of -// `shape`. `updates` is a tensor with shape -// -// indices.shape[:-1] + shape[indices.shape[-1]:] -// -// The simplest form of scatter is to insert individual elements in a tensor by -// index. For example, say we want to insert 4 scattered elements in a rank-1 -// tensor with 8 elements. -// -//
-// -//
-// -// In Python, this scatter operation would look like this: -// -// ```python -// indices = tf.constant([[4], [3], [1], [7]]) -// updates = tf.constant([9, 10, 11, 12]) -// shape = tf.constant([8]) -// scatter = tf.scatter_nd(indices, updates, shape) -// with tf.Session() as sess: -// print(sess.run(scatter)) -// ``` -// -// The resulting tensor would look like this: -// -// [0, 11, 0, 10, 9, 0, 0, 12] -// -// We can also, insert entire slices of a higher rank tensor all at once. For -// example, if we wanted to insert two slices in the first dimension of a -// rank-3 tensor with two matrices of new values. -// -//
-// -//
-// -// In Python, this scatter operation would look like this: -// -// ```python -// indices = tf.constant([[0], [2]]) -// updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], -// [7, 7, 7, 7], [8, 8, 8, 8]], -// [[5, 5, 5, 5], [6, 6, 6, 6], -// [7, 7, 7, 7], [8, 8, 8, 8]]]) -// shape = tf.constant([4, 4, 4]) -// scatter = tf.scatter_nd(indices, updates, shape) -// with tf.Session() as sess: -// print(sess.run(scatter)) -// ``` -// -// The resulting tensor would look like this: -// -// [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], -// [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], -// [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], -// [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] -// -// Arguments: -// indices: Index tensor. -// updates: Updates to scatter into output. -// shape: 1-D. The shape of the resulting tensor. +// Deprecated. Use TensorArraySizeV3 // -// Returns A new tensor with the given shape and updates applied according -// to the indices. -func ScatterNd(scope *Scope, indices tf.Output, updates tf.Output, shape tf.Output) (output tf.Output) { +// DEPRECATED at GraphDef version 26: Use TensorArraySizeV3 +func TensorArraySizeV2(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ScatterNd", + Type: "TensorArraySizeV2", Input: []tf.Input{ - indices, updates, shape, + handle, flow_in, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// SpaceToDepthAttr is an optional argument to SpaceToDepth. -type SpaceToDepthAttr func(optionalAttr) +// Conv2DAttr is an optional argument to Conv2D. +type Conv2DAttr func(optionalAttr) -// SpaceToDepthDataFormat sets the optional data_format attribute to value. +// Conv2DUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. +// If not specified, defaults to true +func Conv2DUseCudnnOnGpu(value bool) Conv2DAttr { + return func(m optionalAttr) { + m["use_cudnn_on_gpu"] = value + } +} + +// Conv2DDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, height, width, channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, channels, height, width]. // If not specified, defaults to "NHWC" -func SpaceToDepthDataFormat(value string) SpaceToDepthAttr { +func Conv2DDataFormat(value string) Conv2DAttr { return func(m optionalAttr) { m["data_format"] = value } } -// SpaceToDepth for tensors of type T. -// -// Rearranges blocks of spatial data, into depth. More specifically, -// this op outputs a copy of the input tensor where values from the `height` -// and `width` dimensions are moved to the `depth` dimension. -// The attr `block_size` indicates the input block size. -// -// * Non-overlapping blocks of size `block_size x block size` are rearranged -// into depth at each location. -// * The depth of the output tensor is `block_size * block_size * input_depth`. -// * The Y, X coordinates within each block of the input become the high order -// component of the output channel index. -// * The input tensor's height and width must be divisible by block_size. -// -// The `data_format` attr specifies the layout of the input and output tensors -// with the following options: -// "NHWC": `[ batch, height, width, channels ]` -// "NCHW": `[ batch, channels, height, width ]` -// "NCHW_VECT_C": -// `qint8 [ batch, channels / 4, height, width, 4 ]` -// -// It is useful to consider the operation as transforming a 6-D Tensor. -// e.g. for data_format = NHWC, -// Each element in the input tensor can be specified via 6 coordinates, -// ordered by decreasing memory layout significance as: -// n,oY,bY,oX,bX,iC (where n=batch index, oX, oY means X or Y coordinates -// within the output image, bX, bY means coordinates -// within the input block, iC means input channels). -// The output would be a transpose to the following layout: -// n,oY,oX,bY,bX,iC -// -// This operation is useful for resizing the activations between convolutions -// (but keeping all data), e.g. instead of pooling. It is also useful for training -// purely convolutional models. -// -// For example, given an input of shape `[1, 2, 2, 1]`, data_format = "NHWC" and -// block_size = 2: -// -// ``` -// x = [[[[1], [2]], -// [[3], [4]]]] -// ``` -// -// This operation will output a tensor of shape `[1, 1, 1, 4]`: -// -// ``` -// [[[[1, 2, 3, 4]]]] -// ``` -// -// Here, the input has a batch of 1 and each batch element has shape `[2, 2, 1]`, -// the corresponding output will have a single element (i.e. width and height are -// both 1) and will have a depth of 4 channels (1 * block_size * block_size). -// The output element shape is `[1, 1, 4]`. -// -// For an input tensor with larger depth, here of shape `[1, 2, 2, 3]`, e.g. -// -// ``` -// x = [[[[1, 2, 3], [4, 5, 6]], -// [[7, 8, 9], [10, 11, 12]]]] -// ``` +// Conv2DDilations sets the optional dilations attribute to value. // -// This operation, for block_size of 2, will return the following tensor of shape -// `[1, 1, 1, 12]` +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each +// filter element on that dimension. The dimension order is determined by the +// value of `data_format`, see above for details. Dilations in the batch and +// depth dimensions must be 1. +// If not specified, defaults to +func Conv2DDilations(value []int64) Conv2DAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes a 2-D convolution given 4-D `input` and `filter` tensors. // -// ``` -// [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]] -// ``` +// Given an input tensor of shape `[batch, in_height, in_width, in_channels]` +// and a filter / kernel tensor of shape +// `[filter_height, filter_width, in_channels, out_channels]`, this op +// performs the following: // -// Similarly, for the following input of shape `[1 4 4 1]`, and a block size of 2: +// 1. Flattens the filter to a 2-D matrix with shape +// `[filter_height * filter_width * in_channels, output_channels]`. +// 2. Extracts image patches from the input tensor to form a *virtual* +// tensor of shape `[batch, out_height, out_width, +// filter_height * filter_width * in_channels]`. +// 3. For each patch, right-multiplies the filter matrix and the image patch +// vector. // -// ``` -// x = [[[[1], [2], [5], [6]], -// [[3], [4], [7], [8]], -// [[9], [10], [13], [14]], -// [[11], [12], [15], [16]]]] -// ``` +// In detail, with the default NHWC format, // -// the operator will return the following tensor of shape `[1 2 2 4]`: +// output[b, i, j, k] = +// sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] * +// filter[di, dj, q, k] // -// ``` -// x = [[[[1, 2, 3, 4], -// [5, 6, 7, 8]], -// [[9, 10, 11, 12], -// [13, 14, 15, 16]]]] -// ``` +// Must have `strides[0] = strides[3] = 1`. For the most common case of the same +// horizontal and vertices strides, `strides = [1, stride, stride, 1]`. // // Arguments: +// input: A 4-D tensor. The dimension order is interpreted according to the value +// of `data_format`, see below for details. +// filter: A 4-D tensor of shape +// `[filter_height, filter_width, in_channels, out_channels]` +// strides: 1-D tensor of length 4. The stride of the sliding window for each +// dimension of `input`. The dimension order is determined by the value of +// `data_format`, see below for details. +// padding: The type of padding algorithm to use. // -// block_size: The size of the spatial block. -func SpaceToDepth(scope *Scope, input tf.Output, block_size int64, optional ...SpaceToDepthAttr) (output tf.Output) { +// Returns A 4-D tensor. The dimension order is determined by the value of +// `data_format`, see below for details. +func Conv2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...Conv2DAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"block_size": block_size} + attrs := map[string]interface{}{"strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "SpaceToDepth", + Type: "Conv2D", Input: []tf.Input{ - input, + input, filter, }, Attrs: attrs, } @@ -25136,536 +25085,471 @@ func SpaceToDepth(scope *Scope, input tf.Output, block_size int64, optional ...S return op.Output(0) } -// AbortAttr is an optional argument to Abort. -type AbortAttr func(optionalAttr) +// StageAttr is an optional argument to Stage. +type StageAttr func(optionalAttr) -// AbortErrorMsg sets the optional error_msg attribute to value. +// StageCapacity sets the optional capacity attribute to value. // -// value: A string which is the message associated with the exception. -// If not specified, defaults to "" -func AbortErrorMsg(value string) AbortAttr { - return func(m optionalAttr) { - m["error_msg"] = value - } -} - -// AbortExitWithoutError sets the optional exit_without_error attribute to value. -// If not specified, defaults to false -func AbortExitWithoutError(value bool) AbortAttr { +// value: Maximum number of elements in the Staging Area. If > 0, inserts +// on the container will block when the capacity is reached. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func StageCapacity(value int64) StageAttr { return func(m optionalAttr) { - m["exit_without_error"] = value + m["capacity"] = value } } -// Raise a exception to abort the process when called. -// -// If exit_without_error is true, the process will exit normally, -// otherwise it will exit with a SIGABORT signal. +// StageMemoryLimit sets the optional memory_limit attribute to value. // -// Returns nothing but an exception. +// value: The maximum number of bytes allowed for Tensors in the Staging Area. +// If > 0, inserts will block until sufficient space is available. +// If not specified, defaults to 0 // -// Returns the created operation. -func Abort(scope *Scope, optional ...AbortAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Abort", - - Attrs: attrs, +// REQUIRES: value >= 0 +func StageMemoryLimit(value int64) StageAttr { + return func(m optionalAttr) { + m["memory_limit"] = value } - return scope.AddOperation(opspec) } -// UniformCandidateSamplerAttr is an optional argument to UniformCandidateSampler. -type UniformCandidateSamplerAttr func(optionalAttr) - -// UniformCandidateSamplerSeed sets the optional seed attribute to value. +// StageContainer sets the optional container attribute to value. // -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func UniformCandidateSamplerSeed(value int64) UniformCandidateSamplerAttr { +// value: If non-empty, this queue is placed in the given container. Otherwise, +// a default container is used. +// If not specified, defaults to "" +func StageContainer(value string) StageAttr { return func(m optionalAttr) { - m["seed"] = value + m["container"] = value } } -// UniformCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// StageSharedName sets the optional shared_name attribute to value. // -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func UniformCandidateSamplerSeed2(value int64) UniformCandidateSamplerAttr { +// value: It is necessary to match this name to the matching Unstage Op. +// If not specified, defaults to "" +func StageSharedName(value string) StageAttr { return func(m optionalAttr) { - m["seed2"] = value + m["shared_name"] = value } } -// Generates labels for candidate sampling with a uniform distribution. -// -// See explanations of candidate sampling and the data formats at -// go/candidate-sampling. -// -// For each batch, this op picks a single set of sampled candidate labels. +// Stage values similar to a lightweight Enqueue. // -// The advantages of sampling candidates per-batch are simplicity and the -// possibility of efficient dense matrix multiplication. The disadvantage is that -// the sampled candidates must be chosen independently of the context and of the -// true labels. +// The basic functionality of this Op is similar to a queue with many +// fewer capabilities and options. This Op is optimized for performance. // // Arguments: -// true_classes: A batch_size * num_true matrix, in which each row contains the -// IDs of the num_true target_classes in the corresponding original label. -// num_true: Number of true labels per context. -// num_sampled: Number of candidates to randomly sample. -// unique: If unique is true, we sample with rejection, so that all sampled -// candidates in a batch are unique. This requires some approximation to -// estimate the post-rejection sampling probabilities. -// range_max: The sampler will sample integers from the interval [0, range_max). +// values: a list of tensors +// dtypes A list of data types that inserted values should adhere to. // -// Returns A vector of length num_sampled, in which each element is -// the ID of a sampled candidate.A batch_size * num_true matrix, representing -// the number of times each candidate is expected to occur in a batch -// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled -// candidate representing the number of times the candidate is expected -// to occur in a batch of sampled candidates. If unique=true, then this is a -// probability. -func UniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...UniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { +// Returns the created operation. +func Stage(scope *Scope, values []tf.Output, optional ...StageAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "UniformCandidateSampler", + Type: "Stage", Input: []tf.Input{ - true_classes, + tf.OutputList(values), }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return scope.AddOperation(opspec) } -// FixedUnigramCandidateSamplerAttr is an optional argument to FixedUnigramCandidateSampler. -type FixedUnigramCandidateSamplerAttr func(optionalAttr) +// StagePeekAttr is an optional argument to StagePeek. +type StagePeekAttr func(optionalAttr) -// FixedUnigramCandidateSamplerVocabFile sets the optional vocab_file attribute to value. +// StagePeekCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// value: Each valid line in this file (which should have a CSV-like format) -// corresponds to a valid word ID. IDs are in sequential order, starting from -// num_reserved_ids. The last entry in each line is expected to be a value -// corresponding to the count or relative probability. Exactly one of vocab_file -// and unigrams needs to be passed to this op. -// If not specified, defaults to "" -func FixedUnigramCandidateSamplerVocabFile(value string) FixedUnigramCandidateSamplerAttr { +// REQUIRES: value >= 0 +func StagePeekCapacity(value int64) StagePeekAttr { return func(m optionalAttr) { - m["vocab_file"] = value + m["capacity"] = value } } -// FixedUnigramCandidateSamplerDistortion sets the optional distortion attribute to value. +// StagePeekMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// value: The distortion is used to skew the unigram probability distribution. -// Each weight is first raised to the distortion's power before adding to the -// internal unigram distribution. As a result, distortion = 1.0 gives regular -// unigram sampling (as defined by the vocab file), and distortion = 0.0 gives -// a uniform distribution. -// If not specified, defaults to 1 -func FixedUnigramCandidateSamplerDistortion(value float32) FixedUnigramCandidateSamplerAttr { +// REQUIRES: value >= 0 +func StagePeekMemoryLimit(value int64) StagePeekAttr { return func(m optionalAttr) { - m["distortion"] = value + m["memory_limit"] = value } } -// FixedUnigramCandidateSamplerNumReservedIds sets the optional num_reserved_ids attribute to value. -// -// value: Optionally some reserved IDs can be added in the range [0, -// ..., num_reserved_ids) by the users. One use case is that a special unknown -// word token is used as ID 0. These IDs will have a sampling probability of 0. -// If not specified, defaults to 0 -func FixedUnigramCandidateSamplerNumReservedIds(value int64) FixedUnigramCandidateSamplerAttr { +// StagePeekContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func StagePeekContainer(value string) StagePeekAttr { return func(m optionalAttr) { - m["num_reserved_ids"] = value + m["container"] = value } } -// FixedUnigramCandidateSamplerNumShards sets the optional num_shards attribute to value. -// -// value: A sampler can be used to sample from a subset of the original range -// in order to speed up the whole computation through parallelism. This parameter -// (together with 'shard') indicates the number of partitions that are being -// used in the overall computation. -// If not specified, defaults to 1 -// -// REQUIRES: value >= 1 -func FixedUnigramCandidateSamplerNumShards(value int64) FixedUnigramCandidateSamplerAttr { +// StagePeekSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func StagePeekSharedName(value string) StagePeekAttr { return func(m optionalAttr) { - m["num_shards"] = value + m["shared_name"] = value } } -// FixedUnigramCandidateSamplerShard sets the optional shard attribute to value. +// Op peeks at the values at the specified index. If the // -// value: A sampler can be used to sample from a subset of the original range -// in order to speed up the whole computation through parallelism. This parameter -// (together with 'num_shards') indicates the particular partition number of a -// sampler op, when partitioning is being used. +// underlying container does not contain sufficient elements +// this op will block until it does. This Op is optimized for +// performance. +func StagePeek(scope *Scope, index tf.Output, dtypes []tf.DataType, optional ...StagePeekAttr) (values []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StagePeek", + Input: []tf.Input{ + index, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("StagePeek", err) + return + } + return values +} + +// MapStageAttr is an optional argument to MapStage. +type MapStageAttr func(optionalAttr) + +// MapStageCapacity sets the optional capacity attribute to value. +// +// value: Maximum number of elements in the Staging Area. If > 0, inserts +// on the container will block when the capacity is reached. // If not specified, defaults to 0 // // REQUIRES: value >= 0 -func FixedUnigramCandidateSamplerShard(value int64) FixedUnigramCandidateSamplerAttr { +func MapStageCapacity(value int64) MapStageAttr { return func(m optionalAttr) { - m["shard"] = value + m["capacity"] = value } } -// FixedUnigramCandidateSamplerUnigrams sets the optional unigrams attribute to value. +// MapStageMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// value: A list of unigram counts or probabilities, one per ID in sequential -// order. Exactly one of vocab_file and unigrams should be passed to this op. -// If not specified, defaults to <> -func FixedUnigramCandidateSamplerUnigrams(value []float32) FixedUnigramCandidateSamplerAttr { +// REQUIRES: value >= 0 +func MapStageMemoryLimit(value int64) MapStageAttr { return func(m optionalAttr) { - m["unigrams"] = value + m["memory_limit"] = value } } -// FixedUnigramCandidateSamplerSeed sets the optional seed attribute to value. +// MapStageContainer sets the optional container attribute to value. // -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func FixedUnigramCandidateSamplerSeed(value int64) FixedUnigramCandidateSamplerAttr { +// value: If non-empty, this queue is placed in the given container. Otherwise, +// a default container is used. +// If not specified, defaults to "" +func MapStageContainer(value string) MapStageAttr { return func(m optionalAttr) { - m["seed"] = value + m["container"] = value } } -// FixedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// MapStageSharedName sets the optional shared_name attribute to value. // -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func FixedUnigramCandidateSamplerSeed2(value int64) FixedUnigramCandidateSamplerAttr { +// value: It is necessary to match this name to the matching Unstage Op. +// If not specified, defaults to "" +func MapStageSharedName(value string) MapStageAttr { return func(m optionalAttr) { - m["seed2"] = value + m["shared_name"] = value } } -// Generates labels for candidate sampling with a learned unigram distribution. -// -// A unigram sampler could use a fixed unigram distribution read from a -// file or passed in as an in-memory array instead of building up the distribution -// from data on the fly. There is also an option to skew the distribution by -// applying a distortion power to the weights. -// -// The vocabulary file should be in CSV-like format, with the last field -// being the weight associated with the word. +// Stage (key, values) in the underlying container which behaves like a hashtable. // -// For each batch, this op picks a single set of sampled candidate labels. +// Arguments: +// key: int64 // -// The advantages of sampling candidates per-batch are simplicity and the -// possibility of efficient dense matrix multiplication. The disadvantage is that -// the sampled candidates must be chosen independently of the context and of the -// true labels. +// values: a list of tensors +// dtypes A list of data types that inserted values should adhere to. // -// Arguments: -// true_classes: A batch_size * num_true matrix, in which each row contains the -// IDs of the num_true target_classes in the corresponding original label. -// num_true: Number of true labels per context. -// num_sampled: Number of candidates to randomly sample. -// unique: If unique is true, we sample with rejection, so that all sampled -// candidates in a batch are unique. This requires some approximation to -// estimate the post-rejection sampling probabilities. -// range_max: The sampler will sample integers from the interval [0, range_max). // -// Returns A vector of length num_sampled, in which each element is -// the ID of a sampled candidate.A batch_size * num_true matrix, representing -// the number of times each candidate is expected to occur in a batch -// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled -// candidate representing the number of times the candidate is expected -// to occur in a batch of sampled candidates. If unique=true, then this is a -// probability. -func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...FixedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { +// Returns the created operation. +func MapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf.Output, dtypes []tf.DataType, optional ...MapStageAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "FixedUnigramCandidateSampler", + Type: "MapStage", Input: []tf.Input{ - true_classes, + key, indices, tf.OutputList(values), }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return scope.AddOperation(opspec) } -// Elementwise computes the bitwise AND of `x` and `y`. +// MapUnstageAttr is an optional argument to MapUnstage. +type MapUnstageAttr func(optionalAttr) + +// MapUnstageCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// The result will have those bits set, that are set in both `x` and `y`. The -// computation is performed on the underlying representations of `x` and `y`. -func BitwiseAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return +// REQUIRES: value >= 0 +func MapUnstageCapacity(value int64) MapUnstageAttr { + return func(m optionalAttr) { + m["capacity"] = value } - opspec := tf.OpSpec{ - Type: "BitwiseAnd", - Input: []tf.Input{ - x, y, - }, +} + +// MapUnstageMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapUnstageMemoryLimit(value int64) MapUnstageAttr { + return func(m optionalAttr) { + m["memory_limit"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Elementwise computes the bitwise left-shift of `x` and `y`. +// MapUnstageContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapUnstageContainer(value string) MapUnstageAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MapUnstageSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapUnstageSharedName(value string) MapUnstageAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op removes and returns the values associated with the key // -// If `y` is negative, or greater than or equal to the width of `x` in bits the -// result is implementation defined. -func LeftShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// from the underlying container. If the underlying container +// does not contain this key, the op will block until it does. +func MapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageAttr) (values []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "LeftShift", + Type: "MapUnstage", Input: []tf.Input{ - x, y, + key, indices, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("MapUnstage", err) + return + } + return values } -// Elementwise computes the bitwise right-shift of `x` and `y`. +// MapSizeAttr is an optional argument to MapSize. +type MapSizeAttr func(optionalAttr) + +// MapSizeCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// Performs a logical shift for unsigned integer types, and an arithmetic shift -// for signed integer types. +// REQUIRES: value >= 0 +func MapSizeCapacity(value int64) MapSizeAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// MapSizeMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// If `y` is negative, or greater than or equal to than the width of `x` in bits -// the result is implementation defined. -func RightShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// REQUIRES: value >= 0 +func MapSizeMemoryLimit(value int64) MapSizeAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// MapSizeContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapSizeContainer(value string) MapSizeAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MapSizeSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapSizeSharedName(value string) MapSizeAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op returns the number of elements in the underlying container. +func MapSize(scope *Scope, dtypes []tf.DataType, optional ...MapSizeAttr) (size tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "RightShift", - Input: []tf.Input{ - x, y, - }, + Type: "MapSize", + + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Adjust the hue of one or more images. -// -// `images` is a tensor of at least 3 dimensions. The last dimension is -// interpretted as channels, and must be three. -// -// The input image is considered in the RGB colorspace. Conceptually, the RGB -// colors are first mapped into HSV. A delta is then applied all the hue values, -// and then remapped back to RGB colorspace. -// -// Arguments: -// images: Images to adjust. At least 3-D. -// delta: A float delta to add to the hue. +// MapIncompleteSizeAttr is an optional argument to MapIncompleteSize. +type MapIncompleteSizeAttr func(optionalAttr) + +// MapIncompleteSizeCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// Returns The hue-adjusted image or images. -func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) { - if scope.Err() != nil { - return +// REQUIRES: value >= 0 +func MapIncompleteSizeCapacity(value int64) MapIncompleteSizeAttr { + return func(m optionalAttr) { + m["capacity"] = value } - opspec := tf.OpSpec{ - Type: "AdjustHue", - Input: []tf.Input{ - images, delta, - }, +} + +// MapIncompleteSizeMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapIncompleteSizeMemoryLimit(value int64) MapIncompleteSizeAttr { + return func(m optionalAttr) { + m["memory_limit"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad. -type AvgPool3DGradAttr func(optionalAttr) +// MapIncompleteSizeContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapIncompleteSizeContainer(value string) MapIncompleteSizeAttr { + return func(m optionalAttr) { + m["container"] = value + } +} -// AvgPool3DGradDataFormat sets the optional data_format attribute to value. -// -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr { +// MapIncompleteSizeSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapIncompleteSizeSharedName(value string) MapIncompleteSizeAttr { return func(m optionalAttr) { - m["data_format"] = value + m["shared_name"] = value } } -// Computes gradients of average pooling function. -// -// Arguments: -// orig_input_shape: The original input dimensions. -// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -// -// Returns The backprop for input. -func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) { +// Op returns the number of incomplete elements in the underlying container. +func MapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...MapIncompleteSizeAttr) (size tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "AvgPool3DGrad", - Input: []tf.Input{ - orig_input_shape, grad, - }, + Type: "MapIncompleteSize", + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ParseSingleSequenceExampleAttr is an optional argument to ParseSingleSequenceExample. -type ParseSingleSequenceExampleAttr func(optionalAttr) - -// ParseSingleSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value. -// -// value: A list of Ncontext_sparse types; the data types of data in -// each context Feature given in context_sparse_keys. -// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleContextSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { - return func(m optionalAttr) { - m["context_sparse_types"] = value - } -} +// OrderedMapUnstageAttr is an optional argument to OrderedMapUnstage. +type OrderedMapUnstageAttr func(optionalAttr) -// ParseSingleSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value. -// If not specified, defaults to <> +// OrderedMapUnstageCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { +// REQUIRES: value >= 0 +func OrderedMapUnstageCapacity(value int64) OrderedMapUnstageAttr { return func(m optionalAttr) { - m["feature_list_dense_types"] = value + m["capacity"] = value } } -// ParseSingleSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value. -// -// value: A list of Ncontext_dense shapes; the shapes of data in -// each context Feature given in context_dense_keys. -// The number of elements in the Feature corresponding to context_dense_key[j] -// must always equal context_dense_shapes[j].NumEntries(). -// The shape of context_dense_values[j] will match context_dense_shapes[j]. -// If not specified, defaults to <> +// OrderedMapUnstageMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleContextDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { +// REQUIRES: value >= 0 +func OrderedMapUnstageMemoryLimit(value int64) OrderedMapUnstageAttr { return func(m optionalAttr) { - m["context_dense_shapes"] = value + m["memory_limit"] = value } } -// ParseSingleSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value. -// -// value: A list of Nfeature_list_sparse types; the data types -// of data in each FeatureList given in feature_list_sparse_keys. -// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { +// OrderedMapUnstageContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func OrderedMapUnstageContainer(value string) OrderedMapUnstageAttr { return func(m optionalAttr) { - m["feature_list_sparse_types"] = value + m["container"] = value } } -// ParseSingleSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value. -// -// value: A list of Nfeature_list_dense shapes; the shapes of -// data in each FeatureList given in feature_list_dense_keys. -// The shape of each Feature in the FeatureList corresponding to -// feature_list_dense_key[j] must always equal -// feature_list_dense_shapes[j].NumEntries(). -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { +// OrderedMapUnstageSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func OrderedMapUnstageSharedName(value string) OrderedMapUnstageAttr { return func(m optionalAttr) { - m["feature_list_dense_shapes"] = value + m["shared_name"] = value } } -// Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors. +// Op removes and returns the values associated with the key // -// Arguments: -// serialized: A scalar containing a binary serialized SequenceExample proto. -// feature_list_dense_missing_assumed_empty: A vector listing the -// FeatureList keys which may be missing from the SequenceExample. If the -// associated FeatureList is missing, it is treated as empty. By default, -// any FeatureList not listed in this vector must exist in the SequenceExample. -// context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars). -// The keys expected in the Examples' features associated with context_sparse -// values. -// context_dense_keys: A list of Ncontext_dense string Tensors (scalars). -// The keys expected in the SequenceExamples' context features associated with -// dense values. -// feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors -// (scalars). The keys expected in the FeatureLists associated with sparse -// values. -// feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars). -// The keys expected in the SequenceExamples' feature_lists associated -// with lists of dense values. -// context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty). -// context_dense_defaults[j] provides default values -// when the SequenceExample's context map lacks context_dense_key[j]. -// If an empty Tensor is provided for context_dense_defaults[j], -// then the Feature context_dense_keys[j] is required. -// The input type is inferred from context_dense_defaults[j], even when it's -// empty. If context_dense_defaults[j] is not empty, its shape must match -// context_dense_shapes[j]. -// debug_name: A scalar containing the name of the serialized proto. -// May contain, for example, table key (descriptive) name for the -// corresponding serialized proto. This is purely useful for debugging -// purposes, and the presence of values here has no effect on the output. -// May also be an empty scalar if no name is available. -func ParseSingleSequenceExample(scope *Scope, serialized tf.Output, feature_list_dense_missing_assumed_empty tf.Output, context_sparse_keys []tf.Output, context_dense_keys []tf.Output, feature_list_sparse_keys []tf.Output, feature_list_dense_keys []tf.Output, context_dense_defaults []tf.Output, debug_name tf.Output, optional ...ParseSingleSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output) { +// from the underlying container. If the underlying container +// does not contain this key, the op will block until it does. +func OrderedMapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapUnstageAttr) (values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ParseSingleSequenceExample", + Type: "OrderedMapUnstage", Input: []tf.Input{ - serialized, feature_list_dense_missing_assumed_empty, tf.OutputList(context_sparse_keys), tf.OutputList(context_dense_keys), tf.OutputList(feature_list_sparse_keys), tf.OutputList(feature_list_dense_keys), tf.OutputList(context_dense_defaults), debug_name, + key, indices, }, Attrs: attrs, } @@ -25675,138 +25559,85 @@ func ParseSingleSequenceExample(scope *Scope, serialized tf.Output, feature_list } var idx int var err error - if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("OrderedMapUnstage", err) return } - return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values + return values } -// DecodeWavAttr is an optional argument to DecodeWav. -type DecodeWavAttr func(optionalAttr) +// OrderedMapSizeAttr is an optional argument to OrderedMapSize. +type OrderedMapSizeAttr func(optionalAttr) -// DecodeWavDesiredChannels sets the optional desired_channels attribute to value. +// OrderedMapSizeCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// value: Number of sample channels wanted. -// If not specified, defaults to -1 -func DecodeWavDesiredChannels(value int64) DecodeWavAttr { +// REQUIRES: value >= 0 +func OrderedMapSizeCapacity(value int64) OrderedMapSizeAttr { return func(m optionalAttr) { - m["desired_channels"] = value + m["capacity"] = value } } -// DecodeWavDesiredSamples sets the optional desired_samples attribute to value. +// OrderedMapSizeMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// value: Length of audio requested. -// If not specified, defaults to -1 -func DecodeWavDesiredSamples(value int64) DecodeWavAttr { +// REQUIRES: value >= 0 +func OrderedMapSizeMemoryLimit(value int64) OrderedMapSizeAttr { return func(m optionalAttr) { - m["desired_samples"] = value + m["memory_limit"] = value } } -// Decode a 16-bit PCM WAV file to a float tensor. -// -// The -32768 to 32767 signed 16-bit values will be scaled to -1.0 to 1.0 in float. -// -// When desired_channels is set, if the input contains fewer channels than this -// then the last channel will be duplicated to give the requested number, else if -// the input has more channels than requested then the additional channels will be -// ignored. -// -// If desired_samples is set, then the audio will be cropped or padded with zeroes -// to the requested length. -// -// The first output contains a Tensor with the content of the audio samples. The -// lowest dimension will be the number of channels, and the second will be the -// number of samples. For example, a ten-sample-long stereo WAV file should give an -// output shape of [10, 2]. -// -// Arguments: -// contents: The WAV-encoded audio, usually from a file. -// -// Returns 2-D with shape `[length, channels]`.Scalar holding the sample rate found in the WAV header. -func DecodeWav(scope *Scope, contents tf.Output, optional ...DecodeWavAttr) (audio tf.Output, sample_rate tf.Output) { +// OrderedMapSizeContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func OrderedMapSizeContainer(value string) OrderedMapSizeAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// OrderedMapSizeSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func OrderedMapSizeSharedName(value string) OrderedMapSizeAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op returns the number of elements in the underlying container. +func OrderedMapSize(scope *Scope, dtypes []tf.DataType, optional ...OrderedMapSizeAttr) (size tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DecodeWav", - Input: []tf.Input{ - contents, - }, + Type: "OrderedMapSize", + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// UniqueAttr is an optional argument to Unique. -type UniqueAttr func(optionalAttr) +// ShapeNAttr is an optional argument to ShapeN. +type ShapeNAttr func(optionalAttr) -// UniqueOutIdx sets the optional out_idx attribute to value. +// ShapeNOutType sets the optional out_type attribute to value. // If not specified, defaults to DT_INT32 -func UniqueOutIdx(value tf.DataType) UniqueAttr { +func ShapeNOutType(value tf.DataType) ShapeNAttr { return func(m optionalAttr) { - m["out_idx"] = value + m["out_type"] = value } } -// Finds unique elements in a 1-D tensor. -// -// This operation returns a tensor `y` containing all of the unique elements of `x` -// sorted in the same order that they occur in `x`. This operation also returns a -// tensor `idx` the same size as `x` that contains the index of each value of `x` -// in the unique output `y`. In other words: -// -// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` -// -// For example: -// -// ``` -// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] -// y, idx = unique(x) -// y ==> [1, 2, 4, 7, 8] -// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] -// ``` -// -// Arguments: -// x: 1-D. +// Returns shape of tensors. // -// Returns 1-D.1-D. -func Unique(scope *Scope, x tf.Output, optional ...UniqueAttr) (y tf.Output, idx tf.Output) { +// This operation returns N 1-D integer tensors representing shape of `input[i]s`. +func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []tf.Output) { if scope.Err() != nil { return } @@ -25815,393 +25646,315 @@ func Unique(scope *Scope, x tf.Output, optional ...UniqueAttr) (y tf.Output, idx a(attrs) } opspec := tf.OpSpec{ - Type: "Unique", + Type: "ShapeN", Input: []tf.Input{ - x, + tf.OutputList(input), }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Concatenates a list of `N` tensors along the first dimension. -// -// The input tensors are all required to have size 1 in the first dimension. -// -// For example: -// -// ``` -// # 'x' is [[1, 4]] -// # 'y' is [[2, 5]] -// # 'z' is [[3, 6]] -// parallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. -// ``` -// -// The difference between concat and parallel_concat is that concat requires all -// of the inputs be computed before the operation will begin but doesn't require -// that the input shapes be known during graph construction. Parallel concat -// will copy pieces of the input into the output as they become available, in -// some situations this can provide a performance benefit. -// -// Arguments: -// values: Tensors to be concatenated. All must have size 1 in the first dimension -// and same shape. -// shape: the final shape of the result; should be equal to the shapes of any input -// but with the number of input values in the first dimension. -// -// Returns The concatenated tensor. -func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"shape": shape} - opspec := tf.OpSpec{ - Type: "ParallelConcat", - Input: []tf.Input{ - tf.OutputList(values), - }, - Attrs: attrs, + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("ShapeN", err) + return } - op := scope.AddOperation(opspec) - return op.Output(0) + return output } -// Concatenates tensors along one dimension. -// -// Arguments: -// concat_dim: 0-D. The dimension along which to concatenate. Must be in the -// range [0, rank(values)). -// values: The `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. +// UniformCandidateSamplerAttr is an optional argument to UniformCandidateSampler. +type UniformCandidateSamplerAttr func(optionalAttr) + +// UniformCandidateSamplerSeed sets the optional seed attribute to value. // -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes. -func Concat(scope *Scope, concat_dim tf.Output, values []tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Concat", - Input: []tf.Input{ - concat_dim, tf.OutputList(values), - }, +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func UniformCandidateSamplerSeed(value int64) UniformCandidateSamplerAttr { + return func(m optionalAttr) { + m["seed"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Compute the lower regularized incomplete Gamma function `Q(a, x)`. -// -// The lower regularized incomplete Gamma function is defined as: -// -// -// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\) -// -// where -// -// \\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\) -// -// is the lower incomplete Gamma function. +// UniformCandidateSamplerSeed2 sets the optional seed2 attribute to value. // -// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete -// Gamma function. -func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Igamma", - Input: []tf.Input{ - a, x, - }, +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func UniformCandidateSamplerSeed2(value int64) UniformCandidateSamplerAttr { + return func(m optionalAttr) { + m["seed2"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Computes offsets of concat inputs within its output. +// Generates labels for candidate sampling with a uniform distribution. // -// For example: +// See explanations of candidate sampling and the data formats at +// go/candidate-sampling. // -// ``` -// # 'x' is [2, 2, 7] -// # 'y' is [2, 3, 7] -// # 'z' is [2, 5, 7] -// concat_offset(2, [x, y, z]) => [0, 0, 0], [0, 2, 0], [0, 5, 0] -// ``` +// For each batch, this op picks a single set of sampled candidate labels. // -// This is typically used by gradient computations for a concat operation. +// The advantages of sampling candidates per-batch are simplicity and the +// possibility of efficient dense matrix multiplication. The disadvantage is that +// the sampled candidates must be chosen independently of the context and of the +// true labels. // // Arguments: -// concat_dim: The dimension along which to concatenate. -// shape: The `N` int32 vectors representing shape of tensors being concatenated. +// true_classes: A batch_size * num_true matrix, in which each row contains the +// IDs of the num_true target_classes in the corresponding original label. +// num_true: Number of true labels per context. +// num_sampled: Number of candidates to randomly sample. +// unique: If unique is true, we sample with rejection, so that all sampled +// candidates in a batch are unique. This requires some approximation to +// estimate the post-rejection sampling probabilities. +// range_max: The sampler will sample integers from the interval [0, range_max). // -// Returns The `N` int32 vectors representing the starting offset -// of input tensors within the concatenated output. -func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset []tf.Output) { +// Returns A vector of length num_sampled, in which each element is +// the ID of a sampled candidate.A batch_size * num_true matrix, representing +// the number of times each candidate is expected to occur in a batch +// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled +// candidate representing the number of times the candidate is expected +// to occur in a batch of sampled candidates. If unique=true, then this is a +// probability. +func UniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...UniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ConcatOffset", + Type: "UniformCandidateSampler", Input: []tf.Input{ - concat_dim, tf.OutputList(shape), + true_classes, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return + return op.Output(0), op.Output(1), op.Output(2) +} + +// CTCLossAttr is an optional argument to CTCLoss. +type CTCLossAttr func(optionalAttr) + +// CTCLossPreprocessCollapseRepeated sets the optional preprocess_collapse_repeated attribute to value. +// +// value: Scalar, if true then repeated labels are +// collapsed prior to the CTC calculation. +// If not specified, defaults to false +func CTCLossPreprocessCollapseRepeated(value bool) CTCLossAttr { + return func(m optionalAttr) { + m["preprocess_collapse_repeated"] = value } - var idx int - var err error - if offset, idx, err = makeOutputList(op, idx, "offset"); err != nil { - scope.UpdateErr("ConcatOffset", err) - return +} + +// CTCLossCtcMergeRepeated sets the optional ctc_merge_repeated attribute to value. +// +// value: Scalar. If set to false, *during* CTC calculation +// repeated non-blank labels will not be merged and are interpreted as +// individual labels. This is a simplified version of CTC. +// If not specified, defaults to true +func CTCLossCtcMergeRepeated(value bool) CTCLossAttr { + return func(m optionalAttr) { + m["ctc_merge_repeated"] = value + } +} + +// CTCLossIgnoreLongerOutputsThanInputs sets the optional ignore_longer_outputs_than_inputs attribute to value. +// +// value: Scalar. If set to true, during CTC +// calculation, items that have longer output sequences than input sequences +// are skipped: they don't contribute to the loss term and have zero-gradient. +// If not specified, defaults to false +func CTCLossIgnoreLongerOutputsThanInputs(value bool) CTCLossAttr { + return func(m optionalAttr) { + m["ignore_longer_outputs_than_inputs"] = value } - return offset } -// Splits a tensor into `num_split` tensors along one dimension. +// Calculates the CTC Loss (log probability) for each batch entry. Also calculates +// +// the gradient. This class performs the softmax operation for you, so inputs +// should be e.g. linear projections of outputs by an LSTM. // // Arguments: -// axis: 0-D. The dimension along which to split. Must be in the range -// `[-rank(value), rank(value))`. -// value: The tensor to split. -// num_split: The number of ways to split. Must evenly divide -// `value.shape[split_dim]`. +// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. +// labels_indices: The indices of a `SparseTensor`. +// `labels_indices(i, :) == [b, t]` means `labels_values(i)` stores the id for +// `(batch b, time t)`. +// labels_values: The values (labels) associated with the given batch and time. +// sequence_length: A vector containing sequence lengths (batch). // -// Returns They are identically shaped tensors, whose shape matches that of `value` -// except along `axis`, where their sizes are -// `values.shape[split_dim] / num_split`. -func Split(scope *Scope, axis tf.Output, value tf.Output, num_split int64) (output []tf.Output) { +// Returns A vector (batch) containing log-probabilities.The gradient of `loss`. 3-D, shape: +// `(max_time x batch_size x num_classes)`. +func CTCLoss(scope *Scope, inputs tf.Output, labels_indices tf.Output, labels_values tf.Output, sequence_length tf.Output, optional ...CTCLossAttr) (loss tf.Output, gradient tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_split": num_split} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Split", + Type: "CTCLoss", Input: []tf.Input{ - axis, value, + inputs, labels_indices, labels_values, sequence_length, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("Split", err) - return + return op.Output(0), op.Output(1) +} + +// CTCGreedyDecoderAttr is an optional argument to CTCGreedyDecoder. +type CTCGreedyDecoderAttr func(optionalAttr) + +// CTCGreedyDecoderMergeRepeated sets the optional merge_repeated attribute to value. +// +// value: If True, merge repeated classes in output. +// If not specified, defaults to false +func CTCGreedyDecoderMergeRepeated(value bool) CTCGreedyDecoderAttr { + return func(m optionalAttr) { + m["merge_repeated"] = value } - return output } -// Splits a tensor into `num_split` tensors along one dimension. +// Performs greedy decoding on the logits given in inputs. // -// Arguments: -// value: The tensor to split. -// size_splits: list containing the sizes of each output tensor along the split -// dimension. Must sum to the dimension of value along split_dim. -// Can contain one -1 indicating that dimension is to be inferred. -// axis: 0-D. The dimension along which to split. Must be in the range -// `[-rank(value), rank(value))`. +// A note about the attribute merge_repeated: if enabled, when +// consecutive logits' maximum indices are the same, only the first of +// these is emitted. Labeling the blank '*', the sequence "A B B * B B" +// becomes "A B B" if merge_repeated = True and "A B B B B" if +// merge_repeated = False. // +// Regardless of the value of merge_repeated, if the maximum index of a given +// time and batch corresponds to the blank, index `(num_classes - 1)`, no new +// element is emitted. // -// Returns Tensors whose shape matches that of `value` -// except along `axis`, where their sizes are -// `size_splits[i]`. -func SplitV(scope *Scope, value tf.Output, size_splits tf.Output, axis tf.Output, num_split int64) (output []tf.Output) { +// Arguments: +// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. +// sequence_length: A vector containing sequence lengths, size `(batch_size)`. +// +// Returns Indices matrix, size `(total_decoded_outputs x 2)`, +// of a `SparseTensor`. The rows store: [batch, time].Values vector, size: `(total_decoded_outputs)`, +// of a `SparseTensor`. The vector stores the decoded classes.Shape vector, size `(2)`, of the decoded SparseTensor. +// Values are: `[batch_size, max_decoded_length]`.Matrix, size `(batch_size x 1)`, containing sequence +// log-probabilities. +func CTCGreedyDecoder(scope *Scope, inputs tf.Output, sequence_length tf.Output, optional ...CTCGreedyDecoderAttr) (decoded_indices tf.Output, decoded_values tf.Output, decoded_shape tf.Output, log_probability tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_split": num_split} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SplitV", + Type: "CTCGreedyDecoder", Input: []tf.Input{ - value, size_splits, axis, + inputs, sequence_length, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("SplitV", err) - return - } - return output + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// Gives a guarantee to the TF runtime that the input tensor is a constant. +// Forwards `data` to the output port determined by `pred`. // -// The runtime is then free to make optimizations based on this. +// If `pred` is true, the `data` input is forwarded to `output_true`. Otherwise, +// the data goes to `output_false`. // -// Only accepts value typed tensors as inputs and rejects resource variable handles -// as input. +// See also `RefSwitch` and `Merge`. // -// Returns the input tensor without modification. -func GuaranteeConst(scope *Scope, input tf.Output) (output tf.Output) { +// Arguments: +// data: The tensor to be forwarded to the appropriate output. +// pred: A scalar that specifies which output port will receive data. +// +// Returns If `pred` is false, data will be forwarded to this output.If `pred` is true, data will be forwarded to this output. +func Switch(scope *Scope, data tf.Output, pred tf.Output) (output_false tf.Output, output_true tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "GuaranteeConst", + Type: "Switch", Input: []tf.Input{ - input, + data, pred, }, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Returns a tensor of zeros with the same shape and type as x. +// Add all input tensors element wise. // // Arguments: -// x: a tensor of type T. -// -// Returns a tensor of the same shape and type as x but filled with zeros. -func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) { +// inputs: Must all be the same size and shape. +func AddN(scope *Scope, inputs []tf.Output) (sum tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ZerosLike", + Type: "AddN", Input: []tf.Input{ - x, + tf.OutputList(inputs), }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Flips all bits elementwise. +// EnterAttr is an optional argument to Enter. +type EnterAttr func(optionalAttr) + +// EnterIsConstant sets the optional is_constant attribute to value. // -// The result will have exactly those bits set, that are not set in `x`. The -// computation is performed on the underlying representation of x. -func Invert(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Invert", - Input: []tf.Input{ - x, - }, +// value: If true, the output is constant within the child frame. +// If not specified, defaults to false +func EnterIsConstant(value bool) EnterAttr { + return func(m optionalAttr) { + m["is_constant"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// DequantizeAttr is an optional argument to Dequantize. -type DequantizeAttr func(optionalAttr) - -// DequantizeMode sets the optional mode attribute to value. -// If not specified, defaults to "MIN_COMBINED" -func DequantizeMode(value string) DequantizeAttr { +// EnterParallelIterations sets the optional parallel_iterations attribute to value. +// +// value: The number of iterations allowed to run in parallel. +// If not specified, defaults to 10 +func EnterParallelIterations(value int64) EnterAttr { return func(m optionalAttr) { - m["mode"] = value + m["parallel_iterations"] = value } } -// Dequantize the 'input' tensor into a float Tensor. -// -// [min_range, max_range] are scalar floats that specify the range for -// the 'input' data. The 'mode' attribute controls exactly which calculations are -// used to convert the float values to their quantized equivalents. -// -// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: -// -// ``` -// if T == qint8, in[i] += (range(T) + 1)/ 2.0 -// out[i] = min_range + (in[i]* (max_range - min_range) / range(T)) -// ``` -// here `range(T) = numeric_limits::max() - numeric_limits::min()` -// -// *MIN_COMBINED Mode Example* -// -// If the input comes from a QuantizedRelu6, the output type is -// quint8 (range of 0-255) but the possible range of QuantizedRelu6 is -// 0-6. The min_range and max_range values are therefore 0.0 and 6.0. -// Dequantize on quint8 will take each value, cast to float, and multiply -// by 6 / 255. -// Note that if quantizedtype is qint8, the operation will additionally add -// each value by 128 prior to casting. -// -// If the mode is 'MIN_FIRST', then this approach is used: -// -// ```c++ -// num_discrete_values = 1 << (# of bits in T) -// range_adjust = num_discrete_values / (num_discrete_values - 1) -// range = (range_max - range_min) * range_adjust -// range_scale = range / num_discrete_values -// const double offset_input = static_cast(input) - lowest_quantized; -// result = range_min + ((input - numeric_limits::min()) * range_scale) -// ``` -// -// *SCALED mode Example* -// -// `SCALED` mode matches the quantization approach used in -// `QuantizeAndDequantize{V2|V3}`. -// -// If the mode is `SCALED`, we do not use the full range of the output type, -// choosing to elide the lowest possible value for symmetry (e.g., output range is -// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to -// 0. -// -// We first find the range of values in our tensor. The -// range we use is always centered on 0, so we find m such that -// ```c++ -// m = max(abs(input_min), abs(input_max)) -// ``` -// -// Our input tensor range is then `[-m, m]`. -// -// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. -// If T is signed, this is -// ``` -// num_bits = sizeof(T) * 8 -// [min_fixed, max_fixed] = -// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1] -// ``` -// -// Otherwise, if T is unsigned, the fixed-point range is -// ``` -// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] -// ``` -// -// From this we compute our scaling factor, s: -// ```c++ -// s = (2 * m) / (max_fixed - min_fixed) -// ``` +// Creates or finds a child frame, and makes `data` available to the child frame. // -// Now we can dequantize the elements of our tensor: -// ```c++ -// result = input * s -// ``` +// This op is used together with `Exit` to create loops in the graph. +// The unique `frame_name` is used by the `Executor` to identify frames. If +// `is_constant` is true, `output` is a constant in the child frame; otherwise +// it may be changed in the child frame. At most `parallel_iterations` iterations +// are run in parallel in the child frame. // // Arguments: +// data: The tensor to be made available to the child frame. +// frame_name: The name of the child frame. // -// min_range: The minimum scalar value possibly produced for the input. -// max_range: The maximum scalar value possibly produced for the input. -func Dequantize(scope *Scope, input tf.Output, min_range tf.Output, max_range tf.Output, optional ...DequantizeAttr) (output tf.Output) { +// Returns The same tensor as `data`. +func Enter(scope *Scope, data tf.Output, frame_name string, optional ...EnterAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"frame_name": frame_name} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Dequantize", + Type: "Enter", Input: []tf.Input{ - input, min_range, max_range, + data, }, Attrs: attrs, } @@ -26209,136 +25962,94 @@ func Dequantize(scope *Scope, input tf.Output, min_range tf.Output, max_range tf return op.Output(0) } -// Returns the element-wise max of two SparseTensors. +// Produce a string tensor that encodes the state of a Reader. // -// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. +// Not all Readers support being serialized, so this can produce an +// Unimplemented error. // // Arguments: -// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, in the canonical lexicographic ordering. -// a_values: 1-D. `N` non-empty values corresponding to `a_indices`. -// a_shape: 1-D. Shape of the input SparseTensor. -// b_indices: counterpart to `a_indices` for the other operand. -// b_values: counterpart to `a_values` for the other operand; must be of the same dtype. -// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal. -// -// Returns 2-D. The indices of the output SparseTensor.1-D. The values of the output SparseTensor. -func SparseSparseMaximum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { +// reader_handle: Handle to a Reader. +func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseSparseMaximum", + Type: "ReaderSerializeStateV2", Input: []tf.Input{ - a_indices, a_values, a_shape, b_indices, b_values, b_shape, + reader_handle, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Returns a batched matrix tensor with new batched diagonal values. -// -// Given `input` and `diagonal`, this operation returns a tensor with the -// same shape and values as `input`, except for the main diagonal of the -// innermost matrices. These will be overwritten by the values in `diagonal`. -// -// The output is computed as follows: -// -// Assume `input` has `k+1` dimensions `[I, J, K, ..., M, N]` and `diagonal` has -// `k` dimensions `[I, J, K, ..., min(M, N)]`. Then the output is a -// tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where: +// Exits the current frame to its parent frame. // -// * `output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]` for `m == n`. -// * `output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n]` for `m != n`. +// Exit makes its input `data` available to the parent frame. // // Arguments: -// input: Rank `k+1`, where `k >= 1`. -// diagonal: Rank `k`, where `k >= 1`. +// data: The tensor to be made available to the parent frame. // -// Returns Rank `k+1`, with `output.shape = input.shape`. -func MatrixSetDiag(scope *Scope, input tf.Output, diagonal tf.Output) (output tf.Output) { +// Returns The same tensor as `data`. +func Exit(scope *Scope, data tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MatrixSetDiag", + Type: "Exit", Input: []tf.Input{ - input, diagonal, + data, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// EditDistanceAttr is an optional argument to EditDistance. -type EditDistanceAttr func(optionalAttr) +// Returns a copy of the input tensor. +func Snapshot(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Snapshot", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} -// EditDistanceNormalize sets the optional normalize attribute to value. -// -// value: boolean (if true, edit distances are normalized by length of truth). +// AbortAttr is an optional argument to Abort. +type AbortAttr func(optionalAttr) + +// AbortErrorMsg sets the optional error_msg attribute to value. // -// The output is: -// If not specified, defaults to true -func EditDistanceNormalize(value bool) EditDistanceAttr { +// value: A string which is the message associated with the exception. +// If not specified, defaults to "" +func AbortErrorMsg(value string) AbortAttr { return func(m optionalAttr) { - m["normalize"] = value + m["error_msg"] = value } } -// Computes the (possibly normalized) Levenshtein Edit Distance. -// -// The inputs are variable-length sequences provided by SparseTensors -// (hypothesis_indices, hypothesis_values, hypothesis_shape) -// and -// (truth_indices, truth_values, truth_shape). -// -// The inputs are: -// -// Arguments: -// hypothesis_indices: The indices of the hypothesis list SparseTensor. -// This is an N x R int64 matrix. -// hypothesis_values: The values of the hypothesis list SparseTensor. -// This is an N-length vector. -// hypothesis_shape: The shape of the hypothesis list SparseTensor. -// This is an R-length vector. -// truth_indices: The indices of the truth list SparseTensor. -// This is an M x R int64 matrix. -// truth_values: The values of the truth list SparseTensor. -// This is an M-length vector. -// truth_shape: truth indices, vector. -// -// Returns A dense float tensor with rank R - 1. -// -// For the example input: -// -// // hypothesis represents a 2x1 matrix with variable-length values: -// // (0,0) = ["a"] -// // (1,0) = ["b"] -// hypothesis_indices = [[0, 0, 0], -// [1, 0, 0]] -// hypothesis_values = ["a", "b"] -// hypothesis_shape = [2, 1, 1] +// AbortExitWithoutError sets the optional exit_without_error attribute to value. +// If not specified, defaults to false +func AbortExitWithoutError(value bool) AbortAttr { + return func(m optionalAttr) { + m["exit_without_error"] = value + } +} + +// Raise a exception to abort the process when called. // -// // truth represents a 2x2 matrix with variable-length values: -// // (0,0) = [] -// // (0,1) = ["a"] -// // (1,0) = ["b", "c"] -// // (1,1) = ["a"] -// truth_indices = [[0, 1, 0], -// [1, 0, 0], -// [1, 0, 1], -// [1, 1, 0]] -// truth_values = ["a", "b", "c", "a"] -// truth_shape = [2, 2, 2] -// normalize = true +// If exit_without_error is true, the process will exit normally, +// otherwise it will exit with a SIGABORT signal. // -// The output will be: +// Returns nothing but an exception. // -// // output is a 2x2 matrix with edit distances normalized by truth lengths. -// output = [[inf, 1.0], // (0,0): no truth, (0,1): no hypothesis -// [0.5, 1.0]] // (1,0): addition, (1,1): no hypothesis -func EditDistance(scope *Scope, hypothesis_indices tf.Output, hypothesis_values tf.Output, hypothesis_shape tf.Output, truth_indices tf.Output, truth_values tf.Output, truth_shape tf.Output, optional ...EditDistanceAttr) (output tf.Output) { +// Returns the created operation. +func Abort(scope *Scope, optional ...AbortAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -26347,256 +26058,298 @@ func EditDistance(scope *Scope, hypothesis_indices tf.Output, hypothesis_values a(attrs) } opspec := tf.OpSpec{ - Type: "EditDistance", - Input: []tf.Input{ - hypothesis_indices, hypothesis_values, hypothesis_shape, truth_indices, truth_values, truth_shape, - }, + Type: "Abort", + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Gather slices from `params` into a Tensor with shape specified by `indices`. -// -// `indices` is an K-dimensional integer tensor, best thought of as a -// (K-1)-dimensional tensor of indices into `params`, where each element defines a -// slice of `params`: -// -// output[i_0, ..., i_{K-2}] = params[indices[i0, ..., i_{K-2}]] -// -// Whereas in @{tf.gather} `indices` defines slices into the first -// dimension of `params`, in `tf.gather_nd`, `indices` defines slices into the -// first `N` dimensions of `params`, where `N = indices.shape[-1]`. -// -// The last dimension of `indices` can be at most the rank of -// `params`: -// -// indices.shape[-1] <= params.rank -// -// The last dimension of `indices` corresponds to elements -// (if `indices.shape[-1] == params.rank`) or slices -// (if `indices.shape[-1] < params.rank`) along dimension `indices.shape[-1]` -// of `params`. The output tensor has shape -// -// indices.shape[:-1] + params.shape[indices.shape[-1]:] -// -// Some examples below. -// -// Simple indexing into a matrix: -// -// ```python -// indices = [[0, 0], [1, 1]] -// params = [['a', 'b'], ['c', 'd']] -// output = ['a', 'd'] -// ``` -// -// Slice indexing into a matrix: -// -// ```python -// indices = [[1], [0]] -// params = [['a', 'b'], ['c', 'd']] -// output = [['c', 'd'], ['a', 'b']] -// ``` -// -// Indexing into a 3-tensor: +// FixedUnigramCandidateSamplerAttr is an optional argument to FixedUnigramCandidateSampler. +type FixedUnigramCandidateSamplerAttr func(optionalAttr) + +// FixedUnigramCandidateSamplerVocabFile sets the optional vocab_file attribute to value. // -// ```python -// indices = [[1]] -// params = [[['a0', 'b0'], ['c0', 'd0']], -// [['a1', 'b1'], ['c1', 'd1']]] -// output = [[['a1', 'b1'], ['c1', 'd1']]] +// value: Each valid line in this file (which should have a CSV-like format) +// corresponds to a valid word ID. IDs are in sequential order, starting from +// num_reserved_ids. The last entry in each line is expected to be a value +// corresponding to the count or relative probability. Exactly one of vocab_file +// and unigrams needs to be passed to this op. +// If not specified, defaults to "" +func FixedUnigramCandidateSamplerVocabFile(value string) FixedUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["vocab_file"] = value + } +} + +// FixedUnigramCandidateSamplerDistortion sets the optional distortion attribute to value. // +// value: The distortion is used to skew the unigram probability distribution. +// Each weight is first raised to the distortion's power before adding to the +// internal unigram distribution. As a result, distortion = 1.0 gives regular +// unigram sampling (as defined by the vocab file), and distortion = 0.0 gives +// a uniform distribution. +// If not specified, defaults to 1 +func FixedUnigramCandidateSamplerDistortion(value float32) FixedUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["distortion"] = value + } +} + +// FixedUnigramCandidateSamplerNumReservedIds sets the optional num_reserved_ids attribute to value. // -// indices = [[0, 1], [1, 0]] -// params = [[['a0', 'b0'], ['c0', 'd0']], -// [['a1', 'b1'], ['c1', 'd1']]] -// output = [['c0', 'd0'], ['a1', 'b1']] +// value: Optionally some reserved IDs can be added in the range [0, +// ..., num_reserved_ids) by the users. One use case is that a special unknown +// word token is used as ID 0. These IDs will have a sampling probability of 0. +// If not specified, defaults to 0 +func FixedUnigramCandidateSamplerNumReservedIds(value int64) FixedUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["num_reserved_ids"] = value + } +} + +// FixedUnigramCandidateSamplerNumShards sets the optional num_shards attribute to value. // +// value: A sampler can be used to sample from a subset of the original range +// in order to speed up the whole computation through parallelism. This parameter +// (together with 'shard') indicates the number of partitions that are being +// used in the overall computation. +// If not specified, defaults to 1 // -// indices = [[0, 0, 1], [1, 0, 1]] -// params = [[['a0', 'b0'], ['c0', 'd0']], -// [['a1', 'b1'], ['c1', 'd1']]] -// output = ['b0', 'b1'] -// ``` +// REQUIRES: value >= 1 +func FixedUnigramCandidateSamplerNumShards(value int64) FixedUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["num_shards"] = value + } +} + +// FixedUnigramCandidateSamplerShard sets the optional shard attribute to value. // -// Batched indexing into a matrix: +// value: A sampler can be used to sample from a subset of the original range +// in order to speed up the whole computation through parallelism. This parameter +// (together with 'num_shards') indicates the particular partition number of a +// sampler op, when partitioning is being used. +// If not specified, defaults to 0 // -// ```python -// indices = [[[0, 0]], [[0, 1]]] -// params = [['a', 'b'], ['c', 'd']] -// output = [['a'], ['b']] -// ``` +// REQUIRES: value >= 0 +func FixedUnigramCandidateSamplerShard(value int64) FixedUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["shard"] = value + } +} + +// FixedUnigramCandidateSamplerUnigrams sets the optional unigrams attribute to value. // -// Batched slice indexing into a matrix: +// value: A list of unigram counts or probabilities, one per ID in sequential +// order. Exactly one of vocab_file and unigrams should be passed to this op. +// If not specified, defaults to <> +func FixedUnigramCandidateSamplerUnigrams(value []float32) FixedUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["unigrams"] = value + } +} + +// FixedUnigramCandidateSamplerSeed sets the optional seed attribute to value. // -// ```python -// indices = [[[1]], [[0]]] -// params = [['a', 'b'], ['c', 'd']] -// output = [[['c', 'd']], [['a', 'b']]] -// ``` +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func FixedUnigramCandidateSamplerSeed(value int64) FixedUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// FixedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. // -// Batched indexing into a 3-tensor: +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func FixedUnigramCandidateSamplerSeed2(value int64) FixedUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Generates labels for candidate sampling with a learned unigram distribution. // -// ```python -// indices = [[[1]], [[0]]] -// params = [[['a0', 'b0'], ['c0', 'd0']], -// [['a1', 'b1'], ['c1', 'd1']]] -// output = [[[['a1', 'b1'], ['c1', 'd1']]], -// [[['a0', 'b0'], ['c0', 'd0']]]] +// A unigram sampler could use a fixed unigram distribution read from a +// file or passed in as an in-memory array instead of building up the distribution +// from data on the fly. There is also an option to skew the distribution by +// applying a distortion power to the weights. // -// indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]] -// params = [[['a0', 'b0'], ['c0', 'd0']], -// [['a1', 'b1'], ['c1', 'd1']]] -// output = [[['c0', 'd0'], ['a1', 'b1']], -// [['a0', 'b0'], ['c1', 'd1']]] +// The vocabulary file should be in CSV-like format, with the last field +// being the weight associated with the word. // +// For each batch, this op picks a single set of sampled candidate labels. // -// indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]] -// params = [[['a0', 'b0'], ['c0', 'd0']], -// [['a1', 'b1'], ['c1', 'd1']]] -// output = [['b0', 'b1'], ['d0', 'c1']] -// ``` +// The advantages of sampling candidates per-batch are simplicity and the +// possibility of efficient dense matrix multiplication. The disadvantage is that +// the sampled candidates must be chosen independently of the context and of the +// true labels. // // Arguments: -// params: The tensor from which to gather values. -// indices: Index tensor. +// true_classes: A batch_size * num_true matrix, in which each row contains the +// IDs of the num_true target_classes in the corresponding original label. +// num_true: Number of true labels per context. +// num_sampled: Number of candidates to randomly sample. +// unique: If unique is true, we sample with rejection, so that all sampled +// candidates in a batch are unique. This requires some approximation to +// estimate the post-rejection sampling probabilities. +// range_max: The sampler will sample integers from the interval [0, range_max). // -// Returns Values from `params` gathered from indices given by `indices`, with -// shape `indices.shape[:-1] + params.shape[indices.shape[-1]:]`. -func GatherNd(scope *Scope, params tf.Output, indices tf.Output) (output tf.Output) { +// Returns A vector of length num_sampled, in which each element is +// the ID of a sampled candidate.A batch_size * num_true matrix, representing +// the number of times each candidate is expected to occur in a batch +// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled +// candidate representing the number of times the candidate is expected +// to occur in a batch of sampled candidates. If unique=true, then this is a +// probability. +func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...FixedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "GatherNd", + Type: "FixedUnigramCandidateSampler", Input: []tf.Input{ - params, indices, + true_classes, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Eagerly executes a python function to compute func(input)->output. The +// Elementwise computes the bitwise AND of `x` and `y`. // -// semantics of the input, output, and attributes are the same as those for -// PyFunc. -func EagerPyFunc(scope *Scope, input []tf.Output, token string, Tout []tf.DataType) (output []tf.Output) { +// The result will have those bits set, that are set in both `x` and `y`. The +// computation is performed on the underlying representations of `x` and `y`. +func BitwiseAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"token": token, "Tout": Tout} opspec := tf.OpSpec{ - Type: "EagerPyFunc", + Type: "BitwiseAnd", Input: []tf.Input{ - tf.OutputList(input), + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Elementwise computes the bitwise left-shift of `x` and `y`. +// +// If `y` is negative, or greater than or equal to the width of `x` in bits the +// result is implementation defined. +func LeftShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("EagerPyFunc", err) - return + opspec := tf.OpSpec{ + Type: "LeftShift", + Input: []tf.Input{ + x, y, + }, } - return output + op := scope.AddOperation(opspec) + return op.Output(0) } -// Stops gradient computation. -// -// When executed in a graph, this op outputs its input tensor as-is. -// -// When building ops to compute gradients, this op prevents the contribution of -// its inputs to be taken into account. Normally, the gradient generator adds ops -// to a graph to compute the derivatives of a specified 'loss' by recursively -// finding out inputs that contributed to its computation. If you insert this op -// in the graph it inputs are masked from the gradient generator. They are not -// taken into account for computing gradients. +// Elementwise computes the bitwise right-shift of `x` and `y`. // -// This is useful any time you want to compute a value with TensorFlow but need -// to pretend that the value was a constant. Some examples include: +// Performs a logical shift for unsigned integer types, and an arithmetic shift +// for signed integer types. // -// * The *EM* algorithm where the *M-step* should not involve backpropagation -// through the output of the *E-step*. -// * Contrastive divergence training of Boltzmann machines where, when -// differentiating the energy function, the training must not backpropagate -// through the graph that generated the samples from the model. -// * Adversarial training, where no backprop should happen through the adversarial -// example generation process. -func StopGradient(scope *Scope, input tf.Output) (output tf.Output) { +// If `y` is negative, or greater than or equal to than the width of `x` in bits +// the result is implementation defined. +func RightShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "StopGradient", + Type: "RightShift", Input: []tf.Input{ - input, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes asin of x element-wise. -func Asin(scope *Scope, x tf.Output) (y tf.Output) { +// Adjust the hue of one or more images. +// +// `images` is a tensor of at least 3 dimensions. The last dimension is +// interpretted as channels, and must be three. +// +// The input image is considered in the RGB colorspace. Conceptually, the RGB +// colors are first mapped into HSV. A delta is then applied all the hue values, +// and then remapped back to RGB colorspace. +// +// Arguments: +// images: Images to adjust. At least 3-D. +// delta: A float delta to add to the hue. +// +// Returns The hue-adjusted image or images. +func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Asin", + Type: "AdjustHue", Input: []tf.Input{ - x, + images, delta, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// PreventGradientAttr is an optional argument to PreventGradient. -type PreventGradientAttr func(optionalAttr) +// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad. +type AvgPool3DGradAttr func(optionalAttr) -// PreventGradientMessage sets the optional message attribute to value. +// AvgPool3DGradDataFormat sets the optional data_format attribute to value. // -// value: Will be printed in the error when anyone tries to differentiate -// this operation. -// If not specified, defaults to "" -func PreventGradientMessage(value string) PreventGradientAttr { +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr { return func(m optionalAttr) { - m["message"] = value + m["data_format"] = value } } -// An identity op that triggers an error if a gradient is requested. -// -// When executed in a graph, this op outputs its input tensor as-is. -// -// When building ops to compute gradients, the TensorFlow gradient system -// will return an error when trying to lookup the gradient of this op, -// because no gradient must ever be registered for this function. This -// op exists to prevent subtle bugs from silently returning unimplemented -// gradients in some corner cases. +// Computes gradients of average pooling function. // // Arguments: -// input: any tensor. +// orig_input_shape: The original input dimensions. +// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. // -// Returns the same input tensor. -func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) { +// Returns The backprop for input. +func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "PreventGradient", + Type: "AvgPool3DGrad", Input: []tf.Input{ - input, + orig_input_shape, grad, }, Attrs: attrs, } @@ -26604,86 +26357,115 @@ func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientA return op.Output(0) } -// Checks a tensor for NaN and Inf values. -// -// When run, reports an `InvalidArgument` error if `tensor` has any values -// that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is. +// ParseSingleSequenceExampleAttr is an optional argument to ParseSingleSequenceExample. +type ParseSingleSequenceExampleAttr func(optionalAttr) + +// ParseSingleSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value. // -// Arguments: +// value: A list of Ncontext_sparse types; the data types of data in +// each context Feature given in context_sparse_keys. +// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// If not specified, defaults to <> // -// message: Prefix of the error message. -func CheckNumerics(scope *Scope, tensor tf.Output, message string) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"message": message} - opspec := tf.OpSpec{ - Type: "CheckNumerics", - Input: []tf.Input{ - tensor, - }, - Attrs: attrs, +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleContextSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["context_sparse_types"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Shuffle dimensions of x according to a permutation and conjugate the result. +// ParseSingleSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value. +// If not specified, defaults to <> // -// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: -// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` -// `y[i,j,k,...,s,t,u] == conj(x[perm[i], perm[j], perm[k],...,perm[s], perm[t], perm[u]])` -func ConjugateTranspose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ConjugateTranspose", - Input: []tf.Input{ - x, perm, - }, +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["feature_list_dense_types"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// UniqueV2Attr is an optional argument to UniqueV2. -type UniqueV2Attr func(optionalAttr) - -// UniqueV2OutIdx sets the optional out_idx attribute to value. -// If not specified, defaults to DT_INT32 -func UniqueV2OutIdx(value tf.DataType) UniqueV2Attr { +// ParseSingleSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value. +// +// value: A list of Ncontext_dense shapes; the shapes of data in +// each context Feature given in context_dense_keys. +// The number of elements in the Feature corresponding to context_dense_key[j] +// must always equal context_dense_shapes[j].NumEntries(). +// The shape of context_dense_values[j] will match context_dense_shapes[j]. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleContextDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { return func(m optionalAttr) { - m["out_idx"] = value + m["context_dense_shapes"] = value } } -// Finds unique elements in a 1-D tensor. +// ParseSingleSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value. // -// This operation returns a tensor `y` containing all of the unique elements of `x` -// sorted in the same order that they occur in `x`. This operation also returns a -// tensor `idx` the same size as `x` that contains the index of each value of `x` -// in the unique output `y`. In other words: +// value: A list of Nfeature_list_sparse types; the data types +// of data in each FeatureList given in feature_list_sparse_keys. +// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// If not specified, defaults to <> // -// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["feature_list_sparse_types"] = value + } +} + +// ParseSingleSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value. // -// For example: +// value: A list of Nfeature_list_dense shapes; the shapes of +// data in each FeatureList given in feature_list_dense_keys. +// The shape of each Feature in the FeatureList corresponding to +// feature_list_dense_key[j] must always equal +// feature_list_dense_shapes[j].NumEntries(). +// If not specified, defaults to <> // -// ``` -// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] -// y, idx = unique(x) -// y ==> [1, 2, 4, 7, 8] -// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] -// ``` +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["feature_list_dense_shapes"] = value + } +} + +// Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors. // // Arguments: -// x: A `Tensor`. -// axis: A `Tensor` of type `int64` (default: 0). The axis of the Tensor to -// find the unique elements. -// -// Returns A `Tensor`. Unique elements along the `axis` of `Tensor` x.A 1-D Tensor. Has the same type as x that contains the index of each -// value of x in the output y. -func UniqueV2(scope *Scope, x tf.Output, axis tf.Output, optional ...UniqueV2Attr) (y tf.Output, idx tf.Output) { +// serialized: A scalar containing a binary serialized SequenceExample proto. +// feature_list_dense_missing_assumed_empty: A vector listing the +// FeatureList keys which may be missing from the SequenceExample. If the +// associated FeatureList is missing, it is treated as empty. By default, +// any FeatureList not listed in this vector must exist in the SequenceExample. +// context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars). +// The keys expected in the Examples' features associated with context_sparse +// values. +// context_dense_keys: A list of Ncontext_dense string Tensors (scalars). +// The keys expected in the SequenceExamples' context features associated with +// dense values. +// feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors +// (scalars). The keys expected in the FeatureLists associated with sparse +// values. +// feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars). +// The keys expected in the SequenceExamples' feature_lists associated +// with lists of dense values. +// context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty). +// context_dense_defaults[j] provides default values +// when the SequenceExample's context map lacks context_dense_key[j]. +// If an empty Tensor is provided for context_dense_defaults[j], +// then the Feature context_dense_keys[j] is required. +// The input type is inferred from context_dense_defaults[j], even when it's +// empty. If context_dense_defaults[j] is not empty, its shape must match +// context_dense_shapes[j]. +// debug_name: A scalar containing the name of the serialized proto. +// May contain, for example, table key (descriptive) name for the +// corresponding serialized proto. This is purely useful for debugging +// purposes, and the presence of values here has no effect on the output. +// May also be an empty scalar if no name is available. +func ParseSingleSequenceExample(scope *Scope, serialized tf.Output, feature_list_dense_missing_assumed_empty tf.Output, context_sparse_keys []tf.Output, context_dense_keys []tf.Output, feature_list_sparse_keys []tf.Output, feature_list_dense_keys []tf.Output, context_dense_defaults []tf.Output, debug_name tf.Output, optional ...ParseSingleSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output) { if scope.Err() != nil { return } @@ -26692,101 +26474,98 @@ func UniqueV2(scope *Scope, x tf.Output, axis tf.Output, optional ...UniqueV2Att a(attrs) } opspec := tf.OpSpec{ - Type: "UniqueV2", + Type: "ParseSingleSequenceExample", Input: []tf.Input{ - x, axis, + serialized, feature_list_dense_missing_assumed_empty, tf.OutputList(context_sparse_keys), tf.OutputList(context_dense_keys), tf.OutputList(feature_list_sparse_keys), tf.OutputList(feature_list_dense_keys), tf.OutputList(context_dense_defaults), debug_name, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Return a slice from 'input'. -// -// The output tensor is a tensor with dimensions described by 'size' -// whose values are extracted from 'input' starting at the offsets in -// 'begin'. -// -// *Requirements*: -// 0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n) -// -// Arguments: -// -// begin: begin[i] specifies the offset into the 'i'th dimension of -// 'input' to slice from. -// size: size[i] specifies the number of elements of the 'i'th dimension -// of 'input' to slice. If size[i] is -1, all remaining elements in dimension -// i are included in the slice (i.e. this is equivalent to setting -// size[i] = input.dim_size(i) - begin[i]). -func Slice(scope *Scope, input tf.Output, begin tf.Output, size tf.Output) (output tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "Slice", - Input: []tf.Input{ - input, begin, size, - }, + var idx int + var err error + if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// StridedSliceGradAttr is an optional argument to StridedSliceGrad. -type StridedSliceGradAttr func(optionalAttr) - -// StridedSliceGradBeginMask sets the optional begin_mask attribute to value. -// If not specified, defaults to 0 -func StridedSliceGradBeginMask(value int64) StridedSliceGradAttr { - return func(m optionalAttr) { - m["begin_mask"] = value + if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return } -} - -// StridedSliceGradEndMask sets the optional end_mask attribute to value. -// If not specified, defaults to 0 -func StridedSliceGradEndMask(value int64) StridedSliceGradAttr { - return func(m optionalAttr) { - m["end_mask"] = value + if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return } -} - -// StridedSliceGradEllipsisMask sets the optional ellipsis_mask attribute to value. -// If not specified, defaults to 0 -func StridedSliceGradEllipsisMask(value int64) StridedSliceGradAttr { - return func(m optionalAttr) { - m["ellipsis_mask"] = value + if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return } + return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values } -// StridedSliceGradNewAxisMask sets the optional new_axis_mask attribute to value. -// If not specified, defaults to 0 -func StridedSliceGradNewAxisMask(value int64) StridedSliceGradAttr { +// DecodeWavAttr is an optional argument to DecodeWav. +type DecodeWavAttr func(optionalAttr) + +// DecodeWavDesiredChannels sets the optional desired_channels attribute to value. +// +// value: Number of sample channels wanted. +// If not specified, defaults to -1 +func DecodeWavDesiredChannels(value int64) DecodeWavAttr { return func(m optionalAttr) { - m["new_axis_mask"] = value + m["desired_channels"] = value } } -// StridedSliceGradShrinkAxisMask sets the optional shrink_axis_mask attribute to value. -// If not specified, defaults to 0 -func StridedSliceGradShrinkAxisMask(value int64) StridedSliceGradAttr { +// DecodeWavDesiredSamples sets the optional desired_samples attribute to value. +// +// value: Length of audio requested. +// If not specified, defaults to -1 +func DecodeWavDesiredSamples(value int64) DecodeWavAttr { return func(m optionalAttr) { - m["shrink_axis_mask"] = value + m["desired_samples"] = value } } -// Returns the gradient of `StridedSlice`. +// Decode a 16-bit PCM WAV file to a float tensor. // -// Since `StridedSlice` cuts out pieces of its `input` which is size -// `shape`, its gradient will have the same shape (which is passed here -// as `shape`). The gradient will be zero in any element that the slice -// does not select. +// The -32768 to 32767 signed 16-bit values will be scaled to -1.0 to 1.0 in float. // -// Arguments are the same as StridedSliceGrad with the exception that -// `dy` is the input gradient to be propagated and `shape` is the -// shape of `StridedSlice`'s `input`. -func StridedSliceGrad(scope *Scope, shape tf.Output, begin tf.Output, end tf.Output, strides tf.Output, dy tf.Output, optional ...StridedSliceGradAttr) (output tf.Output) { +// When desired_channels is set, if the input contains fewer channels than this +// then the last channel will be duplicated to give the requested number, else if +// the input has more channels than requested then the additional channels will be +// ignored. +// +// If desired_samples is set, then the audio will be cropped or padded with zeroes +// to the requested length. +// +// The first output contains a Tensor with the content of the audio samples. The +// lowest dimension will be the number of channels, and the second will be the +// number of samples. For example, a ten-sample-long stereo WAV file should give an +// output shape of [10, 2]. +// +// Arguments: +// contents: The WAV-encoded audio, usually from a file. +// +// Returns 2-D with shape `[length, channels]`.Scalar holding the sample rate found in the WAV header. +func DecodeWav(scope *Scope, contents tf.Output, optional ...DecodeWavAttr) (audio tf.Output, sample_rate tf.Output) { if scope.Err() != nil { return } @@ -26795,70 +26574,50 @@ func StridedSliceGrad(scope *Scope, shape tf.Output, begin tf.Output, end tf.Out a(attrs) } opspec := tf.OpSpec{ - Type: "StridedSliceGrad", + Type: "DecodeWav", Input: []tf.Input{ - shape, begin, end, strides, dy, + contents, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the gradient of `Tile`. -// -// DEPRECATED at GraphDef version 3: TileGrad has been replaced with reduce_sum -// -// Since `Tile` takes an input and repeats the input `multiples` times -// along each dimension, `TileGrad` takes in `multiples` and aggregates -// each repeated tile of `input` into `output`. -func TileGrad(scope *Scope, input tf.Output, multiples tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TileGrad", - Input: []tf.Input{ - input, multiples, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// DataFormatDimMapAttr is an optional argument to DataFormatDimMap. -type DataFormatDimMapAttr func(optionalAttr) +// UniqueAttr is an optional argument to Unique. +type UniqueAttr func(optionalAttr) -// DataFormatDimMapSrcFormat sets the optional src_format attribute to value. -// -// value: source data format. -// If not specified, defaults to "NHWC" -func DataFormatDimMapSrcFormat(value string) DataFormatDimMapAttr { +// UniqueOutIdx sets the optional out_idx attribute to value. +// If not specified, defaults to DT_INT32 +func UniqueOutIdx(value tf.DataType) UniqueAttr { return func(m optionalAttr) { - m["src_format"] = value + m["out_idx"] = value } } -// DataFormatDimMapDstFormat sets the optional dst_format attribute to value. +// Finds unique elements in a 1-D tensor. // -// value: destination data format. -// If not specified, defaults to "NCHW" -func DataFormatDimMapDstFormat(value string) DataFormatDimMapAttr { - return func(m optionalAttr) { - m["dst_format"] = value - } -} - -// Returns the dimension index in the destination data format given the one in +// This operation returns a tensor `y` containing all of the unique elements of `x` +// sorted in the same order that they occur in `x`. This operation also returns a +// tensor `idx` the same size as `x` that contains the index of each value of `x` +// in the unique output `y`. In other words: // -// the source data format. +// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` +// +// For example: +// +// ``` +// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] +// y, idx = unique(x) +// y ==> [1, 2, 4, 7, 8] +// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] +// ``` // // Arguments: -// x: A Tensor with each element as a dimension index in source data format. -// Must be in the range [-4, 4). +// x: 1-D. // -// Returns A Tensor with each element as a dimension index in destination data format. -func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAttr) (y tf.Output) { +// Returns 1-D.1-D. +func Unique(scope *Scope, x tf.Output, optional ...UniqueAttr) (y tf.Output, idx tf.Output) { if scope.Err() != nil { return } @@ -26867,474 +26626,498 @@ func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAtt a(attrs) } opspec := tf.OpSpec{ - Type: "DataFormatDimMap", + Type: "Unique", Input: []tf.Input{ x, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Return the shape of s0 op s1 with broadcast. +// Concatenates a list of `N` tensors along the first dimension. // -// Given `s0` and `s1`, tensors that represent shapes, compute `r0`, the -// broadcasted shape. `s0`, `s1` and `r0` are all integer vectors. -func BroadcastArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Output) { +// The input tensors are all required to have size 1 in the first dimension. +// +// For example: +// +// ``` +// # 'x' is [[1, 4]] +// # 'y' is [[2, 5]] +// # 'z' is [[3, 6]] +// parallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. +// ``` +// +// The difference between concat and parallel_concat is that concat requires all +// of the inputs be computed before the operation will begin but doesn't require +// that the input shapes be known during graph construction. Parallel concat +// will copy pieces of the input into the output as they become available, in +// some situations this can provide a performance benefit. +// +// Arguments: +// values: Tensors to be concatenated. All must have size 1 in the first dimension +// and same shape. +// shape: the final shape of the result; should be equal to the shapes of any input +// but with the number of input values in the first dimension. +// +// Returns The concatenated tensor. +func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"shape": shape} opspec := tf.OpSpec{ - Type: "BroadcastArgs", + Type: "ParallelConcat", Input: []tf.Input{ - s0, s1, + tf.OutputList(values), }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Return the reduction indices for computing gradients of s0 op s1 with broadcast. +// Compute the lower regularized incomplete Gamma function `Q(a, x)`. // -// This is typically used by gradient computations for a broadcasting operation. -func BroadcastGradientArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Output, r1 tf.Output) { +// The lower regularized incomplete Gamma function is defined as: +// +// +// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\) +// +// where +// +// \\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\) +// +// is the lower incomplete Gamma function. +// +// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete +// Gamma function. +func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "BroadcastGradientArgs", + Type: "Igamma", Input: []tf.Input{ - s0, s1, + a, x, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Pads a tensor with mirrored values. -// -// This operation pads a `input` with mirrored values according to the `paddings` -// you specify. `paddings` is an integer tensor with shape `[n, 2]`, where n is -// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates -// how many values to add before the contents of `input` in that dimension, and -// `paddings[D, 1]` indicates how many values to add after the contents of `input` -// in that dimension. Both `paddings[D, 0]` and `paddings[D, 1]` must be no greater -// than `input.dim_size(D)` (or `input.dim_size(D) - 1`) if `copy_border` is true -// (if false, respectively). -// -// The padded size of each dimension D of the output is: -// -// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` +// Computes offsets of concat inputs within its output. // // For example: // // ``` -// # 't' is [[1, 2, 3], [4, 5, 6]]. -// # 'paddings' is [[1, 1]], [2, 2]]. -// # 'mode' is SYMMETRIC. -// # rank of 't' is 2. -// pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2] -// [2, 1, 1, 2, 3, 3, 2] -// [5, 4, 4, 5, 6, 6, 5] -// [5, 4, 4, 5, 6, 6, 5]] +// # 'x' is [2, 2, 7] +// # 'y' is [2, 3, 7] +// # 'z' is [2, 5, 7] +// concat_offset(2, [x, y, z]) => [0, 0, 0], [0, 2, 0], [0, 5, 0] // ``` // +// This is typically used by gradient computations for a concat operation. +// // Arguments: -// input: The input tensor to be padded. -// paddings: A two-column matrix specifying the padding sizes. The number of -// rows must be the same as the rank of `input`. -// mode: Either `REFLECT` or `SYMMETRIC`. In reflect mode the padded regions -// do not include the borders, while in symmetric mode the padded regions -// do include the borders. For example, if `input` is `[1, 2, 3]` and `paddings` -// is `[0, 2]`, then the output is `[1, 2, 3, 2, 1]` in reflect mode, and -// it is `[1, 2, 3, 3, 2]` in symmetric mode. +// concat_dim: The dimension along which to concatenate. +// shape: The `N` int32 vectors representing shape of tensors being concatenated. +// +// Returns The `N` int32 vectors representing the starting offset +// of input tensors within the concatenated output. +func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset []tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ConcatOffset", + Input: []tf.Input{ + concat_dim, tf.OutputList(shape), + }, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if offset, idx, err = makeOutputList(op, idx, "offset"); err != nil { + scope.UpdateErr("ConcatOffset", err) + return + } + return offset +} + +// Splits a tensor into `num_split` tensors along one dimension. +// +// Arguments: +// axis: 0-D. The dimension along which to split. Must be in the range +// `[-rank(value), rank(value))`. +// value: The tensor to split. +// num_split: The number of ways to split. Must evenly divide +// `value.shape[split_dim]`. // -// Returns The padded tensor. -func MirrorPad(scope *Scope, input tf.Output, paddings tf.Output, mode string) (output tf.Output) { +// Returns They are identically shaped tensors, whose shape matches that of `value` +// except along `axis`, where their sizes are +// `values.shape[split_dim] / num_split`. +func Split(scope *Scope, axis tf.Output, value tf.Output, num_split int64) (output []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"mode": mode} + attrs := map[string]interface{}{"num_split": num_split} opspec := tf.OpSpec{ - Type: "MirrorPad", + Type: "Split", Input: []tf.Input{ - input, paddings, + axis, value, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("Split", err) + return + } + return output } -// A placeholder op for a value that will be fed into the computation. -// -// DEPRECATED at GraphDef version 23: Placeholder now behaves the same as PlaceholderV2. -// -// N.B. This operation will fail with an error if it is executed. It is -// intended as a way to represent a value that will always be fed, and to -// provide attrs that enable the fed value to be checked at runtime. +// Splits a tensor into `num_split` tensors along one dimension. // // Arguments: -// dtype: The type of elements in the tensor. -// shape: The shape of the tensor. The shape can be any partially-specified -// shape. To be unconstrained, pass in a shape with unknown rank. +// value: The tensor to split. +// size_splits: list containing the sizes of each output tensor along the split +// dimension. Must sum to the dimension of value along split_dim. +// Can contain one -1 indicating that dimension is to be inferred. +// axis: 0-D. The dimension along which to split. Must be in the range +// `[-rank(value), rank(value))`. // -// Returns A placeholder tensor that must be replaced using the feed mechanism. -func PlaceholderV2(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.Output) { +// +// Returns Tensors whose shape matches that of `value` +// except along `axis`, where their sizes are +// `size_splits[i]`. +func SplitV(scope *Scope, value tf.Output, size_splits tf.Output, axis tf.Output, num_split int64) (output []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype, "shape": shape} + attrs := map[string]interface{}{"num_split": num_split} opspec := tf.OpSpec{ - Type: "PlaceholderV2", - + Type: "SplitV", + Input: []tf.Input{ + value, size_splits, axis, + }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyAdadeltaAttr is an optional argument to ResourceApplyAdadelta. -type ResourceApplyAdadeltaAttr func(optionalAttr) - -// ResourceApplyAdadeltaUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var, accum and update_accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceApplyAdadeltaUseLocking(value bool) ResourceApplyAdadeltaAttr { - return func(m optionalAttr) { - m["use_locking"] = value + if scope.Err() != nil { + return } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("SplitV", err) + return + } + return output } -// Update '*var' according to the adadelta scheme. +// Gives a guarantee to the TF runtime that the input tensor is a constant. // -// accum = rho() * accum + (1 - rho()) * grad.square(); -// update = (update_accum + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad; -// update_accum = rho() * update_accum + (1 - rho()) * update.square(); -// var -= update; +// The runtime is then free to make optimizations based on this. // -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// accum_update: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// rho: Decay factor. Must be a scalar. -// epsilon: Constant factor. Must be a scalar. -// grad: The gradient. +// Only accepts value typed tensors as inputs and rejects resource variable handles +// as input. // -// Returns the created operation. -func ResourceApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdadeltaAttr) (o *tf.Operation) { +// Returns the input tensor without modification. +func GuaranteeConst(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ResourceApplyAdadelta", + Type: "GuaranteeConst", Input: []tf.Input{ - var_, accum, accum_update, lr, rho, epsilon, grad, + input, }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// SqueezeAttr is an optional argument to Squeeze. -type SqueezeAttr func(optionalAttr) - -// SqueezeAxis sets the optional axis attribute to value. -// -// value: If specified, only squeezes the dimensions listed. The dimension -// index starts at 0. It is an error to squeeze a dimension that is not 1. Must -// be in the range `[-rank(input), rank(input))`. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func SqueezeAxis(value []int64) SqueezeAttr { - return func(m optionalAttr) { - m["squeeze_dims"] = value } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Removes dimensions of size 1 from the shape of a tensor. -// -// Given a tensor `input`, this operation returns a tensor of the same type with -// all dimensions of size 1 removed. If you don't want to remove all size 1 -// dimensions, you can remove specific size 1 dimensions by specifying -// `axis`. -// -// For example: -// -// ``` -// # 't' is a tensor of shape [1, 2, 1, 3, 1, 1] -// shape(squeeze(t)) ==> [2, 3] -// ``` -// -// Or, to remove specific size 1 dimensions: -// -// ``` -// # 't' is a tensor of shape [1, 2, 1, 3, 1, 1] -// shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] -// ``` +// Returns a tensor of zeros with the same shape and type as x. // // Arguments: -// input: The `input` to squeeze. +// x: a tensor of type T. // -// Returns Contains the same data as `input`, but has one or more dimensions of -// size 1 removed. -func Squeeze(scope *Scope, input tf.Output, optional ...SqueezeAttr) (output tf.Output) { +// Returns a tensor of the same shape and type as x but filled with zeros. +func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Squeeze", + Type: "ZerosLike", Input: []tf.Input{ - input, + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// SpaceToBatch for N-D tensors of type T. -// -// This operation divides "spatial" dimensions `[1, ..., M]` of the input into a -// grid of blocks of shape `block_shape`, and interleaves these blocks with the -// "batch" dimension (0) such that in the output, the spatial dimensions -// `[1, ..., M]` correspond to the position within the grid, and the batch -// dimension combines both the position within a spatial block and the original -// batch position. Prior to division into blocks, the spatial dimensions of the -// input are optionally zero padded according to `paddings`. See below for a -// precise description. -// -// Arguments: -// input: N-D with shape `input_shape = [batch] + spatial_shape + remaining_shape`, -// where spatial_shape has `M` dimensions. -// block_shape: 1-D with shape `[M]`, all values must be >= 1. -// paddings: 2-D with shape `[M, 2]`, all values must be >= 0. -// `paddings[i] = [pad_start, pad_end]` specifies the padding for input dimension -// `i + 1`, which corresponds to spatial dimension `i`. It is required that -// `block_shape[i]` divides `input_shape[i + 1] + pad_start + pad_end`. -// -// This operation is equivalent to the following steps: -// -// 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the -// input according to `paddings` to produce `padded` of shape `padded_shape`. -// -// 2. Reshape `padded` to `reshaped_padded` of shape: -// -// [batch] + -// [padded_shape[1] / block_shape[0], -// block_shape[0], -// ..., -// padded_shape[M] / block_shape[M-1], -// block_shape[M-1]] + -// remaining_shape -// -// 3. Permute dimensions of `reshaped_padded` to produce -// `permuted_reshaped_padded` of shape: -// -// block_shape + -// [batch] + -// [padded_shape[1] / block_shape[0], -// ..., -// padded_shape[M] / block_shape[M-1]] + -// remaining_shape -// -// 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the batch -// dimension, producing an output tensor of shape: -// -// [batch * prod(block_shape)] + -// [padded_shape[1] / block_shape[0], -// ..., -// padded_shape[M] / block_shape[M-1]] + -// remaining_shape -// -// Some examples: +// QuantizedInstanceNormAttr is an optional argument to QuantizedInstanceNorm. +type QuantizedInstanceNormAttr func(optionalAttr) + +// QuantizedInstanceNormOutputRangeGiven sets the optional output_range_given attribute to value. // -// (1) For the following input of shape `[1, 2, 2, 1]`, `block_shape = [2, 2]`, and -// `paddings = [[0, 0], [0, 0]]`: +// value: If True, `given_y_min` and `given_y_min` +// and `given_y_max` are used as the output range. Otherwise, +// the implementation computes the output range. +// If not specified, defaults to false +func QuantizedInstanceNormOutputRangeGiven(value bool) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["output_range_given"] = value + } +} + +// QuantizedInstanceNormGivenYMin sets the optional given_y_min attribute to value. // -// ``` -// x = [[[[1], [2]], [[3], [4]]]] -// ``` +// value: Output in `y_min` if `output_range_given` is True. +// If not specified, defaults to 0 +func QuantizedInstanceNormGivenYMin(value float32) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["given_y_min"] = value + } +} + +// QuantizedInstanceNormGivenYMax sets the optional given_y_max attribute to value. // -// The output tensor has shape `[4, 1, 1, 1]` and value: +// value: Output in `y_max` if `output_range_given` is True. +// If not specified, defaults to 0 +func QuantizedInstanceNormGivenYMax(value float32) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["given_y_max"] = value + } +} + +// QuantizedInstanceNormVarianceEpsilon sets the optional variance_epsilon attribute to value. // -// ``` -// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] -// ``` +// value: A small float number to avoid dividing by 0. +// If not specified, defaults to 1e-05 +func QuantizedInstanceNormVarianceEpsilon(value float32) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["variance_epsilon"] = value + } +} + +// QuantizedInstanceNormMinSeparation sets the optional min_separation attribute to value. // -// (2) For the following input of shape `[1, 2, 2, 3]`, `block_shape = [2, 2]`, and -// `paddings = [[0, 0], [0, 0]]`: +// value: Minimum value of `y_max - y_min` +// If not specified, defaults to 0.001 +func QuantizedInstanceNormMinSeparation(value float32) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["min_separation"] = value + } +} + +// Quantized Instance normalization. // -// ``` -// x = [[[[1, 2, 3], [4, 5, 6]], -// [[7, 8, 9], [10, 11, 12]]]] -// ``` +// Arguments: +// x: A 4D input Tensor. +// x_min: The value represented by the lowest quantized input. +// x_max: The value represented by the highest quantized input. // -// The output tensor has shape `[4, 1, 1, 3]` and value: +// Returns A 4D Tensor.The value represented by the lowest quantized output.The value represented by the highest quantized output. +func QuantizedInstanceNorm(scope *Scope, x tf.Output, x_min tf.Output, x_max tf.Output, optional ...QuantizedInstanceNormAttr) (y tf.Output, y_min tf.Output, y_max tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizedInstanceNorm", + Input: []tf.Input{ + x, x_min, x_max, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Returns the diagonal part of the tensor. // -// ``` -// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] -// ``` +// This operation returns a tensor with the `diagonal` part +// of the `input`. The `diagonal` part is computed as follows: // -// (3) For the following input of shape `[1, 4, 4, 1]`, `block_shape = [2, 2]`, and -// `paddings = [[0, 0], [0, 0]]`: +// Assume `input` has dimensions `[D1,..., Dk, D1,..., Dk]`, then the output is a +// tensor of rank `k` with dimensions `[D1,..., Dk]` where: // -// ``` -// x = [[[[1], [2], [3], [4]], -// [[5], [6], [7], [8]], -// [[9], [10], [11], [12]], -// [[13], [14], [15], [16]]]] -// ``` +// `diagonal[i1,..., ik] = input[i1, ..., ik, i1,..., ik]`. // -// The output tensor has shape `[4, 2, 2, 1]` and value: +// For example: // // ``` -// x = [[[[1], [3]], [[9], [11]]], -// [[[2], [4]], [[10], [12]]], -// [[[5], [7]], [[13], [15]]], -// [[[6], [8]], [[14], [16]]]] -// ``` -// -// (4) For the following input of shape `[2, 2, 4, 1]`, block_shape = `[2, 2]`, and -// paddings = `[[0, 0], [2, 0]]`: +// # 'input' is [[1, 0, 0, 0] +// [0, 2, 0, 0] +// [0, 0, 3, 0] +// [0, 0, 0, 4]] // +// tf.diag_part(input) ==> [1, 2, 3, 4] // ``` -// x = [[[[1], [2], [3], [4]], -// [[5], [6], [7], [8]]], -// [[[9], [10], [11], [12]], -// [[13], [14], [15], [16]]]] -// ``` -// -// The output tensor has shape `[8, 1, 3, 1]` and value: // -// ``` -// x = [[[[0], [1], [3]]], [[[0], [9], [11]]], -// [[[0], [2], [4]]], [[[0], [10], [12]]], -// [[[0], [5], [7]]], [[[0], [13], [15]]], -// [[[0], [6], [8]]], [[[0], [14], [16]]]] -// ``` +// Arguments: +// input: Rank k tensor where k is even and not zero. // -// Among others, this operation is useful for reducing atrous convolution into -// regular convolution. -func SpaceToBatchND(scope *Scope, input tf.Output, block_shape tf.Output, paddings tf.Output) (output tf.Output) { +// Returns The extracted diagonal. +func DiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SpaceToBatchND", + Type: "DiagPart", Input: []tf.Input{ - input, block_shape, paddings, + input, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// QuantizeAndDequantizeV2Attr is an optional argument to QuantizeAndDequantizeV2. -type QuantizeAndDequantizeV2Attr func(optionalAttr) - -// QuantizeAndDequantizeV2SignedInput sets the optional signed_input attribute to value. +// Returns the element-wise max of two SparseTensors. // -// value: If the quantization is signed or unsigned. -// If not specified, defaults to true -func QuantizeAndDequantizeV2SignedInput(value bool) QuantizeAndDequantizeV2Attr { - return func(m optionalAttr) { - m["signed_input"] = value - } -} - -// QuantizeAndDequantizeV2NumBits sets the optional num_bits attribute to value. +// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. // -// value: The bitwidth of the quantization. -// If not specified, defaults to 8 -func QuantizeAndDequantizeV2NumBits(value int64) QuantizeAndDequantizeV2Attr { - return func(m optionalAttr) { - m["num_bits"] = value - } -} - -// QuantizeAndDequantizeV2RangeGiven sets the optional range_given attribute to value. +// Arguments: +// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, in the canonical lexicographic ordering. +// a_values: 1-D. `N` non-empty values corresponding to `a_indices`. +// a_shape: 1-D. Shape of the input SparseTensor. +// b_indices: counterpart to `a_indices` for the other operand. +// b_values: counterpart to `a_values` for the other operand; must be of the same dtype. +// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal. // -// value: If the range is given or should be computed from the tensor. -// If not specified, defaults to false -func QuantizeAndDequantizeV2RangeGiven(value bool) QuantizeAndDequantizeV2Attr { - return func(m optionalAttr) { - m["range_given"] = value +// Returns 2-D. The indices of the output SparseTensor.1-D. The values of the output SparseTensor. +func SparseSparseMaximum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSparseMaximum", + Input: []tf.Input{ + a_indices, a_values, a_shape, b_indices, b_values, b_shape, + }, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } -// Quantizes then dequantizes a tensor. -// -// This op simulates the precision loss from the quantized forward pass by: -// 1. Quantizing the tensor to fixed point numbers, which should match the target -// quantization method when it is used in inference. -// 2. Dequantizing it back to floating point numbers for the following ops, most -// likely matmul. +// Returns a batched matrix tensor with new batched diagonal values. // -// There are different ways to quantize. This version does not use the full range -// of the output type, choosing to elide the lowest possible value for symmetry -// (e.g., output range is -127 to 127, not -128 to 127 for signed 8 bit -// quantization), so that 0.0 maps to 0. +// Given `input` and `diagonal`, this operation returns a tensor with the +// same shape and values as `input`, except for the main diagonal of the +// innermost matrices. These will be overwritten by the values in `diagonal`. // -// To perform this op, we first find the range of values in our tensor. The range -// we use is always centered on 0, so we find m such that +// The output is computed as follows: // -// 1. m = max(abs(input_min), abs(input_max)) if range_given is true, -// 2. m = max(abs(min_elem(input)), abs(max_elem(input))) otherwise. +// Assume `input` has `k+1` dimensions `[I, J, K, ..., M, N]` and `diagonal` has +// `k` dimensions `[I, J, K, ..., min(M, N)]`. Then the output is a +// tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where: // -// Our input tensor range is then [-m, m]. +// * `output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]` for `m == n`. +// * `output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n]` for `m != n`. // -// Next, we choose our fixed-point quantization buckets, [min_fixed, max_fixed]. -// If signed_input is true, this is +// Arguments: +// input: Rank `k+1`, where `k >= 1`. +// diagonal: Rank `k`, where `k >= 1`. // -// [min_fixed, max_fixed ] = -// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1]. +// Returns Rank `k+1`, with `output.shape = input.shape`. +func MatrixSetDiag(scope *Scope, input tf.Output, diagonal tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MatrixSetDiag", + Input: []tf.Input{ + input, diagonal, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// EditDistanceAttr is an optional argument to EditDistance. +type EditDistanceAttr func(optionalAttr) + +// EditDistanceNormalize sets the optional normalize attribute to value. // -// Otherwise, if signed_input is false, the fixed-point range is +// value: boolean (if true, edit distances are normalized by length of truth). // -// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1]. +// The output is: +// If not specified, defaults to true +func EditDistanceNormalize(value bool) EditDistanceAttr { + return func(m optionalAttr) { + m["normalize"] = value + } +} + +// Computes the (possibly normalized) Levenshtein Edit Distance. // -// From this we compute our scaling factor, s: +// The inputs are variable-length sequences provided by SparseTensors +// (hypothesis_indices, hypothesis_values, hypothesis_shape) +// and +// (truth_indices, truth_values, truth_shape). // -// s = (max_fixed - min_fixed) / (2 * m). +// The inputs are: // -// Now we can quantize and dequantize the elements of our tensor. An element e -// is transformed into e': +// Arguments: +// hypothesis_indices: The indices of the hypothesis list SparseTensor. +// This is an N x R int64 matrix. +// hypothesis_values: The values of the hypothesis list SparseTensor. +// This is an N-length vector. +// hypothesis_shape: The shape of the hypothesis list SparseTensor. +// This is an R-length vector. +// truth_indices: The indices of the truth list SparseTensor. +// This is an M x R int64 matrix. +// truth_values: The values of the truth list SparseTensor. +// This is an M-length vector. +// truth_shape: truth indices, vector. // -// e' = (e * s).round_to_nearest() / s. +// Returns A dense float tensor with rank R - 1. // -// Note that we have a different number of buckets in the signed vs. unsigned -// cases. For example, if num_bits == 8, we get 254 buckets in the signed case -// vs. 255 in the unsigned case. +// For the example input: // -// For example, suppose num_bits = 8 and m = 1. Then +// // hypothesis represents a 2x1 matrix with variable-length values: +// // (0,0) = ["a"] +// // (1,0) = ["b"] +// hypothesis_indices = [[0, 0, 0], +// [1, 0, 0]] +// hypothesis_values = ["a", "b"] +// hypothesis_shape = [2, 1, 1] // -// [min_fixed, max_fixed] = [-127, 127], and -// s = (127 + 127) / 2 = 127. +// // truth represents a 2x2 matrix with variable-length values: +// // (0,0) = [] +// // (0,1) = ["a"] +// // (1,0) = ["b", "c"] +// // (1,1) = ["a"] +// truth_indices = [[0, 1, 0], +// [1, 0, 0], +// [1, 0, 1], +// [1, 1, 0]] +// truth_values = ["a", "b", "c", "a"] +// truth_shape = [2, 2, 2] +// normalize = true // -// Given the vector {-1, -0.5, 0, 0.3}, this is quantized to -// {-127, -63, 0, 38}, and dequantized to {-1, -63.0/127, 0, 38.0/127}. +// The output will be: // -// Arguments: -// input: Tensor to quantize and then dequantize. -// input_min: If range_given, this is the min of the range, otherwise this input -// will be ignored. -// input_max: If range_given, this is the max of the range, otherwise this input -// will be ignored. -func QuantizeAndDequantizeV2(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, optional ...QuantizeAndDequantizeV2Attr) (output tf.Output) { +// // output is a 2x2 matrix with edit distances normalized by truth lengths. +// output = [[inf, 1.0], // (0,0): no truth, (0,1): no hypothesis +// [0.5, 1.0]] // (1,0): addition, (1,1): no hypothesis +func EditDistance(scope *Scope, hypothesis_indices tf.Output, hypothesis_values tf.Output, hypothesis_shape tf.Output, truth_indices tf.Output, truth_values tf.Output, truth_shape tf.Output, optional ...EditDistanceAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -27343,9 +27126,9 @@ func QuantizeAndDequantizeV2(scope *Scope, input tf.Output, input_min tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizeAndDequantizeV2", + Type: "EditDistance", Input: []tf.Input{ - input, input_min, input_max, + hypothesis_indices, hypothesis_values, hypothesis_shape, truth_indices, truth_values, truth_shape, }, Attrs: attrs, } @@ -27353,201 +27136,270 @@ func QuantizeAndDequantizeV2(scope *Scope, input tf.Output, input_min tf.Output, return op.Output(0) } -// SpaceToBatch for 4-D tensors of type T. -// -// This is a legacy version of the more general SpaceToBatchND. +// Gather slices from `params` into a Tensor with shape specified by `indices`. // -// Zero-pads and then rearranges (permutes) blocks of spatial data into batch. -// More specifically, this op outputs a copy of the input tensor where values from -// the `height` and `width` dimensions are moved to the `batch` dimension. After -// the zero-padding, both `height` and `width` of the input must be divisible by the -// block size. +// `indices` is an K-dimensional integer tensor, best thought of as a +// (K-1)-dimensional tensor of indices into `params`, where each element defines a +// slice of `params`: // -// Arguments: -// input: 4-D with shape `[batch, height, width, depth]`. -// paddings: 2-D tensor of non-negative integers with shape `[2, 2]`. It specifies -// the padding of the input with zeros across the spatial dimensions as follows: +// output[i_0, ..., i_{K-2}] = params[indices[i0, ..., i_{K-2}]] // -// paddings = [[pad_top, pad_bottom], [pad_left, pad_right]] +// Whereas in @{tf.gather} `indices` defines slices into the first +// dimension of `params`, in `tf.gather_nd`, `indices` defines slices into the +// first `N` dimensions of `params`, where `N = indices.shape[-1]`. // -// The effective spatial dimensions of the zero-padded input tensor will be: +// The last dimension of `indices` can be at most the rank of +// `params`: // -// height_pad = pad_top + height + pad_bottom -// width_pad = pad_left + width + pad_right +// indices.shape[-1] <= params.rank // -// The attr `block_size` must be greater than one. It indicates the block size. +// The last dimension of `indices` corresponds to elements +// (if `indices.shape[-1] == params.rank`) or slices +// (if `indices.shape[-1] < params.rank`) along dimension `indices.shape[-1]` +// of `params`. The output tensor has shape // -// * Non-overlapping blocks of size `block_size x block size` in the height and -// width dimensions are rearranged into the batch dimension at each location. -// * The batch of the output tensor is `batch * block_size * block_size`. -// * Both height_pad and width_pad must be divisible by block_size. +// indices.shape[:-1] + params.shape[indices.shape[-1]:] // -// The shape of the output will be: +// Some examples below. // -// [batch*block_size*block_size, height_pad/block_size, width_pad/block_size, -// depth] +// Simple indexing into a matrix: // -// Some examples: +// ```python +// indices = [[0, 0], [1, 1]] +// params = [['a', 'b'], ['c', 'd']] +// output = ['a', 'd'] +// ``` // -// (1) For the following input of shape `[1, 2, 2, 1]` and block_size of 2: +// Slice indexing into a matrix: // -// ``` -// x = [[[[1], [2]], [[3], [4]]]] +// ```python +// indices = [[1], [0]] +// params = [['a', 'b'], ['c', 'd']] +// output = [['c', 'd'], ['a', 'b']] // ``` // -// The output tensor has shape `[4, 1, 1, 1]` and value: +// Indexing into a 3-tensor: // -// ``` -// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] -// ``` +// ```python +// indices = [[1]] +// params = [[['a0', 'b0'], ['c0', 'd0']], +// [['a1', 'b1'], ['c1', 'd1']]] +// output = [[['a1', 'b1'], ['c1', 'd1']]] // -// (2) For the following input of shape `[1, 2, 2, 3]` and block_size of 2: // -// ``` -// x = [[[[1, 2, 3], [4, 5, 6]], -// [[7, 8, 9], [10, 11, 12]]]] -// ``` +// indices = [[0, 1], [1, 0]] +// params = [[['a0', 'b0'], ['c0', 'd0']], +// [['a1', 'b1'], ['c1', 'd1']]] +// output = [['c0', 'd0'], ['a1', 'b1']] // -// The output tensor has shape `[4, 1, 1, 3]` and value: // -// ``` -// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] +// indices = [[0, 0, 1], [1, 0, 1]] +// params = [[['a0', 'b0'], ['c0', 'd0']], +// [['a1', 'b1'], ['c1', 'd1']]] +// output = ['b0', 'b1'] // ``` // -// (3) For the following input of shape `[1, 4, 4, 1]` and block_size of 2: +// Batched indexing into a matrix: // -// ``` -// x = [[[[1], [2], [3], [4]], -// [[5], [6], [7], [8]], -// [[9], [10], [11], [12]], -// [[13], [14], [15], [16]]]] +// ```python +// indices = [[[0, 0]], [[0, 1]]] +// params = [['a', 'b'], ['c', 'd']] +// output = [['a'], ['b']] // ``` // -// The output tensor has shape `[4, 2, 2, 1]` and value: +// Batched slice indexing into a matrix: // -// ``` -// x = [[[[1], [3]], [[9], [11]]], -// [[[2], [4]], [[10], [12]]], -// [[[5], [7]], [[13], [15]]], -// [[[6], [8]], [[14], [16]]]] +// ```python +// indices = [[[1]], [[0]]] +// params = [['a', 'b'], ['c', 'd']] +// output = [[['c', 'd']], [['a', 'b']]] // ``` // -// (4) For the following input of shape `[2, 2, 4, 1]` and block_size of 2: +// Batched indexing into a 3-tensor: // -// ``` -// x = [[[[1], [2], [3], [4]], -// [[5], [6], [7], [8]]], -// [[[9], [10], [11], [12]], -// [[13], [14], [15], [16]]]] -// ``` +// ```python +// indices = [[[1]], [[0]]] +// params = [[['a0', 'b0'], ['c0', 'd0']], +// [['a1', 'b1'], ['c1', 'd1']]] +// output = [[[['a1', 'b1'], ['c1', 'd1']]], +// [[['a0', 'b0'], ['c0', 'd0']]]] // -// The output tensor has shape `[8, 1, 2, 1]` and value: +// indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]] +// params = [[['a0', 'b0'], ['c0', 'd0']], +// [['a1', 'b1'], ['c1', 'd1']]] +// output = [[['c0', 'd0'], ['a1', 'b1']], +// [['a0', 'b0'], ['c1', 'd1']]] // -// ``` -// x = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], -// [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] +// +// indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]] +// params = [[['a0', 'b0'], ['c0', 'd0']], +// [['a1', 'b1'], ['c1', 'd1']]] +// output = [['b0', 'b1'], ['d0', 'c1']] // ``` // -// Among others, this operation is useful for reducing atrous convolution into -// regular convolution. +// Arguments: +// params: The tensor from which to gather values. +// indices: Index tensor. // -func SpaceToBatch(scope *Scope, input tf.Output, paddings tf.Output, block_size int64) (output tf.Output) { +// Returns Values from `params` gathered from indices given by `indices`, with +// shape `indices.shape[:-1] + params.shape[indices.shape[-1]:]`. +func GatherNd(scope *Scope, params tf.Output, indices tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "GatherNd", + Input: []tf.Input{ + params, indices, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Eagerly executes a python function to compute func(input)->output. The +// +// semantics of the input, output, and attributes are the same as those for +// PyFunc. +func EagerPyFunc(scope *Scope, input []tf.Output, token string, Tout []tf.DataType) (output []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"token": token, "Tout": Tout} + opspec := tf.OpSpec{ + Type: "EagerPyFunc", + Input: []tf.Input{ + tf.OutputList(input), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("EagerPyFunc", err) + return + } + return output +} + +// Stops gradient computation. +// +// When executed in a graph, this op outputs its input tensor as-is. +// +// When building ops to compute gradients, this op prevents the contribution of +// its inputs to be taken into account. Normally, the gradient generator adds ops +// to a graph to compute the derivatives of a specified 'loss' by recursively +// finding out inputs that contributed to its computation. If you insert this op +// in the graph it inputs are masked from the gradient generator. They are not +// taken into account for computing gradients. +// +// This is useful any time you want to compute a value with TensorFlow but need +// to pretend that the value was a constant. Some examples include: +// +// * The *EM* algorithm where the *M-step* should not involve backpropagation +// through the output of the *E-step*. +// * Contrastive divergence training of Boltzmann machines where, when +// differentiating the energy function, the training must not backpropagate +// through the graph that generated the samples from the model. +// * Adversarial training, where no backprop should happen through the adversarial +// example generation process. +func StopGradient(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "StopGradient", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes asin of x element-wise. +func Asin(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"block_size": block_size} opspec := tf.OpSpec{ - Type: "SpaceToBatch", + Type: "Asin", Input: []tf.Input{ - input, paddings, + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// UnpackAttr is an optional argument to Unpack. -type UnpackAttr func(optionalAttr) +// PreventGradientAttr is an optional argument to PreventGradient. +type PreventGradientAttr func(optionalAttr) -// UnpackAxis sets the optional axis attribute to value. +// PreventGradientMessage sets the optional message attribute to value. // -// value: Dimension along which to unpack. Negative values wrap around, so the -// valid range is `[-R, R)`. -// If not specified, defaults to 0 -func UnpackAxis(value int64) UnpackAttr { +// value: Will be printed in the error when anyone tries to differentiate +// this operation. +// If not specified, defaults to "" +func PreventGradientMessage(value string) PreventGradientAttr { return func(m optionalAttr) { - m["axis"] = value + m["message"] = value } } -// Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors. -// -// Unpacks `num` tensors from `value` by chipping it along the `axis` dimension. -// For example, given a tensor of shape `(A, B, C, D)`; -// -// If `axis == 0` then the i'th tensor in `output` is the slice `value[i, :, :, :]` -// and each tensor in `output` will have shape `(B, C, D)`. (Note that the -// dimension unpacked along is gone, unlike `split`). +// An identity op that triggers an error if a gradient is requested. // -// If `axis == 1` then the i'th tensor in `output` is the slice `value[:, i, :, :]` -// and each tensor in `output` will have shape `(A, C, D)`. -// Etc. +// When executed in a graph, this op outputs its input tensor as-is. // -// This is the opposite of `pack`. +// When building ops to compute gradients, the TensorFlow gradient system +// will return an error when trying to lookup the gradient of this op, +// because no gradient must ever be registered for this function. This +// op exists to prevent subtle bugs from silently returning unimplemented +// gradients in some corner cases. // // Arguments: -// value: 1-D or higher, with `axis` dimension size equal to `num`. -// +// input: any tensor. // -// Returns The list of tensors unpacked from `value`. -func Unpack(scope *Scope, value tf.Output, num int64, optional ...UnpackAttr) (output []tf.Output) { +// Returns the same input tensor. +func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num": num} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Unpack", + Type: "PreventGradient", Input: []tf.Input{ - value, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("Unpack", err) - return - } - return output + return op.Output(0) } -// Increments variable pointed to by 'resource' until it reaches 'limit'. +// Checks a tensor for NaN and Inf values. // -// Arguments: -// resource: Should be from a scalar `Variable` node. -// limit: If incrementing ref would bring it above limit, instead generates an -// 'OutOfRange' error. +// When run, reports an `InvalidArgument` error if `tensor` has any values +// that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is. // +// Arguments: // -// Returns A copy of the input before increment. If nothing else modifies the -// input, the values produced will all be distinct. -func ResourceCountUpTo(scope *Scope, resource tf.Output, limit int64, T tf.DataType) (output tf.Output) { +// message: Prefix of the error message. +func CheckNumerics(scope *Scope, tensor tf.Output, message string) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"limit": limit, "T": T} + attrs := map[string]interface{}{"message": message} opspec := tf.OpSpec{ - Type: "ResourceCountUpTo", + Type: "CheckNumerics", Input: []tf.Input{ - resource, + tensor, }, Attrs: attrs, } @@ -27555,197 +27407,176 @@ func ResourceCountUpTo(scope *Scope, resource tf.Output, limit int64, T tf.DataT return op.Output(0) } -// Delete the stack from its resource container. -// -// Arguments: -// handle: The handle to a stack. +// Shuffle dimensions of x according to a permutation and conjugate the result. // -// Returns the created operation. -func StackCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { +// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: +// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` +// `y[i,j,k,...,s,t,u] == conj(x[perm[i], perm[j], perm[k],...,perm[s], perm[t], perm[u]])` +func ConjugateTranspose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "StackCloseV2", + Type: "ConjugateTranspose", Input: []tf.Input{ - handle, + x, perm, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// BatchToSpace for N-D tensors of type T. -// -// This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of shape -// `block_shape + [batch]`, interleaves these blocks back into the grid defined by -// the spatial dimensions `[1, ..., M]`, to obtain a result with the same rank as -// the input. The spatial dimensions of this intermediate result are then -// optionally cropped according to `crops` to produce the output. This is the -// reverse of SpaceToBatch. See below for a precise description. -// -// Arguments: -// input: N-D with shape `input_shape = [batch] + spatial_shape + remaining_shape`, -// where spatial_shape has M dimensions. -// block_shape: 1-D with shape `[M]`, all values must be >= 1. -// crops: 2-D with shape `[M, 2]`, all values must be >= 0. -// `crops[i] = [crop_start, crop_end]` specifies the amount to crop from input -// dimension `i + 1`, which corresponds to spatial dimension `i`. It is -// required that -// `crop_start[i] + crop_end[i] <= block_shape[i] * input_shape[i + 1]`. -// -// This operation is equivalent to the following steps: -// -// 1. Reshape `input` to `reshaped` of shape: -// [block_shape[0], ..., block_shape[M-1], -// batch / prod(block_shape), -// input_shape[1], ..., input_shape[N-1]] -// -// 2. Permute dimensions of `reshaped` to produce `permuted` of shape -// [batch / prod(block_shape), -// -// input_shape[1], block_shape[0], -// ..., -// input_shape[M], block_shape[M-1], -// -// input_shape[M+1], ..., input_shape[N-1]] -// -// 3. Reshape `permuted` to produce `reshaped_permuted` of shape -// [batch / prod(block_shape), -// -// input_shape[1] * block_shape[0], -// ..., -// input_shape[M] * block_shape[M-1], -// -// input_shape[M+1], -// ..., -// input_shape[N-1]] -// -// 4. Crop the start and end of dimensions `[1, ..., M]` of -// `reshaped_permuted` according to `crops` to produce the output of shape: -// [batch / prod(block_shape), -// -// input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], -// ..., -// input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], -// -// input_shape[M+1], ..., input_shape[N-1]] -// -// Some examples: -// -// (1) For the following input of shape `[4, 1, 1, 1]`, `block_shape = [2, 2]`, and -// `crops = [[0, 0], [0, 0]]`: -// -// ``` -// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] -// ``` -// -// The output tensor has shape `[1, 2, 2, 1]` and value: -// -// ``` -// x = [[[[1], [2]], [[3], [4]]]] -// ``` -// -// (2) For the following input of shape `[4, 1, 1, 3]`, `block_shape = [2, 2]`, and -// `crops = [[0, 0], [0, 0]]`: -// -// ``` -// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] -// ``` +// UniqueV2Attr is an optional argument to UniqueV2. +type UniqueV2Attr func(optionalAttr) + +// UniqueV2OutIdx sets the optional out_idx attribute to value. +// If not specified, defaults to DT_INT32 +func UniqueV2OutIdx(value tf.DataType) UniqueV2Attr { + return func(m optionalAttr) { + m["out_idx"] = value + } +} + +// Finds unique elements in a 1-D tensor. // -// The output tensor has shape `[1, 2, 2, 3]` and value: +// This operation returns a tensor `y` containing all of the unique elements of `x` +// sorted in the same order that they occur in `x`. This operation also returns a +// tensor `idx` the same size as `x` that contains the index of each value of `x` +// in the unique output `y`. In other words: // -// ``` -// x = [[[[1, 2, 3], [4, 5, 6]], -// [[7, 8, 9], [10, 11, 12]]]] -// ``` +// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` // -// (3) For the following input of shape `[4, 2, 2, 1]`, `block_shape = [2, 2]`, and -// `crops = [[0, 0], [0, 0]]`: +// For example: // // ``` -// x = [[[[1], [3]], [[9], [11]]], -// [[[2], [4]], [[10], [12]]], -// [[[5], [7]], [[13], [15]]], -// [[[6], [8]], [[14], [16]]]] +// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] +// y, idx = unique(x) +// y ==> [1, 2, 4, 7, 8] +// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] // ``` // -// The output tensor has shape `[1, 4, 4, 1]` and value: -// -// ``` -// x = [[[1], [2], [3], [4]], -// [[5], [6], [7], [8]], -// [[9], [10], [11], [12]], -// [[13], [14], [15], [16]]] -// ``` +// Arguments: +// x: A `Tensor`. +// axis: A `Tensor` of type `int64` (default: 0). The axis of the Tensor to +// find the unique elements. // -// (4) For the following input of shape `[8, 1, 3, 1]`, `block_shape = [2, 2]`, and -// `crops = [[0, 0], [2, 0]]`: +// Returns A `Tensor`. Unique elements along the `axis` of `Tensor` x.A 1-D Tensor. Has the same type as x that contains the index of each +// value of x in the output y. +func UniqueV2(scope *Scope, x tf.Output, axis tf.Output, optional ...UniqueV2Attr) (y tf.Output, idx tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "UniqueV2", + Input: []tf.Input{ + x, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Return a slice from 'input'. // -// ``` -// x = [[[[0], [1], [3]]], [[[0], [9], [11]]], -// [[[0], [2], [4]]], [[[0], [10], [12]]], -// [[[0], [5], [7]]], [[[0], [13], [15]]], -// [[[0], [6], [8]]], [[[0], [14], [16]]]] -// ``` +// The output tensor is a tensor with dimensions described by 'size' +// whose values are extracted from 'input' starting at the offsets in +// 'begin'. // -// The output tensor has shape `[2, 2, 4, 1]` and value: +// *Requirements*: +// 0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n) // -// ``` -// x = [[[[1], [2], [3], [4]], -// [[5], [6], [7], [8]]], -// [[[9], [10], [11], [12]], -// [[13], [14], [15], [16]]]] -// ``` -func BatchToSpaceND(scope *Scope, input tf.Output, block_shape tf.Output, crops tf.Output) (output tf.Output) { +// Arguments: +// +// begin: begin[i] specifies the offset into the 'i'th dimension of +// 'input' to slice from. +// size: size[i] specifies the number of elements of the 'i'th dimension +// of 'input' to slice. If size[i] is -1, all remaining elements in dimension +// i are included in the slice (i.e. this is equivalent to setting +// size[i] = input.dim_size(i) - begin[i]). +func Slice(scope *Scope, input tf.Output, begin tf.Output, size tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "BatchToSpaceND", + Type: "Slice", Input: []tf.Input{ - input, block_shape, crops, + input, begin, size, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Extract `patches` from `images` and put them in the "depth" output dimension. -// -// Arguments: -// images: 4-D Tensor with shape `[batch, in_rows, in_cols, depth]`. -// ksizes: The size of the sliding window for each dimension of `images`. -// strides: 1-D of length 4. How far the centers of two consecutive patches are in -// the images. Must be: `[1, stride_rows, stride_cols, 1]`. -// rates: 1-D of length 4. Must be: `[1, rate_rows, rate_cols, 1]`. This is the -// input stride, specifying how far two consecutive patch samples are in the -// input. Equivalent to extracting patches with -// `patch_sizes_eff = patch_sizes + (patch_sizes - 1) * (rates - 1)`, followed by -// subsampling them spatially by a factor of `rates`. This is equivalent to -// `rate` in dilated (a.k.a. Atrous) convolutions. -// padding: The type of padding algorithm to use. -// -// We specify the size-related attributes as: +// StridedSliceGradAttr is an optional argument to StridedSliceGrad. +type StridedSliceGradAttr func(optionalAttr) + +// StridedSliceGradBeginMask sets the optional begin_mask attribute to value. +// If not specified, defaults to 0 +func StridedSliceGradBeginMask(value int64) StridedSliceGradAttr { + return func(m optionalAttr) { + m["begin_mask"] = value + } +} + +// StridedSliceGradEndMask sets the optional end_mask attribute to value. +// If not specified, defaults to 0 +func StridedSliceGradEndMask(value int64) StridedSliceGradAttr { + return func(m optionalAttr) { + m["end_mask"] = value + } +} + +// StridedSliceGradEllipsisMask sets the optional ellipsis_mask attribute to value. +// If not specified, defaults to 0 +func StridedSliceGradEllipsisMask(value int64) StridedSliceGradAttr { + return func(m optionalAttr) { + m["ellipsis_mask"] = value + } +} + +// StridedSliceGradNewAxisMask sets the optional new_axis_mask attribute to value. +// If not specified, defaults to 0 +func StridedSliceGradNewAxisMask(value int64) StridedSliceGradAttr { + return func(m optionalAttr) { + m["new_axis_mask"] = value + } +} + +// StridedSliceGradShrinkAxisMask sets the optional shrink_axis_mask attribute to value. +// If not specified, defaults to 0 +func StridedSliceGradShrinkAxisMask(value int64) StridedSliceGradAttr { + return func(m optionalAttr) { + m["shrink_axis_mask"] = value + } +} + +// Returns the gradient of `StridedSlice`. // -// ```python -// ksizes = [1, ksize_rows, ksize_cols, 1] -// strides = [1, strides_rows, strides_cols, 1] -// rates = [1, rates_rows, rates_cols, 1] -// ``` +// Since `StridedSlice` cuts out pieces of its `input` which is size +// `shape`, its gradient will have the same shape (which is passed here +// as `shape`). The gradient will be zero in any element that the slice +// does not select. // -// Returns 4-D Tensor with shape `[batch, out_rows, out_cols, ksize_rows * -// ksize_cols * depth]` containing image patches with size -// `ksize_rows x ksize_cols x depth` vectorized in the "depth" dimension. Note -// `out_rows` and `out_cols` are the dimensions of the output patches. -func ExtractImagePatches(scope *Scope, images tf.Output, ksizes []int64, strides []int64, rates []int64, padding string) (patches tf.Output) { +// Arguments are the same as StridedSliceGrad with the exception that +// `dy` is the input gradient to be propagated and `shape` is the +// shape of `StridedSlice`'s `input`. +func StridedSliceGrad(scope *Scope, shape tf.Output, begin tf.Output, end tf.Output, strides tf.Output, dy tf.Output, optional ...StridedSliceGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksizes": ksizes, "strides": strides, "rates": rates, "padding": padding} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ExtractImagePatches", + Type: "StridedSliceGrad", Input: []tf.Input{ - images, + shape, begin, end, strides, dy, }, Attrs: attrs, } @@ -27753,148 +27584,74 @@ func ExtractImagePatches(scope *Scope, images tf.Output, ksizes []int64, strides return op.Output(0) } -// Bitcasts a tensor from one type to another without copying data. -// -// Given a tensor `input`, this operation returns a tensor that has the same buffer -// data as `input` with datatype `type`. -// -// If the input datatype `T` is larger than the output datatype `type` then the -// shape changes from [...] to [..., sizeof(`T`)/sizeof(`type`)]. +// Returns the gradient of `Tile`. // -// If `T` is smaller than `type`, the operator requires that the rightmost -// dimension be equal to sizeof(`type`)/sizeof(`T`). The shape then goes from -// [..., sizeof(`type`)/sizeof(`T`)] to [...]. +// DEPRECATED at GraphDef version 3: TileGrad has been replaced with reduce_sum // -// *NOTE*: Bitcast is implemented as a low-level cast, so machines with different -// endian orderings will give different results. -func Bitcast(scope *Scope, input tf.Output, type_ tf.DataType) (output tf.Output) { +// Since `Tile` takes an input and repeats the input `multiples` times +// along each dimension, `TileGrad` takes in `multiples` and aggregates +// each repeated tile of `input` into `output`. +func TileGrad(scope *Scope, input tf.Output, multiples tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"type": type_} opspec := tf.OpSpec{ - Type: "Bitcast", + Type: "TileGrad", Input: []tf.Input{ - input, + input, multiples, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// OneHotAttr is an optional argument to OneHot. -type OneHotAttr func(optionalAttr) +// QuantizeAndDequantizeAttr is an optional argument to QuantizeAndDequantize. +type QuantizeAndDequantizeAttr func(optionalAttr) -// OneHotAxis sets the optional axis attribute to value. -// -// value: The axis to fill (default: -1, a new inner-most axis). -// If not specified, defaults to -1 -func OneHotAxis(value int64) OneHotAttr { +// QuantizeAndDequantizeSignedInput sets the optional signed_input attribute to value. +// If not specified, defaults to true +func QuantizeAndDequantizeSignedInput(value bool) QuantizeAndDequantizeAttr { return func(m optionalAttr) { - m["axis"] = value + m["signed_input"] = value } } -// Returns a one-hot tensor. -// -// The locations represented by indices in `indices` take value `on_value`, -// while all other locations take value `off_value`. -// -// If the input `indices` is rank `N`, the output will have rank `N+1`, -// The new axis is created at dimension `axis` (default: the new axis is -// appended at the end). -// -// If `indices` is a scalar the output shape will be a vector of length `depth`. -// -// If `indices` is a vector of length `features`, the output shape will be: -// ``` -// features x depth if axis == -1 -// depth x features if axis == 0 -// ``` -// -// If `indices` is a matrix (batch) with shape `[batch, features]`, -// the output shape will be: -// ``` -// batch x features x depth if axis == -1 -// batch x depth x features if axis == 1 -// depth x batch x features if axis == 0 -// ``` -// -// -// Examples -// ========= -// -// Suppose that -// -// ``` -// indices = [0, 2, -1, 1] -// depth = 3 -// on_value = 5.0 -// off_value = 0.0 -// axis = -1 -// ``` -// -// Then output is `[4 x 3]`: -// -// ```output = -// [5.0 0.0 0.0] // one_hot(0) -// [0.0 0.0 5.0] // one_hot(2) -// [0.0 0.0 0.0] // one_hot(-1) -// [0.0 5.0 0.0] // one_hot(1) -// ``` -// -// Suppose that -// -// ``` -// indices = [0, 2, -1, 1] -// depth = 3 -// on_value = 0.0 -// off_value = 3.0 -// axis = 0 -// ``` -// -// Then output is `[3 x 4]`: -// -// ```output = -// [0.0 3.0 3.0 3.0] -// [3.0 3.0 3.0 0.0] -// [3.0 3.0 3.0 3.0] -// [3.0 0.0 3.0 3.0] -// // ^ one_hot(0) -// // ^ one_hot(2) -// // ^ one_hot(-1) -// // ^ one_hot(1) -// ``` -// Suppose that -// -// ``` -// indices = [[0, 2], [1, -1]] -// depth = 3 -// on_value = 1.0 -// off_value = 0.0 -// axis = -1 -// ``` -// -// Then output is `[2 x 2 x 3]`: -// -// ```output = -// [ -// [1.0, 0.0, 0.0] // one_hot(0) -// [0.0, 0.0, 1.0] // one_hot(2) -// ][ -// [0.0, 1.0, 0.0] // one_hot(1) -// [0.0, 0.0, 0.0] // one_hot(-1) -// ]``` -// -// Arguments: -// indices: A tensor of indices. -// depth: A scalar defining the depth of the one hot dimension. -// on_value: A scalar defining the value to fill in output when `indices[j] = i`. -// off_value: A scalar defining the value to fill in output when `indices[j] != i`. +// QuantizeAndDequantizeNumBits sets the optional num_bits attribute to value. +// If not specified, defaults to 8 +func QuantizeAndDequantizeNumBits(value int64) QuantizeAndDequantizeAttr { + return func(m optionalAttr) { + m["num_bits"] = value + } +} + +// QuantizeAndDequantizeRangeGiven sets the optional range_given attribute to value. +// If not specified, defaults to false +func QuantizeAndDequantizeRangeGiven(value bool) QuantizeAndDequantizeAttr { + return func(m optionalAttr) { + m["range_given"] = value + } +} + +// QuantizeAndDequantizeInputMin sets the optional input_min attribute to value. +// If not specified, defaults to 0 +func QuantizeAndDequantizeInputMin(value float32) QuantizeAndDequantizeAttr { + return func(m optionalAttr) { + m["input_min"] = value + } +} + +// QuantizeAndDequantizeInputMax sets the optional input_max attribute to value. +// If not specified, defaults to 0 +func QuantizeAndDequantizeInputMax(value float32) QuantizeAndDequantizeAttr { + return func(m optionalAttr) { + m["input_max"] = value + } +} + +// Use QuantizeAndDequantizeV2 instead. // -// Returns The one-hot tensor. -func OneHot(scope *Scope, indices tf.Output, depth tf.Output, on_value tf.Output, off_value tf.Output, optional ...OneHotAttr) (output tf.Output) { +// DEPRECATED at GraphDef version 22: Replaced by QuantizeAndDequantizeV2 +func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAndDequantizeAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -27903,9 +27660,9 @@ func OneHot(scope *Scope, indices tf.Output, depth tf.Output, on_value tf.Output a(attrs) } opspec := tf.OpSpec{ - Type: "OneHot", + Type: "QuantizeAndDequantize", Input: []tf.Input{ - indices, depth, on_value, off_value, + input, }, Attrs: attrs, } @@ -28045,53 +27802,39 @@ func Where(scope *Scope, condition tf.Output) (index tf.Output) { return op.Output(0) } -// QuantizeAndDequantizeAttr is an optional argument to QuantizeAndDequantize. -type QuantizeAndDequantizeAttr func(optionalAttr) - -// QuantizeAndDequantizeSignedInput sets the optional signed_input attribute to value. -// If not specified, defaults to true -func QuantizeAndDequantizeSignedInput(value bool) QuantizeAndDequantizeAttr { - return func(m optionalAttr) { - m["signed_input"] = value - } -} - -// QuantizeAndDequantizeNumBits sets the optional num_bits attribute to value. -// If not specified, defaults to 8 -func QuantizeAndDequantizeNumBits(value int64) QuantizeAndDequantizeAttr { - return func(m optionalAttr) { - m["num_bits"] = value - } -} - -// QuantizeAndDequantizeRangeGiven sets the optional range_given attribute to value. -// If not specified, defaults to false -func QuantizeAndDequantizeRangeGiven(value bool) QuantizeAndDequantizeAttr { - return func(m optionalAttr) { - m["range_given"] = value - } -} +// DataFormatDimMapAttr is an optional argument to DataFormatDimMap. +type DataFormatDimMapAttr func(optionalAttr) -// QuantizeAndDequantizeInputMin sets the optional input_min attribute to value. -// If not specified, defaults to 0 -func QuantizeAndDequantizeInputMin(value float32) QuantizeAndDequantizeAttr { +// DataFormatDimMapSrcFormat sets the optional src_format attribute to value. +// +// value: source data format. +// If not specified, defaults to "NHWC" +func DataFormatDimMapSrcFormat(value string) DataFormatDimMapAttr { return func(m optionalAttr) { - m["input_min"] = value + m["src_format"] = value } } -// QuantizeAndDequantizeInputMax sets the optional input_max attribute to value. -// If not specified, defaults to 0 -func QuantizeAndDequantizeInputMax(value float32) QuantizeAndDequantizeAttr { +// DataFormatDimMapDstFormat sets the optional dst_format attribute to value. +// +// value: destination data format. +// If not specified, defaults to "NCHW" +func DataFormatDimMapDstFormat(value string) DataFormatDimMapAttr { return func(m optionalAttr) { - m["input_max"] = value + m["dst_format"] = value } } -// Use QuantizeAndDequantizeV2 instead. +// Returns the dimension index in the destination data format given the one in // -// DEPRECATED at GraphDef version 22: Replaced by QuantizeAndDequantizeV2 -func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAndDequantizeAttr) (output tf.Output) { +// the source data format. +// +// Arguments: +// x: A Tensor with each element as a dimension index in source data format. +// Must be in the range [-4, 4). +// +// Returns A Tensor with each element as a dimension index in destination data format. +func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAttr) (y tf.Output) { if scope.Err() != nil { return } @@ -28100,9 +27843,9 @@ func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAn a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizeAndDequantize", + Type: "DataFormatDimMap", Input: []tf.Input{ - input, + x, }, Attrs: attrs, } @@ -28110,173 +27853,37 @@ func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAn return op.Output(0) } -// Returns the diagonal part of the tensor. -// -// This operation returns a tensor with the `diagonal` part -// of the `input`. The `diagonal` part is computed as follows: -// -// Assume `input` has dimensions `[D1,..., Dk, D1,..., Dk]`, then the output is a -// tensor of rank `k` with dimensions `[D1,..., Dk]` where: -// -// `diagonal[i1,..., ik] = input[i1, ..., ik, i1,..., ik]`. -// -// For example: -// -// ``` -// # 'input' is [[1, 0, 0, 0] -// [0, 2, 0, 0] -// [0, 0, 3, 0] -// [0, 0, 0, 4]] -// -// tf.diag_part(input) ==> [1, 2, 3, 4] -// ``` -// -// Arguments: -// input: Rank k tensor where k is even and not zero. +// Return the shape of s0 op s1 with broadcast. // -// Returns The extracted diagonal. -func DiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) { +// Given `s0` and `s1`, tensors that represent shapes, compute `r0`, the +// broadcasted shape. `s0`, `s1` and `r0` are all integer vectors. +func BroadcastArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "DiagPart", + Type: "BroadcastArgs", Input: []tf.Input{ - input, + s0, s1, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// QuantizedInstanceNormAttr is an optional argument to QuantizedInstanceNorm. -type QuantizedInstanceNormAttr func(optionalAttr) - -// QuantizedInstanceNormOutputRangeGiven sets the optional output_range_given attribute to value. -// -// value: If True, `given_y_min` and `given_y_min` -// and `given_y_max` are used as the output range. Otherwise, -// the implementation computes the output range. -// If not specified, defaults to false -func QuantizedInstanceNormOutputRangeGiven(value bool) QuantizedInstanceNormAttr { - return func(m optionalAttr) { - m["output_range_given"] = value - } -} - -// QuantizedInstanceNormGivenYMin sets the optional given_y_min attribute to value. -// -// value: Output in `y_min` if `output_range_given` is True. -// If not specified, defaults to 0 -func QuantizedInstanceNormGivenYMin(value float32) QuantizedInstanceNormAttr { - return func(m optionalAttr) { - m["given_y_min"] = value - } -} - -// QuantizedInstanceNormGivenYMax sets the optional given_y_max attribute to value. -// -// value: Output in `y_max` if `output_range_given` is True. -// If not specified, defaults to 0 -func QuantizedInstanceNormGivenYMax(value float32) QuantizedInstanceNormAttr { - return func(m optionalAttr) { - m["given_y_max"] = value - } -} - -// QuantizedInstanceNormVarianceEpsilon sets the optional variance_epsilon attribute to value. -// -// value: A small float number to avoid dividing by 0. -// If not specified, defaults to 1e-05 -func QuantizedInstanceNormVarianceEpsilon(value float32) QuantizedInstanceNormAttr { - return func(m optionalAttr) { - m["variance_epsilon"] = value - } -} - -// QuantizedInstanceNormMinSeparation sets the optional min_separation attribute to value. -// -// value: Minimum value of `y_max - y_min` -// If not specified, defaults to 0.001 -func QuantizedInstanceNormMinSeparation(value float32) QuantizedInstanceNormAttr { - return func(m optionalAttr) { - m["min_separation"] = value - } -} - -// Quantized Instance normalization. -// -// Arguments: -// x: A 4D input Tensor. -// x_min: The value represented by the lowest quantized input. -// x_max: The value represented by the highest quantized input. -// -// Returns A 4D Tensor.The value represented by the lowest quantized output.The value represented by the highest quantized output. -func QuantizedInstanceNorm(scope *Scope, x tf.Output, x_min tf.Output, x_max tf.Output, optional ...QuantizedInstanceNormAttr) (y tf.Output, y_min tf.Output, y_max tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizedInstanceNorm", - Input: []tf.Input{ - x, x_min, x_max, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// FakeQuantWithMinMaxVarsAttr is an optional argument to FakeQuantWithMinMaxVars. -type FakeQuantWithMinMaxVarsAttr func(optionalAttr) - -// FakeQuantWithMinMaxVarsNumBits sets the optional num_bits attribute to value. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxVarsNumBits(value int64) FakeQuantWithMinMaxVarsAttr { - return func(m optionalAttr) { - m["num_bits"] = value - } -} - -// FakeQuantWithMinMaxVarsNarrowRange sets the optional narrow_range attribute to value. -// If not specified, defaults to false -func FakeQuantWithMinMaxVarsNarrowRange(value bool) FakeQuantWithMinMaxVarsAttr { - return func(m optionalAttr) { - m["narrow_range"] = value - } -} - -// Fake-quantize the 'inputs' tensor of type float via global float scalars `min` -// -// and `max` to 'outputs' tensor of same shape as `inputs`. -// -// `[min; max]` define the clamping range for the `inputs` data. -// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` -// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and -// then de-quantized and output as floats in `[min; max]` interval. -// `num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive. +// Return the reduction indices for computing gradients of s0 op s1 with broadcast. // -// This operation has a gradient and thus allows for training `min` and `max` -// values. -func FakeQuantWithMinMaxVars(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsAttr) (outputs tf.Output) { +// This is typically used by gradient computations for a broadcasting operation. +func BroadcastGradientArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Output, r1 tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxVars", + Type: "BroadcastGradientArgs", Input: []tf.Input{ - inputs, min, max, + s0, s1, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 9dee1aa72bf0d76ee35931f1e852bfd22556a540..7296205e2403f68587991e1d4c9ce57899eece92 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -311,9 +311,11 @@ tf_cc_test( srcs = [ "src/gen/cc/source_writer_test.cc", ], + data = [ + "src/gen/resources/test.snippet.java", + ], deps = [ ":java_op_gen_lib", - "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", ], diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml index d35bb4111271c11839a160517dc9695ead5b46e9..0b69a8cbe530a13dc35aad3a5c859f77f0deca2a 100644 --- a/tensorflow/java/maven/libtensorflow/pom.xml +++ b/tensorflow/java/maven/libtensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.6.0-rc1 + 1.7.0-rc1 ../ libtensorflow diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml index d9ba1bbbfb91170257f64a56f47c6c980e8a9570..541876f7f5e4fadcbc9336f15b319389dcddbf51 100644 --- a/tensorflow/java/maven/libtensorflow_jni/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.6.0-rc1 + 1.7.0-rc1 ../ libtensorflow_jni diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml index f6f532c2c10d0a4dad9fc2d7750ea708652000b1..d8933e5238149337b08e70b3f407385887aef0a0 100644 --- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.6.0-rc1 + 1.7.0-rc1 ../ libtensorflow_jni_gpu diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml index 0a6b3d23d7d37515cf275e6a46842e32ada4fee1..6286fd73df6dec5643fceda8f6f652220d75e1a7 100644 --- a/tensorflow/java/maven/pom.xml +++ b/tensorflow/java/maven/pom.xml @@ -6,7 +6,7 @@ 4.0.0 org.tensorflow parentpom - 1.6.0-rc1 + 1.7.0-rc1 pom https://www.tensorflow.org diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml index 1d8e8723731f959c8142f0648fc805593d7beac8..4e881f5a631f0b2e389b31a9b24028902eac6301 100644 --- a/tensorflow/java/maven/proto/pom.xml +++ b/tensorflow/java/maven/proto/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.6.0-rc1 + 1.7.0-rc1 ../ proto diff --git a/tensorflow/java/maven/tensorflow-android/pom-android.xml.template b/tensorflow/java/maven/tensorflow-android/pom-android.xml.template index 5cbd0c898dc52ec5dfb72f0a2ac893d492a7d4be..37d2372d7b09f6f144e7abb145cb75bf98356615 100644 --- a/tensorflow/java/maven/tensorflow-android/pom-android.xml.template +++ b/tensorflow/java/maven/tensorflow-android/pom-android.xml.template @@ -20,10 +20,8 @@ UTF-8 - ${build_number} ${build_commit_id} ${build_type} - ${build_url} diff --git a/tensorflow/java/maven/tensorflow-android/update.py b/tensorflow/java/maven/tensorflow-android/update.py index 4ae666e4e5351f1bdaf79d1b5cfdb63b0f811e2b..2206d800ca1fe82c5596ff39e56518bc5aea6211 100644 --- a/tensorflow/java/maven/tensorflow-android/update.py +++ b/tensorflow/java/maven/tensorflow-android/update.py @@ -45,6 +45,9 @@ def get_json(url): def get_commit_id(build_info): """Fetch the git commit id from the build info json object.""" + release_commit_id = build_info.get('build_commit_id') + if release_commit_id: + return release_commit_id actions = build_info.get('actions') build_data = next( a for a in actions @@ -95,20 +98,12 @@ def main(): release_prefix = 'https://storage.googleapis.com/tensorflow/libtensorflow' info_url = '%s/android_buildinfo-%s.json' % (release_prefix, args.version) aar_url = '%s/tensorflow-%s.aar' % (release_prefix, args.version) - build_type = 'release-matrix-android2' + build_type = 'release-android' # Retrieve build information build_info = get_json(info_url) # Check all required build info is present - if build_info.get('result') != 'SUCCESS': - raise ValueError('Invalid json: %s' % build_info) - build_url = build_info.get('url') - if not build_url: - raise ValueError('Missing url: %s' % build_info) - build_number = build_info.get('number') - if not build_number: - raise ValueError('Missing build number: %s' % build_info) build_commit_id = get_commit_id(build_info) if not build_commit_id: raise ValueError('Missing commit id: %s' % build_info) @@ -119,9 +114,7 @@ def main(): f.write( template.substitute({ 'build_commit_id': build_commit_id, - 'build_number': build_number, 'build_type': build_type, - 'build_url': build_url, 'version': args.version })) diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml index 5c1b55085c5df1ec473a3f4e0bf750b236cfc264..d512a7eda9638d428e02beda442ba4d4db9adf62 100644 --- a/tensorflow/java/maven/tensorflow/pom.xml +++ b/tensorflow/java/maven/tensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.6.0-rc1 + 1.7.0-rc1 ../ tensorflow diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h index 615cdc165b36abdc3cf5e717ddb8b385367c067f..59f8beaee78a2f40f6743ca10f72435e757db090 100644 --- a/tensorflow/java/src/gen/cc/java_defs.h +++ b/tensorflow/java/src/gen/cc/java_defs.h @@ -17,10 +17,7 @@ limitations under the License. #define TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_ #include -#include -#include - -#include "tensorflow/core/platform/env.h" +#include namespace tensorflow { namespace java { @@ -104,17 +101,17 @@ class Type { description_ = description; return *this; } - const std::vector& parameters() const { return parameters_; } + const std::list& parameters() const { return parameters_; } Type& add_parameter(const Type& parameter) { parameters_.push_back(parameter); return *this; } - const std::vector& annotations() const { return annotations_; } + const std::list& annotations() const { return annotations_; } Type& add_annotation(const Annotation& annotation) { annotations_.push_back(annotation); return *this; } - const std::deque& supertypes() const { return supertypes_; } + const std::list& supertypes() const { return supertypes_; } Type& add_supertype(const Type& type) { if (type.kind_ == CLASS) { supertypes_.push_front(type); // keep superclass at the front of the list @@ -141,9 +138,9 @@ class Type { string name_; string package_; string description_; - std::vector parameters_; - std::vector annotations_; - std::deque supertypes_; + std::list parameters_; + std::list annotations_; + std::list supertypes_; }; // Definition of a Java annotation @@ -223,16 +220,12 @@ class Method { return_description_ = description; return *this; } - const std::vector& arguments() const { return arguments_; } - Method& add_arguments(const std::vector& args) { - arguments_.insert(arguments_.cend(), args.cbegin(), args.cend()); - return *this; - } + const std::list& arguments() const { return arguments_; } Method& add_argument(const Variable& var) { arguments_.push_back(var); return *this; } - const std::vector& annotations() const { return annotations_; } + const std::list& annotations() const { return annotations_; } Method& add_annotation(const Annotation& annotation) { annotations_.push_back(annotation); return *this; @@ -244,29 +237,13 @@ class Method { bool constructor_; string description_; string return_description_; - std::vector arguments_; - std::vector annotations_; + std::list arguments_; + std::list annotations_; Method(const string& name, const Type& return_type, bool constructor) : name_(name), return_type_(return_type), constructor_(constructor) {} }; -// A piece of code to read from a file. -class Snippet { - public: - static Snippet Create(const string& fname, Env* env = Env::Default()) { - return Snippet(fname, env); - } - const string& data() const { return data_; } - - private: - string data_; - - Snippet(const string& fname, Env* env) { - TF_CHECK_OK(ReadFileToString(env, fname, &data_)); - } -}; - } // namespace java } // namespace tensorflow diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc index 2da81f2911e60be6a47ac13fe8be6142fa283780..214999af9a6f9ee244d336a64830238e6b7ea872 100644 --- a/tensorflow/java/src/gen/cc/source_writer.cc +++ b/tensorflow/java/src/gen/cc/source_writer.cc @@ -14,49 +14,318 @@ limitations under the License. ==============================================================================*/ #include +#include +#include #include "tensorflow/java/src/gen/cc/source_writer.h" namespace tensorflow { +namespace java { -SourceWriter& SourceWriter::Append(const StringPiece& str) { - if (!str.empty()) { - if (newline_) { - DoAppend(left_margin_ + line_prefix_); - newline_ = false; - } - DoAppend(str); - } +SourceWriter::SourceWriter() { + // push an empty generic namespace at start, for simplification + generic_namespaces_.push(new GenericNamespace()); +} + +SourceWriter& SourceWriter::Indent(int tab) { + left_margin_.resize( + std::max(static_cast(left_margin_.size() + tab), 0), ' '); + return *this; +} + +SourceWriter& SourceWriter::Prefix(const char* line_prefix) { + line_prefix_ = line_prefix; return *this; } -SourceWriter& SourceWriter::Write(const string& str) { +SourceWriter& SourceWriter::Write(const StringPiece& str) { size_t line_pos = 0; do { size_t start_pos = line_pos; line_pos = str.find('\n', start_pos); if (line_pos != string::npos) { ++line_pos; - Append(StringPiece(str.data() + start_pos, line_pos - start_pos)); + Append(str.substr(start_pos, line_pos - start_pos)); newline_ = true; } else { - Append(StringPiece(str.data() + start_pos, str.size() - start_pos)); + Append(str.substr(start_pos, str.size() - start_pos)); } } while (line_pos != string::npos && line_pos < str.size()); return *this; } +SourceWriter& SourceWriter::WriteFromFile(const string& fname, Env* env) { + string data_; + TF_CHECK_OK(ReadFileToString(env, fname, &data_)); + return Write(data_); +} + +SourceWriter& SourceWriter::Append(const StringPiece& str) { + if (!str.empty()) { + if (newline_) { + DoAppend(left_margin_ + line_prefix_); + newline_ = false; + } + DoAppend(str); + } + return *this; +} + +SourceWriter& SourceWriter::AppendType(const Type& type) { + if (type.kind() == Type::Kind::GENERIC && type.name().empty()) { + Append("?"); + } else { + Append(type.name()); + } + if (!type.parameters().empty()) { + Append("<"); + for (const Type& t : type.parameters()) { + if (&t != &type.parameters().front()) { + Append(", "); + } + AppendType(t); + } + Append(">"); + } + return *this; +} + SourceWriter& SourceWriter::EndLine() { Append("\n"); newline_ = true; return *this; } -SourceWriter& SourceWriter::Indent(int tab) { - left_margin_.resize(std::max(static_cast(left_margin_.size() + tab), 0), - ' '); +SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) { + GenericNamespace* generic_namespace = PushGenericNamespace(modifiers); + if (!method.constructor()) { + generic_namespace->Visit(method.return_type()); + } + for (const Variable& v : method.arguments()) { + generic_namespace->Visit(v.type()); + } + EndLine(); + WriteDoc(method.description(), method.return_description(), + &method.arguments()); + if (!method.annotations().empty()) { + WriteAnnotations(method.annotations()); + } + WriteModifiers(modifiers); + if (!generic_namespace->declared_types().empty()) { + WriteGenerics(generic_namespace->declared_types()); + Append(" "); + } + if (!method.constructor()) { + AppendType(method.return_type()).Append(" "); + } + Append(method.name()).Append("("); + for (const Variable& v : method.arguments()) { + if (&v != &method.arguments().front()) { + Append(", "); + } + AppendType(v.type()).Append(v.variadic() ? "... " : " ").Append(v.name()); + } + return Append(")").BeginBlock(); +} + +SourceWriter& SourceWriter::EndMethod() { + EndBlock(); + PopGenericNamespace(); return *this; } +SourceWriter& SourceWriter::BeginType(const Type& type, + const std::list* dependencies, int modifiers) { + if (!type.package().empty()) { + Append("package ").Append(type.package()).Append(";").EndLine(); + } + if (dependencies != nullptr && !dependencies->empty()) { + TypeImporter type_importer(type.package()); + for (const Type& t : *dependencies) { + type_importer.Visit(t); + } + EndLine(); + for (const string& s : type_importer.imports()) { + Append("import ").Append(s).Append(";").EndLine(); + } + } + return BeginInnerType(type, modifiers); +} + +SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers) { + GenericNamespace* generic_namespace = PushGenericNamespace(modifiers); + generic_namespace->Visit(type); + EndLine(); + WriteDoc(type.description()); + if (!type.annotations().empty()) { + WriteAnnotations(type.annotations()); + } + WriteModifiers(modifiers); + CHECK_EQ(Type::Kind::CLASS, type.kind()) << ": Not supported yet"; + Append("class ").Append(type.name()); + if (!generic_namespace->declared_types().empty()) { + WriteGenerics(generic_namespace->declared_types()); + } + if (!type.supertypes().empty()) { + bool first_interface = true; + for (const Type& t : type.supertypes()) { + if (t.kind() == Type::CLASS) { // superclass is always first in list + Append(" extends "); + } else if (first_interface) { + Append(" implements "); + first_interface = false; + } else { + Append(", "); + } + AppendType(t); + } + } + return BeginBlock(); +} + +SourceWriter& SourceWriter::EndType() { + EndBlock(); + PopGenericNamespace(); + return *this; +} + +SourceWriter& SourceWriter::WriteFields(const std::list& fields, + int modifiers) { + EndLine(); + for (const Variable& v : fields) { + WriteModifiers(modifiers); + AppendType(v.type()).Append(" ").Append(v.name()).Append(";"); + EndLine(); + } + return *this; +} + +SourceWriter& SourceWriter::WriteModifiers(int modifiers) { + if (modifiers & PUBLIC) { + Append("public "); + } else if (modifiers & PROTECTED) { + Append("protected "); + } else if (modifiers & PRIVATE) { + Append("private "); + } + if (modifiers & STATIC) { + Append("static "); + } + if (modifiers & FINAL) { + Append("final "); + } + return *this; +} + +SourceWriter& SourceWriter::WriteDoc(const string& description, + const string& return_description, const std::list* parameters) { + if (description.empty() && return_description.empty() + && (parameters == nullptr || parameters->empty())) { + return *this; // no doc to write + } + bool do_line_break = false; + Append("/**").EndLine().Prefix(" * "); + if (!description.empty()) { + Write(description).EndLine(); + do_line_break = true; + } + if (parameters != nullptr && !parameters->empty()) { + if (do_line_break) { + EndLine(); + do_line_break = false; + } + for (const Variable& v : *parameters) { + Append("@param ").Append(v.name()); + if (!v.description().empty()) { + Append(" ").Write(v.description()); + } + EndLine(); + } + } + if (!return_description.empty()) { + if (do_line_break) { + EndLine(); + do_line_break = false; + } + Append("@return ").Write(return_description).EndLine(); + } + return Prefix("").Append(" **/").EndLine(); +} + +SourceWriter& SourceWriter::WriteAnnotations( + const std::list& annotations) { + for (const Annotation& a : annotations) { + Append("@" + a.name()); + if (!a.attributes().empty()) { + Append("(").Append(a.attributes()).Append(")"); + } + EndLine(); + } + return *this; +} + +SourceWriter& SourceWriter::WriteGenerics( + const std::list& generics) { + Append("<"); + for (const Type* pt : generics) { + if (pt != generics.front()) { + Append(", "); + } + Append(pt->name()); + if (!pt->supertypes().empty()) { + Append(" extends ").AppendType(pt->supertypes().front()); + } + } + return Append(">"); +} + +SourceWriter::GenericNamespace* SourceWriter::PushGenericNamespace( + int modifiers) { + GenericNamespace* generic_namespace; + if (modifiers & STATIC) { + generic_namespace = new GenericNamespace(); + } else { + generic_namespace = new GenericNamespace(generic_namespaces_.top()); + } + generic_namespaces_.push(generic_namespace); + return generic_namespace; +} + +void SourceWriter::PopGenericNamespace() { + GenericNamespace* generic_namespace = generic_namespaces_.top(); + generic_namespaces_.pop(); + delete generic_namespace; +} + +void SourceWriter::TypeVisitor::Visit(const Type& type) { + DoVisit(type); + for (const Type& t : type.parameters()) { + DoVisit(t); + } + for (const Annotation& t : type.annotations()) { + DoVisit(t); + } + for (const Type& t : type.supertypes()) { + DoVisit(t); + } +} + +void SourceWriter::GenericNamespace::DoVisit(const Type& type) { + // ignore non-generic parameters, wildcards and generics already declared + if (type.kind() == Type::GENERIC + && !type.IsWildcard() + && generic_names_.find(type.name()) == generic_names_.end()) { + declared_types_.push_back(&type); + generic_names_.insert(type.name()); + } +} + +void SourceWriter::TypeImporter::DoVisit(const Type& type) { + if (!type.package().empty() && type.package() != current_package_) { + imports_.insert(type.package() + '.' + type.name()); + } +} + +} // namespace java } // namespace tensorflow diff --git a/tensorflow/java/src/gen/cc/source_writer.h b/tensorflow/java/src/gen/cc/source_writer.h index bff26eb185db0cf933632f33f916b87d8a757edd..6abe13b5d217b30d826d013e14a590eeb91719fb 100644 --- a/tensorflow/java/src/gen/cc/source_writer.h +++ b/tensorflow/java/src/gen/cc/source_writer.h @@ -17,45 +17,23 @@ limitations under the License. #define TENSORFLOW_JAVA_SRC_GEN_CC_SOURCE_WRITER_H_ #include +#include +#include +#include #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/java/src/gen/cc/java_defs.h" namespace tensorflow { +namespace java { -// A utility class for writing source code, normally generated at -// compile-time. -// -// Source writers are language-agnostic and therefore only expose generic -// methods common to most languages. Extend or wrap this class to implement -// language-specific features. -// -// Note: if you are looking to reuse this class for generating code in another -// language than Java, please do by moving it at the '//tensorflow/core/lib/io' -// level. +// A class for writing Java source code. class SourceWriter { public: + SourceWriter(); virtual ~SourceWriter() = default; - // Returns true if the writer is at the beginnig of a new line - bool newline() const { return newline_; } - - // Appends a piece of code or text. - // - // It is expected that no newline character is present in the data provided, - // otherwise Write() must be used. - SourceWriter& Append(const StringPiece& str); - - // Writes a block of code or text. - // - // The data might potentially contain newline characters, therefore it will - // be scanned to ensure that each line is indented and prefixed properly, - // making it a bit slower than Append(). - SourceWriter& Write(const string& text); - - // Appends a newline character and start writing on a new line. - SourceWriter& EndLine(); - // Indents following lines with white spaces. // // Indentation is cumulative, i.e. the provided tabulation is added to the @@ -75,18 +53,166 @@ class SourceWriter { // Indent(2)->Prefix("//") will result in prefixing lines with " //". // // An empty value ("") will remove any line prefix that was previously set. - SourceWriter& Prefix(const char* line_prefix) { - line_prefix_ = line_prefix; - return *this; + SourceWriter& Prefix(const char* line_prefix); + + // Writes a source code snippet. + // + // The data might potentially contain newline characters, therefore it will + // be scanned to ensure that each line is indented and prefixed properly, + // making it a bit slower than Append(). + SourceWriter& Write(const StringPiece& text); + + // Writes a source code snippet read from a file. + // + // All lines of the file at the provided path will be read and written back + // to the output of this writer in regard of its current attributes (e.g. + // the indentation, prefix, etc.) + SourceWriter& WriteFromFile(const string& fname, Env* env = Env::Default()); + + // Appends a piece of source code. + // + // It is expected that no newline character is present in the data provided, + // otherwise Write() must be used. + SourceWriter& Append(const StringPiece& str); + + // Appends a type to the current line. + // + // The type is written in its simple form (i.e. not prefixed by its package) + // and followed by any parameter types it has enclosed in brackets (<>). + SourceWriter& AppendType(const Type& type); + + // Appends a newline character. + // + // Data written after calling this method will start on a new line, in respect + // of the current indentation. + SourceWriter& EndLine(); + + // Begins a block of source code. + // + // This method appends a new opening brace to the current data and indent the + // next lines according to Google Java Style Guide. The block can optionally + // be preceded by an expression (e.g. Append("if(true)").BeginBlock();) + SourceWriter& BeginBlock() { + return Append(newline_ ? "{" : " {").EndLine().Indent(2); + } + + // Ends the current block of source code. + // + // This method appends a new closing brace to the current data and outdent the + // next lines back to the margin used before BeginBlock() was invoked. + SourceWriter& EndBlock() { + return Indent(-2).Append("}").EndLine(); } + // Begins to write a method. + // + // This method outputs the signature of the Java method from the data passed + // in the 'method' parameter and starts a new block. Additionnal modifiers can + // also be passed in parameter to define the accesses and the scope of this + // method. + SourceWriter& BeginMethod(const Method& method, int modifiers = 0); + + // Ends the current method. + // + // This method ends the block of code that has begun when invoking + // BeginMethod() prior to this. + SourceWriter& EndMethod(); + + // Begins to write the main type of a source file. + // + // This method outputs the declaration of the Java type from the data passed + // in the 'type' parameter and starts a new block. Additionnal modifiers can + // also be passed in parameter to define the accesses and the scope of this + // type. + // + // If not null, all types found in the 'dependencies' list will be imported + // before declaring the new type. + SourceWriter& BeginType(const Type& clazz, + const std::list* dependencies, int modifiers = 0); + + // Begins to write a new inner type. + // + // This method outputs the declaration of the Java type from the data passed + // in the 'type' parameter and starts a new block. Additionnal modifiers can + // also be passed in parameter to define the accesses and the scope of this + // type. + SourceWriter& BeginInnerType(const Type& type, int modifiers = 0); + + // Ends the current type. + // + // This method ends the block of code that has begun when invoking + // BeginType() or BeginInnerType() prior to this. + SourceWriter& EndType(); + + // Writes a list of variables as fields of a type. + // + // This method must be called within the definition of a type (see BeginType() + // or BeginInnerType()). Additional modifiers can also be passed in parameter + // to define the accesses and the scope of those fields. + SourceWriter& WriteFields(const std::list& fields, + int modifiers = 0); + protected: virtual void DoAppend(const StringPiece& str) = 0; private: + // A utility base class for visiting elements of a type. + class TypeVisitor { + public: + virtual ~TypeVisitor() = default; + void Visit(const Type& type); + + protected: + virtual void DoVisit(const Type& type) = 0; + }; + + // A utility class for keeping track of declared generics in a given scope. + class GenericNamespace : public TypeVisitor { + public: + GenericNamespace() = default; + explicit GenericNamespace(const GenericNamespace* parent) + : generic_names_(parent->generic_names_) {} + std::list declared_types() { + return declared_types_; + } + protected: + virtual void DoVisit(const Type& type); + + private: + std::list declared_types_; + std::set generic_names_; + }; + + // A utility class for collecting a list of import statements to declare. + class TypeImporter : public TypeVisitor { + public: + explicit TypeImporter(const string& current_package) + : current_package_(current_package) {} + virtual ~TypeImporter() = default; + const std::set imports() { + return imports_; + } + protected: + virtual void DoVisit(const Type& type); + + private: + string current_package_; + std::set imports_; + }; + string left_margin_; string line_prefix_; bool newline_ = true; + std::stack generic_namespaces_; + + SourceWriter& WriteModifiers(int modifiers); + SourceWriter& WriteDoc(const string& description, + const string& return_description = "", + const std::list* parameters = nullptr); + SourceWriter& WriteAnnotations(const std::list& annotations); + SourceWriter& WriteGenerics(const std::list& generics); + GenericNamespace* PushGenericNamespace(int modifiers); + void PopGenericNamespace(); }; // A writer that outputs source code into a file. @@ -128,6 +254,7 @@ class SourceBufferWriter : public SourceWriter { string* buffer_; }; +} // namespace java } // namespace tensorflow #endif // TENSORFLOW_JAVA_SRC_GEN_CC_SOURCE_WRITER_H_ diff --git a/tensorflow/java/src/gen/cc/source_writer_test.cc b/tensorflow/java/src/gen/cc/source_writer_test.cc index e9738957548184726395c4e6634ba12a5a9a0109..6926a5a411d070e25f2382c72589d879d3ca2180 100644 --- a/tensorflow/java/src/gen/cc/source_writer_test.cc +++ b/tensorflow/java/src/gen/cc/source_writer_test.cc @@ -13,11 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/java/src/gen/cc/source_writer.h" +#include + #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/java/src/gen/cc/java_defs.h" +#include "tensorflow/java/src/gen/cc/source_writer.h" namespace tensorflow { +namespace java { namespace { TEST(AppendTest, SingleLineText) { @@ -211,5 +215,366 @@ TEST(MarginTest, EmptyPrefix) { ASSERT_STREQ(expected, writer.str().data()); } +TEST(StreamTest, BlocksAndLines) { + SourceBufferWriter writer; + + writer.Append("int i = 0;").EndLine() + .Append("int j = 10;").EndLine() + .Append("if (true)") + .BeginBlock() + .Append("int aLongWayToTen = 0;").EndLine() + .Append("while (++i <= j)") + .BeginBlock() + .Append("++aLongWayToTen;").EndLine() + .EndBlock() + .EndBlock(); + + const char* expected = + "int i = 0;\n" + "int j = 10;\n" + "if (true) {\n" + " int aLongWayToTen = 0;\n" + " while (++i <= j) {\n" + " ++aLongWayToTen;\n" + " }\n" + "}\n"; + ASSERT_STREQ(expected, writer.str().data()); +} + +TEST(StreamTest, Types) { + SourceBufferWriter writer; + Type generic = Type::Generic("T").add_supertype(Type::Class("Number")); + + writer.AppendType(Type::Int()).Append(", ") + .AppendType(Type::Class("String")).Append(", ") + .AppendType(generic).Append(", ") + .AppendType(Type::ListOf(generic)).Append(", ") + .AppendType(Type::ListOf(Type::IterableOf(generic))).Append(", ") + .AppendType(Type::ListOf(Type::Generic())); + + const char* expected = + "int, String, T, List, List>, List"; + ASSERT_STREQ(expected, writer.str().data()); +} + +TEST(StreamTest, FileSnippet) { + SourceBufferWriter writer; + const string& fname = "tensorflow/java/src/gen/resources/test.snippet.java"; + + writer.WriteFromFile(fname) + .BeginBlock() + .WriteFromFile(fname) + .EndBlock(); + + const char* expected = + "// Here is a little snippet\n" + "System.out.println(\"Hello!\");\n" + "{\n" + " // Here is a little snippet\n" + " System.out.println(\"Hello!\");\n" + "}\n"; + ASSERT_STREQ(expected, writer.str().data()); +} + +TEST(WriteType, SimpleClass) { + SourceBufferWriter writer; + Type clazz = Type::Class("Test", "org.tensorflow"); + + writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + + const char* expected = + "package org.tensorflow;\n\n" + "public class Test {\n}\n"; + ASSERT_STREQ(expected, writer.str().data()); +} + +TEST(WriteType, SimpleClassWithDependencies) { + SourceBufferWriter writer; + Type clazz = Type::Class("Test", "org.tensorflow"); + std::list deps; + deps.push_back(Type::Class("TypeA", "org.test.sub")); + deps.push_back(Type::Class("TypeA", "org.test.sub")); // a second time + deps.push_back(Type::Class("TypeB", "org.other")); + deps.push_back(Type::Class("SamePackageType", "org.tensorflow")); + deps.push_back(Type::Class("NoPackageType")); + + writer.BeginType(clazz, &deps, PUBLIC).EndType(); + + const char* expected = + "package org.tensorflow;\n\n" + "import org.other.TypeB;\n" + "import org.test.sub.TypeA;\n\n" + "public class Test {\n}\n"; + ASSERT_STREQ(expected, writer.str().data()); +} + +TEST(WriteType, AnnotatedAndDocumentedClass) { + SourceBufferWriter writer; + Type clazz = Type::Class("Test", "org.tensorflow"); + clazz.description("This class has a\n

\nmultiline description."); + clazz.add_annotation(Annotation::Create("Bean")); + clazz.add_annotation(Annotation::Create("SuppressWarnings") + .attributes("\"rawtypes\"")); + + writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + + const char* expected = + "package org.tensorflow;\n\n" + "/**\n" + " * This class has a\n" + " *

\n" + " * multiline description.\n" + " **/\n" + "@Bean\n" + "@SuppressWarnings(\"rawtypes\")\n" + "public class Test {\n}\n"; + ASSERT_STREQ(expected, writer.str().data()); +} + +TEST(WriteType, ParameterizedClass) { + SourceBufferWriter writer; + Type clazz = Type::Class("Test", "org.tensorflow"); + clazz.add_parameter(Type::Generic("T")); + clazz.add_parameter(Type::Generic("U").add_supertype(Type::Class("Number"))); + + writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + + const char* expected = + "package org.tensorflow;\n\n" + "public class Test {\n}\n"; + ASSERT_STREQ(expected, writer.str().data()); +} + +TEST(WriteType, ParameterizedClassAndSupertypes) { + SourceBufferWriter writer; + Type clazz = Type::Class("Test", "org.tensorflow"); + Type type_t = Type::Generic("T"); + clazz.add_parameter(type_t); + Type type_u = Type::Generic("U").add_supertype(Type::Class("Number")); + clazz.add_parameter(type_u); + clazz.add_supertype(Type::Interface("Parametrizable").add_parameter(type_u)); + clazz.add_supertype(Type::Interface("Runnable")); + clazz.add_supertype(Type::Class("SuperTest").add_parameter(type_t)); + + writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + + const char* expected = + "package org.tensorflow;\n\n" + "public class Test" + " extends SuperTest implements Parametrizable, Runnable {\n}\n"; + ASSERT_STREQ(expected, writer.str().data()); +} + +TEST(WriteType, ParameterizedClassFields) { + SourceBufferWriter writer; + Type clazz = Type::Class("Test", "org.tensorflow"); + Type type_t = Type::Generic("T").add_supertype(Type::Class("Number")); + clazz.add_parameter(type_t); + std::list static_fields; + static_fields.push_back(Variable::Create("field1", Type::Class("String"))); + std::list member_fields; + member_fields.push_back(Variable::Create("field2", Type::Class("String"))); + member_fields.push_back(Variable::Create("field3", type_t)); + + writer.BeginType(clazz, nullptr, PUBLIC) + .WriteFields(static_fields, STATIC | PUBLIC | FINAL) + .WriteFields(member_fields, PRIVATE) + .EndType(); + + const char* expected = + "package org.tensorflow;\n\n" + "public class Test {\n" + " \n" + " public static final String field1;\n" + " \n" + " private String field2;\n" + " private T field3;\n" + "}\n"; + ASSERT_STREQ(expected, writer.str().data()); +} + +TEST(WriteType, SimpleInnerClass) { + SourceBufferWriter writer; + Type clazz = Type::Class("Test", "org.tensorflow"); + Type inner_class = Type::Class("InnerTest"); + + writer.BeginType(clazz, nullptr, PUBLIC) + .BeginInnerType(inner_class, PUBLIC) + .EndType() + .EndType(); + + const char* expected = + "package org.tensorflow;\n\n" + "public class Test {\n" + " \n" + " public class InnerTest {\n" + " }\n" + "}\n"; + ASSERT_STREQ(expected, writer.str().data()); +} + +TEST(WriteType, StaticParameterizedInnerClass) { + SourceBufferWriter writer; + Type clazz = Type::Class("Test", "org.tensorflow"); + Type type_t = Type::Generic("T").add_supertype(Type::Class("Number")); + clazz.add_parameter(type_t); + Type inner_class = Type::Class("InnerTest"); + inner_class.add_parameter(type_t); + + writer.BeginType(clazz, nullptr, PUBLIC) + .BeginInnerType(inner_class, PUBLIC | STATIC) + .EndType() + .EndType(); + + const char* expected = + "package org.tensorflow;\n\n" + "public class Test {\n" + " \n" + " public static class InnerTest {\n" + " }\n" + "}\n"; + ASSERT_STREQ(expected, writer.str().data()); +} + +TEST(WriteMethod, SimpleMethod) { + SourceBufferWriter writer; + Type clazz = Type::Class("Test", "org.tensorflow"); + Method method = Method::Create("doNothing", Type::Void()); + + writer.BeginType(clazz, nullptr, PUBLIC) + .BeginMethod(method, PUBLIC).EndMethod() + .EndType(); + + const char* expected = + "package org.tensorflow;\n\n" + "public class Test {\n" + " \n" + " public void doNothing() {\n" + " }\n" + "}\n"; + ASSERT_STREQ(expected, writer.str().data()); +} + +TEST(WriteMethod, AnnotatedAndDocumentedMethod) { + SourceBufferWriter writer; + Type clazz = Type::Class("Test", "org.tensorflow"); + Method method = Method::Create("doNothing", Type::Void()); + method.description("This method has a\n

\nmultiline description."); + method.add_annotation(Annotation::Create("Override")); + method.add_annotation(Annotation::Create("SuppressWarnings") + .attributes("\"rawtypes\"")); + + writer.BeginType(clazz, nullptr, PUBLIC) + .BeginMethod(method, PUBLIC).EndMethod() + .EndType(); + + const char* expected = + "package org.tensorflow;\n\n" + "public class Test {\n" + " \n" + " /**\n" + " * This method has a\n" + " *

\n" + " * multiline description.\n" + " **/\n" + " @Override\n" + " @SuppressWarnings(\"rawtypes\")\n" + " public void doNothing() {\n" + " }\n" + "}\n"; + ASSERT_STREQ(expected, writer.str().data()); +} + +TEST(WriteMethod, DocumentedMethodWithArguments) { + SourceBufferWriter writer; + Type clazz = Type::Class("Test", "org.tensorflow"); + Method method = Method::Create("boolToInt", Type::Int()); + method.description("Converts a boolean to an int"); + method.return_description("int value for this boolean"); + method.add_argument(Variable::Create("b", Type::Boolean())); + Variable reverse = Variable::Create("reverse", Type::Boolean()); + reverse.description("if true, value is reversed"); + method.add_argument(reverse); + + writer.BeginType(clazz, nullptr, PUBLIC) + .BeginMethod(method, PUBLIC) + .Append("if (b && !reverse)") + .BeginBlock() + .Append("return 1;").EndLine() + .EndBlock() + .Append("return 0;").EndLine() + .EndMethod() + .EndType(); + + const char* expected = + "package org.tensorflow;\n\n" + "public class Test {\n" + " \n" + " /**\n" + " * Converts a boolean to an int\n" + " * \n" + " * @param b\n" + " * @param reverse if true, value is reversed\n" + " * @return int value for this boolean\n" + " **/\n" + " public int boolToInt(boolean b, boolean reverse) {\n" + " if (b && !reverse) {\n" + " return 1;\n" + " }\n" + " return 0;\n" + " }\n" + "}\n"; + ASSERT_STREQ(expected, writer.str().data()); +} + +TEST(WriteMethod, ParameterizedMethod) { + SourceBufferWriter writer; + Type clazz = Type::Class("Test", "org.tensorflow"); + Type type_t = Type::Generic("T").add_supertype(Type::Class("Number")); + clazz.add_parameter(type_t); + Method method = Method::Create("doNothing", type_t); + + writer.BeginType(clazz, nullptr, PUBLIC) + .BeginMethod(method, PUBLIC) + .Append("return null;").EndLine() + .EndMethod() + .EndType(); + + const char* expected = + "package org.tensorflow;\n\n" + "public class Test {\n" + " \n" + " public T doNothing() {\n" + " return null;\n" + " }\n" + "}\n"; + ASSERT_STREQ(expected, writer.str().data()); +} + +TEST(WriteMethod, StaticParameterizedMethod) { + SourceBufferWriter writer; + Type clazz = Type::Class("Test", "org.tensorflow"); + Type type_t = Type::Generic("T").add_supertype(Type::Class("Number")); + clazz.add_parameter(type_t); + Method method = Method::Create("doNothing", type_t); + + writer.BeginType(clazz, nullptr, PUBLIC) + .BeginMethod(method, PUBLIC | STATIC) + .Append("return null;").EndLine() + .EndMethod() + .EndType(); + + const char* expected = + "package org.tensorflow;\n\n" + "public class Test {\n" + " \n" + " public static T doNothing() {\n" + " return null;\n" + " }\n" + "}\n"; + ASSERT_STREQ(expected, writer.str().data()); +} + } // namespace +} // namespace java } // namespace tensorflow diff --git a/tensorflow/java/src/gen/resources/test.snippet.java b/tensorflow/java/src/gen/resources/test.snippet.java new file mode 100644 index 0000000000000000000000000000000000000000..5e412a9aef436bb73a4d013d1b698b75ad9fbab4 --- /dev/null +++ b/tensorflow/java/src/gen/resources/test.snippet.java @@ -0,0 +1,2 @@ +// Here is a little snippet +System.out.println("Hello!"); diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 4c8c73548cac261131c2950d725ac41c6af3dab0..ae7e3e73aed1e43bd78e9f1d4b02bb02c854580d 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -58,6 +58,18 @@ py_library( "//tensorflow/tools/api/generator:__pkg__", "//tensorflow/tools/quantization:__pkg__", # TODO(b/34059704): remove when fixed ], + deps = [":no_contrib"] + if_not_windows([ + "//tensorflow/contrib:contrib_py", + ]), +) + +py_library( + name = "no_contrib", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + visibility = [ + "//tensorflow:__pkg__", + ], deps = [ ":array_ops", ":bitwise_ops", @@ -66,6 +78,7 @@ py_library( ":client_testlib", ":confusion_matrix", ":control_flow_ops", + ":cudnn_rnn_ops_gen", ":errors", ":framework", ":framework_for_generated_wrappers", @@ -86,39 +99,38 @@ py_library( ":ops", ":platform", ":pywrap_tensorflow", + ":saver_test_utils", ":script_ops", ":session_ops", ":sets", ":sparse_ops", ":spectral_ops", + ":spectral_ops_test_util", ":standard_ops", ":state_ops", ":string_ops", + ":subscribe", ":summary", ":tensor_array_ops", - ":training", - ":saver_test_utils", - ":subscribe", ":test_ops", # TODO: Break testing code out into separate rule. - ":tf_item", ":tf_cluster", + ":tf_item", ":tf_optimizer", + ":training", ":util", ":weights_broadcast_ops", - "//third_party/py/numpy", "//tensorflow/core:protos_all_py", "//tensorflow/python/data", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/feature_column:feature_column_py", "//tensorflow/python/keras", - "//tensorflow/python/ops/losses", "//tensorflow/python/ops/distributions", "//tensorflow/python/ops/linalg", + "//tensorflow/python/ops/losses", "//tensorflow/python/profiler", "//tensorflow/python/saved_model", - ] + if_not_windows([ - "//tensorflow/contrib:contrib_py", - ]), + "//third_party/py/numpy", + ], ) tf_py_build_info_genrule() @@ -765,6 +777,31 @@ py_library( ], ) +py_library( + name = "smart_cond", + srcs = ["framework/smart_cond.py"], + srcs_version = "PY2AND3", + deps = [ + ":control_flow_ops", + ":tensor_util", + ], +) + +py_test( + name = "smart_cond_test", + size = "small", + srcs = ["framework/smart_cond_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":constant_op", + ":framework_ops", + ":math_ops", + ":session", + ":smart_cond", + ], +) + py_library( name = "sparse_tensor", srcs = ["framework/sparse_tensor.py"], @@ -1007,6 +1044,11 @@ cuda_py_tests( "//third_party/py/numpy", "//tensorflow/core:protos_all_py", ], + shard_count = 10, + tags = [ + "noasan", + "optonly", + ], ) py_test( @@ -1023,7 +1065,7 @@ py_test( py_test( name = "framework_importer_test", - size = "medium", + size = "large", srcs = ["framework/importer_test.py"], main = "framework/importer_test.py", srcs_version = "PY2AND3", @@ -1331,6 +1373,12 @@ tf_gen_op_wrapper_private_py( ], ) +tf_gen_op_wrapper_private_py( + name = "summary_ops_gen", + visibility = ["//tensorflow:__subpackages__"], + deps = ["//tensorflow/core:summary_ops_op_lib"], +) + tf_gen_op_wrapper_private_py( name = "audio_ops_gen", require_shape_functions = True, @@ -1340,6 +1388,13 @@ tf_gen_op_wrapper_private_py( ], ) +tf_gen_op_wrapper_private_py( + name = "cudnn_rnn_ops_gen", + visibility = [ + "//tensorflow:__subpackages__", + ], +) + tf_gen_op_wrapper_private_py( name = "candidate_sampling_ops_gen", visibility = ["//learning/brain/python/ops:__pkg__"], @@ -1750,6 +1805,7 @@ py_library( py_library( name = "gradients", srcs = [ + "ops/custom_gradient.py", "ops/gradients.py", "ops/gradients_impl.py", ], @@ -1763,6 +1819,7 @@ py_library( ":control_flow_util", ":framework", ":framework_for_generated_wrappers", + ":framework_ops", ":functional_ops", ":image_grad", ":linalg_grad", @@ -1775,6 +1832,9 @@ py_library( ":platform", ":spectral_grad", ":util", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:tape", "//third_party/py/numpy", "@six_archive//:six", ], @@ -1817,13 +1877,16 @@ py_library( ":control_flow_ops", ":framework", ":framework_for_generated_wrappers", + ":gradients", ":image_ops_gen", ":math_ops", + ":nn", ":nn_ops_gen", ":random_ops", ":string_ops", ":util", ":variables", + "//third_party/py/numpy", ], ) @@ -2552,6 +2615,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":user_ops_gen", + ":util", "@six_archive//:six", ], ) @@ -2820,9 +2884,11 @@ py_library( ":client", ":control_flow_ops", ":data_flow_ops", + ":device", ":errors", ":framework", ":framework_for_generated_wrappers", + ":framework_ops", ":gradients", ":init_ops", ":io_ops", @@ -2847,6 +2913,8 @@ py_library( ":variable_scope", ":variables", "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -2857,6 +2925,7 @@ py_library( srcs = ["training/checkpointable.py"], srcs_version = "PY2AND3", deps = [ + ":array_ops", ":dtypes", ":io_ops_gen", ":ops", @@ -2875,6 +2944,18 @@ py_test( ], ) +py_test( + name = "distribute_test", + size = "small", + srcs = ["training/distribute_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":training", + ":variable_scope", + ], +) + py_test( name = "evaluation_test", size = "small", @@ -3050,7 +3131,6 @@ tf_proto_library( "framework/cpp_shape_inference.proto", ], ), - go_api_version = 2, ) tf_proto_library_py( @@ -3629,6 +3709,7 @@ py_test( ":framework_for_generated_wrappers", ":math_ops", ":state_ops_gen", + ":variable_scope", ":variables", "//tensorflow/core:protos_all_py", ], @@ -3919,7 +4000,13 @@ py_test( size = "small", srcs = ["training/checkpoint_utils_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], + tags = [ + "manual", + "no_cuda_on_cpu_tap", + "no_oss", + "no_windows", + "notap", + ], deps = [ ":client", ":client_testlib", @@ -3928,6 +4015,7 @@ py_test( ":partitioned_variables", ":platform", ":pywrap_tensorflow", + ":resource_variable_ops", ":state_ops", ":training", ":variable_scope", @@ -3957,6 +4045,25 @@ py_test( ], ) +py_test( + name = "warm_starting_util_test", + size = "small", + srcs = ["training/warm_starting_util_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":array_ops", + ":client_testlib", + ":dtypes", + ":framework_ops", + ":init_ops", + ":training", + ":variable_scope", + ":variables", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + ], +) + py_test( name = "monitored_session_test", size = "medium", @@ -4047,6 +4154,7 @@ py_library( ":pywrap_tensorflow", ":summary_op_util", ":summary_ops", + ":summary_ops_gen", ":util", "//tensorflow/python/eager:context", "//third_party/py/numpy", @@ -4091,6 +4199,7 @@ py_library( ":control_flow_ops", ":framework_for_generated_wrappers", ":platform", + ":smart_cond", ":tensor_util", ":util", ":variable_scope", @@ -4711,6 +4820,7 @@ py_test( srcs_version = "PY2AND3", tags = [ "grappler", + "no_cuda_on_cpu_tap", "no_pip", ], deps = [ diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 02ed5517ca895ab070a89f8810f77dadcff9212b..3346937904885c216d7a8de86fc6036604376173 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -99,6 +99,10 @@ from tensorflow.python.user_ops import user_ops from tensorflow.python.util import compat +# Import cudnn rnn ops to make sure their ops are registered. +from tensorflow.python.ops import gen_cudnn_rnn_ops as _ + + # Import the names from python/training.py as train.Name. from tensorflow.python.training import training as train @@ -139,6 +143,10 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import tensor_array_ops +# Eager execution +from tensorflow.python.eager.context import executing_eagerly +from tensorflow.python.framework.ops import enable_eager_execution + # Symbols whitelisted for export without documentation. # TODO(cwhipkey): review these and move to contrib, expose through # documentation, or remove. @@ -198,13 +206,9 @@ tf_export('TensorInfo')(TensorInfo) _allowed_symbols.extend([ 'arg_max', 'arg_min', - 'mul', # use tf.multiply instead. - 'neg', # use tf.negative instead. - 'sub', # use tf.subtract instead. 'create_partitioned_variables', 'deserialize_many_sparse', 'lin_space', - 'list_diff', # Use tf.listdiff instead. 'listdiff', # Use tf.listdiff instead. 'parse_single_sequence_example', 'serialize_many_sparse', @@ -294,6 +298,12 @@ _allowed_symbols.extend([ 'MONOLITHIC_BUILD', ]) +# Eager execution +_allowed_symbols.extend([ + 'enable_eager_execution', + 'executing_eagerly', +]) + # Remove all extra symbols that don't have a docstring or are not explicitly # referenced in the whitelist. remove_undocumented(__name__, _allowed_symbols, [ diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index f3c4fecdc0fde0436bea76cc774edaabe1bc07dd..da5dc6f5998bd6f63445dc3694e53d1032e3d1ab 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -21,6 +21,7 @@ from __future__ import print_function import functools import re import threading +import warnings import numpy as np @@ -888,6 +889,8 @@ class BaseSession(SessionInterface): Either a single value if `fetches` is a single graph element, or a list of values if `fetches` is a list, or a dictionary with the same keys as `fetches` if that is a dictionary (described above). + Order in which `fetches` operations are evaluated inside the call + is undefined. Raises: RuntimeError: If this `Session` is in an invalid state (e.g. has been @@ -1085,7 +1088,10 @@ class BaseSession(SessionInterface): if isinstance(subfeed_val, ops.Tensor): raise TypeError('The value of a feed cannot be a tf.Tensor object. ' 'Acceptable feed values include Python scalars, ' - 'strings, lists, numpy ndarrays, or TensorHandles.') + 'strings, lists, numpy ndarrays, or TensorHandles.' + 'For reference, the tensor object was ' + + str(feed_val) + ' which was passed to the ' + 'feed with key ' + str(feed) + '.') subfeed_dtype = subfeed_t.dtype.as_numpy_dtype if isinstance(subfeed_val, int) and _convert_to_numpy_obj( @@ -1217,19 +1223,12 @@ class BaseSession(SessionInterface): compat.as_bytes(options.SerializeToString())) if options else None run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None try: - with errors.raise_exception_on_not_ok_status() as status: - if self._created_with_new_api: - results = tf_session.TF_SessionRun_wrapper( - self._session, options_ptr, {}, fetch_list, target_list, - run_metadata_ptr, status) - else: - results = tf_session.TF_Run(self._session, options_ptr, {}, - fetch_list, target_list, status, - run_metadata_ptr) - if fetch_handler: - results = fetch_handler.build_results(self, results) - else: - results = results[0] if results else None + results = self._call_tf_sessionrun( + options_ptr, {}, fetch_list, target_list, run_metadata_ptr) + if fetch_handler: + results = fetch_handler.build_results(self, results) + else: + results = results[0] if results else None if run_metadata: proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) run_metadata.ParseFromString(compat.as_bytes(proto_data)) @@ -1250,13 +1249,7 @@ class BaseSession(SessionInterface): assert len(target_list) == 1 def _single_operation_run(): - with errors.raise_exception_on_not_ok_status() as status: - if self._created_with_new_api: - tf_session.TF_SessionRun_wrapper(self._session, None, {}, [], - target_list, None, status) - else: - tf_session.TF_Run(self._session, None, {}, [], target_list, status, - None) + self._call_tf_sessionrun(None, {}, [], target_list, None) return _single_operation_run elif isinstance(fetches, ops.Tensor): @@ -1266,13 +1259,7 @@ class BaseSession(SessionInterface): assert not target_list def _single_tensor_run(): - with errors.raise_exception_on_not_ok_status() as status: - if self._created_with_new_api: - results = tf_session.TF_SessionRun_wrapper( - self._session, None, {}, fetch_list, [], None, status) - else: - results = tf_session.TF_Run(self._session, None, {}, fetch_list, [], - status, None) + results = self._call_tf_sessionrun(None, {}, fetch_list, [], None) return results[0] return _single_tensor_run @@ -1280,13 +1267,8 @@ class BaseSession(SessionInterface): # In all other cases, we must use `fetch_handler` to build the # results for us. def _fetch_handler_run(): - with errors.raise_exception_on_not_ok_status() as status: - if self._created_with_new_api: - results = tf_session.TF_SessionRun_wrapper( - self._session, None, {}, fetch_list, target_list, None, status) - else: - results = tf_session.TF_Run(self._session, None, {}, fetch_list, - target_list, status, None) + results = self._call_tf_sessionrun( + None, {}, fetch_list, target_list, None) return fetch_handler.build_results(self, results) return _fetch_handler_run @@ -1326,35 +1308,22 @@ class BaseSession(SessionInterface): fetches = _name_list(fetch_list) targets = _name_list(target_list) - def _run_fn(session, feed_dict, fetch_list, target_list, options, - run_metadata): + def _run_fn(feed_dict, fetch_list, target_list, options, run_metadata): # Ensure any changes to the graph are reflected in the runtime. self._extend_graph() - with errors.raise_exception_on_not_ok_status() as status: - if self._created_with_new_api: - return tf_session.TF_SessionRun_wrapper(session, options, feed_dict, - fetch_list, target_list, - run_metadata, status) - else: - return tf_session.TF_Run(session, options, feed_dict, fetch_list, - target_list, status, run_metadata) + return self._call_tf_sessionrun( + options, feed_dict, fetch_list, target_list, run_metadata) - def _prun_fn(session, handle, feed_dict, fetch_list): + def _prun_fn(handle, feed_dict, fetch_list): if target_list: raise RuntimeError('partial_run() requires empty target_list.') - with errors.raise_exception_on_not_ok_status() as status: - if self._created_with_new_api: - return tf_session.TF_SessionPRun_wrapper(session, handle, feed_dict, - fetch_list, status) - else: - return tf_session.TF_PRun(session, handle, feed_dict, fetch_list, - status) + return self._call_tf_sessionprun(handle, feed_dict, fetch_list) if handle is None: - return self._do_call(_run_fn, self._session, feeds, fetches, targets, - options, run_metadata) + return self._do_call(_run_fn, feeds, fetches, targets, options, + run_metadata) else: - return self._do_call(_prun_fn, self._session, handle, feeds, fetches) + return self._do_call(_prun_fn, handle, feeds, fetches) def _do_call(self, fn, *args): try: @@ -1374,23 +1343,23 @@ class BaseSession(SessionInterface): raise type(e)(node_def, op, message) def _extend_graph(self): - # Nothing to do if we're using the new session interface - # TODO(skyewm): remove this function altogether eventually if self._created_with_new_api: - return - - # Ensure any changes to the graph are reflected in the runtime. - with self._extend_lock: - if self._graph.version > self._current_version: - # pylint: disable=protected-access - graph_def, self._current_version = self._graph._as_graph_def( - from_version=self._current_version, add_shapes=self._add_shapes) - # pylint: enable=protected-access - + with self._graph._lock: # pylint: disable=protected-access with errors.raise_exception_on_not_ok_status() as status: - tf_session.TF_ExtendGraph(self._session, - graph_def.SerializeToString(), status) - self._opened = True + tf_session.ExtendSession(self._session, status) + else: + # Ensure any changes to the graph are reflected in the runtime. + with self._extend_lock: + if self._graph.version > self._current_version: + # pylint: disable=protected-access + graph_def, self._current_version = self._graph._as_graph_def( + from_version=self._current_version, add_shapes=self._add_shapes) + # pylint: enable=protected-access + + with errors.raise_exception_on_not_ok_status() as status: + tf_session.TF_ExtendGraph(self._session, + graph_def.SerializeToString(), status) + self._opened = True # The threshold to run garbage collection to delete dead tensors. _DEAD_HANDLES_THRESHOLD = 10 @@ -1441,6 +1410,27 @@ class BaseSession(SessionInterface): feed_dict[feed_tensor] = np_val return handles + def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, + run_metadata): + with errors.raise_exception_on_not_ok_status() as status: + if self._created_with_new_api: + return tf_session.TF_SessionRun_wrapper( + self._session, options, feed_dict, fetch_list, target_list, + run_metadata, status) + else: + return tf_session.TF_Run( + self._session, options, feed_dict, fetch_list, target_list, + status, run_metadata) + + def _call_tf_sessionprun(self, handle, feed_dict, fetch_list): + with errors.raise_exception_on_not_ok_status() as status: + if self._created_with_new_api: + return tf_session.TF_SessionPRun_wrapper( + self._session, handle, feed_dict, fetch_list, status) + else: + return tf_session.TF_PRun( + self._session, handle, feed_dict, fetch_list, status) + @tf_export('Session') class Session(BaseSession): @@ -1637,6 +1627,9 @@ class InteractiveSession(BaseSession): ``` """ + _count_lock = threading.Lock() + _active_session_count = 0 # GUARDED_BY(_count_lock) + def __init__(self, target='', graph=None, config=None): """Creates a new interactive TensorFlow session. @@ -1665,6 +1658,19 @@ class InteractiveSession(BaseSession): config.graph_options.place_pruned_graph = True super(InteractiveSession, self).__init__(target, graph, config) + with InteractiveSession._count_lock: + if InteractiveSession._active_session_count > 0: + warnings.warn('An interactive session is already active. This can ' + 'cause out-of-memory errors in some cases. You must ' + 'explicitly call `InteractiveSession.close()` to release ' + 'resources held by the other session(s).') + InteractiveSession._active_session_count += 1 + # NOTE(mrry): We do not use `Session._closed` here because it has unhelpful + # semantics (in particular, it is not set to true if `Session.close()` is + # called on a session that has not been "opened" by running a step) and we + # cannot change those semantics without breaking existing code. + self._explicitly_closed = False + self._default_session = self.as_default() self._default_session.enforce_nesting = False self._default_session.__enter__() @@ -1677,6 +1683,14 @@ class InteractiveSession(BaseSession): def close(self): """Closes an `InteractiveSession`.""" super(InteractiveSession, self).close() + with InteractiveSession._count_lock: + if not self._explicitly_closed: + InteractiveSession._active_session_count -= 1 + self._explicitly_closed = True + else: + return if self._explicit_graph is not None: self._default_graph.__exit__(None, None, None) + self._default_graph = None self._default_session.__exit__(None, None, None) + self._default_session = None diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 490572254b0be6a110ef06cea15d20d780f732cf..6e2640efd1d58ab524e42b62f62ad3d38f360c0e 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -22,13 +22,13 @@ import os import sys import threading import time +import warnings import numpy as np import six from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.core.framework import attr_value_pb2 -from tensorflow.core.framework import types_pb2 from tensorflow.core.lib.core import error_codes_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session @@ -37,6 +37,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import function +from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_util @@ -46,6 +47,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gen_control_flow_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops # Import resource_variable_ops for the variables-to-tensor implicit conversion. from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import @@ -63,6 +65,10 @@ ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape) @test_util.with_c_api class SessionTest(test_util.TensorFlowTestCase): + def setUp(self): + super(SessionTest, self).setUp() + warnings.simplefilter('always') + def testUseExistingGraph(self): with ops.Graph().as_default() as g, ops.device('/cpu:0'): a = constant_op.constant(6.0, shape=[1, 1]) @@ -187,12 +193,10 @@ class SessionTest(test_util.TensorFlowTestCase): a = constant_op.constant(0.0, shape=[2, 3]) # NOTE(mrry): The original_op is nonsense, but used here to test that the # errors are reported correctly. - # pylint: disable=protected-access with sess.graph._original_op(a.op): b = array_ops.identity(a, name='id') with sess.graph._original_op(b.op): c = array_ops.placeholder(dtypes.float32) - # pylint: enable=protected-access def exc_predicate(e): return (e.op == c.op and e.op._original_op == b.op and @@ -1052,6 +1056,43 @@ class SessionTest(test_util.TensorFlowTestCase): for t in threads: t.join() + def testParallelRunAndBuild(self): + with session.Session() as sess: + c = constant_op.constant(5.0) + stop = threading.Event() + + def run_loop(): + while not stop.is_set(): + self.assertEqual(sess.run(c), 5.0) + + threads = [self.checkedThread(target=run_loop) for _ in range(100)] + for t in threads: + t.start() + + # Do some graph construction. Try to exercise non-trivial paths. + graph = ops.get_default_graph() + gdef = None + for _ in range(10): + x = array_ops.placeholder(dtype=dtypes.float32) + with ops.colocate_with(x): + y = array_ops.placeholder(dtype=dtypes.float32) + with ops.device('/cpu:0'): + z = control_flow_ops.while_loop( + lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y]) + with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}): + gradients_impl.gradients(z, [x, y]) + if gdef is None: + gdef = graph.as_graph_def() + else: + # NOTE(skyewm): import_graph_def breaks the running threads without + # the C API enabled. This is not a regression so I didn't fix it. + if ops._USE_C_API: + importer.import_graph_def(gdef, name='import') + + stop.set() + for t in threads: + t.join() + def testRunFeedDict(self): with session.Session() as s: x = array_ops.zeros([2]) @@ -1153,6 +1194,33 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertAllEqual([[24.0]], e.eval()) sess.close() + def testMultipleInteractiveSessionsWarning(self): + # Reinitialize the global state to ensure that the expected warnings will + # be emitted. + session.InteractiveSession._active_session_count = 0 # pylint: disable=protected-access + + sess = session.InteractiveSession() + sess.run(constant_op.constant(4.0)) # Run so that the session is "opened". + sess.close() + # Opening and closing interactive sessions serially should not warn. + with warnings.catch_warnings(record=True) as w: + sess = session.InteractiveSession() + sess.close() + self.assertEqual(0, len(w)) + + with warnings.catch_warnings(record=True) as w: + sess = session.InteractiveSession() + self.assertEqual(0, len(w)) + with warnings.catch_warnings(record=True) as w: + sess2 = session.InteractiveSession() + self.assertEqual(1, len(w)) + self.assertTrue('An interactive session is already active. This can cause ' + 'out-of-memory errors in some cases. You must explicitly ' + 'call `InteractiveSession.close()` to release resources ' + 'held by the other session(s).' in str(w[0].message)) + sess2.close() + sess.close() + def testInteractivePlacePrunedGraph(self): sess = session.InteractiveSession() @@ -1745,8 +1813,8 @@ class SessionTest(test_util.TensorFlowTestCase): # Ensure that errors from building the graph get propagated. data = array_ops.placeholder(dtypes.float32, shape=[]) # pylint: disable=protected-access - enter_1 = gen_control_flow_ops._enter(data, 'foo_1', False) - enter_2 = gen_control_flow_ops._enter(data, 'foo_2', False) + enter_1 = gen_control_flow_ops.enter(data, 'foo_1', False) + enter_2 = gen_control_flow_ops.enter(data, 'foo_2', False) # pylint: enable=protected-access res = math_ops.add(enter_1, enter_2) with self.assertRaisesOpError('has inputs from different frames'): @@ -1815,144 +1883,5 @@ class SessionTest(test_util.TensorFlowTestCase): sess.run(a, feed_dict={a: 1}) -class GraphMutationTest(test_util.TensorFlowTestCase): - - def setUp(self): - self._original_use_c_api_value = ops._USE_C_API - ops._USE_C_API = True - super(GraphMutationTest, self).setUp() - - def tearDown(self): - ops._USE_C_API = self._original_use_c_api_value - super(GraphMutationTest, self).tearDown() - - def testUpdateInputAfterRunning(self): - with ops.Graph().as_default() as g: - a = constant_op.constant(1.0) - b = constant_op.constant(2.0) - c = a + b - - with session.Session(graph=g) as sess: - self.assertAllEqual(3.0, sess.run(c)) - c.op._update_input(1, a) # pylint: disable=protected-access - with self.assertRaisesRegexp( - errors.FailedPreconditionError, - 'add.*was changed by updating input tensor after it was run'): - sess.run(c) - - # Check that running the graph with a new session is fine - with session.Session(graph=g) as sess2: - self.assertAllEqual(2.0, sess2.run(c)) - - def testSetDeviceAfterRunning(self): - with ops.Graph().as_default() as g: - a = constant_op.constant(1.0) - b = constant_op.constant(2.0) - c = a + b - - with session.Session(graph=g) as sess: - self.assertAllEqual(3.0, sess.run(c)) - c.op._set_device('/cpu:0') # pylint: disable=protected-access - with self.assertRaisesRegexp( - errors.FailedPreconditionError, - 'add.*was changed by setting device after it was run'): - sess.run(c) - - def testSetAttrAfterRunning(self): - with ops.Graph().as_default() as g: - a = constant_op.constant(1.0, dtype=dtypes.float32) - b = math_ops.cast(a, dtypes.float64) - - with session.Session(graph=g) as sess: - self.assertAllEqual(1.0, sess.run(b)) - b.op._set_attr('DstT', attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT)) - with self.assertRaisesRegexp( - errors.FailedPreconditionError, - 'Cast.*was changed by setting attribute after it was run'): - sess.run(b) - - def testRunModifyRun(self): - with ops.Graph().as_default() as g: - a = constant_op.constant(1.0) - b = constant_op.constant(2.0) - c = a + b - - with session.Session(graph=g) as sess: - self.assertAllEqual(3.0, sess.run(c)) - - d = b + c - d.op._update_input(0, a) # pylint: disable=protected-access - self.assertAllEqual(3.0, sess.run(c)) - self.assertAllEqual(4.0, sess.run(d)) - - def testRunModifyRunTwoSessions(self): - with ops.Graph().as_default() as g: - a = constant_op.constant(1.0) - b = constant_op.constant(2.0) - c = a + b - - with session.Session(graph=g) as sess1: - with session.Session(graph=g) as sess2: - self.assertAllEqual(3.0, sess1.run(c)) - self.assertAllEqual(3.0, sess2.run(c)) - - d = b + c - d.op._update_input(0, a) # pylint: disable=protected-access - self.assertAllEqual(3.0, sess2.run(c)) - self.assertAllEqual(4.0, sess2.run(d)) - - d.op._update_input(0, b) # pylint: disable=protected-access - self.assertAllEqual(3.0, sess1.run(c)) - self.assertAllEqual(5.0, sess1.run(d)) - - with self.assertRaisesRegexp( - errors.FailedPreconditionError, - 'add.*was changed by updating input tensor after it was run'): - sess2.run(c) - - def testTwoSessionsOneRunBeforeModification(self): - with ops.Graph().as_default() as g, ops.device('/cpu:0'): - a = constant_op.constant(1.0) - b = constant_op.constant(2.0) - c = a + b - - with session.Session(graph=g) as sess1: - with session.Session(graph=g) as sess2: - sess1.run(c) - - c.op._set_device('/cpu:0') # pylint: disable=protected-access - - with self.assertRaisesRegexp( - errors.FailedPreconditionError, - 'add.*was changed by setting device after it was run'): - sess1.run(c) - - # sess2 was not run before modification - self.assertAllEqual(3.0, sess2.run(c)) - - def testTwoSessionsBothRunBeforeModification(self): - with ops.Graph().as_default() as g, ops.device('/cpu:0'): - a = constant_op.constant(1.0) - b = constant_op.constant(2.0) - c = a + b - - with session.Session(graph=g) as sess1: - with session.Session(graph=g) as sess2: - sess1.run(c) - sess2.run(c) - - c.op._set_device('/cpu:0') # pylint: disable=protected-access - - with self.assertRaisesRegexp( - errors.FailedPreconditionError, - 'add.*was changed by setting device after it was run'): - sess1.run(c) - - with self.assertRaisesRegexp( - errors.FailedPreconditionError, - 'add.*was changed by setting device after it was run'): - sess2.run(c) - - if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index f305cd271f98bea697ea8ff15be799d3e80db0bf..e88fc0c01a8bb7534f47e2a0389965c102bbad7b 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -720,6 +720,9 @@ def TF_Reset(target, containers=None, config=None): } %unignore SetRequireShapeInferenceFns; +%unignore TF_TryEvaluateConstant_wrapper; +%noexception TF_TryEvaluateConstant_wrapper; +%unignore ExtendSession; %include "tensorflow/python/client/tf_session_helper.h" diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index 361dbc22b097a9bc82f656d7416b88c4a3a1ec2d..a8ab91749a86749a1eef25e2674634334682d0f3 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -493,4 +493,19 @@ std::vector TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( return input_strs; } +PyObject* TF_TryEvaluateConstant_wrapper(TF_Graph* graph, TF_Output output, + TF_Status* status) { + TF_Tensor* result_tensor; + bool evaluated = + TF_TryEvaluateConstant(graph, output, &result_tensor, status); + if (!evaluated || TF_GetCode(status) != TF_OK) Py_RETURN_NONE; + + Safe_TF_TensorPtr safe_result_tensor(result_tensor); + PyObject* out; + Status s = TF_TensorToPyArray(std::move(safe_result_tensor), &out); + Set_TF_Status_from_Status(status, s); + if (!s.ok()) Py_RETURN_NONE; + return out; +} + } // namespace tensorflow diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 29d5b28f40a7c07c199eec8c8cd85de626f6b068..83318dc178f6da3828a8dc41e81b7fc3e2e19e22 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -213,6 +213,11 @@ std::vector TF_GraphGetTensorShape_wrapper(TF_Graph* graph, std::vector TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( TF_ImportGraphDefResults* results); +// If evaluation was possible, returns the numpy ndarray of the evaluated +// result. Otherwise returns None. +PyObject* TF_TryEvaluateConstant_wrapper(TF_Graph* graph, TF_Output output, + TF_Status* status); + } // namespace tensorflow #endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ diff --git a/tensorflow/python/client/timeline_test.py b/tensorflow/python/client/timeline_test.py index 9641b8b7f2735e2e0477aec59edd539e999fa969..5e6b5acdb02e4c8c167485520a8d84ac43db7511 100644 --- a/tensorflow/python/client/timeline_test.py +++ b/tensorflow/python/client/timeline_test.py @@ -155,9 +155,12 @@ class TimelineTest(test.TestCase): ctf = step_analysis.chrome_trace.format_to_string() self._validateTrace(ctf) maximums = step_analysis.allocator_maximums - self.assertTrue('cpu' in maximums) + cpuname = 'cpu' + if 'mklcpu' in maximums: + cpuname = 'mkl' + cpuname + self.assertTrue(cpuname in maximums) cpu_max = maximums[ - 'cuda_host_bfc'] if 'cuda_host_bfc' in maximums else maximums['cpu'] + 'cuda_host_bfc'] if 'cuda_host_bfc' in maximums else maximums[cpuname] # At least num1 + num2, both float32s (4 bytes each) self.assertGreater(cpu_max.num_bytes, 8) self.assertGreater(cpu_max.timestamp, 0) diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py index 02720a2e985914d3a6774dc6f64d1316890c46bf..25269dc810ae2e3107f8b5317496a35a8ff59d0c 100644 --- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py @@ -297,6 +297,21 @@ class MemoryCacheDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(i2.get_next()) + def testCacheTakeRepeat(self): + dataset = dataset_ops.Dataset.range(10).cache().take(5).repeat(2) + itr = dataset.make_one_shot_iterator() + n = itr.get_next() + + expected_values = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] + + with self.test_session() as sess: + for i, expected in enumerate(expected_values): + self.assertEqual(expected, sess.run(n), + "Unexpected value at index %s" % i) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(itr.get_next()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py index 14627810b57f68fd96e3e3cc7b51b4fbf7365299..ea5b41e5d819743ad03f3148d654329aea51dab7 100644 --- a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py @@ -263,7 +263,7 @@ class DatasetConstructorTest(test.TestCase): for i in range(3): results = sess.run(get_next) for component, result_component in zip( - (zip(*components[:3])[i] + expected[i]), results): + (list(zip(*components[:3]))[i] + expected[i]), results): if sparse_tensor.is_sparse(component): self.assertSparseValuesEqual(component, result_component) else: diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py index b9258b720edd4ecd620c61eed18f6f975cb7f439..4f2216f0a340acb582c2d09523b0c78af99bdd90 100644 --- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py @@ -17,11 +17,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import time + import numpy as np +from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops @@ -156,6 +160,65 @@ class FilterDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testReturnComponent(self): + iterator = ( + dataset_ops.Dataset.zip( + (dataset_ops.Dataset.range(10), + dataset_ops.Dataset.from_tensors(True).repeat(None))) + .filter(lambda x, y: y).make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + self.assertEqual((i, True), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testParallelFilters(self): + dataset = dataset_ops.Dataset.range(10).filter( + lambda x: math_ops.equal(x % 2, 0)) + iterators = [dataset.make_one_shot_iterator() for _ in range(10)] + next_elements = [iterator.get_next() for iterator in iterators] + with self.test_session() as sess: + self.assertEqual([0 for _ in range(10)], sess.run(next_elements)) + + +class FilterDatasetBenchmark(test.Benchmark): + + def _benchmark(self, predicate, name): + with ops.Graph().as_default(): + dataset = ( + dataset_ops.Dataset.from_tensors(True).repeat(None).filter(predicate)) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for _ in range(5): + sess.run(next_element.op) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100 + print("Filter dataset using %s. Median wall time: %f" % + (name, median_wall_time)) + self.report_benchmark( + iters=100, + wall_time=median_wall_time, + name="benchmark_filter_dataset_%s" % name) + + def benchmarkSimpleFunction(self): + self._benchmark(array_ops.identity, "simple_function") + + def benchmarkReturnComponentOptimization(self): + self._benchmark(lambda x: x, "return_component") + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_test.py index 23c6d7385f8d4a12019fa514f349f2598d9629de..4a14a915bdb33f1ac6e8fc1839b32bc81fa8de05 100644 --- a/tensorflow/python/data/kernel_tests/iterator_ops_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_ops_test.py @@ -22,6 +22,7 @@ import warnings import numpy as np +from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops @@ -44,6 +45,7 @@ from tensorflow.python.ops import script_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import server_lib +from tensorflow.python.util import compat class IteratorTest(test.TestCase): @@ -63,8 +65,9 @@ class IteratorTest(test.TestCase): def testCapturingStateInOneShotRaisesException(self): var = variables.Variable(37.0, name="myvar") - dataset = (dataset_ops.Dataset.from_tensor_slices([0.0, 1.0, 2.0]) - .map(lambda x: x + var)) + dataset = ( + dataset_ops.Dataset.from_tensor_slices([0.0, 1.0, 2.0]) + .map(lambda x: x + var)) with self.assertRaisesRegexp( ValueError, r"`Dataset.make_one_shot_iterator\(\)` does not support " "datasets that capture stateful objects.+myvar"): @@ -78,8 +81,9 @@ class IteratorTest(test.TestCase): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) - iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(14).make_one_shot_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(14).make_one_shot_iterator()) get_next = iterator.get_next() self.assertEqual([c.shape[1:] for c in components], @@ -103,8 +107,9 @@ class IteratorTest(test.TestCase): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) - iterator = (dataset_ops.Dataset.from_tensor_slices(tensor_components) - .map(_map_fn).repeat(14).make_one_shot_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(tensor_components) + .map(_map_fn).repeat(14).make_one_shot_iterator()) get_next = iterator.get_next() self.assertEqual([c.shape[1:] for c in components], @@ -125,10 +130,13 @@ class IteratorTest(test.TestCase): np.array(37.0) * np.arange(7)) def within_container(): + def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .map(_map_fn).repeat(14).make_one_shot_iterator()) + + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components) + .map(_map_fn).repeat(14).make_one_shot_iterator()) return iterator.get_next() server = server_lib.Server.create_local_server() @@ -159,8 +167,8 @@ class IteratorTest(test.TestCase): # Create a session with a single thread to ensure that the # one-shot iterator initializer does not deadlock. - config = config_pb2.ConfigProto(inter_op_parallelism_threads=1, - use_per_session_threads=True) + config = config_pb2.ConfigProto( + inter_op_parallelism_threads=1, use_per_session_threads=True) with session.Session(config=config) as sess: self.assertAllEqual([1, 4, 9], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -169,6 +177,7 @@ class IteratorTest(test.TestCase): # Test with multiple threads invoking the one-shot iterator concurrently. with session.Session(config=config) as sess: results = [] + def consumer_thread(): try: results.append(sess.run(next_element)) @@ -177,7 +186,8 @@ class IteratorTest(test.TestCase): num_threads = 8 threads = [ - self.checkedThread(consumer_thread) for _ in range(num_threads)] + self.checkedThread(consumer_thread) for _ in range(num_threads) + ] for t in threads: t.start() for t in threads: @@ -205,24 +215,24 @@ class IteratorTest(test.TestCase): sess.run(next_element) with self.test_session() as sess: + def consumer_thread(): with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(next_element) num_threads = 8 threads = [ - self.checkedThread(consumer_thread) for _ in range(num_threads)] + self.checkedThread(consumer_thread) for _ in range(num_threads) + ] for t in threads: t.start() for t in threads: t.join() def testSimpleSharedResource(self): - components = ( - np.array(1, dtype=np.int64), - np.array([1, 2, 3], dtype=np.int64), - np.array(37.0, dtype=np.float64) - ) + components = (np.array(1, dtype=np.int64), + np.array([1, 2, 3], dtype=np.int64), + np.array(37.0, dtype=np.float64)) server = server_lib.Server.create_local_server() @@ -231,9 +241,10 @@ class IteratorTest(test.TestCase): # first session (initializing the iterator) is visible in the # second session. with ops.Graph().as_default(): - iterator = (dataset_ops.Dataset.from_tensors(components) - .map(lambda x, y, z: (x, y, z)).make_initializable_iterator( - shared_name="shared_iterator")) + iterator = ( + dataset_ops.Dataset.from_tensors(components) + .map(lambda x, y, z: (x, y, z)).make_initializable_iterator( + shared_name="shared_iterator")) init_op = iterator.initializer get_next = iterator.get_next() @@ -269,8 +280,9 @@ class IteratorTest(test.TestCase): def testNotInitializedError(self): components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) - iterator = (dataset_ops.Dataset.from_tensors(components) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensors(components) + .make_initializable_iterator()) get_next = iterator.get_next() with self.test_session() as sess: @@ -320,8 +332,8 @@ class IteratorTest(test.TestCase): def testReinitializableIteratorStaticErrors(self): # Non-matching structure for types and shapes. with self.assertRaises(TypeError): - iterator = iterator_ops.Iterator.from_structure((dtypes.int64, - dtypes.float64), [None]) + iterator = iterator_ops.Iterator.from_structure( + (dtypes.int64, dtypes.float64), [None]) # Test validation of dataset argument. iterator = iterator_ops.Iterator.from_structure((dtypes.int64, @@ -337,18 +349,18 @@ class IteratorTest(test.TestCase): # Incompatible types. with self.assertRaises(TypeError): iterator.make_initializer( - dataset_ops.Dataset.from_tensors((constant_op.constant( - [1, 2, 3], dtype=dtypes.int32), constant_op.constant( - [4., 5., 6., 7.], dtype=dtypes.float32)))) + dataset_ops.Dataset.from_tensors( + (constant_op.constant([1, 2, 3], dtype=dtypes.int32), + constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float32)))) # Incompatible shapes. iterator = iterator_ops.Iterator.from_structure( (dtypes.int64, dtypes.float64), ([None], [])) with self.assertRaises(TypeError): iterator.make_initializer( - dataset_ops.Dataset.from_tensors((constant_op.constant( - [1, 2, 3], dtype=dtypes.int64), constant_op.constant( - [4., 5., 6., 7.], dtype=dtypes.float64)))) + dataset_ops.Dataset.from_tensors( + (constant_op.constant([1, 2, 3], dtype=dtypes.int64), + constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64)))) def testIteratorStringHandle(self): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) @@ -370,33 +382,40 @@ class IteratorTest(test.TestCase): iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle()) - self.assertEqual( - 10, sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle})) - self.assertEqual( - 1, sess.run(next_element, - feed_dict={handle_placeholder: iterator_3_handle})) - self.assertEqual( - 20, sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle})) - self.assertEqual( - 2, sess.run(next_element, - feed_dict={handle_placeholder: iterator_3_handle})) - self.assertEqual( - 30, sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle})) - self.assertEqual( - 3, sess.run(next_element, - feed_dict={handle_placeholder: iterator_3_handle})) - self.assertEqual( - 40, sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle})) + self.assertEqual(10, + sess.run( + next_element, + feed_dict={handle_placeholder: iterator_4_handle})) + self.assertEqual(1, + sess.run( + next_element, + feed_dict={handle_placeholder: iterator_3_handle})) + self.assertEqual(20, + sess.run( + next_element, + feed_dict={handle_placeholder: iterator_4_handle})) + self.assertEqual(2, + sess.run( + next_element, + feed_dict={handle_placeholder: iterator_3_handle})) + self.assertEqual(30, + sess.run( + next_element, + feed_dict={handle_placeholder: iterator_4_handle})) + self.assertEqual(3, + sess.run( + next_element, + feed_dict={handle_placeholder: iterator_3_handle})) + self.assertEqual(40, + sess.run( + next_element, + feed_dict={handle_placeholder: iterator_4_handle})) with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element, - feed_dict={handle_placeholder: iterator_3_handle}) + sess.run( + next_element, feed_dict={handle_placeholder: iterator_3_handle}) with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle}) + sess.run( + next_element, feed_dict={handle_placeholder: iterator_4_handle}) def testIteratorStringHandleReuseTensorObject(self): dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) @@ -427,8 +446,8 @@ class IteratorTest(test.TestCase): self.assertIsNot(handle_with_name, handle_with_same_name) def testIteratorStringHandleError(self): - dataset_int_scalar = (dataset_ops.Dataset.from_tensor_slices([1, 2, - 3]).repeat()) + dataset_int_scalar = ( + dataset_ops.Dataset.from_tensor_slices([1, 2, 3]).repeat()) dataset_float_vector = (dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])) handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) @@ -522,6 +541,58 @@ class IteratorTest(test.TestCase): target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) + def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self): + s1 = server_lib.Server.create_local_server() + s2 = server_lib.Server.create_local_server() + s3 = server_lib.Server.create_local_server() + + cluster_def = cluster_pb2.ClusterDef() + workers = cluster_def.job.add() + workers.name = "worker" + workers.tasks[0] = s1.target[len("grpc://"):] + workers.tasks[1] = s2.target[len("grpc://"):] + client = cluster_def.job.add() + client.name = "client" + client.tasks[0] = s3.target[len("grpc://"):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + worker_devices = [ + "/job:worker/replica:0/task:%d/cpu:0" % i for i in range(2) + ] + itr_handles = [] + for device in worker_devices: + with ops.device(device): + src = dataset_ops.Dataset.from_tensor_slices([device]) + itr = src.make_one_shot_iterator() + itr_handles.append(itr.string_handle()) + + targets = dataset_ops.Dataset.from_tensor_slices(worker_devices) + handles = dataset_ops.Dataset.from_tensor_slices(itr_handles) + + @function.Defun(dtypes.string) + def loading_func(h): + remote_itr = iterator_ops.Iterator.from_string_handle( + h, itr.output_types, itr.output_shapes) + return remote_itr.get_next() + + def map_fn(target, handle): + return functional_ops.remote_call( + args=[handle], Tout=[dtypes.string], f=loading_func, target=target) + + with ops.device("/job:client"): + client_dataset = dataset_ops.Dataset.zip((targets, handles)).map(map_fn) + itr = client_dataset.make_initializable_iterator() + n = itr.get_next() + + with session.Session(s3.target, config=config) as sess: + sess.run(itr.initializer) + expected_values = worker_devices + for expected in expected_values: + self.assertEqual((compat.as_bytes(expected),), sess.run(n)) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(n) + def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -641,8 +712,7 @@ class IteratorTest(test.TestCase): with warnings.catch_warnings(record=True) as w: for _ in range(100): iterator.get_next() - self.assertEqual(100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD, - len(w)) + self.assertEqual(100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD, len(w)) for warning in w: self.assertTrue( iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE in str(warning.message)) diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py index 4e7691ee8144a19a62476281d86fb5df46dd3e4b..6442eb9ff554e61829796fb904342072d1846a32 100644 --- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py @@ -46,8 +46,9 @@ class ListFilesDatasetOpTest(test.TestCase): dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*')) with self.test_session() as sess: itr = dataset.make_one_shot_iterator() + next_element = itr.get_next() with self.assertRaises(errors.OutOfRangeError): - sess.run(itr.get_next()) + sess.run(next_element) def testSimpleDirectory(self): filenames = ['a', 'b', 'c'] @@ -56,13 +57,14 @@ class ListFilesDatasetOpTest(test.TestCase): dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*')) with self.test_session() as sess: itr = dataset.make_one_shot_iterator() + next_element = itr.get_next() full_filenames = [] produced_filenames = [] for filename in filenames: full_filenames.append( compat.as_bytes(path.join(self.tmp_dir, filename))) - produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) + produced_filenames.append(compat.as_bytes(sess.run(next_element))) self.assertItemsEqual(full_filenames, produced_filenames) with self.assertRaises(errors.OutOfRangeError): sess.run(itr.get_next()) @@ -73,12 +75,13 @@ class ListFilesDatasetOpTest(test.TestCase): with self.test_session() as sess: itr = dataset.make_initializable_iterator() + next_element = itr.get_next() sess.run( itr.initializer, feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')}) with self.assertRaises(errors.OutOfRangeError): - sess.run(itr.get_next()) + sess.run(next_element) def testSimpleDirectoryInitializer(self): filenames = ['a', 'b', 'c'] @@ -89,6 +92,7 @@ class ListFilesDatasetOpTest(test.TestCase): with self.test_session() as sess: itr = dataset.make_initializable_iterator() + next_element = itr.get_next() sess.run( itr.initializer, feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')}) @@ -98,7 +102,7 @@ class ListFilesDatasetOpTest(test.TestCase): for filename in filenames: full_filenames.append( compat.as_bytes(path.join(self.tmp_dir, filename))) - produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) + produced_filenames.append(compat.as_bytes(sess.run(next_element))) self.assertItemsEqual(full_filenames, produced_filenames) @@ -114,6 +118,7 @@ class ListFilesDatasetOpTest(test.TestCase): with self.test_session() as sess: itr = dataset.make_initializable_iterator() + next_element = itr.get_next() sess.run( itr.initializer, feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py')}) @@ -123,7 +128,7 @@ class ListFilesDatasetOpTest(test.TestCase): for filename in filenames[1:-1]: full_filenames.append( compat.as_bytes(path.join(self.tmp_dir, filename))) - produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) + produced_filenames.append(compat.as_bytes(sess.run(next_element))) self.assertItemsEqual(full_filenames, produced_filenames) with self.assertRaises(errors.OutOfRangeError): @@ -138,6 +143,7 @@ class ListFilesDatasetOpTest(test.TestCase): with self.test_session() as sess: itr = dataset.make_initializable_iterator() + next_element = itr.get_next() sess.run( itr.initializer, feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py*')}) @@ -147,13 +153,44 @@ class ListFilesDatasetOpTest(test.TestCase): for filename in filenames[1:]: full_filenames.append( compat.as_bytes(path.join(self.tmp_dir, filename))) - produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) + produced_filenames.append(compat.as_bytes(sess.run(next_element))) self.assertItemsEqual(full_filenames, produced_filenames) with self.assertRaises(errors.OutOfRangeError): sess.run(itr.get_next()) + def testNoShuffle(self): + filenames = ['a', 'b', 'c'] + self._touchTempFiles(filenames) + + # Repeat the list twice and ensure that the order is the same each time. + # NOTE(mrry): This depends on an implementation detail of `list_files()`, + # which is that the list of files is captured when the iterator is + # initialized. Otherwise, or if e.g. the iterator were initialized more than + # once, it's possible that the non-determinism of `tf.matching_files()` + # would cause this test to fail. However, it serves as a useful confirmation + # that the `shuffle=False` argument is working as intended. + # TODO(b/73959787): Provide some ordering guarantees so that this test is + # more meaningful. + dataset = dataset_ops.Dataset.list_files( + path.join(self.tmp_dir, '*'), shuffle=False).repeat(2) + with self.test_session() as sess: + itr = dataset.make_one_shot_iterator() + next_element = itr.get_next() + + full_filenames = [] + produced_filenames = [] + for filename in filenames * 2: + full_filenames.append( + compat.as_bytes(path.join(self.tmp_dir, filename))) + produced_filenames.append(compat.as_bytes(sess.run(next_element))) + with self.assertRaises(errors.OutOfRangeError): + sess.run(itr.get_next()) + self.assertItemsEqual(full_filenames, produced_filenames) + self.assertEqual(produced_filenames[:len(filenames)], + produced_filenames[len(filenames):]) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py index d7140088c310767d40bd2cf3413c899375acab15..1ddedfda4e1c9d6b6949f796be1870f167435763 100644 --- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py @@ -21,6 +21,7 @@ import gzip import os import zlib +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers from tensorflow.python.framework import constant_op @@ -736,12 +737,43 @@ class TFRecordDatasetTest(test.TestCase): one_mebibyte = 2**20 d = readers.TFRecordDataset(self.test_filenames, buffer_size=one_mebibyte) iterator = d.make_one_shot_iterator() + next_element = iterator.get_next() with self.test_session() as sess: for j in range(self._num_files): for i in range(self._num_records): - self.assertAllEqual(self._record(j, i), sess.run(iterator.get_next())) + self.assertAllEqual(self._record(j, i), sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) + sess.run(next_element) + + def testReadFromDatasetOfFiles(self): + files = dataset_ops.Dataset.from_tensor_slices(self.test_filenames) + d = readers.TFRecordDataset(files) + iterator = d.make_one_shot_iterator() + next_element = iterator.get_next() + with self.test_session() as sess: + for j in range(self._num_files): + for i in range(self._num_records): + self.assertAllEqual(self._record(j, i), sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testReadTenEpochsFromDatasetOfFilesInParallel(self): + files = dataset_ops.Dataset.from_tensor_slices( + self.test_filenames).repeat(10) + d = readers.TFRecordDataset(files, num_parallel_reads=4) + iterator = d.make_one_shot_iterator() + next_element = iterator.get_next() + expected = [] + actual = [] + with self.test_session() as sess: + for _ in range(10): + for j in range(self._num_files): + for i in range(self._num_records): + expected.append(self._record(j, i)) + actual.append(sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self.assertEqual(sorted(expected), sorted(actual)) if __name__ == "__main__": diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py index c089fb08c1082c1cf74d492796550980d6755591..5fcc48831f3ca744e015c92760f12ea4dbef2ff7 100644 --- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py @@ -132,6 +132,33 @@ class ShuffleDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testSeedZero(self): + """Test for same behavior when the seed is a Python or Tensor zero.""" + iterator = ( + dataset_ops.Dataset.range(10).shuffle(10, seed=0) + .make_one_shot_iterator()) + get_next = iterator.get_next() + + elems = [] + with self.test_session() as sess: + for _ in range(10): + elems.append(sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) + iterator = ( + dataset_ops.Dataset.range(10).shuffle(10, seed=seed_placeholder) + .make_initializable_iterator()) + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer, feed_dict={seed_placeholder: 0}) + for elem in elems: + self.assertEqual(elem, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + def testDefaultArguments(self): components = [0, 1, 2, 3, 4] iterator = (dataset_ops.Dataset.from_tensor_slices(components).shuffle(5) diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD index f12b358a7dc35c18338171e489fa88ba1a82d11b..3119ab003794cb9bc0c748dfeb47597e0877f5fd 100644 --- a/tensorflow/python/data/ops/BUILD +++ b/tensorflow/python/data/ops/BUILD @@ -23,6 +23,7 @@ py_library( "//tensorflow/python:tensor_util", "//tensorflow/python:util", "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:random_seed", "//tensorflow/python/data/util:sparse", "//third_party/py/numpy", ], @@ -34,6 +35,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":dataset_ops", + "//tensorflow/python:array_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -50,9 +52,11 @@ py_library( "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", + "//tensorflow/python/eager:context", ], ) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 3fb1f8d5479fc461a8d1f509c5eec2d0ed4a44c9..c0a6283be433aba80eab2375cbaed6f187e3c4c3 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -26,16 +26,17 @@ import six from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import random_seed from tensorflow.python.data.util import sparse from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import math_ops @@ -90,7 +91,7 @@ class Dataset(object): Raises: RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError( "dataset.make_initializable_iterator is not supported when eager " "execution is enabled.") @@ -110,11 +111,11 @@ class Dataset(object): self.output_types, self.output_shapes, self.output_classes) - def make_one_shot_iterator(self): + def __iter__(self): """Creates an `Iterator` for enumerating the elements of this dataset. - Note: The returned iterator will be initialized automatically. - A "one-shot" iterator does not currently support re-initialization. + The returned iterator implements the Python iterator protocol and therefore + can only be used in eager mode. Returns: An `Iterator` over the elements of this dataset. @@ -122,10 +123,23 @@ class Dataset(object): Raises: RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): - raise RuntimeError( - "dataset.make_one_shot_iterator is not supported when eager " - "execution is enabled.") + if context.executing_eagerly(): + return iterator_ops.EagerIterator(self) + else: + raise RuntimeError("dataset.__iter__() is only supported when eager " + "execution is enabled.") + + def make_one_shot_iterator(self): + """Creates an `Iterator` for enumerating the elements of this dataset. + + Note: The returned iterator will be initialized automatically. + A "one-shot" iterator does not currently support re-initialization. + + Returns: + An `Iterator` over the elements of this dataset. + """ + if context.executing_eagerly(): + return iterator_ops.EagerIterator(self) # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is # a 0-argument function. @function.Defun(capture_by_value=True) @@ -549,7 +563,7 @@ class Dataset(object): Args: buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the - maximum number elements that will be buffered when prefetching. + maximum number of elements that will be buffered when prefetching. Returns: Dataset: A `Dataset`. @@ -557,7 +571,7 @@ class Dataset(object): return PrefetchDataset(self, buffer_size) @staticmethod - def list_files(file_pattern): + def list_files(file_pattern, shuffle=None): """A dataset of all files matching a pattern. Example: @@ -570,16 +584,31 @@ class Dataset(object): - /path/to/dir/b.py - /path/to/dir/c.py - NOTE: The order of the file names returned can be non-deterministic. + NOTE: The order of the file names returned can be non-deterministic even + when `shuffle` is `False`. Args: file_pattern: A string or scalar string `tf.Tensor`, representing the filename pattern that will be matched. + shuffle: (Optional.) If `True`, the file names will be shuffled randomly. + Defaults to `True`. Returns: Dataset: A `Dataset` of strings corresponding to file names. """ - return Dataset.from_tensor_slices(gen_io_ops.matching_files(file_pattern)) + # TODO(b/73959787): Add a `seed` argument and make the `shuffle=False` + # behavior deterministic (e.g. by sorting the filenames). + if shuffle is None: + shuffle = True + matching_files = gen_io_ops.matching_files(file_pattern) + dataset = Dataset.from_tensor_slices(matching_files) + if shuffle: + # NOTE(mrry): The shuffle buffer size must be greater than zero, but the + # list of files might be empty. + buffer_size = math_ops.maximum( + array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1) + dataset = dataset.shuffle(buffer_size) + return dataset def repeat(self, count=None): """Repeats this dataset `count` times. @@ -758,11 +787,31 @@ class Dataset(object): def padded_batch(self, batch_size, padded_shapes, padding_values=None): """Combines consecutive elements of this dataset into padded batches. - Like `Dataset.dense_to_sparse_batch()`, this method combines - multiple consecutive elements of this dataset, which might have - different shapes, into a single element. The tensors in the - resulting element have an additional outer dimension, and are - padded to the respective shape in `padded_shapes`. + This transformation combines multiple consecutive elements of the input + dataset into a single element. Like @{tf.data.Dataset.batch}, the tensors + in the resulting element have an additional outer dimension, which will be + `batch_size` for all but the last element, and `N % batch_size` for the + last element (where `N` is the number of elements in this dataset). Unlike + @{tf.data.Dataset.batch}, the elements may have different shapes for some + of their components, and this transformation will pad each component to + the respective shape in `padding_shapes`. The `padding_shapes` argument + determines the resulting shape for each dimension of each component in an + output element: + + * If the dimension is a constant (e.g. `tf.Dimension(37)`), the component + will be padded out to that length in that dimension. + * If the dimension is unknown (e.g. `tf.Dimension(None)`), the component + will be padded out to the maximum length of all elements in that + dimension. + + NOTE: If the number of elements (`N`) in this dataset is not an exact + multiple of `batch_size`, the final batch contain smaller tensors with + shape `N % batch_size` in the batch dimension. If your program depends on + the batches having the same shape, consider using the + @{tf.contrib.data.padded_batch_and_drop_remainder} transformation instead. + + See also @{tf.contrib.data.dense_to_sparse_batch}, which combines elements + that may have different shapes into a @{tf.SparseTensor}. Args: batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of @@ -1484,16 +1533,7 @@ class ShuffleDataset(Dataset): self._input_dataset = input_dataset self._buffer_size = ops.convert_to_tensor( buffer_size, dtype=dtypes.int64, name="buffer_size") - seed, seed2 = random_seed.get_seed(seed) - if seed is None: - self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") - else: - self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") - if seed2 is None: - self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") - else: - self._seed2 = ops.convert_to_tensor( - seed2, dtype=dtypes.int64, name="seed2") + self._seed, self._seed2 = random_seed.get_seed(seed) if reshuffle_each_iteration is None: self._reshuffle_each_iteration = True else: @@ -1910,47 +1950,13 @@ class FlatMapDataset(Dataset): return self._output_types -class InterleaveDataset(Dataset): +class InterleaveDataset(FlatMapDataset): """A `Dataset` that maps a function over its input and interleaves the result. """ def __init__(self, input_dataset, map_func, cycle_length, block_length): """See `Dataset.interleave()` for details.""" - super(InterleaveDataset, self).__init__() - self._input_dataset = input_dataset - - @function.Defun(*nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes))) - def tf_map_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) - if _should_unpack_args(nested_args): - dataset = map_func(*nested_args) - else: - dataset = map_func(nested_args) - - if not isinstance(dataset, Dataset): - raise TypeError("`map_func` must return a `Dataset` object.") - - self._output_classes = dataset.output_classes - self._output_types = dataset.output_types - self._output_shapes = dataset.output_shapes - - return dataset._as_variant_tensor() # pylint: disable=protected-access - - self._map_func = tf_map_func - self._map_func.add_to_graph(ops.get_default_graph()) - + super(InterleaveDataset, self).__init__(input_dataset, map_func) self._cycle_length = ops.convert_to_tensor( cycle_length, dtype=dtypes.int64, name="cycle_length") self._block_length = ops.convert_to_tensor( @@ -1959,27 +1965,15 @@ class InterleaveDataset(Dataset): def _as_variant_tensor(self): return gen_dataset_ops.interleave_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._map_func.captured_inputs, + self._map_func.captured_inputs, # pylint: disable=protected-access self._cycle_length, self._block_length, - f=self._map_func, + f=self._map_func, # pylint: disable=protected-access output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(self.output_shapes, self.output_classes))) - @property - def output_classes(self): - return self._output_classes - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types - class FilterDataset(Dataset): """A `Dataset` that filters its input according to a predicate function.""" diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 4756ec74820bace5bea4e1f41ebe214420fe5c3d..d79b9d6011b6ebd00a47d572165cdbba8a31bd32 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -17,14 +17,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import threading import warnings from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.util.tf_export import tf_export @@ -412,3 +416,147 @@ class Iterator(object): of an element of this dataset. """ return self._output_types + + +_uid_counter = 0 +_uid_lock = threading.Lock() + + +def _generate_shared_name(prefix): + with _uid_lock: + global _uid_counter + uid = _uid_counter + _uid_counter += 1 + return "{}{}".format(prefix, uid) + + +class EagerIterator(object): + """An iterator producing tf.Tensor objects from a tf.data.Dataset.""" + + def __init__(self, dataset): + """Creates a new iterator over the given dataset. + + For example: + ```python + dataset = tf.data.Dataset.range(4) + for x in Iterator(dataset): + print(x) + ``` + + Tensors produced will be placed on the device on which this iterator object + was created. + + Args: + dataset: A `tf.data.Dataset` object. + + Raises: + RuntimeError: When invoked without eager execution enabled. + """ + + if not context.executing_eagerly(): + raise RuntimeError( + "{} objects can only be used when eager execution is enabled, use " + "tf.data.Dataset.make_initializable_iterator or " + "tf.data.Dataset.make_one_shot_iterator for graph construction". + format(type(self))) + with ops.device("/device:CPU:0"): + ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access + self._output_classes = dataset.output_classes + self._output_types = dataset.output_types + self._output_shapes = dataset.output_shapes + self._flat_output_types = nest.flatten( + sparse.as_dense_types(self._output_types, self._output_classes)) + self._flat_output_shapes = nest.flatten( + sparse.as_dense_shapes(self._output_shapes, self._output_classes)) + self._resource = gen_dataset_ops.iterator( + shared_name="", + container=_generate_shared_name("eageriterator"), + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + gen_dataset_ops.make_iterator(ds_variant, self._resource) + # Delete the resource when this object is deleted + self._resource_deleter = resource_variable_ops.EagerResourceDeleter( + handle=self._resource, handle_device="/device:CPU:0") + self._device = context.context().device_name + + def __iter__(self): + return self + + def __next__(self): # For Python 3 compatibility + return self.next() + + def _next_internal(self): + """Returns a nested structure of `tf.Tensor`s containing the next element. + """ + with ops.device(self._device): + # TODO(ashankar): Consider removing this ops.device() contextmanager + # and instead mimic ops placement in graphs: Operations on resource + # handles execute on the same device as where the resource is placed. + # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next` + # because in eager mode this code will run synchronously on the calling + # thread. Therefore we do not need to make a defensive context switch + # to a background thread, and can achieve a small constant performance + # boost by invoking the iterator synchronously. + ret = gen_dataset_ops.iterator_get_next_sync( + self._resource, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + + return sparse.deserialize_sparse_tensors( + nest.pack_sequence_as(self._output_types, ret), self._output_types, + self._output_shapes, self._output_classes) + + def next(self): + """Returns a nested structure of `tf.Tensor`s containing the next element. + """ + try: + return self._next_internal() + except errors.OutOfRangeError: + raise StopIteration + + @property + def output_classes(self): + """Returns the class of each component of an element of this iterator. + + The expected values are `tf.Tensor` and `tf.SparseTensor`. + + Returns: + A nested structure of Python `type` objects corresponding to each + component of an element of this dataset. + """ + return self._output_classes + + @property + def output_shapes(self): + """Returns the shape of each component of an element of this iterator. + + Returns: + A nested structure of `tf.TensorShape` objects corresponding to each + component of an element of this dataset. + """ + return self._output_shapes + + @property + def output_types(self): + """Returns the type of each component of an element of this iterator. + + Returns: + A nested structure of `tf.DType` objects corresponding to each component + of an element of this dataset. + """ + return self._output_types + + def get_next(self, name=None): + """Returns a nested structure of `tf.Tensor`s containing the next element. + + Args: + name: (Optional.) A name for the created operation. Currently unused. + + Returns: + A nested structure of `tf.Tensor` objects. + + Raises: + `tf.errors.OutOfRangeError`: If the end of the dataset has been reached. + """ + del name + return self._next_internal() diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py index fa7601741b11f018e9b53ed3b77a7561be50d3f4..fe033f5546498d57dd98289d2cda1a8bbb1c7822 100644 --- a/tensorflow/python/data/ops/readers.py +++ b/tensorflow/python/data/ops/readers.py @@ -17,11 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.ops.dataset_ops import Dataset +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import convert +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.util.tf_export import tf_export @@ -31,7 +34,7 @@ _DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB @tf_export("data.TextLineDataset") -class TextLineDataset(Dataset): +class TextLineDataset(dataset_ops.Dataset): """A `Dataset` comprising lines from one or more text files.""" def __init__(self, filenames, compression_type=None, buffer_size=None): @@ -73,8 +76,7 @@ class TextLineDataset(Dataset): return dtypes.string -@tf_export("data.TFRecordDataset") -class TFRecordDataset(Dataset): +class _TFRecordDataset(dataset_ops.Dataset): """A `Dataset` comprising records from one or more TFRecord files.""" def __init__(self, filenames, compression_type=None, buffer_size=None): @@ -87,7 +89,7 @@ class TFRecordDataset(Dataset): buffer_size: (Optional.) A `tf.int64` scalar representing the number of bytes in the read buffer. 0 means no buffering. """ - super(TFRecordDataset, self).__init__() + super(_TFRecordDataset, self).__init__() # Force the type to string even if filenames is an empty list. self._filenames = ops.convert_to_tensor( filenames, dtypes.string, name="filenames") @@ -118,8 +120,112 @@ class TFRecordDataset(Dataset): return dtypes.string +class ParallelInterleaveDataset(dataset_ops.InterleaveDataset): + """A `Dataset` that maps a function over its input and flattens the result.""" + + def __init__(self, input_dataset, map_func, cycle_length, block_length, + sloppy, buffer_output_elements, prefetch_input_elements): + """See `tf.contrib.data.parallel_interleave()` for details.""" + super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func, + cycle_length, block_length) + self._sloppy = ops.convert_to_tensor( + sloppy, dtype=dtypes.bool, name="sloppy") + self._buffer_output_elements = convert.optional_param_to_tensor( + "buffer_output_elements", + buffer_output_elements, + argument_default=2 * block_length) + self._prefetch_input_elements = convert.optional_param_to_tensor( + "prefetch_input_elements", + prefetch_input_elements, + argument_default=2 * cycle_length) + + def _as_variant_tensor(self): + # pylint: disable=protected-access + return gen_dataset_ops.parallel_interleave_dataset( + self._input_dataset._as_variant_tensor(), + self._map_func.captured_inputs, + self._cycle_length, + self._block_length, + self._sloppy, + self._buffer_output_elements, + self._prefetch_input_elements, + f=self._map_func, + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + # pylint: enable=protected-access + + +@tf_export("data.TFRecordDataset") +class TFRecordDataset(dataset_ops.Dataset): + """A `Dataset` comprising records from one or more TFRecord files.""" + + def __init__(self, filenames, compression_type=None, buffer_size=None, + num_parallel_reads=None): + """Creates a `TFRecordDataset` to read for one or more TFRecord files. + + NOTE: The `num_parallel_reads` argument can be used to improve performance + when reading from a remote filesystem. + + Args: + filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or + more filenames. + compression_type: (Optional.) A `tf.string` scalar evaluating to one of + `""` (no compression), `"ZLIB"`, or `"GZIP"`. + buffer_size: (Optional.) A `tf.int64` scalar representing the number of + bytes in the read buffer. 0 means no buffering. + num_parallel_reads: (Optional.) A `tf.int64` scalar representing the + number of files to read in parallel. Defaults to reading files + sequentially. + + Raises: + TypeError: If any argument does not have the expected type. + ValueError: If any argument does not have the expected shape. + """ + super(TFRecordDataset, self).__init__() + if isinstance(filenames, dataset_ops.Dataset): + if filenames.output_types != dtypes.string: + raise TypeError( + "`filenames` must be a `tf.data.Dataset` of `tf.string` elements.") + if not filenames.output_shapes.is_compatible_with(tensor_shape.scalar()): + raise ValueError( + "`filenames` must be a `tf.data.Dataset` of scalar `tf.string` " + "elements.") + else: + filenames = ops.convert_to_tensor(filenames, dtype=dtypes.string) + filenames = array_ops.reshape(filenames, [-1], name="flat_filenames") + filenames = dataset_ops.Dataset.from_tensor_slices(filenames) + + def read_one_file(filename): + return _TFRecordDataset(filename, compression_type, buffer_size) + + if num_parallel_reads is None: + self._impl = filenames.flat_map(read_one_file) + else: + self._impl = ParallelInterleaveDataset( + filenames, read_one_file, cycle_length=num_parallel_reads, + block_length=1, sloppy=False, buffer_output_elements=None, + prefetch_input_elements=None) + + def _as_variant_tensor(self): + return self._impl._as_variant_tensor() # pylint: disable=protected-access + + @property + def output_classes(self): + return self._impl.output_classes + + @property + def output_shapes(self): + return self._impl.output_shapes + + @property + def output_types(self): + return self._impl.output_types + + @tf_export("data.FixedLengthRecordDataset") -class FixedLengthRecordDataset(Dataset): +class FixedLengthRecordDataset(dataset_ops.Dataset): """A `Dataset` of fixed-length records from one or more binary files.""" def __init__(self, diff --git a/tensorflow/python/data/util/BUILD b/tensorflow/python/data/util/BUILD index e32c7b54a48dd887c2748897c3ce3661aab9f497..b1bdbdab37b63667b475c732df7a47d9e57f2b19 100644 --- a/tensorflow/python/data/util/BUILD +++ b/tensorflow/python/data/util/BUILD @@ -86,6 +86,30 @@ py_test( ], ) +py_library( + name = "random_seed", + srcs = ["random_seed.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + ], +) + +py_test( + name = "random_seed_test", + size = "small", + srcs = ["random_seed_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":random_seed", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:util", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/python/data/util/random_seed.py b/tensorflow/python/data/util/random_seed.py new file mode 100644 index 0000000000000000000000000000000000000000..e2c9d8672f94587fd3164f25f97b44a97526be07 --- /dev/null +++ b/tensorflow/python/data/util/random_seed.py @@ -0,0 +1,58 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for generating Tensor-valued random seeds.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def get_seed(seed): + """Returns the local seeds an operation should use given an op-specific seed. + + See @{tf.get_seed} for more details. This wrapper adds support for the case + where `seed` may be a tensor. + + Args: + seed: An integer or a @{tf.int64} scalar tensor. + + Returns: + A tuple of two @{tf.int64} scalar tensors that should be used for the local + seed of the calling dataset. + """ + seed, seed2 = random_seed.get_seed(seed) + if seed is None: + seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") + else: + seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") + if seed2 is None: + seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") + else: + with ops.name_scope("seed2") as scope: + seed2 = ops.convert_to_tensor(seed2, dtype=dtypes.int64) + seed2 = array_ops.where( + math_ops.logical_and( + math_ops.equal(seed, 0), math_ops.equal(seed2, 0)), + constant_op.constant(2**31 - 1, dtype=dtypes.int64), + seed2, + name=scope) + return seed, seed2 diff --git a/tensorflow/python/data/util/random_seed_test.py b/tensorflow/python/data/util/random_seed_test.py new file mode 100644 index 0000000000000000000000000000000000000000..33227e82afe6fe1c748693d107d4e9844abb8e09 --- /dev/null +++ b/tensorflow/python/data/util/random_seed_test.py @@ -0,0 +1,83 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utilities working with arbitrarily nested structures.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.util import random_seed as data_random_seed +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class RandomSeedTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testRandomSeed(self): + zero_t = constant_op.constant(0, dtype=dtypes.int64, name='zero') + one_t = constant_op.constant(1, dtype=dtypes.int64, name='one') + intmax_t = constant_op.constant( + 2**31 - 1, dtype=dtypes.int64, name='intmax') + test_cases = [ + # Each test case is a tuple with input to get_seed: + # (input_graph_seed, input_op_seed) + # and output from get_seed: + # (output_graph_seed, output_op_seed) + ((None, None), (0, 0)), + ((None, 1), (random_seed.DEFAULT_GRAPH_SEED, 1)), + ((1, 1), (1, 1)), + ((0, 0), (0, 2**31 - 1)), # Avoid nondeterministic (0, 0) output + ((2**31 - 1, 0), (0, 2**31 - 1)), # Don't wrap to (0, 0) either + ((0, 2**31 - 1), (0, 2**31 - 1)), # Wrapping for the other argument + # Once more, with tensor-valued arguments + ((None, one_t), (random_seed.DEFAULT_GRAPH_SEED, 1)), + ((1, one_t), (1, 1)), + ((0, zero_t), (0, 2**31 - 1)), # Avoid nondeterministic (0, 0) output + ((2**31 - 1, zero_t), (0, 2**31 - 1)), # Don't wrap to (0, 0) either + ((0, intmax_t), (0, 2**31 - 1)), # Wrapping for the other argument + ] + for tc in test_cases: + tinput, toutput = tc[0], tc[1] + random_seed.set_random_seed(tinput[0]) + g_seed, op_seed = data_random_seed.get_seed(tinput[1]) + g_seed = self.evaluate(g_seed) + op_seed = self.evaluate(op_seed) + msg = 'test_case = {0}, got {1}, want {2}'.format( + tinput, (g_seed, op_seed), toutput) + self.assertEqual((g_seed, op_seed), toutput, msg=msg) + random_seed.set_random_seed(None) + + if not context.executing_eagerly(): + random_seed.set_random_seed(1) + tinput = (1, None) + toutput = (1, ops.get_default_graph()._last_id) # pylint: disable=protected-access + random_seed.set_random_seed(tinput[0]) + g_seed, op_seed = data_random_seed.get_seed(tinput[1]) + g_seed = self.evaluate(g_seed) + op_seed = self.evaluate(op_seed) + msg = 'test_case = {0}, got {1}, want {2}'.format(1, (g_seed, op_seed), + toutput) + self.assertEqual((g_seed, op_seed), toutput, msg=msg) + random_seed.set_random_seed(None) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 253588fc3b2986af3ab8c6be5b0b85f178c06336..512d292ee2ffa3e61cca0952c0d530c5ec9b3d2a 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -957,7 +957,7 @@ cuda_py_test( cuda_py_test( name = "session_debug_grpc_test", - size = "large", + size = "medium", srcs = ["lib/session_debug_grpc_test.py"], additional_deps = [ ":debug_data", @@ -967,7 +967,6 @@ cuda_py_test( ":grpc_wrapper", ":hooks", ":session_debug_testlib", - "//third_party/py/numpy", "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -983,6 +982,29 @@ cuda_py_test( ], ) +cuda_py_test( + name = "grpc_large_data_test", + size = "medium", + srcs = ["lib/grpc_large_data_test.py"], + additional_deps = [ + ":dumping_wrapper", + ":grpc_debug_test_server", + ":grpc_wrapper", + ":session_debug_testlib", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], + tags = [ + "no_oss", # Test flaky due to port collisions. + "no_windows", + "oss_serial", + ], +) + # TODO(cais): Run the test in OSS, perhaps through a sh_test. cuda_py_test( name = "dist_session_debug_grpc_test", diff --git a/tensorflow/python/debug/README.md b/tensorflow/python/debug/README.md index a2273b050bb1ecd5a35938c3de57fb8562f1d26d..269bbb19bdb898d1d81d0b9c618a284a437e68b9 100644 --- a/tensorflow/python/debug/README.md +++ b/tensorflow/python/debug/README.md @@ -37,12 +37,18 @@ models: * Association of nodes and tensors in graphs with Python source lines * Profiling of models at the level of graph nodes and Python source lines. (Omitted internal-only feature) +* A [gRPC](https://grpc.io/)-based remote debugging protocol, which allows us to + build a browser-based graphical user interface (GUI) for TFDBG: the + [TensorBoard Debugger Plugin](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/debugger/README.md). ## How to use TFDBG? * For a walkthrough of TFDBG command-line interface, see https://www.tensorflow.org/programmers_guide/debugger. +* For information on the web GUI of TFDBG (TensorBoard Debugger Plugin), see + [this README](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/debugger/README.md). * For programmatic use of the API of TFDBG, see https://www.tensorflow.org/api_docs/python/tfdbg. + ## Related Publications * Cai, S., Breck E., Nielsen E., Salib M., Sculley D. (2016) TensorFlow Debugger: diff --git a/tensorflow/python/debug/cli/analyzer_cli.py b/tensorflow/python/debug/cli/analyzer_cli.py index 156afdfd4c44f2f1a07ffdd1e68ad48bbbe31cba..9a47cd12b47b35d0a85cfc1a211fdfee7cfa25bc 100644 --- a/tensorflow/python/debug/cli/analyzer_cli.py +++ b/tensorflow/python/debug/cli/analyzer_cli.py @@ -185,6 +185,15 @@ class DebugAnalyzer(object): type=str, default="", help="List only Tensors passing the filter of the specified name") + ap.add_argument( + "-fenn", + "--filter_exclude_node_names", + dest="filter_exclude_node_names", + type=str, + default="", + help="When applying the tensor filter, exclude node with names " + "matching the regular expression. Applicable only if --tensor_filter " + "or -f is used.") ap.add_argument( "-n", "--node_name_filter", @@ -484,6 +493,10 @@ class DebugAnalyzer(object): Returns: Output text lines as a RichTextLines object. + + Raises: + ValueError: If `--filter_exclude_node_names` is used without `-f` or + `--tensor_filter` being used. """ # TODO(cais): Add annotations of substrings for dumped tensor names, to @@ -520,8 +533,15 @@ class DebugAnalyzer(object): _add_main_menu(output, node_name=None, enable_list_tensors=False) return output - data_to_show = self._debug_dump.find(filter_callable) + data_to_show = self._debug_dump.find( + filter_callable, + exclude_node_names=parsed.filter_exclude_node_names) else: + if parsed.filter_exclude_node_names: + raise ValueError( + "The flag --filter_exclude_node_names is valid only when " + "the flag -f or --tensor_filter is used.") + data_to_show = self._debug_dump.dumped_tensor_data # TODO(cais): Implement filter by lambda on tensor value. diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py index 6b110fda9eba301f298e84b63d091bb300549bee..55231954d1c8ea987bbf87755dfde83d5efd03f0 100644 --- a/tensorflow/python/debug/cli/analyzer_cli_test.py +++ b/tensorflow/python/debug/cli/analyzer_cli_test.py @@ -820,6 +820,32 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase): op_type_regex="(Add|MatMul)") check_main_menu(self, out, list_tensors_enabled=False) + def testListTensorWithFilterAndNodeNameExclusionWorks(self): + # First, create and register the filter. + def is_2x1_vector(datum, tensor): + del datum # Unused. + return list(tensor.shape) == [2, 1] + self._analyzer.add_tensor_filter("is_2x1_vector", is_2x1_vector) + + # Use shorthand alias for the command prefix. + out = self._registry.dispatch_command( + "lt", ["-f", "is_2x1_vector", "--filter_exclude_node_names", ".*v.*"]) + + # If the --filter_exclude_node_names were not used, then the matching + # tensors would be: + # - simple_mul_add/v:0 + # - simple_mul_add/v/read:0 + # - simple_mul_add/matmul:0 + # - simple_mul_add/add:0 + # + # With the --filter_exclude_node_names option, only the last two should + # show up in the result. + assert_listed_tensors( + self, + out, ["simple_mul_add/matmul:0", "simple_mul_add/add:0"], + ["MatMul", "Add"], tensor_filter_name="is_2x1_vector") + check_main_menu(self, out, list_tensors_enabled=False) + def testListTensorsFilterNanOrInf(self): """Test register and invoke a tensor filter.""" diff --git a/tensorflow/python/debug/cli/curses_ui.py b/tensorflow/python/debug/cli/curses_ui.py index bb52f9051250625836b0d7a0f8e30265d9b34e92..f66cefb427c9ccfa0769655415193e8d2535e53c 100644 --- a/tensorflow/python/debug/cli/curses_ui.py +++ b/tensorflow/python/debug/cli/curses_ui.py @@ -1185,6 +1185,22 @@ class CursesUI(base_ui.BaseUI): self._main_menu = None self._main_menu_pad = None + def _pad_line_end_with_whitespace(self, pad, row, line_end_x): + """Pad the whitespace at the end of a line with the default color pair. + + Prevents spurious color pairs from appearing at the end of the lines in + certain text terimnals. + + Args: + pad: The curses pad object to operate on. + row: (`int`) row index. + line_end_x: (`int`) column index of the end of the line (beginning of + the whitespace). + """ + if line_end_x < self._max_x - 2: + pad.addstr(row, line_end_x, " " * (self._max_x - 3 - line_end_x), + self._default_color_pair) + def _screen_add_line_to_output_pad(self, pad, row, txt, color_segments=None): """Render a line in a text pad. @@ -1208,6 +1224,7 @@ class CursesUI(base_ui.BaseUI): if not color_segments: pad.addstr(row, 0, txt, self._default_color_pair) + self._pad_line_end_with_whitespace(pad, row, len(txt)) return if not isinstance(color_segments, list): @@ -1248,6 +1265,8 @@ class CursesUI(base_ui.BaseUI): for segment, color_pair in zip(all_segments, all_color_pairs): if segment[1] < self._max_x: pad.addstr(row, segment[0], txt[segment[0]:segment[1]], color_pair) + if all_segments: + self._pad_line_end_with_whitespace(pad, row, all_segments[-1][1]) def _screen_scroll_output_pad(self, pad, viewport_top, viewport_left, screen_location_top, screen_location_left, diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py index 8d355aa27f6fa10a1889420a9087800be12a81ce..8a65ad087b3002d8ad93f3a64f48715d26ff62d8 100644 --- a/tensorflow/python/debug/lib/debug_data.py +++ b/tensorflow/python/debug/lib/debug_data.py @@ -23,6 +23,7 @@ import glob import json import os import platform +import re import numpy as np import six @@ -1411,7 +1412,11 @@ class DebugDumpDir(object): return self._watch_key_to_datum[device_name].get(debug_watch_key, []) - def find(self, predicate, first_n=0, device_name=None): + def find(self, + predicate, + first_n=0, + device_name=None, + exclude_node_names=None): """Find dumped tensor data by a certain predicate. Args: @@ -1430,17 +1435,24 @@ class DebugDumpDir(object): time order) for which the predicate returns True. To return all the `DebugTensotDatum` instances, let first_n be <= 0. device_name: optional device name. + exclude_node_names: Optional regular expression to exclude nodes with + names matching the regular expression. Returns: A list of all `DebugTensorDatum` objects in this `DebugDumpDir` object for which predicate returns True, sorted in ascending order of the timestamp. """ + if exclude_node_names: + exclude_node_names = re.compile(exclude_node_names) matched_data = [] for device in (self._dump_tensor_data if device_name is None else (self._dump_tensor_data[device_name],)): for datum in self._dump_tensor_data[device]: + if exclude_node_names and exclude_node_names.match(datum.node_name): + continue + if predicate(datum, datum.get_tensor()): matched_data.append(datum) diff --git a/tensorflow/python/debug/lib/debug_gradients.py b/tensorflow/python/debug/lib/debug_gradients.py index 16f51a4b32f711b97077643cec669bb8970e0b21..589a13db7f798aef3bb82dfbd442deabfbcf2a41 100644 --- a/tensorflow/python/debug/lib/debug_gradients.py +++ b/tensorflow/python/debug/lib/debug_gradients.py @@ -156,11 +156,12 @@ class GradientsDebugger(object): # TODO(cais): Implement value_stack. grad_debug_op_name = _tensor_to_grad_debug_op_name(input_tensor, self._uuid) # pylint: disable=protected-access - identity_op = (gen_array_ops._debug_gradient_ref_identity - if input_tensor.dtype._is_ref_dtype - else gen_array_ops._debug_gradient_identity) - debug_grad_identity = identity_op(input_tensor, name=grad_debug_op_name) + identity_op = ( + gen_array_ops.debug_gradient_ref_identity + if input_tensor.dtype._is_ref_dtype else + gen_array_ops.debug_gradient_identity) # pylint: enable=protected-access + debug_grad_identity = identity_op(input_tensor, name=grad_debug_op_name) assert debug_grad_identity.dtype == input_tensor.dtype if debug_grad_identity.op.name != grad_debug_op_name: raise ValueError( diff --git a/tensorflow/python/debug/lib/grpc_large_data_test.py b/tensorflow/python/debug/lib/grpc_large_data_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5bc477a9baeb7116530fc9122b926458c1a6c08e --- /dev/null +++ b/tensorflow/python/debug/lib/grpc_large_data_test.py @@ -0,0 +1,210 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sending large-size data through tfdbg grpc channels. + +"Large-size data" includes large GraphDef protos and large Tensor protos. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.python.debug.lib import grpc_debug_test_server +from tensorflow.python.debug.lib import session_debug_testlib +from tensorflow.python.debug.wrappers import framework +from tensorflow.python.debug.wrappers import grpc_wrapper +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase): + + @classmethod + def setUpClass(cls): + (cls.debug_server_port, cls.debug_server_url, _, cls.debug_server_thread, + cls.debug_server + ) = grpc_debug_test_server.start_server_on_separate_thread( + dump_to_filesystem=False) + tf_logging.info("debug server url: %s", cls.debug_server_url) + + @classmethod + def tearDownClass(cls): + cls.debug_server.stop_server().wait() + cls.debug_server_thread.join() + + def tearDown(self): + ops.reset_default_graph() + self.debug_server.clear_data() + + def testSendingLargeGraphDefsWorks(self): + with self.test_session( + use_gpu=True, + config=session_debug_testlib.no_rewrite_session_config()) as sess: + u = variables.Variable(42.0, name="original_u") + for _ in xrange(50 * 1000): + u = array_ops.identity(u) + sess.run(variables.global_variables_initializer()) + + def watch_fn(fetches, feeds): + del fetches, feeds + return framework.WatchOptions( + debug_ops=["DebugIdentity"], + node_name_regex_whitelist=r"original_u") + sess = grpc_wrapper.GrpcDebugWrapperSession( + sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) + self.assertAllClose(42.0, sess.run(u)) + + self.assertAllClose( + [42.0], + self.debug_server.debug_tensor_values["original_u:0:DebugIdentity"]) + self.assertEqual(2 if test.is_gpu_available() else 1, + len(self.debug_server.partition_graph_defs)) + max_graph_def_size = max([ + len(graph_def.SerializeToString()) + for graph_def in self.debug_server.partition_graph_defs]) + self.assertGreater(max_graph_def_size, 4 * 1024 * 1024) + + def testSendingLargeFloatTensorWorks(self): + with self.test_session( + use_gpu=True, + config=session_debug_testlib.no_rewrite_session_config()) as sess: + u_init_val_array = list(xrange(1200 * 1024)) + # Size: 4 * 1200 * 1024 = 4800k > 4M + + u_init = constant_op.constant( + u_init_val_array, dtype=dtypes.float32, name="u_init") + u = variables.Variable(u_init, name="u") + + def watch_fn(fetches, feeds): + del fetches, feeds # Unused by this watch_fn. + return framework.WatchOptions( + debug_ops=["DebugIdentity"], + node_name_regex_whitelist=r"u_init") + sess = grpc_wrapper.GrpcDebugWrapperSession( + sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) + sess.run(u.initializer) + + self.assertAllEqual( + u_init_val_array, + self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0]) + + def testSendingStringTensorWithAlmostTooLargeStringsWorks(self): + with self.test_session( + use_gpu=True, + config=session_debug_testlib.no_rewrite_session_config()) as sess: + u_init_val = [ + b"", b"spam", b"A" * 2500 * 1024, b"B" * 2500 * 1024, b"egg", b""] + u_init = constant_op.constant( + u_init_val, dtype=dtypes.string, name="u_init") + u = variables.Variable(u_init, name="u") + + def watch_fn(fetches, feeds): + del fetches, feeds + return framework.WatchOptions( + debug_ops=["DebugIdentity"], + node_name_regex_whitelist=r"u_init") + sess = grpc_wrapper.GrpcDebugWrapperSession( + sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) + sess.run(u.initializer) + + self.assertAllEqual( + u_init_val, + self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0]) + + def testSendingLargeStringTensorWorks(self): + with self.test_session( + use_gpu=True, + config=session_debug_testlib.no_rewrite_session_config()) as sess: + strs_total_size_threshold = 5000 * 1024 + cum_size = 0 + u_init_val_array = [] + while cum_size < strs_total_size_threshold: + strlen = np.random.randint(200) + u_init_val_array.append(b"A" * strlen) + cum_size += strlen + + u_init = constant_op.constant( + u_init_val_array, dtype=dtypes.string, name="u_init") + u = variables.Variable(u_init, name="u") + + def watch_fn(fetches, feeds): + del fetches, feeds + return framework.WatchOptions( + debug_ops=["DebugIdentity"], + node_name_regex_whitelist=r"u_init") + sess = grpc_wrapper.GrpcDebugWrapperSession( + sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) + sess.run(u.initializer) + + self.assertAllEqual( + u_init_val_array, + self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0]) + + def testSendingEmptyFloatTensorWorks(self): + with self.test_session( + use_gpu=True, + config=session_debug_testlib.no_rewrite_session_config()) as sess: + u_init = constant_op.constant( + [], dtype=dtypes.float32, shape=[0], name="u_init") + u = variables.Variable(u_init, name="u") + + def watch_fn(fetches, feeds): + del fetches, feeds + return framework.WatchOptions( + debug_ops=["DebugIdentity"], + node_name_regex_whitelist=r"u_init") + sess = grpc_wrapper.GrpcDebugWrapperSession( + sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) + sess.run(u.initializer) + + u_init_value = self.debug_server.debug_tensor_values[ + "u_init:0:DebugIdentity"][0] + self.assertEqual(np.float32, u_init_value.dtype) + self.assertEqual(0, len(u_init_value)) + + def testSendingEmptyStringTensorWorks(self): + with self.test_session( + use_gpu=True, + config=session_debug_testlib.no_rewrite_session_config()) as sess: + u_init = constant_op.constant( + [], dtype=dtypes.string, shape=[0], name="u_init") + u = variables.Variable(u_init, name="u") + + def watch_fn(fetches, feeds): + del fetches, feeds + return framework.WatchOptions( + debug_ops=["DebugIdentity"], + node_name_regex_whitelist=r"u_init") + sess = grpc_wrapper.GrpcDebugWrapperSession( + sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) + sess.run(u.initializer) + + u_init_value = self.debug_server.debug_tensor_values[ + "u_init:0:DebugIdentity"][0] + self.assertEqual(np.object, u_init_value.dtype) + self.assertEqual(0, len(u_init_value)) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/debug/lib/session_debug_file_test.py b/tensorflow/python/debug/lib/session_debug_file_test.py index 1a6bedbbcbf94eb95e49d43e2d03c85b53bebb7b..ba0f15b4e2ff23295eae764088144a3d1b533f01 100644 --- a/tensorflow/python/debug/lib/session_debug_file_test.py +++ b/tensorflow/python/debug/lib/session_debug_file_test.py @@ -22,7 +22,6 @@ import shutil import tempfile from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.lib import debug_utils @@ -36,13 +35,6 @@ from tensorflow.python.platform import googletest class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase): - def _no_rewrite_session_config(self): - rewriter_config = rewriter_config_pb2.RewriterConfig( - disable_model_pruning=True, - arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF) - graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) - return config_pb2.ConfigProto(graph_options=graph_options) - def _debug_urls(self, run_number=None): return ["file://%s" % self._debug_dump_dir(run_number=run_number)] @@ -55,7 +47,8 @@ class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase): def testAllowsDifferentWatchesOnDifferentRuns(self): """Test watching different tensors on different runs of the same graph.""" - with session.Session(config=self._no_rewrite_session_config()) as sess: + with session.Session( + config=session_debug_testlib.no_rewrite_session_config()) as sess: u_init_val = [[5.0, 3.0], [-1.0, 0.0]] v_init_val = [[2.0], [-1.0]] diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py index b623ee31c5dc59894373ec7952e53acd0f6e1126..ff49b6954776264ccb2eceeceab7da5a881081f0 100644 --- a/tensorflow/python/debug/lib/session_debug_grpc_test.py +++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py @@ -24,11 +24,9 @@ from __future__ import print_function import os import shutil -import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.lib import debug_utils @@ -38,28 +36,15 @@ from tensorflow.python.debug.wrappers import framework from tensorflow.python.debug.wrappers import grpc_wrapper from tensorflow.python.debug.wrappers import hooks from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest -from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging from tensorflow.python.training import monitored_session -def no_rewrite_session_config(): - rewriter_config = rewriter_config_pb2.RewriterConfig( - disable_model_pruning=True, - arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, - dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) - graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) - return config_pb2.ConfigProto(graph_options=graph_options) - - class GrpcDebugServerTest(test_util.TensorFlowTestCase): def testRepeatedRunServerRaisesException(self): @@ -142,19 +127,22 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): return os.path.join(self._dump_root, "run_%d" % run_number) def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException(self): - sess = session.Session(config=no_rewrite_session_config()) + sess = session.Session( + config=session_debug_testlib.no_rewrite_session_config()) with self.assertRaisesRegexp( TypeError, "Expected type str or list in grpc_debug_server_addresses"): grpc_wrapper.GrpcDebugWrapperSession(sess, 1337) def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException2(self): - sess = session.Session(config=no_rewrite_session_config()) + sess = session.Session( + config=session_debug_testlib.no_rewrite_session_config()) with self.assertRaisesRegexp( TypeError, "Expected type str in list grpc_debug_server_addresses"): grpc_wrapper.GrpcDebugWrapperSession(sess, ["localhost:1337", 1338]) def testUseInvalidWatchFnTypeWithGrpcDebugWrapperSessionRaisesException(self): - sess = session.Session(config=no_rewrite_session_config()) + sess = session.Session( + config=session_debug_testlib.no_rewrite_session_config()) with self.assertRaises(TypeError): grpc_wrapper.GrpcDebugWrapperSession( sess, "localhost:%d" % self._server_port, watch_fn="foo") @@ -164,7 +152,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): v = variables.Variable(20.0, name="v") w = math_ops.multiply(u, v, name="w") - sess = session.Session(config=no_rewrite_session_config()) + sess = session.Session( + config=session_debug_testlib.no_rewrite_session_config()) sess.run(u.initializer) sess.run(v.initializer) @@ -190,7 +179,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): v = variables.Variable(20.0, name="v") w = math_ops.multiply(u, v, name="w") - sess = session.Session(config=no_rewrite_session_config()) + sess = session.Session( + config=session_debug_testlib.no_rewrite_session_config()) sess.run(u.initializer) sess.run(v.initializer) @@ -223,7 +213,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): v = variables.Variable(20.0, name="v") w = math_ops.multiply(u, v, name="w") - sess = session.Session(config=no_rewrite_session_config()) + sess = session.Session( + config=session_debug_testlib.no_rewrite_session_config()) sess.run(u.initializer) sess.run(v.initializer) @@ -254,7 +245,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): v = variables.Variable(20.0, name="v") w = math_ops.multiply(u, v, name="w") - sess = session.Session(config=no_rewrite_session_config()) + sess = session.Session( + config=session_debug_testlib.no_rewrite_session_config()) sess.run(u.initializer) sess.run(v.initializer) @@ -298,7 +290,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): v = variables.Variable(20.0, name="v") w = math_ops.multiply(u, v, name="w") - sess = session.Session(config=no_rewrite_session_config()) + sess = session.Session( + config=session_debug_testlib.no_rewrite_session_config()) sess.run(variables.global_variables_initializer()) grpc_debug_hook = hooks.TensorBoardDebugHook( @@ -324,168 +317,6 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): hooks.GrpcDebugHook(["foo:42424"]) -class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase): - - @classmethod - def setUpClass(cls): - (cls.debug_server_port, cls.debug_server_url, _, cls.debug_server_thread, - cls.debug_server - ) = grpc_debug_test_server.start_server_on_separate_thread( - dump_to_filesystem=False) - tf_logging.info("debug server url: %s", cls.debug_server_url) - - @classmethod - def tearDownClass(cls): - cls.debug_server.stop_server().wait() - cls.debug_server_thread.join() - - def tearDown(self): - ops.reset_default_graph() - self.debug_server.clear_data() - - def testSendingLargeGraphDefsWorks(self): - with self.test_session( - use_gpu=True, config=no_rewrite_session_config()) as sess: - u = variables.Variable(42.0, name="original_u") - for _ in xrange(50 * 1000): - u = array_ops.identity(u) - sess.run(variables.global_variables_initializer()) - - def watch_fn(fetches, feeds): - del fetches, feeds - return framework.WatchOptions( - debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"original_u") - sess = grpc_wrapper.GrpcDebugWrapperSession( - sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) - self.assertAllClose(42.0, sess.run(u)) - - self.assertAllClose( - [42.0], - self.debug_server.debug_tensor_values["original_u:0:DebugIdentity"]) - self.assertEqual(2 if test.is_gpu_available() else 1, - len(self.debug_server.partition_graph_defs)) - max_graph_def_size = max([ - len(graph_def.SerializeToString()) - for graph_def in self.debug_server.partition_graph_defs]) - self.assertGreater(max_graph_def_size, 4 * 1024 * 1024) - - def testSendingLargeFloatTensorWorks(self): - with self.test_session( - use_gpu=True, config=no_rewrite_session_config()) as sess: - u_init_val_array = list(xrange(1200 * 1024)) - # Size: 4 * 1200 * 1024 = 4800k > 4M - - u_init = constant_op.constant( - u_init_val_array, dtype=dtypes.float32, name="u_init") - u = variables.Variable(u_init, name="u") - - def watch_fn(fetches, feeds): - del fetches, feeds # Unused by this watch_fn. - return framework.WatchOptions( - debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"u_init") - sess = grpc_wrapper.GrpcDebugWrapperSession( - sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) - sess.run(u.initializer) - - self.assertAllEqual( - u_init_val_array, - self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0]) - - def testSendingStringTensorWithAlmostTooLargeStringsWorks(self): - with self.test_session( - use_gpu=True, config=no_rewrite_session_config()) as sess: - u_init_val = [ - b"", b"spam", b"A" * 2500 * 1024, b"B" * 2500 * 1024, b"egg", b""] - u_init = constant_op.constant( - u_init_val, dtype=dtypes.string, name="u_init") - u = variables.Variable(u_init, name="u") - - def watch_fn(fetches, feeds): - del fetches, feeds - return framework.WatchOptions( - debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"u_init") - sess = grpc_wrapper.GrpcDebugWrapperSession( - sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) - sess.run(u.initializer) - - self.assertAllEqual( - u_init_val, - self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0]) - - def testSendingLargeStringTensorWorks(self): - with self.test_session( - use_gpu=True, config=no_rewrite_session_config()) as sess: - strs_total_size_threshold = 5000 * 1024 - cum_size = 0 - u_init_val_array = [] - while cum_size < strs_total_size_threshold: - strlen = np.random.randint(200) - u_init_val_array.append(b"A" * strlen) - cum_size += strlen - - u_init = constant_op.constant( - u_init_val_array, dtype=dtypes.string, name="u_init") - u = variables.Variable(u_init, name="u") - - def watch_fn(fetches, feeds): - del fetches, feeds - return framework.WatchOptions( - debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"u_init") - sess = grpc_wrapper.GrpcDebugWrapperSession( - sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) - sess.run(u.initializer) - - self.assertAllEqual( - u_init_val_array, - self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0]) - - def testSendingEmptyFloatTensorWorks(self): - with self.test_session( - use_gpu=True, config=no_rewrite_session_config()) as sess: - u_init = constant_op.constant( - [], dtype=dtypes.float32, shape=[0], name="u_init") - u = variables.Variable(u_init, name="u") - - def watch_fn(fetches, feeds): - del fetches, feeds - return framework.WatchOptions( - debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"u_init") - sess = grpc_wrapper.GrpcDebugWrapperSession( - sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) - sess.run(u.initializer) - - u_init_value = self.debug_server.debug_tensor_values[ - "u_init:0:DebugIdentity"][0] - self.assertEqual(np.float32, u_init_value.dtype) - self.assertEqual(0, len(u_init_value)) - - def testSendingEmptyStringTensorWorks(self): - with self.test_session( - use_gpu=True, config=no_rewrite_session_config()) as sess: - u_init = constant_op.constant( - [], dtype=dtypes.string, shape=[0], name="u_init") - u = variables.Variable(u_init, name="u") - - def watch_fn(fetches, feeds): - del fetches, feeds - return framework.WatchOptions( - debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"u_init") - sess = grpc_wrapper.GrpcDebugWrapperSession( - sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) - sess.run(u.initializer) - - u_init_value = self.debug_server.debug_tensor_values[ - "u_init:0:DebugIdentity"][0] - self.assertEqual(np.object, u_init_value.dtype) - self.assertEqual(0, len(u_init_value)) - - class SessionDebugConcurrentTest( session_debug_testlib.DebugConcurrentRunCallsTest): @@ -548,7 +379,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase): self._server_2.clear_data() def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenDebugNodes(self): - with session.Session(config=no_rewrite_session_config()) as sess: + with session.Session( + config=session_debug_testlib.no_rewrite_session_config()) as sess: v_1 = variables.Variable(50.0, name="v_1") v_2 = variables.Variable(-50.0, name="v_1") delta_1 = constant_op.constant(5.0, name="delta_1") @@ -617,7 +449,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase): ("toggled_2", 0, "DebugIdentity")]) self._servers_and_threads.append((server, server_thread)) - with session.Session(config=no_rewrite_session_config()) as sess: + with session.Session( + config=session_debug_testlib.no_rewrite_session_config()) as sess: v_1 = variables.Variable(50.0, name="v_1") v_2 = variables.Variable(-50.0, name="v_1") # These two nodes have names that match those in the @@ -656,7 +489,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase): self.assertEqual(0, len(server.debug_tensor_values)) def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenServers(self): - with session.Session(config=no_rewrite_session_config()) as sess: + with session.Session( + config=session_debug_testlib.no_rewrite_session_config()) as sess: v = variables.Variable(50.0, name="v") delta = constant_op.constant(5.0, name="delta") inc_v = state_ops.assign_add(v, delta, name="inc_v") @@ -698,7 +532,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase): self.assertEqual(0, len(self._server_2.debug_tensor_values)) def testToggleBreakpointsWorks(self): - with session.Session(config=no_rewrite_session_config()) as sess: + with session.Session( + config=session_debug_testlib.no_rewrite_session_config()) as sess: v_1 = variables.Variable(50.0, name="v_1") v_2 = variables.Variable(-50.0, name="v_2") delta_1 = constant_op.constant(5.0, name="delta_1") @@ -755,7 +590,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase): self.assertSetEqual(set(), self._server_1.breakpoints) def testTensorBoardDebuggerWrapperToggleBreakpointsWorks(self): - with session.Session(config=no_rewrite_session_config()) as sess: + with session.Session( + config=session_debug_testlib.no_rewrite_session_config()) as sess: v_1 = variables.Variable(50.0, name="v_1") v_2 = variables.Variable(-50.0, name="v_2") delta_1 = constant_op.constant(5.0, name="delta_1") @@ -827,7 +663,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase): self._server_1.query_source_file_line(__file__, 1) def testTensorBoardDebuggerWrapperDisablingTracebackSourceSendingWorks(self): - with session.Session(config=no_rewrite_session_config()) as sess: + with session.Session( + config=session_debug_testlib.no_rewrite_session_config()) as sess: v_1 = variables.Variable(50.0, name="v_1") v_2 = variables.Variable(-50.0, name="v_2") delta_1 = constant_op.constant(5.0, name="delta_1") diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py index f4fac1401918ccacd38aae5ad2ef8d686c9204b9..070d9c4cd7094c81b18192e75885ae6dd6729cbf 100644 --- a/tensorflow/python/debug/lib/session_debug_testlib.py +++ b/tensorflow/python/debug/lib/session_debug_testlib.py @@ -669,6 +669,55 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase): self.assertEqual(1, len(first_bad_datum)) self.assertEqual(x_name, first_bad_datum[0].node_name) + def testFindInfOrNanWithOpNameExclusion(self): + with session.Session() as sess: + u_name = "testFindInfOrNanWithOpNameExclusion/u" + v_name = "testFindInfOrNanWithOpNameExclusion/v" + w_name = "testFindInfOrNanWithOpNameExclusion/w" + x_name = "testFindInfOrNanWithOpNameExclusion/x" + y_name = "testFindInfOrNanWithOpNameExclusion/y" + z_name = "testFindInfOrNanWithOpNameExclusion/z" + + u_init = constant_op.constant([2.0, 4.0]) + u = variables.Variable(u_init, name=u_name) + v_init = constant_op.constant([2.0, 1.0]) + v = variables.Variable(v_init, name=v_name) + + # Expected output: [0.0, 3.0] + w = math_ops.subtract(u, v, name=w_name) + + # Expected output: [inf, 1.3333] + x = math_ops.div(u, w, name=x_name) + + # Expected output: [nan, 4.0] + y = math_ops.multiply(w, x, name=y_name) + + z = math_ops.multiply(y, y, name=z_name) + + u.initializer.run() + v.initializer.run() + + _, dump = self._debug_run_and_get_dump( + sess, z, + expected_partition_graph_count=self._expected_partition_graph_count) + + # Find all "offending tensors". + bad_data = dump.find(debug_data.has_inf_or_nan, + exclude_node_names=".*/x$") + + # Verify that the nodes with bad values are caught through running find + # on the debug dump. + self.assertEqual(2, len(bad_data)) + # Assert that the node `x` should have been excluded. + self.assertEqual(y_name, bad_data[0].node_name) + self.assertEqual(z_name, bad_data[1].node_name) + + first_bad_datum = dump.find( + debug_data.has_inf_or_nan, first_n=1, exclude_node_names=".*/x$") + + self.assertEqual(1, len(first_bad_datum)) + self.assertEqual(y_name, first_bad_datum[0].node_name) + def _session_run_for_graph_structure_lookup(self): with session.Session(config=no_rewrite_session_config()) as sess: u_name = "testDumpGraphStructureLookup/u" diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py index 1465cb72950c8fa6a453ebd4290bbf6382173ff8..c8625655e51a43a222addedd4beecdd3515d7fb6 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py @@ -115,6 +115,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): # unavailable (i.e., is None), the run-start CLI will be launched to ask # the user. This is the case, e.g., right before the first run starts. self._active_tensor_filter = None + self._active_filter_exclude_node_names = None self._active_tensor_filter_run_start_response = None self._run_through_times = 1 self._skip_debug = False @@ -148,6 +149,15 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): type=str, default="", help="Run until a tensor in the graph passes the specified filter.") + ap.add_argument( + "-fenn", + "--filter_exclude_node_names", + dest="filter_exclude_node_names", + type=str, + default="", + help="When applying the tensor filter, exclude node with names " + "matching the regular expression. Applicable only if --tensor_filter " + "or -f is used.") ap.add_argument( "--node_name_filter", dest="node_name_filter", @@ -324,9 +334,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): debug_dump.set_python_graph(self._sess.graph) passed_filter = None + passed_filter_exclude_node_names = None if self._active_tensor_filter: if not debug_dump.find( - self._tensor_filters[self._active_tensor_filter], first_n=1): + self._tensor_filters[self._active_tensor_filter], first_n=1, + exclude_node_names=self._active_filter_exclude_node_names): # No dumped tensor passes the filter in this run. Clean up the dump # directory and move on. self._remove_dump_root() @@ -334,10 +346,14 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): else: # Some dumped tensor(s) from this run passed the filter. passed_filter = self._active_tensor_filter + passed_filter_exclude_node_names = ( + self._active_filter_exclude_node_names) self._active_tensor_filter = None + self._active_filter_exclude_node_names = None self._prep_debug_cli_for_run_end( - debug_dump, request.tf_error, passed_filter) + debug_dump, request.tf_error, passed_filter, + passed_filter_exclude_node_names) self._run_start_response = self._launch_cli() @@ -358,7 +374,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): if os.path.isdir(self._dump_root): shutil.rmtree(self._dump_root) - def _prep_debug_cli_for_run_end(self, debug_dump, tf_error, passed_filter): + def _prep_debug_cli_for_run_end(self, + debug_dump, + tf_error, + passed_filter, + passed_filter_exclude_node_names): """Prepare (but not launch) CLI for run-end, with debug dump from the run. Args: @@ -368,6 +388,9 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): (if any). passed_filter: (None or str) Name of the tensor filter that just passed and caused the preparation of this run-end CLI (if any). + passed_filter_exclude_node_names: (None or str) Regular expression used + with the tensor filter to exclude ops with names matching the regular + expresssion. """ if tf_error: @@ -383,6 +406,9 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): if passed_filter is not None: # Some dumped tensor(s) from this run passed the filter. self._init_command = "lt -f %s" % passed_filter + if passed_filter_exclude_node_names: + self._init_command += (" --filter_exclude_node_names %s" % + passed_filter_exclude_node_names) self._title_color = "red_on_white" self._run_cli = analyzer_cli.create_analyzer_ui( @@ -496,6 +522,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): parsed.op_type_filter = parsed.op_type_filter or None parsed.tensor_dtype_filter = parsed.tensor_dtype_filter or None + if parsed.filter_exclude_node_names and not parsed.till_filter_pass: + raise ValueError( + "The --filter_exclude_node_names (or -feon) flag is valid only if " + "the --till_filter_pass (or -f) flag is used.") + if parsed.profile: raise debugger_cli_common.CommandLineExit( exit_token=framework.OnRunStartResponse( @@ -525,6 +556,8 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): if parsed.till_filter_pass in self._tensor_filters: action = framework.OnRunStartAction.DEBUG_RUN self._active_tensor_filter = parsed.till_filter_pass + self._active_filter_exclude_node_names = ( + parsed.filter_exclude_node_names) self._active_tensor_filter_run_start_response = run_start_response else: # Handle invalid filter name. diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py index 490812c96d83791cdc20c56f16c968f1a1851af8..b06fa26a935b42709575f8e400e0bda951ffbbc7 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py @@ -87,7 +87,11 @@ class LocalCLIDebuggerWrapperSessionForTest( def _prep_cli_for_run_start(self): pass - def _prep_debug_cli_for_run_end(self, debug_dump, tf_error, passed_filter): + def _prep_debug_cli_for_run_end(self, + debug_dump, + tf_error, + passed_filter, + passed_filter_exclude_op_names): self.observers["debug_dumps"].append(debug_dump) self.observers["tf_errors"].append(tf_error) @@ -451,6 +455,36 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase): self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"])) self.assertEqual([None, None], wrapped_sess.observers["tf_errors"]) + def testRunTillFilterPassesWithExcludeOpNames(self): + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run", "-f", "greater_than_twelve", + "--filter_exclude_node_names", "inc_v.*"], + ["run"], ["run"]], + self.sess, + dump_root=self._tmp_dir) + + def greater_than_twelve(datum, tensor): + del datum # Unused. + return tensor > 12.0 + + # Verify that adding the same tensor filter more than once is tolerated + # (i.e., as if it were added only once). + wrapped_sess.add_tensor_filter("greater_than_twelve", greater_than_twelve) + + # run five times. + wrapped_sess.run(self.inc_v) + wrapped_sess.run(self.inc_v) + wrapped_sess.run(self.inc_v) + wrapped_sess.run(self.inc_v) + + self.assertAllClose(14.0, self.sess.run(self.v)) + + self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"]) + + # Due to the --filter_exclude_op_names flag, the run-end CLI should show up + # not after run 3, but after run 4. + self.assertEqual([4], wrapped_sess.observers["run_end_cli_run_numbers"]) + def testRunTillFilterPassesWorksInConjunctionWithOtherNodeNameFilter(self): """Test that --.*_filter flags work in conjunction with -f. diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index ab81d40148476735492890f608315b19eaa0a33f..0e089a26eb88061ece54008a68c51de41b7b362b 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -42,7 +42,6 @@ py_library( ":backprop", ":context", ":core", - ":custom_gradient", ":execute", ":function", ":graph_callable", @@ -103,10 +102,10 @@ cuda_py_test( additional_deps = [ ":backprop", ":context", - ":custom_gradient", ":test", "//tensorflow/python:embedding_ops", "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:resource_variable_ops", @@ -206,21 +205,6 @@ cc_library( ], ) -py_library( - name = "custom_gradient", - srcs = ["custom_gradient.py"], - srcs_version = "PY2AND3", - visibility = ["//tensorflow:internal"], - deps = [ - ":context", - ":tape", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:util", - ], -) - py_library( name = "graph_only_ops", srcs = ["graph_only_ops.py"], @@ -364,7 +348,6 @@ py_test( deps = [ ":backprop", ":context", - ":custom_gradient", ":test", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 14bcc60006228eeaabea241ee18d960174a9dbea..c54a5a1445df73e16688e776eddd4edf9d026535 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import functools import operator import threading @@ -41,26 +40,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect - - -class _TensorCache(object): - """Simple cache which evicts items based on length in a FIFO manner.""" - - def __init__(self, max_items=256): - self._data = collections.OrderedDict() - self._max_items = max_items if max_items else 256 - - def put(self, key, value): - self._data[key] = value - - if len(self._data) > self._max_items: - self._data.popitem(last=False) - - def get(self, key): - return self._data.get(key, None) - - def flush(self): - self._data = {} +from tensorflow.python.util.tf_export import tf_export _op_attr_type_cache = {} @@ -106,6 +86,14 @@ class _MockOp(object): return make_attr(typ, self.attrs[i + 1]) raise KeyError(attr) + def _get_control_flow_context(self): + raise NotImplementedError( + "tf.GradientTape.gradients() does not support graph control flow " + "operations like tf.cond or tf.while at this time. Use tf.gradients() " + "instead. If you need this feature, please file a feature request at " + "https://github.com/tensorflow/tensorflow/issues/new" + ) + def _magic_gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs, out_grads): @@ -183,8 +171,8 @@ def implicit_val_and_grad(f): """Returns a function which differentiates f with respect to variables. The wrapped function returns the value and the gradient of f when called with - the same arguments. The gradient is with respect to all TFE variables which - have `variable.watch()` called on them by f. + the same arguments. The gradient is with respect to all trainable TFE + variables accessed by `f`. This function is useful when the exact set of variables to differentiate with is not known ahead of time. @@ -261,8 +249,8 @@ def implicit_grad(f): """Returns a function which differentiates f with respect to variables. The wrapped function returns the gradient of f when called with the same - arguments. The gradient is with respect to all TFE variables which have - `variable.watch()` called on them by f. + arguments. The gradient is with respect to all trainable TFE variables + accessed by `f`. This function is useful when the exact set of variables to differentiate with is not known ahead of time. @@ -622,7 +610,7 @@ def _num_elements(grad): raise ValueError("`grad` not a Tensor or IndexedSlices.") -_zeros_cache = _TensorCache() +_zeros_cache = context._TensorCache() # pylint: disable=protected-access def _fast_fill(value, shape, dtype): @@ -658,64 +646,55 @@ _default_vspace = imperative_grad.VSpace( ones=_ones) +@tf_export("GradientTape") class GradientTape(object): - """Records operations to use to compute gradients. + """Record operations for automatic differentiation. - Operations are recorded if: - - they happen in code marked by this context manager - - at least one of their inputs is being watched + Operations are recorded if they are executed within this context manager and + at least one of their inputs is being "watched". - Outputs of recorded operations are watched. Variables are automatically - watched and tensors can be manually watched by calling the watch method on the - context manager. + Trainable variables (created by `tf.contrib.eager.Variable` or + @{tf.get_variable}, trainable=True is default in both cases) are automatically + watched. Tensors can be manually watched by invoking the `watch` method on + this context manager. - Example usage: + For example, consider the function `y = x * x`. The gradient at `x = 3.0` can + be computed as: ```python + x = tf.constant(3.) with tfe.GradientTape() as g: - x = tf.constant(3.0) g.watch(x) y = x * x - grad = g.gradient(y, [x])[0] - assert grad.numpy() == 6.0 + grad = g.gradient(y, [x])[0] # Will compute to 6.0 ``` - It is possible to use GradientTapes to compute higher-order derivatives as - follows: + GradientTapes can be nested to compute higher-order derivatives. For example, ```python + x = tf.constant(3.0) with tfe.GradientTape() as g: - x = tf.constant(3.0) - g.watch(x) - y = x * x with tfe.GradientTape() as gg: - gg.watch(y) - z = 2 * y - inner_grad = gg.gradient(z, [y])[0] - assert inner_grad.numpy() == 2 - y = y + inner_grad - grad = g.gradient(y, [x])[0] - assert grad.numpy() == 6.0 + gg.watch(x) + y = x * x + dy_dx = gg.gradient(y, [x])[0] # Will compute to 6.0 + d2y_dx2 = g.gradient(dy_dx, [x])[0] # Will compute to 2.0 ``` By default, the resources held by a GradientTape are released as soon as - GradientTape.gradient() method is called. However, if one need to compute - multiple gradients over the same computation, she can create a persistent - GradientTape. Persistent tapes allow multiple calls to the gradient() method - and release resources when the tape object is destructed. - - Example usage: + GradientTape.gradient() method is called. To compute multiple gradients over + the same computation, create a persistent gradient tape. This allows multiple + calls to the gradient() method as resources are released when the tape object + is garbage collected. For example: ```python + x = tf.constant(3.0) with tfe.GradientTape(persistent=True) as g: - x = tf.constant(3.0) g.watch(x) y = x * x z = y * y - dz_dx = g.gradient(z, [x])[0] - assert dz_dx.numpy() == 108.0 # 4*x^3 at x = 3 - dy_dx = g.gradient(y, [x])[0] - assert dy_dx.numpy() == 6.0 + dy_dx = g.gradient(z, [x])[0] # 6.0 + dz_dx = g.gradient(y, [x])[0] # 108.0 (4*x^3 at x = 3) del g # Drop the reference to the tape """ @@ -724,8 +703,8 @@ class GradientTape(object): Args: persistent: Boolean controlling whether a persistent gradient tape - is created. Must be True or False. - + is created. False by default, which means at most one call can + be made to the gradient() method on this object. """ self._tape = None self._persistent = persistent @@ -741,7 +720,7 @@ class GradientTape(object): """Ensures that `tensor` is being traced by this tape. Args: - tensor: a Tensor or Variable a list of Tensors or Variables. + tensor: a Tensor or list of Tensors. """ for t in nest.flatten(tensor): if isinstance(t, resource_variable_ops.ResourceVariable): @@ -756,14 +735,14 @@ class GradientTape(object): key=lambda v: v.handle._id)) # pylint: disable=protected-access def gradient(self, target, sources, output_gradients=None): - """Computes the gradient using information traced by the tape. + """Computes the gradient using operations recorded in context of this tape. Args: - target: the tensor to be differentiated. - sources: a list of Tensors or Variables, the target will be - differentiated with respect to the sources. + target: Tensor to be differentiated. + sources: a list of Tensors or Variables. `target` will be differentiated + against elements in `sources`. output_gradients: a list of gradients, one for each element of - target. Defaults to None. + target. Defaults to None. Returns: a list of Tensors (or IndexedSlices, or None), one for each element in @@ -771,7 +750,7 @@ class GradientTape(object): Raises: RuntimeError: if called inside the context of the tape, or if called more - than once. + than once on a non-persistent tape. """ if self._tape is None: raise RuntimeError("GradientTape.gradient can only be called once " diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 734558dee2b35813810341480eb38a4bace2936b..f04d89a6d976d1c1f71b385322032e74d42949b5 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -23,7 +23,6 @@ import numpy as np from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import context -from tensorflow.python.eager import custom_gradient from tensorflow.python.eager import tape from tensorflow.python.eager import test from tensorflow.python.framework import constant_op @@ -32,6 +31,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import gradients from tensorflow.python.ops import math_ops @@ -115,6 +116,19 @@ class BackpropTest(test.TestCase): with self.assertRaises(RuntimeError): backprop.gradients_function(f)(constant_op.constant(1.0)) + def testGradientsFunctionInCustomGradient(self): + + @custom_gradient.custom_gradient + def f(x): + (y,) = backprop.gradients_function(lambda x: x * x)(x) + + def grad(dy): + return [2 * dy] + + return y, grad + + self.assertAllEqual(f(1.0), 2.0) + def testImplicitGradOverEmbeddingLookup(self): batch_size = 8 embedding_size = 512 @@ -182,6 +196,19 @@ class BackpropTest(test.TestCase): g, = backprop.gradients_function(loss, [0])(logits, labels) self.assertAllEqual(g.numpy(), [[-0.5, 0.5]]) + @test_util.run_in_graph_and_eager_modes() + def testGradientWithinTapeBlock(self): + v1 = resource_variable_ops.ResourceVariable(1.) + self.evaluate(v1.initializer) + with backprop.GradientTape() as t: + loss = 2 * v1 + with self.assertRaises(RuntimeError): + t.gradient(loss, [v1]) + with backprop.GradientTape(persistent=True) as t: + loss = 2 * v1 + grad = t.gradient(loss, [v1]) + self.assertAllEqual(self.evaluate(grad[0]), 2.0) + @test_util.assert_no_new_tensors def testSecondGrad(self): @@ -343,6 +370,7 @@ class BackpropTest(test.TestCase): self.assertEqual(backprop.implicit_grad(f)()[0][0], None) @test_util.assert_no_new_tensors + @test_util.run_in_graph_and_eager_modes() def testGradientTape(self): with backprop.GradientTape() as g: x = constant_op.constant(3.0) @@ -352,10 +380,53 @@ class BackpropTest(test.TestCase): gg.watch(y) z = 2 * y inner_grad = gg.gradient(z, [y])[0] - self.assertEqual(inner_grad.numpy(), 2.0) + self.assertEqual(self.evaluate(inner_grad), 2.0) y += inner_grad grad = g.gradient(y, [x])[0] - self.assertEqual(grad.numpy(), 6.0) + self.assertEqual(self.evaluate(grad), 6.0) + + @test_util.run_in_graph_and_eager_modes() + def testGradientTapeWithCond(self): + x = constant_op.constant(3.0) + + def true_fn(): + return x + + def false_fn(): + return x * x + + with backprop.GradientTape() as g: + g.watch(x) + y = control_flow_ops.cond(x < x, true_fn, false_fn) + + if not context.executing_eagerly(): + with self.assertRaisesRegexp(NotImplementedError, 'tf.gradients'): + dy = g.gradient(y, [x])[0] + else: + dy = g.gradient(y, [x])[0] + self.assertEqual(self.evaluate(dy), 6.0) + + @test_util.run_in_graph_and_eager_modes() + def testGradientTapeWithWhileLoop(self): + i = constant_op.constant(1) + x = constant_op.constant(2.) + + def cond(i, _): + return i < 3 + + def body(i, x): + return i + 1, x * 2 + + with backprop.GradientTape() as g: + g.watch([x]) + _, y = control_flow_ops.while_loop(cond, body, [i, x]) + + if not context.executing_eagerly(): + with self.assertRaisesRegexp(NotImplementedError, 'tf.gradients'): + dy = g.gradient(y, [x])[0] + else: + dy = g.gradient(y, [x])[0] + self.assertEqual(self.evaluate(dy), 4.0) @test_util.assert_no_new_tensors def testGradientTapeGradientCalledMultipleTimes(self): @@ -370,6 +441,7 @@ class BackpropTest(test.TestCase): g.gradient(y, [x]) @test_util.assert_no_new_tensors + @test_util.run_in_graph_and_eager_modes() def testPersistentTape(self): with backprop.GradientTape(persistent=True) as g: x = constant_op.constant(3.0) @@ -377,12 +449,13 @@ class BackpropTest(test.TestCase): y = x * x z = y * y dz_dx = g.gradient(z, [x])[0] - self.assertEqual(dz_dx.numpy(), 4*3*3*3) + self.assertEqual(self.evaluate(dz_dx), 4 * 3 * 3 * 3) dy_dx = g.gradient(y, [x])[0] - self.assertEqual(dy_dx.numpy(), 2*3) + self.assertEqual(self.evaluate(dy_dx), 2 * 3) del g @test_util.assert_no_new_tensors + @test_util.run_in_graph_and_eager_modes() def testPersistentNestedTape(self): with backprop.GradientTape(persistent=True) as g: x = constant_op.constant(3.0) @@ -393,22 +466,24 @@ class BackpropTest(test.TestCase): z = 2 * y for _ in range(2): inner_grad = gg.gradient(z, [y])[0] - self.assertEqual(inner_grad.numpy(), 2.0) + self.assertEqual(self.evaluate(inner_grad), 2.0) y += inner_grad del gg grad = g.gradient(y, [x])[0] - self.assertEqual(grad.numpy(), 6.0) + self.assertEqual(self.evaluate(grad), 6.0) grad = g.gradient(z, [x])[0] - self.assertEqual(grad.numpy(), 12.0) + self.assertEqual(self.evaluate(grad), 12.0) del g @test_util.assert_no_new_tensors + @test_util.run_in_graph_and_eager_modes() def testGradientTapeVariable(self): v = resource_variable_ops.ResourceVariable(1.0, name='v') + self.evaluate(v.initializer) with backprop.GradientTape() as g: y = v * v grad = g.gradient(y, [v])[0] - self.assertAllEqual(grad, 2.0) + self.assertAllEqual(self.evaluate(grad), 2.0) @test_util.assert_no_new_tensors def testEmptyParamsForValueAndGradFunction(self): diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index b56cbe80a7ab6b90d715187b0f0a44847038fc37..9ca5041c38ed07b39fd73b9f110ab06e8e903251 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -35,7 +35,6 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop # pylint: disable=unused-import from tensorflow.python.eager import context from tensorflow.python.eager import core -from tensorflow.python.eager import execute from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import dtypes @@ -56,11 +55,11 @@ def c_tfe_py_fastpath_execute(a, transpose_b=False, name=None): ctx = context.context() - assert not ctx.in_graph_mode( + assert ctx.executing_eagerly( ), "The prototype doesn't contain C code for graph construction" try: return pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx._handle, ctx.device_name, "MatMul", execute.record_gradient, name, + ctx._handle, ctx.device_name, "MatMul", name, ctx._post_execution_callbacks, a, b, "transpose_a", transpose_a, "transpose_b", transpose_b) except core._NotOkStatusException as e: @@ -83,16 +82,24 @@ class MicroBenchmarks(test.Benchmark): self._num_iters_2_by_2 = 30000 self._num_iters_100_by_784 = 1000 - def _run(self, func, num_iters): + def _run(self, func, num_iters, execution_mode=None): # call func to maybe warm up the GPU - func() - start = time.time() - for _ in xrange(num_iters): + ctx = context.context() + with ctx.execution_mode(execution_mode): func() - end = time.time() - mean_us = (end - start) * 1e6 / num_iters - self.report_benchmark(iters=num_iters, wall_time=mean_us, - extras={"examples_per_sec": num_iters/(end-start)}) + if execution_mode == context.ASYNC: + ctx.async_wait() + start = time.time() + for _ in xrange(num_iters): + func() + if execution_mode == context.ASYNC: + ctx.async_wait() + end = time.time() + mean_us = (end - start) * 1e6 / num_iters + self.report_benchmark( + iters=num_iters, + wall_time=mean_us, + extras={"examples_per_sec": num_iters / (end - start)}) def benchmark_create_np_array(self): func = lambda: np.array([3.0]) @@ -237,13 +244,15 @@ class MicroBenchmarks(test.Benchmark): func = lambda: np.dot(a, b) self._run(func, num_iters) - def _benchmark_tf_matmul(self, m, transpose_b, num_iters): + def _benchmark_tf_matmul(self, m, transpose_b, num_iters, + execution_mode=None): func = lambda: math_ops.matmul(m, m, transpose_b=transpose_b) - self._run(func, num_iters) + self._run(func, num_iters, execution_mode=execution_mode) def _benchmark_gen_math_ops_matmul(self, m, transpose_b, num_iters): def func(): - gen_math_ops._mat_mul(m, m, transpose_b=transpose_b) + gen_math_ops.mat_mul(m, m, transpose_b=transpose_b) + self._run(func, num_iters) def _benchmark_tfe_py_fastpath_execute_matmul(self, m, transpose_b, @@ -267,14 +276,28 @@ class MicroBenchmarks(test.Benchmark): self._run(func, num_iters) - def _benchmark_defun_matmul(self, m, transpose_b, num_iters): + def _benchmark_defun_matmul(self, + m, + transpose_b, + num_iters, + execution_mode=None): f = function.defun(math_ops.matmul) func = lambda: f(m, m, transpose_b) - self._run(func, num_iters) + self._run(func, num_iters, execution_mode=execution_mode) def _benchmark_read_variable(self, m, num_iters): self._run(m.value, num_iters) + def _benchmark_matmul_read_variable(self, m, num_iters): + self._benchmark_gen_math_ops_matmul( + m, transpose_b=False, num_iters=num_iters) + + def _benchmark_matmul_read_variable_with_tape(self, m, num_iters): + with backprop.GradientTape() as tape: + tape.watch(m) + self._benchmark_gen_math_ops_matmul( + m, transpose_b=False, num_iters=num_iters) + def _benchmark_read_variable_with_tape(self, m, num_iters): with backprop.GradientTape() as tape: tape.watch(m) @@ -291,6 +314,15 @@ class MicroBenchmarks(test.Benchmark): self._benchmark_tf_matmul( m, transpose_b=False, num_iters=self._num_iters_2_by_2) + def benchmark_tf_matmul_2_by_2_CPU_async(self): + with context.device(CPU): + m = self._m_2_by_2.cpu() + self._benchmark_tf_matmul( + m, + transpose_b=False, + num_iters=self._num_iters_2_by_2, + execution_mode=context.ASYNC) + def benchmark_gen_math_ops_matmul_2_by_2_CPU(self): with context.device(CPU): m = self._m_2_by_2.cpu() @@ -315,6 +347,15 @@ class MicroBenchmarks(test.Benchmark): self._benchmark_defun_matmul( m, transpose_b=False, num_iters=self._num_iters_2_by_2) + def benchmark_defun_matmul_2_by_2_CPU_async(self): + with context.device(CPU): + m = self._m_2_by_2.cpu() + self._benchmark_defun_matmul( + m, + transpose_b=False, + num_iters=self._num_iters_2_by_2, + execution_mode=context.ASYNC) + def benchmark_tf_matmul_2_by_2_GPU(self): if not context.num_gpus(): return @@ -323,6 +364,17 @@ class MicroBenchmarks(test.Benchmark): self._benchmark_tf_matmul( m, transpose_b=False, num_iters=self._num_iters_2_by_2) + def benchmark_tf_matmul_2_by_2_GPU_async(self): + if not context.num_gpus(): + return + with context.device(GPU): + m = self._m_2_by_2.gpu() + self._benchmark_tf_matmul( + m, + transpose_b=False, + num_iters=self._num_iters_2_by_2, + execution_mode=context.ASYNC) + def benchmark_gen_math_ops_matmul_2_by_2_GPU(self): if not context.num_gpus(): return @@ -347,6 +399,17 @@ class MicroBenchmarks(test.Benchmark): self._benchmark_defun_matmul( m, transpose_b=False, num_iters=self._num_iters_2_by_2) + def benchmark_defun_matmul_2_by_2_GPU_async(self): + if not context.num_gpus(): + return + with context.device(GPU): + m = self._m_2_by_2.gpu() + self._benchmark_defun_matmul( + m, + transpose_b=False, + num_iters=self._num_iters_2_by_2, + execution_mode=context.ASYNC) + # Benchmarks for AA.T, A of dimension 100 by 784. def benchmark_np_matmul_100_by_784(self): self._benchmark_np_matmul( @@ -360,6 +423,15 @@ class MicroBenchmarks(test.Benchmark): self._benchmark_tf_matmul( m, transpose_b=True, num_iters=self._num_iters_100_by_784) + def benchmark_tf_matmul_100_by_784_CPU_async(self): + with context.device(CPU): + m = self._m_100_by_784.cpu() + self._benchmark_tf_matmul( + m, + transpose_b=True, + num_iters=self._num_iters_100_by_784, + execution_mode=context.ASYNC) + def benchmark_gen_math_ops_matmul_100_by_784_CPU(self): with context.device(CPU): m = self._m_100_by_784.cpu() @@ -392,6 +464,17 @@ class MicroBenchmarks(test.Benchmark): self._benchmark_tf_matmul( m, transpose_b=True, num_iters=self._num_iters_100_by_784) + def benchmark_tf_matmul_100_by_784_GPU_async(self): + if not context.num_gpus(): + return + with context.device(GPU): + m = self._m_100_by_784.gpu() + self._benchmark_tf_matmul( + m, + transpose_b=True, + num_iters=self._num_iters_100_by_784, + execution_mode=context.ASYNC) + def benchmark_gen_math_ops_matmul_100_by_784_GPU(self): if not context.num_gpus(): return @@ -416,6 +499,17 @@ class MicroBenchmarks(test.Benchmark): self._benchmark_defun_matmul( m, transpose_b=True, num_iters=self._num_iters_100_by_784) + def benchmark_matmul_read_variable_op_2_by_2_CPU(self): + with context.device(CPU): + m = resource_variable_ops.ResourceVariable(self._m_2_by_2) + self._benchmark_matmul_read_variable(m, num_iters=self._num_iters_2_by_2) + + def benchmark_matmul_read_variable_op_with_tape_2_by_2_CPU(self): + with context.device(CPU): + m = resource_variable_ops.ResourceVariable(self._m_2_by_2) + self._benchmark_matmul_read_variable_with_tape( + m, num_iters=self._num_iters_2_by_2) + def benchmark_read_variable_op_2_by_2_CPU(self): with context.device(CPU): m = resource_variable_ops.ResourceVariable(self._m_2_by_2) diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 07652d3e02b6364e23b6579a64dcadf02dc5eb99..6c9a14730c0db4bdf23fc10b23d63b758349bdc1 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import errors from tensorflow.python.util import compat from tensorflow.python.util import is_in_graph_mode from tensorflow.python.util import tf_contextlib +from tensorflow.python.util.tf_export import tf_export GRAPH_MODE = 0 EAGER_MODE = 1 @@ -52,6 +53,28 @@ DEVICE_PLACEMENT_WARN = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_WARN DEVICE_PLACEMENT_SILENT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT DEVICE_PLACEMENT_SILENT_FOR_INT32 = ( pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) +SYNC = 0 +ASYNC = 1 + + +class _TensorCache(object): + """Simple cache which evicts items based on length in a FIFO manner.""" + + def __init__(self, max_items=256): + self._data = collections.OrderedDict() + self._max_items = max_items if max_items else 256 + + def put(self, key, value): + self._data[key] = value + + if len(self._data) > self._max_items: + self._data.popitem(last=False) + + def get(self, key): + return self._data.get(key, None) + + def flush(self): + self._data = {} # TODO(agarwal): better name ? @@ -60,32 +83,43 @@ class _EagerContext(threading.local): def __init__(self): super(_EagerContext, self).__init__() - self.device_spec = pydev.DeviceSpec.from_string( - "/job:localhost/replica:0/task:0/device:CPU:0") + self.device_spec = pydev.DeviceSpec.from_string("") self.device_name = self.device_spec.to_string() self.mode = _default_mode self.scope_name = "" self.recording_summaries = False self.summary_writer_resource = None self.scalar_cache = {} + self.ones_rank_cache = _TensorCache() + self.execution_mode = None -ContextStackEntry = collections.namedtuple( - "ContextStackEntry", ["is_building_function", "enter_context_fn"]) +ContextSwitch = collections.namedtuple( + "ContextSwitch", ["is_building_function", "enter_context_fn"]) -class ContextStack(threading.local): +# `_ContextSwitchStack` is a `threading.local` to match the semantics of +# ``DefaultGraphStack`, which is also a `threading.local`. +class _ContextSwitchStack(threading.local): """A thread-local stack of context switches.""" - def __init__(self): - super(ContextStack, self).__init__() + def __init__(self, eager): + super(_ContextSwitchStack, self).__init__() self.stack = [] + if eager: + # Initialize the stack with a pointer to enter the eager context; this + # ensures that the fact that eager execution was enabled is propagated + # across threads, since (1) `enable_eager_execution` modifies a + # process-level flag (`_default_mode`) and (2) `__init__` is called each + # time a threading.local object is used in a separate thread. + self.push(is_building_function=False, enter_context_fn=eager_mode) def push(self, is_building_function, enter_context_fn): """Push metadata about a context switch onto the stack. A context switch can take one of two forms: installing a graph as the - default graph, or entering the eager context. + default graph, or entering the eager context. For each context switch, + we record whether or not the entered context is building a function. Args: is_building_function: (bool.) Whether the context is building a function. @@ -94,7 +128,7 @@ class ContextStack(threading.local): """ self.stack.append( - ContextStackEntry(is_building_function, enter_context_fn)) + ContextSwitch(is_building_function, enter_context_fn)) def pop(self): """Pop the stack.""" @@ -102,34 +136,49 @@ class ContextStack(threading.local): self.stack.pop() -context_stack = ContextStack() - - # TODO(agarwal): rename to EagerContext / EagerRuntime ? # TODO(agarwal): consider keeping the corresponding Graph here. class Context(object): """Environment in which eager operations execute.""" - def __init__(self, config=None, device_policy=None): + # TODO(agarwal): create and link in some documentation for `execution_mode`. + # pylint: disable=redefined-outer-name + def __init__(self, config=None, device_policy=None, execution_mode=None): """Creates a new Context. Args: config: (Optional.) A `ConfigProto` protocol buffer with configuration - options for the Context. Note that a lot of these options may be - currently unimplemented or irrelevant when eager execution is enabled. + options for the Context. Note that a lot of these options may be + currently unimplemented or irrelevant when eager execution is enabled. device_policy: (Optional.) What policy to use when trying to run an - operation on a device with inputs which are not on that device. - Valid values: - tfe.DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not - correct. - tfe.DEVICE_PLACEMENT_WARN: copies the tensors which are not on the + operation on a device with inputs which are not on that device. + When set to None, an appropriate value will be picked automatically. + The value picked may change between TensorFlow releases. + + Defaults to tf.contrib.eager.DEVICE_PLACEMENT_SILENT_FOR_INT32. + Valid values: + - tfe.DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is + not correct. + - tfe.DEVICE_PLACEMENT_WARN: copies the tensors which are not on the right device but raises a warning. - tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might + - tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might hide performance problems. - tfe.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors, + - tfe.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors, raising errors on the other ones. + execution_mode: (Optional.) Policy controlling how operations dispatched + are actually executed. When set to None, an appropriate value will be + picked automatically. The value picked may change between TensorFlow + releases. + Valid values: + - tf.contrib.eager.SYNC: executes each operation synchronously. + - tf.contrib.eager.ASYNC: executes each operation asynchronously. These + operations may return "non-ready" handles. + + Raises: + ValueError: If execution_mode is not valid. """ self._eager_context = _EagerContext() + self._context_switches = _ContextSwitchStack(self.executing_eagerly()) self._context_handle = None self._context_devices = None self._post_execution_callbacks = [] @@ -137,6 +186,14 @@ class Context(object): self._seed = None self._initialize_lock = threading.Lock() self._device_policy = device_policy + if execution_mode not in (None, SYNC, ASYNC): + raise ValueError( + "execution_mode should be None/SYNC/ASYNC. Got %s" % execution_mode) + if execution_mode is None: + execution_mode = SYNC + self._execution_mode = execution_mode + + # pylint: enable=redefined-outer-name def _set_global_seed(self, seed): """Set a global eager mode seed for random ops.""" @@ -174,6 +231,8 @@ class Context(object): if self._device_policy is not None: pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy( opts, self._device_policy) + if self._execution_mode == ASYNC: + pywrap_tensorflow.TFE_ContextOptionsSetAsync(True) self._context_handle = pywrap_tensorflow.TFE_NewContext(opts, status) finally: pywrap_tensorflow.TFE_DeleteContextOptions(opts) @@ -232,26 +291,29 @@ class Context(object): old_mode = ctx.mode ctx.mode = mode if mode == EAGER_MODE: - context_stack.push(False, eager_mode) + # Entering graph mode does not provide us with sufficient information to + # record a context switch; graph-based context switches are only logged + # when a graph is registered as the default graph. + self.context_switches.push(False, eager_mode) try: yield finally: ctx.mode = old_mode if mode == EAGER_MODE: - context_stack.pop() + self.context_switches.pop() - def in_graph_mode(self): - """Returns True if current thread is in GRAPH mode.""" - return self._eager_context.mode == GRAPH_MODE - - def in_eager_mode(self): - """Returns True if current thread is in EAGER mode.""" + def executing_eagerly(self): + """Returns True if current thread has eager executing enabled.""" return self._eager_context.mode == EAGER_MODE def scalar_cache(self): """Per-device cache for scalars.""" return self._eager_context.scalar_cache + def ones_rank_cache(self): + """Per-device cache for scalars.""" + return self._eager_context.ones_rank_cache + @property def scope_name(self): """Returns scope name for the current thread.""" @@ -335,6 +397,43 @@ class Context(object): """List of the names of devices available to execute operations.""" return self._devices + def get_execution_mode(self): + mode = self._eager_context.execution_mode + if mode is None: + mode = self._execution_mode + return mode + + def set_execution_mode(self, mode): + """Sets execution mode for current thread.""" + if mode not in (None, SYNC, ASYNC): + raise ValueError( + "Execution mode should be None/SYNC/ASYNC. Got %s" % mode) + if mode is None: + mode = SYNC + self._eager_context.execution_mode = mode + with errors.raise_exception_on_not_ok_status() as status: + pywrap_tensorflow.TFE_ContextSetAsyncForThread(self._handle, + mode == ASYNC, status) + + @tf_contextlib.contextmanager + def execution_mode(self, mode): + """Context manager for setting execution mode for current thread.""" + old_mode = self.get_execution_mode() + try: + self.set_execution_mode(mode) + yield + finally: + self.set_execution_mode(old_mode) + + def async_wait(self): + """Waits for ops dispatched in ASYNC mode to finish.""" + with errors.raise_exception_on_not_ok_status() as status: + pywrap_tensorflow.TFE_ContextAsyncWait(self._handle, status) + + def async_clear_error(self): + """Clears errors raised during ASYNC execution.""" + pywrap_tensorflow.TFE_ContextAsyncClearError(self._handle) + def num_gpus(self): """The number of GPUs available to execute operations.""" self._initialize_handle_and_devices() @@ -457,6 +556,11 @@ class Context(object): run_metadata.ParseFromString(compat.as_bytes(proto_data)) return run_metadata + @property + def context_switches(self): + """Returns a stack of context switches.""" + return self._context_switches + _context = None _context_lock = threading.Lock() @@ -498,23 +602,29 @@ def internal_operation_seed(): return context()._internal_operation_seed() # pylint: disable=protected-access -def in_graph_mode(): - """Returns True if current thread is in GRAPH mode for default context.""" - return context().in_graph_mode() +@tf_export("executing_eagerly") +def executing_eagerly(): + """Returns True if the current thread has eager execution enabled. + + Eager execution is typically enabled via @{tf.enable_eager_execution}, + but may also be enabled within the context of a Python function via + tf.contrib.eager.py_func. + """ + return context().executing_eagerly() def in_eager_mode(): - """Returns True if current thread is in EAGER mode for default context.""" - return context().in_eager_mode() + """Use executing_eagerly() instead. This function will be removed.""" + return executing_eagerly() def graph_mode(): - """Context-manager to enable GRAPH mode for current thread.""" + """Context-manager to disable eager execution for the current thread.""" return context()._mode(GRAPH_MODE) # pylint: disable=protected-access def eager_mode(): - """Context-manager to enable EAGER mode for current thread.""" + """Context-manager to enable eager execution for the current thread.""" return context()._mode(EAGER_MODE) # pylint: disable=protected-access @@ -568,6 +678,26 @@ def list_devices(): return context().devices() +def set_execution_mode(mode): + """Sets execution mode for the current thread.""" + context().set_execution_mode(mode) + + +def execution_mode(mode): + """Context manager for setting execution mode for current thread.""" + return context().execution_mode(mode) + + +def async_wait(): + """Waits for ops dispatched in ASYNC mode to finish.""" + return context().async_wait() + + +def async_clear_error(): + """Clears errors raised during ASYNC execution mode.""" + return context().async_clear_error() + + def num_gpus(): """Get the number of available GPU devices. @@ -607,4 +737,8 @@ def export_run_metadata(): # (for example, enable_eager_execution in python/framework/ops.py), # but they do all import this file. Note that IS_IN_GRAPH_MODE and # in_graph_mode are both parameterless functions. -is_in_graph_mode.IS_IN_GRAPH_MODE = in_graph_mode +def _tmp_in_graph_mode(): + return not executing_eagerly() + + +is_in_graph_mode.IS_IN_GRAPH_MODE = _tmp_in_graph_mode diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index c68e2f422eb81a915d4f941ffb920f221d9be250..6ebf5b24819d48ba4a17d6059510eee7affe40ea 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -33,7 +33,10 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import resource_variable_ops def execute(op_name, num_outputs, inputs, attrs=None): @@ -54,19 +57,27 @@ class TFETest(test_util.TensorFlowTestCase): def testContext(self): ctx = context.Context() - self.assertFalse(ctx.in_graph_mode()) - self.assertTrue(ctx.in_eager_mode()) + self.assertTrue(ctx.executing_eagerly()) self.assertEqual('', ctx.scope_name) ctx.scope_name = 'foo' self.assertEqual('foo', ctx.scope_name) + self.assertEqual(context.SYNC, ctx.get_execution_mode()) + ctx.set_execution_mode(context.ASYNC) + self.assertEqual(context.ASYNC, ctx.get_execution_mode()) + ctx.set_execution_mode(context.SYNC) + self.assertEqual(context.SYNC, ctx.get_execution_mode()) + with ctx.execution_mode(context.ASYNC): + self.assertEqual(context.ASYNC, ctx.get_execution_mode()) + ctx.set_execution_mode(context.SYNC) + self.assertEqual(context.SYNC, ctx.get_execution_mode()) + self.assertIsNone(ctx.summary_writer_resource) ctx.summary_writer_resource = 'mock' self.assertEqual('mock', ctx.summary_writer_resource) - self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0', - ctx.device_name) + self.assertEqual('', ctx.device_name) self.assertEqual(ctx.device_name, ctx.device_spec.to_string()) with ctx.device('GPU:0'): self.assertEqual('/job:localhost/replica:0/task:0/device:GPU:0', @@ -100,19 +111,30 @@ class TFETest(test_util.TensorFlowTestCase): self.assertEqual(len(cpu_stats.node_stats), 1) self.assertEqual(cpu_stats.node_stats[0].node_name, 'Add') - def testContextStackContainsEagerMode(self): - # Eager execution has been enabled, and no other context - # switch has occurred, so `context_stack` should contain - # exactly one entry. - self.assertEqual(len(context.context_stack.stack), 1) - stack_entry = context.context_stack.stack[0] + def testShouldCopy(self): + if not context.context().num_gpus(): + self.skipTest('No devices other than CPUs found') + with ops.device('gpu:0'): + x = constant_op.constant(1.0) + y = array_ops.identity(x) + # The value we're testing y.device against will depend on what the behavior + # of not explicitly specifying a device in the context is. This behavior is + # subject to change (for example, in the future we may want to use GPUs, if + # available, when no device is explicitly provided) + self.assertEqual(y.device, '/job:localhost/replica:0/task:0/device:CPU:0') + + def testContextSwitchStackContainsEagerMode(self): + # Eager execution has been enabled, and no other context switch has + # occurred, so `context_switches` should contain exactly one entry. + self.assertEqual(len(context.context().context_switches.stack), 1) + switch = context.context().context_switches.stack[0] # The entry should log that eager mode was entered. - self.assertIs(stack_entry.enter_context_fn, context.eager_mode) + self.assertIs(switch.enter_context_fn, context.eager_mode) # It is not possible to build a graph function when eager execution # is enabled; the stack entry should reflect this fact. - self.assertFalse(stack_entry.is_building_function) + self.assertFalse(switch.is_building_function) def testInt32GPU(self): if not context.context().num_gpus(): @@ -136,9 +158,9 @@ class TFETest(test_util.TensorFlowTestCase): def get_context_values(ctx): return [ - ctx.in_graph_mode(), - ctx.in_eager_mode(), ctx.scope_name, ctx.summary_writer_resource, - ctx.device_name, ctx.num_gpus() + ctx.executing_eagerly(), ctx.scope_name, ctx.summary_writer_resource, + ctx.device_name, + ctx.num_gpus() ] def get_values(ctx, values): @@ -169,6 +191,18 @@ class TFETest(test_util.TensorFlowTestCase): attrs=('T', x.dtype.as_datatype_enum))[0].cpu().numpy() self.assertEqual(3, result) + def testResourceTensorPlacement(self): + if not context.context().num_gpus(): + self.skipTest('No GPUs found') + + with context.device('gpu:0'): + v = resource_variable_ops.ResourceVariable(1.0) + with context.device('cpu:0'): + # Check that even though we specified the cpu device we'll run the read op + # in the device where the handle is. + self.assertAllEqual( + gen_resource_variable_ops.read_variable_op(v.handle, v.dtype), 1.0) + def testCopyBetweenDevices(self): if not context.context().num_gpus(): self.skipTest('No GPUs found') @@ -183,6 +217,23 @@ class TFETest(test_util.TensorFlowTestCase): with self.assertRaises(RuntimeError): x.gpu(context.context().num_gpus() + 1) + def testCopyBetweenDevicesAsync(self): + if not context.context().num_gpus(): + self.skipTest('No GPUs found') + with context.execution_mode(context.ASYNC): + x = constant_op.constant([[1., 2.], [3., 4.]]) + x = x.cpu() + x = x.gpu() + x = x.gpu() + x = x.cpu() + context.async_wait() + + # Invalid device + with self.assertRaises(RuntimeError): + x.gpu(context.context().num_gpus() + 1) + context.async_wait() + context.async_clear_error() + def testCopyScope(self): if not context.context().num_gpus(): self.skipTest('No GPUs found') @@ -223,16 +274,49 @@ class TFETest(test_util.TensorFlowTestCase): attrs=('T', three.dtype.as_datatype_enum))[0] self.assertAllEqual(15, product) + def testExecuteBasicAsync(self): + with context.execution_mode(context.ASYNC): + three = constant_op.constant(3) + five = constant_op.constant(5) + product = execute( + b'Mul', + num_outputs=1, + inputs=[three, five], + attrs=('T', three.dtype.as_datatype_enum))[0] + self.assertAllEqual(15, product) + # Error: Invalid arguments + context.set_execution_mode(context.ASYNC) + with self.assertRaises(errors.InvalidArgumentError): + execute( + b'MatMul', + num_outputs=1, + inputs=[three, five], + attrs=('transpose_a', False, 'transpose_b', False, 'T', + three.dtype.as_datatype_enum)) + context.async_wait() + context.async_clear_error() + context.set_execution_mode(context.SYNC) + def testExecuteTooManyNumOutputs(self): # num_outputs provided is 50, but only one output is produced. - # That should be okay. product = execute( b'Mul', num_outputs=50, - inputs=[constant_op.constant(3), constant_op.constant(5)], + inputs=[constant_op.constant(3), + constant_op.constant(5)], attrs=('T', dtypes.int32.as_datatype_enum))[0] self.assertAllEqual(15, product) + def testExecuteTooFewNumOutputs(self): + # num_outputs provided is 0, but one output is produced. + with self.assertRaises(errors.InvalidArgumentError): + _ = execute( + b'Mul', + num_outputs=0, + inputs=[constant_op.constant(3), + constant_op.constant(5)], + attrs=('T', dtypes.int32.as_datatype_enum))[0] + def testMatMulGPU(self): if not context.context().num_gpus(): self.skipTest('No GPUs found') @@ -520,5 +604,61 @@ class TFETest(test_util.TensorFlowTestCase): self.assertIsInstance(t, ops.EagerTensor) +class SendRecvTest(test_util.TensorFlowTestCase): + + cpu_device = '/job:localhost/replica:0/task:0/device:CPU:0' + + def _send(self, tensor, tensor_name, to_device): + return execute( + b'_Send', num_outputs=0, inputs=[tensor], + attrs=('T', tensor.dtype.as_datatype_enum, + 'tensor_name', tensor_name, + 'send_device', tensor.device, + 'send_device_incarnation', 0, + 'recv_device', to_device, + 'client_terminated', True)) + + def _recv(self, dtype, tensor_name, from_device): + device_name = context.context().device_name + if not device_name: + device_name = self.cpu_device + return execute( + b'_Recv', num_outputs=1, inputs=[], + attrs=('tensor_type', dtype.as_datatype_enum, + 'tensor_name', tensor_name, + 'send_device', from_device, + 'send_device_incarnation', 0, + 'recv_device', device_name, + 'client_terminated', False))[0] + + def testBasic(self): + t0 = constant_op.constant(1.0) + t1 = constant_op.constant(2.0) + self._send(t0, 't0', self.cpu_device) + self._send(t1, 't1', self.cpu_device) + self.assertAllEqual( + self._recv(dtypes.float32, 't0', self.cpu_device), + 1.0) + self.assertAllEqual( + self._recv(dtypes.float32, 't1', self.cpu_device), + 2.0) + + def testLocalCrossDevice(self): + if not context.context().num_gpus(): + self.skipTest('No GPUs found') + gpu_device_name = '/job:localhost/replica:0/task:0/device:GPU:0' + with ops.device('GPU:0'): + t0 = constant_op.constant(1.0) + self._send(t0, 't0', self.cpu_device) + self.assertAllEqual( + self._recv(dtypes.float32, 't0', gpu_device_name), + 1.0) + self._send(constant_op.constant(2.0), 't1', gpu_device_name) + with ops.device('GPU:0'): + self.assertAllEqual( + self._recv(dtypes.float32, 't1', self.cpu_device), + 2.0) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/custom_gradient.py b/tensorflow/python/eager/custom_gradient.py deleted file mode 100644 index 05460ff9968312528d87f5fc2ad0495b4da2ad1a..0000000000000000000000000000000000000000 --- a/tensorflow/python/eager/custom_gradient.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Decorator to overrides the gradient for a function.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.eager import context -from tensorflow.python.eager import tape -from tensorflow.python.framework import ops as tf_ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_array_ops -from tensorflow.python.util import nest -from tensorflow.python.util import tf_decorator - - -def custom_gradient(f): - """Decorator to define a function with a custom gradient. - - The input function is expected to return the tuple - (results, gradient_function). - - The output function will return results while possibly recording the - gradient_function and inputs in the tape. - - Args: - f: function to be decorated. - - Returns: - decorated function. - """ - - def decorated(*args, **kwargs): - """Decorated function with custom gradient.""" - if context.in_graph_mode(): - if kwargs: - raise ValueError( - "custom_gradient in graph mode doesn't support keyword arguments.") - name = "CustomGradient-%s" % tf_ops.uid() - args = [tf_ops.convert_to_tensor(x) for x in args] - result, grad_fn = f(*args) - flat_result = nest.flatten(result) - all_tensors = flat_result + args - - @tf_ops.RegisterGradient(name) - def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable - gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)])) - # Need to return one value per input to the IdentityN, so pad the - # gradients of the inputs of the custom_gradient function with the - # gradients of the outputs as well. - return ([None] * len(flat_result)) + gradients - - with tf_ops.get_default_graph().gradient_override_map( - {"IdentityN": name}): - all_tensors = array_ops.identity_n(all_tensors) - return nest.pack_sequence_as( - structure=result, flat_sequence=all_tensors[:len(flat_result)]) - - input_tensors = [tf_ops.convert_to_tensor(x) for x in args] - - with tape.stop_recording(): - result, grad_fn = f(*args, **kwargs) - flat_result = nest.flatten(result) - # TODO(apassos) consider removing the identity below. - flat_result = [gen_array_ops.identity(x) for x in flat_result] - - def actual_grad_fn(*outputs): - return nest.flatten(grad_fn(*outputs)) - - tape.record_operation( - f.__name__, - flat_result, - input_tensors, - actual_grad_fn) - flat_result = list(flat_result) - return nest.pack_sequence_as(result, flat_result) - - return tf_decorator.make_decorator(f, decorated) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index b3317bd3235f432220d9d5d135f1af18a6f43310..343012e552592a6f8bb1255118add3e938aa443c 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -36,6 +36,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.util import compat @@ -111,7 +112,7 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False): """ del as_ref # Unused. - if context.in_eager_mode(): + if context.executing_eagerly(): return value default_graph = ops.get_default_graph() @@ -162,31 +163,15 @@ class CapturingGraph(ops.Graph): op_def=None, compute_shapes=True, compute_device=True): - # TODO(apassos) probably control flow has to be handled delicately here as - # in if a resource is accessed inside a control flow context we need the - # control dependency to point to something outside the context which is - # guaranteed to happen after the access. - # # TODO(apassos) this should do some form of alias analysis as ops which # forward the resources such as Identity and Switch can cause serialization # to fail. - resource_inputs = set() - control_inputs = set() for i, inp in enumerate(inputs): if inp.graph is not self: inputs[i] = capture_value(self.captures, inp, inp.dtype, inp.op.name) - inp = inputs[i] - if inp.dtype == dtypes_module.resource: - if inp.name in self._last_op_using_resource_tensor: - control_inputs.add(self._last_op_using_resource_tensor[inp.name]) - resource_inputs.add(inp.name) - with self.control_dependencies(list(control_inputs)): - op = super(CapturingGraph, self).create_op( - op_type, inputs, dtypes, input_types, name, attrs, op_def, - compute_shapes, compute_device) - for name in resource_inputs: - self._last_op_using_resource_tensor[name] = op - return op + return super(CapturingGraph, self).create_op( + op_type, inputs, dtypes, input_types, name, attrs, op_def, + compute_shapes, compute_device) # TODO(apassos): it'd be really nice if we could scope this registration. @@ -310,7 +295,7 @@ class _EagerDefinedFunction(object): proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) function_def = function_pb2.FunctionDef() function_def.ParseFromString(compat.as_bytes(proto_data)) - if context.in_eager_mode(): + if context.executing_eagerly(): _register(fn) self.definition = function_def self.name = function_def.signature.name @@ -453,7 +438,14 @@ class GraphModeFunction(object): all_args = args + self._extra_inputs signature = self._forward_fdef.signature ctx = context.context() - if ctx.in_graph_mode(): + if ctx.executing_eagerly(): + outputs = execute.execute( + str(signature.name), + num_outputs=len(signature.output_arg), + inputs=all_args, + attrs=None, + ctx=ctx) + else: g = ops.get_default_graph() g._add_function(self._forward_fdef) # pylint: disable=protected-access op = g.create_op( @@ -468,13 +460,6 @@ class GraphModeFunction(object): outputs, (ops.Tensor, type(None))) else list(outputs) for i, s in enumerate(self._output_shapes): outputs[i].set_shape(s) - else: - outputs = execute.execute( - str(signature.name), - num_outputs=len(signature.output_arg), - inputs=all_args, - attrs=None, - ctx=ctx) real_outputs = outputs[:len(self._returns)] side_outputs = outputs[len(self._returns):] @@ -545,7 +530,14 @@ class GraphModeFunction(object): return self._backprop_call(tensor_inputs) ctx = context.context() - if ctx.in_graph_mode(): + if ctx.executing_eagerly(): + result = execute.execute( + str(self._func_name), + num_outputs=self._num_outputs, + inputs=tensor_inputs + self._extra_inputs, + attrs=None, + ctx=ctx) + else: g = ops.get_default_graph() self.add_to_graph(g) signature = self._function_def.definition.signature @@ -562,13 +554,6 @@ class GraphModeFunction(object): return op for i, s in enumerate(self._output_shapes): result[i].set_shape(s) - else: - result = execute.execute( - str(self._func_name), - num_outputs=self._num_outputs, - inputs=tensor_inputs + self._extra_inputs, - attrs=None, - ctx=ctx) return self._build_call_outputs(result) @@ -636,13 +621,15 @@ def _defun_internal(name, func, args, kwds): for collection in curr_graph.collections: tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection( collection) - with tmp_graph.as_default(): + with tmp_graph.as_default(), AutomaticControlDependencies() as a: func_inputs = _get_defun_inputs(args) def convert(x): if x is None: return None - return ops.convert_to_tensor_or_indexed_slices(x) + x = ops.convert_to_tensor_or_indexed_slices(x) + x = a.mark_as_return(x) + return x with capture_tensors(captures): this_tape = tape.push_new_tape() @@ -679,7 +666,7 @@ def _defun_internal(name, func, args, kwds): if x not in all_ignored_ops) # Register any other functions defined in the graph # TODO(ashankar): Oh lord, forgive me for this lint travesty. - if context.in_eager_mode(): + if context.executing_eagerly(): for f in tmp_graph._functions.values(): # pylint: disable=protected-access # TODO(ashankar): What about the gradient registry? _register(f._c_func) # pylint: disable=protected-access @@ -887,10 +874,39 @@ class AutomaticControlDependencies(object): self._returned_tensors = set() def mark_as_return(self, tensor): + """Acts like identity but marks the `Tensor` as a return value. + + This will possibly return a copy of the `Tensor`. Usage: + + ``` + with AutomaticControlDependencies() as a: + ... + t = a.mark_as_return(t) + _ = ...(t...) # i.e. it's safe to use t here + ``` + + Args: + tensor: the `Tensor` to be marked + + Returns: + a copy of the `Tensor`. + """ + if isinstance(tensor, ops.IndexedSlices): + values = array_ops.identity(tensor.values) + indices = array_ops.identity(tensor.indices) + self._returned_tensors.add(indices) + self._returned_tensors.add(values) + return ops.IndexedSlices(values, indices, dense_shape=tensor.dense_shape) + # We want to make the return values depend on the stateful operations, but + # we don't want to introduce a cycle, so we make the return value the result + # of a new identity operation that the stateful operations definitely don't + # depend on. + tensor = array_ops.identity(tensor) self._returned_tensors.add(tensor) + return tensor def __enter__(self): - if context.in_eager_mode(): + if context.executing_eagerly(): return self # This code assumes no other thread is adding ops to the graph while # we're adding ops to the graph. @@ -961,7 +977,7 @@ class AutomaticControlDependencies(object): merge_for_resource[o] = new_merge[0].op def __exit__(self, unused_type, unused_value, unused_traceback): - if context.in_eager_mode(): + if context.executing_eagerly(): return if self._graph is not ops.get_default_graph(): @@ -1008,7 +1024,8 @@ class AutomaticControlDependencies(object): for op in new_operations: control_inputs = set() # Ensure stateful ops run - if self._graph._registered_ops[op.type].is_stateful: # pylint: disable=protected-access + if (op.type not in self._graph._registered_ops # pylint: disable=protected-access + or self._graph._registered_ops[op.type].is_stateful): # pylint: disable=protected-access ops_which_must_run.add(op) # Ignore switches (they're handled separately) if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource: @@ -1044,9 +1061,10 @@ class AutomaticControlDependencies(object): # Ensure all ops which must run do run for r in self._returned_tensors: - r.op._add_control_inputs( # pylint: disable=protected-access - [o for o in ops_which_must_run - if o._control_flow_context is r.op._control_flow_context]) # pylint: disable=protected-access + if ops_which_must_run: + r.op._add_control_inputs( # pylint: disable=protected-access + [o for o in ops_which_must_run + if o._control_flow_context is r.op._control_flow_context]) # pylint: disable=protected-access def automatic_control_dependencies(f): @@ -1066,8 +1084,7 @@ def automatic_control_dependencies(f): def wrapper(*args, **kwds): with AutomaticControlDependencies() as a: result = f(*args, **kwds) - for t in nest.flatten(result): - a.mark_as_return(t) - return result + result_flat = [a.mark_as_return(t) for t in nest.flatten(result)] + return nest.pack_sequence_as(result, result_flat) return tf_decorator.make_decorator(f, wrapper) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 431d9388c0ee97eda197142ec97b9448d985b04b..fd1d2c25ffe50cb7afcae29b3d0b15635b6a57dd 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -37,6 +37,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.training import gradient_descent class FunctionTest(test.TestCase): @@ -606,7 +607,7 @@ class AutomaticControlDependenciesTest(test.TestCase): v.assign(v + 1) v.assign(2 * v) val = v.read_value() - c.mark_as_return(val) + val = c.mark_as_return(val) self.assertAllEqual(val.eval(), 4.0) def testCondMustRun(self): @@ -626,7 +627,7 @@ class AutomaticControlDependenciesTest(test.TestCase): control_flow_ops.cond(p, true_fn, false_fn) val = v.read_value() - c.mark_as_return(val) + val = c.mark_as_return(val) self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0) self.assertAllEqual(val.eval(feed_dict={p: True}), 6.0) @@ -647,7 +648,7 @@ class AutomaticControlDependenciesTest(test.TestCase): control_flow_ops.cond(p, true_fn, false_fn) one = constant_op.constant(1.0) - c.mark_as_return(one) + one = c.mark_as_return(one) one.eval(feed_dict={p: False}) self.assertAllEqual(v.read_value().eval(), 5.0) one.eval(feed_dict={p: True}) @@ -681,7 +682,7 @@ class AutomaticControlDependenciesTest(test.TestCase): control_flow_ops.cond(p, true_fn, false_fn) with ops.name_scope('final'): val = v.read_value() - c.mark_as_return(val) + val = c.mark_as_return(val) self.assertAllEqual(val.eval(feed_dict={p: False, q: False}), 3.0) self.assertAllEqual(val.eval(feed_dict={p: False, q: True}), 6.0) self.assertAllEqual(val.eval(feed_dict={p: True, q: True}), 7.0) @@ -703,7 +704,7 @@ class AutomaticControlDependenciesTest(test.TestCase): control_flow_ops.cond(p, true_fn, false_fn) val = v.read_value() - c.mark_as_return(val) + val = c.mark_as_return(val) self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0) self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0) @@ -724,7 +725,7 @@ class AutomaticControlDependenciesTest(test.TestCase): control_flow_ops.cond(p, true_fn, false_fn) val = v.read_value() - c.mark_as_return(val) + val = c.mark_as_return(val) self.assertAllEqual(val.eval(feed_dict={p: False}), 6.0) self.assertAllEqual(val.eval(feed_dict={p: True}), 12.0) @@ -745,7 +746,7 @@ class AutomaticControlDependenciesTest(test.TestCase): control_flow_ops.cond(p, true_fn, false_fn) v.assign(v * 2) val = v.read_value() - c.mark_as_return(val) + val = c.mark_as_return(val) self.assertAllEqual(val.eval(feed_dict={p: False}), 10.0) self.assertAllEqual(val.eval(feed_dict={p: True}), 20.0) @@ -762,6 +763,37 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertAllEqual(f().eval(), 4.0) + def testOptimizerInDefun(self): + def loss(v): + return v**2 + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) + + @function.defun + def train(): + v = resource_variable_ops.ResourceVariable(1.0) + grad = backprop.implicit_grad(loss)(v) + optimizer.apply_gradients(grad) + return v.read_value() + + value = train() + self.assertEqual(value.numpy(), -1.0) + + def testOptimizerInDefunWithCapturedVariable(self): + v = resource_variable_ops.ResourceVariable(1.0) + def loss(): + return v**2 + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) + + @function.defun + def train(): + grad = backprop.implicit_grad(loss)() + optimizer.apply_gradients(grad) + + train() + self.assertEqual(v.numpy(), -1.0) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py index 62106bf0e2809e3c056e4a357f3d05251b7dca68..ee5d87f0835a8e70e0ce14537a51ea5418db41b9 100644 --- a/tensorflow/python/eager/graph_callable.py +++ b/tensorflow/python/eager/graph_callable.py @@ -279,9 +279,12 @@ def _graph_callable_internal(func, shape_and_dtypes): # scope's view of which variables exist. variable_captures = _VariableCapturingScope() with variable_captures.initializing_scope(), function.capture_tensors( - captures): + captures), function.AutomaticControlDependencies() as a: func_outputs = func(*func_inputs) - outputs_list = nest.flatten(func_outputs) + outputs_list = nest.flatten(func_outputs) + for i, x in enumerate(outputs_list): + if x is not None: + outputs_list[i] = a.mark_as_return(x) if len(outputs_list) == 1 and outputs_list[0] is None: outputs_list = [] output_shapes = [x.shape for x in outputs_list] @@ -294,9 +297,12 @@ def _graph_callable_internal(func, shape_and_dtypes): # knows about all variables. tmp_graph.clear_resource_control_flow_state() with variable_captures.capturing_scope(), function.capture_tensors( - captures): + captures), function.AutomaticControlDependencies() as a: captured_outputs = func(*func_inputs) captured_outlist = nest.flatten(captured_outputs) + for i, x in enumerate(captured_outlist): + if x is not None: + captured_outlist[i] = a.mark_as_return(x) capturing_operations = tmp_graph.get_operations()[ len(initializing_operations):] @@ -400,7 +406,7 @@ def graph_callable(shape_and_dtypes): A callable graph object. """ # TODO(alive,apassos): support initialized_value and friends from tf.Variable. - assert context.in_eager_mode(), ( + assert context.executing_eagerly(), ( "graph_callable can only be used when Eager execution is enabled.") def decorator(func): return tf_decorator.make_decorator(func, diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py index f2e70341d975fb06bce7f2ce6cba7d8c3bc9826c..fc76ede4c502ae8b554c925a921e419bf003c40c 100644 --- a/tensorflow/python/eager/ops_test.py +++ b/tensorflow/python/eager/ops_test.py @@ -17,8 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import threading import numpy as np +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import execute from tensorflow.python.eager import test @@ -130,8 +132,12 @@ class OpsTest(test_util.TensorFlowTestCase): dtype=dtypes.int64) values = constant_op.constant([2, 3, 5, 7, 11]) shape = constant_op.constant([2, 7], dtype=dtypes.int64) - result = sparse_ops.gen_sparse_ops._sparse_split( # pylint: disable=protected-access - split_dim, indices, values, shape, num_split=2) + result = sparse_ops.gen_sparse_ops.sparse_split( + split_dim, + indices, + values, + shape, + num_split=2) output_indices, output_values, output_shape = result self.assertEqual(2, len(output_indices)) self.assertEqual(2, len(output_values)) @@ -277,6 +283,25 @@ class OpsTest(test_util.TensorFlowTestCase): context._context = context.Context() # pylint: enable=protected-access + def testSoftPlacement(self): + if not context.context().num_gpus(): + self.skipTest('No GPUs found') + # Temporarily replace the context + # pylint: disable=protected-access + del context._context + try: + context._context = context.Context( + device_policy=context.DEVICE_PLACEMENT_SILENT, + config=config_pb2.ConfigProto(allow_soft_placement=True)) + cpu_tensor = constant_op.constant(1.0) + result = cpu_tensor + cpu_tensor + self.assertEqual(result.device, + '/job:localhost/replica:0/task:0/device:GPU:0') + finally: + del context._context + context._context = context.Context() + # pylint: enable=protected-access + def testRandomUniform(self): scalar_shape = constant_op.constant([], dtype=dtypes.int32) @@ -352,6 +377,22 @@ class OpsTest(test_util.TensorFlowTestCase): def testNoOpIsNone(self): self.assertTrue(control_flow_ops.no_op() is None) + def testEagerContextPreservedAcrossThreads(self): + def init_fn(): + self.assertTrue(context.executing_eagerly()) + with ops.init_scope(): + self.assertTrue(context.executing_eagerly()) + context_switches = context.context().context_switches + self.assertEqual(len(context_switches.stack), 1) + self.assertFalse(context_switches.stack[0].is_building_function) + self.assertEqual(context_switches.stack[0].enter_context_fn, + context.eager_mode) + + self.assertTrue(context.executing_eagerly()) + t1 = threading.Thread(target=init_fn) + t1.start() + t1.join() + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/python_eager_op_gen.cc b/tensorflow/python/eager/python_eager_op_gen.cc index e6d03297e0b85856ff165af310149c79e494ab36..c2ce8efd7f70c6ba93b6d444f88ddbb9aa51ccdb 100644 --- a/tensorflow/python/eager/python_eager_op_gen.cc +++ b/tensorflow/python/eager/python_eager_op_gen.cc @@ -367,7 +367,7 @@ void GenEagerPythonOp::HandleGraphMode(const string& function_setup) { // Handle graph-mode case strings::StrAppend(&result_, " _ctx = _context.context()\n" - " if _ctx.in_graph_mode():\n", + " if not _ctx.executing_eagerly():\n", function_setup, " _, _, _op = _op_def_lib._apply_op_helper(\n"); AddBodyNoReturn(" "); @@ -712,9 +712,9 @@ bool GenEagerPythonOp::AddEagerFallbackCode( } void GenEagerPythonOp::AddEagerFastPathExecute() { - string fastpath_execute_params = strings::StrCat( - "_ctx._handle, _ctx.device_name, \"", op_def_.name(), "\", ", - "_execute.record_gradient, name, _ctx._post_execution_callbacks"); + string fastpath_execute_params = + strings::StrCat("_ctx._handle, _ctx.device_name, \"", op_def_.name(), + "\", ", "name, _ctx._post_execution_callbacks"); string fallback_params; for (int i = 0; i < api_def_.in_arg_size(); i++) { @@ -955,10 +955,10 @@ from tensorflow.python.util.tf_export import tf_export if (api_def->visibility() == ApiDef::SKIP) { continue; } - // An op is hidden if either its ApiDef visibility is HIDDEN // or it is in the hidden_ops list. bool is_hidden = api_def->visibility() == ApiDef::HIDDEN; + bool hidden_by_api_def = is_hidden; if (!is_hidden) { for (const string& hidden : hidden_ops) { if (op_def.name() == hidden) { @@ -971,13 +971,22 @@ from tensorflow.python.util.tf_export import tf_export string function_name; python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(), &function_name); - if (is_hidden) function_name = strings::StrCat("_", function_name); - - // When users create custom python wrappers, they may link in the - // default op registry by accident, and because they can't - // enumerate all 'hidden' symbols, this guard is to prevent - // instantiating a python reserved word in their wrapper. - if (python_op_gen_internal::IsPythonReserved(function_name)) { + bool is_reserved = python_op_gen_internal::IsPythonReserved(function_name); + + // Prefix an op with underscore if the op is listed in hidden_ops or + // name is reserved or it is of the exceptions in IsOpWithUnderscorePrefix. + // Do not add underscores to ops set to HIDDEN in ApiDef otherwise. + // TODO(annarev): don't prefix with underscores even if op is in hidden_ops. + if (is_hidden) { + if (!hidden_by_api_def || is_reserved || + python_op_gen_internal::IsOpWithUnderscorePrefix(function_name)) { + function_name = strings::StrCat("_", function_name); + } + } else if (is_reserved) { + // When users create custom python wrappers, they may link in the + // default op registry by accident, and because they can't + // enumerate all 'hidden' symbols, this guard is to prevent + // instantiating a python reserved word in their wrapper. continue; } diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index 6fa076507d11ab9c88891cbeb0a4fb3959e4e99d..519814b979e00dd7c9df41eacbe1edc02c9d88e8 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -163,7 +163,7 @@ PyObject* PyIntFromDataType(TF_DataType l) { extern "C" { -static const int kMaxEagerTensorParentSize = 32; +static const int kMaxEagerTensorParentSize = 64; // TODO(agarwal): store context handle in EagerTensor. typedef struct EagerTensor { @@ -185,6 +185,16 @@ typedef struct EagerTensor { // This stores `_keras_mask` object and is set by Tensorflow layers. PyObject* keras_mask; + + // This stores `_tensor_shape`, a cached `TensorShape` object, and is set the + // first time that `_EagerTensorBase`'s `shape` property is called. + PyObject* tensor_shape; + + // We store a status object here as an optimization to avoid allocating a new + // Status objects on different functions that operate on EagerTensor and need + // to use a TF_Status object. However note that accesses to `status` are not + // thread-safe. + TF_Status* status; } EagerTensor; // tp_init for EagerTensor. @@ -195,6 +205,9 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { self->handle_data = Py_None; Py_INCREF(Py_None); self->keras_mask = Py_None; + Py_INCREF(Py_None); + self->tensor_shape = Py_None; + self->status = TF_NewStatus(); PyObject* value; PyObject* context = nullptr; PyObject* device = nullptr; @@ -269,17 +282,17 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { } TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get()); if (desired_dtype >= 0 && desired_dtype != handle_dtype) { - auto out_status = tensorflow::make_safe(TF_NewStatus()); handle = tensorflow::make_safe( EagerCast(GetContext(context), handle.get(), handle_dtype, - static_cast(desired_dtype), out_status.get())); - if (TF_GetCode(out_status.get()) != TF_OK) { - PyErr_SetString( - PyExc_ValueError, - tensorflow::strings::StrCat("Error while casting from DataType ", - handle_dtype, " to ", desired_dtype, ". ", - TF_Message(out_status.get())) - .c_str()); + static_cast(desired_dtype), self->status)); + if (TF_GetCode(self->status) != TF_OK) { + PyErr_SetString(PyExc_ValueError, + tensorflow::strings::StrCat( + "Error while casting from DataType ", handle_dtype, + " to ", desired_dtype, ". ", TF_Message(self->status)) + .c_str()); + // Cleanup self->status before returning. + TF_SetStatus(self->status, TF_OK, ""); return -1; } handle_dtype = TFE_TensorHandleDataType(handle.get()); @@ -323,10 +336,14 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { // tp_dealloc for EagerTensor. void EagerTensor_dealloc(EagerTensor* self) { + TF_DeleteStatus(self->status); Py_DECREF(self->handle_data); Py_DECREF(self->keras_mask); - TFE_DeleteTensorHandle(self->handle); - self->handle = nullptr; + Py_DECREF(self->tensor_shape); + if (self->handle != nullptr) { + TFE_DeleteTensorHandle(self->handle); + self->handle = nullptr; + } // We have the global interpreter lock, so use this chance to perform delayed // refcount decrements. tensorflow::ClearDecrefCache(); @@ -348,12 +365,21 @@ static PyObject* EagerTensor_datatype_enum(EagerTensor* self) { // Getter for `_shape_tuple`. static PyObject* EagerTensor_shape_tuple(EagerTensor* self) { auto handle = self->handle; - int n = TFE_TensorHandleNumDims(handle); + int n = TFE_TensorHandleNumDims(handle, self->status); + if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { + // Cleanup self->status before returning. + TF_SetStatus(self->status, TF_OK, ""); + return nullptr; + } PyObject* shape = PyTuple_New(n); if (PyErr_Occurred()) return nullptr; for (int i = 0; i < n; ++i) { - PyObject* dim = PyLong_FromLongLong(TFE_TensorHandleDim(handle, i)); - if (dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) { + PyObject* dim = + PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, self->status)); + if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError) || + dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) { + // Cleanup self->status before returning. + TF_SetStatus(self->status, TF_OK, ""); Py_DECREF(shape); if (dim != nullptr) Py_DECREF(dim); PyErr_SetString(PyExc_RuntimeError, "Error while creating shape"); @@ -365,10 +391,16 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) { // Getter for `_rank`. static PyObject* EagerTensor_rank(EagerTensor* self) { + int num_dims = TFE_TensorHandleNumDims(self->handle, self->status); + if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { + // Cleanup self->status before returning. + TF_SetStatus(self->status, TF_OK, ""); + return nullptr; + } #if PY_MAJOR_VERSION < 3 - return PyInt_FromLong(TFE_TensorHandleNumDims(self->handle)); + return PyInt_FromLong(num_dims); #else - return PyLong_FromLong(TFE_TensorHandleNumDims(self->handle)); + return PyLong_FromLong(num_dims); #endif } @@ -397,6 +429,19 @@ static int EagerTensor_setkeras_mask(EagerTensor* self, PyObject* value, self->keras_mask = value; return 0; } + +static PyObject* EagerTensor_tensor_shape(EagerTensor* self, void* unused) { + Py_INCREF(self->tensor_shape); + return self->tensor_shape; +} + +static int EagerTensor_settensor_shape(EagerTensor* self, PyObject* value, + void* unused) { + Py_DECREF(self->tensor_shape); + Py_INCREF(value); + self->tensor_shape = value; + return 0; +} // Function `_copy_to_device`. static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args, PyObject* kwds) { @@ -437,10 +482,16 @@ static PyObject* EagerTensor_numpy(EagerTensor* self) { // Getter `device`. static PyObject* EagerTensor_device(EagerTensor* self) { + const char* device = TFE_TensorHandleDeviceName(self->handle, self->status); + if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { + // Cleanup self->status before returning. + TF_SetStatus(self->status, TF_OK, ""); + return nullptr; + } #if PY_MAJOR_VERSION >= 3 - return PyUnicode_FromString(TFE_TensorHandleDeviceName(self->handle)); + return PyUnicode_FromString(device); #else - return PyBytes_FromString(TFE_TensorHandleDeviceName(self->handle)); + return PyBytes_FromString(device); #endif } @@ -455,6 +506,9 @@ static PyGetSetDef EagerTensor_getseters[] = { {const_cast("_keras_mask"), (getter)EagerTensor_keras_mask, (setter)EagerTensor_setkeras_mask, const_cast("_keras_mask"), nullptr}, + {const_cast("_tensor_shape"), (getter)EagerTensor_tensor_shape, + (setter)EagerTensor_settensor_shape, const_cast("_tensor_shape"), + nullptr}, {nullptr} /* Sentinel */ }; @@ -491,16 +545,11 @@ PyTypeObject* EagerTensorType = nullptr; #if PY_MAJOR_VERSION >= 3 static PyType_Slot EagerTensor_Type_slots[] = { - Py_tp_dealloc, - reinterpret_cast(EagerTensor_dealloc), - Py_tp_methods, - reinterpret_cast(EagerTensor_methods), - Py_tp_getset, - reinterpret_cast(EagerTensor_getseters), - Py_tp_init, - reinterpret_cast(EagerTensor_init), - 0, - nullptr, + {Py_tp_dealloc, reinterpret_cast(EagerTensor_dealloc)}, + {Py_tp_methods, reinterpret_cast(EagerTensor_methods)}, + {Py_tp_getset, reinterpret_cast(EagerTensor_getseters)}, + {Py_tp_init, reinterpret_cast(EagerTensor_init)}, + {0, nullptr}, }; PyType_Spec EagerTensor_Type_spec = {"EagerTensor", sizeof(EagerTensor), 0, @@ -575,7 +624,10 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) { t->handle_data = Py_None; Py_INCREF(Py_None); t->keras_mask = Py_None; + Py_INCREF(Py_None); + t->tensor_shape = Py_None; t->handle = handle; + t->status = TF_NewStatus(); } return reinterpret_cast(t); } @@ -673,6 +725,7 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim) { auto tensor = tensorflow::make_safe(TF_AllocateTensor( TF_INT32, &num_tensors_int, /*num_dims=*/1, /*len=*/4 * num_tensors_int)); int32_t* data = reinterpret_cast(TF_TensorData(tensor.get())); + auto status = tensorflow::make_safe(TF_NewStatus()); for (Py_ssize_t i = 0; i < num_tensors; ++i) { PyObject* tensor_obj = PyList_GET_ITEM(tensor_list, i); if (!EagerTensor_CheckExact(tensor_obj)) { @@ -687,21 +740,27 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim) { EagerTensor* t = reinterpret_cast(tensor_obj); TFE_TensorHandle* handle = t->handle; - if (slice_dim >= TFE_TensorHandleNumDims(handle)) { - PyErr_SetString(PyExc_IndexError, - tensorflow::strings::StrCat( - "Slice dimension (", slice_dim, - ") must be smaller than rank of all " - "tensors, but tensor at index ", - i, " has rank ", TFE_TensorHandleNumDims(handle)) - .c_str()); + int num_dims = TFE_TensorHandleNumDims(handle, status.get()); + if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) { + return nullptr; + } + if (slice_dim >= num_dims) { + PyErr_SetString( + PyExc_IndexError, + tensorflow::strings::StrCat("Slice dimension (", slice_dim, + ") must be smaller than rank of all " + "tensors, but tensor at index ", + i, " has rank ", num_dims) + .c_str()); + return nullptr; + } + int64_t dim = TFE_TensorHandleDim(handle, slice_dim, status.get()); + if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) { return nullptr; } - int64_t dim = TFE_TensorHandleDim(handle, slice_dim); data[i] = dim; } - auto status = tensorflow::make_safe(TF_NewStatus()); TFE_TensorHandle* handle = TFE_NewTensorHandle(tensor.get(), status.get()); if (TF_GetCode(status.get()) != TF_OK) { PyErr_SetString( diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index f9692a8910aa6354c7ed81c7e88aed882058f276..32d731d0f68910b8e41a57cb32ae60c3ea6742f7 100644 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -51,6 +51,13 @@ void TFE_Py_Execute(TFE_Context* ctx, const char* device_name, // This function is not thread-safe. PyObject* TFE_Py_RegisterExceptionClass(PyObject* e); +// Registers e as the type of the ResourceVariable class. +// Returns Py_None if registration succeeds, else throws a TypeError and returns +// NULL. +// +// This function is not thread-safe. +PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e); + // Registers e as the Exception to be raised when the conditions of // TFE_Py_FastPathExecute_C have not been met. When this exception is set, it // is a signal to the calling code that it should fall back to the safer (and @@ -160,13 +167,10 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, // Item 2: device_name: Name of the device on which to execute the operation, // or NULL for automatic selection. // Item 3: op_name: Name of the TensorFlow op to execute. -// Item 4: record_gradient_callback: Callback that records the gradient of the -// result. The callback takes (op_name, inputs, attrs, result, name) -// - all sequences and records the gradient. -// Item 5: name: An optional name for the operation. -// Item 6: List representing all callbacks to execute after successful +// Item 4: name: An optional name for the operation. +// Item 5: List representing all callbacks to execute after successful // op execute. -// Item 7 onwards: inputs - This is a list of inputs followed by a list of +// Item 6 onwards: inputs - This is a list of inputs followed by a list of // attrs. It is not necessary for type attrs to be present. // // This is named _C since there doesn't seem to be any way to make it visible diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 30e08c8e6531739e3db66a94308e4ce2aff61f11..55ba509065ba44ccafbd209201a250205553e261 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -31,12 +31,30 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/python/eager/pywrap_tensor.h" +#include "tensorflow/python/lib/core/safe_ptr.h" using tensorflow::string; using tensorflow::strings::Printf; namespace { +struct FastPathOpExecInfo { + TFE_Context* ctx; + const char* device_name; + // The op def of the main op being executed. + const tensorflow::OpDef* op_def; + + bool run_callbacks; + bool run_post_exec_callbacks; + bool run_gradient_callback; + + // The op name of the main op being executed. + PyObject* name; + // The op type name of the main op being executed. + PyObject* op_name; + PyObject* callbacks; +}; + #define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \ bool fn_name(const string& key, PyObject* py_value, TF_Status* status, \ type* value) { \ @@ -75,6 +93,34 @@ Py_ssize_t TensorShapeNumDims(PyObject* value) { return size; } +bool IsInteger(PyObject* py_value) { +#if PY_MAJOR_VERSION >= 3 + return PyLong_Check(py_value); +#else + return PyInt_Check(py_value); +#endif +} + +bool ParseDimensionValue(const string& key, PyObject* py_value, + TF_Status* status, int64_t* value) { + if (IsInteger(py_value)) { + return ParseInt64Value(key, py_value, status, value); + } + + tensorflow::Safe_PyObjectPtr dimension_value( + PyObject_GetAttrString(py_value, "_value")); + if (dimension_value == nullptr) { + TF_SetStatus( + status, TF_INVALID_ARGUMENT, + tensorflow::strings::StrCat("Expecting a Dimension for attr ", key, + ", got ", py_value->ob_type->tp_name) + .c_str()); + return false; + } + + return ParseInt64Value(key, dimension_value.get(), status, value); +} + bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status, const char** value) { if (PyBytes_Check(py_value)) { @@ -101,14 +147,6 @@ bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status, return true; } -bool IsInteger(PyObject* py_value) { -#if PY_MAJOR_VERSION >= 3 - return PyLong_Check(py_value); -#else - return PyInt_Check(py_value); -#endif -} - // The passed in py_value is expected to be an object of the python type // dtypes.DType or an int. bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status, @@ -117,18 +155,18 @@ bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status, return ParseIntValue(key, py_value, status, value); } - PyObject* py_type_enum = PyObject_GetAttrString(py_value, "_type_enum"); + tensorflow::Safe_PyObjectPtr py_type_enum( + PyObject_GetAttrString(py_value, "_type_enum")); if (py_type_enum == nullptr) { + TF_SetStatus( + status, TF_INVALID_ARGUMENT, + tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key, + ", got ", py_value->ob_type->tp_name) + .c_str()); return false; } - if (!ParseIntValue(key, py_type_enum, status, value)) { - Py_DECREF(py_type_enum); - return false; - } - - Py_DECREF(py_type_enum); - return true; + return ParseIntValue(key, py_type_enum.get(), status, value); } bool SetOpAttrList( @@ -146,11 +184,11 @@ bool SetOpAttrList( const int num_values = PySequence_Size(py_list); if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values; -#define PARSE_LIST(c_type, parse_fn) \ - std::unique_ptr values(new c_type[num_values]); \ - for (int i = 0; i < num_values; ++i) { \ - auto py_value = PySequence_ITEM(py_list, i); \ - if (!parse_fn(key, py_value, status, &values[i])) return false; \ +#define PARSE_LIST(c_type, parse_fn) \ + std::unique_ptr values(new c_type[num_values]); \ + for (int i = 0; i < num_values; ++i) { \ + tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); \ + if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \ } if (type == TF_ATTR_STRING) { @@ -175,9 +213,9 @@ bool SetOpAttrList( // dims across all the input lists. int total_dims = 0; for (int i = 0; i < num_values; ++i) { - auto py_value = PySequence_ITEM(py_list, i); - if (py_value != Py_None) { - if (!PySequence_Check(py_value)) { + tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); + if (py_value.get() != Py_None) { + if (!PySequence_Check(py_value.get())) { TF_SetStatus( status, TF_INVALID_ARGUMENT, tensorflow::strings::StrCat( @@ -186,7 +224,7 @@ bool SetOpAttrList( .c_str()); return false; } - const auto size = TensorShapeNumDims(py_value); + const auto size = TensorShapeNumDims(py_value.get()); if (size >= 0) { total_dims += size; } @@ -200,12 +238,12 @@ bool SetOpAttrList( std::unique_ptr num_dims(new int[num_values]); int64_t* offset = buffer.get(); for (int i = 0; i < num_values; ++i) { - auto py_value = PySequence_ITEM(py_list, i); - if (py_value == Py_None) { + tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); + if (py_value.get() == Py_None) { dims[i] = nullptr; num_dims[i] = -1; } else { - const auto size = TensorShapeNumDims(py_value); + const auto size = TensorShapeNumDims(py_value.get()); if (size == -1) { dims[i] = nullptr; num_dims[i] = -1; @@ -214,10 +252,12 @@ bool SetOpAttrList( dims[i] = offset; num_dims[i] = size; for (int j = 0; j < size; ++j) { - auto inner_py_value = PySequence_ITEM(py_value, j); - if (inner_py_value == Py_None) { + tensorflow::Safe_PyObjectPtr inner_py_value( + PySequence_ITEM(py_value.get(), j)); + if (inner_py_value.get() == Py_None) { *offset = -1; - } else if (!ParseInt64Value(key, inner_py_value, status, offset)) { + } else if (!ParseDimensionValue(key, inner_py_value.get(), status, + offset)) { return false; } ++offset; @@ -238,21 +278,12 @@ bool SetOpAttrList( return true; } -// This is only declared here since GetFunc makes a recursive call to -// SetOpAttrScalarDefault. -void SetOpAttrScalarDefault( - TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, - const char* attr_name, - tensorflow::gtl::FlatMap* attr_list_sizes, - TF_Status* status); - TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, TF_Status* status) { TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status); for (const auto& attr : func.attr()) { if (TF_GetCode(status) != TF_OK) return nullptr; - SetOpAttrScalarDefault(ctx, func_op, attr.second, attr.first.data(), - nullptr, status); + SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status); if (TF_GetCode(status) != TF_OK) return nullptr; } return func_op; @@ -398,10 +429,12 @@ bool SetOpAttrScalar( } std::unique_ptr dims(new int64_t[num_dims]); for (int i = 0; i < num_dims; ++i) { - auto inner_py_value = PySequence_ITEM(py_value, i); - if (inner_py_value == Py_None) { + tensorflow::Safe_PyObjectPtr inner_py_value( + PySequence_ITEM(py_value, i)); + if (inner_py_value.get() == Py_None) { dims[i] = -1; - } else if (!ParseInt64Value(key, inner_py_value, status, &dims[i])) { + } else if (!ParseDimensionValue(key, inner_py_value.get(), status, + &dims[i])) { return false; } } @@ -452,57 +485,9 @@ void SetOpAttrScalarDefault( const char* attr_name, tensorflow::gtl::FlatMap* attr_list_sizes, TF_Status* status) { - switch (default_value.value_case()) { - case tensorflow::AttrValue::kS: - TFE_OpSetAttrString(op, attr_name, default_value.s().data()); - break; - case tensorflow::AttrValue::kI: - TFE_OpSetAttrInt(op, attr_name, static_cast(default_value.i())); - (*attr_list_sizes)[attr_name] = default_value.i(); - break; - case tensorflow::AttrValue::kF: - TFE_OpSetAttrFloat(op, attr_name, default_value.f()); - break; - case tensorflow::AttrValue::kB: - TFE_OpSetAttrBool(op, attr_name, default_value.b()); - break; - case tensorflow::AttrValue::kType: - TFE_OpSetAttrType(op, attr_name, - static_cast(default_value.type())); - break; - case tensorflow::AttrValue::kShape: { - const auto& tensor_shape = default_value.shape(); - if (tensor_shape.unknown_rank()) { - TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status); - } else { - const auto num_dims = tensor_shape.dim_size(); - std::unique_ptr dims(new int64_t[num_dims]); - for (int i = 0; i < num_dims; ++i) { - dims[i] = tensor_shape.dim(i).size(); - } - TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status); - } - } break; - case tensorflow::AttrValue::kFunc: { - const auto func_op = GetFunc(ctx, default_value.func(), status); - if (TF_GetCode(status) != TF_OK) return; - // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList - // require TFE_Op* and just convert it internally a NameAttrValue, so - // consider adding an overload to the C API to make this case easier. - TFE_OpSetAttrFunction(op, attr_name, func_op); - } break; - case tensorflow::AttrValue::kList: - TF_FALLTHROUGH_INTENDED; - case tensorflow::AttrValue::kTensor: - TF_FALLTHROUGH_INTENDED; - case tensorflow::AttrValue::kPlaceholder: - TF_FALLTHROUGH_INTENDED; - case tensorflow::AttrValue::VALUE_NOT_SET: - TF_SetStatus( - status, TF_UNIMPLEMENTED, - tensorflow::strings::StrCat("Unable to get setfor default value: ", - default_value.DebugString()) - .data()); + SetOpAttrValueScalar(ctx, op, default_value, attr_name, status); + if (default_value.value_case() == tensorflow::AttrValue::kI) { + (*attr_list_sizes)[attr_name] = default_value.i(); } } @@ -579,6 +564,8 @@ PyObject* fallback_exception_class = nullptr; // Python function that returns a backward_function. PyObject* backward_function_getter = nullptr; +PyTypeObject* resource_variable_type = nullptr; + tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED); tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0; @@ -627,11 +614,28 @@ PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) { "TFE_Py_RegisterExceptionClass: " "Registered class should be subclass of Exception."); return nullptr; - } else { - Py_INCREF(e); - exception_class = e; - Py_RETURN_NONE; } + + Py_INCREF(e); + exception_class = e; + Py_RETURN_NONE; +} + +PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e) { + if (!PyType_Check(e)) { + PyErr_SetString( + PyExc_TypeError, + "TFE_Py_RegisterResourceVariableType: Need to register a type."); + return nullptr; + } + + if (resource_variable_type != nullptr) { + Py_DECREF(resource_variable_type); + } + + Py_INCREF(e); + resource_variable_type = reinterpret_cast(e); + Py_RETURN_NONE; } PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) { @@ -1008,7 +1012,15 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { TFE_TensorHandle* t = EagerTensor_Handle(tensor); tensorflow::int64 id = EagerTensor_id(tensor); - return tensorflow::eager::TapeTensor{id, t->t.dtype(), t->t.shape()}; + const tensorflow::Tensor* tensor = nullptr; + const tensorflow::Status status = t->handle->Tensor(&tensor); + if (MaybeRaiseExceptionFromStatus(status, nullptr)) { + return tensorflow::eager::TapeTensor{id, t->handle->dtype, + tensorflow::TensorShape({})}; + } else { + return tensorflow::eager::TapeTensor{id, t->handle->dtype, + tensor->shape()}; + } } tensorflow::int64 id = FastTensorId(tensor); if (PyErr_Occurred()) { @@ -1312,6 +1324,16 @@ std::vector MakeTensorList(PyObject* tensors) { PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, PyObject* target, PyObject* sources, PyObject* output_gradients, TF_Status* status) { + TFE_Py_Tape* tape_obj = reinterpret_cast(tape); + if (!tape_obj->tape->IsPersistent()) { + auto* tape_set = GetTapeSet(); + if (tape_set->find(tape_obj) != tape_set->end()) { + PyErr_SetString(PyExc_RuntimeError, + "Trying to call tape.gradient on a non-persistent tape " + "while it is still active."); + return nullptr; + } + } PyVSpace c_vspace(vspace); if (!c_vspace.Initialize().ok()) { return nullptr; @@ -1337,7 +1359,6 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, Py_INCREF(tensor); } } - TFE_Py_Tape* tape_obj = reinterpret_cast(tape); std::vector result; status->status = tape_obj->tape->ComputeGradient( c_vspace, target_vec, sources_vec, outgrad_vec, &result); @@ -1364,7 +1385,7 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, } namespace { -static const int kFastPathExecuteInputStartIndex = 6; +static const int kFastPathExecuteInputStartIndex = 5; PyObject* GetPythonObjectFromString(const char* s) { #if PY_MAJOR_VERSION >= 3 @@ -1374,8 +1395,12 @@ PyObject* GetPythonObjectFromString(const char* s) { #endif } -bool CheckEagerTensors(PyObject* seq, int start_index, - const tensorflow::OpDef& op_def) { +bool CheckResourceVariable(PyObject* item) { + return PyObject_TypeCheck(item, resource_variable_type); +} + +bool CheckInputsOk(PyObject* seq, int start_index, + const tensorflow::OpDef& op_def) { for (int i = 0; i < op_def.input_arg_size(); i++) { PyObject* item = PyTuple_GET_ITEM(seq, i + start_index); if (!op_def.input_arg(i).number_attr().empty() || @@ -1383,9 +1408,13 @@ bool CheckEagerTensors(PyObject* seq, int start_index, // This item should be a list input. if (!PyList_Check(item)) return false; for (Py_ssize_t j = 0; j < PyList_Size(item); j++) { - if (!EagerTensor_CheckExact(PyList_GET_ITEM(item, j))) return false; + PyObject* inner_item = PyList_GET_ITEM(item, j); + if (!EagerTensor_CheckExact(inner_item) && + !CheckResourceVariable(inner_item)) { + return false; + } } - } else if (!EagerTensor_CheckExact(item)) { + } else if (!EagerTensor_CheckExact(item) && !CheckResourceVariable(item)) { return false; } } @@ -1393,71 +1422,6 @@ bool CheckEagerTensors(PyObject* seq, int start_index, return true; } -// Adds input and type attr to the op, and to the list of flattened -// inputs/attrs. -bool AddInputToOp(PyObject* input, const tensorflow::OpDef::ArgDef* input_arg, - std::vector* flattened_attrs, - std::vector* flattened_inputs, TFE_Op* op, - TF_Status* status) { - TFE_TensorHandle* input_handle = EagerTensor_Handle(input); - if (input_arg != nullptr && !input_arg->type_attr().empty()) { - auto dtype = TFE_TensorHandleDataType(input_handle); - TFE_OpSetAttrType(op, input_arg->type_attr().data(), dtype); - if (flattened_attrs != nullptr) { - flattened_attrs->push_back( - GetPythonObjectFromString(input_arg->type_attr().data())); - flattened_attrs->push_back(PyLong_FromLong(dtype)); - } - } - - if (flattened_inputs != nullptr) { - flattened_inputs->push_back(input); - } - TFE_OpAddInput(op, input_handle, status); - if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { - return false; - } - return true; -} - -const tensorflow::OpDef* GetOpDef(PyObject* py_op_name) { - const char* op_name = TFE_GetPythonString(py_op_name); - if (op_name == nullptr) { - PyErr_SetString(PyExc_TypeError, - Printf("expected a string for op_name, got %s instead", - py_op_name->ob_type->tp_name) - .c_str()); - return nullptr; - } - - const tensorflow::OpRegistrationData* op_reg_data = nullptr; - const tensorflow::Status lookup_status = - tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data); - if (MaybeRaiseExceptionFromStatus(lookup_status, nullptr)) { - return nullptr; - } - return &op_reg_data->op_def; -} - -const char* GetDeviceName(PyObject* py_device_name) { - if (py_device_name != Py_None) { - return TFE_GetPythonString(py_device_name); - } - return nullptr; -} - -bool RaiseIfNotPyList(PyObject* list, const string& attr_name) { - if (!PyList_Check(list)) { - PyErr_SetString(PyExc_TypeError, - Printf("expected a list for attr %s, got %s instead", - attr_name.data(), list->ob_type->tp_name) - .data()); - - return false; - } - return true; -} - bool OpDoesntRequireOutput(const string& op_name) { static tensorflow::gtl::FlatSet* ops_that_dont_require_outputs = new tensorflow::gtl::FlatSet({ @@ -1582,7 +1546,6 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, break; } } - if (!should_record) Py_RETURN_NONE; string c_op_name = TFE_GetPythonString(op_name); @@ -1616,53 +1579,212 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, Py_RETURN_NONE; } -bool RunCallbacks(bool run_gradient_callback, bool run_post_exec_callbacks, - const tensorflow::OpDef* op_def, PyObject* args, - const std::vector& flattened_inputs, - const std::vector& flattened_attrs, - PyObject* flattened_result, PyObject* op_name, PyObject* name, - PyObject* record_gradient_callback, PyObject* callbacks) { - PyObject* inputs = PyTuple_New(flattened_inputs.size()); +void MaybeWatchVariable(PyObject* input) { + DCHECK(CheckResourceVariable(input)); + DCHECK(PyObject_HasAttrString(input, "_trainable")); + + tensorflow::Safe_PyObjectPtr trainable( + PyObject_GetAttrString(input, "_trainable")); + if (trainable.get() == Py_False) return; + TFE_Py_TapeSetWatchVariable(input); +} + +bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info, + PyObject* input, tensorflow::Safe_PyObjectPtr* output, + TF_Status* status) { + MaybeWatchVariable(input); + + TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status); + auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); }); + if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; + + // Set dtype + DCHECK(PyObject_HasAttrString(input, "_dtype")); + tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype")); + int value; + if (!ParseTypeValue("_dtype", dtype.get(), status, &value)) { + return false; + } + TFE_OpSetAttrType(op, "dtype", static_cast(value)); + + TFE_OpSetDevice(op, parent_op_exec_info.device_name, status); + if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; + + // Get handle + tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle")); + if (!EagerTensor_CheckExact(handle.get())) return false; + TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status); + if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; + + int num_retvals = 1; + TFE_TensorHandle* output_handle; + TFE_Execute(op, &output_handle, &num_retvals, status); + if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; + + // Always create the py object (and correctly DECREF it) from the returned + // value, else the data will leak. + output->reset(EagerTensorFromHandle(output_handle)); + + // TODO(nareshmodi): Should we run post exec callbacks here? + if (parent_op_exec_info.run_gradient_callback) { + tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1)); + PyTuple_SET_ITEM(inputs.get(), 0, handle.release()); + + tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1)); + Py_INCREF(output->get()); // stay alive after since tuple steals. + PyTuple_SET_ITEM(outputs.get(), 0, output->get()); + + if (!RecordGradient(GetPythonObjectFromString("ReadVariableOp"), + inputs.get(), Py_None, outputs.get(), Py_None)) { + return false; + } + } + + return true; +} + +// Supports only 2 cases at the moment: +// i) input is an EagerTensor +// ii) input is a ResourceVariable - in this case, the is_variable param is set +// to true. +bool ConvertToTensor(const FastPathOpExecInfo& op_exec_info, PyObject* input, + tensorflow::Safe_PyObjectPtr* output_handle, + TF_Status* status) { + if (CheckResourceVariable(input)) { + return ReadVariableOp(op_exec_info, input, output_handle, status); + } + + Py_INCREF(input); + output_handle->reset(input); + + return true; +} + +// Adds input and type attr to the op, and to the list of flattened +// inputs/attrs. +bool AddInputToOp(const FastPathOpExecInfo& op_exec_info, PyObject* input, + const tensorflow::OpDef::ArgDef* input_arg, + std::vector* flattened_attrs, + std::vector* flattened_inputs, + TFE_Op* op, TF_Status* status) { + // py_eager_tensor's ownership is transferred to flattened_inputs if it is + // required, else the object is destroyed and DECREF'd when the object goes + // out of scope in this function. + tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr; + + if (!ConvertToTensor(op_exec_info, input, &py_eager_tensor, status)) { + return false; + } + + TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get()); + + if (input_arg != nullptr && !input_arg->type_attr().empty()) { + auto dtype = TFE_TensorHandleDataType(input_handle); + TFE_OpSetAttrType(op, input_arg->type_attr().data(), dtype); + if (flattened_attrs != nullptr) { + flattened_attrs->emplace_back( + GetPythonObjectFromString(input_arg->type_attr().data())); + flattened_attrs->emplace_back(PyLong_FromLong(dtype)); + } + } + + if (flattened_inputs != nullptr) { + flattened_inputs->emplace_back(std::move(py_eager_tensor)); + } + + TFE_OpAddInput(op, input_handle, status); + if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { + return false; + } + + return true; +} + +const tensorflow::OpDef* GetOpDef(PyObject* py_op_name) { + const char* op_name = TFE_GetPythonString(py_op_name); + if (op_name == nullptr) { + PyErr_SetString(PyExc_TypeError, + Printf("expected a string for op_name, got %s instead", + py_op_name->ob_type->tp_name) + .c_str()); + return nullptr; + } + + const tensorflow::OpRegistrationData* op_reg_data = nullptr; + const tensorflow::Status lookup_status = + tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data); + if (MaybeRaiseExceptionFromStatus(lookup_status, nullptr)) { + return nullptr; + } + return &op_reg_data->op_def; +} + +const char* GetDeviceName(PyObject* py_device_name) { + if (py_device_name != Py_None) { + return TFE_GetPythonString(py_device_name); + } + return nullptr; +} + +bool RaiseIfNotPyList(PyObject* list, const string& attr_name) { + if (!PyList_Check(list)) { + PyErr_SetString(PyExc_TypeError, + Printf("expected a list for attr %s, got %s instead", + attr_name.data(), list->ob_type->tp_name) + .data()); + + return false; + } + return true; +} + +bool RunCallbacks( + const FastPathOpExecInfo& op_exec_info, PyObject* args, + const std::vector& flattened_inputs, + const std::vector& flattened_attrs, + PyObject* flattened_result) { + if (!op_exec_info.run_callbacks) return true; + + tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs.size())); for (int i = 0; i < flattened_inputs.size(); i++) { - PyObject* input = flattened_inputs[i]; + PyObject* input = flattened_inputs[i].get(); Py_INCREF(input); - PyTuple_SET_ITEM(inputs, i, input); + PyTuple_SET_ITEM(inputs.get(), i, input); } int num_non_inferred_attrs = PyTuple_GET_SIZE(args) - - op_def->input_arg_size() - + op_exec_info.op_def->input_arg_size() - kFastPathExecuteInputStartIndex; int num_attrs = flattened_attrs.size() + num_non_inferred_attrs; - PyObject* attrs = PyTuple_New(num_attrs); + tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs)); for (int i = 0; i < num_non_inferred_attrs; i++) { - auto* attr = PyTuple_GET_ITEM( - args, kFastPathExecuteInputStartIndex + op_def->input_arg_size() + i); + auto* attr = + PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + + op_exec_info.op_def->input_arg_size() + i); Py_INCREF(attr); - PyTuple_SET_ITEM(attrs, i, attr); + PyTuple_SET_ITEM(attrs.get(), i, attr); } for (int i = num_non_inferred_attrs; i < num_attrs; i++) { - // Not INCREFing anything in flattened_attrs as each of those is a new - // reference, so allow the attrs tuple to steal the reference. - PyTuple_SET_ITEM(attrs, i, flattened_attrs.at(i - num_non_inferred_attrs)); + PyObject* attr_or_name = + flattened_attrs.at(i - num_non_inferred_attrs).get(); + Py_INCREF(attr_or_name); + PyTuple_SET_ITEM(attrs.get(), i, attr_or_name); } - PyObject* callback_args = - Py_BuildValue("OOOOO", op_name, inputs, attrs, flattened_result, name); - - auto cleaner = tensorflow::gtl::MakeCleanup([inputs, attrs, callback_args] { - Py_DECREF(inputs); - Py_DECREF(attrs); - Py_DECREF(callback_args); - }); - - if (run_gradient_callback) { - RecordGradient(op_name, inputs, attrs, flattened_result, name); + if (op_exec_info.run_gradient_callback) { + if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(), + flattened_result, op_exec_info.name)) { + return false; + } } - if (run_post_exec_callbacks) { - for (Py_ssize_t i = 0; i < PyList_Size(callbacks); i++) { - PyObject* callback_fn = PyList_GET_ITEM(callbacks, i); + if (op_exec_info.run_post_exec_callbacks) { + tensorflow::Safe_PyObjectPtr callback_args( + Py_BuildValue("OOOOO", op_exec_info.op_name, inputs.get(), attrs.get(), + flattened_result, op_exec_info.name)); + for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) { + PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i); if (!PyCallable_Check(callback_fn)) { PyErr_SetString( PyExc_TypeError, @@ -1673,7 +1795,7 @@ bool RunCallbacks(bool run_gradient_callback, bool run_post_exec_callbacks, return false; } PyObject* callback_result = - PyObject_CallObject(callback_fn, callback_args); + PyObject_CallObject(callback_fn, callback_args.get()); if (!callback_result) { return false; } @@ -1697,15 +1819,30 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { return nullptr; } - TFE_Context* ctx = reinterpret_cast( + FastPathOpExecInfo op_exec_info; + + op_exec_info.ctx = reinterpret_cast( PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr)); - const char* device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1)); - PyObject* op_name = PyTuple_GET_ITEM(args, 2); - const tensorflow::OpDef* op_def = GetOpDef(op_name); - if (op_def == nullptr) return nullptr; - PyObject* record_gradient_callback = PyTuple_GET_ITEM(args, 3); - PyObject* name = PyTuple_GET_ITEM(args, 4); - PyObject* callbacks = PyTuple_GET_ITEM(args, 5); + op_exec_info.device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1)); + op_exec_info.op_name = PyTuple_GET_ITEM(args, 2); + op_exec_info.op_def = GetOpDef(op_exec_info.op_name); + if (op_exec_info.op_def == nullptr) return nullptr; + op_exec_info.name = PyTuple_GET_ITEM(args, 3); + op_exec_info.callbacks = PyTuple_GET_ITEM(args, 4); + + const tensorflow::OpDef* op_def = op_exec_info.op_def; + + // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks + // (similar to benchmark_tf_gradient_function_*). Also consider using an + // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks + // point out problems with heap allocs. + op_exec_info.run_gradient_callback = + !*ThreadTapeIsStopped() && !GetTapeSet()->empty(); + op_exec_info.run_post_exec_callbacks = + op_exec_info.callbacks != Py_None && + PyList_Size(op_exec_info.callbacks) > 0; + op_exec_info.run_callbacks = op_exec_info.run_gradient_callback || + op_exec_info.run_post_exec_callbacks; if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) { PyErr_SetString( @@ -1718,7 +1855,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { return nullptr; } - if (!CheckEagerTensors(args, kFastPathExecuteInputStartIndex, *op_def)) { + if (!CheckInputsOk(args, kFastPathExecuteInputStartIndex, *op_def)) { RaiseFallbackException( "This function does not handle the case of the path where " "all inputs are not already EagerTensors."); @@ -1726,7 +1863,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { } TF_Status* status = TF_NewStatus(); - TFE_Op* op = TFE_NewOp(ctx, op_def->name().c_str(), status); + TFE_Op* op = TFE_NewOp(op_exec_info.ctx, op_def->name().c_str(), status); auto cleaner = tensorflow::gtl::MakeCleanup([status, op] { TF_DeleteStatus(status); TFE_DeleteOp(op); @@ -1753,8 +1890,8 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { // OpRegistrationData. for (const auto& attr : op_def->attr()) { if (attr_name == attr.name()) { - SetOpAttrWithDefaults(ctx, op, attr, attr_name.data(), py_attr_value, - &attr_list_sizes, status); + SetOpAttrWithDefaults(op_exec_info.ctx, op, attr, attr_name.data(), + py_attr_value, &attr_list_sizes, status); if (TF_GetCode(status) != TF_OK) { RaiseFallbackException(TF_Message(status)); @@ -1766,34 +1903,28 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { } } - TFE_OpSetDevice(op, device_name, status); + TFE_OpSetDevice(op, op_exec_info.device_name, status); if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { return nullptr; } - // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks - // (similar to benchmark_tf_gradient_function_*). Also consider using an - // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks - // point out problems with heap allocs. - bool run_gradient_callback = !*ThreadTapeIsStopped() && - !GetTapeSet()->empty() && - record_gradient_callback != Py_None; - bool run_post_exec_callbacks = - callbacks != Py_None && PyList_Size(callbacks) > 0; - bool run_callbacks = run_gradient_callback || run_post_exec_callbacks; // Flat attrs and inputs as required by the record_gradient call. The attrs // here only contain inferred attrs (non-inferred attrs are added directly // from the input args). - // All items in flattened_attrs contain new references. - // All items in flattened_inputs contain borrowed references. + // All items in flattened_attrs and flattened_inputs contain + // Safe_PyObjectPtr - any time something steals a reference to this, it must + // INCREF. // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work // directly. - std::unique_ptr> flattened_attrs = nullptr; - std::unique_ptr> flattened_inputs = nullptr; + std::unique_ptr> flattened_attrs = + nullptr; + std::unique_ptr> flattened_inputs = + nullptr; - if (run_callbacks) { - flattened_attrs.reset(new std::vector); - flattened_inputs.reset(new std::vector); + // TODO(nareshmodi): Encapsulate callbacks information into a struct. + if (op_exec_info.run_callbacks) { + flattened_attrs.reset(new std::vector); + flattened_inputs.reset(new std::vector); } // Add inferred attrs and inputs. @@ -1813,16 +1944,16 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { Py_ssize_t len = PyList_Size(input); TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len); - if (run_callbacks) { - flattened_attrs->push_back( + if (op_exec_info.run_callbacks) { + flattened_attrs->emplace_back( GetPythonObjectFromString(input_arg.number_attr().data())); - flattened_attrs->push_back(PyLong_FromLong(len)); + flattened_attrs->emplace_back(PyLong_FromLong(len)); } attr_list_sizes[input_arg.number_attr()] = len; if (len > 0) { // First item adds the type attr. - if (!AddInputToOp(PyList_GET_ITEM(input, 0), &input_arg, + if (!AddInputToOp(op_exec_info, PyList_GET_ITEM(input, 0), &input_arg, flattened_attrs.get(), flattened_inputs.get(), op, status)) { return nullptr; @@ -1830,7 +1961,8 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { for (Py_ssize_t j = 1; j < len; j++) { // Since the list is homogeneous, we don't need to re-add the attr. - if (!AddInputToOp(PyList_GET_ITEM(input, j), nullptr /* input_arg */, + if (!AddInputToOp(op_exec_info, PyList_GET_ITEM(input, j), + nullptr /* input_arg */, nullptr /* flattened_attrs */, flattened_inputs.get(), op, status)) { return nullptr; @@ -1844,12 +1976,20 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { Py_ssize_t len = PyList_Size(input); tensorflow::gtl::InlinedVector attr_value(len); PyObject* py_attr_value = nullptr; - if (run_callbacks) { + if (op_exec_info.run_callbacks) { py_attr_value = PyTuple_New(len); } for (Py_ssize_t j = 0; j < len; j++) { PyObject* py_input = PyList_GET_ITEM(input, j); - TFE_TensorHandle* input_handle = EagerTensor_Handle(py_input); + tensorflow::Safe_PyObjectPtr py_eager_tensor; + if (!ConvertToTensor(op_exec_info, py_input, &py_eager_tensor, + status)) { + return nullptr; + } + + TFE_TensorHandle* input_handle = + EagerTensor_Handle(py_eager_tensor.get()); + attr_value[j] = TFE_TensorHandleDataType(input_handle); TFE_OpAddInput(op, input_handle, status); @@ -1857,22 +1997,23 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { return nullptr; } - if (run_callbacks) { - flattened_inputs->push_back(py_input); + if (op_exec_info.run_callbacks) { + flattened_inputs->emplace_back(std::move(py_eager_tensor)); PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j])); } } - if (run_callbacks) { - flattened_attrs->push_back(GetPythonObjectFromString(attr_name.data())); - flattened_attrs->push_back(py_attr_value); + if (op_exec_info.run_callbacks) { + flattened_attrs->emplace_back( + GetPythonObjectFromString(attr_name.data())); + flattened_attrs->emplace_back(py_attr_value); } TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(), attr_value.size()); attr_list_sizes[attr_name] = len; } else { // The item is a single item. - if (!AddInputToOp(input, &input_arg, flattened_attrs.get(), + if (!AddInputToOp(op_exec_info, input, &input_arg, flattened_attrs.get(), flattened_inputs.get(), op, status)) { return nullptr; } @@ -1896,27 +2037,27 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { Py_BEGIN_ALLOW_THREADS; TFE_Execute(op, retvals.data(), &num_retvals, status); Py_END_ALLOW_THREADS; + if (TF_GetCode(status) != TF_OK) { // Augment the status with the op_name for easier debugging similar to // TFE_Py_Execute. TF_SetStatus(status, TF_GetCode(status), - tensorflow::strings::StrCat(TF_Message(status), " [Op:", - TFE_GetPythonString(op_name), "]") + tensorflow::strings::StrCat( + TF_Message(status), + " [Op:", TFE_GetPythonString(op_exec_info.op_name), "]") .c_str()); MaybeRaiseExceptionFromTFStatus(status, nullptr); return nullptr; } - PyObject* flat_result = PyList_New(num_retvals); + tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals)); for (int i = 0; i < num_retvals; ++i) { - PyList_SET_ITEM(flat_result, i, EagerTensorFromHandle(retvals[i])); + PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i])); } - if (run_callbacks && - !RunCallbacks(run_gradient_callback, run_post_exec_callbacks, op_def, - args, *flattened_inputs, *flattened_attrs, flat_result, - op_name, name, record_gradient_callback, callbacks)) { + if (!RunCallbacks(op_exec_info, args, *flattened_inputs, *flattened_attrs, + flat_result.get())) { return nullptr; } @@ -1928,11 +2069,10 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { if (op_def->output_arg_size() == 1) { if (!op_def->output_arg(0).number_attr().empty() || !op_def->output_arg(0).type_list_attr().empty()) { - return flat_result; + return flat_result.release(); } else { - auto* result = PyList_GET_ITEM(flat_result, 0); + auto* result = PyList_GET_ITEM(flat_result.get(), 0); Py_INCREF(result); - Py_DECREF(flat_result); return result; } } @@ -1945,7 +2085,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()]; PyObject* inner_list = PyList_New(list_length); for (int j = 0; j < list_length; j++) { - PyObject* obj = PyList_GET_ITEM(flat_result, flat_result_index++); + PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++); Py_INCREF(obj); PyList_SET_ITEM(inner_list, j, obj); } @@ -1954,18 +2094,17 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()]; PyObject* inner_list = PyList_New(list_length); for (int j = 0; j < list_length; j++) { - PyObject* obj = PyList_GET_ITEM(flat_result, flat_result_index++); + PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++); Py_INCREF(obj); PyList_SET_ITEM(inner_list, j, obj); } PyList_SET_ITEM(result, i, inner_list); } else { - PyObject* obj = PyList_GET_ITEM(flat_result, flat_result_index++); + PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++); Py_INCREF(obj); PyList_SET_ITEM(result, i, obj); } } - Py_DECREF(flat_result); return result; } diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py index 49323e6640e664ef5f98b227964f9dd4e248ca39..faaae40b3f1ef02984a7a75c23ae4acae65ac335 100644 --- a/tensorflow/python/eager/pywrap_tfe_test.py +++ b/tensorflow/python/eager/pywrap_tfe_test.py @@ -21,13 +21,13 @@ from __future__ import print_function from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import context -from tensorflow.python.eager import execute from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import resource_variable_ops class Tests(test.TestCase): @@ -46,15 +46,28 @@ class Tests(test.TestCase): self.assertAllClose( math_ops.matmul(a_2_by_2, b_2_by_2), pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx._handle, ctx.device_name, "MatMul", execute.record_gradient, - None, None, a_2_by_2, b_2_by_2, "transpose_a", False, "transpose_b", - False)) + ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2, + b_2_by_2, "transpose_a", False, "transpose_b", False)) self.assertAllClose( math_ops.matmul(a_100_by_784, b_100_by_784, transpose_b=True), pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx._handle, ctx.device_name, "MatMul", execute.record_gradient, - None, None, a_100_by_784, b_100_by_784, "transpose_a", False, - "transpose_b", True)) + ctx._handle, ctx.device_name, "MatMul", None, None, a_100_by_784, + b_100_by_784, "transpose_a", False, "transpose_b", True)) + + @test_util.assert_no_new_tensors + @test_util.assert_no_garbage_created + def testFastpathExecute_ResourceVariableMatMulCorrectResponse(self): + ctx = context.context() + a_2_by_2 = constant_op.constant(1.0, shape=[2, 2]) + m = resource_variable_ops.ResourceVariable(a_2_by_2) + x = pywrap_tensorflow.TFE_Py_FastPathExecute( + ctx._handle, ctx.device_name, "MatMul", None, None, m, m, "transpose_a", + False, "transpose_b", False) + y = pywrap_tensorflow.TFE_Py_FastPathExecute( + ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2, a_2_by_2, + "transpose_a", False, "transpose_b", False) + + self.assertAllEqual(x, y) @test_util.assert_no_new_tensors @test_util.assert_no_garbage_created @@ -64,12 +77,27 @@ class Tests(test.TestCase): a_2_by_2 = constant_op.constant(1.0, shape=[2, 2]) tape.watch(a_2_by_2) z = pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx._handle, ctx.device_name, "MatMul", execute.record_gradient, None, - None, a_2_by_2, a_2_by_2, "transpose_a", False, "transpose_b", False) + ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2, + a_2_by_2, "transpose_a", False, "transpose_b", False) dz_dy = tape.gradient(z, [a_2_by_2])[0] self.assertAllEqual(dz_dy.numpy(), constant_op.constant(4.0, shape=[2, 2]).numpy()) + @test_util.assert_no_new_tensors + @test_util.assert_no_garbage_created + def testFastpathExecute_ResourceVariableTapeWrite(self): + ctx = context.context() + with backprop.GradientTape(persistent=True) as tape: + a_2_by_2 = constant_op.constant(1.0, shape=[2, 2]) + m = resource_variable_ops.ResourceVariable(a_2_by_2) + tape.watch(m) + z = pywrap_tensorflow.TFE_Py_FastPathExecute( + ctx._handle, ctx.device_name, "MatMul", None, None, m, m, + "transpose_a", False, "transpose_b", False) + dz_dy = tape.gradient(z, [m])[0] + self.assertAllEqual(dz_dy.numpy(), + constant_op.constant(4.0, shape=[2, 2]).numpy()) + # Tests homogeneous list op @test_util.assert_no_new_tensors @test_util.assert_no_garbage_created @@ -80,9 +108,9 @@ class Tests(test.TestCase): self.assertAllClose( math_ops.add_n([a_2_by_2, b_2_by_2]), - pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx._handle, ctx.device_name, "AddN", execute.record_gradient, None, - None, [a_2_by_2, b_2_by_2])) + pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, + "AddN", None, None, + [a_2_by_2, b_2_by_2])) # Tests homogeneous list op @test_util.assert_no_new_tensors @@ -96,8 +124,8 @@ class Tests(test.TestCase): tape.watch(a_2_by_2) tape.watch(b_2_by_2) z1 = pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx._handle, ctx.device_name, "AddN", execute.record_gradient, None, - None, [a_2_by_2, b_2_by_2]) + ctx._handle, ctx.device_name, "AddN", None, None, + [a_2_by_2, b_2_by_2]) z2 = math_ops.add_n([a_2_by_2, b_2_by_2]) dz1_dy = tape.gradient(z1, [a_2_by_2])[0] dz2_dy = tape.gradient(z2, [a_2_by_2])[0] @@ -113,9 +141,9 @@ class Tests(test.TestCase): self.assertAllClose( array_ops.identity_n([a_2_by_2, b_2_by_2]), - pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx._handle, ctx.device_name, "IdentityN", execute.record_gradient, - None, None, [a_2_by_2, b_2_by_2])) + pywrap_tensorflow.TFE_Py_FastPathExecute(ctx._handle, ctx.device_name, + "IdentityN", None, None, + [a_2_by_2, b_2_by_2])) # Tests heterogeneous list op @test_util.assert_no_new_tensors @@ -129,8 +157,8 @@ class Tests(test.TestCase): tape.watch(a_2_by_2) tape.watch(b_2_by_2) z1 = pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx._handle, ctx.device_name, "IdentityN", execute.record_gradient, - None, None, [a_2_by_2, b_2_by_2]) + ctx._handle, ctx.device_name, "IdentityN", None, None, + [a_2_by_2, b_2_by_2]) z2 = array_ops.identity_n([a_2_by_2, b_2_by_2]) dz1_dy = tape.gradient(z1[0], [a_2_by_2])[0] dz2_dy = tape.gradient(z2[0], [a_2_by_2])[0] @@ -141,28 +169,26 @@ class Tests(test.TestCase): def testFastpathExecute_InvalidInputs(self): a_2_by_2 = random_ops.random_uniform((2, 2)) ctx = context.context() - assert not ctx.in_graph_mode( + assert ctx.executing_eagerly( ), "The prototype doesn't contain C code for graph construction" ctx_handle = ctx._handle # pylint: disable=protected-access # Not enough base params with self.assertRaisesRegexp(ValueError, - "at least 6 items in the input tuple"): + "at least 5 items in the input tuple"): pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, "Identity") # Not enough inputs with self.assertRaisesRegexp(ValueError, - "Expected to be at least 7, was 6"): - pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx_handle, ctx_handle, "Identity", backprop._record_gradient, None, - []) + "Expected to be at least 6, was 5"): + pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx_handle, + "Identity", None, []) # Bad type with self.assertRaisesRegexp(TypeError, "expected a string for op_name"): - pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx_handle, ctx.device_name, ctx_handle, backprop._record_gradient, - None, [], a_2_by_2) + pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, + ctx_handle, None, [], a_2_by_2) if __name__ == "__main__": diff --git a/tensorflow/python/eager/tape_test.py b/tensorflow/python/eager/tape_test.py index b490bac66db03b0a61a8852f45f1f558cccaf121..4326d5efa3d362e883815eb2d3dafb27df25afd4 100644 --- a/tensorflow/python/eager/tape_test.py +++ b/tensorflow/python/eager/tape_test.py @@ -21,11 +21,11 @@ from __future__ import print_function from tensorflow.python.eager import backprop from tensorflow.python.eager import context -from tensorflow.python.eager import custom_gradient from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops # Importing nn_grad for the registration functions. @@ -165,21 +165,6 @@ class TapeTest(test.TestCase): g, = backprop.gradients_function(fn, [0])(t) self.assertAllEqual(g, 1.0) - def testCustomGradientGraphMode(self): - with context.graph_mode(), self.test_session(): - - @custom_gradient.custom_gradient - def f(x): - - def grad(dresult): - return dresult * 10.0 - - return x, grad - - inp = constant_op.constant(1.0) - grad = gradients_impl.gradients(f(inp), inp) - self.assertAllEqual(grad[0].eval(), 10.0) - if __name__ == '__main__': test.main() diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index c519fd557a9319d6ef5522b26198e5b4202917fc..5afb5a7dd5d88715768fda985fcea34bc798e37f 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -7,6 +7,7 @@ package( licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") filegroup( name = "all_files", @@ -35,9 +36,9 @@ py_library( ":linear", ":model_fn", ":parsing_utils", + ":replicate_model_fn", ":run_config", ":training", - ":warm_starting_util", "//tensorflow/python:util", ], ) @@ -264,7 +265,6 @@ py_library( "//tensorflow/python:nn", "//tensorflow/python:partitioned_variables", "//tensorflow/python:summary", - "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python/feature_column", "//tensorflow/python/ops/losses", @@ -278,12 +278,12 @@ py_library( srcs = ["canned/dnn_testing_utils.py"], srcs_version = "PY2AND3", deps = [ + ":estimator", ":head", ":metric_keys", ":model_fn", ":numpy_io", ":prediction_keys", - ":warm_starting_util", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", @@ -427,7 +427,6 @@ py_library( ":model_fn", ":run_config", ":util", - ":warm_starting_util", "//tensorflow/core:protos_all_py", "//tensorflow/python:client", "//tensorflow/python:control_flow_ops", @@ -617,6 +616,7 @@ py_library( "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:summary", + "//tensorflow/python:training", "//tensorflow/python:weights_broadcast_ops", "//tensorflow/python/feature_column", "//tensorflow/python/ops/losses", @@ -870,37 +870,65 @@ py_test( ) py_library( - name = "warm_starting_util", - srcs = ["warm_starting_util.py"], + name = "replicate_model_fn", + srcs = [ + "replicate_model_fn.py", + ], srcs_version = "PY2AND3", deps = [ + ":export_output", + ":model_fn", + ":util", + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:device", + "//tensorflow/python:device_lib", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", + "@six_archive//:six", ], ) -py_test( - name = "warm_starting_util_test", - size = "small", - srcs = ["warm_starting_util_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":warm_starting_util", +cuda_py_test( + name = "replicate_model_fn_test", + size = "medium", + srcs = ["replicate_model_fn_test.py"], + additional_deps = [ + "//tensorflow/python/estimator", + ":dnn", + ":export_export", + ":export_output", + ":model_fn", + ":numpy_io", + ":optimizers", + ":prediction_keys", + "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", + "//tensorflow/python/saved_model:signature_constants", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:platform", + "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", - "//tensorflow/python/feature_column", - "//third_party/py/numpy", + ":replicate_model_fn", + ], + tags = [ + "multi_gpu", + "noasan", # flaky time outs ], ) diff --git a/tensorflow/python/estimator/canned/baseline_test.py b/tensorflow/python/estimator/canned/baseline_test.py index 18c955f5a0e998de983b31fc4cc595895e6bbcbd..7833df2052657114c9799417e1b9d96035b4c5ef 100644 --- a/tensorflow/python/estimator/canned/baseline_test.py +++ b/tensorflow/python/estimator/canned/baseline_test.py @@ -1071,11 +1071,13 @@ class BaselineClassifierEvaluationTest(test.TestCase): ops.GraphKeys.GLOBAL_STEP: 100, metric_keys.MetricKeys.LOSS_MEAN: 1.3133, metric_keys.MetricKeys.ACCURACY: 0., + metric_keys.MetricKeys.PRECISION: 0., + metric_keys.MetricKeys.RECALL: 0., metric_keys.MetricKeys.PREDICTION_MEAN: 0.2689, metric_keys.MetricKeys.LABEL_MEAN: 1., metric_keys.MetricKeys.ACCURACY_BASELINE: 1, metric_keys.MetricKeys.AUC: 0., - metric_keys.MetricKeys.AUC_PR: 0.5, + metric_keys.MetricKeys.AUC_PR: 1., } else: # Multi classes: loss = 1 * -log ( softmax(logits)[label] ) @@ -1132,11 +1134,13 @@ class BaselineClassifierEvaluationTest(test.TestCase): ops.GraphKeys.GLOBAL_STEP: 100, metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2, metric_keys.MetricKeys.ACCURACY: 0.5, + metric_keys.MetricKeys.PRECISION: 0., + metric_keys.MetricKeys.RECALL: 0., metric_keys.MetricKeys.PREDICTION_MEAN: 0.2689, metric_keys.MetricKeys.LABEL_MEAN: 0.5, metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5, metric_keys.MetricKeys.AUC: 0.5, - metric_keys.MetricKeys.AUC_PR: 0.25, + metric_keys.MetricKeys.AUC_PR: 0.75, } else: # Expand logits since batch_size=2 @@ -1207,12 +1211,14 @@ class BaselineClassifierEvaluationTest(test.TestCase): ops.GraphKeys.GLOBAL_STEP: 100, metric_keys.MetricKeys.LOSS_MEAN: loss_mean, metric_keys.MetricKeys.ACCURACY: 2. / (1. + 2.), + metric_keys.MetricKeys.PRECISION: 0., + metric_keys.MetricKeys.RECALL: 0., metric_keys.MetricKeys.PREDICTION_MEAN: predictions_mean, metric_keys.MetricKeys.LABEL_MEAN: label_mean, metric_keys.MetricKeys.ACCURACY_BASELINE: ( max(label_mean, 1-label_mean)), metric_keys.MetricKeys.AUC: 0.5, - metric_keys.MetricKeys.AUC_PR: 0.16666645, + metric_keys.MetricKeys.AUC_PR: 2. / (1. + 2.), } else: # Multi classes: unweighted_loss = 1 * -log ( soft_max(logits)[label] ) @@ -1542,4 +1548,3 @@ class BaselineLogitFnTest(test.TestCase): if __name__ == '__main__': test.main() - diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py index 7043da8de036e5be27d223271c37e065d9ffbcdd..6382622e0b5c72e5d3fcd9b9c6863968a425b86f 100644 --- a/tensorflow/python/estimator/canned/dnn.py +++ b/tensorflow/python/estimator/canned/dnn.py @@ -32,7 +32,6 @@ from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary -from tensorflow.python.training import training_util from tensorflow.python.util.tf_export import tf_export # The default learning rate of 0.05 is a historical artifact of the initial @@ -183,17 +182,11 @@ def _dnn_model_fn(features, input_layer_partitioner=input_layer_partitioner) logits = logit_fn(features=features, mode=mode) - def _train_op_fn(loss): - """Returns the op to optimize the loss.""" - return optimizer.minimize( - loss, - global_step=training_util.get_global_step()) - return head.create_estimator_spec( features=features, mode=mode, labels=labels, - train_op_fn=_train_op_fn, + optimizer=optimizer, logits=logits) diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py index 84675bf2a4a1655026bbba37c5d7a63d2f788c46..d275695eb319117cf94aefd7038ab5ee685e05a9 100644 --- a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py +++ b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py @@ -26,7 +26,7 @@ import six from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 -from tensorflow.python.estimator import warm_starting_util +from tensorflow.python.estimator import estimator from tensorflow.python.estimator.canned import dnn_linear_combined from tensorflow.python.estimator.canned import dnn_testing_utils from tensorflow.python.estimator.canned import linear_testing_utils @@ -866,7 +866,7 @@ class DNNLinearCombinedWarmStartingTest(test.TestCase): learning_rate=0.0), # The provided regular expression will only warm-start the deep # portion of the model. - warm_start_from=warm_starting_util.WarmStartSettings( + warm_start_from=estimator.WarmStartSettings( ckpt_to_initialize_from=dnn_lc_classifier.model_dir, vars_to_warm_start='.*(dnn).*'))) diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py index cbae43e4f7fef0271de20a4ec54449989455d4bd..44545c058c673d00f16c4276dc42cdbf4ca188e4 100644 --- a/tensorflow/python/estimator/canned/dnn_testing_utils.py +++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py @@ -27,8 +27,8 @@ import six from tensorflow.core.framework import summary_pb2 from tensorflow.python.client import session as tf_session +from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn -from tensorflow.python.estimator import warm_starting_util from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.canned import prediction_keys @@ -53,7 +53,7 @@ from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import checkpoint_utils from tensorflow.python.training import gradient_descent from tensorflow.python.training import monitored_session -from tensorflow.python.training import optimizer +from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import saver from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util @@ -134,7 +134,8 @@ def mock_head(testcase, hidden_units, logits_dimension, expected_logits): hidden_weights_names + hidden_biases_names + [LOGITS_WEIGHTS_NAME + '/part_0:0', LOGITS_BIASES_NAME + '/part_0:0']) - def _create_estimator_spec(features, mode, logits, labels, train_op_fn): + def _create_estimator_spec( + features, mode, logits, labels, train_op_fn=None, optimizer=None): del features, labels # Not used. trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) testcase.assertItemsEqual(expected_var_names, @@ -144,8 +145,12 @@ def mock_head(testcase, hidden_units, logits_dimension, expected_logits): expected_logits, logits, message='Failed for mode={}. '.format(mode)) with ops.control_dependencies([assert_logits]): if mode == model_fn.ModeKeys.TRAIN: + if train_op_fn is not None: + train_op = train_op_fn(loss) + elif optimizer is not None: + train_op = optimizer.minimize(loss, global_step=None) return model_fn.EstimatorSpec( - mode=mode, loss=loss, train_op=train_op_fn(loss)) + mode=mode, loss=loss, train_op=train_op) elif mode == model_fn.ModeKeys.EVAL: return model_fn.EstimatorSpec(mode=mode, loss=array_ops.identity(loss)) elif mode == model_fn.ModeKeys.PREDICT: @@ -203,8 +208,8 @@ def mock_optimizer(testcase, hidden_units, expected_loss=None): return control_flow_ops.no_op() optimizer_mock = test.mock.NonCallableMagicMock( - spec=optimizer.Optimizer, - wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer')) + spec=optimizer_lib.Optimizer, + wraps=optimizer_lib.Optimizer(use_locking=False, name='my_optimizer')) optimizer_mock.minimize = test.mock.MagicMock(wraps=_minimize) return optimizer_mock @@ -828,7 +833,7 @@ class BaseDNNWarmStartingTest(object): optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0), # The provided regular expression will only warm-start the city # embedding, not the kernels and biases of the hidden weights. - warm_start_from=warm_starting_util.WarmStartSettings( + warm_start_from=estimator.WarmStartSettings( ckpt_to_initialize_from=dnn_classifier.model_dir, vars_to_warm_start='.*(city).*')) @@ -892,7 +897,7 @@ class BaseDNNWarmStartingTest(object): dimension=2) # We can create our VocabInfo object from the new and old occupation # FeatureColumn's. - occupation_vocab_info = warm_starting_util.VocabInfo( + occupation_vocab_info = estimator.VocabInfo( new_vocab=new_occupation.categorical_column.vocabulary_file, new_vocab_size=new_occupation.categorical_column.vocabulary_size, num_oov_buckets=new_occupation.categorical_column.num_oov_buckets, @@ -907,7 +912,7 @@ class BaseDNNWarmStartingTest(object): feature_columns=[occupation], n_classes=4, optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0), - warm_start_from=warm_starting_util.WarmStartSettings( + warm_start_from=estimator.WarmStartSettings( ckpt_to_initialize_from=dnn_classifier.model_dir, var_name_to_vocab_info={ OCCUPATION_EMBEDDING_NAME: occupation_vocab_info @@ -978,7 +983,7 @@ class BaseDNNWarmStartingTest(object): optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0), # The 'city' variable correspond to the 'locality' variable in the # previous model. - warm_start_from=warm_starting_util.WarmStartSettings( + warm_start_from=estimator.WarmStartSettings( ckpt_to_initialize_from=dnn_classifier.model_dir, var_name_to_prev_var_name={ CITY_EMBEDDING_NAME: @@ -1035,13 +1040,16 @@ class BaseDNNClassifierEvaluateTest(object): metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2., metric_keys.MetricKeys.ACCURACY: 0.5, + metric_keys.MetricKeys.PRECISION: 0.0, + metric_keys.MetricKeys.RECALL: 0.0, metric_keys.MetricKeys.PREDICTION_MEAN: 0.11105597, metric_keys.MetricKeys.LABEL_MEAN: 0.5, metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5, # There is no good way to calculate AUC for only two data points. But # that is what the algorithm returns. metric_keys.MetricKeys.AUC: 0.5, - metric_keys.MetricKeys.AUC_PR: 0.25, + metric_keys.MetricKeys.AUC_PR: 0.75, + ops.GraphKeys.GLOBAL_STEP: global_step }, dnn_classifier.evaluate(input_fn=_input_fn, steps=1)) diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index 8d742a2c6147e86619d4c0aad59b69459384bd4d..bb033d349534e044b2b92d064051ee5fa07f4d62 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.ops.losses import losses from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary +from tensorflow.python.training import training_util _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -85,40 +86,39 @@ class _Head(object): ```python def _my_dnn_model_fn(features, labels, mode, params, config=None): # Optionally your callers can pass head to model_fn as a param. - head = tf.contrib.learn.regression_head(...) - input = tf.contrib.layers.input_from_feature_columns(features, ...) - last_hidden_layer_out = tf.contrib.layers.stack( - input, tf.contrib.layers.fully_connected, [1000, 500]) - logits = tf.contrib.layers.fully_connected( - last_hidden_layer_out, head.logits_dimension, activation_fn=None) - - def _train_op_fn(loss): - return optimizer.minimize(loss) + head = tf.contrib.estimator.regression_head(...) + inputs = tf.feature_column.input_layer(features, ...) + hidden_layer0 = tf.layers.dense( + inputs, units=1000, activation=tf.nn.relu) + hidden_layer1 = tf.layers.dense( + hidden_layer0, units=500, activation=tf.nn.relu) + logits = tf.layers.dense( + hidden_layer1, units=head.logits_dimension, activation=None) return head.create_estimator_spec( features=features, labels=labels, mode=mode, logits=logits, - train_op_fn=_train_op_fn) + optimizer=optimizer) ``` There are cases where computing and applying gradients can not be meaningfully - captured with train_op_fn we support (for example, with sync optimizer). In - such case, you can take the responsibility on your own. Here is a common - use case, + captured with optimizer or train_op_fn we support (for example, with sync + optimizer). In such case, you can take the responsibility on your own. Here is + a common use case, ```python estimator_spec = head.create_estimator_spec( features=features, labels=labels, mode=mode, logits=logits, - train_op_fn=tf.contrib.learn.no_op_train_fn) + train_op_fn=lambda _: tf.no_op()) if mode == model_fn.ModeKeys.TRAIN: optimizer = ... sync = tf.train.SyncReplicasOptimizer(opt=optimizer, ...) - update_op = tf.contrib.layers.optimize_loss(optimizer=sync, - loss=estimator_spec.loss, ...) + update_op = sync.minimize( + estimator_spec.loss, global_step=tf.get_global_step()) hooks = [sync.make_session_run_hook(is_chief)] ... update train_op and hooks in EstimatorSpec and return ``` @@ -172,10 +172,12 @@ class _Head(object): """ raise NotImplementedError('Calling an abstract method.') + # TODO(b/65403806): By default, collect regularization_losses from + # GraphKeys.REGULARIZATION_LOSSES collection. @abc.abstractmethod def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None, - regularization_losses=None): + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None, regularization_losses=None): """Returns `EstimatorSpec` that a model_fn can return. Please note that, @@ -186,10 +188,14 @@ class _Head(object): mode: Estimator's `ModeKeys`. logits: logits `Tensor` to be used by the head. labels: Labels `Tensor`, or `dict` of same. + optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. + Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which + updates variables and increments `global_step`. train_op_fn: Function that takes a scalar loss `Tensor` and returns an op - to optimize the model with the loss. This is used in TRAIN mode and - must not be None. None is allowed in other modes. If you want to - optimize loss yourself you can pass `no_op_train_fn` and then use + to optimize the model with the loss in TRAIN mode. Used if `optimizer` + is `None`. Exactly one of `train_op_fn` and `optimizer` must be set in + TRAIN mode. None is allowed in other modes. If you want to optimize loss + yourself you can pass `lambda _: tf.no_op()` and then use EstimatorSpec.loss to compute and apply gradients. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. @@ -694,8 +700,8 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): processed_labels=label_ids) def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None, - regularization_losses=None): + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None, regularization_losses=None): """Returns an `EstimatorSpec`. Args: @@ -706,8 +712,11 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): labels: Labels integer or string `Tensor` with shape matching `logits`, namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is required argument when `mode` equals `TRAIN` or `EVAL`. + optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. + Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which + updates variables and increments `global_step`. train_op_fn: Function that takes a scalar loss `Tensor` and returns - `train_op`. Required in TRAIN mode. + `train_op`. Used if `optimizer` is `None`. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results users need to @@ -717,7 +726,8 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): Returns: `EstimatorSpec`. Raises: - ValueError: If `train_op_fn` is `None` in TRAIN mode. + ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN + mode, or if both are set. """ with ops.name_scope(self._name, 'head'): logits = _check_logits_final_dim(logits, self.logits_dimension) @@ -780,8 +790,16 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): regularization_loss=regularization_loss)) # Train. - if train_op_fn is None: - raise ValueError('train_op_fn cannot be None.') + if optimizer is not None: + if train_op_fn is not None: + raise ValueError('train_op_fn and optimizer cannot both be set.') + train_op = optimizer.minimize( + regularized_training_loss, + global_step=training_util.get_global_step()) + elif train_op_fn is not None: + train_op = train_op_fn(regularized_training_loss) + else: + raise ValueError('train_op_fn and optimizer cannot both be None.') # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -807,7 +825,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, - train_op=train_op_fn(regularized_training_loss)) + train_op=train_op) def _binary_logistic_head_with_sigmoid_cross_entropy_loss( @@ -869,11 +887,12 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss( Raises: ValueError: If `thresholds` contains a value outside of `(0, 1)`. ValueError: If `loss_reduction` is invalid. + TypeError: if `label_vocabulary` has invalid type. """ thresholds = tuple(thresholds) if thresholds else tuple() if label_vocabulary is not None and not isinstance(label_vocabulary, (list, tuple)): - raise ValueError( + raise TypeError( 'label_vocabulary should be a list or tuple. Given type: {}'.format( type(label_vocabulary))) @@ -940,6 +959,18 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): predictions=class_ids, weights=weights, name=keys.ACCURACY), + _summary_key(self._name, keys.PRECISION): + metrics_lib.precision( + labels=labels, + predictions=class_ids, + weights=weights, + name=keys.PRECISION), + _summary_key(self._name, keys.RECALL): + metrics_lib.recall( + labels=labels, + predictions=class_ids, + weights=weights, + name=keys.RECALL), _summary_key(self._name, keys.PREDICTION_MEAN): _predictions_mean( predictions=logistic, @@ -1027,8 +1058,8 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): processed_labels=labels) def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None, - regularization_losses=None): + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None, regularization_losses=None): """Returns an `EstimatorSpec`. Args: @@ -1039,8 +1070,11 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): labels: Labels integer or string `Tensor` with shape matching `logits`, namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is required argument when `mode` equals `TRAIN` or `EVAL`. + optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. + Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which + updates variables and increments `global_step`. train_op_fn: Function that takes a scalar loss `Tensor` and returns - `train_op`. Required in TRAIN mode. + `train_op`. Used if `optimizer` is `None`. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results users need to @@ -1050,7 +1084,8 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): Returns: `EstimatorSpec`. Raises: - ValueError: If `train_op_fn` is `None` in TRAIN mode. + ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN + mode, or if both are set. """ # Predict. with ops.name_scope(self._name, 'head'): @@ -1122,8 +1157,16 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): regularization_loss=regularization_loss)) # Train. - if train_op_fn is None: - raise ValueError('train_op_fn can not be None.') + if optimizer is not None: + if train_op_fn is not None: + raise ValueError('train_op_fn and optimizer cannot both be set.') + train_op = optimizer.minimize( + regularized_training_loss, + global_step=training_util.get_global_step()) + elif train_op_fn is not None: + train_op = train_op_fn(regularized_training_loss) + else: + raise ValueError('train_op_fn and optimizer cannot both be None.') # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -1148,7 +1191,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, - train_op=train_op_fn(regularized_training_loss)) + train_op=train_op) def _regression_head_with_mean_squared_error_loss( @@ -1277,8 +1320,8 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): processed_labels=labels) def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None, - regularization_losses=None): + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None, regularization_losses=None): """Returns an `EstimatorSpec`. Args: @@ -1290,8 +1333,11 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): `[D0, D1, ... DN, logits_dimension]`. When `logits_dimension=1`, shape `[D0, D1, ... DN]` is also supported. `labels` is required argument when `mode` equals `TRAIN` or `EVAL`. + optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. + Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which + updates variables and increments `global_step`. train_op_fn: Function that takes a scalar loss `Tensor` and returns - `train_op`. Required in TRAIN mode. + `train_op`. Used if `optimizer` is `None`. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results users need to @@ -1301,7 +1347,8 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): Returns: `EstimatorSpec`. Raises: - ValueError: If `train_op_fn` is `None` in TRAIN mode. + ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN + mode, or if both are set. """ # Predict. with ops.name_scope(self._name, 'head'): @@ -1361,8 +1408,16 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): eval_metric_ops=eval_metric_ops) # Train. - if train_op_fn is None: - raise ValueError('train_op_fn can not be None.') + if optimizer is not None: + if train_op_fn is not None: + raise ValueError('train_op_fn and optimizer cannot both be set.') + train_op = optimizer.minimize( + regularized_training_loss, + global_step=training_util.get_global_step()) + elif train_op_fn is not None: + train_op = train_op_fn(regularized_training_loss) + else: + raise ValueError('train_op_fn and optimizer cannot both be None.') # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -1387,7 +1442,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, - train_op=train_op_fn(regularized_training_loss)) + train_op=train_op) def _assert_range(labels, n_classes, message=None): diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index a300f315c18f60e77f262a3b961c5ef6306bc235..fe6ee07529bc0314618a7cc85926dbb39660a352 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -300,7 +300,12 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): features = {'x': values_2x3} # Static shape. - with self.assertRaisesRegexp(ValueError, 'Dimensions must be equal'): + with self.assertRaisesRegexp( + ValueError, + r'Shape mismatch: The shape of labels \(received \(3,\)\) should equal ' + r'the shape of logits except for the last dimension ' + r'\(received \(2, 3\)\)\.' + ): head.create_loss( features=features, mode=model_fn.ModeKeys.EVAL, @@ -837,6 +842,41 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2, }, summary_str, tol) + def test_train_with_optimizer(self): + n_classes = 3 + head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes) + + logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32) + labels = np.array(((1,), (1,)), dtype=np.int64) + features = {'x': np.array(((42,),), dtype=np.int32)} + expected_train_result = 'my_train_op' + + class _Optimizer(object): + + def minimize(self, loss, global_step): + del global_step + return string_ops.string_join( + [constant_op.constant(expected_train_result), + string_ops.as_string(loss, precision=2)]) + + # loss = sum(cross_entropy(labels, logits)) = sum(10, 0) = 10. + expected_loss = 10. + spec = head.create_estimator_spec( + features=features, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + optimizer=_Optimizer()) + + tol = 1e-2 + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) + self.assertEqual( + six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)), + train_result) + def test_train_summaries_with_head_name(self): n_classes = 3 head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( @@ -1554,11 +1594,13 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): # loss_mean = loss/2 = 41./2 = 20.5 keys.LOSS_MEAN: 20.5, keys.ACCURACY: 1./2, + keys.PRECISION: 1., + keys.RECALL: 1./2, keys.PREDICTION_MEAN: 1./2, keys.LABEL_MEAN: 2./2, keys.ACCURACY_BASELINE: 2./2, keys.AUC: 0., - keys.AUC_PR: 0.74999905, + keys.AUC_PR: 1., } # Assert spec contains expected tensors. @@ -1597,11 +1639,13 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): expected_metric_keys = [ '{}/some_binary_head'.format(metric_keys.MetricKeys.LOSS_MEAN), '{}/some_binary_head'.format(metric_keys.MetricKeys.ACCURACY), + '{}/some_binary_head'.format(metric_keys.MetricKeys.PRECISION), + '{}/some_binary_head'.format(metric_keys.MetricKeys.RECALL), '{}/some_binary_head'.format(metric_keys.MetricKeys.PREDICTION_MEAN), '{}/some_binary_head'.format(metric_keys.MetricKeys.LABEL_MEAN), '{}/some_binary_head'.format(metric_keys.MetricKeys.ACCURACY_BASELINE), '{}/some_binary_head'.format(metric_keys.MetricKeys.AUC), - '{}/some_binary_head'.format(metric_keys.MetricKeys.AUC_PR) + '{}/some_binary_head'.format(metric_keys.MetricKeys.AUC_PR), ] self.assertItemsEqual(expected_metric_keys, spec.eval_metric_ops.keys()) @@ -1632,11 +1676,13 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): keys.LOSS_MEAN: expected_unregularized_loss, keys.LOSS_REGULARIZATION: expected_regularization_loss, keys.ACCURACY: 1./2, + keys.PRECISION: 1., + keys.RECALL: 1./2, keys.PREDICTION_MEAN: 1./2, keys.LABEL_MEAN: 2./2, keys.ACCURACY_BASELINE: 2./2, keys.AUC: 0., - keys.AUC_PR: 0.75, + keys.AUC_PR: 1., } # Assert predictions, loss, and metrics. @@ -1737,11 +1783,13 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): expected_metrics = { keys.LOSS_MEAN: 1.62652338 / 2., keys.ACCURACY: 1./2, + keys.PRECISION: 1., + keys.RECALL: .5, keys.PREDICTION_MEAN: 1./2, keys.LABEL_MEAN: 2./2, keys.ACCURACY_BASELINE: 2./2, keys.AUC: 0., - keys.AUC_PR: 0.74999905, + keys.AUC_PR: 1., keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 1., keys.PRECISION_AT_THRESHOLD % thresholds[0]: 1., keys.RECALL_AT_THRESHOLD % thresholds[0]: 1., @@ -1929,6 +1977,39 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): metric_keys.MetricKeys.LOSS_MEAN: 20.5, }, summary_str) + def test_train_with_optimizer(self): + head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss() + + logits = np.array(((45,), (-41,),), dtype=np.float32) + labels = np.array(((1,), (1,),), dtype=np.float64) + expected_train_result = b'my_train_op' + features = {'x': np.array(((42,),), dtype=np.float32)} + # loss = sum(cross_entropy(labels, logits)) = sum(0, 41) = 41 + expected_loss = 41. + + class _Optimizer(object): + + def minimize(self, loss, global_step): + del global_step + with ops.control_dependencies((check_ops.assert_equal( + math_ops.to_float(expected_loss), math_ops.to_float(loss), + name='assert_loss'),)): + return constant_op.constant(expected_train_result) + + # Create estimator spec. + spec = head.create_estimator_spec( + features=features, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + optimizer=_Optimizer()) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss) + self.assertEqual(expected_train_result, train_result) + def test_train_summaries_with_head_name(self): head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( name='some_binary_head') @@ -2182,13 +2263,15 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): keys.LOSS_MEAN: 26.9615384615, # accuracy = (1*1 + .1*0 + 1.5*0)/(1 + .1 + 1.5) = 1/2.6 = .38461538461 keys.ACCURACY: .38461538461, + keys.PRECISION: 1./2.5, + keys.RECALL: 1./1.1, # prediction_mean = (1*1 + .1*0 + 1.5*1)/(1 + .1 + 1.5) = 2.5/2.6 # = .96153846153 keys.PREDICTION_MEAN: .96153846153, keys.LABEL_MEAN: expected_label_mean, keys.ACCURACY_BASELINE: 1 - expected_label_mean, keys.AUC: .45454565, - keys.AUC_PR: .21923049, + keys.AUC_PR: .6737757325172424, } # Assert spec contains expected tensors. @@ -2481,13 +2564,15 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): expected_metrics = { keys.LOSS_MEAN: expected_loss / np.sum(weights), keys.ACCURACY: (1.*0. + 1.5*1. + 2.*1. + 2.5*0.) / np.sum(weights), + keys.PRECISION: 2.0/3.0, + keys.RECALL: 2.0/4.5, keys.PREDICTION_MEAN: (1.*1 + 1.5*0 + 2.*1 + 2.5*0) / np.sum(weights), keys.LABEL_MEAN: (1.*0 + 1.5*0 + 2.*1 + 2.5*1) / np.sum(weights), keys.ACCURACY_BASELINE: (1.*0 + 1.5*0 + 2.*1 + 2.5*1) / np.sum(weights), # We cannot reliably calculate AUC with only 4 data points, but the # values should not change because of backwards-compatibility. keys.AUC: 0.5222, - keys.AUC_PR: 0.5119, + keys.AUC_PR: 0.7341, } tol = 1e-2 @@ -3059,6 +3144,40 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): metric_keys.MetricKeys.LOSS_MEAN: 6.5, }, summary_str) + def test_train_with_optimizer(self): + head = head_lib._regression_head_with_mean_squared_error_loss() + self.assertEqual(1, head.logits_dimension) + + # Create estimator spec. + logits = np.array(((45,), (41,),), dtype=np.float32) + labels = np.array(((43.,), (44.,),), dtype=np.float64) + expected_train_result = b'my_train_op' + features = {'x': np.array(((42.,),), dtype=np.float32)} + # loss = (43-45)^2 + (44-41)^2 = 4 + 9 = 13 + expected_loss = 13 + + class _Optimizer(object): + + def minimize(self, loss, global_step): + del global_step + with ops.control_dependencies((check_ops.assert_equal( + math_ops.to_float(expected_loss), math_ops.to_float(loss), + name='assert_loss'),)): + return constant_op.constant(expected_train_result) + + spec = head.create_estimator_spec( + features=features, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + optimizer=_Optimizer()) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss) + self.assertEqual(expected_train_result, train_result) + def test_train_summaries_with_head_name(self): head = head_lib._regression_head_with_mean_squared_error_loss( name='some_regression_head') diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py index a2f24ef27044680fe93b176b5207593165d0d109..e7ec4179917a88703444f8aa835ed0359ff58a46 100644 --- a/tensorflow/python/estimator/canned/linear.py +++ b/tensorflow/python/estimator/canned/linear.py @@ -33,7 +33,6 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary from tensorflow.python.training import ftrl -from tensorflow.python.training import training_util from tensorflow.python.util.tf_export import tf_export @@ -157,17 +156,11 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer, units=head.logits_dimension, feature_columns=feature_columns) logits = logit_fn(features=features) - def _train_op_fn(loss): - """Returns the op to optimize the loss.""" - return optimizer.minimize( - loss, - global_step=training_util.get_global_step()) - return head.create_estimator_spec( features=features, mode=mode, labels=labels, - train_op_fn=_train_op_fn, + optimizer=optimizer, logits=logits) diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py index e88fcbbd2e0e3617dde428662e58b1d86c4eddd0..da3ce86999b32e081eb8f12bbd9f7a4599ed4eaa 100644 --- a/tensorflow/python/estimator/canned/linear_testing_utils.py +++ b/tensorflow/python/estimator/canned/linear_testing_utils.py @@ -31,7 +31,6 @@ from tensorflow.core.example import feature_pb2 from tensorflow.python.client import session as tf_session from tensorflow.python.estimator import estimator from tensorflow.python.estimator import run_config -from tensorflow.python.estimator import warm_starting_util from tensorflow.python.estimator.canned import linear from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.export import export @@ -1338,11 +1337,13 @@ class BaseLinearClassifierEvaluationTest(object): ops.GraphKeys.GLOBAL_STEP: 100, metric_keys.MetricKeys.LOSS_MEAN: 41., metric_keys.MetricKeys.ACCURACY: 0., + metric_keys.MetricKeys.PRECISION: 0., + metric_keys.MetricKeys.RECALL: 0., metric_keys.MetricKeys.PREDICTION_MEAN: 0., metric_keys.MetricKeys.LABEL_MEAN: 1., metric_keys.MetricKeys.ACCURACY_BASELINE: 1, metric_keys.MetricKeys.AUC: 0., - metric_keys.MetricKeys.AUC_PR: 0.5, + metric_keys.MetricKeys.AUC_PR: 1., } else: # Multi classes: loss = 1 * -log ( soft_max(logits)[label] ) @@ -1407,6 +1408,8 @@ class BaseLinearClassifierEvaluationTest(object): ops.GraphKeys.GLOBAL_STEP: 100, metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2, metric_keys.MetricKeys.ACCURACY: 0., + metric_keys.MetricKeys.PRECISION: 0., + metric_keys.MetricKeys.RECALL: 0., metric_keys.MetricKeys.PREDICTION_MEAN: 0.5, metric_keys.MetricKeys.LABEL_MEAN: 0.5, metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5, @@ -1488,6 +1491,8 @@ class BaseLinearClassifierEvaluationTest(object): ops.GraphKeys.GLOBAL_STEP: 100, metric_keys.MetricKeys.LOSS_MEAN: loss_mean, metric_keys.MetricKeys.ACCURACY: 0., + metric_keys.MetricKeys.PRECISION: 0., + metric_keys.MetricKeys.RECALL: 0., metric_keys.MetricKeys.PREDICTION_MEAN: predictions_mean, metric_keys.MetricKeys.LABEL_MEAN: label_mean, metric_keys.MetricKeys.ACCURACY_BASELINE: ( @@ -1968,7 +1973,7 @@ class BaseLinearWarmStartingTest(object): optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0), # The provided regular expression will only warm-start the age variable # and not the bias. - warm_start_from=warm_starting_util.WarmStartSettings( + warm_start_from=estimator.WarmStartSettings( ckpt_to_initialize_from=linear_classifier.model_dir, vars_to_warm_start='.*(age).*')) @@ -2016,7 +2021,7 @@ class BaseLinearWarmStartingTest(object): vocabulary_size=len(new_vocab_list)) # We can create our VocabInfo object from the new and old occupation # FeatureColumn's. - occupation_vocab_info = warm_starting_util.VocabInfo( + occupation_vocab_info = estimator.VocabInfo( new_vocab=new_occupation.vocabulary_file, new_vocab_size=new_occupation.vocabulary_size, num_oov_buckets=new_occupation.num_oov_buckets, @@ -2030,7 +2035,7 @@ class BaseLinearWarmStartingTest(object): feature_columns=[occupation], n_classes=4, optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0), - warm_start_from=warm_starting_util.WarmStartSettings( + warm_start_from=estimator.WarmStartSettings( ckpt_to_initialize_from=linear_classifier.model_dir, var_name_to_vocab_info={ OCCUPATION_WEIGHT_NAME: occupation_vocab_info @@ -2082,7 +2087,7 @@ class BaseLinearWarmStartingTest(object): optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0), # The 'age' variable correspond to the 'age_in_years' variable in the # previous model. - warm_start_from=warm_starting_util.WarmStartSettings( + warm_start_from=estimator.WarmStartSettings( ckpt_to_initialize_from=linear_classifier.model_dir, var_name_to_prev_var_name={ AGE_WEIGHT_NAME: AGE_WEIGHT_NAME.replace('age', 'age_in_years') diff --git a/tensorflow/python/estimator/canned/metric_keys.py b/tensorflow/python/estimator/canned/metric_keys.py index 44eb680939203fea67e3391326a6f1013f022ad5..f374d3154982e3b7cdc637e9e3606b3a2947cbf3 100644 --- a/tensorflow/python/estimator/canned/metric_keys.py +++ b/tensorflow/python/estimator/canned/metric_keys.py @@ -28,6 +28,8 @@ class MetricKeys(object): LOSS_REGULARIZATION = 'regularization_loss' ACCURACY = 'accuracy' + PRECISION = 'precision' + RECALL = 'recall' # This is the best the model could do by always predicting one class. # Should be < ACCURACY in a trained model. ACCURACY_BASELINE = 'accuracy_baseline' diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 1167b3834eb6a79abf670f629ec2cbc37957d191..6a4132bca2cb9f14984b39462d00cf68e4e4ae62 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import copy import os import tempfile @@ -35,7 +36,6 @@ from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config from tensorflow.python.estimator import util -from tensorflow.python.estimator import warm_starting_util from tensorflow.python.estimator.export.export import build_all_signature_defs from tensorflow.python.estimator.export.export import get_temp_export_dir from tensorflow.python.estimator.export.export import get_timestamped_export_dir @@ -49,11 +49,13 @@ from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import device_setter from tensorflow.python.training import evaluation from tensorflow.python.training import monitored_session from tensorflow.python.training import saver from tensorflow.python.training import training from tensorflow.python.training import training_util +from tensorflow.python.training import warm_starting_util from tensorflow.python.util import compat from tensorflow.python.util import compat_internal from tensorflow.python.util import nest @@ -137,8 +139,8 @@ class Estimator(object): to configure Estimators from hyper parameter tuning. * `config`: Optional configuration object. Will receive what is passed to Estimator in `config` parameter, or the default `config`. - Allows updating things in your model_fn based on configuration - such as `num_ps_replicas`, or `model_dir`. + Allows updating things in your `model_fn` based on + configuration such as `num_ps_replicas`, or `model_dir`. * Returns: `EstimatorSpec` @@ -165,7 +167,7 @@ class Estimator(object): ValueError: if this is called via a subclass and if that class overrides a member of `Estimator`. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError( 'Estimators are not supported when eager execution is enabled.') @@ -216,8 +218,8 @@ class Estimator(object): self._params = copy.deepcopy(params or {}) # pylint: disable=protected-access - self._warm_start_settings = ( - warm_starting_util._get_default_warm_start_settings(warm_start_from)) + self._warm_start_settings = _get_default_warm_start_settings( + warm_start_from) # pylint: enable=protected-access @property @@ -299,11 +301,11 @@ class Estimator(object): * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a tuple (features, labels) with same constraints as below. - * A tuple (features, labels): Where features is a `Tensor` or a - dictionary of string feature name to `Tensor` and labels is a + * A tuple (features, labels): Where `features` is a `Tensor` or a + dictionary of string feature name to `Tensor` and `labels` is a `Tensor` or a dictionary of string label name to `Tensor`. Both - features and labels are consumed by `model_fn`. They should satisfy - the expectation of `model_fn` from inputs. + `features` and `labels` are consumed by `model_fn`. They should + satisfy the expectation of `model_fn` from inputs. hooks: List of `SessionRunHook` subclass instances. Used for callbacks inside the training loop. @@ -379,11 +381,11 @@ class Estimator(object): * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a tuple (features, labels) with same constraints as below. - * A tuple (features, labels): Where features is a `Tensor` or a - dictionary of string feature name to `Tensor` and labels is a + * A tuple (features, labels): Where `features` is a `Tensor` or a + dictionary of string feature name to `Tensor` and `labels` is a `Tensor` or a dictionary of string label name to `Tensor`. Both - features and labels are consumed by `model_fn`. They should satisfy - the expectation of `model_fn` from inputs. + `features` and `labels` are consumed by `model_fn`. They should + satisfy the expectation of `model_fn` from inputs. steps: Number of steps for which to evaluate model. If `None`, evaluates until `input_fn` raises an end-of-input exception. @@ -455,17 +457,17 @@ class Estimator(object): checkpoint_path: Path of a specific checkpoint to predict. If `None`, the latest checkpoint in `model_dir` is used. yield_single_examples: If False, yield the whole batch as returned by the - model_fn instead of decomposing the batch into individual elements. This - is useful if model_fn return some tensor with first dimension not - equal to the batch size + `model_fn` instead of decomposing the batch into individual elements. + This is useful if `model_fn` returns some tensors whose first dimension + is not equal to the batch size. Yields: Evaluated values of `predictions` tensors. Raises: - ValueError: Could not find a trained model in model_dir. - ValueError: if batch length of predictions are not same and - yield_single_examples is True. + ValueError: Could not find a trained model in `model_dir`. + ValueError: If batch length of predictions is not the same and + `yield_single_examples` is True. ValueError: If there is a conflict between `predict_keys` and `predictions`. For example if `predict_keys` is not `None` but `EstimatorSpec.predictions` is not a `dict`. @@ -515,7 +517,7 @@ class Estimator(object): allowed_overrides = set([ '_call_input_fn', '_create_global_step', '_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks', - '_tf_api_names' + '_tf_api_names', '_validate_features_in_predict_input' ]) estimator_members = set([m for m in Estimator.__dict__.keys() if not m.startswith('__')]) @@ -570,7 +572,7 @@ class Estimator(object): export_dir_base: A string containing a directory in which to create timestamped subdirectories containing exported SavedModels. serving_input_receiver_fn: A function that takes no argument and - returns a `ServingInputReceiver`. + returns a `ServingInputReceiver` or `TensorServingInputReceiver`. assets_extra: A dict specifying how to populate the assets.extra directory within the exported SavedModel, or `None` if no extra assets are needed. as_text: whether to write the SavedModel proto in text format. @@ -668,11 +670,14 @@ class Estimator(object): # Unconditionally drop the label (the second element of result). result = result[0] + self._validate_features_in_predict_input(result) + return result, input_hooks + + def _validate_features_in_predict_input(self, result): if not _has_dataset_or_queue_runner(result): logging.warning('Input graph does not use tf.data.Dataset or contain a ' 'QueueRunner. That means predict yields forever. ' 'This is probably a mistake.') - return result, input_hooks def _get_features_and_labels_from_input_fn(self, input_fn, mode): """Extracts the `features` and labels from return values of `input_fn`.""" @@ -720,7 +725,7 @@ class Estimator(object): """Creates the global step tensor in graph. The global step tensor must be an integer type with name 'global_step' and - be added to the collection ${tf.GraphKeys.GLOBAL_STEP}. + be added to the collection @{tf.GraphKeys.GLOBAL_STEP}. Args: graph: The graph in which to create the global step tensor. @@ -826,7 +831,7 @@ class Estimator(object): logging.info('Warm-starting with WarmStartSettings: %s' % (self._warm_start_settings,)) # pylint: disable=protected-access - warm_starting_util._warm_start(self._warm_start_settings) + warm_starting_util.warm_start(*self._warm_start_settings) # pylint: enable=protected-access # Check if the user created a loss summary, and add one if they didn't. # We assume here that the summary is called 'loss'. If it is not, we will @@ -844,7 +849,7 @@ class Estimator(object): 'loss': estimator_spec.loss, 'step': global_step_tensor }, - every_n_iter=100) + every_n_iter=self._config.log_step_count_steps) ]) worker_hooks.extend(estimator_spec.training_hooks) @@ -1007,13 +1012,6 @@ def _get_replica_device_setter(config): Returns: A replica device setter, or None. """ - ps_ops = [ - 'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable', - 'MutableHashTableV2', 'MutableHashTableOfTensors', - 'MutableHashTableOfTensorsV2', 'MutableDenseHashTable', - 'MutableDenseHashTableV2', 'VarHandleOp' - ] - if config.task_type: worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id) else: @@ -1024,7 +1022,7 @@ def _get_replica_device_setter(config): ps_tasks=config.num_ps_replicas, worker_device=worker_device, merge_devices=True, - ps_ops=ps_ops, + ps_ops=list(device_setter.STANDARD_PS_OPS), cluster=config.cluster_spec) else: return None @@ -1118,7 +1116,7 @@ def _write_dict_to_summary(output_dir, try: summ = summary_pb2.Summary.FromString(dictionary[key]) for i, _ in enumerate(summ.value): - summ.value[i].tag = key + summ.value[i].tag = '%s/%d' % (key, i) summary_proto.value.extend(summ.value) except message.DecodeError: logging.warn('Skipping summary for %s, cannot parse string to Summary.', @@ -1155,3 +1153,187 @@ class _DatasetInitializerHook(training.SessionRunHook): def after_create_session(self, session, coord): del coord session.run(self._initializer) + +VocabInfo = warm_starting_util.VocabInfo # pylint: disable=invalid-name + + +@tf_export('estimator.WarmStartSettings') +class WarmStartSettings( + collections.namedtuple('WarmStartSettings', [ + 'ckpt_to_initialize_from', + 'vars_to_warm_start', + 'var_name_to_vocab_info', + 'var_name_to_prev_var_name', + ])): + """Settings for warm-starting in Estimators. + + Example Use with canned `DNNEstimator`: + + ``` + emb_vocab_file = tf.feature_column.embedding_column( + tf.feature_column.categorical_column_with_vocabulary_file( + "sc_vocab_file", "new_vocab.txt", vocab_size=100), + dimension=8) + emb_vocab_list = tf.feature_column.embedding_column( + tf.feature_column.categorical_column_with_vocabulary_list( + "sc_vocab_list", vocabulary_list=["a", "b"]), + dimension=8) + estimator = tf.estimator.DNNClassifier( + hidden_units=[128, 64], feature_columns=[emb_vocab_file, emb_vocab_list], + warm_start_from=ws) + ``` + + where `ws` could be defined as: + + Warm-start all weights in the model (input layer and hidden weights). + Either the directory or a specific checkpoint can be provided (in the case + of the former, the latest checkpoint will be used): + + ``` + ws = WarmStartSettings(ckpt_to_initialize_from="/tmp") + ws = WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000") + ``` + + Warm-start only the embeddings (input layer): + + ``` + ws = WarmStartSettings(ckpt_to_initialize_from="/tmp", + vars_to_warm_start=".*input_layer.*") + ``` + + Warm-start all weights but the embedding parameters corresponding to + `sc_vocab_file` have a different vocab from the one used in the current + model: + + ``` + vocab_info = tf.estimator.VocabInfo( + new_vocab=sc_vocab_file.vocabulary_file, + new_vocab_size=sc_vocab_file.vocabulary_size, + num_oov_buckets=sc_vocab_file.num_oov_buckets, + old_vocab="old_vocab.txt" + ) + ws = WarmStartSettings( + ckpt_to_initialize_from="/tmp", + var_name_to_vocab_info={ + "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info + }) + ``` + + Warm-start only `sc_vocab_file` embeddings (and no other variables), which + have a different vocab from the one used in the current model: + + ``` + vocab_info = tf.estimator.VocabInfo( + new_vocab=sc_vocab_file.vocabulary_file, + new_vocab_size=sc_vocab_file.vocabulary_size, + num_oov_buckets=sc_vocab_file.num_oov_buckets, + old_vocab="old_vocab.txt" + ) + ws = WarmStartSettings( + ckpt_to_initialize_from="/tmp", + vars_to_warm_start=None, + var_name_to_vocab_info={ + "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info + }) + ``` + + Warm-start all weights but the parameters corresponding to `sc_vocab_file` + have a different vocab from the one used in current checkpoint, and only + 100 of those entries were used: + + ``` + vocab_info = tf.estimator.VocabInfo( + new_vocab=sc_vocab_file.vocabulary_file, + new_vocab_size=sc_vocab_file.vocabulary_size, + num_oov_buckets=sc_vocab_file.num_oov_buckets, + old_vocab="old_vocab.txt", + old_vocab_size=100 + ) + ws = WarmStartSettings( + ckpt_to_initialize_from="/tmp", + var_name_to_vocab_info={ + "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info + }) + ``` + + Warm-start all weights but the parameters corresponding to `sc_vocab_file` + have a different vocab from the one used in current checkpoint and the + parameters corresponding to `sc_vocab_list` have a different name from the + current checkpoint: + + ``` + vocab_info = tf.estimator.VocabInfo( + new_vocab=sc_vocab_file.vocabulary_file, + new_vocab_size=sc_vocab_file.vocabulary_size, + num_oov_buckets=sc_vocab_file.num_oov_buckets, + old_vocab="old_vocab.txt", + old_vocab_size=100 + ) + ws = WarmStartSettings( + ckpt_to_initialize_from="/tmp", + var_name_to_vocab_info={ + "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info + }, + var_name_to_prev_var_name={ + "input_layer/sc_vocab_list_embedding/embedding_weights": + "old_tensor_name" + }) + ``` + + Attributes: + ckpt_to_initialize_from: [Required] A string specifying the directory with + checkpoint file(s) or path to checkpoint from which to warm-start the + model parameters. + vars_to_warm_start: [Optional] A regular expression that captures which + variables to warm-start (see tf.get_collection). Defaults to `'.*'`, + which warm-starts all variables. If `None` is explicitly given, only + variables specified in `var_name_to_vocab_info` will be warm-started. + var_name_to_vocab_info: [Optional] Dict of variable names (strings) to + VocabInfo. The variable names should be "full" variables, not the names + of the partitions. If not explicitly provided, the variable is assumed to + have no vocabulary. + var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to + name of the previously-trained variable in `ckpt_to_initialize_from`. If + not explicitly provided, the name of the variable is assumed to be same + between previous checkpoint and current model. + """ + + def __new__(cls, + ckpt_to_initialize_from, + vars_to_warm_start='.*', + var_name_to_vocab_info=None, + var_name_to_prev_var_name=None): + if not ckpt_to_initialize_from: + raise ValueError( + '`ckpt_to_initialize_from` MUST be set in WarmStartSettings') + return super(WarmStartSettings, cls).__new__( + cls, + ckpt_to_initialize_from, + vars_to_warm_start, + var_name_to_vocab_info or {}, + var_name_to_prev_var_name or {}, + ) + + +def _get_default_warm_start_settings(warm_start_from): + """Returns default WarmStartSettings. + + Args: + warm_start_from: Either a string representing the filepath of a checkpoint + to initialize from, or an instance of WarmStartSettings. + + Returns: + Either None or an instance of WarmStartSettings. + + Raises: + ValueError: If warm_start_from is not None but is neither a string nor an + instance of WarmStartSettings. + """ + if warm_start_from is None: + return None + if isinstance(warm_start_from, six.string_types): + return WarmStartSettings(ckpt_to_initialize_from=warm_start_from) + elif isinstance(warm_start_from, WarmStartSettings): + return warm_start_from + else: + raise ValueError('warm_start_from must be a string or a WarmStartSettings') diff --git a/tensorflow/python/estimator/estimator_lib.py b/tensorflow/python/estimator/estimator_lib.py index 01699e7399c4089281e9ece76e534e1f82692257..be8930b3cbcd89dbb31dffde0a7a5ecfb64fcd8b 100644 --- a/tensorflow/python/estimator/estimator_lib.py +++ b/tensorflow/python/estimator/estimator_lib.py @@ -30,6 +30,8 @@ from tensorflow.python.estimator.canned.linear import LinearRegressor from tensorflow.python.estimator.canned.parsing_utils import classifier_parse_example_spec from tensorflow.python.estimator.canned.parsing_utils import regressor_parse_example_spec from tensorflow.python.estimator.estimator import Estimator +from tensorflow.python.estimator.estimator import VocabInfo +from tensorflow.python.estimator.estimator import WarmStartSettings from tensorflow.python.estimator.export import export_lib as export from tensorflow.python.estimator.exporter import Exporter from tensorflow.python.estimator.exporter import FinalExporter @@ -41,8 +43,6 @@ from tensorflow.python.estimator.run_config import RunConfig from tensorflow.python.estimator.training import EvalSpec from tensorflow.python.estimator.training import train_and_evaluate from tensorflow.python.estimator.training import TrainSpec -from tensorflow.python.estimator.warm_starting_util import VocabInfo -from tensorflow.python.estimator.warm_starting_util import WarmStartSettings from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 7a0745b1d0d5ae932fa59be56a4952e82922a584..f4255091bf6c44916789a182e60e583171ad5e6b 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -48,6 +48,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import state_ops @@ -1268,10 +1269,10 @@ class EstimatorEvaluateTest(test.TestCase): _, _ = features, labels global_step = training.get_global_step() - image = array_ops.zeros([1, 3, 3, 1]) + image = array_ops.zeros([5, 3, 3, 1]) eval_metric_ops = { - 'image': (summary.image('image', image, max_outputs=1), - constant_op.constant(1)) + 'foo': (summary.image('image', image, max_outputs=3), + constant_op.constant(1)) } return model_fn_lib.EstimatorSpec( mode, @@ -1291,10 +1292,10 @@ class EstimatorEvaluateTest(test.TestCase): writer_cache.FileWriterCache.clear() # Get last evaluation Event written. - if check_eventfile_for_keyword('image', os.path.join(est.model_dir, - 'eval')): - return - self.fail('{} should be part of reported summaries.'.format('image')) + for key in ['foo/0', 'foo/1', 'foo/2']: + self.assertTrue( + check_eventfile_for_keyword(key, os.path.join(est.model_dir, 'eval')), + '{} should be part of reported summaries.'.format(key)) class EstimatorPredictTest(test.TestCase): @@ -1936,6 +1937,60 @@ class EstimatorExportTest(test.TestCase): # cleanup gfile.DeleteRecursively(tmpdir) + def test_export_savedmodel_tensor_features(self): + """Test that models accepting a single raw Tensor can be exported. + + See https://github.com/tensorflow/tensorflow/issues/11674 + + If the model_fn and receiver_fn accept raw tensors rather than dictionaries + as input, export_savedmodel should be okay with that, too. + + """ + + tmpdir = tempfile.mkdtemp() + + def _input_fn_tensor_features(): + t = array_ops.constant([1, 2, 3], dtype=dtypes.float32, shape=[1, 3]) + return (t, None) + + def _model_fn_tensor_features(features, labels, mode): + _ = labels + prediction = math_ops.matmul(features, features, transpose_b=True) + + return model_fn_lib.EstimatorSpec( + mode, + predictions=prediction, + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + export_outputs={ + 'test': export_output.PredictOutput({'prediction': prediction}) + }) + + def _serving_input_receiver_fn(): + feat = array_ops.placeholder(dtype=dtypes.float32) + return export.TensorServingInputReceiver( + features=feat, receiver_tensors=feat) + + est = estimator.Estimator(model_fn=_model_fn_tensor_features) + est.train(input_fn=_input_fn_tensor_features, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = est.export_savedmodel( + export_dir_base, _serving_input_receiver_fn) + + # Restore, to validate that the export was well-formed. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name.lower() for x in graph.get_operations()] + self.assertTrue('const' in graph_ops) + self.assertTrue('matmul' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + def test_scaffold_is_used_for_saver(self): tmpdir = tempfile.mkdtemp() diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index 83251c79fc561e16ebddb638668b92b3c69b8af4..9206a4964b3b7a6e3cc1e0f9e965a197be78c4ba 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -21,17 +21,16 @@ from __future__ import print_function import collections import os -import time import six +from tensorflow.python.estimator import util from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops -from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils @@ -120,6 +119,62 @@ class ServingInputReceiver(collections.namedtuple( receiver_tensors_alternatives=receiver_tensors_alternatives) +@tf_export('estimator.export.TensorServingInputReceiver') +class TensorServingInputReceiver(collections.namedtuple( + 'TensorServingInputReceiver', + ['features', 'receiver_tensors', 'receiver_tensors_alternatives'])): + """A return type for a serving_input_receiver_fn. + + This is for use with models that expect a single `Tensor` or `SparseTensor` + as an input feature, as opposed to a dict of features. + + The normal `ServingInputReceiver` always returns a feature dict, even if it + contains only one entry, and so can be used only with models that accept such + a dict. For models that accept only a single raw feature, the + `serving_input_receiver_fn` provided to `Estimator.export_savedmodel()` should + return this `TensorServingInputReceiver` instead. See: + https://github.com/tensorflow/tensorflow/issues/11674 + + Note that the receiver_tensors and receiver_tensor_alternatives arguments + will be automatically converted to the dict representation in either case, + because the SavedModel format requires each input `Tensor` to have a name + (provided by the dict key). + + The expected return values are: + features: A single `Tensor` or `SparseTensor`, representing the feature + to be passed to the model. + receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying + input nodes where this receiver expects to be fed by default. Typically, + this is a single placeholder expecting serialized `tf.Example` protos. + receiver_tensors_alternatives: a dict of string to additional + groups of receiver tensors, each of which may be a `Tensor` or a dict of + string to `Tensor`. These named receiver tensor alternatives generate + additional serving signatures, which may be used to feed inputs at + different points within the input receiver subgraph. A typical usage is + to allow feeding raw feature `Tensor`s *downstream* of the + tf.parse_example() op. Defaults to None. + """ + + def __new__(cls, features, receiver_tensors, + receiver_tensors_alternatives=None): + if features is None: + raise ValueError('features must be defined.') + if not (isinstance(features, ops.Tensor) + or isinstance(features, sparse_tensor.SparseTensor)): + raise ValueError('feature must be a Tensor or SparseTensor.') + + receiver = ServingInputReceiver( + features=features, + receiver_tensors=receiver_tensors, + receiver_tensors_alternatives=receiver_tensors_alternatives) + + return super(TensorServingInputReceiver, cls).__new__( + cls, + features=receiver.features[_SINGLE_FEATURE_DEFAULT_NAME], + receiver_tensors=receiver.receiver_tensors, + receiver_tensors_alternatives=receiver.receiver_tensors_alternatives) + + @tf_export('estimator.export.build_parsing_serving_input_receiver_fn') def build_parsing_serving_input_receiver_fn(feature_spec, default_batch_size=None): @@ -273,13 +328,6 @@ def _log_signature_report(signature_def_map, excluded_signatures): logging.warn('Export includes no default signature!') -# When we create a timestamped directory, there is a small chance that the -# directory already exists because another worker is also writing exports. -# In this case we just wait one second to get a new timestamp and try again. -# If this fails several times in a row, then something is seriously wrong. -MAX_DIRECTORY_CREATION_ATTEMPTS = 10 - - def get_timestamped_export_dir(export_dir_base): """Builds a path to a new subdirectory within the base directory. @@ -298,25 +346,7 @@ def get_timestamped_export_dir(export_dir_base): RuntimeError: if repeated attempts fail to obtain a unique timestamped directory name. """ - attempts = 0 - while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS: - export_timestamp = int(time.time()) - - export_dir = os.path.join( - compat.as_bytes(export_dir_base), - compat.as_bytes(str(export_timestamp))) - if not gfile.Exists(export_dir): - # Collisions are still possible (though extremely unlikely): this - # directory is not actually created yet, but it will be almost - # instantly on return from this function. - return export_dir - time.sleep(1) - attempts += 1 - logging.warn( - 'Export directory {} already exists; retrying (attempt {}/{})'.format( - export_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)) - raise RuntimeError('Failed to obtain a unique export directory name after ' - '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS)) + return util.get_timestamped_dir(export_dir_base) def get_temp_export_dir(timestamped_export_dir): diff --git a/tensorflow/python/estimator/export/export_lib.py b/tensorflow/python/estimator/export/export_lib.py index 99cd81d678bc04e7ed52de721a1fdf799c116795..226fc97fd3a3aefe61c4b88088873ce7489168c7 100644 --- a/tensorflow/python/estimator/export/export_lib.py +++ b/tensorflow/python/estimator/export/export_lib.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.python.estimator.export.export import build_parsing_serving_input_receiver_fn from tensorflow.python.estimator.export.export import build_raw_serving_input_receiver_fn from tensorflow.python.estimator.export.export import ServingInputReceiver +from tensorflow.python.estimator.export.export import TensorServingInputReceiver from tensorflow.python.estimator.export.export_output import ClassificationOutput from tensorflow.python.estimator.export.export_output import ExportOutput from tensorflow.python.estimator.export.export_output import PredictOutput @@ -34,6 +35,7 @@ _allowed_symbols = [ 'build_parsing_serving_input_receiver_fn', 'build_raw_serving_input_receiver_fn', 'ServingInputReceiver', + 'TensorServingInputReceiver', 'ClassificationOutput', 'ExportOutput', 'PredictOutput', diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py index 8442bf04accbd0bc15f5958069bf3060debd42bc..eb9688bc973666554b6057f5f546b9a2d18461d6 100644 --- a/tensorflow/python/estimator/export/export_test.py +++ b/tensorflow/python/estimator/export/export_test.py @@ -385,5 +385,67 @@ class ExportTest(test_util.TensorFlowTestCase): self.assertTrue(int(time_2) < int(time_3)) +class TensorServingReceiverTest(test_util.TensorFlowTestCase): + + def test_tensor_serving_input_receiver_constructor(self): + features = constant_op.constant([0]) + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + r = export.TensorServingInputReceiver(features, receiver_tensors) + self.assertTrue(isinstance(r.features, ops.Tensor)) + self.assertTrue(isinstance(r.receiver_tensors, dict)) + + def test_tensor_serving_input_receiver_sparse(self): + features = sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]) + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + r = export.TensorServingInputReceiver(features, receiver_tensors) + self.assertTrue(isinstance(r.features, sparse_tensor.SparseTensor)) + self.assertTrue(isinstance(r.receiver_tensors, dict)) + + def test_serving_input_receiver_features_invalid(self): + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + + with self.assertRaisesRegexp(ValueError, "features must be defined"): + export.TensorServingInputReceiver( + features=None, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp(ValueError, "feature must be a Tensor"): + export.TensorServingInputReceiver( + features={"1": constant_op.constant([1])}, + receiver_tensors=receiver_tensors) + + def test_serving_input_receiver_receiver_tensors_invalid(self): + features = constant_op.constant([0]) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensors must be defined"): + export.TensorServingInputReceiver( + features=features, + receiver_tensors=None) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensors keys must be strings"): + export.TensorServingInputReceiver( + features=features, + receiver_tensors={ + 1: array_ops.placeholder(dtypes.string, name="example0")}) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensor example1 must be a Tensor"): + export.TensorServingInputReceiver( + features=features, + receiver_tensors={"example1": [1]}) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/estimator/replicate_model_fn.py b/tensorflow/python/estimator/replicate_model_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..144d89abf3444062927d9261301fe50f4a63b280 --- /dev/null +++ b/tensorflow/python/estimator/replicate_model_fn.py @@ -0,0 +1,824 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities to replicate model_fn's over local GPUs. + +This file contains util that allow to replicate `Estimator.model_fn` over +GPUs. Replicated version of a `model_fn` is returned that can subsequently +be used with `Estimator`. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict +from contextlib import contextmanager +import copy + +import six + +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.client import device_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator import util +from tensorflow.python.estimator.export import export_output as export_output_lib +from tensorflow.python.framework import device as framework_device +from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import tf_logging +from tensorflow.python.training import device_setter as device_setter_lib +from tensorflow.python.training import optimizer as optimizer_lib + + +def _replicate_model_fn(model_fn, + devices=None): + """Replicate `Estimator.model_fn` over GPUs. + + The given `model_fn` specifies a single forward pass of a model. To replicate + such a model over GPUs, each GPU gets its own instance of the forward pass + (a.k.a. a tower). The input features and labels get sharded into the chunks + that correspond to the number of GPUs. Each tower computes a loss based + on its input. For each such loss, gradients are computed. After that, the + available losses are aggregated to form aggregated loss. Available + gradients are summed. Then, they update weights using the specified + optimizer. + + If `devices` are `None`, then all available GPUs are going to be used for + replication. If no GPUs are available, then the model is going to be + placed on the CPU. + + Two modes of local replication over available GPUs are supported: + 1) If exactly 1 GPU is detected, then variables and operations are placed + onto the GPU. + 2) If more than 1 GPU is detected, then variables are going to be placed on + the CPU. Replicas of operations are placed on each individual GPU. + + Here is an example of how one might use their `model_fn` to run over GPUs: + ```python + ... + def model_fn(...): # See `model_fn` in `Estimator`. + loss = ... + optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) + optimizer = tf.contrib.estimator._TowerOptimizer(optimizer) + if mode == tf.estimator.ModeKeys.TRAIN: + # See the section below on `EstimatorSpec.train_op`. + return EstimatorSpec(mode=mode, loss=loss, + train_op=optimizer.minimize(loss)) + + # No change for `ModeKeys.EVAL` or `ModeKeys.PREDICT`. + return EstimatorSpec(...) + ... + classifier = tf.estimator.Estimator( + model_fn=tf.contrib.estimator.replicate_model_fn(model_fn)) + ``` + + Please see `DNNClassifierIntegrationTest` for an example with a canned + Estimator. + + On `EstimatorSpec.train_op`: + `model_fn` returns `EstimatorSpec.train_op` for + `tf.estimator.GraphKeys.TRAIN`. It is typically derived using an optimizer. + Towers are expected to populate it in the same way. Gradients from all towers + are reduced and applied in the last tower. To achieve that in the case of + multiple towers, `_TowerOptimizer` needs to be used. See `_TowerOptimizer`. + + On sharding input features and labels: + Input features and labels are split for consumption by each tower. They are + split across the dimension 0. Features and labels need to be batch major. + + On reduction algorithms: + Certain algorithms were chosen for aggregating results of computations on + multiple towers: + - Losses from all towers are reduced according to `loss_reduction` argument + to TowerOptimizer.. + - Gradients from all towers are reduced according to the `loss_reduction` + for each trainable variable. + - `eval_metrics_ops` are reduced per metric using `reduce_mean`. + - `EstimatorSpec.predictions` and `EstimatorSpec.export_outputs` are + reduced using concatenation. + - For all other fields of `EstimatorSpec` the values of the first tower + are taken. + + On distribution of variables: + Variables are not duplicated between towers. Instead, they are placed on a + single device as defined above and shared across towers. + + On overhead: + If only one device is specified, then aggregation of loss and gradients + doesn't happen. Replication consists of placing `model_fn` onto the + specified device. + + On current limitations: + - `predictions` are not supported for `ModeKeys.EVAL`. They are required + for `tf.contrib.estimator.add_metrics`. + + Args: + model_fn: `model_fn` as defined in `Estimator`. See the section above about + the train_op argument of `EstimatorSpec`. + devices: Optional list of devices to replicate the model across. This + argument can be used to replice only on the subset of available GPUs. + If `None`, then all available GPUs are going to be used for replication. + If no GPUs are available, then the model is going to be placed on the CPU. + + Returns: + A replicated version of the supplied `model_fn`. Returned function that + conforms to the requirements of `Estimator`'s `model_fn` and can be used + instead of the supplied `model_fn`. + """ + return _replicate_model_fn_with_mode( + model_fn, + devices, + # TODO(isaprykin): Query the system configuration to choose modes other + # than `SHARED_LOCAL_PARAMETER_SERVER`, even though it is often + # appropriate. + mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER) + + +class _VariableDistributionMode(object): + """Modes for variable distribution used for forcing a particular one. + + Forcing a mode is meant for performance experimentation purposes rather than + for general use cases. + """ + + SHARED_LOCAL_PARAMETER_SERVER = 1 + """Variables are placed on a single device and shared across all devices. + + Two ways to achieve this distribution over available GPUs are supported: + 1) If exactly 1 GPU is detected, then variables and operations are placed + onto GPU. + 2) If more than 1 GPU is detected, then variables are going to be placed on + the CPU. Replicas of operations are placed on each individual GPU. + """ + + SHARED_ROUND_ROBIN = 2 + """Variables are placed on all devices in a round-robin fashion. + + Every subsequent variable is placed on the next device. There is only one + copy of each variable that is shared across all devices. + """ + + +def _replicate_model_fn_with_mode( + model_fn, + devices=None, + mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER): + """A version of `replicate_model_fn` that allows to specify a `mode`.""" + if not devices: + devices = _get_local_devices('GPU') or _get_local_devices('CPU') + + is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0].upper() + consolidation_device = devices[0] if is_a_single_gpu_case else '/CPU:0' + + ps_devices = [consolidation_device] + if mode == _VariableDistributionMode.SHARED_ROUND_ROBIN: + ps_devices = devices + + tf_logging.info('Replicating the `model_fn` across {}. Variables are going ' + 'to be placed on {}. Consolidation device is going to be {}.' + .format(devices, ps_devices, consolidation_device)) + + def single_device_model_fn(features, labels, mode, params=None, config=None): + """`model_fn` on a single device without reduction overhead.""" + return _get_loss_towers( + model_fn=model_fn, + mode=mode, + features=[features], + labels=[labels], + params=params, + config=config, + devices=devices, + local_ps_devices=ps_devices)[0] # One device, so one spec is out. + + def replicated_model_fn(features, labels, mode, params=None, config=None): + """Replicated version of `model_fn` to be used instead.""" + feature_shards, label_shards = _split_batch( + features, labels, len(devices), device=consolidation_device) + tower_specs = _get_loss_towers( + model_fn=model_fn, + mode=mode, + features=feature_shards, + labels=label_shards, + params=params, + config=config, + devices=devices, + local_ps_devices=ps_devices) + + if mode == model_fn_lib.ModeKeys.TRAIN: + train_op = _minimize_towers(tower_specs) + return _train_spec( + tower_specs, train_op, aggregation_device=consolidation_device) + elif mode == model_fn_lib.ModeKeys.EVAL: + return _eval_spec(tower_specs, aggregation_device=consolidation_device) + elif mode == model_fn_lib.ModeKeys.PREDICT: + return _predict_spec(tower_specs, aggregation_device=consolidation_device) + + if len(devices) == 1: + return single_device_model_fn + else: + return replicated_model_fn + + +class _TowerOptimizer(optimizer_lib.Optimizer): + """Gathers gradients from all towers and reduces them in the last one.""" + + COLLECTION_FOR_GRAPH_STATES = 'replicate_model_fn_graph_states' + + def __init__(self, optimizer_or_optimizer_fn, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE): + """Wrap an existing optimizer for gathering gradients across towers. + + Each invocation of model_fn has to call the same optimizers in the same + order. + + Multiple optimizers that use the same or different losses are supported. + + If _TowerOptimizer is used but `replicate_model_fn` isn't, then no + aggregation will happen. All calls will simply be forwarded to the + underlying optimizer. The behavior is similar if there is only one tower. + + If _TowerOptimizer is used together with SyncReplicasOptimizer that wraps + the user's optimizer, then it's the SyncReplicasOptimizer that needs to be + wrapped with _TowerOptimizer. + + Args: + optimizer_or_optimizer_fn: an instance of optimizer to wrap. That + instance is going to be used for optimizer-specific logic. This can + also be a no-argument function that returns such an optimizer instance. + loss_reduction: controls whether losses are summed or averaged. + """ + self._optimizer_or_optimizer_fn = optimizer_or_optimizer_fn + self._loss_reduction = loss_reduction + + @staticmethod + def has_been_used(): + return _TowerOptimizer._graph_state().has_tower_optimizer_been_used + + def get_slot(self, *args, **kwargs): + return self._get_optimizer().get_slot(*args, **kwargs) + + def get_slot_names(self, *args, **kwargs): + return self._get_optimizer().get_slot_names(*args, **kwargs) + + def get_name(self, *args, **kwargs): + return self._get_optimizer().get_name(*args, **kwargs) + + def variables(self, *args, **kwargs): + return self._get_optimizer().variables(*args, **kwargs) + + def compute_gradients(self, loss, *args, **kwargs): + """Compute gradients, but first, if needed, scale the loss.""" + _TowerOptimizer._graph_state().set_loss_reduction(self._loss_reduction) + loss = _scale_loss(loss, + self._loss_reduction, + self._graph_state().number_of_towers) + return self._get_optimizer().compute_gradients(loss, *args, **kwargs) + + def apply_gradients(self, grads_and_vars, global_step=None, **kwargs): + """Collect gradients updates to apply them with the last tower.""" + if self._graph_state().number_of_towers == 1: + # Avoid the overhead of reduction if there's only one tower. + # + # There assumed to be only one tower if aggregation-related methods were + # not called by `_get_loss_towers`, for example if the model_fn uses + # TowerEstimator, but `replicate_model_fn` isn't used. + return self._get_optimizer().apply_gradients(grads_and_vars, global_step, + **kwargs) + + self._graph_state().collect_gradients(grads_and_vars) + + if not self._graph_state().is_the_last_tower: + with ops_lib.control_dependencies(_extract_tensors(grads_and_vars)): + return self._construct_no_op_train_op() + else: + # Gradients need to be gathered and applied in the scope of the first + # tower, so that the tensors are accessible via names without prefixes. + var_scope, name_scope = self._graph_state().scopes_of_the_first_tower + with variable_scope.variable_scope(var_scope): + with ops_lib.name_scope(name_scope): + return self._apply_gathered_gradients(global_step, **kwargs) + + def _apply_gathered_gradients(self, global_step, **kwargs): + graph_state = self._graph_state() + optimizer = self._get_optimizer() + + grad_lists = {} + for grad, var in graph_state.get_latest_gradients_from_all_towers(): + if grad is not None: + grad_lists.setdefault(var, []).append(grad) + + aggregated_grads = [] + with ops_lib.name_scope('gradient_aggregating'): + for var, grads in six.iteritems(grad_lists): + grad = _compute_sum_on_device(grads, var.device) + aggregated_grads.append((grad, var)) + return optimizer.apply_gradients( + aggregated_grads, global_step=global_step, **kwargs) + + def _get_optimizer(self): + if callable(self._optimizer_or_optimizer_fn): + # If optimizer is given as a function then we need to wait till we are + # under the right graph context before constructing it. That's why the + # optimizer is constructed in _get_optimizer() rather than __init__(). + self._optimizer_or_optimizer_fn = self._optimizer_or_optimizer_fn() + self._graph_state().has_tower_optimizer_been_used = True + return self._optimizer_or_optimizer_fn + + def _construct_no_op_train_op(self): + return control_flow_ops.no_op(name='train_op_placeholder') + + @staticmethod + def _graph_state(): + graph_states = ops_lib.get_default_graph().get_collection_ref( + _TowerOptimizer.COLLECTION_FOR_GRAPH_STATES) + if not graph_states: + graph_states.append(_TowerOptimizer._PerGraphState()) + return graph_states[-1] + + @staticmethod + def _did_towers_have_same_optimizer_calls(): + graph_state = _TowerOptimizer._graph_state() + return graph_state.did_towers_have_same_optimizer_calls() + + @staticmethod + def _clear_graph_state(): + # Clearing the Graph collection will prevent _PerGraphState from being + # serialized. + ops_lib.get_default_graph().clear_collection( + _TowerOptimizer.COLLECTION_FOR_GRAPH_STATES) + + class _PerGraphState(object): + """Gradient reduction related state of a Tensorflow graph.""" + + def __init__(self): + self._collected_grads_and_vars = defaultdict(list) + self._current_tower_index = 0 + self._number_of_towers = 1 + self._loss_reduction = None + # Scopes of the first tower that don't have a prefix: + self._variable_scope = None + self._name_scope = None + # If needed, alert that _TowerOptimizer needs to be used with model_fn. + self._has_tower_optimizer_been_used = False + + def collect_gradients(self, grads_and_vars): + self._collected_grads_and_vars[self._current_tower_index].append( + grads_and_vars) + + def get_latest_gradients_from_all_towers(self): + """Get gradients across towers for the last called optimizer.""" + grads_and_vars = [] + index_of_last_gradients = len( + self._collected_grads_and_vars[self._current_tower_index]) - 1 + for tower_id in range(self._current_tower_index + 1): + grads_and_vars.extend( + self._collected_grads_and_vars[tower_id][index_of_last_gradients]) + return grads_and_vars + + def set_number_of_towers(self, number_of_towers): + self._number_of_towers = number_of_towers + + def set_loss_reduction(self, loss_reduction): + self._loss_reduction = loss_reduction + + @contextmanager + def tower(self, tower_id, var_scope, name_scope): + if tower_id == 0: + self._variable_scope = var_scope + self._name_scope = name_scope + self._current_tower_index = tower_id + yield + + @property + def scopes_of_the_first_tower(self): + return self._variable_scope, self._name_scope + + @property + def is_the_last_tower(self): + return self._current_tower_index == (self._number_of_towers - 1) + + @property + def number_of_towers(self): + return self._number_of_towers + + @property + def loss_reduction(self): + return self._loss_reduction + + @property + def has_tower_optimizer_been_used(self): + return self._has_tower_optimizer_been_used + + @has_tower_optimizer_been_used.setter + def has_tower_optimizer_been_used(self, value): + self._has_tower_optimizer_been_used = value + + def did_towers_have_same_optimizer_calls(self): + total_number_of_grads = sum([ + len(grads) + for _, grads in six.iteritems(self._collected_grads_and_vars) + ]) + return total_number_of_grads % self._number_of_towers == 0 + + +def _get_local_devices(device_type): + local_device_protos = device_lib.list_local_devices() + return [ + device.name + for device in local_device_protos + if device.device_type == device_type + ] + + +def _split_batch(features, labels, number_of_shards, device): + """Split input features and labes into batches.""" + + def ensure_divisible_by_shards(sequence): + batch_size = ops_lib.convert_to_tensor(sequence).get_shape()[0] + if batch_size % number_of_shards != 0: + raise ValueError( + 'Batch size {} needs to be divisible by the number of GPUs, which ' + 'is {}.'.format(batch_size, number_of_shards)) + + def split_dictionary(dictionary): + """Split a dictionary into shards.""" + shards = [{} for _ in range(number_of_shards)] + for name, tensor in six.iteritems(dictionary): + if isinstance(tensor, sparse_tensor.SparseTensor): + for i, shard in enumerate( + sparse_ops.sparse_split( + sp_input=tensor, num_split=number_of_shards, axis=0)): + shards[i][name] = shard + else: + ensure_divisible_by_shards(tensor) + for i, shard in enumerate(array_ops.split(tensor, number_of_shards)): + shards[i][name] = shard + return shards + + with ops_lib.name_scope('split_inputs'): + with ops_lib.device(device): + if isinstance(features, dict): + feature_shards = split_dictionary(features) + else: + ensure_divisible_by_shards(features) + feature_shards = array_ops.split(features, number_of_shards) + + if labels is None: + label_shards = None + elif isinstance(labels, dict): + label_shards = split_dictionary(labels) + else: + ensure_divisible_by_shards(labels) + label_shards = array_ops.split(labels, number_of_shards) + return feature_shards, label_shards + + +_DEFAULT_NAME_SCOPE_PATTERN = 'tower_{}' + + +def _get_loss_towers(model_fn, + mode, + features, + labels, + params, + config, + devices, + local_ps_devices, + name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN): + """Replicate the loss computation across devices.""" + tower_specs = [] + + model_fn_args = util.fn_args(model_fn) + optional_params = {} + if 'params' in model_fn_args: + optional_params['params'] = copy.deepcopy(params) + if 'config' in model_fn_args: + optional_params['config'] = copy.deepcopy(config) + + # pylint: disable=protected-access + round_robin_strategy = device_setter_lib._RoundRobinStrategy( + num_tasks=len(local_ps_devices)) + _TowerOptimizer._graph_state().set_number_of_towers(len(devices)) + + for i, device in enumerate(devices): + is_the_first_tower = (i == 0) + + device_setter = _local_device_setter( + worker_device=device, + ps_devices=local_ps_devices, + ps_strategy=round_robin_strategy) + + # We would like to preserve the names of the variables and ops that the user + # might be relying on. Names without a prefix are going to resolve to + # variables and ops of the first tower. + name_scope = name_scope_pattern + if is_the_first_tower: + name_scope = '' + + with variable_scope.variable_scope( + '', reuse=not is_the_first_tower) as var_scope: + with ops_lib.name_scope(name_scope.format(i)) as name_scope: + with _TowerOptimizer._graph_state().tower( + tower_id=i, var_scope=var_scope, name_scope=name_scope): + with ops_lib.device(device_setter): + labels_shard = None + if labels: + labels_shard = labels[i] + + tower_spec = model_fn( + mode=mode, + features=features[i], + labels=labels_shard, + **optional_params) + + if (tower_spec.train_op is not None and len(devices) > 1 and + not _TowerOptimizer.has_been_used()): + raise ValueError('Please wrap optimizers with _TowerOptimizer' + ' in order to use replicate_model_fn with' + ' multiple `devices`.') + + # Scaling the loss here doesn't actually affect gradients. Another + # instance of scaling happens inside the _TowerOptimizer. + tower_spec = _scale_tower_loss( + tower_spec, + _TowerOptimizer._graph_state().loss_reduction, + number_of_towers=len(devices)) + tower_specs.append(tower_spec) + + if not _TowerOptimizer._did_towers_have_same_optimizer_calls(): + raise ValueError('Each invocation of model_fn was supposed to make the same' + ' optimizer calls.') + _TowerOptimizer._clear_graph_state() + # pylint: enable=protected-access + return tower_specs + + +def _local_device_setter(worker_device, ps_devices, ps_strategy): + """A device setter that puts distributes Var/Ops to PS/workers.""" + ps_ops = ['Variable', 'VariableV2', 'VarHandleOp'] + + def local_device_chooser(op): + current_device = framework_device.DeviceSpec.from_string(op.device or '') + + node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def + if node_def.op in ps_ops: + ps_device_spec = framework_device.DeviceSpec.from_string( + '{}'.format(ps_devices[ps_strategy(op)])) + + ps_device_spec.merge_from(current_device) + return ps_device_spec.to_string() + else: + worker_device_spec = framework_device.DeviceSpec.from_string( + worker_device or '') + worker_device_spec.merge_from(current_device) + return worker_device_spec.to_string() + + return local_device_chooser + + +def _scale_tower_loss(tower_spec, loss_reduction, number_of_towers): + """Produce an EstimatorSpec with approproriately scaled loss.""" + if tower_spec.loss is None: + return tower_spec + + estimator_spec = _asdict(tower_spec) + estimator_spec['loss'] = _scale_loss( + tower_spec.loss, + loss_reduction, + number_of_towers, + reduced_loss_name='averaged_loss') + return model_fn_lib.EstimatorSpec(**estimator_spec) + + +def _scale_loss(loss, loss_reduction, number_of_towers, reduced_loss_name=None): + """If needed, scale down the loss for averaging loss by summing.""" + if loss is None: + return None + if number_of_towers == 1: + return loss + + if loss_reduction == losses.Reduction.NONE: + raise ValueError('Tower losses need to be reduced in some way, yet {} ' + 'reduction is specified.'.format(loss_reduction)) + + if loss_reduction != losses.Reduction.SUM: + return math_ops.div(loss, 1.0 * number_of_towers, name=reduced_loss_name) + else: + return loss + + +def _minimize_towers(tower_specs): + """`train_op` of the last tower applies aggregated gradients.""" + return tower_specs[-1].train_op + + +def _compute_sum_on_device(values, device, name=None): + with ops_lib.device(device): + if isinstance(values[0], ops_lib.IndexedSlices): + if name: + raise ValueError('The name {} is not expected to be given to ' + 'IndexedSlices {}'.format(name, values)) + + values_concat = array_ops.concat([v.values for v in values], axis=0) + indices_concat = array_ops.concat([v.indices for v in values], axis=0) + return ops_lib.IndexedSlices(values_concat, indices_concat, + values[0].dense_shape) + else: + return math_ops.add_n(values, name=name) + + +def _train_spec(tower_specs, + train_op, + aggregation_device, + aggregated_loss_name='loss'): + """Populate replicated EstimatorSpec for `GraphKeys.TRAIN`.""" + # Spec of the last tower is used as the template for the final spec, because + # some `EstimatorSpec.training_hooks` rely on calls made in model_fn. For + # example, `SyncReplicasOptimizerHook` validates the + # `SyncReplicasOptimizer.apply_gradients` call. `TowerEstimator` makes that + # call only in the last tower. + estimator_spec = _asdict(tower_specs[-1]) + estimator_spec['mode'] = model_fn_lib.ModeKeys.TRAIN + estimator_spec['train_op'] = train_op + estimator_spec['loss'] = _compute_sum_on_device( + [spec.loss for spec in tower_specs], aggregation_device, + aggregated_loss_name) + return model_fn_lib.EstimatorSpec(**estimator_spec) + + +def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'): + """Populate replicated EstimatorSpec for `GraphKeys.EVAL`.""" + estimator_spec = _asdict(tower_specs[0]) + estimator_spec['mode'] = model_fn_lib.ModeKeys.EVAL + estimator_spec['loss'] = _compute_sum_on_device( + [spec.loss for spec in tower_specs], aggregation_device, + aggregated_loss_name) + + update_ops = [] + for tower_spec in tower_specs: + for name, (_, update_op) in six.iteritems(tower_spec.eval_metric_ops): + update_ops.append(update_op) + + with ops_lib.control_dependencies(update_ops): + reduced_update_op = _reduce_metric_variables(len(tower_specs)) + + eval_metric_ops = {} + for name, (metric_tensor, _) in six.iteritems(tower_specs[0].eval_metric_ops): + eval_metric_ops[name] = (metric_tensor, reduced_update_op) + estimator_spec['eval_metric_ops'] = eval_metric_ops + return model_fn_lib.EstimatorSpec(**estimator_spec) + + +def _reduce_metric_variables(number_of_towers): + """Aggregate local variables used in metrics into the first tower.""" + if number_of_towers == 1: + return control_flow_ops.no_op(name='no_eval_metric_reduction') + + metric_variables = ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES) + variables_per_tower = len(metric_variables) // number_of_towers + + if len(metric_variables) % number_of_towers != 0: + raise ValueError( + 'Different `EstimatorSpec.eval_metric_ops` across `model_fn()` calls.' + ' Expected {} local variables, but got {} instead.'.format( + variables_per_tower * number_of_towers, len(metric_variables))) + + # `metric_variables` has the size of `variables_per_tower` x + # number_of_towers. Each tower is produced by calling the same model_fn. + # First `variables_per_tower` correspond to the first tower. Each such + # variable has an replica at the `(variables_per_tower * i)` position, where + # `i` is `[1.. number_of_towers]`. We are going to add values from replicas + # to each variable of the first tower. We then zero out replica values, so + # that `_reduce_metric_variables` operation is idempotent. If a metric + # is then computed based on local variables from the first tower, then the + # resulting metric is an estimate for all `number_of_towers` towers. + ops = [] + for i in range(0, variables_per_tower): + next_replica_id = i + variables_per_tower + replicas = [ + metric_variables[replica_id] + for replica_id in range(next_replica_id, len(metric_variables), + variables_per_tower) + ] # `replicas` doesn't contain the first-tower variable. + + reduce_op = state_ops.assign_add(metric_variables[i], + math_ops.add_n(replicas)) + + with ops_lib.control_dependencies([reduce_op]): + for replica in replicas: + zeros_for_replica = array_ops.zeros( + array_ops.shape(replica), dtype=replica.dtype) + zero_out_replica_op = state_ops.assign(replica, zeros_for_replica) + ops.append(zero_out_replica_op) + + return control_flow_ops.group(*ops) + + +def _predict_spec(tower_specs, aggregation_device): + """Populate replicated EstimatorSpec for `GraphKeys.PREDICT`.""" + estimator_spec = _asdict(tower_specs[0]) + estimator_spec['mode'] = model_fn_lib.ModeKeys.PREDICT + + with ops_lib.device(aggregation_device): + estimator_spec['predictions'] = _concat_tensor_dicts( + *[tower_spec.predictions for tower_spec in tower_specs]) + + export_outputs_dict = _dict_concat( + *[tower_spec.export_outputs for tower_spec in tower_specs]) + + export_outputs = {} + for name, export_output_list in six.iteritems(export_outputs_dict): + if isinstance(export_output_list[0], export_output_lib.PredictOutput): + export_outputs[name] = export_output_lib.PredictOutput( + outputs=_concat_tensor_dicts(*[ + export_output.outputs for export_output in export_output_list + ])) + elif isinstance(export_output_list[0], + export_output_lib.RegressionOutput): + export_outputs[name] = export_output_lib.RegressionOutput( + value=array_ops.concat( + [export_output.value for export_output in export_output_list], + axis=0)) + elif isinstance(export_output_list[0], + export_output_lib.ClassificationOutput): + scores = None + if export_output_list[0].scores is not None: + scores = array_ops.concat( + [export_output.scores for export_output in export_output_list], + axis=0) + + classes = None + if export_output_list[0].classes is not None: + classes = array_ops.stack( + [export_output.classes for export_output in export_output_list], + axis=0) + + export_outputs[name] = export_output_lib.ClassificationOutput( + scores=scores, classes=classes) + + estimator_spec['export_outputs'] = export_outputs + return model_fn_lib.EstimatorSpec(**estimator_spec) + + +def _concat_tensor_dicts(*tensor_dicts): + return { + name: array_ops.concat(tensors, axis=0, name=name) + for name, tensors in six.iteritems(_dict_concat(*tensor_dicts)) + } + + +def _extract_tensors(tensors_and_vars): + tensors = [] + for tensor_and_var in tensors_and_vars: + tensor, _ = tensor_and_var + if isinstance(tensor, ops_lib.IndexedSlices): + tensors.append(tensor.values) + elif tensor is not None: + tensors.append(tensor) + return tensors + + +def _dict_concat(*dicts): + list_dict = {} + for d in dicts: + if d is None: + continue + + for k, v in six.iteritems(d): + list_dict.setdefault(k, []).append(v) + return list_dict + + +def _asdict(namedtuple): + """Returns a namedtuple as a dictionary. + + This is required because `_asdict()` in Python 3.x.x is broken in classes + that inherit from `collections.namedtuple`. See + https://bugs.python.org/issue24931 for more details. + + Args: + namedtuple: An object that inherits from `collections.namedtuple`. + + Returns: + A dictionary version of the tuple. + """ + return {k: getattr(namedtuple, k) for k in namedtuple._fields} diff --git a/tensorflow/python/estimator/replicate_model_fn_test.py b/tensorflow/python/estimator/replicate_model_fn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ad1f9c02b92d7b1ce929494f4b6fbf636762a7fd --- /dev/null +++ b/tensorflow/python/estimator/replicate_model_fn_test.py @@ -0,0 +1,1739 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utilities that replicate `Estimator.model_fn` over GPUs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +import shutil +import tempfile +import numpy as np +import six + +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator import replicate_model_fn +from tensorflow.python.estimator.canned import dnn +from tensorflow.python.estimator.canned import optimizers +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.export import export_output +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import losses +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import adam +from tensorflow.python.training import device_setter +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import training + + +# TODO(isaprykin): Parametrize all the tests on +# replicate_model_fn._VariableDistributionMode when it's supported. +class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def test_complete_flow_with_public_version(self): + return self._complete_flow_with_mode(mode=None) + + def test_complete_flow_with_mode_local_ps_server(self): + return self._complete_flow_with_mode( + replicate_model_fn._VariableDistributionMode. + SHARED_LOCAL_PARAMETER_SERVER) + + def test_complete_flow_with_mode_round_robin(self): + return self._complete_flow_with_mode( + replicate_model_fn._VariableDistributionMode.SHARED_ROUND_ROBIN) + + def _complete_flow_with_mode(self, mode): + n_classes = 3 + input_dimension = 2 + batch_size = 12 + + data = np.linspace( + 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32) + x_data = data.reshape(batch_size, input_dimension) + categorical_data = np.random.random_integers( + 0, len(x_data), size=len(x_data)) + y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1)) + train_input_fn = numpy_io.numpy_input_fn( + x={'x': x_data, + 'categories': categorical_data}, + y=y_data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': x_data, + 'categories': categorical_data}, + y=y_data, + batch_size=batch_size, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': x_data, + 'categories': categorical_data}, + batch_size=batch_size, + shuffle=False) + + feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,)), + feature_column.embedding_column( + feature_column.categorical_column_with_vocabulary_list( + 'categories', + vocabulary_list=np.linspace( + 0., len(x_data), len(x_data), dtype=np.int64)), 1) + ] + + def optimizer_fn(): + return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05) + + estimator = dnn.DNNClassifier( + hidden_units=(2, 2), + # Adagrad is configured with `get_optimizer_instance`, so the function + # form of `TowerOptimizer.__init__` is used. + optimizer=replicate_model_fn._TowerOptimizer( + optimizer_fn, loss_reduction=losses.Reduction.SUM), + feature_columns=feature_columns, + n_classes=n_classes, + model_dir=self._model_dir) + + if not mode: # Use the public `replicate_model_fn`. + model_fn = replicate_model_fn._replicate_model_fn( + estimator.model_fn, devices=['/gpu:0', '/gpu:1', '/gpu:2']) + else: + model_fn = replicate_model_fn._replicate_model_fn_with_mode( + estimator.model_fn, + devices=['/gpu:0', '/gpu:1', '/gpu:2'], + mode=mode) + + estimator = estimator_lib.Estimator( + model_fn=model_fn, + model_dir=estimator.model_dir, + config=estimator.config, + params=estimator.params) + + num_steps = 10 + estimator.train(train_input_fn, steps=num_steps) + + scores = estimator.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops_lib.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + predicted_proba = np.array([ + x[prediction_keys.PredictionKeys.PROBABILITIES] + for x in estimator.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, n_classes), predicted_proba.shape) + + feature_spec = feature_column.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = estimator.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + # Nothing should be left in the graph so that it doesn't get serialized. + self.assertFalse(ops_lib.get_default_graph().get_collection_ref( + replicate_model_fn._TowerOptimizer.COLLECTION_FOR_GRAPH_STATES)) + + def _as_label(self, data_in_float): + return np.rint(data_in_float).astype(np.int64) + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + +class ReplicateModelTest(test_util.TensorFlowTestCase): + + def create_model_fn_with_loss_reduction(self, loss_reduction): + + def model_fn(mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, + predictions=predictions, + reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + optimizer = replicate_model_fn._TowerOptimizer( + gradient_descent.GradientDescentOptimizer(params['learning_rate']), + loss_reduction=loss_reduction) + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=optimizer.minimize(loss)) + + return model_fn + + @property + def params(self): + params = {} + params['learning_rate'] = 1.0 + return params + + def test_train(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.create_model_fn_with_loss_reduction(losses.Reduction.SUM), + devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # derivative of loss = (1*c - 1) + (2*c - 2) is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(7.0, session.run(c)) + + def test_train_with_mean_reduction(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + # Add another trainable variable that doesn't produce a gradient to + # verify that None gradients are supported. + _ = variable_scope.get_variable( + 'another_variable', + initializer=constant_op.constant(1, dtype=dtypes.float64), + dtype=dtypes.float64) + + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.create_model_fn_with_loss_reduction(losses.Reduction.MEAN), + devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = ((1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)) / 2.0 + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # derivative of loss = (1*c - 1)/2 + (2*c - 2)/2 is 1.5. + # It's the same computation as without mean reduction, but the + # loss from every tower is scaled by 1/. + # new value of c = 10 - learning rate * 1.5 = 8.5 + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(8.5, session.run(c)) + + def test_train_two_steps_collected_gradients_are_reset_between_steps(self): + with ops_lib.Graph().as_default(): + features = array_ops.placeholder(dtypes.float64) + labels = array_ops.placeholder(dtypes.float64) + + feature_inputs = np.array([[1.0], [2.0]]), np.array([[1.5], [2.5]]) + label_inputs = np.array([[1.0], [2.0]]), np.array([[1.5], [2.5]]) + + # loss = feature * c - label + expected_losses = ((1.0 * 10 - 1.0) + (2.0 * 10 - 2.0), + (1.5 * 7.0 - 1.5) + (2.5 * 7.0 - 2.5)) + # Derivative of the loss is 1.0 + 2.0 for the first step and 1.5 + 2.5 + # for the second. + expected_c = 10.0 - 3.0, 7.0 - 4.0 + + with self.test_session() as session, variable_scope.variable_scope( + '', reuse=variable_scope.AUTO_REUSE): + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.create_model_fn_with_loss_reduction(losses.Reduction.SUM), + devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + session.run(variables.global_variables_initializer()) + + for feature_input, label_input, loss, weight in zip( + feature_inputs, label_inputs, expected_losses, expected_c): + feeds = {features: feature_input, labels: label_input} + + self.assertEqual(loss, session.run(estimator_spec.loss, feeds)) + + session.run(estimator_spec.train_op, feeds) + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(weight, session.run(c, feeds)) + + def test_eval(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.create_model_fn_with_loss_reduction(losses.Reduction.SUM), + devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.EVAL, self.params) + session.run(variables.local_variables_initializer()) + session.run(variables.global_variables_initializer()) + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + # loss[i] = features[i] * 10 - labels[i]. + # Accuracy is 0.0 (no match) in the first tower. + # Accuracy is 1.0 (match) in the second tower, since the feature + # times weight "c" happened to be equal to the label. + total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) + + self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01) + + def test_eval_with_mean_reduction(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.create_model_fn_with_loss_reduction(losses.Reduction.MEAN), + devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.EVAL, self.params) + session.run(variables.local_variables_initializer()) + session.run(variables.global_variables_initializer()) + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + # loss[i] = features[i] * 10 - labels[i]. + # Accuracy is 0.0 (no match) in the first tower. + # Accuracy is 1.0 (match) in the second tower, since the feature + # times weight "c" happened to be equal to the label. + total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) / 2.0 + + self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01) + + def test_predict(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.create_model_fn_with_loss_reduction(losses.Reduction.SUM), + devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.PREDICT, self.params) + session.run(variables.global_variables_initializer()) + + self.assertAllClose({ + 'probabilities': np.array([[0.1], [0.02]]) + }, session.run(estimator_spec.predictions)) + + def test_train_single_tower(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.create_model_fn_with_loss_reduction(losses.Reduction.SUM), + devices=['/gpu:0']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # loss' of c is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(7.0, session.run(c)) + + def test_eval_single_tower(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.create_model_fn_with_loss_reduction(losses.Reduction.SUM), + devices=['/gpu:0']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.EVAL, self.params) + session.run(variables.local_variables_initializer()) + session.run(variables.global_variables_initializer()) + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + # Accuracy is 0.0 (no match) in the first tower. + # Accuracy is 1.0 (match) in the second tower, since the feature + # times weight "c" happened to be equal to the label. + total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) + + self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01) + + def test_predict_single_tower(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.create_model_fn_with_loss_reduction(losses.Reduction.SUM), + devices=['/gpu:0']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.PREDICT, self.params) + session.run(variables.global_variables_initializer()) + + self.assertAllClose({ + 'probabilities': np.array([[0.1], [0.02]]) + }, session.run(estimator_spec.predictions)) + + def test_batch_size_that_is_not_divisible_by_the_number_of_gpus(self): + features = np.array([[1.0], [2.0], [3.0]]) + labels = np.array([[1.0], [2.0], [3.0]]) + + with self.assertRaisesRegexp( + ValueError, '.*Batch.+size.+needs.+to.+be.+divisible.+by.+GPUs.+'): + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.create_model_fn_with_loss_reduction(losses.Reduction.SUM), + devices=['/gpu:0', '/gpu:1']) + _ = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + + def test_unsupported_loss_reduction(self): + features = np.array([[1.0], [2.0], [3.0]]) + labels = np.array([[1.0], [2.0], [3.0]]) + + with self.assertRaisesRegexp(ValueError, + '.+none.+reduction.+is.+specified.+'): + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.create_model_fn_with_loss_reduction(losses.Reduction.NONE), + devices=['/gpu:0', '/gpu:1', '/gpu:2']) + _ = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + + def test_places_on_gpu_with_upper_case_spelling(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session(): + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.create_model_fn_with_loss_reduction(losses.Reduction.SUM), + devices=['/GPU:0']) + _ = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', c.device) + + def test_places_on_gpu_with_lower_case_spelling(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session(): + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.create_model_fn_with_loss_reduction(losses.Reduction.SUM), + devices=['/gpu:0']) + _ = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', c.device) + + +class ReplicateAcrossASingleDeviceWithoutTowerOptimizer( + test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + optimizer = gradient_descent.GradientDescentOptimizer( + params['learning_rate']) + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=optimizer.minimize(loss)) + + @property + def params(self): + params = {} + params['learning_rate'] = 1.0 + return params + + def test_train_single_tower(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.model_fn, devices=['/gpu:0']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # loss' of c is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(7.0, session.run(c)) + + +class UseTowerEstimatorWithoutReplication(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + features = features['features'] + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + optimizer = replicate_model_fn._TowerOptimizer( + gradient_descent.GradientDescentOptimizer(params['learning_rate'])) + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=optimizer.minimize(loss)) + + @property + def params(self): + params = {} + params['learning_rate'] = 1.0 + return params + + def test_train_single_tower(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + train_input_fn = numpy_io.numpy_input_fn( + x={'features': features}, y=labels, batch_size=2, shuffle=False) + + with self.test_session(): + estimator = estimator_lib.Estimator( + model_fn=self.model_fn, + model_dir=tempfile.mkdtemp(), + params=self.params) + estimator.train(train_input_fn, steps=1) + + self.assertEqual(7.0, estimator.get_variable_value('c')) + + +class MakeSureSyncReplicasOptimizerWorks(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + features = features['features'] + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + optimizer = gradient_descent.GradientDescentOptimizer( + params['learning_rate']) + optimizer = training.SyncReplicasOptimizer( + optimizer, replicas_to_aggregate=1) + sync_hook = optimizer.make_session_run_hook(True) + optimizer = replicate_model_fn._TowerOptimizer( + optimizer, loss_reduction=losses.Reduction.SUM) + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + training_hooks=[sync_hook], + predictions={'probabilities': predictions}, + train_op=optimizer.minimize( + loss, global_step=training.get_global_step())) + + @property + def params(self): + params = {} + params['learning_rate'] = 1.0 + return params + + def test_train_multiple_towers(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + train_input_fn = numpy_io.numpy_input_fn( + x={'features': features}, y=labels, batch_size=2, shuffle=False) + + model_fn = replicate_model_fn._replicate_model_fn( + self.model_fn, + devices=['/gpu:0', '/gpu:1']) + + estimator = estimator_lib.Estimator( + model_fn=model_fn, model_dir=tempfile.mkdtemp(), params=self.params) + estimator.train(train_input_fn, steps=1) + + self.assertEqual(7.0, estimator.get_variable_value('c')) + + +class ReplicateWithTwoOptimizersTest(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + side_effects = variable_scope.get_variable( + 'side_effects', + initializer=constant_op.constant(0, dtype=dtypes.float64), + dtype=dtypes.float64, + use_resource=True, + trainable=False) + + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + first_optimizer = replicate_model_fn._TowerOptimizer( + gradient_descent.GradientDescentOptimizer(1.0), + loss_reduction=losses.Reduction.SUM) + second_optimizer = replicate_model_fn._TowerOptimizer( + adam.AdamOptimizer(1.0), loss_reduction=losses.Reduction.SUM) + + with ops_lib.control_dependencies([side_effects.assign_add(1.0)]): + first_grads_and_vars = first_optimizer.compute_gradients(loss) + + train_op = control_flow_ops.group( + [first_optimizer.apply_gradients(first_grads_and_vars), + second_optimizer.minimize(loss)]) + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=train_op) + + def test_train(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.model_fn, + devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(features, labels, + model_fn_lib.ModeKeys.TRAIN, {}) + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # loss' of c is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + # Adam subtracts another ~1. + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertNear(6.0, session.run(c), 0.000001) + + side_effects = variable_scope.get_variable( + 'side_effects', dtype=dtypes.float64) + self.assertNear(2.0, session.run(side_effects), 0.000001) + + +class ReplicateWithTwoLossesAndOneOptimizer(test_util.TensorFlowTestCase): + + def setUp(self): + self._should_skip_optimizer = False + self._towers_left_before_skipping_optimizer = -1 + + def incorrectly_skip_optimizer_for_tower(self, tower_number): + self._should_skip_optimizer = True + self._towers_left_before_skipping_optimizer = tower_number + + def should_skip_optimizer(self): + if not self._should_skip_optimizer: + return False + if self._towers_left_before_skipping_optimizer == 0: + return True + else: + self._towers_left_before_skipping_optimizer -= 1 + return False + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + d = variable_scope.get_variable( + 'd', + initializer=constant_op.constant(2, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + another_predictions = math_ops.multiply(features, d) + another_loss = losses.absolute_difference( + labels=labels, + predictions=another_predictions, + reduction=losses.Reduction.SUM) + another_loss = math_ops.reduce_sum(another_loss) + + total_loss = math_ops.add(loss, another_loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + train_ops = [] + + optimizer = replicate_model_fn._TowerOptimizer( + gradient_descent.GradientDescentOptimizer(1.0), + loss_reduction=losses.Reduction.SUM) + train_ops.append(optimizer.minimize(loss, var_list=[c])) + if not self.should_skip_optimizer(): + another_optimizer = replicate_model_fn._TowerOptimizer( + gradient_descent.GradientDescentOptimizer(1.0), + loss_reduction=losses.Reduction.SUM) + train_ops.append(another_optimizer.minimize(another_loss, var_list=[d])) + + train_op = control_flow_ops.group(train_ops) + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=total_loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=train_op) + + def test_train(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with ops_lib.Graph().as_default(), self.test_session() as session: + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.model_fn, + devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(features, labels, + model_fn_lib.ModeKeys.TRAIN, {}) + session.run(variables.global_variables_initializer()) + + # For each tower, loss = (feature * c - label) + (feature * d - label). + total_loss = (1.0 * 10 - 1.0 + 1.0 * 2.0 - 1.0) + ( + 2.0 * 10 - 2.0 + 2.0 * 2.0 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + session.run(estimator_spec.train_op) + + # loss' of c or loss' of d is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + # new value of d = 2 - learning rate * 3 = -1.0. + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertNear(7.0, session.run(c), 0.000001) + d = variable_scope.get_variable('d', dtype=dtypes.float64) + self.assertNear(-1.0, session.run(d), 0.000001) + + def test_different_optimizer_calls_within_towers(self): + self.incorrectly_skip_optimizer_for_tower(1) + + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session(), ops_lib.Graph().as_default(): + with self.assertRaisesRegexp( + ValueError, '.+was.+supposed.+to.+make.+same.+optimizer.+calls.+'): + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.model_fn, devices=['/gpu:0', '/gpu:1']) + _ = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN, + {}) + + +class FailToWrapOptimizerInTheModelFn(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + optimizer = gradient_descent.GradientDescentOptimizer(1.0) + train_op = optimizer.minimize(loss) + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=train_op) + + def test_train(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session(): + with self.assertRaisesRegexp(ValueError, + 'Please.+wrap.+with.+TowerOptimizer'): + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.model_fn, devices=['/gpu:0', '/gpu:1']) + _ = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN, + {}) + + +class GetLossTowersTest(test_util.TensorFlowTestCase): + + def create_model_fn_with_loss_reduction(self, loss_reduction): + + def model_fn(mode, features, labels, params): + del params + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(0.25, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c) + labels = np.array([0.1, 0.2, 0.3, labels[0]]) + + loss = losses.absolute_difference( + labels=labels, + predictions=predictions, + reduction=losses.Reduction.SUM) + + optimizer = replicate_model_fn._TowerOptimizer( + gradient_descent.GradientDescentOptimizer(1.0), + loss_reduction) + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=math_ops.reduce_sum(loss), + train_op=optimizer.minimize(loss)) + + return model_fn + + def test_gradients_are_computed(self): + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + self.create_model_fn_with_loss_reduction(losses.Reduction.SUM), + mode=None, + features=[[0.6], [1.6]], + labels=[[0.6], [0.6]], + params=None, + config=None, + devices=['/gpu:0', '/gpu:1'], + local_ps_devices=['/gpu:0'], + name_scope_pattern='test_tower_{}') + session.run(variables.global_variables_initializer()) + + self.assertEqual(len(tower_specs), 2) + + self.assertEqual('/device:GPU:0', tower_specs[0].loss.device) + self.assertEqual('Sum:0', tower_specs[0].loss.name) + self.assertEqual(1.0, session.run(tower_specs[0].loss)) + + self.assertEqual('/device:GPU:1', tower_specs[1].loss.device) + self.assertEqual('test_tower_1/Sum:0', tower_specs[1].loss.name) + # The input batch for the second tower had a loss that is 1.0 + # bigger: 0.6 vs 1.6. + self.assertEqual(2.0, session.run(tower_specs[1].loss)) + + self.assertEqual(1, len(variables.global_variables())) + self.assertEqual(1, len(variables.trainable_variables())) + + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(0.25, session.run(c)) + + def test_gradients_are_computed_with_mean_reduction(self): + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + self.create_model_fn_with_loss_reduction(losses.Reduction.MEAN), + mode=model_fn_lib.ModeKeys.EVAL, + features=[[0.6], [1.6]], + labels=[[0.6], [0.6]], + params=None, + config=None, + devices=['/gpu:0', '/gpu:1'], + local_ps_devices=['/gpu:0'], + name_scope_pattern='test_tower_{}') + session.run(variables.global_variables_initializer()) + + self.assertEqual(len(tower_specs), 2) + + self.assertEqual('/device:GPU:0', tower_specs[0].loss.device) + self.assertEqual('averaged_loss:0', tower_specs[0].loss.name) + self.assertEqual(0.5, session.run(tower_specs[0].loss)) + + self.assertEqual('/device:GPU:1', tower_specs[1].loss.device) + self.assertEqual('test_tower_1/averaged_loss:0', tower_specs[1].loss.name) + # The input batch for the second tower had a loss that is 1.0 + # bigger: 0.6 vs 1.6. + self.assertEqual(1.0, session.run(tower_specs[1].loss)) + + self.assertEqual(1, len(variables.global_variables())) + self.assertEqual(1, len(variables.trainable_variables())) + + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(0.25, session.run(c)) + + def test_variables_are_round_robined_correctly(self): + """Test that creates multiple variables and tests round-robin placement.""" + + def model_fn(mode, features, labels, params): + del params + for variable_name in ['a', 'b', 'c', 'd']: + c = variable_scope.get_variable( + variable_name, + initializer=constant_op.constant(0.25, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c) + labels = np.array([0.1, 0.2, 0.3, labels[0]]) + loss = losses.absolute_difference( + labels=labels, + predictions=predictions, + reduction=losses.Reduction.SUM) + return model_fn_lib.EstimatorSpec( + mode=mode, loss=math_ops.reduce_sum(loss)) + + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + model_fn, + mode=None, + features=[[0.6], [1.6], [2.6]], + labels=[[0.6], [0.6], [2.6]], + params=None, + config=None, + devices=['/gpu:0', '/gpu:1', '/gpu:3'], + local_ps_devices=['/gpu:0', '/gpu:1', '/gpu:3'], + name_scope_pattern='test_tower_{}') + session.run(variables.global_variables_initializer()) + + self.assertEqual(len(tower_specs), 3) + self.assertEqual('/device:GPU:0', tower_specs[0].loss.device) + self.assertEqual('/device:GPU:1', tower_specs[1].loss.device) + self.assertEqual('/device:GPU:3', tower_specs[2].loss.device) + + with variable_scope.variable_scope('', reuse=True): + a = variable_scope.get_variable('a', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', a.device) + b = variable_scope.get_variable('b', dtype=dtypes.float64) + self.assertEqual('/device:GPU:1', b.device) + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual('/device:GPU:3', c.device) + d = variable_scope.get_variable('d', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', d.device) + + +class SplitBatchTest(test_util.TensorFlowTestCase): + + def evaluate_shards(self, first_list, second_list): + evaluate_items = lambda x: x.eval() + return list(map(evaluate_items, first_list)), list( + map(evaluate_items, second_list)) + + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def test_simple_half_split(self): + with self.test_session(): + features = [0.0, 1.0, 2.0, 3.0] + labels = [10.0, 11.0, 12.0, 13.0] + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + feature_shards, label_shards = self.evaluate_shards( + feature_shards, label_shards) + + self.assertAllEqual([[0.0, 1.0], [2.0, 3.0]], feature_shards) + self.assertAllEqual([[10.0, 11.0], [12.0, 13.0]], label_shards) + + def test_to_each_their_own(self): + with self.test_session(): + features = [0.0, 1.0, 2.0, 3.0] + labels = [10.0, 11.0, 12.0, 13.0] + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 4, device='/gpu:0') + + feature_shards, label_shards = self.evaluate_shards( + feature_shards, label_shards) + + self.assertAllEqual([[0.0], [1.0], [2.0], [3.0]], feature_shards) + self.assertAllEqual([[10.0], [11.0], [12.0], [13.0]], label_shards) + + def test_one_batch(self): + with self.test_session(): + features = [0.0, 1.0, 2.0, 3.0] + labels = [10.0, 11.0, 12.0, 13.0] + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 1, device='/gpu:0') + + feature_shards, label_shards = self.evaluate_shards( + feature_shards, label_shards) + + self.assertAllEqual([[0.0, 1.0, 2.0, 3.0]], feature_shards) + self.assertAllEqual([[10.0, 11.0, 12.0, 13.0]], label_shards) + + def test_half_split_in_dictionary(self): + with self.test_session(): + features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} + labels = [10.0, 11.0, 12.0, 13.0] + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval()) + self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval()) + self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval()) + self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval()) + self.assertAllEqual([10.0, 11.0], label_shards[0].eval()) + self.assertAllEqual([12.0, 13.0], label_shards[1].eval()) + + def test_sparse_tensor_can_be_split_unevenly(self): + with self.test_session(): + features = { + 'x': + sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 2], [2, 2]], + values=[1.0, 2.0, 3.0], + dense_shape=[3, 4]) + } + labels = np.array([[1.0], [2.0]]) + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + self.assertSparseValuesEqual( + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 2]], values=[1., 2.], dense_shape=[2, 4]), + feature_shards[0]['x'].eval()) + self.assertSparseValuesEqual( + sparse_tensor.SparseTensorValue( + indices=[[0, 2]], values=[3.], dense_shape=[1, 4]), + feature_shards[1]['x'].eval()) + self.assertAllEqual([[1.0]], label_shards[0].eval()) + self.assertAllEqual([[2.0]], label_shards[1].eval()) + + def test_sparse_tensor_can_be_split_unevenly_repeated_row(self): + with self.test_session(): + features = { + 'x': + sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 0], [1, 1]], + values=[1.0, 2.0, 3.0], + dense_shape=[3, 4]) + } + labels = np.array([[1.0], [2.0]]) + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + self.assertSparseValuesEqual( + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 0], [1, 1]], + values=[1., 2., 3.], + dense_shape=[2, 4]), feature_shards[0]['x'].eval()) + + second_batch = feature_shards[1]['x'].eval() + self.assertFalse(len(second_batch.indices)) + self.assertFalse(len(second_batch.values)) + self.assertAllEqual([1, 4], second_batch.dense_shape) + self.assertAllEqual([[1.0]], label_shards[0].eval()) + self.assertAllEqual([[2.0]], label_shards[1].eval()) + + def test_one_batch_in_dictionary(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} + labels = [10.0, 11.0, 12.0, 13.0] + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 1, device='/gpu:0') + + self.assertAllEqual([0.0, 1.0, 2.0, 3.0], + feature_shards[0]['first'].eval()) + self.assertAllEqual([4.0, 5.0, 6.0, 7.0], + feature_shards[0]['second'].eval()) + self.assertAllEqual([10.0, 11.0, 12.0, 13.0], label_shards[0].eval()) + + def test_feature_and_label_dictionaries(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} + labels = {'first': [10.0, 11.0], 'second': [12.0, 13.0]} + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval()) + self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval()) + self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval()) + self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval()) + self.assertAllEqual([10.0], label_shards[0]['first'].eval()) + self.assertAllEqual([12.0], label_shards[0]['second'].eval()) + self.assertAllEqual([11], label_shards[1]['first'].eval()) + self.assertAllEqual([13.0], label_shards[1]['second'].eval()) + + +class TrainSpecTest(test_util.TensorFlowTestCase): + + expected_predictions = {} + + def create_estimator_spec(self, loss): + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.TRAIN, + loss=loss, + train_op=loss, # Not used; currently required. + predictions=self.expected_predictions) + + def create_constant_loss(self, loss_value): + return constant_op.constant(loss_value, dtype=dtypes.float64) + + def test_example(self): + with self.test_session() as session: + tower_losses = list(map(self.create_constant_loss, [2, 4, 6])) + tower_specs = list(map(self.create_estimator_spec, tower_losses)) + + expected_train_op = tower_losses[1] + + estimator_spec = replicate_model_fn._train_spec( + tower_specs, expected_train_op, aggregation_device='/gpu:0') + + self.assertEqual(expected_train_op, estimator_spec.train_op) + self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss)) + self.assertEqual(self.expected_predictions, estimator_spec.predictions) + + +class EvalSpecTest(test_util.TensorFlowTestCase): + + def create_estimator_spec(self, loss, metrics): + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics) + + def create_constant_loss(self, loss_value): + return constant_op.constant(loss_value, dtype=dtypes.float64) + + def create_eval_metrics(self, noise): + predictions = np.array([0.1, 0.2, 0.3, 0.6 + noise]) + labels = np.array([0.1, 0.2, 0.3, 0.6]) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + return metrics + + def test_example(self): + with self.test_session() as session: + tower_losses = map(self.create_constant_loss, [2, 4, 6]) + tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3]) + tower_specs = [ + self.create_estimator_spec(l, m) + for l, m in zip(tower_losses, tower_metrics) + ] + session.run(variables.local_variables_initializer()) + + estimator_spec = replicate_model_fn._eval_spec( + tower_specs, aggregation_device='/device:GPU:0') + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + self.assertEqual('/device:CPU:0', accuracy.device) + self.assertEqual('/device:CPU:0', auc.device) + + session.run([a, b]) + accuracy, auc = session.run([accuracy, auc]) + + self.assertNear((12 - 2) / 12, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss)) + + def test_handles_single_tower(self): + with self.test_session() as session: + tower_losses = map(self.create_constant_loss, [5]) + tower_metrics = map(self.create_eval_metrics, [0.2]) + tower_specs = [ + self.create_estimator_spec(l, m) + for l, m in zip(tower_losses, tower_metrics) + ] + session.run(variables.local_variables_initializer()) + + estimator_spec = replicate_model_fn._eval_spec( + tower_specs, aggregation_device='/device:GPU:0') + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + self.assertEqual('/device:CPU:0', accuracy.device) + self.assertEqual('/device:CPU:0', auc.device) + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + self.assertNear((4 - 1) / 4, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertEqual(5, session.run(estimator_spec.loss)) + + +class PredictSpecTest(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(0.25, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.add(np.array([features[0], features[0]]), c) + + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.PREDICT, + predictions={ + 'probabilities': predictions + }) + + def test_example(self): + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + self.model_fn, + mode=None, + features=[[0.1], [0.2]], + labels=[[], []], + params=None, + config=None, + devices=['/gpu:0', '/gpu:1'], + local_ps_devices=['/gpu:0'], + ) + session.run(variables.global_variables_initializer()) + + estimator_spec = replicate_model_fn._predict_spec( + tower_specs, aggregation_device='/gpu:0') + + self.assertEqual('/device:GPU:0', + estimator_spec.predictions['probabilities'].device) + self.assertAllClose({ + 'probabilities': np.array([0.35, 0.35, 0.45, 0.45]) + }, session.run(estimator_spec.predictions)) + + +class ReduceMetricVariablesTest(test_util.TensorFlowTestCase): + + def create_metric_variable(self, initial_value, name): + return variable_scope.variable( + initial_value, + trainable=False, + collections=[ops_lib.GraphKeys.METRIC_VARIABLES], + validate_shape=True, + name=name) + + def create_tower_metrics(self, tower_id): + with variable_scope.variable_scope('', reuse=(tower_id != 0)): + self.create_metric_variable(1.3 * (tower_id + 1), 'total') + self.create_metric_variable(2.3 * (tower_id + 1), 'count') + self.create_metric_variable( + np.array([3.3, 3.5, 3.7]) * (tower_id + 1), 'total') + + def test_example(self): + with self.test_session() as session: + for tower_id in range(3): + self.create_tower_metrics(tower_id) + + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=3)) + + # 1st tower = 1.3, 2.3, [3.3, 3.5, 3.7] + # 2nd tower = 2.6, 4.6, [6.6, 7.0, 7.4] + # 3rd tower = 3.9, 6.9, [9.9, 10.5, 11.1] + # Reduced = 7.8, 13.8, [19.8, 21.0, 22.2] + # Towers are accumulated in the first tower. + local_metrics = session.run( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)) + + self.assertNear(7.8, local_metrics[0], 0.01) + self.assertNear(13.8, local_metrics[1], 0.01) + self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01) + self.assertNear(0.0, local_metrics[3], 0.01) + self.assertNear(0.0, local_metrics[4], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01) + self.assertNear(0.0, local_metrics[6], 0.01) + self.assertNear(0.0, local_metrics[7], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01) + + def test_reduce_is_idempotent(self): + with self.test_session() as session: + for tower_id in range(3): + self.create_tower_metrics(tower_id) + + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + for _ in range(20): + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=3)) + + local_metrics = session.run( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)) + + self.assertNear(7.8, local_metrics[0], 0.01) + self.assertNear(13.8, local_metrics[1], 0.01) + self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01) + self.assertNear(0.0, local_metrics[3], 0.01) + self.assertNear(0.0, local_metrics[4], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01) + self.assertNear(0.0, local_metrics[6], 0.01) + self.assertNear(0.0, local_metrics[7], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01) + + def test_handles_single_tower(self): + with self.test_session() as session: + self.create_tower_metrics(0) + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=1)) + + local_metrics = session.run( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)) + + self.assertNear(1.3, local_metrics[0], 0.01) + self.assertNear(2.3, local_metrics[1], 0.01) + self.assertAllClose([3.3, 3.5, 3.7], local_metrics[2], 0.01) + + def test_doesnt_accept_uneven_number_of_variables(self): + with self.test_session() as session: + for tower_id in range(3): + self.create_tower_metrics(tower_id) + self.create_metric_variable(-1.0, 'oddball') + + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + with self.assertRaisesRegexp( + ValueError, '.+Expected.+local.+variables.+but.+got.+instead.+'): + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=3)) + + +class MergeExportOutputsTest(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = {'probabilities': math_ops.multiply(features, c)} + loss = losses.absolute_difference( + labels=labels, + predictions=predictions['probabilities'], + reduction=losses.Reduction.SUM) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions['probabilities']), + 'auc': metrics_lib.auc(labels, predictions['probabilities']) + } + tensor_string_repr = str(features) + classes = constant_op.constant( + re.search('(split_inputs/split:[0-9])', tensor_string_repr).group(1), + dtype=dtypes.string) + + export_outputs = { + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + export_output.PredictOutput(predictions), + 'classification_output': + export_output.ClassificationOutput(predictions['probabilities'], + classes), + 'classification_scores': + export_output.ClassificationOutput( + scores=predictions['probabilities']), + 'classification_classes': + export_output.ClassificationOutput(classes=classes), + 'regression_output': + export_output.RegressionOutput(predictions['probabilities']), + } + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=math_ops.reduce_sum(loss), + eval_metric_ops=metrics, + predictions=predictions, + export_outputs=export_outputs) + + def replicate_estimator_spec(self, session): + features = np.array([0.01, 0.002]) + labels = np.array([0.01, 0.02]) + + replicated_model_fn = replicate_model_fn._replicate_model_fn( + self.model_fn, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(features, labels, + model_fn_lib.ModeKeys.PREDICT, {}) + session.run(variables.global_variables_initializer()) + return estimator_spec + + def test_merge_predict_output(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + { + 'probabilities': np.array([0.1, 0.02]) + }, + session.run(estimator_spec.export_outputs[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs)) + + def test_merge_classification_output_scores_classes(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + [0.1, 0.02], + session.run( + estimator_spec.export_outputs['classification_output'].scores)) + self.assertAllEqual( + [b'split_inputs/split:0', b'split_inputs/split:1'], + session.run( + estimator_spec.export_outputs['classification_output'].classes)) + + def test_merge_classification_output_scores(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + [0.1, 0.02], + session.run( + estimator_spec.export_outputs['classification_scores'].scores)) + self.assertEqual( + None, estimator_spec.export_outputs['classification_scores'].classes) + + def test_merge_classification_output_classes(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllEqual( + [b'split_inputs/split:0', b'split_inputs/split:1'], + session.run( + estimator_spec.export_outputs['classification_classes'].classes)) + self.assertEqual( + None, estimator_spec.export_outputs['classification_classes'].scores) + + def test_merge_regression_output(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + [0.1, 0.02], + session.run(estimator_spec.export_outputs['regression_output'].value)) + + +class GetLocalDevicesTest(test_util.TensorFlowTestCase): + + def test_there_is_at_least_a_cpu(self): + self.assertTrue(replicate_model_fn._get_local_devices('CPU')) + + def test_there_is_no_xpu(self): + self.assertFalse( + replicate_model_fn._get_local_devices('XPU')) # XPU doesn't exist. + + def test_whether_there_is_a_gpu(self): + if test.is_gpu_available(): + self.assertTrue(len(replicate_model_fn._get_local_devices('GPU'))) + + +class LocalDeviceSetterTest(test_util.TensorFlowTestCase): + + def test_vars_are_on_ps_but_ops_are_on_workers(self): + ps_devices = ['/device:GPU:3'] + round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices)) + + local_device_setter = replicate_model_fn._local_device_setter( + ps_devices=ps_devices, + ps_strategy=round_robin, + worker_device='/device:GPU:2') + + with ops_lib.device(local_device_setter): + a = variables.Variable(0.01) + self.assertEqual('/device:GPU:3', a.device) + + b = variables.Variable(0.02) + self.assertEqual('/device:GPU:3', b.device) + + c = variables.Variable(0.03) + self.assertEqual('/device:GPU:3', c.device) + + a_op = array_ops.concat(a, axis=0) + self.assertEqual('/device:GPU:2', a_op.device) + + b_op = array_ops.concat(b, axis=0) + self.assertEqual('/device:GPU:2', b_op.device) + + def test_round_robin_placement(self): + ps_devices = [ + '/device:GPU:0', '/device:GPU:1', '/device:GPU:3', '/device:GPU:4' + ] + round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices)) + + local_device_setter = replicate_model_fn._local_device_setter( + ps_devices=ps_devices, + ps_strategy=round_robin, + worker_device='/device:GPU:2') + + with ops_lib.device(local_device_setter): + a = variables.Variable(0.01) + self.assertEqual('/device:GPU:0', a.device) + + b = variables.Variable(0.02) + self.assertEqual('/device:GPU:1', b.device) + + c = variables.Variable(0.03) + self.assertEqual('/device:GPU:3', c.device) + + a_op = array_ops.concat(a, axis=0) + self.assertEqual('/device:GPU:2', a_op.device) + + b_op = array_ops.concat(b, axis=0) + self.assertEqual('/device:GPU:2', b_op.device) + + c = variables.Variable(0.03) + self.assertEqual('/device:GPU:4', c.device) + + d = variables.Variable(0.03) + self.assertEqual('/device:GPU:0', d.device) + + c_op = array_ops.concat(c, axis=0) + self.assertEqual('/device:GPU:2', c_op.device) + + +class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase): + + def test_vectors(self): + with self.test_session() as session: + total = replicate_model_fn._compute_sum_on_device( + [1.0, 2.0, 3.0, 4.0], device='/device:GPU:0', name='test_sum') + + self.assertEqual('/device:GPU:0', total.device) + self.assertEqual('test_sum', total.op.name) + self.assertEqual(10.0, session.run(total)) + + def test_tensors(self): + with self.test_session() as session: + total = replicate_model_fn._compute_sum_on_device( + [[1.0, 2.0], [3.0, 4.0]], device='/device:GPU:0', name='test_sum') + + self.assertEqual('/device:GPU:0', total.device) + self.assertEqual('test_sum', total.op.name) + self.assertAllEqual([4.0, 6.0], session.run(total)) + + def test_indexedslices(self): + with self.test_session() as session: + a = ops_lib.IndexedSlices( + constant_op.constant([1.0, 2.0]), [0, 1], + dense_shape=constant_op.constant([2])) + b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1]) + + total = replicate_model_fn._compute_sum_on_device( + [a, b], device='/device:GPU:0') + + self.assertEqual('/device:GPU:0', total.device) + self.assertAllEqual([4.0, 6.0], + session.run(ops_lib.convert_to_tensor(total))) + + def test_indexedslices_higher_dimensions(self): + with self.test_session() as session: + a = ops_lib.IndexedSlices( + constant_op.constant([[1.0, 5.0], [2.0, 6.0]]), [0, 1], + dense_shape=constant_op.constant([2, 4])) + b = ops_lib.IndexedSlices( + constant_op.constant([[3.0, 7.0], [4.0, 8.0]]), [0, 1]) + + total = replicate_model_fn._compute_sum_on_device( + [a, b], device='/device:GPU:0') + + self.assertEqual('/device:GPU:0', total.device) + self.assertAllEqual([[4.0, 12.0], [6.0, 14.0]], + session.run(ops_lib.convert_to_tensor(total))) + + def test_indexedslices_some_dont_overlap(self): + with self.test_session() as session: + a = ops_lib.IndexedSlices( + constant_op.constant([1.0, 2.0]), [0, 3], + dense_shape=constant_op.constant([4])) + b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1]) + + total = replicate_model_fn._compute_sum_on_device( + [a, b], device='/device:GPU:0') + + self.assertEqual('/device:GPU:0', total.device) + self.assertAllEqual([4.0, 4.0, 0.0, 2.0], + session.run(ops_lib.convert_to_tensor(total))) + + def test_no_name_for_indexslices(self): + a = ops_lib.IndexedSlices( + constant_op.constant([1.0, 2.0]), [0, 1], + dense_shape=constant_op.constant([2])) + b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1]) + + with self.assertRaisesRegexp(ValueError, '.+name.+not.+expected.+'): + _ = replicate_model_fn._compute_sum_on_device( + [a, b], device='/device:GPU:0', name='cant_name_indexslices') + + +class ConcatTensorDictsTest(test_util.TensorFlowTestCase): + + def test_example(self): + tensor_dicts = [ + { + 'a': np.array([1.0, 2.0]), + 'b': np.array([11.0]), + 'c': np.array([21.0]), + }, + { + 'a': np.array([3.0]), + 'b': np.array([12.0, 13.0]), + }, + { + 'b': np.array([14.0]), + }, + ] + + with self.test_session() as session: + self.assertAllClose({ + 'a': np.array([1.0, 2.0, 3.0]), + 'b': np.array([11.0, 12.0, 13.0, 14.0]), + 'c': np.array([21.0]), + }, session.run(replicate_model_fn._concat_tensor_dicts(*tensor_dicts))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py index 3e021242c4cc914990c6b38736b8f725213b5b7e..141eaeff649414412a4277f8945dcb4445985170 100644 --- a/tensorflow/python/estimator/run_config.py +++ b/tensorflow/python/estimator/run_config.py @@ -43,7 +43,8 @@ _DEFAULT_REPLACEABLE_LIST = [ 'session_config', 'keep_checkpoint_max', 'keep_checkpoint_every_n_hours', - 'log_step_count_steps' + 'log_step_count_steps', + 'distribute' ] _SAVE_CKPT_ERR = ( @@ -300,7 +301,8 @@ class RunConfig(object): session_config=None, keep_checkpoint_max=5, keep_checkpoint_every_n_hours=10000, - log_step_count_steps=100): + log_step_count_steps=100, + distribute=None): """Constructs a RunConfig. All distributed training related properties `cluster_spec`, `is_chief`, @@ -345,7 +347,7 @@ class RunConfig(object): os.environ['TF_CONFIG'] = json.dumps( {'cluster': cluster, 'task': {'type': 'worker', 'index': 1}}) - config = ClusterConfig() + config = RunConfig() assert config.master == 'host4:2222' assert config.task_id == 1 assert config.num_ps_replicas == 2 @@ -363,7 +365,7 @@ class RunConfig(object): os.environ['TF_CONFIG'] = json.dumps( {'cluster': cluster, 'task': {'type': 'chief', 'index': 0}}) - config = ClusterConfig() + config = RunConfig() assert config.master == 'host0:2222' assert config.task_id == 0 assert config.num_ps_replicas == 2 @@ -381,7 +383,7 @@ class RunConfig(object): os.environ['TF_CONFIG'] = json.dumps( {'cluster': cluster, 'task': {'type': 'evaluator', 'index': 0}}) - config = ClusterConfig() + config = RunConfig() assert config.master == '' assert config.evaluator_master == '' assert config.task_id == 0 @@ -423,8 +425,11 @@ class RunConfig(object): to be saved. The default value of 10,000 hours effectively disables the feature. log_step_count_steps: The frequency, in number of global steps, that the - global step/sec will be logged during training. - + global step/sec and the loss will be logged during training. + distribute: an optional instance of + `tf.contrib.distribute.DistributionStrategy`. If specified, + then Estimator will distribute the user's model according to the policy + specified by that strategy. Raises: ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs` @@ -460,7 +465,8 @@ class RunConfig(object): session_config=session_config, keep_checkpoint_max=keep_checkpoint_max, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, - log_step_count_steps=log_step_count_steps) + log_step_count_steps=log_step_count_steps, + distribute=distribute) self._init_distributed_setting_from_environment_var(tf_config) @@ -671,6 +677,12 @@ class RunConfig(object): """Returns the platform defined (in TF_CONFIG) service dict.""" return self._service + @property + def distribute(self): + """Returns the optional `tf.contrib.distribute.DistributionStrategy` object. + """ + return self._distribute + def replace(self, **kwargs): """Returns a new instance of `RunConfig` replacing specified properties. diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 2cc3331a15867e9a984847391857bf84baee7424..e38b765da52a7b6957a4fb8a02087c5d1fd5a781 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -128,9 +128,16 @@ class TrainSpec( """Creates a validated `TrainSpec` instance. Args: - input_fn: Training input function returning a tuple of: - features - `Tensor` or dictionary of string feature name to `Tensor`. - labels - `Tensor` or dictionary of `Tensor` with labels. + input_fn: A function that provides input data for training as minibatches. + See @{$get_started/premade_estimators#create_input_functions} for more + information. The function should construct and return one of + the following: + * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a + tuple (features, labels) with same constraints as below. + * A tuple (features, labels): Where features is a `Tensor` or a + dictionary of string feature name to `Tensor` and labels is a + `Tensor` or a dictionary of string label name to `Tensor`. + max_steps: Int. Positive number of total steps for which to train model. If `None`, train forever. The training `input_fn` is not expected to generate `OutOfRangeError` or `StopIteration` exceptions. See the @@ -185,9 +192,16 @@ class EvalSpec( """Creates a validated `EvalSpec` instance. Args: - input_fn: Evaluation input function returning a tuple of: - features - `Tensor` or dictionary of string feature name to `Tensor`. - labels - `Tensor` or dictionary of `Tensor` with labels. + input_fn: A function that constructs the input data for evaluation. + See @{$get_started/premade_estimators#create_input_functions} for more + information. The function should construct and return one of + the following: + * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a + tuple (features, labels) with same constraints as below. + * A tuple (features, labels): Where features is a `Tensor` or a + dictionary of string feature name to `Tensor` and labels is a + `Tensor` or a dictionary of string label name to `Tensor`. + steps: Int. Positive number of steps for which to evaluate model. If `None`, evaluates until `input_fn` raises an end-of-input exception. See `Estimator.evaluate` for details. diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py index 3ce8eea84b6bf601ce89dfaa7d8e3a5d193468b3..bb4bdd3fdfb2e19dbc1c581d7771f2e1ac4442ba 100644 --- a/tensorflow/python/estimator/util.py +++ b/tensorflow/python/estimator/util.py @@ -20,7 +20,12 @@ from __future__ import division from __future__ import print_function import functools +import os +import time +from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import compat from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect @@ -56,3 +61,48 @@ def fn_args(fn): if _is_bounded_method(fn): args.remove('self') return tuple(args) + + +# When we create a timestamped directory, there is a small chance that the +# directory already exists because another process is also creating these +# directories. In this case we just wait one second to get a new timestamp and +# try again. If this fails several times in a row, then something is seriously +# wrong. +MAX_DIRECTORY_CREATION_ATTEMPTS = 10 + + +def get_timestamped_dir(dir_base): + """Builds a path to a new subdirectory within the base directory. + + The subdirectory will be named using the current time. + This guarantees monotonically increasing directory numbers even across + multiple runs of the pipeline. + The timestamp used is the number of seconds since epoch UTC. + + Args: + dir_base: A string containing a directory to create the subdirectory under. + + Returns: + The full path of the new subdirectory (which is not actually created yet). + + Raises: + RuntimeError: if repeated attempts fail to obtain a unique timestamped + directory name. + """ + attempts = 0 + while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS: + timestamp = int(time.time()) + + result_dir = os.path.join( + compat.as_bytes(dir_base), compat.as_bytes(str(timestamp))) + if not gfile.Exists(result_dir): + # Collisions are still possible (though extremely unlikely): this + # directory is not actually created yet, but it will be almost + # instantly on return from this function. + return result_dir + time.sleep(1) + attempts += 1 + logging.warn('Directory {} already exists; retrying (attempt {}/{})'.format( + result_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)) + raise RuntimeError('Failed to obtain a unique export directory name after ' + '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS)) diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index a758f8a4fc4898713772c4e919acda48b0f6ad0b..238a90b67d9d0039c25a6f3800aad25a2db9e36f 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -74,7 +74,10 @@ py_test( srcs = ["feature_column_test.py"], data = [":vocabulary_testdata"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_cuda_on_cpu_tap", + "no_pip", + ], deps = [ ":feature_column", ":feature_column_py", diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index c416881c3119c160d28f4b8e37cd2aeb22f239a6..7d99fcb3e79318c2fecabaa9bdd0347aa67cf309 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -16,7 +16,7 @@ FeatureColumns provide a high level abstraction for ingesting and representing features. FeatureColumns are also the primary way of encoding features for -canned ${tf.estimator.Estimator}s. +canned @{tf.estimator.Estimator}s. When using FeatureColumns with `Estimators`, the type of feature column you should choose depends on (1) the feature type and (2) the model type. @@ -1626,7 +1626,7 @@ class _FeatureColumn(object): It is used for get_parsing_spec for `tf.parse_example`. Returned spec is a dict from keys ('string') to `VarLenFeature`, `FixedLenFeature`, and other - supported objects. Please check documentation of ${tf.parse_example} for all + supported objects. Please check documentation of @{tf.parse_example} for all supported spec objects. Let's say a Feature column depends on raw feature ('raw') and another @@ -1677,7 +1677,7 @@ class _DenseColumn(_FeatureColumn): weight_collections: List of graph collections to which Variables (if any will be created) are added. trainable: If `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see ${tf.Variable}). + `GraphKeys.TRAINABLE_VARIABLES` (see @{tf.Variable}). Returns: `Tensor` of shape [batch_size] + `_variable_shape`. @@ -1735,7 +1735,7 @@ class _CategoricalColumn(_FeatureColumn): WARNING: Do not subclass this layer unless you know what you are doing: the API is subject to future changes. - A categorical feature typically handled with a ${tf.SparseTensor} of IDs. + A categorical feature typically handled with a @{tf.SparseTensor} of IDs. """ __metaclass__ = abc.ABCMeta @@ -1770,7 +1770,7 @@ class _CategoricalColumn(_FeatureColumn): weight_collections: List of graph collections to which variables (if any will be created) are added. trainable: If `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see ${tf.get_variable}). + `GraphKeys.TRAINABLE_VARIABLES` (see @{tf.get_variable}). """ pass @@ -1804,6 +1804,21 @@ def _create_categorical_column_weighted_sum( name='weighted_sum') +class _SequenceDenseColumn(_FeatureColumn): + """Represents dense sequence data.""" + + __metaclass__ = abc.ABCMeta + + TensorSequenceLengthPair = collections.namedtuple( # pylint: disable=invalid-name + 'TensorSequenceLengthPair', ['dense_tensor', 'sequence_length']) + + @abc.abstractmethod + def _get_sequence_dense_tensor( + self, inputs, weight_collections=None, trainable=None): + """Returns a `TensorSequenceLengthPair`.""" + pass + + class _LazyBuilder(object): """Handles caching of transformations while building the model. @@ -1874,12 +1889,12 @@ class _LazyBuilder(object): self._feature_tensors[key] = feature_tensor return feature_tensor - if not isinstance(key, (str, _FeatureColumn)): - raise TypeError('"key" must be either a "str" or "_FeatureColumn". ' - 'Provided: {}'.format(key)) + if isinstance(key, str): + raise ValueError('Feature {} is not in features dictionary.'.format(key)) if not isinstance(key, _FeatureColumn): - raise ValueError('Feature {} is not in features dictionary.'.format(key)) + raise TypeError('"key" must be either a "str" or "_FeatureColumn". ' + 'Provided: {}'.format(key)) column = key logging.debug('Transforming feature_column %s.', column) @@ -2152,7 +2167,7 @@ class _BucketizedColumn(_DenseColumn, _CategoricalColumn, class _EmbeddingColumn( - _DenseColumn, + _DenseColumn, _SequenceDenseColumn, collections.namedtuple('_EmbeddingColumn', ( 'categorical_column', 'dimension', 'combiner', 'initializer', 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable' @@ -2178,7 +2193,9 @@ class _EmbeddingColumn( self._shape = tensor_shape.vector(self.dimension) return self._shape - def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + def _get_dense_tensor_internal( + self, inputs, weight_collections=None, trainable=None): + """Private method that follows the signature of _get_dense_tensor.""" # Get sparse IDs and weights. sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access inputs, weight_collections=weight_collections, trainable=trainable) @@ -2210,6 +2227,43 @@ class _EmbeddingColumn( name='%s_weights' % self.name, max_norm=self.max_norm) + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + if isinstance(self.categorical_column, _SequenceCategoricalColumn): + raise ValueError( + 'In embedding_column: {}. ' + 'categorical_column must not be of type _SequenceCategoricalColumn. ' + 'Suggested fix A: If you wish to use input_layer, use a ' + 'non-sequence categorical_column_with_*. ' + 'Suggested fix B: If you wish to create sequence input, use ' + 'sequence_input_layer instead of input_layer. ' + 'Given (type {}): {}'.format( + self.name, type(self.categorical_column), + self.categorical_column)) + return self._get_dense_tensor_internal( + inputs=inputs, weight_collections=weight_collections, + trainable=trainable) + + def _get_sequence_dense_tensor( + self, inputs, weight_collections=None, trainable=None): + if not isinstance(self.categorical_column, _SequenceCategoricalColumn): + raise ValueError( + 'In embedding_column: {}. ' + 'categorical_column must be of type _SequenceCategoricalColumn ' + 'to use sequence_input_layer. ' + 'Suggested fix: Use one of sequence_categorical_column_with_*. ' + 'Given (type {}): {}'.format( + self.name, type(self.categorical_column), + self.categorical_column)) + dense_tensor = self._get_dense_tensor_internal( # pylint: disable=protected-access + inputs=inputs, + weight_collections=weight_collections, + trainable=trainable) + sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access + sequence_length = _sequence_length_from_sparse_tensor( + sparse_tensors.id_tensor) + return _SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) + class _SharedEmbeddingColumn( _DenseColumn, @@ -2890,7 +2944,7 @@ def _prune_invalid_ids(sparse_ids, sparse_weights): return sparse_ids, sparse_weights -class _IndicatorColumn(_DenseColumn, +class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn, collections.namedtuple('_IndicatorColumn', ['categorical_column'])): """Represents a one-hot column for use in deep networks. @@ -2966,15 +3020,53 @@ class _IndicatorColumn(_DenseColumn, Returns: Dense `Tensor` created within `_transform_feature`. + + Raises: + ValueError: If `categorical_column` is a `_SequenceCategoricalColumn`. """ # Do nothing with weight_collections and trainable since no variables are # created in this function. del weight_collections del trainable + if isinstance(self.categorical_column, _SequenceCategoricalColumn): + raise ValueError( + 'In indicator_column: {}. ' + 'categorical_column must not be of type _SequenceCategoricalColumn. ' + 'Suggested fix A: If you wish to use input_layer, use a ' + 'non-sequence categorical_column_with_*. ' + 'Suggested fix B: If you wish to create sequence input, use ' + 'sequence_input_layer instead of input_layer. ' + 'Given (type {}): {}'.format( + self.name, type(self.categorical_column), + self.categorical_column)) # Feature has been already transformed. Return the intermediate # representation created by _transform_feature. return inputs.get(self) + def _get_sequence_dense_tensor( + self, inputs, weight_collections=None, trainable=None): + # Do nothing with weight_collections and trainable since no variables are + # created in this function. + del weight_collections + del trainable + if not isinstance(self.categorical_column, _SequenceCategoricalColumn): + raise ValueError( + 'In indicator_column: {}. ' + 'categorical_column must be of type _SequenceCategoricalColumn ' + 'to use sequence_input_layer. ' + 'Suggested fix: Use one of sequence_categorical_column_with_*. ' + 'Given (type {}): {}'.format( + self.name, type(self.categorical_column), + self.categorical_column)) + # Feature has been already transformed. Return the intermediate + # representation created by _transform_feature. + dense_tensor = inputs.get(self) + sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access + sequence_length = _sequence_length_from_sparse_tensor( + sparse_tensors.id_tensor) + return _SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) + def _verify_static_batch_size_equality(tensors, columns): # bath_size is a tf.Dimension object. @@ -2990,3 +3082,68 @@ def _verify_static_batch_size_equality(tensors, columns): 'Batch size of columns ({}, {}): ({}, {})'.format( columns[bath_size_column_index].name, columns[i].name, expected_batch_size, tensors[i].shape[0])) + + +def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1): + """Returns a [batch_size] Tensor with per-example sequence length.""" + with ops.name_scope(None, 'sequence_length') as name_scope: + row_ids = sp_tensor.indices[:, 0] + column_ids = sp_tensor.indices[:, 1] + column_ids += array_ops.ones_like(column_ids) + seq_length = math_ops.to_int64( + math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements) + # If the last n rows do not have ids, seq_length will have shape + # [batch_size - n]. Pad the remaining values with zeros. + n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1] + padding = array_ops.zeros(n_pad, dtype=seq_length.dtype) + return array_ops.concat([seq_length, padding], axis=0, name=name_scope) + + +class _SequenceCategoricalColumn( + _CategoricalColumn, + collections.namedtuple( + '_SequenceCategoricalColumn', ['categorical_column'])): + """Represents sequences of categorical data.""" + + @property + def name(self): + return self.categorical_column.name + + @property + def _parse_example_spec(self): + return self.categorical_column._parse_example_spec # pylint: disable=protected-access + + def _transform_feature(self, inputs): + return self.categorical_column._transform_feature(inputs) # pylint: disable=protected-access + + @property + def _num_buckets(self): + return self.categorical_column._num_buckets # pylint: disable=protected-access + + def _get_sparse_tensors(self, inputs, weight_collections=None, + trainable=None): + sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access + id_tensor = sparse_tensors.id_tensor + weight_tensor = sparse_tensors.weight_tensor + # Expands final dimension, so that embeddings are not combined during + # embedding lookup. + check_id_rank = check_ops.assert_equal( + array_ops.rank(id_tensor), 2, + data=[ + 'Column {} expected ID tensor of rank 2. '.format(self.name), + 'id_tensor shape: ', array_ops.shape(id_tensor)]) + with ops.control_dependencies([check_id_rank]): + id_tensor = sparse_ops.sparse_reshape( + id_tensor, + shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0)) + if weight_tensor is not None: + check_weight_rank = check_ops.assert_equal( + array_ops.rank(weight_tensor), 2, + data=[ + 'Column {} expected weight tensor of rank 2.'.format(self.name), + 'weight_tensor shape:', array_ops.shape(weight_tensor)]) + with ops.control_dependencies([check_weight_rank]): + weight_tensor = sparse_ops.sparse_reshape( + weight_tensor, + shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0)) + return _CategoricalColumn.IdWeightPair(id_tensor, weight_tensor) diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index d3d8c9c154fbfcc9613acce4e1bdab7df2e7d56d..782b505d6c1d0b576b7734f088c4d2c9625f4be2 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -181,7 +181,7 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False): TypeError: if shape is incorrectly specified or unsupported. """ ctx = context.context() - if not ctx.in_graph_mode(): + if ctx.executing_eagerly(): t = convert_to_eager_tensor(value, ctx, dtype) if shape is None: return t diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index 99ae8b24f11c4955379ae532ba7b921ebec63385..0edae92fd4a86e7d10a180ce64364d3ea552bf60 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -343,7 +343,9 @@ tf_export("uint8").export_constant(__name__, "uint8") uint16 = DType(types_pb2.DT_UINT16) tf_export("uint16").export_constant(__name__, "uint16") uint32 = DType(types_pb2.DT_UINT32) +tf_export("uint32").export_constant(__name__, "uint32") uint64 = DType(types_pb2.DT_UINT64) +tf_export("uint64").export_constant(__name__, "uint32") int16 = DType(types_pb2.DT_INT16) tf_export("int16").export_constant(__name__, "int16") int8 = DType(types_pb2.DT_INT8) diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py index 3172f3c2c3d259d2c3f2b340b101aef043d0fc33..392a4f65c6e62c3cb70f8e02a9b24f015a09f649 100644 --- a/tensorflow/python/framework/framework_lib.py +++ b/tensorflow/python/framework/framework_lib.py @@ -48,6 +48,7 @@ ## Graph collections @@add_to_collection +@@add_to_collections @@get_collection @@get_collection_ref @@GraphKeys @@ -92,6 +93,7 @@ from tensorflow.python.framework.ops import get_default_graph from tensorflow.python.framework.ops import reset_default_graph from tensorflow.python.framework.ops import GraphKeys from tensorflow.python.framework.ops import add_to_collection +from tensorflow.python.framework.ops import add_to_collections from tensorflow.python.framework.ops import get_collection from tensorflow.python.framework.ops import get_collection_ref from tensorflow.python.framework.ops import convert_to_tensor diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index caa604999c2fad4ce111d910a77e4b99399c11ca..14d72d8a3de7e22bee4f9961c2f66044c217f641 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -489,10 +489,10 @@ class _DefinedFunction(object): # Adds this function into 'g'. # pylint: disable=protected-access - if context.in_graph_mode(): - g._add_function(self) - else: + if context.executing_eagerly(): context.context().add_function_def(self.definition) + else: + g._add_function(self) # pylint: enable=protected-access # Ensures related sub-routines are defined in 'g', too. diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 52052ba77d42fa91692e7699f49898d0c01c22be..65ca801cbe922b36e3bc72bc2fbcd88f66aa5290 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -193,7 +193,7 @@ class FunctionTest(test.TestCase): @function.Defun(dtypes.float32, dtypes.float32) def XSquarePlusOneGrad(x, dy): - dx = functional_ops._symbolic_gradient( + dx = functional_ops.symbolic_gradient( input=[x, dy], Tout=[dtypes.float32], f="XSquarePlusOneFn", name="dx") return dx @@ -295,7 +295,7 @@ class FunctionTest(test.TestCase): # gradient function is (x, y, dz) -> (dx, dy). dx's shape # should be the same as x's; and dy's shape should be the same # as y's. - dx, dy = functional_ops._symbolic_gradient( + dx, dy = functional_ops.symbolic_gradient( input=[x, y, dz], Tout=[dtypes.float32] * 2, f="Foo") self.assertEqual(x.get_shape(), dx.get_shape()) self.assertEqual(y.get_shape(), dy.get_shape()) diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py index 5a543317e665a940841714fd72d834a430f8406a..910364364c8be84b1a629dbdaae5e69443d07e75 100644 --- a/tensorflow/python/framework/graph_util_impl.py +++ b/tensorflow/python/framework/graph_util_impl.py @@ -235,7 +235,7 @@ def convert_variables_to_constants(sess, variable_names = [] variable_dict_names = [] for node in inference_graph.node: - if node.op in ["Variable", "VariableV2"]: + if node.op in ["Variable", "VariableV2", "VarHandleOp"]: variable_name = node.name if ((variable_names_whitelist is not None and variable_name not in variable_names_whitelist) or @@ -243,7 +243,10 @@ def convert_variables_to_constants(sess, variable_name in variable_names_blacklist)): continue variable_dict_names.append(variable_name) - variable_names.append(variable_name + ":0") + if node.op == "VarHandleOp": + variable_names.append(variable_name + "/Read/ReadVariableOp:0") + else: + variable_names.append(variable_name + ":0") if variable_names: returned_variables = sess.run(variable_names) else: @@ -266,6 +269,17 @@ def convert_variables_to_constants(sess, tensor=tensor_util.make_tensor_proto( data, dtype=dtype.type, shape=data.shape))) how_many_converted += 1 + elif input_node.op == "ReadVariableOp" and ( + input_node.input[0] in found_variables): + # The preceding branch converts all VarHandleOps of ResourceVariables to + # constants, so we need to convert the associated ReadVariableOps to + # Identity ops. + output_node.op = "Identity" + output_node.name = input_node.name + output_node.input.extend([input_node.input[0]]) + output_node.attr["T"].CopyFrom(input_node.attr["dtype"]) + if "_class" in input_node.attr: + output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) else: output_node.CopyFrom(input_node) output_graph_def.node.extend([output_node]) diff --git a/tensorflow/python/framework/graph_util_test.py b/tensorflow/python/framework/graph_util_test.py index 0421837d49de753d642aed59d1524619a243dcb8..b618152b0256fd043dc7259960d867278ba55b0a 100644 --- a/tensorflow/python/framework/graph_util_test.py +++ b/tensorflow/python/framework/graph_util_test.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import gen_state_ops from tensorflow.python.ops import math_ops # pylint: disable=unused-import from tensorflow.python.ops import math_ops as math_ops_lib +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -47,46 +48,46 @@ class DeviceFunctionsTest(test.TestCase): def testTwoDeviceFunctions(self): with ops.Graph().as_default() as g: - var_0 = gen_state_ops._variable( + var_0 = gen_state_ops.variable( shape=[1], dtype=dtypes.float32, name="var_0", container="", shared_name="") with g.device(test_device_func_pin_variable_to_cpu): - var_1 = gen_state_ops._variable( + var_1 = gen_state_ops.variable( shape=[1], dtype=dtypes.float32, name="var_1", container="", shared_name="") - var_2 = gen_state_ops._variable( + var_2 = gen_state_ops.variable( shape=[1], dtype=dtypes.float32, name="var_2", container="", shared_name="") - var_3 = gen_state_ops._variable( + var_3 = gen_state_ops.variable( shape=[1], dtype=dtypes.float32, name="var_3", container="", shared_name="") with g.device(test_device_func_pin_variable_to_cpu): - var_4 = gen_state_ops._variable( + var_4 = gen_state_ops.variable( shape=[1], dtype=dtypes.float32, name="var_4", container="", shared_name="") with g.device("/device:GPU:0"): - var_5 = gen_state_ops._variable( + var_5 = gen_state_ops.variable( shape=[1], dtype=dtypes.float32, name="var_5", container="", shared_name="") - var_6 = gen_state_ops._variable( + var_6 = gen_state_ops.variable( shape=[1], dtype=dtypes.float32, name="var_6", @@ -226,52 +227,62 @@ class DeviceFunctionsTest(test.TestCase): constant_graph_def.library) def testConvertVariablesToConsts(self): - with ops.Graph().as_default(): - variable_node = variables.Variable(1.0, name="variable_node") - _ = variables.Variable(1.0, name="unused_variable_node") - output_node = math_ops_lib.multiply( - variable_node, 2.0, name="output_node") - with session.Session() as sess: - init = variables.initialize_variables([variable_node]) - sess.run(init) - output = sess.run(output_node) - self.assertNear(2.0, output, 0.00001) - variable_graph_def = sess.graph.as_graph_def() - # First get the constant_graph_def when variable_names_whitelist is set, - # note that if variable_names_whitelist is not set an error will be - # thrown because unused_variable_node is not initialized. - constant_graph_def = graph_util.convert_variables_to_constants( - sess, - variable_graph_def, ["output_node"], - variable_names_whitelist=set(["variable_node"])) + self._test_variable_to_const_conversion(use_resource=False) - # Then initialize the unused variable, and get another - # constant_graph_def when variable_names_whitelist is not set. - sess.run(variables.global_variables_initializer()) - constant_graph_def_without_variable_whitelist = ( - graph_util.convert_variables_to_constants(sess, variable_graph_def, - ["output_node"])) - - # The unused variable should be cleared so the two graphs should be - # equivalent. - self.assertEqual( - str(constant_graph_def), - str(constant_graph_def_without_variable_whitelist)) - - # Test variable name black list. This should result in the variable not - # being a const. - sess.run(variables.global_variables_initializer()) - constant_graph_def_with_blacklist = ( - graph_util.convert_variables_to_constants( - sess, - variable_graph_def, ["output_node"], - variable_names_blacklist=set(["variable_node"]))) - variable_node = None - for node in constant_graph_def_with_blacklist.node: - if node.name == "variable_node": - variable_node = node - self.assertIsNotNone(variable_node) - self.assertEqual(variable_node.op, "VariableV2") + def testConvertResourceVariablesToConsts(self): + self._test_variable_to_const_conversion(use_resource=True) + + def _test_variable_to_const_conversion(self, use_resource): + with ops.Graph().as_default(): + with variable_scope.variable_scope("", use_resource=use_resource): + variable_node = variable_scope.get_variable( + "variable_node", initializer=1.0) + another_variable = variable_scope.get_variable( + "unused_variable_node", initializer=1.0) + output_node = math_ops_lib.multiply( + variable_node, 2.0, name="output_node") + with session.Session() as sess: + sess.run(variable_node.initializer) + output = sess.run(output_node) + self.assertNear(2.0, output, 0.00001) + variable_graph_def = sess.graph.as_graph_def() + # First get the constant_graph_def when variable_names_whitelist is + # set, note that if variable_names_whitelist is not set an error will + # be thrown because unused_variable_node is not initialized. + constant_graph_def = graph_util.convert_variables_to_constants( + sess, + variable_graph_def, ["output_node"], + variable_names_whitelist=set(["variable_node"])) + + # Then initialize the unused variable, and get another + # constant_graph_def when variable_names_whitelist is not set. + sess.run(another_variable.initializer) + constant_graph_def_without_variable_whitelist = ( + graph_util.convert_variables_to_constants( + sess, variable_graph_def, ["output_node"])) + + # The unused variable should be cleared so the two graphs should be + # equivalent. + self.assertEqual( + str(constant_graph_def), + str(constant_graph_def_without_variable_whitelist)) + + # Test variable name black list. This should result in the variable + # not being a const. + constant_graph_def_with_blacklist = ( + graph_util.convert_variables_to_constants( + sess, + variable_graph_def, ["output_node"], + variable_names_blacklist=set(["variable_node"]))) + variable_node = None + for node in constant_graph_def_with_blacklist.node: + if node.name == "variable_node": + variable_node = node + self.assertIsNotNone(variable_node) + if use_resource: + self.assertEqual(variable_node.op, "VarHandleOp") + else: + self.assertEqual(variable_node.op, "VariableV2") # Now we make sure the variable is now a constant, and that the graph still # produces the expected result. @@ -279,8 +290,9 @@ class DeviceFunctionsTest(test.TestCase): _ = importer.import_graph_def(constant_graph_def, name="") self.assertEqual(4, len(constant_graph_def.node)) for node in constant_graph_def.node: - self.assertNotEqual("Variable", node.op) - self.assertNotEqual("VariableV2", node.op) + self.assertNotIn( + node.op, + ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"]) with session.Session() as sess: output_node = sess.graph.get_tensor_by_name("output_node:0") output = sess.run(output_node) diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index 6ecc1a40ae14760dd39242aaf595b32a9decdc9f..4ea34d7bb2831845aec1f40fcdb7f64a8f8c438a 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -301,14 +301,17 @@ def _ProcessNewOps(graph): colocation_pairs = {} for new_op in graph._add_new_tf_operations(compute_devices=False): # pylint: disable=protected-access + original_device = new_op.device + new_op._set_device('') # pylint: disable=protected-access colocation_names = _GetColocationNames(new_op) if colocation_names: colocation_pairs[new_op] = colocation_names - # Don't apply this op's device function, since colocation constraints - # override device functions. Note that this op's device may still be set - # by the loop below. + # Don't set a device for this op, since colocation constraints override + # device functions and the original device. Note that this op's device may + # still be set by the loop below. + # TODO(skyewm): why does it override the original device? else: - with _MaybeDevice(new_op.device): + with _MaybeDevice(original_device): graph._apply_device_functions(new_op) # pylint: disable=protected-access # The following loop populates the device field of ops that are colocated @@ -475,32 +478,39 @@ def import_graph_def(graph_def, _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements) - with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: - try: - with errors.raise_exception_on_not_ok_status() as status: - results = c_api.TF_GraphImportGraphDefWithResults( - graph._c_graph, serialized, options, status) # pylint: disable=protected-access - except errors.InvalidArgumentError as e: - # Convert to ValueError for backwards compatibility. - raise ValueError(str(e)) - - _ProcessNewOps(graph) + # _ProcessNewOps mutates the new operations. _lock ensures a Session.run + # call cannot occur between creating the TF_Operations in the + # TF_GraphImportGraphDefWithResults call and mutating the them in + # _ProcessNewOps. + with graph._lock: # pylint: disable=protected-access + with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: + try: + with errors.raise_exception_on_not_ok_status() as status: + results = c_api.TF_GraphImportGraphDefWithResults( + graph._c_graph, serialized, options, status) # pylint: disable=protected-access + except errors.InvalidArgumentError as e: + # Convert to ValueError for backwards compatibility. + raise ValueError(str(e)) + + # Create _DefinedFunctions for any imported functions. + # + # We do this by creating _DefinedFunctions directly from `graph_def`, and + # adding them to `graph`. Adding an existing function to a TF_Graph is a + # no-op, so this only has the effect of updating the Python state (usually + # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). + # + # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph + # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph + # TODO(b/74620627): move this after _ProcessNewOps outside the lock once + # _USE_C_SHAPES is removed. + if graph_def.library and graph_def.library.function: + # pylint: disable=protected-access + functions = function._from_library(graph_def.library) + for f in functions: + f.add_to_graph(graph) + # pylint: enable=protected-access - # Create _DefinedFunctions for any imported functions. - # - # We do this by creating _DefinedFunctions directly from `graph_def`, and - # adding them to `graph`. Adding an existing function to a TF_Graph is a - # no-op, so this only has the effect of updating the Python state (usually - # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). - # - # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph - # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph - if graph_def.library and graph_def.library.function: - # pylint: disable=protected-access - functions = function._from_library(graph_def.library) - for f in functions: - f.add_to_graph(graph) - # pylint: enable=protected-access + _ProcessNewOps(graph) # Treat input mappings that don't appear in the graph as an error, because # they are likely to be due to a typo. diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index bf5d9fe0936882c242198bdc7118f9f3a4e79260..6593b1718434fd2035133f65aa08b17774e9e806 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -680,6 +680,49 @@ class ImportGraphDefTest(test.TestCase): "list { s: 'loc:@imported_graph/A' }", b.node_def.attr["_class"]) + def testColocationAndDevice(self): + # A and B are colocated, device set on A. + original_graph_def = self._MakeGraphDef(""" + node { name: 'A' op: 'None' device: '/device:CPU:0' attr { + key: '_class' + value { list { s: 'loc:@A' } } + } } + node { name: 'B' op: 'None' attr { + key: '_class' + value { list { s: 'loc:@A' } } + } }""") + + with ops.Graph().as_default(): + a, b = importer.import_graph_def(original_graph_def, + return_elements=["A", "B"], + name="") + self.assertEqual(a.device, "/device:CPU:0") + self.assertEqual(b.device, "/device:CPU:0") + self.assertEqual(a.colocation_groups(), [b"loc:@A"]) + self.assertEqual(b.colocation_groups(), [b"loc:@A"]) + + # A and B are colocated, device set on B. + original_graph_def = self._MakeGraphDef(""" + node { name: 'A' op: 'None' attr { + key: '_class' + value { list { s: 'loc:@A' } } + } } + node { name: 'B' op: 'None' device: '/device:CPU:0' attr { + key: '_class' + value { list { s: 'loc:@A' } } + } }""") + + with ops.Graph().as_default(): + a, b = importer.import_graph_def(original_graph_def, + return_elements=["A", "B"], + name="") + # TODO(skyewm): this behavior seems inconsistent with the above. Why is + # B's device ignored? + self.assertEqual(a.device, "") + self.assertEqual(b.device, "") + self.assertEqual(a.colocation_groups(), [b"loc:@A"]) + self.assertEqual(b.colocation_groups(), [b"loc:@A"]) + def testColocationWithDeviceFn(self): original_graph_def = self._MakeGraphDef(""" node { name: 'A' op: 'None' attr { diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py index 4c1bd736d727e974375ad9008a579361137fb9d6..391b17720c6f5925fe6cab02ac2a784257177a27 100644 --- a/tensorflow/python/framework/meta_graph.py +++ b/tensorflow/python/framework/meta_graph.py @@ -695,7 +695,7 @@ def import_scoped_meta_graph(meta_graph_or_file, Raises: ValueError: If the graph_def contains unbound inputs. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise ValueError("Exporting/importing meta graphs is not supported when " "eager execution is enabled.") if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): @@ -737,7 +737,9 @@ def import_scoped_meta_graph(meta_graph_or_file, import_scope or "", mark_as_used=False) importer.import_graph_def( - input_graph_def, name=(import_scope or ""), input_map=input_map, + input_graph_def, + name=(import_scope or scope_to_prepend_to_names), + input_map=input_map, producer_op_list=producer_op_list) # Restores all the other collections. @@ -856,7 +858,7 @@ def export_scoped_meta_graph(filename=None, Raises: ValueError: When the `GraphDef` is larger than 2GB. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise ValueError("Exporting/importing meta graphs is not supported when " "Eager Execution is enabled.") graph = graph or ops.get_default_graph() diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py index 19dcd6a1b34741290b2578d93b79883c103fdb1b..5d5fb037fc217849ea32102bf60796c47d565f3b 100644 --- a/tensorflow/python/framework/meta_graph_test.py +++ b/tensorflow/python/framework/meta_graph_test.py @@ -537,6 +537,21 @@ class ScopedMetaGraphTest(test.TestCase): self.assertEqual(list(imported_variables.values())[0].name, "foo/bar/myvar:0") + def testScopedImportUnderNameScopeNoVarScope(self): + graph = ops.Graph() + with graph.as_default(): + variables.Variable(initial_value=1.0, trainable=True, name="myvar") + meta_graph_def, _ = meta_graph.export_scoped_meta_graph(graph=graph) + + graph = ops.Graph() + with graph.as_default(): + with ops.name_scope("foo"): + imported_variables = meta_graph.import_scoped_meta_graph( + meta_graph_def) + self.assertEqual(len(imported_variables), 1) + self.assertEqual(list(imported_variables.values())[0].name, + "foo/myvar:0") + def testImportsUsingSameScopeName(self): with ops.Graph().as_default(): variables.Variable(0, name="v") @@ -905,20 +920,6 @@ class ExportImportAcrossScopesTest(test.TestCase): with variable_scope.variable_scope("importA/keepA"): graph_fn(use_resource=use_resource) - if use_resource: - # Bringing in collections that contain ResourceVariables will adds ops - # to the graph the first time a variable is encountered, so mimic the - # same behavior. - seen_variables = set() - for collection_key in sorted([ - ops.GraphKeys.GLOBAL_VARIABLES, - ops.GraphKeys.TRAINABLE_VARIABLES, - ]): - for var in expected_graph.get_collection(collection_key): - if var not in seen_variables: - var._read_variable_op() - seen_variables.add(var) - result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0] expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0] diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 5a14ea417626265b8cee4bbec025f2bb9f5d4307..25a951a2de10c0c549b02c686a02415c7ce5b2ec 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -63,6 +63,7 @@ from tensorflow.python.util.tf_export import tf_export # in code or via the environment variable. This will be removed once all # functionality is supported and there's no performance penalty with it enabled. _USE_C_API = os.getenv("TF_C_API_GRAPH_CONSTRUCTION", "0") is not "0" +_USE_C_SHAPES = os.getenv("TF_C_API_GRAPH_CONSTRUCTION_SHAPES", "0") is not "0" def tensor_id(tensor): @@ -369,7 +370,7 @@ class Tensor(_TensorLike): """ graph = self._op._graph._c_graph # pylint: disable=protected-access - if graph: + if graph and _USE_C_SHAPES: with errors.raise_exception_on_not_ok_status() as status: num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output(), status) @@ -395,10 +396,10 @@ class Tensor(_TensorLike): "Tensor._shape cannot be assigned, use Tensor.set_shape instead.") def __iter__(self): - if context.in_graph_mode(): + if not context.executing_eagerly(): raise TypeError( - "`Tensor` objects are not iterable when eager execution is not " - "enabled. To iterate over this tensor use `tf.map_fn`.") + "Tensor objects are not iterable when eager execution is not " + "enabled. To iterate over this tensor use tf.map_fn.") shape = self._shape_tuple() if shape is None: raise TypeError("Cannot iterate over a tensor with unknown shape.") @@ -466,9 +467,13 @@ class Tensor(_TensorLike): ValueError: If `shape` is not compatible with the current shape of this tensor. """ - if not self._op._graph._c_graph: # pylint: disable=protected-access # ASIM + if not _USE_C_SHAPES: # pylint: disable=protected-access self._shape_val = self._shape_val.merge_with(shape) - return + + if not self._op._graph._c_graph: return + + # Update C shape even if _USE_C_SHAPES = False, since we still want + # set_shape to be reflected in the C API graph for when we run it. if not isinstance(shape, tensor_shape.TensorShape): shape = tensor_shape.TensorShape(shape) dim_list = [] @@ -772,7 +777,7 @@ class _EagerTensorBase(Tensor): six.raise_from(core._status_to_exception(e.code, e.message), None) # Record the copy on tape and define backprop copy as well. - if not context.in_graph_mode(): + if context.executing_eagerly(): self_device = self.device def grad_fun(dresult): return [dresult._copy(device_name=self_device)] @@ -782,7 +787,11 @@ class _EagerTensorBase(Tensor): @property def shape(self): - return tensor_shape.TensorShape(self._shape_tuple()) + if self._tensor_shape is None: # pylint: disable=access-member-before-definition + # `_tensor_shape` is declared and defined in the definition of + # `EagerTensor`, in C. + self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple()) + return self._tensor_shape def get_shape(self): """Alias of Tensor.shape.""" @@ -829,41 +838,51 @@ class _EagerTensorBase(Tensor): def set_shape(self, shape): if not self.shape.is_compatible_with(shape): raise ValueError( - "EagerTensor's shape %s is not compatible with supplied shape %s" % + "Tensor's shape %s is not compatible with supplied shape %s" % (self.shape, shape)) # Methods not supported / implemented for Eager Tensors. @property def op(self): - raise AttributeError("op not supported for Eager Tensors.") + raise AttributeError( + "Tensor.op is meaningless when eager execution is enabled.") @property def graph(self): - raise AttributeError("graph not supported for Eager Tensors.") + raise AttributeError( + "Tensor.graph is meaningless when eager execution is enabled.") @property def name(self): - raise AttributeError("name not supported for Eager Tensors.") + raise AttributeError( + "Tensor.name is meaningless when eager execution is enabled.") @property def value_index(self): - raise AttributeError("value_index not supported for Eager Tensors.") + raise AttributeError( + "Tensor.value_index is meaningless when eager execution is enabled.") def consumers(self): - raise NotImplementedError("consumers not supported for Eager Tensors.") + raise NotImplementedError( + "Tensor.consumers is meaningless when eager execution is enabled.") def _add_consumer(self, consumer): - raise NotImplementedError("_add_consumer not supported for Eager Tensors.") + raise NotImplementedError( + "_add_consumer not supported when eager execution is enabled.") def _as_node_def_input(self): raise NotImplementedError( - "_as_node_def_input not supported for Eager Tensors.") + "_as_node_def_input not supported when eager execution is enabled.") def _as_tf_output(self): - raise NotImplementedError("_as_tf_output not supported for Eager Tensors.") + raise NotImplementedError( + "_as_tf_output not supported when eager execution is enabled.") def eval(self, feed_dict=None, session=None): - raise NotImplementedError("eval not supported for Eager Tensors.") + raise NotImplementedError( + "eval is not supported when eager execution is enabled, " + "is .numpy() what you're looking for?" + ) # This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and @@ -989,7 +1008,7 @@ def internal_convert_to_tensor(value, """ if ctx is None: ctx = context.context() - if ctx.in_eager_mode(): + if ctx.executing_eagerly(): # Fast path for EagerTensors that don't need any conversion. if isinstance(value, EagerTensor): # Note that we don't check that value's dtype matches the dtype @@ -1897,7 +1916,8 @@ class Operation(object): tensor._add_consumer(self) # pylint: disable=protected-access self._recompute_node_def() - def _update_input(self, index, tensor): + # TODO(skyewm): Remove `update_dtype` when we enable the C API. + def _update_input(self, index, tensor, update_dtype=True): """Update the input to this operation at the given index. NOTE: This is for TF internal use only. Please don't use it. @@ -1905,6 +1925,7 @@ class Operation(object): Args: index: the index of the input to update. tensor: the Tensor to be used as the input at the given index. + update_dtype: If `False`, the type for this input is not updated. Raises: TypeError: if tensor is not a Tensor, @@ -1924,7 +1945,8 @@ class Operation(object): else: self._inputs_val[index].consumers().remove(self) self._inputs_val[index] = tensor - self._input_types_val[index] = tensor.dtype + if update_dtype: + self._input_types_val[index] = tensor.dtype tensor._add_consumer(self) # pylint: disable=protected-access self._recompute_node_def() @@ -2486,7 +2508,7 @@ def _set_shapes_for_outputs(op): def set_shapes_for_outputs(op): """Set the shapes for op's outputs.""" - if op._c_op: # pylint: disable=protected-access + if op._c_op and _USE_C_SHAPES: # pylint: disable=protected-access return _set_shapes_for_outputs_c_api(op) else: return _set_shapes_for_outputs(op) @@ -2690,21 +2712,24 @@ class Graph(object): def __init__(self): """Creates a new, empty Graph.""" - # Protects the core state that may be accessed by multiple readers. - # Only state that can be returned via public accessors (`as_graph_def()`, - # `get_operations()`, `as_graph_element()`, `get_collection()`, and - # `get_collection_ref()`) is by the lock. Thread-safety is provided on a - # best-effort basis to support buggy programs, and is not guaranteed by the - # public `tf.Graph` API. + # Protects core state that can be returned via public accessors, as well as + # synchronizes Session.run calls with methods that create and mutate ops + # (e.g. Graph.create_op()). This synchronization is necessary because it's + # illegal to modify an operation after it's been run. Thread-safety is + # provided on a best-effort basis to support buggy programs, and is not + # guaranteed by the public `tf.Graph` API. + # + # The lock must be reentrant because create_op can be called recursively due + # to control flow. Without a reentrant lock, many methods would also need a + # "locked" version or parameter (including generated code). + # # NOTE(mrry): This does not protect the various stacks. A warning will # be reported if these are used from multiple threads - self._lock = threading.Lock() + self._lock = threading.RLock() self._nodes_by_id = dict() # GUARDED_BY(self._lock) self._next_id_counter = 0 # GUARDED_BY(self._lock) self._nodes_by_name = dict() # GUARDED_BY(self._lock) self._version = 0 # GUARDED_BY(self._lock) - # Current name stack: uniquified names - self._name_stack = "" # Maps a name used in the graph to the next id to use for that name. self._names_in_use = {} self._stack_state_is_thread_local = False @@ -2763,6 +2788,9 @@ class Graph(object): # being called inside function definitions behave as if they were seeing the # actual outside graph). self._graph_key = "grap-key-%d/" % (uid(),) + # A string with the last reduction method passed to + # losses.compute_weighted_loss(), or None. + self._last_loss_reduction = None self._container = "" self._registered_ops = op_def_registry.get_registered_ops() @@ -2776,7 +2804,6 @@ class Graph(object): c_api.SetRequireShapeInferenceFns(self._c_graph, False) else: self._scoped_c_graph = None - self._variable_creator_stack = [] # TODO(apassos) remove once the C API is used by default. def _use_c_api_hack(self): @@ -2817,17 +2844,26 @@ class Graph(object): # frozen, and this functionality is still not ready for public visibility. @tf_contextlib.contextmanager def _variable_creator_scope(self, creator): + # This step makes a copy of the existing stack, and it also initializes + # self._thread_local._variable_creator_stack if it doesn't exist yet. old = list(self._variable_creator_stack) - self._variable_creator_stack.append(creator) + self._thread_local._variable_creator_stack.append(creator) try: yield finally: - self._variable_creator_stack = old + self._thread_local._variable_creator_stack = old # Note: this method is private because the API of tf.Graph() is public and # frozen, and this functionality is still not ready for public visibility. - def _get_variable_creator_stack(self): - return list(self._variable_creator_stack) + @property + def _variable_creator_stack(self): + if not hasattr(self._thread_local, "_variable_creator_stack"): + self._thread_local._variable_creator_stack = [] + return list(self._thread_local._variable_creator_stack) + + @_variable_creator_stack.setter + def _variable_creator_stack(self, variable_creator_stack): + self._thread_local._variable_creator_stack = variable_creator_stack def _extract_stack(self): """A lightweight, extensible re-implementation of traceback.extract_stack. @@ -3259,17 +3295,34 @@ class Graph(object): input_ops = set([t.op for t in inputs]) control_inputs = self._control_dependencies_for_inputs(input_ops) - ret = Operation( - node_def, - self, - inputs=inputs, - output_types=dtypes, - control_inputs=control_inputs, - input_types=input_types, - original_op=self._default_original_op, - op_def=op_def) - self._create_op_helper(ret, compute_shapes=compute_shapes, - compute_device=compute_device) + # _create_op_helper mutates the new Operation. _lock ensures a Session.run + # call cannot occur between creating and mutating the op. + with self._lock: + ret = Operation( + node_def, + self, + inputs=inputs, + output_types=dtypes, + control_inputs=control_inputs, + input_types=input_types, + original_op=self._default_original_op, + op_def=op_def) + + # TODO(vrv): Instead of eagerly filling in shape property for every op, + # only populate the shape when requested. + # + # TODO(skyewm): unlike in the original Python implementation, the C API + # always computes shape information (even for function calls, which the + # original Python shape inference code doesn't handle). Deprecate the + # compute_shapes argument. + # + # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES + # is removed + if (ret._c_op and _USE_C_SHAPES) or compute_shapes: # pylint: disable=protected-access + set_shapes_for_outputs(ret) + + self._create_op_helper(ret, compute_shapes=compute_shapes, + compute_device=compute_device) return ret def _create_op_from_tf_operation(self, c_op, compute_device=True): @@ -3301,15 +3354,6 @@ class Graph(object): def _create_op_helper(self, op, compute_shapes=True, compute_device=True): """Common logic for creating an op in this graph.""" - # TODO(vrv): Instead of eagerly filling in shape property for every op, only - # populate the shape when requested. - # - # TODO(skyewm): unlike in the original Python implementation, the C API - # always computes shape information (even for function calls, which the - # original Python shape inference code doesn't handle). Deprecate the - # compute_shapes argument. - if op._c_op or compute_shapes: # pylint: disable=protected-access - set_shapes_for_outputs(op) # TODO(b/XXXX): move to Operation.__init__ once _USE_C_API flag is removed. self._add_op(op) @@ -3414,6 +3458,12 @@ class Graph(object): ] for op in new_ops: + # Operations created by the C API always retrieve shapes from the C API so + # we preserve the shapes of ops created in import_graph_def (from the + # "_output_shapes" attr of the imported NodeDef). + # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES + # is removed. + _set_shapes_for_outputs_c_api(op) new_control_inputs = self._control_dependencies_for_inputs(op.inputs) # pylint: disable=protected-access op._add_control_inputs(new_control_inputs) @@ -3861,6 +3911,17 @@ class Graph(object): finally: self._default_original_op = old_original_op + @property + def _name_stack(self): + # This may be called from a thread where name_stack doesn't yet exist. + if not hasattr(self._thread_local, "_name_stack"): + self._thread_local._name_stack = "" + return self._thread_local._name_stack + + @_name_stack.setter + def _name_stack(self, name_stack): + self._thread_local._name_stack = name_stack + # pylint: disable=g-doc-return-or-yield,line-too-long @tf_contextlib.contextmanager def name_scope(self, name): @@ -4777,15 +4838,15 @@ def device(device_name_or_function): Raises: RuntimeError: If eager execution is enabled and a function is passed in. """ - if context.in_graph_mode(): - return get_default_graph().device(device_name_or_function) - else: + if context.executing_eagerly(): # TODO(agarwal): support device functions in EAGER mode. if callable(device_name_or_function): raise RuntimeError( "tf.device does not support functions when eager execution " "is enabled.") return context.device(device_name_or_function) + else: + return get_default_graph().device(device_name_or_function) @tf_export("container") @@ -4804,13 +4865,20 @@ def container(container_name): @tf_export("colocate_with") def colocate_with(op, ignore_existing=False): - if context.in_graph_mode(): - return get_default_graph().colocate_with(op, ignore_existing) - else: + if context.executing_eagerly(): if op is not None: return device(op.device) else: return _NullContextmanager() + else: + default_graph = get_default_graph() + if isinstance(op, EagerTensor): + if default_graph.building_function: + op = internal_convert_to_tensor(op) + else: + raise ValueError("Encountered an Eager-defined Tensor during graph " + "construction, but a function was not being built.") + return default_graph.colocate_with(op, ignore_existing) @tf_export("control_dependencies") @@ -4820,20 +4888,29 @@ def control_dependencies(control_inputs): See @{tf.Graph.control_dependencies} for more details. + When eager execution is enabled, any callable object in the `control_inputs` + list will be called. + Args: control_inputs: A list of `Operation` or `Tensor` objects which must be executed or computed before running the operations defined in the context. Can also be `None` to clear the control - dependencies. + dependencies. If eager execution is enabled, any callable object in the + `control_inputs` list will be called. Returns: A context manager that specifies control dependencies for all operations constructed within the context. """ - if context.in_graph_mode(): - return get_default_graph().control_dependencies(control_inputs) - else: + if context.executing_eagerly(): + if control_inputs: + # Excute any pending callables. + for control in control_inputs: + if callable(control): + control() return _NullContextmanager() + else: + return get_default_graph().control_dependencies(control_inputs) class _DefaultStack(threading.local): @@ -5054,11 +5131,12 @@ class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access @tf_contextlib.contextmanager def get_controller(self, default): try: - context.context_stack.push(default.building_function, default.as_default) + context.context().context_switches.push(default.building_function, + default.as_default) with super(_DefaultGraphStack, self).get_controller(default) as g: yield g finally: - context.context_stack.pop() + context.context().context_switches.pop() _default_graph_stack = _DefaultGraphStack() @@ -5084,87 +5162,130 @@ def init_scope(): graph function. Here, a context is defined as either a graph or an eager context. Every context switch, i.e., every installation of a graph as the default graph and every switch into eager mode, is logged in a - thread-local stack called the `context_stack`; the log entry for a + thread-local stack called `context_switches`; the log entry for a context switch is popped from the stack when the context is exited. - Entering an `init_scope` is equivalent to crawling up the - `context_stack`, finding the first context that is not building a graph - function, and entering it. A caveat is that if graph mode is enabled - but the default graph stack is empty, then entering an `init_scope` - will simply install a fresh graph as the default one. + Entering an `init_scope` is equivalent to crawling up + `context_switches`, finding the first context that is not building a + graph function, and entering it. A caveat is that if graph mode is + enabled but the default graph stack is empty, then entering an + `init_scope` will simply install a fresh graph as the default one. (3) The gradient tape is paused while the scope is active. """ # pylint: enable=g-doc-return-or-yield,line-too-long - in_graph_mode = context.in_graph_mode() - # Retrieve the active name scope: entering an `init_scope` preserves - # the name scope of the current context. - if in_graph_mode: + if context.executing_eagerly(): + # Fastpath. + with tape.stop_recording(): + yield + else: + # Retrieve the active name scope: entering an `init_scope` preserves + # the name scope of the current context. default_graph = get_default_graph() scope = default_graph.get_name_scope() - else: - scope = context.context().scope_name - if scope and scope[-1] != '/': - # Names that end with trailing slashes are treated by `name_scope` as - # absolute. - scope = scope + '/' - - outer_context = None - if in_graph_mode and not _default_graph_stack.stack: - outer_context = default_graph.as_default - else: - for stack_entry in reversed(context.context_stack.stack): - if not stack_entry.is_building_function: - outer_context = stack_entry.enter_context_fn - break + if scope and scope[-1] != '/': + # Names that end with trailing slashes are treated by `name_scope` as + # absolute. + scope = scope + '/' + + outer_context = None + if not _default_graph_stack.stack: + # If the default graph stack is empty, then we cannot be building a + # function. Install the global graph (which, in this case, is also the + # default graph) as the outer context. + if default_graph.building_function: + raise RuntimeError("The global graph is building a function.") + outer_context = default_graph.as_default + else: + # Find a context that is not building a function. + for stack_entry in reversed(context.context().context_switches.stack): + if not stack_entry.is_building_function: + outer_context = stack_entry.enter_context_fn + break + + if outer_context is None: + # As a last resort, obtain the global default graph; this graph doesn't + # necessarily live on the graph stack (and hence it doesn't necessarily + # live on the context stack), but it is stored in the graph stack's + # encapsulating object. + outer_context = _default_graph_stack._GetGlobalDefaultGraph().as_default # pylint: disable=protected-access - if outer_context is None: - raise AssertionError("All graphs are building functions, and no " + if outer_context is None: + # Sanity check; this shouldn't be triggered. + raise RuntimeError("All graphs are building functions, and no " "eager context was previously active.") - try: with outer_context(), name_scope(scope), control_dependencies( None), tape.stop_recording(): yield - finally: - pass -def enable_eager_execution(config=None, device_policy=None): - """Enables, for the rest of the lifetime of this program, eager execution. +@tf_export("enable_eager_execution") +def enable_eager_execution(config=None, device_policy=None, + execution_mode=None): + """Enables eager execution for the lifetime of this program. - If not called immediately on startup risks creating breakage and bugs. + Eager execution provides an imperative interface to TensorFlow. With eager + execution enabled, TensorFlow functions execute operations immediately (as + opposed to adding to a graph to be executed later in a @{tf.Session}) and + return concrete values (as opposed to symbolic references to a node in a + computational graph). - Example: + For example: ```python - tfe.enable_eager_execution() + tf.enable_eager_execution() # After eager execution is enabled, operations are executed as they are - # defined and `Tensor`s hold concrete values, which can be accessed as - # `numpy.ndarray`s through the `numpy()` method. + # defined and Tensor objects hold concrete values, which can be accessed as + # numpy.ndarray`s through the numpy() method. assert tf.multiply(6, 7).numpy() == 42 ``` + Eager execution cannot be enabled after TensorFlow APIs have been used to + create or execute graphs. It is typically recommended to invoke this function + at program startup and not in a library (as most libraries should be usable + both with and without eager execution). + Args: - config: (Optional.) A `ConfigProto` protocol buffer with configuration - options for the Context. Note that a lot of these options may be - currently unimplemented or irrelevant when eager execution is enabled. - device_policy: (Optional.) What policy to use when trying to run an - operation on a device with inputs which are not on that device. + config: (Optional.) A @{tf.ConfigProto} to use to configure the environment + in which operations are executed. Note that @{tf.ConfigProto} is also + used to configure graph execution (via @{tf.Session}) and many options + within `tf.ConfigProto` are not implemented (or are irrelevant) when + eager execution is enabled. + device_policy: (Optional.) Policy controlling how operations requiring + inputs on a specific device (e.g., a GPU 0) handle inputs on a different + device (e.g. GPU 1 or CPU). When set to None, an appropriate value will be + picked automatically. The value picked may change between TensorFlow + releases. Valid values: - tfe.DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not - correct. - tfe.DEVICE_PLACEMENT_WARN: copies the tensors which are not on the - right device but raises a warning. - tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might - hide performance problems. - tfe.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors, - raising errors on the other ones. + + - tf.contrib.eager.DEVICE_PLACEMENT_EXPLICIT: raises an error if the + placement is not correct. + + - tf.contrib.eager.DEVICE_PLACEMENT_WARN: copies the tensors which are not + on the right device but logs a warning. + + - tf.contrib.eager.DEVICE_PLACEMENT_SILENT: silently copies the tensors. + Note that this may hide performance problems as there is no notification + provided when operations are blocked on the tensor being copied between + devices. + + - tf.contrib.eager.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies + int32 tensors, raising errors on the other ones. + execution_mode: (Optional.) Policy controlling how operations dispatched are + actually executed. When set to None, an appropriate value will be picked + automatically. The value picked may change between TensorFlow releases. + Valid values: + + - tf.contrib.eager.SYNC: executes each operation synchronously. + + - tf.contrib.eager.ASYNC: executes each operation asynchronously. These + operations may return "non-ready" handles. Raises: - ValueError: If trying to create a context after using graph operations - or if trying to create a context with nontrivial options which differ - from those of the existing context. + ValueError: If eager execution is enabled after creating/executing a + TensorFlow graph, or if options provided conflict with a previous call + to this function. """ if config is not None and not isinstance(config, config_pb2.ConfigProto): raise TypeError( @@ -5174,8 +5295,12 @@ def enable_eager_execution(config=None, device_policy=None): context.DEVICE_PLACEMENT_SILENT, context.DEVICE_PLACEMENT_SILENT_FOR_INT32): raise ValueError( - "device_policy must be one of None, tfe.DEVICE_PLACEMENT_*" + "device_policy must be one of None, tf.contrib.eager.DEVICE_PLACEMENT_*" ) + if execution_mode not in (None, context.SYNC, context.ASYNC): + raise ValueError( + "execution_mode must be one of None, tf.contrib.eager.SYNC, " + "tf.contrib.eager.ASYNC") # pylint: disable=protected-access if context._default_mode == context.GRAPH_MODE: graph_mode_has_been_used = ( @@ -5183,30 +5308,29 @@ def enable_eager_execution(config=None, device_policy=None): _default_graph_stack._global_default_graph is not None) if graph_mode_has_been_used: raise ValueError( - "tfe.enable_eager_execution has to be called at program startup.") + "tf.enable_eager_execution must be called at program startup.") context._default_mode = context.EAGER_MODE if context._context is None: - context._context = context.Context(config=config, - device_policy=device_policy) - if context.context_stack.stack: - raise AssertionError("Invariant violated: The context stack must " - "be empty when eager execution is enabled.") - # Log that eager execution has been enabled by pushing an entry onto the - # context stack; this entry won't ever be popped, as it's impossible to - # disable eager execution - context.context_stack.push(False, context.eager_mode) - elif ((config is not None and config is not context._context._config) - or (device_policy is not None - and device_policy is not context._context._device_policy)): + context._context = context.Context( + config=config, + device_policy=device_policy, + execution_mode=execution_mode) + elif ((config is not None and config is not context._context._config) or + (device_policy is not None and + device_policy is not context._context._device_policy) or + (execution_mode is not None and + execution_mode is not context._context._execution_mode)): raise ValueError("Trying to change the options of an active eager" " execution. Context config: %s, specified config:" - " %s. Context device policy: %s; specified device" - " policy: %s." % (config, context._context._config, - device_policy, - context._context._device_policy)) + " %s. Context device policy: %s, specified device" + " policy: %s. Context execution mode: %s, " + " specified execution mode %s." % + (context._context._config, config, + context._context._device_policy, device_policy, + context._context._execution_mode, execution_mode)) else: raise ValueError( - "tfe.enable_eager_execution has to be called at program startup.") + "tf.enable_eager_execution must be called at program startup.") def eager_run(main=None, argv=None): @@ -5290,6 +5414,8 @@ def get_name_scope(): Returns: A string representing the current name scope. """ + if context.executing_eagerly(): + return context.context().scope_name.rstrip("/") return get_default_graph().get_name_scope() @@ -5544,7 +5670,7 @@ def add_to_collection(name, value): """ get_default_graph().add_to_collection(name, value) - +@tf_export("add_to_collections") def add_to_collections(names, value): """Wrapper for `Graph.add_to_collections()` using the default graph. @@ -5666,7 +5792,7 @@ class name_scope(object): # pylint: disable=invalid-name self._default_name = default_name self._values = values self._ctx = context.context() - self._in_eager_mode = self._ctx.in_eager_mode() + self._in_eager_mode = self._ctx.executing_eagerly() def __enter__(self): """Start the scope block. @@ -5740,6 +5866,9 @@ def strip_name_scope(name, export_scope): is None. """ if export_scope: + if export_scope[-1] == "/": + export_scope = export_scope[:-1] + try: # Strips export_scope/, export_scope///, # ^export_scope/, loc:@export_scope/. @@ -5765,6 +5894,9 @@ def prepend_name_scope(name, import_scope): is None. """ if import_scope: + if import_scope[-1] == "/": + import_scope = import_scope[:-1] + try: str_to_replace = r"([\^]|loc:@|^)(.*)" return re.sub(str_to_replace, r"\1" + import_scope + r"/\2", @@ -5845,10 +5977,11 @@ def get_from_proto_function(collection_name): def _assert_collection_is_ok(collection_name): - if context.in_eager_mode(): + if context.executing_eagerly(): if collection_name in GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access - raise ValueError("When Eager Execution is enabled, variable " - "collections are not supported.") + raise ValueError( + "variable collections are not supported when eager execution is enabled." + ) def _operation_conversion_error(op, dtype=None, name=None, as_ref=False): diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index a141fe6340c70efde84411db4efb1f80cb0a61c5..aa51391871f4c12d34b86311cc5b8ea9aabd5434 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -763,6 +763,7 @@ class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase): self.assertEqual(g.get_operation_by_name("myop"), op) self.assertEqual(g.get_tensor_by_name("myop:0"), op.outputs[0]) + @test_util.enable_c_shapes def testShape(self): g = ops.Graph() with g.as_default(): @@ -1555,6 +1556,35 @@ class MultithreadedGraphStateTest(test_util.TensorFlowTestCase): input: "^ColocateWithMe_2" } """, gd) + def testNameStack(self): + + class NameSettingThread(self.TestThread): + + def run(self): + with g.name_scope("foo"): + op1 = g.create_op("FloatOutput", [], [dtypes.float32]) + self.has_mutated_graph.set() + self.should_continue.wait() + self.should_continue.clear() + op2 = g.create_op("FloatOutput", [], [dtypes.float32]) + self.result = (op1, op2) + + g = ops.Graph() + threads = [NameSettingThread(g, i) for i in range(3)] + for t in threads: + t.start() + t.has_mutated_graph.wait() + t.has_mutated_graph.clear() + + for t in threads: + t.should_continue.set() + t.join() + + suffixes = ["", "_1", "_2"] + for t, s in zip(threads, suffixes): + self.assertEquals("foo" + s + "/FloatOutput", t.result[0].name) + self.assertEquals("foo" + s + "/FloatOutput_1", t.result[1].name) + @test_util.with_c_api class ObjectWithName(object): @@ -1763,7 +1793,13 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): return constant_op.constant(2.0) future.calls = 0 - if context.in_graph_mode(): + if context.executing_eagerly(): + a = constant_op.constant(1.0) + b = future + with ops.control_dependencies([a, b]): + c = constant_op.constant(3.0) + self.assertEqual(future.calls, 1) + else: g = ops.Graph() with g.as_default(): a = constant_op.constant(1.0) @@ -1772,12 +1808,6 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): c = constant_op.constant(3.0) self.assertEqual(c.op.control_inputs, [a.op, b.op]) self.assertEqual(future.calls, 1) - else: - a = constant_op.constant(1.0) - b = future() - with ops.control_dependencies([a, b]): - c = constant_op.constant(3.0) - self.assertEqual(future.calls, 1) def testBasicWithConversion(self): g = ops.Graph() @@ -2150,19 +2180,11 @@ class InitScopeTest(test_util.TensorFlowTestCase): with ops.init_scope(): # Because g is building a function, init_scope should # escape out to the eager context. - self.assertTrue(context.in_eager_mode()) + self.assertTrue(context.executing_eagerly()) # g should be reinstated as the default graph, and the # graph context should be re-entered. self.assertIs(g, ops.get_default_graph()) - self.assertTrue(context.in_graph_mode()) - - def testAllGraphsBuildingFunctionsRaisesError(self): - g = ops.Graph() - g._building_function = True # pylint: disable=protected-access - with g.as_default(): - with self.assertRaises(AssertionError): - with ops.init_scope(): - pass + self.assertFalse(context.executing_eagerly()) def testStaysInEagerWhenOnlyEagerContextActive(self): with context.eager_mode(): @@ -2241,6 +2263,29 @@ class InitScopeTest(test_util.TensorFlowTestCase): self.assertEqual(4, int(compiled_outer(inner=compiled_inner))) self.assertEqual(7, int(compiled_outer(inner=compiled_inner))) + def testFallsBackToGlobalGraphWhenAllGraphsAreBuildingFunctions(self): + with context.graph_mode(): + ops.reset_default_graph() + # This doesn't push anything onto the graph stack, but it does + # set the stack's global graph. + global_graph = ops.get_default_graph() + fn_graph = ops.Graph() + + # pylint: disable=protected-access + fn_graph._building_function = True + self.assertEqual(len(ops._default_graph_stack.stack), 0) + with fn_graph.as_default(): + self.assertEqual(len(ops._default_graph_stack.stack), 1) + with ops.init_scope(): + self.assertGreater(len(ops._default_graph_stack.stack), 1) + dummy = constant_op.constant(1.0) + self.assertEqual(len(ops._default_graph_stack.stack), 1) + # Note that the global graph is _not_ on the graph stack. + self.assertEqual(len(ops._default_graph_stack.stack), 0) + # Ensure that `dummy` was added to the global graph. + self.assertEqual(global_graph, dummy.graph) + # pylint: enable=protected-access + def testInstallsDefaultGraphWhenGraphStackIsEmptyInGraphMode(self): with context.graph_mode(): # pylint: disable=protected-access @@ -2262,12 +2307,13 @@ class InitScopeTest(test_util.TensorFlowTestCase): with context.eager_mode(): def foo(): with ops.name_scope("inner"), ops.init_scope(): - if context.in_graph_mode(): - self.assertEqual(ops.get_name_scope(), "inner") - else: + if context.executing_eagerly(): # A trailing slash is always appended when eager execution is # enabled. self.assertEqual(context.context().scope_name, "inner/") + else: + self.assertEqual(ops.get_name_scope(), "inner") + foo() self.assertEqual(ops.get_name_scope(), "") foo_compiled = eager_function.defun(foo) @@ -2877,7 +2923,7 @@ class OutputTypesTest(test_util.TensorFlowTestCase): with g.as_default(): x = constant_op.constant([1, 1, 2, 4, 4, 4, 7, 8, 8], dtype=dtypes.double) - y, _ = gen_array_ops._unique(x) + y, _ = gen_array_ops.unique(x) self.assertEqual([types_pb2.DT_DOUBLE, types_pb2.DT_INT32], y.op._output_types) # pylint: disable=protected-access @@ -2902,6 +2948,9 @@ class EnableEagerExecutionTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(ValueError, "device_policy must be one of"): c = config_pb2.ConfigProto() ops.enable_eager_execution(c, c) + with self.assertRaisesRegexp(ValueError, "execution_mode must be one of"): + c = config_pb2.ConfigProto() + ops.enable_eager_execution(c, execution_mode=c) if __name__ == "__main__": diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index c95149d177990e364c3d6b9daeae5dc535cf0070..9850f0becc69ff1f53b70f0ad2296aead8b5152c 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -75,6 +75,33 @@ bool IsPythonReserved(const string& s) { return kPythonReserved->count(s) > 0; } +bool IsOpWithUnderscorePrefix(const string& s) { + static const std::set* const kUnderscoreOps = new std::set( + {// Lowercase built-in functions and types in Python, from: + // [x for x in dir(__builtins__) if x[0].islower()] except "round". + // These need to be excluded so they don't conflict with actual built-in + // functions since we use '*' imports. + "abs", "all", "any", "apply", "bin", "bool", "buffer", "bytearray", + "bytes", "callable", "chr", "classmethod", "cmp", "coerce", "compile", + "complex", "copyright", "credits", "delattr", "dict", "dir", "divmod", + "enumerate", "eval", "execfile", "exit", "file", "filter", "float", + "format", "frozenset", "getattr", "globals", "hasattr", "hash", "help", + "hex", "id", "input", "int", "intern", "isinstance", "issubclass", + "iter", "len", "license", "list", "locals", "long", "map", "max", + "memoryview", "min", "next", "object", "oct", "open", "ord", "pow", + "print", "property", "quit", "range", "raw_input", "reduce", "reload", + "repr", "reversed", "set", "setattr", "slice", "sorted", "staticmethod", + "str", "sum", "super", "tuple", "type", "unichr", "unicode", "vars", + "xrange", "zip", + // These have the same name as ops defined in Python and might be used + // incorrectly depending on order of '*' imports. + // TODO(annarev): reduce usage of '*' imports and remove these from the + // list. + "fused_batch_norm", "histogram_fixed_width", "stack", + "batch_norm_with_global_normalization"}); + return kUnderscoreOps->count(s) > 0; +} + string AvoidPythonReserved(const string& s) { if (IsPythonReserved(s)) return strings::StrCat(s, "_"); return s; @@ -816,6 +843,7 @@ from tensorflow.python.util.tf_export import tf_export // An op is hidden if either its ApiDef visibility is HIDDEN // or it is in the hidden_ops list. bool is_hidden = api_def->visibility() == ApiDef::HIDDEN; + bool hidden_by_api_def = is_hidden; if (!is_hidden) { for (const string& hidden : hidden_ops) { if (op_def.name() == hidden) { @@ -828,13 +856,22 @@ from tensorflow.python.util.tf_export import tf_export string function_name; python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(), &function_name); - if (is_hidden) function_name = strings::StrCat("_", function_name); - - // When users create custom python wrappers, they may link in the - // default op registry by accident, and because they can't - // enumerate all 'hidden' symbols, this guard is to prevent - // instantiating a python reserved word in their wrapper. - if (python_op_gen_internal::IsPythonReserved(function_name)) { + bool is_reserved = python_op_gen_internal::IsPythonReserved(function_name); + + // Prefix an op with underscore if the op is listed in hidden_ops or + // name is reserved or it is of the exceptions in IsOpWithUnderscorePrefix. + // Do not add underscores to ops set to HIDDEN in ApiDef otherwise. + // TODO(annarev): don't prefix with underscores even if op is in hidden_ops. + if (is_hidden) { + if (!hidden_by_api_def || is_reserved || + python_op_gen_internal::IsOpWithUnderscorePrefix(function_name)) { + function_name = strings::StrCat("_", function_name); + } + } else if (is_reserved) { + // When users create custom python wrappers, they may link in the + // default op registry by accident, and because they can't + // enumerate all 'hidden' symbols, this guard is to prevent + // instantiating a python reserved word in their wrapper. continue; } diff --git a/tensorflow/python/framework/python_op_gen_internal.h b/tensorflow/python/framework/python_op_gen_internal.h index 4319e5a7820b33283df8153fdc76e0e567813a17..e0cfb05f4bdf8afd09957c62a9ba3af1fd0882a6 100644 --- a/tensorflow/python/framework/python_op_gen_internal.h +++ b/tensorflow/python/framework/python_op_gen_internal.h @@ -29,6 +29,9 @@ namespace python_op_gen_internal { // Returns true if s is a Python keyword or built-in. bool IsPythonReserved(const string& s); +// Whether the op should be prefixed with underscore. +bool IsOpWithUnderscorePrefix(const string& s); + // Add a _ to the end of s if necessary to avoid a Python keyword or built-in. string AvoidPythonReserved(const string& s); diff --git a/tensorflow/python/framework/random_seed.py b/tensorflow/python/framework/random_seed.py index 1e74a790a3fb0c72b7c0fb1127ffac95f386d85e..b724432e00b0d11de86a0fff9ff31758ad36479f 100644 --- a/tensorflow/python/framework/random_seed.py +++ b/tensorflow/python/framework/random_seed.py @@ -52,20 +52,20 @@ def get_seed(op_seed): A tuple of two integers that should be used for the local seed of this operation. """ - is_graph_mode = context.in_graph_mode() + eager = context.executing_eagerly() - if is_graph_mode: - global_seed = ops.get_default_graph().seed - else: + if eager: global_seed = context.global_seed() + else: + global_seed = ops.get_default_graph().seed if global_seed is not None: if op_seed is None: # pylint: disable=protected-access - if is_graph_mode: - op_seed = ops.get_default_graph()._last_id - else: + if eager: op_seed = context.internal_operation_seed() + else: + op_seed = ops.get_default_graph()._last_id seeds = _truncate_seed(global_seed), _truncate_seed(op_seed) else: @@ -176,7 +176,7 @@ def set_random_seed(seed): Args: seed: integer. """ - if context.in_graph_mode(): - ops.get_default_graph().seed = seed - else: + if context.executing_eagerly(): context.set_global_seed(seed) + else: + ops.get_default_graph().seed = seed diff --git a/tensorflow/python/framework/random_seed_test.py b/tensorflow/python/framework/random_seed_test.py index b4c98ab8b289c850c6171425167bb17606a4162d..194492268631abfa911bd45f13a302c09a2c8bda 100644 --- a/tensorflow/python/framework/random_seed_test.py +++ b/tensorflow/python/framework/random_seed_test.py @@ -40,13 +40,13 @@ class RandomSeedTest(test.TestCase): ((2**31 - 1, 0), (0, 2**31 - 1)), # Don't wrap to (0, 0) either ((0, 2**31 - 1), (0, 2**31 - 1)), # Wrapping for the other argument ] - if context.in_graph_mode(): - # 0 will be the default_graph._lastid. - test_cases.append(((1, None), (1, 0))) - else: + if context.executing_eagerly(): # operation seed is random number generated based on global seed. # it's not tested due to possibility of platform or version difference. pass + else: + # 0 will be the default_graph._lastid. + test_cases.append(((1, None), (1, 0))) for tc in test_cases: tinput, toutput = tc[0], tc[1] random_seed.set_random_seed(tinput[0]) diff --git a/tensorflow/python/framework/smart_cond.py b/tensorflow/python/framework/smart_cond.py new file mode 100644 index 0000000000000000000000000000000000000000..c7ff23e4ff809ed7bc57259fa3ec9feb921b5a71 --- /dev/null +++ b/tensorflow/python/framework/smart_cond.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""smart_cond and related utilties.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python import pywrap_tensorflow as c_api +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import control_flow_ops + + +def smart_cond(pred, true_fn=None, false_fn=None, name=None): + """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. + + If `pred` is a bool or has a constant value, we return either `true_fn()` + or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. + + Arguments: + pred: A scalar determining whether to return the result of `true_fn` or + `false_fn`. + true_fn: The callable to be performed if pred is true. + false_fn: The callable to be performed if pred is false. + name: Optional name prefix when using `tf.cond`. + + Returns: + Tensors returned by the call to either `true_fn` or `false_fn`. + + Raises: + TypeError: If `true_fn` or `false_fn` is not callable. + """ + if not callable(true_fn): + raise TypeError("`true_fn` must be callable.") + if not callable(false_fn): + raise TypeError("`false_fn` must be callable.") + + pred_value = smart_constant_value(pred) + if pred_value is not None: + if pred_value: + return true_fn() + else: + return false_fn() + else: + return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn, + name=name) + + +def smart_constant_value(pred): + """Return the bool value for `pred`, or None if `pred` had a dynamic value. + + Arguments: + pred: A scalar, either a Python bool or tensor. + + Returns: + True or False if `pred` has a constant boolean value, None otherwise. + + Raises: + TypeError: If `pred` is not a Tensor or bool. + """ + if pred in {0, 1}: # Accept 1/0 as valid boolean values + pred_value = bool(pred) + elif isinstance(pred, bool): + pred_value = pred + elif isinstance(pred, ops.Tensor): + pred_value = tensor_util.constant_value(pred) + # TODO(skyewm): consider folding this into tensor_util.constant_value when + # _USE_C_API is removed (there may be performance and correctness bugs, so I + # wanted to limit the change hidden behind _USE_C_API). + # pylint: disable=protected-access + if pred_value is None and ops._USE_C_API: + with errors.raise_exception_on_not_ok_status() as status: + pred_value = c_api.TF_TryEvaluateConstant_wrapper( + pred.graph._c_graph, pred._as_tf_output(), status) + # pylint: enable=protected-access + + else: + raise TypeError("`pred` must be a Tensor, or a Python bool, or 1 or 0. " + "Found instead: %s" % pred) + return pred_value + + +def smart_case(pred_fn_pairs, default=None, exclusive=False, name="smart_case"): + """Like tf.case, except attempts to statically evaluate predicates. + + If any predicate in `pred_fn_pairs` is a bool or has a constant value, the + associated callable will be called or omitted depending on its value. + Otherwise this functions like tf.case. + + Args: + pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a + callable which returns a list of tensors. + default: Optional callable that returns a list of tensors. + exclusive: True iff at most one predicate is allowed to evaluate to `True`. + name: A name for this operation (optional). + + Returns: + The tensors returned by the first pair whose predicate evaluated to True, or + those returned by `default` if none does. + + Raises: + TypeError: If `pred_fn_pairs` is not a list/dictionary. + TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. + TypeError: If `fns[i]` is not callable for any i, or `default` is not + callable. + """ + return control_flow_ops._case_helper( # pylint: disable=protected-access + smart_cond, pred_fn_pairs, default, exclusive, name, + allow_python_preds=True) diff --git a/tensorflow/python/framework/smart_cond_test.py b/tensorflow/python/framework/smart_cond_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1170a41c99995ae875e58a2d5491e05bc1e40df6 --- /dev/null +++ b/tensorflow/python/framework/smart_cond_test.py @@ -0,0 +1,166 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +def raise_exception(): + raise RuntimeError("did not expect to be called") + + +@test_util.with_c_api +class SmartCondTest(test_util.TensorFlowTestCase): + + def testTrue(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(2) + y = constant_op.constant(5) + z = smart_cond.smart_cond(True, lambda: math_ops.multiply(x, 16), + lambda: math_ops.multiply(y, 5)) + self.assertEqual(z.eval(), 32) + + def testFalse(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(4) + y = constant_op.constant(3) + z = smart_cond.smart_cond(False, lambda: math_ops.multiply(x, 16), + lambda: math_ops.multiply(y, 3)) + self.assertEqual(z.eval(), 9) + + def testUnknown(self): + with ops.Graph().as_default(): + with session.Session(): + x = array_ops.placeholder(dtype=dtypes.int32) + y = smart_cond.smart_cond(x > 0, lambda: constant_op.constant(1), + lambda: constant_op.constant(2)) + self.assertEqual(y.eval(feed_dict={x: 1}), 1) + self.assertEqual(y.eval(feed_dict={x: -1}), 2) + + def testEval(self): + # Constant expression evaluation only works with the C API enabled. + if not ops._USE_C_API: return + + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(1) + y = constant_op.constant(2) + # x * y > 0 can be evaluated at graph construction time, so the false + # branch shouldn't be evaluated at all. + z = smart_cond.smart_cond(x * y > 0, lambda: constant_op.constant(1), + raise_exception) + self.assertEqual(z.eval(feed_dict={x: 1}), 1) + + def testPlaceholderWithDefault(self): + with ops.Graph().as_default(): + with session.Session(): + x = array_ops.placeholder_with_default(1, shape=()) + y = smart_cond.smart_cond(x > 0, lambda: constant_op.constant(1), + lambda: constant_op.constant(2)) + self.assertEqual(y.eval(), 1) + self.assertEqual(y.eval(feed_dict={x: -1}), 2) + + def testMissingArg1(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(1) + with self.assertRaises(TypeError): + smart_cond.smart_cond(True, false_fn=lambda: x) + + def testMissingArg2(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(1) + with self.assertRaises(TypeError): + smart_cond.smart_cond(True, lambda: x) + + +@test_util.with_c_api +class SmartCaseTest(test_util.TensorFlowTestCase): + + def testTrue(self): + x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) + conditions = [(True, lambda: constant_op.constant(1)), + (x == 0, raise_exception)] + y = smart_cond.smart_case(conditions, default=raise_exception, + exclusive=False) + z = smart_cond.smart_case(conditions, default=raise_exception, + exclusive=True) + with session.Session() as sess: + # No feed_dict necessary + self.assertEqual(sess.run(y), 1) + self.assertEqual(sess.run(z), 1) + + def testFalse(self): + conditions = [(False, raise_exception)] + y = smart_cond.smart_case(conditions, + default=lambda: constant_op.constant(1), + exclusive=False) + z = smart_cond.smart_case(conditions, + default=lambda: constant_op.constant(1), + exclusive=True) + with session.Session() as sess: + self.assertEqual(sess.run(y), 1) + self.assertEqual(sess.run(z), 1) + + def testMix(self): + # Constant expression evaluation only works with the C API enabled. + if not ops._USE_C_API: return + + x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) + y = constant_op.constant(10) + conditions = [(x > 1, lambda: constant_op.constant(1)), + (y < 1, raise_exception), + (False, raise_exception), + (True, lambda: constant_op.constant(3))] + z = smart_cond.smart_case(conditions, default=raise_exception) + with session.Session() as sess: + self.assertEqual(sess.run(z, feed_dict={x: 2}), 1) + self.assertEqual(sess.run(z, feed_dict={x: 0}), 3) + + +@test_util.with_c_api +class SmartConstantValueTest(test_util.TensorFlowTestCase): + + # TODO(skyewm): this is essentially a regression test for + # TF_TryEvaluateConstant, and is not really a valid smart_constant_value test + # (smart_constant_value is only supposed to return bools). Move the + # TF_TryEvaluateConstant call to tensor_util.constant_value and make this a + # constant_value test instead. + def testCond(self): + with ops.Graph().as_default(): + pred = array_ops.placeholder_with_default(True, shape=()) + x = control_flow_ops.cond(pred, + lambda: constant_op.constant(1), + lambda: constant_op.constant(2)) + self.assertIsNone(smart_cond.smart_constant_value(x)) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index 222071cb9e87aa0fdd9788d1c72df4c66ea61547..af2a5b1a7ef9a70c0baf5d02257951803a7a76fa 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -156,7 +156,7 @@ class Dimension(object): ``` Args: - other: Another Dimension. + other: Another Dimension, or a value accepted by `as_dimension`. Returns: A Dimension whose value is the sum of `self` and `other`. @@ -167,6 +167,17 @@ class Dimension(object): else: return Dimension(self._value + other.value) + def __radd__(self, other): + """Returns the sum of `other` and `self`. + + Args: + other: Another Dimension, or a value accepted by `as_dimension`. + + Returns: + A Dimension whose value is the sum of `self` and `other`. + """ + return self + other + def __sub__(self, other): """Returns the subtraction of `other` from `self`. @@ -180,10 +191,10 @@ class Dimension(object): ``` Args: - other: Another Dimension. + other: Another Dimension, or a value accepted by `as_dimension`. Returns: - A Dimension whose value is the subtraction of sum of `other` from `self`. + A Dimension whose value is the subtraction of `other` from `self`. """ other = as_dimension(other) if self._value is None or other.value is None: @@ -191,6 +202,21 @@ class Dimension(object): else: return Dimension(self._value - other.value) + def __rsub__(self, other): + """Returns the subtraction of `self` from `other`. + + Args: + other: Another Dimension, or a value accepted by `as_dimension`. + + Returns: + A Dimension whose value is the subtraction of `self` from `other`. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return Dimension(None) + else: + return Dimension(other.value - self._value) + def __mul__(self, other): """Returns the product of `self` and `other`. @@ -204,17 +230,32 @@ class Dimension(object): ``` Args: - other: Another Dimension. + other: Another Dimension, or a value accepted by `as_dimension`. Returns: A Dimension whose value is the product of `self` and `other`. """ - other = as_dimension(other) + try: + other = as_dimension(other) + except (TypeError, ValueError): + return NotImplemented + if self._value is None or other.value is None: return Dimension(None) else: return Dimension(self._value * other.value) + def __rmul__(self, other): + """Returns the product of `self` and `other`. + + Args: + other: Another Dimension, or a value accepted by `as_dimension`. + + Returns: + A Dimension whose value is the product of `self` and `other`. + """ + return self * other + def __floordiv__(self, other): """Returns the quotient of `self` and `other` rounded down. @@ -228,17 +269,35 @@ class Dimension(object): ``` Args: - other: Another `Dimension`. + other: Another Dimension, or a value accepted by `as_dimension`. Returns: A `Dimension` whose value is the integer quotient of `self` and `other`. """ - other = as_dimension(other) + try: + other = as_dimension(other) + except (TypeError, ValueError): + return NotImplemented if self._value is None or other.value is None: return Dimension(None) else: return Dimension(self._value // other.value) + def __rfloordiv__(self, other): + """Returns the quotient of `other` and `self` rounded down. + + Args: + other: Another Dimension, or a value accepted by `as_dimension`. + + Returns: + A `Dimension` whose value is the integer quotient of `self` and `other`. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return Dimension(None) + else: + return Dimension(other.value // self._value) + def __div__(self, other): """DEPRECATED: Use `__floordiv__` via `x // y` instead. @@ -256,7 +315,7 @@ class Dimension(object): return self // other def __mod__(self, other): - """Returns `self` modulo `other. + """Returns `self` modulo `other`. Dimension moduli are computed as follows: @@ -268,17 +327,35 @@ class Dimension(object): ``` Args: - other: Another Dimension. + other: Another Dimension, or a value accepted by `as_dimension`. Returns: A Dimension whose value is `self` modulo `other`. """ - other = as_dimension(other) + try: + other = as_dimension(other) + except (TypeError, ValueError): + return NotImplemented if self._value is None or other.value is None: return Dimension(None) else: return Dimension(self._value % other.value) + def __rmod__(self, other): + """Returns `other` modulo `self`. + + Args: + other: Another Dimension, or a value accepted by `as_dimension`. + + Returns: + A Dimension whose value is `other` modulo `self`. + """ + try: + other = as_dimension(other) + except (TypeError, ValueError): + return NotImplemented + return other % self + def __lt__(self, other): """Returns True if `self` is known to be less than `other`. @@ -456,6 +533,7 @@ class TensorShape(object): else: # Got a list of dimensions self._dims = [as_dimension(d) for d in dims_iter] + self._ndims = None def __repr__(self): return "TensorShape(%r)" % self._dims @@ -473,19 +551,26 @@ class TensorShape(object): """Returns a list of Dimensions, or None if the shape is unspecified.""" return self._dims + @dims.setter + def dims(self, dims): + self._dims = dims + self._ndims = None + @property def ndims(self): """Returns the rank of this shape, or None if it is unspecified.""" if self._dims is None: return None else: - return len(self._dims) + if self._ndims is None: + self._ndims = len(self._dims) + return self._ndims def __len__(self): """Returns the rank of this shape, or raises ValueError if unspecified.""" if self._dims is None: raise ValueError("Cannot take the length of Shape with unknown rank.") - return len(self._dims) + return self.ndims def __bool__(self): """Returns True if this shape contains non-zero information.""" diff --git a/tensorflow/python/framework/tensor_shape_test.py b/tensorflow/python/framework/tensor_shape_test.py index fffd86c7a6241b8be92ad33852da244ab9b5284d..4e8ce4d889c4ef0c6e56806587a64e8f9be7e10a 100644 --- a/tensorflow/python/framework/tensor_shape_test.py +++ b/tensorflow/python/framework/tensor_shape_test.py @@ -34,12 +34,20 @@ class DimensionTest(test_util.TensorFlowTestCase): self.assertEqual(tensor_shape.Dimension(15), dim + tensor_shape.Dimension(3)) self.assertEqual(tensor_shape.Dimension(15), dim + 3) + self.assertEqual(tensor_shape.Dimension(15), 3 + dim) + self.assertEqual(tensor_shape.Dimension(9), dim - 3) + self.assertEqual(tensor_shape.Dimension(1), 13 - dim) self.assertEqual(tensor_shape.Dimension(24), dim * tensor_shape.Dimension(2)) self.assertEqual(tensor_shape.Dimension(24), dim * 2) + self.assertEqual(tensor_shape.Dimension(24), 2 * dim) + self.assertEqual([4] * 12, [4] * dim) + self.assertEqual(12 * [4], dim * [4]) + self.assertEqual(tensor_shape.Dimension(24), 2 * dim) self.assertEqual( tensor_shape.Dimension(6), dim // tensor_shape.Dimension(2)) self.assertEqual(tensor_shape.Dimension(6), dim // 2) + self.assertEqual(tensor_shape.Dimension(0), 2 // dim) self.assertEqual(tensor_shape.Dimension(12), dim.merge_with(tensor_shape.Dimension(12))) self.assertEqual(tensor_shape.Dimension(12), dim.merge_with(12)) @@ -176,6 +184,14 @@ class DimensionTest(test_util.TensorFlowTestCase): self.assertEqual(str(tensor_shape.Dimension(7)), "7") self.assertEqual(str(tensor_shape.Dimension(None)), "?") + def testMod(self): + four = tensor_shape.Dimension(4) + nine = tensor_shape.Dimension(9) + self.assertEqual(nine % four, 1) + # test both __mod__ and __rmod__. + self.assertEqual(nine % 4, 1) + self.assertEqual(4 % nine, 4) + class ShapeTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/framework/tensor_spec.py b/tensorflow/python/framework/tensor_spec.py index a0411bc3d9b4b2b87e5a31e9f201154f28ccf1cc..6676cfcaa334e02208d9ec346de7d266c4700f24 100644 --- a/tensorflow/python/framework/tensor_spec.py +++ b/tensorflow/python/framework/tensor_spec.py @@ -65,6 +65,11 @@ class TensorSpec(object): else: raise ValueError("`tensor` should be a tf.Tensor") + @classmethod + def is_bounded(cls): + del cls + return False + @property def shape(self): """Returns the `TensorShape` that represents the shape of the tensor.""" @@ -80,6 +85,16 @@ class TensorSpec(object): """Returns the name of the described tensor.""" return self._name + @property + def is_discrete(self): + """Whether spec is discrete.""" + return self.dtype.is_integer + + @property + def is_continuous(self): + """Whether spec is continuous.""" + return self.dtype.is_floating + def is_compatible_with(self, spec_or_tensor): """True if the shape and dtype of `spec_or_tensor` are compatible.""" return (self._dtype.is_compatible_with(spec_or_tensor.dtype) and @@ -95,6 +110,9 @@ class TensorSpec(object): def __ne__(self, other): return not self == other + def __reduce__(self): + return TensorSpec, (self._shape, self._dtype, self._name) + class BoundedTensorSpec(TensorSpec): """A `TensorSpec` that specifies minimum and maximum values. @@ -163,19 +181,16 @@ class BoundedTensorSpec(TensorSpec): self._maximum = np.array(maximum, dtype=self.dtype.as_numpy_dtype()) self._maximum.setflags(write=False) + @classmethod + def is_bounded(cls): + del cls + return True + @classmethod def from_spec(cls, spec): dtype = dtypes.as_dtype(spec.dtype) - if dtype in [dtypes.float64, dtypes.float32]: - # Avoid under/over-flow for `dtype.maximum - dtype.minimum`. - low = dtype.min / 2 - high = dtype.max / 2 - else: - low = dtype.min - high = dtype.max - - minimum = getattr(spec, "minimum", low) - maximum = getattr(spec, "maximum", high) + minimum = getattr(spec, "minimum", dtype.min) + maximum = getattr(spec, "maximum", dtype.max) return BoundedTensorSpec(spec.shape, dtype, minimum, maximum, spec.name) @property @@ -198,4 +213,7 @@ class BoundedTensorSpec(TensorSpec): return (tensor_spec_eq and np.allclose(self.minimum, other.minimum) and np.allclose(self.maximum, other.maximum)) + def __reduce__(self): + return BoundedTensorSpec, (self._shape, self._dtype, self._minimum, + self._maximum, self._name) diff --git a/tensorflow/python/framework/tensor_spec_test.py b/tensorflow/python/framework/tensor_spec_test.py index 54ca4d9a19c2e1c879c05cfb828085951bdd8444..2e9e43e12279fe833d640d4163c5474c398e70cd 100644 --- a/tensorflow/python/framework/tensor_spec_test.py +++ b/tensorflow/python/framework/tensor_spec_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import pickle + import numpy as np from tensorflow.python.framework import constant_op @@ -127,6 +129,26 @@ class TensorSpecTest(test_util.TensorFlowTestCase): self.assertEqual(bounded_spec.dtype, spec.dtype) self.assertEqual(bounded_spec.name, spec.name) + def testIsDiscrete(self): + discrete_spec = tensor_spec.TensorSpec((1, 2), dtypes.int32) + continuous_spec = tensor_spec.TensorSpec((1, 2), dtypes.float32) + self.assertTrue(discrete_spec.is_discrete) + self.assertFalse(continuous_spec.is_discrete) + + def testIsContinuous(self): + discrete_spec = tensor_spec.TensorSpec((1, 2), dtypes.int32) + continuous_spec = tensor_spec.TensorSpec((1, 2), dtypes.float32) + self.assertFalse(discrete_spec.is_continuous) + self.assertTrue(continuous_spec.is_continuous) + + def testIsBounded(self): + unbounded_spec = tensor_spec.TensorSpec((1, 2), dtypes.int32) + self.assertFalse(unbounded_spec.is_bounded()) + + def testSerialization(self): + desc = tensor_spec.TensorSpec([1, 5], dtypes.float32, "test") + self.assertEqual(pickle.loads(pickle.dumps(desc)), desc) + class BoundedTensorSpecTest(test_util.TensorFlowTestCase): @@ -138,6 +160,11 @@ class BoundedTensorSpecTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(ValueError, "not compatible"): tensor_spec.BoundedTensorSpec((3, 5), dtypes.uint8, 0, (1, 1, 1)) + def testIsBounded(self): + bounded_spec = tensor_spec.BoundedTensorSpec( + (1, 2), dtypes.int32, minimum=0, maximum=1) + self.assertTrue(bounded_spec.is_bounded()) + def testMinimumMaximumAttributes(self): spec = tensor_spec.BoundedTensorSpec( (1, 2, 3), dtypes.float32, 0, (5, 5, 5)) @@ -222,6 +249,10 @@ class BoundedTensorSpecTest(test_util.TensorFlowTestCase): self.assertEqual(spec.dtype.max, bounded_spec.maximum) self.assertEqual(spec.name, bounded_spec.name) + def testSerialization(self): + desc = tensor_spec.BoundedTensorSpec([1, 5], dtypes.float32, -1, 1, "test") + self.assertEqual(pickle.loads(pickle.dumps(desc)), desc) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 27afaa074a6becd5c8b7db94be59e8da1611c13a..984bcecdfe05efd79bdf218197c410b14abe3516 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -559,16 +559,16 @@ def MakeNdarray(tensor): if tensor.tensor_content: return (np.frombuffer(tensor.tensor_content, dtype=dtype).copy() .reshape(shape)) - elif tensor_dtype == dtypes.float16: + elif tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16: # the half_val field of the TensorProto stores the binary representation # of the fp16: we need to reinterpret this as a proper float16 if len(tensor.half_val) == 1: tmp = np.array(tensor.half_val[0], dtype=np.uint16) - tmp.dtype = np.float16 + tmp.dtype = tensor_dtype.as_numpy_dtype return np.repeat(tmp, num_elements).reshape(shape) else: tmp = np.fromiter(tensor.half_val, dtype=np.uint16) - tmp.dtype = np.float16 + tmp.dtype = tensor_dtype.as_numpy_dtype return tmp.reshape(shape) elif tensor_dtype == dtypes.float32: if len(tensor.float_val) == 1: @@ -586,8 +586,7 @@ def MakeNdarray(tensor): return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape) elif tensor_dtype in [ dtypes.int32, dtypes.uint8, dtypes.uint16, dtypes.int16, dtypes.int8, - dtypes.qint32, dtypes.quint8, dtypes.qint8, dtypes.qint16, dtypes.quint16, - dtypes.bfloat16 + dtypes.qint32, dtypes.quint8, dtypes.qint8, dtypes.qint16, dtypes.quint16 ]: if len(tensor.int_val) == 1: return np.repeat(np.array(tensor.int_val[0], dtype=dtype), @@ -829,7 +828,7 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name Returns: A `TensorShape` based on the constant value of the given `tensor`. """ - if context.in_eager_mode(): + if context.executing_eagerly(): return tensor_shape.as_shape( [dim if dim != -1 else None for dim in tensor.numpy()]) diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index bea0ee34fd4900cc9d4d5d52348ba4512368e81f..35fff80c61b98e7603d3b7b5df3cabdb59059a72 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -235,6 +235,26 @@ class TensorUtilTest(test.TestCase): self.assertEquals(np.float16, a.dtype) self.assertAllClose(np.array([10.0, 20.0], dtype=np.float16), a) + def testBfloat16(self): + test_type = dtypes.bfloat16.as_numpy_dtype + t = tensor_util.make_tensor_proto(np.array([10.0, 20.0], dtype=test_type)) + # 10.0: 16672 = 010000010(130) 0100000: (1+0/2+1/4) * 2^(130-127) + # 20.0: 16800 = 010000011(131) 0100000: (1+0/2+1/4) * 2^(131-127) + self.assertProtoEquals(""" + dtype: DT_BFLOAT16 + tensor_shape { + dim { + size: 2 + } + } + half_val: 16672 + half_val: 16800 + """, t) + + a = tensor_util.MakeNdarray(t) + self.assertEquals(test_type, a.dtype) + self.assertAllClose(np.array([10.0, 20.0], dtype=test_type), a) + def testInt(self): t = tensor_util.make_tensor_proto(10) self.assertProtoEquals(""" @@ -768,7 +788,7 @@ class ConstantValueTest(test.TestCase): self.assertAllClose(np_val, tensor_util.constant_value(tf_val)) def testUnknown(self): - tf_val = gen_state_ops._variable( + tf_val = gen_state_ops.variable( shape=[3, 4, 7], dtype=dtypes.float32, name="tf_val", diff --git a/tensorflow/python/framework/test_file_system.cc b/tensorflow/python/framework/test_file_system.cc index 094ea6f658ab800736eebce2db7ee80da151a033..6e9915adbb619c5c4891742ddda700da47ed590f 100644 --- a/tensorflow/python/framework/test_file_system.cc +++ b/tensorflow/python/framework/test_file_system.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/null_file_system.h" namespace tensorflow { diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 4fd70039817ad7e0e7077077cb1037625330121c..43106b6e598d464b15d0fe00265ccec906fff9a7 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -53,9 +53,11 @@ from tensorflow.python.eager import tape # pylint: disable=unused-import from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops @@ -205,6 +207,10 @@ def CudaSupportsHalfMatMulAndConv(): return pywrap_tensorflow.CudaSupportsHalfMatMulAndConv() +def IsMklEnabled(): + return pywrap_tensorflow.IsMklEnabled() + + def InstallStackTraceHandler(): pywrap_tensorflow.InstallStacktraceHandler() @@ -331,6 +337,8 @@ def _use_c_api_wrapper(fn, use_c_api, *args, **kwargs): # Make sure default graph reflects prev_value in case next test doesn't call # reset_default_graph(). ops.reset_default_graph() + + # pylint: disable=protected-access @@ -403,6 +411,31 @@ def enable_c_api(fn): return wrapper +def enable_c_shapes(fn): + """Decorator for enabling C shapes on a test. + + Note this enables the C shapes after running the test class's setup/teardown + methods. + + Args: + fn: the function to be wrapped + + Returns: + The wrapped function + """ + + def wrapper(*args, **kwargs): + prev_value = ops._USE_C_SHAPES + # Only use C shapes if the C API is already enabled. + ops._USE_C_SHAPES = ops._USE_C_API + try: + fn(*args, **kwargs) + finally: + ops._USE_C_SHAPES = prev_value + + return wrapper + + # This decorator is a hacky way to run all the test methods in a decorated # class with and without C API enabled. # TODO(iga): Remove this and its uses once we switch to using C API by default. @@ -422,7 +455,8 @@ def with_c_api(cls): # If the C API is already enabled, don't do anything. Some tests break if the # same test is run twice, so this allows us to turn on the C API by default # without breaking these tests. - if ops._USE_C_API: return cls + if ops._USE_C_API: + return cls for name, value in cls.__dict__.copy().items(): if callable(value) and name.startswith("test"): @@ -430,6 +464,35 @@ def with_c_api(cls): return cls +def assert_no_new_pyobjects_executing_eagerly(f): + """Decorator for asserting that no new Python objects persist after a test. + + Runs the test multiple times executing eagerly, first as a warmup and then + several times to let objects accumulate. The warmup helps ignore caches which + do not grow as the test is run repeatedly. + + Useful for checking that there are no missing Py_DECREFs in the C exercised by + a bit of Python. + """ + + def decorator(self, **kwargs): + """Warms up, gets an object count, runs the test, checks for new objects.""" + with context.eager_mode(): + gc.disable() + f(self, **kwargs) + gc.collect() + previous_count = len(gc.get_objects()) + for _ in range(3): + f(self, **kwargs) + gc.collect() + # There should be no new Python objects hanging around. + new_count = len(gc.get_objects()) + self.assertEqual(previous_count, new_count) + gc.enable() + + return decorator + + def assert_no_new_tensors(f): """Decorator for asserting that no new Tensors persist after a test. @@ -451,15 +514,17 @@ def assert_no_new_tensors(f): def decorator(self, **kwargs): """Finds existing Tensors, runs the test, checks for new Tensors.""" - def _is_tensor(obj): + def _is_tensorflow_object(obj): try: - return (isinstance(obj, ops.Tensor) or - isinstance(obj, variables.Variable)) + return isinstance(obj, + (ops.Tensor, variables.Variable, + tensor_shape.Dimension, tensor_shape.TensorShape)) except ReferenceError: # If the object no longer exists, we don't care about it. return False - tensors_before = set(id(obj) for obj in gc.get_objects() if _is_tensor(obj)) + tensors_before = set( + id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj)) outside_graph_key = ops.get_default_graph()._graph_key with ops.Graph().as_default(): # Run the test in a new graph so that collections get cleared when it's @@ -469,11 +534,12 @@ def assert_no_new_tensors(f): # Make an effort to clear caches, which would otherwise look like leaked # Tensors. backprop._zeros_cache.flush() + context.get_default_context().ones_rank_cache().flush() context.get_default_context().scalar_cache().clear() gc.collect() tensors_after = [ obj for obj in gc.get_objects() - if _is_tensor(obj) and id(obj) not in tensors_before + if _is_tensorflow_object(obj) and id(obj) not in tensors_before ] if tensors_after: raise AssertionError(("%d Tensors not deallocated after test: %s" % ( @@ -506,6 +572,30 @@ def assert_no_garbage_created(f): previous_garbage = len(gc.garbage) f(self, **kwargs) gc.collect() + if len(gc.garbage) > previous_garbage: + logging.error( + "The decorated test created work for Python's garbage collector, " + "likely due to a reference cycle. New objects in cycle(s):") + for i, obj in enumerate(gc.garbage[previous_garbage:]): + try: + logging.error("Object %d of %d", i, + len(gc.garbage) - previous_garbage) + + def _safe_object_str(obj): + return "<%s %d>" % (obj.__class__.__name__, id(obj)) + + logging.error(" Object type: %s", _safe_object_str(obj)) + logging.error(" Referrer types: %s", ", ".join( + [_safe_object_str(ref) for ref in gc.get_referrers(obj)])) + logging.error(" Referent types: %s", ", ".join( + [_safe_object_str(ref) for ref in gc.get_referents(obj)])) + logging.error(" Object attribute names: %s", dir(obj)) + logging.error(" Object __str__:") + logging.error(obj) + logging.error(" Object __repr__:") + logging.error(repr(obj)) + except Exception: + logging.error("(Exception while printing object)") # This will fail if any garbage has been created, typically because of a # reference cycle. self.assertEqual(previous_garbage, len(gc.garbage)) @@ -564,6 +654,7 @@ def run_in_graph_and_eager_modes(__unused__=None, # This decorator runs the wrapped test twice. # Reset the test environment between runs. self.tearDown() + self._tempdir = None self.setUp() def run_eager_mode(self, **kwargs): @@ -620,15 +711,23 @@ def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): return 0, 0 return int(match.group(1)), int(match.group(2)) - for local_device in device_lib.list_local_devices(): - if local_device.device_type == "GPU": - if (min_cuda_compute_capability is None or - compute_capability_from_device_desc(local_device.physical_device_desc) - >= min_cuda_compute_capability): + try: + for local_device in device_lib.list_local_devices(): + if local_device.device_type == "GPU": + if (min_cuda_compute_capability is None or + compute_capability_from_device_desc( + local_device.physical_device_desc) >= + min_cuda_compute_capability): + return True + if local_device.device_type == "SYCL" and not cuda_only: return True - if local_device.device_type == "SYCL" and not cuda_only: - return True - return False + return False + except errors_impl.NotFoundError as e: + if not all([x in str(e) for x in ["CUDA", "not find"]]): + raise e + else: + logging.error(str(e)) + return False @contextlib.contextmanager @@ -787,7 +886,7 @@ class TensorFlowTestCase(googletest.TestCase): Returns: tensors numpy values. """ - if context.in_eager_mode(): + if context.executing_eagerly(): return self._eval_helper(tensors) else: sess = ops.get_default_session() @@ -817,9 +916,9 @@ class TensorFlowTestCase(googletest.TestCase): Use the `use_gpu` and `force_gpu` options to control where ops are run. If `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if - `use_gpu` - is True, TensorFlow tries to run as many ops on the GPU as possible. If both - `force_gpu and `use_gpu` are False, all ops are pinned to the CPU. + `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as + possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to + the CPU. Example: ```python @@ -1171,9 +1270,9 @@ class TensorFlowTestCase(googletest.TestCase): msg="Mismatched value: a%s is different from b%s." % (path_str, path_str)) except TypeError as e: - msg = "Error: a%s has %s, but b%s has %s" % ( - path_str, type(a), path_str, type(b)) - e.args = ((e.args[0] + ' : ' + msg,) + e.args[1:]) + msg = "Error: a%s has %s, but b%s has %s" % (path_str, type(a), + path_str, type(b)) + e.args = ((e.args[0] + " : " + msg,) + e.args[1:]) raise def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): @@ -1353,8 +1452,7 @@ class TensorFlowTestCase(googletest.TestCase): """ device1 = pydev.canonical_name(device1) device2 = pydev.canonical_name(device2) - self.assertEqual(device1, device2, - "Devices %s and %s are not equal. %s" % + self.assertEqual(device1, device2, "Devices %s and %s are not equal. %s" % (device1, device2, msg)) # Fix Python 3 compatibility issues diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index a717eb39513ac3369ae133b6090ff82597f12eb7..02ffa93baee5c643ebdceaa274710f9d58e6eecb 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -82,6 +82,14 @@ class TestUtilTest(test_util.TensorFlowTestCase): else: print("GoogleCuda is disabled") + def testIsMklEnabled(self): + # This test doesn't assert anything. + # It ensures the py wrapper function is generated correctly. + if test_util.IsMklEnabled(): + print("MKL is enabled") + else: + print("MKL is disabled") + def testAssertProtoEqualsStr(self): graph_str = "node { name: 'w1' op: 'params' }" @@ -440,6 +448,26 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase): LeakedTensorTest().test_has_no_leak() + def test_no_new_objects_decorator(self): + + class LeakedObjectTest(object): + + def __init__(inner_self): # pylint: disable=no-self-argument + inner_self.assertEqual = self.assertEqual # pylint: disable=invalid-name + inner_self.accumulation = [] + + @test_util.assert_no_new_pyobjects_executing_eagerly + def test_has_leak(self): + self.accumulation.append([1.]) + + @test_util.assert_no_new_pyobjects_executing_eagerly + def test_has_no_leak(self): + self.not_accumulating = [1.] + + with self.assertRaises(AssertionError): + LeakedObjectTest().test_has_leak() + + LeakedObjectTest().test_has_no_leak() if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/grappler/hierarchical_controller.py b/tensorflow/python/grappler/hierarchical_controller.py index 655e43e78f3e1e02fc5dc1982487f7291ab1e23d..c0866c1069ac7f7e25cbd12cb5a490e2ed5e4bec 100644 --- a/tensorflow/python/grappler/hierarchical_controller.py +++ b/tensorflow/python/grappler/hierarchical_controller.py @@ -258,9 +258,11 @@ class HierarchicalController(Controller): "attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size]) variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1]) seq2seq_input_layer = array_ops.placeholder_with_default( - array_ops.zeros([1, self.num_groups, self.group_emb_size], + array_ops.zeros([self.hparams.num_children, + self.num_groups, + self.group_emb_size], dtypes.float32), - shape=(1, self.num_groups, self.group_emb_size)) + shape=(self.hparams.num_children, self.num_groups, self.group_emb_size)) self.seq2seq_input_layer = seq2seq_input_layer def compute_reward(self, run_time): @@ -585,12 +587,29 @@ class HierarchicalController(Controller): """Approximating the blocks of a TF graph from a graph_def. Args: - grouping_actions: grouping predictions + grouping_actions: grouping predictions. verbose: print stuffs. Returns: groups: list of groups. """ + groups = [ + self._create_group_embeddings(grouping_actions, i, verbose) for + i in range(self.hparams.num_children) + ] + return np.stack(groups, axis=0) + + def _create_group_embeddings(self, grouping_actions, child_id, verbose=False): + """Approximating the blocks of a TF graph from a graph_def for each child. + + Args: + grouping_actions: grouping predictions. + child_id: child_id for the group. + verbose: print stuffs. + + Returns: + groups: group embedding for the child_id. + """ if verbose: print("Processing input_graph") @@ -599,23 +618,23 @@ class HierarchicalController(Controller): dag_matrix = np.zeros([self.num_groups, self.num_groups], dtype=np.float32) for op in self.important_ops: topo_op_index = self.name_to_topo_order_index[op.name] - # TODO(agoldie) child_id - group_index = grouping_actions[0][topo_op_index] + group_index = grouping_actions[child_id][topo_op_index] for output_op in self.get_node_fanout(op): if output_op.name not in self.important_op_names: continue - output_group_index = grouping_actions[0][self.name_to_topo_order_index[ - output_op.name]] + output_group_index = ( + grouping_actions[child_id][self.name_to_topo_order_index[ + output_op.name]]) dag_matrix[group_index, output_group_index] += 1.0 num_connections = np.sum(dag_matrix) num_intra_group_connections = dag_matrix.trace() num_inter_group_connections = num_connections - num_intra_group_connections if verbose: print("grouping evaluation metric") - print("num_connections={} num_intra_group_connections={} " - "num_inter_group_connections={}").format( - num_connections, num_intra_group_connections, - num_inter_group_connections) + print(("num_connections={} num_intra_group_connections={} " + "num_inter_group_connections={}").format( + num_connections, num_intra_group_connections, + num_inter_group_connections)) self.dag_matrix = dag_matrix # output_shape @@ -648,7 +667,8 @@ class HierarchicalController(Controller): ], dtype=np.float32) for op_index, op in enumerate(self.important_ops): - group_index = grouping_actions[0][self.name_to_topo_order_index[op.name]] + group_index = grouping_actions[child_id][ + self.name_to_topo_order_index[op.name]] type_name = str(op.op) type_index = self.type_dict[type_name] group_embedding[group_index, type_index] += 1 @@ -675,7 +695,7 @@ class HierarchicalController(Controller): shape=[num_children, self.num_groups], trainable=False) - x = array_ops.tile(self.seq2seq_input_layer, [num_children, 1, 1]) + x = self.seq2seq_input_layer last_c, last_h, attn_mem = self.encode(x) actions, log_probs = {}, {} actions["sample"], log_probs["sample"] = ( @@ -972,8 +992,8 @@ class HierarchicalController(Controller): controller_ops["reward"]["ph"][child_id]: reward, }) if verbose: - print("run_time={:<.5f} reward={:<.5f} " - "best_reward={:<.5f}").format(run_time, reward, best_reward) + print(("run_time={:<.5f} reward={:<.5f} " + "best_reward={:<.5f}").format(run_time, reward, best_reward)) # Reward is a double, best_reward a float: allow for some slack in the # comparison. @@ -988,8 +1008,7 @@ class HierarchicalController(Controller): def generate_placement(self, grouping, sess): controller_ops = self.ops["controller"] feed_seq2seq_input_dict = {} - feed_seq2seq_input_dict[self.seq2seq_input_layer] = np.expand_dims( - grouping, axis=0) + feed_seq2seq_input_dict[self.seq2seq_input_layer] = grouping sess.run( controller_ops["y_preds"]["sample"], feed_dict=feed_seq2seq_input_dict) diff --git a/tensorflow/python/grappler/item.i b/tensorflow/python/grappler/item.i index 9a84c60b04029a64ed35a01f045a6eec5e492504..593d38206d127978f1982a0f2cc22e17daee1a3d 100644 --- a/tensorflow/python/grappler/item.i +++ b/tensorflow/python/grappler/item.i @@ -83,7 +83,6 @@ static GItem TF_NewItem( tensorflow::grappler::ItemConfig cfg; cfg.ignore_user_placement = ignore_user_placement; cfg.ignore_colocation = ignore_colocation; - cfg.inline_functions = true; std::unique_ptr item = tensorflow::grappler::GrapplerItemFromMetaGraphDef("item", meta_graph, cfg); if (!item) { diff --git a/tensorflow/python/grappler/item_test.py b/tensorflow/python/grappler/item_test.py index 7c3efd6249cbdaa2675632f7fc8e25fb88658a24..c40de9da0abca3bb99a82a1456261f45b1c45c99 100644 --- a/tensorflow/python/grappler/item_test.py +++ b/tensorflow/python/grappler/item_test.py @@ -111,7 +111,7 @@ class ItemTest(test.TestCase): with ops.Graph().as_default() as g: c = constant_op.constant([10]) v = variables.Variable([3], dtype=dtypes.int32) - i = gen_array_ops._ref_identity(v) + i = gen_array_ops.ref_identity(v) a = state_ops.assign(i, c) train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(a) diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 0f5150174049250e86bbac0a49eb998339058326..5a84b16a23f567fba6d08aaefd3b816a76907735 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -321,7 +321,7 @@ class LayoutOptimizerTest(test.TestCase): conv = _two_layer_model(x) dim = array_ops.placeholder(dtype='int32') sizes = constant_op.constant([50, 10, 4], shape=[3]) - split = gen_array_ops._split_v( + split = gen_array_ops.split_v( value=conv, size_splits=sizes, axis=dim, num_split=3) output = math_ops.reduce_sum(split[0]) @@ -896,7 +896,7 @@ class LayoutOptimizerTest(test.TestCase): add = math_ops.add(conv, conv) mean = math_ops.reduce_mean(conv) condition = math_ops.less(conv, mean) - select = gen_math_ops._select(condition, conv, add) + select = gen_math_ops.select(condition, conv, add) output = array_ops.identity(select) with session.Session(config=_get_config(False)) as sess: @@ -926,7 +926,7 @@ class LayoutOptimizerTest(test.TestCase): conv = _two_layer_model(x) add = math_ops.add(conv, conv) condition = array_ops.placeholder(dtype='bool') - select = gen_math_ops._select(condition, conv, add) + select = gen_math_ops.select(condition, conv, add) output = array_ops.identity(select) condition_val = np.zeros((1, 7, 7, 64)) @@ -957,7 +957,7 @@ class LayoutOptimizerTest(test.TestCase): conv = _two_layer_model(x) add = math_ops.add(conv, conv) condition = constant_op.constant(True) - select = gen_math_ops._select(condition, conv, add) + select = gen_math_ops.select(condition, conv, add) output = array_ops.identity(select) with session.Session(config=_get_config(False)) as sess: @@ -1023,7 +1023,7 @@ class LayoutOptimizerTest(test.TestCase): conv = _two_layer_model(x) ksize = constant_op.constant([1, 2, 3, 1], shape=[4]) strides = array_ops.placeholder(dtype='int32', shape=[4]) - max_pool = gen_nn_ops._max_pool_v2(conv, ksize, strides, 'VALID') + max_pool = gen_nn_ops.max_pool_v2(conv, ksize, strides, 'VALID') output = array_ops.identity(max_pool) strides_val = [1, 3, 2, 1] diff --git a/tensorflow/python/grappler/memory_optimizer_test.py b/tensorflow/python/grappler/memory_optimizer_test.py index 948911f099674af4c6dd19bfdac75e5fc1f75c78..4df959ce04169395589aeebaef9e3e7839e2300c 100644 --- a/tensorflow/python/grappler/memory_optimizer_test.py +++ b/tensorflow/python/grappler/memory_optimizer_test.py @@ -162,7 +162,8 @@ class MemoryOptimizerRecomputeTest(test.TestCase): arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, memory_optimization=rewriter_config_pb2.RewriterConfig. RECOMPUTATION_HEURISTICS, - memory_optimizer_target_node_name_prefix='optimizer/gradients/'), + # Checks that name scope "gradients/" also match sub-scope. + memory_optimizer_target_node_name_scope='gradients/'), original_metagraph) self.assertGreater( len(rewritten_graph_def.node), @@ -176,6 +177,35 @@ class MemoryOptimizerRecomputeTest(test.TestCase): len([node for node in rewritten_graph_def.node if 'Recomputed/' in node.name])) + def testRewritingNameScopedGradientNamesScope(self): + """Tests that rewriting occurs with non-standard gradient names.""" + (original_metagraph, _, _, + _) = self._GetMetaGraph(optimizer_scope_name='foo/bar') + rewritten_graph_def = tf_optimizer.OptimizeGraph( + rewriter_config_pb2.RewriterConfig( + disable_model_pruning=True, + constant_folding=rewriter_config_pb2.RewriterConfig.OFF, + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, + layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF, + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + memory_optimization=rewriter_config_pb2.RewriterConfig. + RECOMPUTATION_HEURISTICS, + # This should not match anything. + memory_optimizer_target_node_name_scope='r/gradients/'), + original_metagraph) + self.assertEqual( + len(rewritten_graph_def.node), len(original_metagraph.graph_def.node)) + self.assertEqual(0, + len([ + node for node in original_metagraph.graph_def.node + if 'Recomputed/' in node.name + ])) + self.assertEqual(0, + len([ + node for node in rewritten_graph_def.node + if 'Recomputed/' in node.name + ])) + def _GetMemoryOptimizerSessionConfig(self): rewrite_options = rewriter_config_pb2.RewriterConfig( disable_model_pruning=True, diff --git a/tensorflow/python/grappler/model_analyzer.cc b/tensorflow/python/grappler/model_analyzer.cc index d23eb811ac2b0a6a8802979b4d966b5617c8a8d9..5a76cdd8fb29361cd800dea60cb9ebc0e39f6487 100644 --- a/tensorflow/python/grappler/model_analyzer.cc +++ b/tensorflow/python/grappler/model_analyzer.cc @@ -26,9 +26,10 @@ namespace grappler { ModelAnalyzer::ModelAnalyzer(const GrapplerItem& item) : item_(item) {} -Status ModelAnalyzer::GenerateReport(bool debug, std::ostream& os) { +Status ModelAnalyzer::GenerateReport(bool debug, bool assume_valid_feeds, + std::ostream& os) { GraphProperties properties(item_); - TF_RETURN_IF_ERROR(properties.InferStatically(false)); + TF_RETURN_IF_ERROR(properties.InferStatically(assume_valid_feeds)); for (const auto& node : item_.MainOpsFanin()) { PrintNodeInfo(node, properties, debug, os); diff --git a/tensorflow/python/grappler/model_analyzer.h b/tensorflow/python/grappler/model_analyzer.h index 5bc551927d88db723e21b29903d6f5b941048139..97ffafabe1f785e3b2c3044143b8fb8006b59225 100644 --- a/tensorflow/python/grappler/model_analyzer.h +++ b/tensorflow/python/grappler/model_analyzer.h @@ -31,7 +31,7 @@ class GraphProperties; class ModelAnalyzer { public: explicit ModelAnalyzer(const GrapplerItem& item); - Status GenerateReport(bool debug, std::ostream& os); + Status GenerateReport(bool debug, bool assume_valid_feeds, std::ostream& os); private: void PrintNodeInfo(const NodeDef* node, const GraphProperties& properties, diff --git a/tensorflow/python/grappler/model_analyzer.i b/tensorflow/python/grappler/model_analyzer.i index 7c3a692d0efc501341ff1dff3cf24b8a4830ec84..4955780764be802b9e4be3598bf114b227757194 100644 --- a/tensorflow/python/grappler/model_analyzer.i +++ b/tensorflow/python/grappler/model_analyzer.i @@ -40,7 +40,8 @@ limitations under the License. %} %{ -string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, bool debug) { +string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, + bool assume_valid_feeds, bool debug) { tensorflow::grappler::ItemConfig cfg; cfg.apply_optimizations = false; std::unique_ptr item = @@ -53,10 +54,11 @@ string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, bool debug tensorflow::grappler::ModelAnalyzer analyzer(*item); std::stringstream os; - analyzer.GenerateReport(debug, os); + analyzer.GenerateReport(debug, assume_valid_feeds, os); return os.str(); } %} -string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, bool debug); +string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, + bool assume_valid_feeds, bool debug); diff --git a/tensorflow/python/grappler/model_analyzer.py b/tensorflow/python/grappler/model_analyzer.py index 535889e1c4034952562a05e4d044fcafeddbc0ca..98cdc5785011dcebbaaf43704772b3de00c9d6ca 100644 --- a/tensorflow/python/grappler/model_analyzer.py +++ b/tensorflow/python/grappler/model_analyzer.py @@ -22,11 +22,12 @@ from tensorflow.python import pywrap_tensorflow as tf_wrap from tensorflow.python.framework import errors -def GenerateModelReport(metagraph, debug=False): +def GenerateModelReport(metagraph, assume_valid_feeds=True, debug=False): """Report what's known statically about each node in the provided metagraph. Args: metagraph: A TensorFlow MetaGraphDef. + assume_valid_feeds: If True, assume that the shape of the fed nodes is valid debug: Add some information useful for debugging. Returns: @@ -34,6 +35,6 @@ def GenerateModelReport(metagraph, debug=False): """ with errors.raise_exception_on_not_ok_status(): ret_from_swig = tf_wrap.GenerateModelReport(metagraph.SerializeToString(), - debug) + assume_valid_feeds, debug) return ret_from_swig diff --git a/tensorflow/python/grappler/tf_optimizer.i b/tensorflow/python/grappler/tf_optimizer.i index 1b657983a4690dd0ddb7f569ce514b08cb10400a..39ca71e99af06c19fb7fe5bf185c29106729f5e9 100644 --- a/tensorflow/python/grappler/tf_optimizer.i +++ b/tensorflow/python/grappler/tf_optimizer.i @@ -98,8 +98,8 @@ PyObject* TF_OptimizeGraph( const tensorflow::MetaGraphDef& metagraph, bool verbose, const string& graph_id, TF_Status* out_status) { tensorflow::grappler::ItemConfig item_config; - item_config.inline_functions = false; item_config.apply_optimizations = false; + item_config.ignore_user_placement = false; std::unique_ptr grappler_item = tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config); diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 16738066ced26cd02267806af54171044b20ccce..16033e9b8f3b6970f92c40a5b61db815a97cf6aa 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -8,6 +8,12 @@ exports_files(["LICENSE"]) package(default_visibility = ["//visibility:public"]) load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +config_setting( + name = "empty_condition", + values = {"define": "UNUSED=unused"}, +) py_library( name = "keras", @@ -39,9 +45,16 @@ py_library( "_impl/keras/datasets/mnist.py", "_impl/keras/datasets/reuters.py", "_impl/keras/engine/__init__.py", - "_impl/keras/engine/topology.py", + "_impl/keras/engine/base_layer.py", + "_impl/keras/engine/input_layer.py", + "_impl/keras/engine/network.py", + "_impl/keras/engine/saving.py", + "_impl/keras/engine/sequential.py", "_impl/keras/engine/training.py", + "_impl/keras/engine/training_arrays.py", "_impl/keras/engine/training_eager.py", + "_impl/keras/engine/training_generator.py", + "_impl/keras/engine/training_utils.py", "_impl/keras/estimator.py", "_impl/keras/initializers.py", "_impl/keras/layers/__init__.py", @@ -74,8 +87,8 @@ py_library( "_impl/keras/utils/generic_utils.py", "_impl/keras/utils/io_utils.py", "_impl/keras/utils/layer_utils.py", + "_impl/keras/utils/multi_gpu_utils.py", "_impl/keras/utils/np_utils.py", - "_impl/keras/utils/training_utils.py", "_impl/keras/utils/vis_utils.py", "_impl/keras/wrappers/__init__.py", "_impl/keras/wrappers/scikit_learn.py", @@ -119,7 +132,11 @@ py_library( ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = [ + deps = select({ + ":empty_condition": [], + "//conditions:default": [], + }) + [ + "@six_archive//:six", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", @@ -158,7 +175,6 @@ py_library( "//tensorflow/python/estimator", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/saved_model", - "@six_archive//:six", ], ) @@ -386,11 +402,10 @@ py_test( py_test( name = "convolutional_recurrent_test", - size = "medium", + size = "large", srcs = ["_impl/keras/layers/convolutional_recurrent_test.py"], shard_count = 2, srcs_version = "PY2AND3", - tags = ["noasan"], # times out b/63678675 deps = [ ":keras", "//tensorflow/python:client_testlib", @@ -597,6 +612,7 @@ py_test( "no_windows", "noasan", # times out "notsan", + "optonly", # times out ], deps = [ ":keras", @@ -641,16 +657,17 @@ py_test( ], ) -py_test( - name = "training_utils_test", - size = "medium", - srcs = ["_impl/keras/utils/training_utils_test.py"], - srcs_version = "PY2AND3", - tags = ["multi_gpu"], - deps = [ +cuda_py_test( + name = "multi_gpu_utils_test", + srcs = ["_impl/keras/utils/multi_gpu_utils_test.py"], + additional_deps = [ ":keras", - "//tensorflow/python:client_testlib", "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + ], + tags = [ + "guitar", + "multi_gpu", ], ) @@ -759,11 +776,36 @@ py_test( size = "small", srcs = ["_impl/keras/engine/topology_test.py"], srcs_version = "PY2AND3", + tags = [ + "no-internal-py3", + ], + deps = [ + ":keras", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + +py_test( + name = "saving_test", + size = "small", + srcs = ["_impl/keras/engine/saving_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":keras", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + +py_test( + name = "sequential_test", + size = "small", + srcs = ["_impl/keras/engine/sequential_test.py"], + srcs_version = "PY2AND3", deps = [ ":keras", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", "//third_party/py/numpy", ], ) diff --git a/tensorflow/python/keras/_impl/keras/__init__.py b/tensorflow/python/keras/_impl/keras/__init__.py index b63907b2e60acfc80ee411b9193b2829f0224c3e..53f5d31e9c5b861c551a7a9ca3700c383ea679d7 100644 --- a/tensorflow/python/keras/_impl/keras/__init__.py +++ b/tensorflow/python/keras/_impl/keras/__init__.py @@ -40,4 +40,4 @@ from tensorflow.python.keras._impl.keras.layers import Input from tensorflow.python.keras._impl.keras.models import Model from tensorflow.python.keras._impl.keras.models import Sequential -__version__ = '2.1.4-tf' +__version__ = '2.1.5-tf' diff --git a/tensorflow/python/keras/_impl/keras/applications/densenet.py b/tensorflow/python/keras/_impl/keras/applications/densenet.py index 6521f8410435fd13393b9991d3ee9a6342a912d0..ca83e8691237216e799f2ca738dcb6822506e2cb 100644 --- a/tensorflow/python/keras/_impl/keras/applications/densenet.py +++ b/tensorflow/python/keras/_impl/keras/applications/densenet.py @@ -31,7 +31,7 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.applications import imagenet_utils from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import AveragePooling2D from tensorflow.python.keras._impl.keras.layers import BatchNormalization diff --git a/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py index bf3901fc54419c2b401bf9c4d6311b39a18f1aba..17e407dd58460e6d6802a3e137a96faf38a6f576 100644 --- a/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py +++ b/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py @@ -31,7 +31,7 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.applications import imagenet_utils from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import AveragePooling2D from tensorflow.python.keras._impl.keras.layers import BatchNormalization diff --git a/tensorflow/python/keras/_impl/keras/applications/inception_v3.py b/tensorflow/python/keras/_impl/keras/applications/inception_v3.py index e268e97bc663773a218f01b958b08f8e43c74ee2..2897c6058eb445ceacc34084b53dc89f556e3e9c 100644 --- a/tensorflow/python/keras/_impl/keras/applications/inception_v3.py +++ b/tensorflow/python/keras/_impl/keras/applications/inception_v3.py @@ -37,7 +37,7 @@ from tensorflow.python.keras._impl.keras import layers from tensorflow.python.keras._impl.keras.applications import imagenet_utils from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import AveragePooling2D from tensorflow.python.keras._impl.keras.layers import BatchNormalization diff --git a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py index 1bbbedb85e47902b9e6d3dd741e9d52ab9209080..ad96b53a4528d99a014a0214b52a78d6a60076f8 100644 --- a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py +++ b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py @@ -79,8 +79,8 @@ from tensorflow.python.keras._impl.keras.applications import imagenet_utils from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs -from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import BatchNormalization from tensorflow.python.keras._impl.keras.layers import Conv2D diff --git a/tensorflow/python/keras/_impl/keras/applications/nasnet.py b/tensorflow/python/keras/_impl/keras/applications/nasnet.py index 08dae57f006c64021cbca26404770cd89b1ce176..dd33230a7eb9272f8fc60daee63e1f92574cf5e3 100644 --- a/tensorflow/python/keras/_impl/keras/applications/nasnet.py +++ b/tensorflow/python/keras/_impl/keras/applications/nasnet.py @@ -49,7 +49,7 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions from tensorflow.python.keras._impl.keras.applications.inception_v3 import preprocess_input -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import add from tensorflow.python.keras._impl.keras.layers import AveragePooling2D diff --git a/tensorflow/python/keras/_impl/keras/applications/resnet50.py b/tensorflow/python/keras/_impl/keras/applications/resnet50.py index a47dd657bb9ea0627d82831b7ee5d0b33788b5b7..46c0e635578c7f4707b027247943d75b16d703ad 100644 --- a/tensorflow/python/keras/_impl/keras/applications/resnet50.py +++ b/tensorflow/python/keras/_impl/keras/applications/resnet50.py @@ -34,7 +34,7 @@ from tensorflow.python.keras._impl.keras import layers from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import AveragePooling2D from tensorflow.python.keras._impl.keras.layers import BatchNormalization diff --git a/tensorflow/python/keras/_impl/keras/applications/vgg16.py b/tensorflow/python/keras/_impl/keras/applications/vgg16.py index 9da74253abc2124844ab89b7727ddda4f754d8e2..cefb25063e30505c9c34b49fd2df6eb7210d7ca8 100644 --- a/tensorflow/python/keras/_impl/keras/applications/vgg16.py +++ b/tensorflow/python/keras/_impl/keras/applications/vgg16.py @@ -32,7 +32,7 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Conv2D from tensorflow.python.keras._impl.keras.layers import Dense from tensorflow.python.keras._impl.keras.layers import Flatten diff --git a/tensorflow/python/keras/_impl/keras/applications/vgg19.py b/tensorflow/python/keras/_impl/keras/applications/vgg19.py index 961c1f991893dbc0df858e9f72b61202c9fee500..dadaf4fdf0cc5922752c6867720c5d8cdbcab19a 100644 --- a/tensorflow/python/keras/_impl/keras/applications/vgg19.py +++ b/tensorflow/python/keras/_impl/keras/applications/vgg19.py @@ -32,7 +32,7 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Conv2D from tensorflow.python.keras._impl.keras.layers import Dense from tensorflow.python.keras._impl.keras.layers import Flatten diff --git a/tensorflow/python/keras/_impl/keras/applications/xception.py b/tensorflow/python/keras/_impl/keras/applications/xception.py index 7e7ca5a18a31622ac79d61ab01ce65341a4a46c5..971063a16d1f5ba0e25189f1ef2f6c24eb5f8d61 100644 --- a/tensorflow/python/keras/_impl/keras/applications/xception.py +++ b/tensorflow/python/keras/_impl/keras/applications/xception.py @@ -44,7 +44,7 @@ from tensorflow.python.keras._impl.keras import layers from tensorflow.python.keras._impl.keras.applications import imagenet_utils from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import BatchNormalization from tensorflow.python.keras._impl.keras.layers import Conv2D diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py index a2db05f6cfd0c20fef3b18832834d06990b7a512..7baf27642a475eb3a09687a1d19a6ed05de046e9 100644 --- a/tensorflow/python/keras/_impl/keras/backend.py +++ b/tensorflow/python/keras/_impl/keras/backend.py @@ -55,10 +55,10 @@ from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-im from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variables as variables_module from tensorflow.python.training import moving_averages +from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export - py_all = all py_sum = sum @@ -343,7 +343,7 @@ def learning_phase(): Returns: Learning phase (scalar integer tensor or Python integer). """ - if context.in_eager_mode(): + if context.executing_eagerly(): if 'eager' not in _GRAPH_LEARNING_PHASES: # Fallback to inference mode as default. return 0 @@ -369,13 +369,42 @@ def set_learning_phase(value): """ global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned if value not in {0, 1}: - raise ValueError('Expected learning phase to be ' '0 or 1.') - if context.in_eager_mode(): + raise ValueError('Expected learning phase to be 0 or 1.') + if context.executing_eagerly(): _GRAPH_LEARNING_PHASES['eager'] = value else: _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value +@tf_contextlib.contextmanager +def learning_phase_scope(value): + """Provides a scope within which the learning phase is equal to `value`. + + The learning phase gets restored to its original value upon exiting the scope. + + Arguments: + value: Learning phase value, either 0 or 1 (integers). + + Yields: + The provided value. + + Raises: + ValueError: if `value` is neither `0` nor `1`. + """ + if value not in {0, 1}: + raise ValueError('Expected learning phase to be 0 or 1.') + previous_value = learning_phase() + try: + set_learning_phase(value) + yield value + finally: + # Restore learning phase to initial value. + if context.executing_eagerly(): + _GRAPH_LEARNING_PHASES['eager'] = previous_value + else: + _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value + + @tf_export('keras.backend.get_session') def get_session(): """Returns the TF session to be used by the backend. @@ -394,8 +423,9 @@ def get_session(): A TensorFlow session. """ global _SESSION - if ops.get_default_session() is not None: - session = ops.get_default_session() + default_session = ops.get_default_session() + if default_session is not None: + session = default_session else: if _SESSION is None: if not os.environ.get('OMP_NUM_THREADS'): @@ -466,7 +496,7 @@ def _is_current_explicit_device(device_type): """ device_type = device_type.upper() if device_type not in ['CPU', 'GPU']: - raise ValueError('device_type should be either "CPU" or "GPU".') + raise ValueError('`device_type` should be either "CPU" or "GPU".') device = _get_current_tf_device() return device is not None and device.device_type == device_type.upper() @@ -2596,7 +2626,7 @@ def get_value(x): Returns: A Numpy array. """ - if context.in_eager_mode(): + if context.executing_eagerly(): return x.numpy() return x.eval(session=get_session()) @@ -2611,7 +2641,7 @@ def batch_get_value(tensors): Returns: A list of Numpy arrays. """ - if context.in_eager_mode(): + if context.executing_eagerly(): return [x.numpy() for x in tensors] if tensors: return get_session().run(tensors) @@ -2629,7 +2659,7 @@ def set_value(x, value): (of the same shape). """ value = np.asarray(value, dtype=dtype(x)) - if context.in_eager_mode(): + if context.executing_eagerly(): x.assign(value) else: tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0]) @@ -2652,7 +2682,7 @@ def batch_set_value(tuples): tuples: a list of tuples `(tensor, value)`. `value` should be a Numpy array. """ - if context.in_eager_mode(): + if context.executing_eagerly(): for x, value in tuples: x.assign(np.asarray(value, dtype=dtype(x))) else: @@ -2749,7 +2779,7 @@ class Function(object): self.updates_op = control_flow_ops.group(*updates_ops) self.name = name # additional tensor substitutions - self.feed_dict = session_kwargs.pop('feed_dict', {}) + self.feed_dict = session_kwargs.pop('feed_dict', None) # additional operations self.fetches = session_kwargs.pop('fetches', []) if not isinstance(self.fetches, list): @@ -2759,8 +2789,15 @@ class Function(object): def __call__(self, inputs): if not isinstance(inputs, (list, tuple)): raise TypeError('`inputs` should be a list or tuple.') - feed_dict = self.feed_dict.copy() + + if self.feed_dict: + feed_dict = self.feed_dict.copy() + else: + feed_dict = {} + for tensor, value in zip(self.inputs, inputs): + if value is None: + continue if is_sparse(tensor): sparse_coo = value.tocoo() indices = np.concatenate((np.expand_dims(sparse_coo.row, 1), @@ -3087,7 +3124,7 @@ def rnn(step_function, outputs_shape[1] = inputs_shape[1] outputs.set_shape(outputs_shape) - if not context.in_eager_mode(): + if not context.executing_eagerly(): last_output._uses_learning_phase = uses_learning_phase return last_output, outputs, new_states @@ -3336,7 +3373,7 @@ def categorical_crossentropy(target, output, from_logits=False): target * math_ops.log(output), axis=len(output.get_shape()) - 1) else: - return nn.softmax_cross_entropy_with_logits(labels=target, logits=output) + return nn.softmax_cross_entropy_with_logits_v2(labels=target, logits=output) @tf_export('keras.backend.sparse_categorical_crossentropy') @@ -3478,7 +3515,7 @@ def l2_normalize(x, axis=None): Returns: A tensor. """ - return nn.l2_normalize(x, dim=axis) + return nn.l2_normalize(x, axis=axis) @tf_export('keras.backend.in_top_k') diff --git a/tensorflow/python/keras/_impl/keras/backend_test.py b/tensorflow/python/keras/_impl/keras/backend_test.py index f29ca49378bc43385b9e90d3f1cefb7937df64cd..fb4b2a0e1dc06c904d4b93038840dbf688d42ed4 100644 --- a/tensorflow/python/keras/_impl/keras/backend_test.py +++ b/tensorflow/python/keras/_impl/keras/backend_test.py @@ -128,6 +128,22 @@ class BackendUtilsTest(test.TestCase): sess.run(variables.global_variables_initializer()) sess.run(y, feed_dict={x: np.random.random((2, 3))}) + def test_learning_phase_scope(self): + with self.test_session(): + initial_learning_phase = keras.backend.learning_phase() + with keras.backend.learning_phase_scope(1) as lp: + self.assertEqual(lp, 1) + self.assertEqual(keras.backend.learning_phase(), 1) + self.assertEqual(keras.backend.learning_phase(), initial_learning_phase) + with keras.backend.learning_phase_scope(0) as lp: + self.assertEqual(lp, 0) + self.assertEqual(keras.backend.learning_phase(), 0) + self.assertEqual(keras.backend.learning_phase(), initial_learning_phase) + with self.assertRaises(ValueError): + with keras.backend.learning_phase_scope(None): + pass + self.assertEqual(keras.backend.learning_phase(), initial_learning_phase) + def test_int_shape(self): x = keras.backend.placeholder(shape=(3, 4)) self.assertEqual(keras.backend.int_shape(x), (3, 4)) diff --git a/tensorflow/python/keras/_impl/keras/callbacks.py b/tensorflow/python/keras/_impl/keras/callbacks.py index f6c466142522927135d66f73f9f5c697671649ec..deb1e8867dba3d52816ebda02bd9a3bf2ec7bc09 100644 --- a/tensorflow/python/keras/_impl/keras/callbacks.py +++ b/tensorflow/python/keras/_impl/keras/callbacks.py @@ -778,16 +778,24 @@ class TensorBoard(Callback): while i < val_size: step = min(self.batch_size, val_size - i) batch_val = [] - batch_val.append(val_data[0][i:i + step]) - batch_val.append(val_data[1][i:i + step]) - batch_val.append(val_data[2][i:i + step]) + batch_val.append(val_data[0][i:i + step] + if val_data[0] is not None else None) + batch_val.append(val_data[1][i:i + step] + if val_data[1] is not None else None) + batch_val.append(val_data[2][i:i + step] + if val_data[2] is not None else None) if self.model.uses_learning_phase: # do not slice the learning phase - batch_val = [x[i:i + step] for x in val_data[:-1]] + batch_val = [x[i:i + step] if x is not None else None + for x in val_data[:-1]] batch_val.append(val_data[-1]) else: - batch_val = [x[i:i + step] for x in val_data] - feed_dict = dict(zip(tensors, batch_val)) + batch_val = [x[i:i + step] if x is not None else None + for x in val_data] + feed_dict = {} + for key, val in zip(tensors, batch_val): + if val is not None: + feed_dict[key] = val result = self.sess.run([self.merged], feed_dict=feed_dict) summary_str = result[0] self.writer.add_summary(summary_str, epoch) diff --git a/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py b/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py index b9ae41a0d4d0e8d9df70e3fc1952e81c5f57e8d9..508e95f719a02977960b80c283495ced642293c5 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py +++ b/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py @@ -24,8 +24,10 @@ import os import numpy as np from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.datasets.fashion_mnist.load_data') def load_data(): """Loads the Fashion-MNIST dataset. diff --git a/tensorflow/python/keras/_impl/keras/engine/__init__.py b/tensorflow/python/keras/_impl/keras/engine/__init__.py index 31f624f9af65cac60b6466d4eb5753cbdee984c6..1bc533ab8f7ba37948d82bc69fe1c9bfe00d6834 100644 --- a/tensorflow/python/keras/_impl/keras/engine/__init__.py +++ b/tensorflow/python/keras/_impl/keras/engine/__init__.py @@ -18,13 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs -from tensorflow.python.keras._impl.keras.engine.topology import Input -from tensorflow.python.keras._impl.keras.engine.topology import InputLayer -from tensorflow.python.keras._impl.keras.engine.topology import InputSpec -from tensorflow.python.keras._impl.keras.engine.topology import Layer +from tensorflow.python.keras._impl.keras.engine.base_layer import InputSpec +from tensorflow.python.keras._impl.keras.engine.base_layer import Layer +from tensorflow.python.keras._impl.keras.engine.input_layer import Input +from tensorflow.python.keras._impl.keras.engine.input_layer import InputLayer +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.network import Network from tensorflow.python.keras._impl.keras.engine.training import Model - - -# Note: topology.Node is an internal class, -# it isn't meant to be used by Keras users. diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..5615241ae3077102ef40f9c0619161964a62a335 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py @@ -0,0 +1,505 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +"""Base layer code (`Layer`). +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six.moves import zip # pylint: disable=redefined-builtin + +from tensorflow.python.eager import context +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras import constraints +from tensorflow.python.keras._impl.keras import initializers +from tensorflow.python.keras._impl.keras import regularizers +from tensorflow.python.keras._impl.keras.utils import generic_utils +from tensorflow.python.layers import base as tf_base_layers +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export + + +# pylint: disable=invalid-name +InputSpec = tf_base_layers.InputSpec +Node = tf_base_layers.Node +TFBaseLayer = tf_base_layers.Layer +# pylint: enable=invalid-name + + +@tf_export('keras.layers.Layer') +class Layer(tf_base_layers.Layer): + """Abstract base layer class. + + # Properties + name: String, must be unique within a model. + input_spec: List of InputSpec class instances + each entry describes one required input: + - ndim + - dtype + A layer with `n` input tensors must have + an `input_spec` of length `n`. + trainable: Boolean, whether the layer weights + will be updated during training. + uses_learning_phase: Whether any operation + of the layer uses `K.in_training_phase()` + or `K.in_test_phase()`. + input_shape: Shape tuple. Provided for convenience, + but note that there may be cases in which this + attribute is ill-defined (e.g. a shared layer + with multiple input shapes), in which case + requesting `input_shape` will raise an Exception. + Prefer using `layer.get_input_shape_for(input_shape)`, + or `layer.get_input_shape_at(node_index)`. + output_shape: Shape tuple. See above. + inbound_nodes: List of nodes. + outbound_nodes: List of nodes. + input, output: Input/output tensor(s). Note that if the layer is used + more than once (shared layer), this is ill-defined + and will raise an exception. In such cases, use + `layer.get_input_at(node_index)`. + input_mask, output_mask: Same as above, for masks. + trainable_weights: List of variables. + non_trainable_weights: List of variables. + weights: The concatenation of the lists trainable_weights and + non_trainable_weights (in this order). + + # Methods + call(x, mask=None): Where the layer's logic lives. + __call__(x, mask=None): Wrapper around the layer logic (`call`). + If x is a Keras tensor: + - Connect current layer with last layer from tensor: + `self._add_inbound_node(last_layer)` + - Add layer to tensor history + If layer is not built: + - Build from inputs shape + get_weights() + set_weights(weights) + get_config() + count_params() + compute_output_shape(input_shape) + compute_mask(x, mask) + get_input_at(node_index) + get_output_at(node_index) + get_input_shape_at(node_index) + get_output_shape_at(node_index) + get_input_mask_at(node_index) + get_output_mask_at(node_index) + + # Class Methods + from_config(config) + + # Internal methods: + build(input_shape) + _add_inbound_node(layer, index=0) + """ + + def __init__(self, **kwargs): + # These properties should be set by the user via keyword arguments. + # note that 'dtype', 'input_shape' and 'batch_input_shape' + # are only applicable to input layers: do not pass these keywords + # to non-input layers. + allowed_kwargs = { + 'activity_regularizer', + 'input_shape', + 'batch_input_shape', + 'batch_size', + 'dtype', + 'name', + 'trainable', + 'weights', + } + # Validate optional keyword arguments. + for kwarg in kwargs: + if kwarg not in allowed_kwargs: + raise TypeError('Keyword argument not understood:', kwarg) + + # Get layer name. + name = kwargs.get('name') + + # Get `trainable` status. + trainable = kwargs.get('trainable', True) + + # Get `dtype`. + dtype = kwargs.get('dtype') + if dtype is None: + dtype = K.floatx() + + # Call super, which will set all properties common to Keras layers + # and core TF layers. + super(Layer, self).__init__( + name=name, dtype=dtype, trainable=trainable, + activity_regularizer=kwargs.get('activity_regularizer')) + + # Add properties that are Keras-only for now. + self.supports_masking = False + + # Manage input shape information if passed. + if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: + # In this case we will later create an input layer + # to insert before the current layer + if 'batch_input_shape' in kwargs: + batch_input_shape = tuple(kwargs['batch_input_shape']) + elif 'input_shape' in kwargs: + if 'batch_size' in kwargs: + batch_size = kwargs['batch_size'] + else: + batch_size = None + batch_input_shape = (batch_size,) + tuple(kwargs['input_shape']) + self._batch_input_shape = batch_input_shape + + # Manage initial weight values if passed. + if 'weights' in kwargs: + self._initial_weights = kwargs['weights'] + else: + self._initial_weights = None + + def add_weight(self, + name, + shape, + dtype=None, + initializer=None, + regularizer=None, + trainable=True, + constraint=None): + """Adds a weight variable to the layer. + + Arguments: + name: String, the name for the weight variable. + shape: The shape tuple of the weight. + dtype: The dtype of the weight. + initializer: An Initializer instance (callable). + regularizer: An optional Regularizer instance. + trainable: A boolean, whether the weight should + be trained via backprop or not (assuming + that the layer itself is also trainable). + constraint: An optional Constraint instance. + + Returns: + The created weight variable. + """ + if dtype is None: + dtype = K.floatx() + weight = self.add_variable(name, shape, + dtype=dtype, + initializer=initializers.get(initializer), + regularizer=regularizers.get(regularizer), + constraint=constraints.get(constraint), + trainable=trainable) + return weight + + def call(self, inputs, **kwargs): # pylint: disable=unused-argument + """This is where the layer's logic lives. + + Arguments: + inputs: Input tensor, or list/tuple of input tensors. + **kwargs: Additional keyword arguments. + + Returns: + A tensor or list/tuple of tensors. + """ + return inputs + + def __call__(self, inputs, **kwargs): + """Wrapper around self.call(), for handling internal references. + + If a Keras tensor is passed: + - We call self._add_inbound_node(). + - If necessary, we `build` the layer to match + the shape of the input(s). + - We update the _keras_history of the output tensor(s) + with the current layer. + This is done as part of _add_inbound_node(). + + Arguments: + inputs: Can be a tensor or list/tuple of tensors. + **kwargs: Additional keyword arguments to be passed to `call()`. + + Returns: + Output of the layer's `call` method. + + Raises: + ValueError: in case the layer is missing shape information + for its `build` call. + """ + # Actually call the layer (optionally building it). + output = super(Layer, self).__call__(inputs, **kwargs) + if context.executing_eagerly(): + return output + + if hasattr(self, '_symbolic_set_inputs') and not self.inputs: + # Subclassed network: explicitly set metadata normally set by a call to + # self._set_inputs(). + self._symbolic_set_inputs(inputs, output) + + # Update learning phase info. + output_tensors = generic_utils.to_list(output) + uses_lp = any( + [getattr(x, '_uses_learning_phase', False) + for x in generic_utils.to_list(inputs)]) + uses_lp = getattr(self, 'uses_learning_phase', False) or uses_lp + for i in range(len(output_tensors)): + output_tensors[i]._uses_learning_phase = getattr( + output_tensors[i], '_uses_learning_phase', False) or uses_lp + + # Optionally load weight values that were specified at layer instantiation. + if hasattr(self, '_initial_weights') and self._initial_weights is not None: + self.set_weights(self._initial_weights) + del self._initial_weights + return output + + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer. + + Assumes that the layer will be built + to match that input shape provided. + + Arguments: + input_shape: Shape tuple (tuple of integers) + or list of shape tuples (one per output tensor of the layer). + Shape tuples can include None for free dimensions, + instead of an integer. + + Returns: + An input shape tuple. + """ + logging.warning( + 'All custom layers should implement the ' + '`compute_output_shape` method. This layer (' + self.name + ') ' + 'is relying on the base `Layer.compute_output_shape` implementation, ' + 'which will start raising a `NotImplementedError` ' + 'as of July 1st, 2018.') + return input_shape + + def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument + """Computes an output mask tensor. + + Arguments: + inputs: Tensor or list of tensors. + mask: Tensor or list of tensors. + + Returns: + None or a tensor (or list of tensors, + one per output tensor of the layer). + """ + if not self.supports_masking: + if mask is not None: + if isinstance(mask, list): + if any(m is not None for m in mask): + raise TypeError('Layer ' + self.name + ' does not support masking, ' + 'but was passed an input_mask: ' + str(mask)) + else: + raise TypeError('Layer ' + self.name + ' does not support masking, ' + 'but was passed an input_mask: ' + str(mask)) + # masking not explicitly supported: return None as mask + return None + # if masking is explicitly supported, by default + # carry over the input mask + return mask + + def get_input_mask_at(self, node_index): + """Retrieves the input mask tensor(s) of a layer at a given node. + + Arguments: + node_index: Integer, index of the node + from which to retrieve the attribute. + E.g. `node_index=0` will correspond to the + first time the layer was called. + + Returns: + A mask tensor + (or list of tensors if the layer has multiple inputs). + """ + inputs = self.get_input_at(node_index) + if isinstance(inputs, list): + return [getattr(x, '_keras_mask', None) for x in inputs] + else: + return getattr(inputs, '_keras_mask', None) + + def get_output_mask_at(self, node_index): + """Retrieves the output mask tensor(s) of a layer at a given node. + + Arguments: + node_index: Integer, index of the node + from which to retrieve the attribute. + E.g. `node_index=0` will correspond to the + first time the layer was called. + + Returns: + A mask tensor + (or list of tensors if the layer has multiple outputs). + """ + output = self.get_output_at(node_index) + if isinstance(output, list): + return [getattr(x, '_keras_mask', None) for x in output] + else: + return getattr(output, '_keras_mask', None) + + @property + def input_mask(self): + """Retrieves the input mask tensor(s) of a layer. + + Only applicable if the layer has exactly one inbound node, + i.e. if it is connected to one incoming layer. + + Returns: + Input mask tensor (potentially None) or list of input + mask tensors. + + Raises: + AttributeError: if the layer is connected to + more than one incoming layers. + """ + inputs = self.input + if isinstance(inputs, list): + return [getattr(x, '_keras_mask', None) for x in inputs] + else: + return getattr(inputs, '_keras_mask', None) + + @property + def output_mask(self): + """Retrieves the output mask tensor(s) of a layer. + + Only applicable if the layer has exactly one inbound node, + i.e. if it is connected to one incoming layer. + + Returns: + Output mask tensor (potentially None) or list of output + mask tensors. + + Raises: + AttributeError: if the layer is connected to + more than one incoming layers. + """ + output = self.output + if isinstance(output, list): + return [getattr(x, '_keras_mask', None) for x in output] + else: + return getattr(output, '_keras_mask', None) + + def set_weights(self, weights): + """Sets the weights of the layer, from Numpy arrays. + + Arguments: + weights: a list of Numpy arrays. The number + of arrays and their shape must match + number of the dimensions of the weights + of the layer (i.e. it should match the + output of `get_weights`). + + Raises: + ValueError: If the provided weights list does not match the + layer's specifications. + """ + params = self.weights + if len(params) != len(weights): + raise ValueError('You called `set_weights(weights)` on layer "' + + self.name + '" with a weight list of length ' + + str(len(weights)) + ', but the layer was expecting ' + + str(len(params)) + ' weights. Provided weights: ' + + str(weights)[:50] + '...') + if not params: + return + weight_value_tuples = [] + param_values = K.batch_get_value(params) + for pv, p, w in zip(param_values, params, weights): + if pv.shape != w.shape: + raise ValueError('Layer weight shape ' + str(pv.shape) + + ' not compatible with ' + 'provided weight shape ' + str(w.shape)) + weight_value_tuples.append((p, w)) + K.batch_set_value(weight_value_tuples) + + def get_weights(self): + """Returns the current weights of the layer. + + Returns: + Weights values as a list of numpy arrays. + """ + params = self.weights + return K.batch_get_value(params) + + def get_config(self): + """Returns the config of the layer. + + A layer config is a Python dictionary (serializable) + containing the configuration of a layer. + The same layer can be reinstantiated later + (without its trained weights) from this configuration. + + The config of a layer does not include connectivity + information, nor the layer class name. These are handled + by `Network` (one layer of abstraction above). + + Returns: + Python dictionary. + """ + config = {'name': self.name, 'trainable': self.trainable} + if hasattr(self, '_batch_input_shape'): + config['batch_input_shape'] = self._batch_input_shape + if hasattr(self, 'dtype'): + config['dtype'] = self.dtype + return config + + @classmethod + def from_config(cls, config): + """Creates a layer from its config. + + This method is the reverse of `get_config`, + capable of instantiating the same layer from the config + dictionary. It does not handle layer connectivity + (handled by Network), nor weights (handled by `set_weights`). + + Arguments: + config: A Python dictionary, typically the + output of get_config. + + Returns: + A layer instance. + """ + return cls(**config) + + @tf_base_layers.Layer.activity_regularizer.setter + def activity_regularizer(self, activity_regularizer): + self._activity_regularizer = activity_regularizer + + +def shape_type_conversion(fn): + """Decorator that handles tuple/TensorShape conversion. + + Used in `compute_output_shape` and `build`. + + Arguments: + fn: function to wrap. + + Returns: + Wrapped function. + """ + + def wrapper(instance, input_shape): + if input_shape is not None: + if isinstance(input_shape, list): + input_shape = [ + tuple(tensor_shape.TensorShape(x).as_list()) for x in input_shape] + else: + input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list()) + output_shape = fn(instance, input_shape) + if output_shape is not None: + if isinstance(output_shape, list): + return [tensor_shape.TensorShape(x) for x in output_shape] + return tensor_shape.TensorShape(output_shape) + + return wrapper diff --git a/tensorflow/python/keras/_impl/keras/engine/input_layer.py b/tensorflow/python/keras/_impl/keras/engine/input_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..b51dd8a2189d0c8542c84dfeac9be0d72b96ff1b --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/input_layer.py @@ -0,0 +1,231 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +"""Input layer code (`Input` and `InputLayer`). +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras.engine import base_layer +from tensorflow.python.layers import base as tf_base_layers +from tensorflow.python.ops import array_ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export('keras.layers.InputLayer') +class InputLayer(base_layer.Layer): + """Layer to be used as an entry point into a Network (a graph of layers). + + It can either wrap an existing tensor (pass an `input_tensor` argument) + or create its a placeholder tensor (pass arguments `input_shape`, and + optionally, `dtype`). + + It is generally recommend to use the functional layer API via `Input`, + (which creates an `InputLayer`) without directly using `InputLayer`. + + Arguments: + input_shape: Shape tuple (not including the batch axis), or `TensorShape` + instance (not including the batch axis). + batch_size: Optional input batch size (integer or None). + dtype: Datatype of the input. + input_tensor: Optional tensor to use as layer input + instead of creating a placeholder. + sparse: Boolean, whether the placeholder created + is meant to be sparse. + name: Name of the layer (string). + """ + + def __init__(self, + input_shape=None, + batch_size=None, + dtype=None, + input_tensor=None, + sparse=False, + name=None, + **kwargs): + if 'batch_input_shape' in kwargs: + batch_input_shape = kwargs.pop('batch_input_shape') + if input_shape and batch_input_shape: + raise ValueError('Only provide the input_shape OR ' + 'batch_input_shape argument to ' + 'InputLayer, not both at the same time.') + batch_size = batch_input_shape[0] + input_shape = batch_input_shape[1:] + if kwargs: + raise ValueError('Unrecognized keyword arguments:', kwargs.keys()) + + if not name: + prefix = 'input' + name = prefix + '_' + str(K.get_uid(prefix)) + + if not dtype: + if input_tensor is None: + dtype = K.floatx() + else: + dtype = K.dtype(input_tensor) + super(InputLayer, self).__init__(dtype=dtype, name=name) + self.built = True + self.sparse = sparse + self.batch_size = batch_size + + if isinstance(input_shape, tensor_shape.TensorShape): + input_shape = tuple(input_shape.as_list()) + + if input_tensor is None: + if input_shape is not None: + batch_input_shape = (batch_size,) + tuple(input_shape) + else: + batch_input_shape = None + + if context.executing_eagerly(): + # In eager mode, create a temporary placeholder to call the layer on. + input_tensor = tf_base_layers._DeferredTensor( # pylint: disable=protected-access + shape=batch_input_shape, + dtype=dtype, + name=self.name) + else: + # In graph mode, create a graph placeholder to call the layer on. + if sparse: + input_tensor = array_ops.sparse_placeholder( + shape=batch_input_shape, + dtype=dtype, + name=self.name) + else: + input_tensor = array_ops.placeholder( + shape=batch_input_shape, + dtype=dtype, + name=self.name) + + # For compatibility with Keras API. + self.is_placeholder = True + self._batch_input_shape = batch_input_shape + else: + # For compatibility with Keras API. + self.is_placeholder = False + self._batch_input_shape = tuple(input_tensor.get_shape().as_list()) + + # Create an input node to add to self.outbound_node + # and set output_tensors' _keras_history. + input_tensor._keras_history = (self, 0, 0) # pylint: disable=protected-access + tf_base_layers.Node( + self, + inbound_layers=[], + node_indices=[], + tensor_indices=[], + input_tensors=[input_tensor], + output_tensors=[input_tensor]) + + def get_config(self): + config = { + 'batch_input_shape': self._batch_input_shape, + 'dtype': self.dtype, + 'sparse': self.sparse, + 'name': self.name + } + return config + + +@tf_export('keras.layers.Input', 'keras.Input') +def Input( # pylint: disable=invalid-name + shape=None, + batch_size=None, + name=None, + dtype=None, + sparse=False, + tensor=None, + **kwargs): + """`Input()` is used to instantiate a Keras tensor. + + A Keras tensor is a tensor object from the underlying backend + (Theano or TensorFlow), which we augment with certain + attributes that allow us to build a Keras model + just by knowing the inputs and outputs of the model. + + For instance, if a, b and c are Keras tensors, + it becomes possible to do: + `model = Model(input=[a, b], output=c)` + + The added Keras attribute is: + `_keras_history`: Last layer applied to the tensor. + the entire layer graph is retrievable from that layer, + recursively. + + Arguments: + shape: A shape tuple (integers), not including the batch size. + For instance, `shape=(32,)` indicates that the expected input + will be batches of 32-dimensional vectors. + batch_size: optional static batch size (integer). + name: An optional name string for the layer. + Should be unique in a model (do not reuse the same name twice). + It will be autogenerated if it isn't provided. + dtype: The data type expected by the input, as a string + (`float32`, `float64`, `int32`...) + sparse: A boolean specifying whether the placeholder + to be created is sparse. + tensor: Optional existing tensor to wrap into the `Input` layer. + If set, the layer will not create a placeholder tensor. + **kwargs: deprecated arguments support. + + Returns: + A tensor. + + Example: + + ```python + # this is a logistic regression in Keras + x = Input(shape=(32,)) + y = Dense(16, activation='softmax')(x) + model = Model(x, y) + ``` + + Raises: + ValueError: in case of invalid arguments. + """ + if 'batch_shape' in kwargs: + batch_shape = kwargs.pop('batch_shape') + if shape and batch_shape: + raise ValueError('Only provide the shape OR ' + 'batch_shape argument to ' + 'Input, not both at the same time.') + batch_size = batch_shape[0] + shape = batch_shape[1:] + if kwargs: + raise ValueError('Unrecognized keyword arguments:', kwargs.keys()) + + if dtype is None: + dtype = K.floatx() + if not shape and tensor is None: + raise ValueError('Please provide to Input either a `shape`' + ' or a `tensor` argument. Note that ' + '`shape` does not include the batch ' + 'dimension.') + input_layer = InputLayer( + input_shape=shape, + batch_size=batch_size, + name=name, + dtype=dtype, + sparse=sparse, + input_tensor=tensor) + # Return tensor including `_keras_history`. + # Note that in this case train_output and test_output are the same pointer. + outputs = input_layer._inbound_nodes[0].output_tensors + if len(outputs) == 1: + return outputs[0] + else: + return outputs diff --git a/tensorflow/python/keras/_impl/keras/engine/topology.py b/tensorflow/python/keras/_impl/keras/engine/network.py similarity index 58% rename from tensorflow/python/keras/_impl/keras/engine/topology.py rename to tensorflow/python/keras/_impl/keras/engine/network.py index f562a19cf5ea5c3573962f565c2a92be28bbd591..ea4be0d293b7c4f50cec47eb067f7a928375be0b 100644 --- a/tensorflow/python/keras/_impl/keras/engine/topology.py +++ b/tensorflow/python/keras/_impl/keras/engine/network.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== # pylint: disable=protected-access -"""Base layer code and base model (Network) code. +"""A `Network` is way to compose layers: the topological form of a `Model`. """ from __future__ import absolute_import from __future__ import division @@ -30,19 +30,17 @@ from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import constraints -from tensorflow.python.keras._impl.keras import initializers -from tensorflow.python.keras._impl.keras import regularizers -from tensorflow.python.keras._impl.keras.utils import conv_utils +from tensorflow.python.keras._impl.keras.engine import base_layer +from tensorflow.python.keras._impl.keras.engine import saving +from tensorflow.python.keras._impl.keras.utils import generic_utils from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.keras._impl.keras.utils.layer_utils import print_summary as print_layer_summary from tensorflow.python.layers import base as tf_base_layers from tensorflow.python.layers import utils as tf_layers_util -from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpointable from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect -from tensorflow.python.util.tf_export import tf_export # pylint: disable=g-import-not-at-top @@ -57,684 +55,12 @@ except ImportError: yaml = None # pylint: enable=g-import-not-at-top -# pylint: disable=invalid-name -InputSpec = tf_base_layers.InputSpec -Node = tf_base_layers.Node -TFBaseLayer = tf_base_layers.Layer -# pylint: enable=invalid-name - - -@tf_export('keras.layers.Layer') -class Layer(tf_base_layers.Layer): - """Abstract base layer class. - - # Properties - name: String, must be unique within a model. - input_spec: List of InputSpec class instances - each entry describes one required input: - - ndim - - dtype - A layer with `n` input tensors must have - an `input_spec` of length `n`. - trainable: Boolean, whether the layer weights - will be updated during training. - uses_learning_phase: Whether any operation - of the layer uses `K.in_training_phase()` - or `K.in_test_phase()`. - input_shape: Shape tuple. Provided for convenience, - but note that there may be cases in which this - attribute is ill-defined (e.g. a shared layer - with multiple input shapes), in which case - requesting `input_shape` will raise an Exception. - Prefer using `layer.get_input_shape_for(input_shape)`, - or `layer.get_input_shape_at(node_index)`. - output_shape: Shape tuple. See above. - inbound_nodes: List of nodes. - outbound_nodes: List of nodes. - input, output: Input/output tensor(s). Note that if the layer is used - more than once (shared layer), this is ill-defined - and will raise an exception. In such cases, use - `layer.get_input_at(node_index)`. - input_mask, output_mask: Same as above, for masks. - trainable_weights: List of variables. - non_trainable_weights: List of variables. - weights: The concatenation of the lists trainable_weights and - non_trainable_weights (in this order). - - # Methods - call(x, mask=None): Where the layer's logic lives. - __call__(x, mask=None): Wrapper around the layer logic (`call`). - If x is a Keras tensor: - - Connect current layer with last layer from tensor: - `self._add_inbound_node(last_layer)` - - Add layer to tensor history - If layer is not built: - - Build from inputs shape - get_weights() - set_weights(weights) - get_config() - count_params() - compute_output_shape(input_shape) - compute_mask(x, mask) - get_input_at(node_index) - get_output_at(node_index) - get_input_shape_at(node_index) - get_output_shape_at(node_index) - get_input_mask_at(node_index) - get_output_mask_at(node_index) - - # Class Methods - from_config(config) - - # Internal methods: - build(input_shape) - _add_inbound_node(layer, index=0) - """ - - def __init__(self, **kwargs): - # These properties should be set by the user via keyword arguments. - # note that 'dtype', 'input_shape' and 'batch_input_shape' - # are only applicable to input layers: do not pass these keywords - # to non-input layers. - allowed_kwargs = { - 'activity_regularizer', - 'input_shape', - 'batch_input_shape', - 'batch_size', - 'dtype', - 'name', - 'trainable', - 'weights', - } - # Validate optional keyword arguments. - for kwarg in kwargs: - if kwarg not in allowed_kwargs: - raise TypeError('Keyword argument not understood:', kwarg) - - # Get layer name. - name = kwargs.get('name') - - # Get `trainable` status. - trainable = kwargs.get('trainable', True) - - # Get `dtype`. - dtype = kwargs.get('dtype') - if dtype is None: - dtype = K.floatx() - - # Call super, which will set all properties common to Keras layers - # and core TF layers. - super(Layer, self).__init__( - name=name, dtype=dtype, trainable=trainable, - activity_regularizer=kwargs.get('activity_regularizer')) - - # Add properties that are Keras-only for now. - self.supports_masking = False - - # Manage input shape information if passed. - if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: - # In this case we will later create an input layer - # to insert before the current layer - if 'batch_input_shape' in kwargs: - batch_input_shape = tuple(kwargs['batch_input_shape']) - elif 'input_shape' in kwargs: - if 'batch_size' in kwargs: - batch_size = kwargs['batch_size'] - else: - batch_size = None - batch_input_shape = (batch_size,) + tuple(kwargs['input_shape']) - self._batch_input_shape = batch_input_shape - - # Manage initial weight values if passed. - if 'weights' in kwargs: - self._initial_weights = kwargs['weights'] - else: - self._initial_weights = None - - def add_weight(self, - name, - shape, - dtype=None, - initializer=None, - regularizer=None, - trainable=True, - constraint=None): - """Adds a weight variable to the layer. - - Arguments: - name: String, the name for the weight variable. - shape: The shape tuple of the weight. - dtype: The dtype of the weight. - initializer: An Initializer instance (callable). - regularizer: An optional Regularizer instance. - trainable: A boolean, whether the weight should - be trained via backprop or not (assuming - that the layer itself is also trainable). - constraint: An optional Constraint instance. - - Returns: - The created weight variable. - """ - if dtype is None: - dtype = K.floatx() - weight = self.add_variable(name, shape, - dtype=dtype, - initializer=initializers.get(initializer), - regularizer=regularizers.get(regularizer), - constraint=constraints.get(constraint), - trainable=trainable) - return weight - - def call(self, inputs, **kwargs): # pylint: disable=unused-argument - """This is where the layer's logic lives. - - Arguments: - inputs: Input tensor, or list/tuple of input tensors. - **kwargs: Additional keyword arguments. - - Returns: - A tensor or list/tuple of tensors. - """ - return inputs - - def __call__(self, inputs, **kwargs): - """Wrapper around self.call(), for handling internal references. - - If a Keras tensor is passed: - - We call self._add_inbound_node(). - - If necessary, we `build` the layer to match - the shape of the input(s). - - We update the _keras_history of the output tensor(s) - with the current layer. - This is done as part of _add_inbound_node(). - - Arguments: - inputs: Can be a tensor or list/tuple of tensors. - **kwargs: Additional keyword arguments to be passed to `call()`. - - Returns: - Output of the layer's `call` method. - - Raises: - ValueError: in case the layer is missing shape information - for its `build` call. - """ - # Actually call the layer (optionally building it). - output = super(Layer, self).__call__(inputs, **kwargs) - if context.in_eager_mode(): - return output - - # Un-built subclassed network: build it - if isinstance(self, Network) and not self.inputs: - self._set_inputs(inputs, training=kwargs.get('training')) - - # Update learning phase info. - output_tensors = to_list(output) - uses_lp = any( - [getattr(x, '_uses_learning_phase', False) for x in to_list(inputs)]) - uses_lp = getattr(self, 'uses_learning_phase', False) or uses_lp - for i in range(len(output_tensors)): - output_tensors[i]._uses_learning_phase = getattr( - output_tensors[i], '_uses_learning_phase', False) or uses_lp - - # Optionally load weight values that were specified at layer instantiation. - if hasattr(self, '_initial_weights') and self._initial_weights is not None: - self.set_weights(self._initial_weights) - del self._initial_weights - return output - - def compute_output_shape(self, input_shape): - """Computes the output shape of the layer. - - Assumes that the layer will be built - to match that input shape provided. - - Arguments: - input_shape: Shape tuple (tuple of integers) - or list of shape tuples (one per output tensor of the layer). - Shape tuples can include None for free dimensions, - instead of an integer. - - Returns: - An input shape tuple. - """ - logging.warning( - 'All custom layers should implement the ' - '`compute_output_shape` method. This layer (' + self.name + ') ' - 'is relying on the base `Layer.compute_output_shape` implementation, ' - 'which will start raising a `NotImplementedError` ' - 'as of July 1st, 2018.') - return input_shape - - def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument - """Computes an output mask tensor. - - Arguments: - inputs: Tensor or list of tensors. - mask: Tensor or list of tensors. - - Returns: - None or a tensor (or list of tensors, - one per output tensor of the layer). - """ - if not self.supports_masking: - if mask is not None: - if isinstance(mask, list): - if any(m is not None for m in mask): - raise TypeError('Layer ' + self.name + ' does not support masking, ' - 'but was passed an input_mask: ' + str(mask)) - else: - raise TypeError('Layer ' + self.name + ' does not support masking, ' - 'but was passed an input_mask: ' + str(mask)) - # masking not explicitly supported: return None as mask - return None - # if masking is explicitly supported, by default - # carry over the input mask - return mask - - def get_input_mask_at(self, node_index): - """Retrieves the input mask tensor(s) of a layer at a given node. - Arguments: - node_index: Integer, index of the node - from which to retrieve the attribute. - E.g. `node_index=0` will correspond to the - first time the layer was called. +class Network(base_layer.Layer): + """A `Network` is a composition of layers. - Returns: - A mask tensor - (or list of tensors if the layer has multiple inputs). - """ - inputs = self.get_input_at(node_index) - if isinstance(inputs, list): - return [getattr(x, '_keras_mask', None) for x in inputs] - else: - return getattr(inputs, '_keras_mask', None) - - def get_output_mask_at(self, node_index): - """Retrieves the output mask tensor(s) of a layer at a given node. - - Arguments: - node_index: Integer, index of the node - from which to retrieve the attribute. - E.g. `node_index=0` will correspond to the - first time the layer was called. - - Returns: - A mask tensor - (or list of tensors if the layer has multiple outputs). - """ - output = self.get_output_at(node_index) - if isinstance(output, list): - return [getattr(x, '_keras_mask', None) for x in output] - else: - return getattr(output, '_keras_mask', None) - - @property - def input_mask(self): - """Retrieves the input mask tensor(s) of a layer. - - Only applicable if the layer has exactly one inbound node, - i.e. if it is connected to one incoming layer. - - Returns: - Input mask tensor (potentially None) or list of input - mask tensors. - - Raises: - AttributeError: if the layer is connected to - more than one incoming layers. - """ - inputs = self.input - if isinstance(inputs, list): - return [getattr(x, '_keras_mask', None) for x in inputs] - else: - return getattr(inputs, '_keras_mask', None) - - @property - def output_mask(self): - """Retrieves the output mask tensor(s) of a layer. - - Only applicable if the layer has exactly one inbound node, - i.e. if it is connected to one incoming layer. - - Returns: - Output mask tensor (potentially None) or list of output - mask tensors. - - Raises: - AttributeError: if the layer is connected to - more than one incoming layers. - """ - output = self.output - if isinstance(output, list): - return [getattr(x, '_keras_mask', None) for x in output] - else: - return getattr(output, '_keras_mask', None) - - def set_weights(self, weights): - """Sets the weights of the layer, from Numpy arrays. - - Arguments: - weights: a list of Numpy arrays. The number - of arrays and their shape must match - number of the dimensions of the weights - of the layer (i.e. it should match the - output of `get_weights`). - - Raises: - ValueError: If the provided weights list does not match the - layer's specifications. - """ - params = self.weights - if len(params) != len(weights): - raise ValueError('You called `set_weights(weights)` on layer "' + - self.name + '" with a weight list of length ' + - str(len(weights)) + ', but the layer was expecting ' + - str(len(params)) + ' weights. Provided weights: ' + - str(weights)[:50] + '...') - if not params: - return - weight_value_tuples = [] - param_values = K.batch_get_value(params) - for pv, p, w in zip(param_values, params, weights): - if pv.shape != w.shape: - raise ValueError('Layer weight shape ' + str(pv.shape) + - ' not compatible with ' - 'provided weight shape ' + str(w.shape)) - weight_value_tuples.append((p, w)) - K.batch_set_value(weight_value_tuples) - - def get_weights(self): - """Returns the current weights of the layer. - - Returns: - Weights values as a list of numpy arrays. - """ - params = self.weights - return K.batch_get_value(params) - - def get_config(self): - """Returns the config of the layer. - - A layer config is a Python dictionary (serializable) - containing the configuration of a layer. - The same layer can be reinstantiated later - (without its trained weights) from this configuration. - - The config of a layer does not include connectivity - information, nor the layer class name. These are handled - by `Network` (one layer of abstraction above). - - Returns: - Python dictionary. - """ - config = {'name': self.name, 'trainable': self.trainable} - if hasattr(self, '_batch_input_shape'): - config['batch_input_shape'] = self._batch_input_shape - if hasattr(self, 'dtype'): - config['dtype'] = self.dtype - return config - - @classmethod - def from_config(cls, config): - """Creates a layer from its config. - - This method is the reverse of `get_config`, - capable of instantiating the same layer from the config - dictionary. It does not handle layer connectivity - (handled by Network), nor weights (handled by `set_weights`). - - Arguments: - config: A Python dictionary, typically the - output of get_config. - - Returns: - A layer instance. - """ - return cls(**config) - - @tf_base_layers.Layer.activity_regularizer.setter - def activity_regularizer(self, activity_regularizer): - self._activity_regularizer = activity_regularizer - - -class InputLayer(Layer): - """Layer to be used as an entry point into a Network (a graph of layers). - - It can either wrap an existing tensor (pass an `input_tensor` argument) - or create its a placeholder tensor (pass arguments `input_shape`, and - optionally, `dtype`). - - It is generally recommend to use the functional layer API via `Input`, - (which creates an `InputLayer`) without directly using `InputLayer`. - - Arguments: - input_shape: Shape tuple (not including the batch axis), or `TensorShape` - instance (not including the batch axis). - batch_size: Optional input batch size (integer or None). - dtype: Datatype of the input. - input_tensor: Optional tensor to use as layer input - instead of creating a placeholder. - sparse: Boolean, whether the placeholder created - is meant to be sparse. - name: Name of the layer (string). - """ - - def __init__(self, - input_shape=None, - batch_size=None, - dtype=None, - input_tensor=None, - sparse=False, - name=None, - **kwargs): - if 'batch_input_shape' in kwargs: - batch_input_shape = kwargs.pop('batch_input_shape') - if input_shape and batch_input_shape: - raise ValueError('Only provide the input_shape OR ' - 'batch_input_shape argument to ' - 'InputLayer, not both at the same time.') - batch_size = batch_input_shape[0] - input_shape = batch_input_shape[1:] - if kwargs: - raise ValueError('Unrecognized keyword arguments:', kwargs.keys()) - - if not name: - prefix = 'input' - name = prefix + '_' + str(K.get_uid(prefix)) - - if not dtype: - if input_tensor is None: - dtype = K.floatx() - else: - dtype = K.dtype(input_tensor) - super(InputLayer, self).__init__(dtype=dtype, name=name) - self.built = True - self.sparse = sparse - self.batch_size = batch_size - - if isinstance(input_shape, tensor_shape.TensorShape): - input_shape = tuple(input_shape.as_list()) - - if input_tensor is None: - if input_shape is not None: - batch_input_shape = (batch_size,) + tuple(input_shape) - else: - batch_input_shape = None - - if context.in_eager_mode(): - # In eager mode, create a temporary placeholder to call the layer on. - input_tensor = tf_base_layers._DeferredTensor( # pylint: disable=protected-access - shape=batch_input_shape, - dtype=dtype, - name=self.name) - else: - # In graph mode, create a graph placeholder to call the layer on. - if sparse: - input_tensor = array_ops.sparse_placeholder( - shape=batch_input_shape, - dtype=dtype, - name=self.name) - else: - input_tensor = array_ops.placeholder( - shape=batch_input_shape, - dtype=dtype, - name=self.name) - - # For compatibility with Keras API. - self.is_placeholder = True - self._batch_input_shape = batch_input_shape - else: - # For compatibility with Keras API. - self.is_placeholder = False - self._batch_input_shape = tuple(input_tensor.get_shape().as_list()) - - # Create an input node to add to self.outbound_node - # and set output_tensors' _keras_history. - input_tensor._keras_history = (self, 0, 0) # pylint: disable=protected-access - tf_base_layers.Node( - self, - inbound_layers=[], - node_indices=[], - tensor_indices=[], - input_tensors=[input_tensor], - output_tensors=[input_tensor]) - - def get_config(self): - config = { - 'batch_input_shape': self._batch_input_shape, - 'dtype': self.dtype, - 'sparse': self.sparse, - 'name': self.name - } - return config - - -@tf_export('keras.layers.Input', 'keras.Input') -def Input( # pylint: disable=invalid-name - shape=None, - batch_size=None, - name=None, - dtype=None, - sparse=False, - tensor=None, - **kwargs): - """`Input()` is used to instantiate a Keras tensor. - - A Keras tensor is a tensor object from the underlying backend - (Theano or TensorFlow), which we augment with certain - attributes that allow us to build a Keras model - just by knowing the inputs and outputs of the model. - - For instance, if a, b and c are Keras tensors, - it becomes possible to do: - `model = Model(input=[a, b], output=c)` - - The added Keras attribute is: - `_keras_history`: Last layer applied to the tensor. - the entire layer graph is retrievable from that layer, - recursively. - - Arguments: - shape: A shape tuple (integers), not including the batch size. - For instance, `shape=(32,)` indicates that the expected input - will be batches of 32-dimensional vectors. - batch_size: optional static batch size (integer). - name: An optional name string for the layer. - Should be unique in a model (do not reuse the same name twice). - It will be autogenerated if it isn't provided. - dtype: The data type expected by the input, as a string - (`float32`, `float64`, `int32`...) - sparse: A boolean specifying whether the placeholder - to be created is sparse. - tensor: Optional existing tensor to wrap into the `Input` layer. - If set, the layer will not create a placeholder tensor. - **kwargs: deprecated arguments support. - - Returns: - A tensor. - - Example: - - ```python - # this is a logistic regression in Keras - x = Input(shape=(32,)) - y = Dense(16, activation='softmax')(x) - model = Model(x, y) - ``` - - Raises: - ValueError: in case of invalid arguments. - """ - if 'batch_shape' in kwargs: - batch_shape = kwargs.pop('batch_shape') - if shape and batch_shape: - raise ValueError('Only provide the shape OR ' - 'batch_shape argument to ' - 'Input, not both at the same time.') - batch_size = batch_shape[0] - shape = batch_shape[1:] - if kwargs: - raise ValueError('Unrecognized keyword arguments:', kwargs.keys()) - - if dtype is None: - dtype = K.floatx() - if not shape and tensor is None: - raise ValueError('Please provide to Input either a `shape`' - ' or a `tensor` argument. Note that ' - '`shape` does not include the batch ' - 'dimension.') - input_layer = InputLayer( - input_shape=shape, - batch_size=batch_size, - name=name, - dtype=dtype, - sparse=sparse, - input_tensor=tensor) - # Return tensor including `_keras_history`. - # Note that in this case train_output and test_output are the same pointer. - outputs = input_layer._inbound_nodes[0].output_tensors - if len(outputs) == 1: - return outputs[0] - else: - return outputs - - -class Network(Layer): - """A Network is a directed acyclic graph of layers. - - It is the topological form of a "model". A Model - is simply a Network with added training routines. - - # Properties - name - inputs - outputs - input_layers - output_layers - input_spec (list of class instances) - each entry describes one required input: - - ndim - - dtype - trainable (boolean) - input_shape - output_shape - inbound_nodes: list of nodes - outbound_nodes: list of nodes - trainable_weights (list of variables) - non_trainable_weights (list of variables) - - # Methods - summary - get_layer - get_weights - set_weights - get_config - compute_output_shape - - # Class Methods - from_config + It is the topological form of a "model". A `Model` + is simply a `Network` with added training routines. """ def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called @@ -766,18 +92,20 @@ class Network(Layer): self._expects_training_arg = False self.supports_masking = False - self.optimizer = None + if not hasattr(self, 'optimizer'): + # Don't reset optimizer if already set. + self.optimizer = None # Private attributes to implement compatibility with Layer. self._updates = [] # Used in symbolic mode only. self._losses = [] # Used in symbolic mode only. self._scope = None # Never used. self._reuse = None # Never used. - if context.in_eager_mode: + if context.executing_eagerly(): self._graph = None else: self._graph = ops.get_default_graph() # Used in symbolic mode only. - # A Network does not create weights of its own, thus has no dtype. + # A Network does not create weights of its own, thus has no dtype. self._dtype = None # All layers in order of horizontal graph traversal. @@ -800,7 +128,7 @@ class Network(Layer): self.outputs = [outputs] # User-prodived argument validation. - if context.in_eager_mode(): + if context.executing_eagerly(): # Check that all inputs/outputs are DeferredTensors. for tensor in self.inputs: if not isinstance(tensor, tf_base_layers._DeferredTensor): # pylint: disable=protected-access @@ -862,17 +190,6 @@ class Network(Layer): self.built = True self._is_graph_network = True - # # List of initial layers (1 to 1 mapping with self.inputs, - # # hence the same layer might appear twice) - # self._input_layers = [] - # self._input_layers_node_indices = [] - # self._input_layers_tensor_indices = [] - # # list of layers (1 to 1 mapping with self.inputs, - # # hence the same layer might appear twice) - # self._output_layers = [] - # self._output_layers_node_indices = [] - # self._output_layers_tensor_indices = [] - self._input_layers = [] self._output_layers = [] self._input_coordinates = [] @@ -949,7 +266,7 @@ class Network(Layer): self._feed_input_names.append(layer.name) self._feed_input_shapes.append(K.int_shape(self.inputs[i])) # layer.input gives an error in eager mode - if context.in_graph_mode(): + if not context.executing_eagerly(): self._feed_inputs.append(layer.input) for layer in self._output_layers: self.output_names.append(layer.name) @@ -977,6 +294,13 @@ class Network(Layer): if not is_graph_network: if value not in self._layers: self._layers.append(value) + if isinstance(value, checkpointable.CheckpointableBase): + # Layer (and therefore Network/Model) inherit from CheckpointableBase + # rather than Checkpointable, which means there is no Checkpointable + # __setattr__ override (it would be a performance issue for functional + # layers). Therefore Model tracks Checkpointable objects itself. + self._track_checkpointable( + checkpointable=value, name=name, overwrite=True) super(Network, self).__setattr__(name, value) def add_variable(self, name, shape, dtype=None, initializer=None, @@ -984,7 +308,7 @@ class Network(Layer): raise NotImplementedError('`add_variable` is not supported on Networks.') def add_loss(self, *args, **kwargs): - if context.in_eager_mode(): + if context.executing_eagerly(): raise NotImplementedError('`add_loss` is not supported on Networks ' 'when eager execution is enabled.') super(Network, self).add_loss(*args, **kwargs) @@ -1053,17 +377,17 @@ class Network(Layer): if not self._is_graph_network: return None - inputs = to_list(inputs) + inputs = generic_utils.to_list(inputs) if mask is None: masks = [None for _ in range(len(inputs))] else: - masks = to_list(mask) + masks = generic_utils.to_list(mask) cache_key = (tf_layers_util.object_list_uid(inputs) + '_' + tf_layers_util.object_list_uid(masks)) if cache_key in self._output_mask_cache: return self._output_mask_cache[cache_key] else: - _, output_masks = self._run_internal_graph(inputs, masks) + _, output_masks = self._run_internal_graph(inputs, mask=masks) return output_masks @property @@ -1073,6 +397,7 @@ class Network(Layer): def get_layer(self, name=None, index=None): """Retrieves a layer based on either its name (unique) or index. + If `name` and `index` are both provided, `index` will take precedence. Indices are based on order of horizontal graph traversal (bottom-up). Arguments: @@ -1104,7 +429,7 @@ class Network(Layer): @property def updates(self): - """Retrieve the network's updates. + """Retrieves the network's updates. Will only include updates that are either unconditional, or conditional on inputs to this model @@ -1150,7 +475,7 @@ class Network(Layer): Returns: A list of update ops. """ - if context.in_eager_mode(): + if context.executing_eagerly(): return [] if not self.trainable and not self.stateful: @@ -1162,7 +487,10 @@ class Network(Layer): # `updates` might contain irrelevant updates, so it needs to be filtered # with respect to inputs the model has been called on. - relevant_inputs = self.inputs or [] + if self.inputs: + relevant_inputs = self.inputs[:] + else: + relevant_inputs = [] for i in range(1, len(self._inbound_nodes)): inputs = self.get_input_at(i) if isinstance(inputs, list): @@ -1181,7 +509,7 @@ class Network(Layer): @property def losses(self): - """Retrieve the network's losses. + """Retrieves the network's losses. Will only include losses that are either unconditional, or conditional on inputs to this model @@ -1194,10 +522,13 @@ class Network(Layer): losses = [] for layer in self.layers: losses += layer.losses - if context.in_eager_mode(): + if context.executing_eagerly(): return losses - relevant_inputs = self.inputs or [] + if self.inputs: + relevant_inputs = self.inputs[:] + else: + relevant_inputs = [] for i in range(1, len(self._inbound_nodes)): inputs = self.get_input_at(i) if isinstance(inputs, list): @@ -1261,7 +592,7 @@ class Network(Layer): return specs def call(self, inputs, training=None, mask=None): - """Call the model on new inputs. + """Calls the model on new inputs. In this case `call` just reapplies all ops in the graph to the new inputs @@ -1284,7 +615,7 @@ class Network(Layer): else: masks = nest.flatten(mask) - if context.in_graph_mode(): + if not context.executing_eagerly(): # Try to retrieve cached outputs if the layer has already been called # on these exact inputs. cache_key = (tf_layers_util.object_list_uid(inputs) @@ -1490,7 +821,7 @@ class Network(Layer): else: output_masks = [None for _ in range(len(output_tensors))] - if context.in_graph_mode(): + if not context.executing_eagerly(): if layer.activity_regularizer is not None: regularization_losses = [ layer.activity_regularizer(x) for x in output_tensors @@ -1520,7 +851,7 @@ class Network(Layer): if output_masks is not None: output_masks = output_masks[0] - if context.in_graph_mode(): + if not context.executing_eagerly(): # Update cache; # keys are based on ids on input tensors and inputs masks. cache_key = (tf_layers_util.object_list_uid(inputs) @@ -1691,7 +1022,7 @@ class Network(Layer): layer(input_tensors, **kwargs) def process_layer(layer_data): - """Deserialize a layer, then call it on appropriate inputs. + """Deserializes a layer, then call it on appropriate inputs. Arguments: layer_data: layer config dict. @@ -1748,7 +1079,7 @@ class Network(Layer): return cls(inputs=input_tensors, outputs=output_tensors, name=name) def save(self, filepath, overwrite=True, include_optimizer=True): - """Save the model to a single HDF5 file. + """Saves the model to a single HDF5 file. The savefile includes: - The model architecture, allowing to re-instantiate the model. @@ -1818,7 +1149,7 @@ class Network(Layer): if not proceed: return with h5py.File(filepath, 'w') as f: - save_weights_to_hdf5_group(f, self.layers) + saving.save_weights_to_hdf5_group(f, self.layers) def load_weights(self, filepath, by_name=False): """Loads all layer weights from a HDF5 save file. @@ -1849,12 +1180,12 @@ class Network(Layer): if 'layer_names' not in f.attrs and 'model_weights' in f: f = f['model_weights'] if by_name: - load_weights_from_hdf5_group_by_name(f, self.layers) + saving.load_weights_from_hdf5_group_by_name(f, self.layers) else: - load_weights_from_hdf5_group(f, self.layers) + saving.load_weights_from_hdf5_group(f, self.layers) def _updated_config(self): - """Util hared between different serialization methods. + """Util shared between different serialization methods. Returns: Model config with Keras version information added. @@ -1883,9 +1214,6 @@ class Network(Layer): Returns: A JSON string. """ - if not self._is_graph_network: - raise NotImplementedError - def get_json_type(obj): # If obj is any numpy type if type(obj).__module__ == np.__name__: @@ -1920,9 +1248,6 @@ class Network(Layer): Raises: ImportError: if yaml module is not found. """ - if not self._is_graph_network: - raise NotImplementedError - if yaml is None: raise ImportError('Requires yaml module installed.') return yaml.dump(self._updated_config(), **kwargs) @@ -1989,370 +1314,12 @@ def get_source_inputs(tensor, layer=None, node_index=None): return source_tensors -def to_list(x): - """Normalizes a list/tensor into a list. - - If a tensor is passed, we return - a list of size 1 containing the tensor. - - Arguments: - x: target object to be normalized. - - Returns: - A list. - """ - if isinstance(x, list): - return x - return [x] - - -def save_weights_to_hdf5_group(f, layers): - from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top - - f.attrs['layer_names'] = [layer.name.encode('utf8') for layer in layers] - f.attrs['backend'] = K.backend().encode('utf8') - f.attrs['keras_version'] = str(keras_version).encode('utf8') - - for layer in layers: - g = f.create_group(layer.name) - symbolic_weights = layer.weights - weight_values = K.batch_get_value(symbolic_weights) - weight_names = [] - for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)): - if hasattr(w, 'name') and w.name: - name = str(w.name) - else: - name = 'param_' + str(i) - weight_names.append(name.encode('utf8')) - g.attrs['weight_names'] = weight_names - for name, val in zip(weight_names, weight_values): - param_dset = g.create_dataset(name, val.shape, dtype=val.dtype) - if not val.shape: - # scalar - param_dset[()] = val - else: - param_dset[:] = val - - -def preprocess_weights_for_loading(layer, - weights, - original_keras_version=None, - original_backend=None): - """Converts layers weights from Keras 1 format to Keras 2. - - Arguments: - layer: Layer instance. - weights: List of weights values (Numpy arrays). - original_keras_version: Keras version for the weights, as a string. - original_backend: Keras backend the weights were trained with, - as a string. - - Returns: - A list of weights values (Numpy arrays). - """ - if layer.__class__.__name__ == 'Bidirectional': - num_weights_per_layer = len(weights) // 2 - forward_weights = preprocess_weights_for_loading( - layer.forward_layer, weights[:num_weights_per_layer], - original_keras_version, original_backend) - backward_weights = preprocess_weights_for_loading( - layer.backward_layer, weights[num_weights_per_layer:], - original_keras_version, original_backend) - weights = forward_weights + backward_weights - - if original_keras_version == '1': - if layer.__class__.__name__ == 'TimeDistributed': - weights = preprocess_weights_for_loading( - layer.layer, weights, original_keras_version, original_backend) - - if layer.__class__.__name__ == 'Conv1D': - shape = weights[0].shape - # Handle Keras 1.1 format - if shape[:2] != (layer.kernel_size[0], 1) or shape[3] != layer.filters: - # Legacy shape: - # (filters, input_dim, filter_length, 1) - assert shape[0] == layer.filters and shape[2:] == (layer.kernel_size[0], - 1) - weights[0] = np.transpose(weights[0], (2, 3, 1, 0)) - weights[0] = weights[0][:, 0, :, :] - - if layer.__class__.__name__ == 'Conv2D': - if layer.data_format == 'channels_first': - # old: (filters, stack_size, kernel_rows, kernel_cols) - # new: (kernel_rows, kernel_cols, stack_size, filters) - weights[0] = np.transpose(weights[0], (2, 3, 1, 0)) - - if layer.__class__.__name__ == 'Conv2DTranspose': - if layer.data_format == 'channels_last': - # old: (kernel_rows, kernel_cols, stack_size, filters) - # new: (kernel_rows, kernel_cols, filters, stack_size) - weights[0] = np.transpose(weights[0], (0, 1, 3, 2)) - if layer.data_format == 'channels_first': - # old: (filters, stack_size, kernel_rows, kernel_cols) - # new: (kernel_rows, kernel_cols, filters, stack_size) - weights[0] = np.transpose(weights[0], (2, 3, 0, 1)) - - if layer.__class__.__name__ == 'Conv3D': - if layer.data_format == 'channels_first': - # old: (filters, stack_size, ...) - # new: (..., stack_size, filters) - weights[0] = np.transpose(weights[0], (2, 3, 4, 1, 0)) - - if layer.__class__.__name__ == 'GRU': - if len(weights) == 9: - kernel = np.concatenate([weights[0], weights[3], weights[6]], axis=-1) - recurrent_kernel = np.concatenate( - [weights[1], weights[4], weights[7]], axis=-1) - bias = np.concatenate([weights[2], weights[5], weights[8]], axis=-1) - weights = [kernel, recurrent_kernel, bias] - - if layer.__class__.__name__ == 'LSTM': - if len(weights) == 12: - # old: i, c, f, o - # new: i, f, c, o - kernel = np.concatenate( - [weights[0], weights[6], weights[3], weights[9]], axis=-1) - recurrent_kernel = np.concatenate( - [weights[1], weights[7], weights[4], weights[10]], axis=-1) - bias = np.concatenate( - [weights[2], weights[8], weights[5], weights[11]], axis=-1) - weights = [kernel, recurrent_kernel, bias] - - if layer.__class__.__name__ == 'ConvLSTM2D': - if len(weights) == 12: - kernel = np.concatenate( - [weights[0], weights[6], weights[3], weights[9]], axis=-1) - recurrent_kernel = np.concatenate( - [weights[1], weights[7], weights[4], weights[10]], axis=-1) - bias = np.concatenate( - [weights[2], weights[8], weights[5], weights[11]], axis=-1) - if layer.data_format == 'channels_first': - # old: (filters, stack_size, kernel_rows, kernel_cols) - # new: (kernel_rows, kernel_cols, stack_size, filters) - kernel = np.transpose(kernel, (2, 3, 1, 0)) - recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0)) - weights = [kernel, recurrent_kernel, bias] - - if layer.__class__.__name__ in ['Model', 'Sequential']: - new_weights = [] - # trainable weights - for sublayer in layer.layers: - num_weights = len(sublayer.trainable_weights) - if num_weights > 0: - new_weights.extend( - preprocess_weights_for_loading( - layer=sublayer, - weights=weights[:num_weights], - original_keras_version=original_keras_version, - original_backend=original_backend)) - weights = weights[num_weights:] - - # non-trainable weights - for sublayer in layer.layers: - num_weights = len([ - l for l in sublayer.weights if l not in sublayer.trainable_weights - ]) - if num_weights > 0: - new_weights.extend( - preprocess_weights_for_loading( - layer=sublayer, - weights=weights[:num_weights], - original_keras_version=original_keras_version, - original_backend=original_backend)) - weights = weights[num_weights:] - weights = new_weights - - conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D'] - if layer.__class__.__name__ in conv_layers: - if original_backend == 'theano': - weights[0] = conv_utils.convert_kernel(weights[0]) - if layer.__class__.__name__ == 'ConvLSTM2D': - weights[1] = conv_utils.convert_kernel(weights[1]) - if K.int_shape(layer.weights[0]) != weights[0].shape: - weights[0] = np.transpose(weights[0], (3, 2, 0, 1)) - if layer.__class__.__name__ == 'ConvLSTM2D': - weights[1] = np.transpose(weights[1], (3, 2, 0, 1)) - - # Convert the weights of CuDNNLSTM so that they could be loaded into LSTM - if layer.__class__.__name__ == 'LSTM' and len(weights) == 3: - # Determine if loading a CuDNNLSTM layer from the number of bias weights: - # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4) - # if there's no bias weight in the file, skip this conversion - units = weights[1].shape[0] - bias = weights[2] - if len(bias) == units * 8: - # reshape the kernels - kernels = np.split(weights[0], 4, axis=1) - kernels = [ - kernel.reshape(-1).reshape(kernel.shape, order='F') - for kernel in kernels - ] - weights[0] = np.concatenate(kernels, axis=1) - - # transpose the recurrent kernels - recurrent_kernels = np.split(weights[1], 4, axis=1) - recurrent_kernels = [kernel.T for kernel in recurrent_kernels] - weights[1] = np.concatenate(recurrent_kernels, axis=1) - - # split the bias into half and merge - weights[2] = bias[:units * 4] + bias[units * 4:] - - return weights - - -def load_weights_from_hdf5_group(f, layers): - """Implements topological (order-based) weight loading. - - Arguments: - f: A pointer to a HDF5 group. - layers: a list of target layers. - - Raises: - ValueError: in case of mismatch between provided layers - and weights file. - """ - if 'keras_version' in f.attrs: - original_keras_version = f.attrs['keras_version'].decode('utf8') - else: - original_keras_version = '1' - if 'backend' in f.attrs: - original_backend = f.attrs['backend'].decode('utf8') - else: - original_backend = None - - filtered_layers = [] - for layer in layers: - weights = layer.weights - if weights: - filtered_layers.append(layer) - - layer_names = [n.decode('utf8') for n in f.attrs['layer_names']] - filtered_layer_names = [] - for name in layer_names: - g = f[name] - weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] - if weight_names: - filtered_layer_names.append(name) - layer_names = filtered_layer_names - if len(layer_names) != len(filtered_layers): - raise ValueError('You are trying to load a weight file ' - 'containing ' + str(len(layer_names)) + - ' layers into a model with ' + str(len(filtered_layers)) + - ' layers.') - - # We batch weight value assignments in a single backend call - # which provides a speedup in TensorFlow. - weight_value_tuples = [] - for k, name in enumerate(layer_names): - g = f[name] - weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] - weight_values = [g[weight_name] for weight_name in weight_names] - layer = filtered_layers[k] - symbolic_weights = layer.weights - weight_values = preprocess_weights_for_loading( - layer, weight_values, original_keras_version, original_backend) - if len(weight_values) != len(symbolic_weights): - raise ValueError('Layer #' + str(k) + ' (named "' + layer.name + - '" in the current model) was found to ' - 'correspond to layer ' + name + ' in the save file. ' - 'However the new layer ' + layer.name + ' expects ' + - str(len(symbolic_weights)) + - ' weights, but the saved weights have ' + - str(len(weight_values)) + ' elements.') - weight_value_tuples += zip(symbolic_weights, weight_values) - K.batch_set_value(weight_value_tuples) - - -def load_weights_from_hdf5_group_by_name(f, layers): - """Implements name-based weight loading. - - (instead of topological weight loading). - - Layers that have no matching name are skipped. - - Arguments: - f: A pointer to a HDF5 group. - layers: a list of target layers. - - Raises: - ValueError: in case of mismatch between provided layers - and weights file. - """ - if 'keras_version' in f.attrs: - original_keras_version = f.attrs['keras_version'].decode('utf8') - else: - original_keras_version = '1' - if 'backend' in f.attrs: - original_backend = f.attrs['backend'].decode('utf8') - else: - original_backend = None - - # New file format. - layer_names = [n.decode('utf8') for n in f.attrs['layer_names']] - - # Reverse index of layer name to list of layers with name. - index = {} - for layer in layers: - if layer.name: - index.setdefault(layer.name, []).append(layer) - - # We batch weight value assignments in a single backend call - # which provides a speedup in TensorFlow. - weight_value_tuples = [] - for k, name in enumerate(layer_names): - g = f[name] - weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] - weight_values = [g[weight_name] for weight_name in weight_names] - - for layer in index.get(name, []): - symbolic_weights = layer.weights - weight_values = preprocess_weights_for_loading( - layer, weight_values, original_keras_version, original_backend) - if len(weight_values) != len(symbolic_weights): - raise ValueError('Layer #' + str(k) + ' (named "' + layer.name + - '") expects ' + str(len(symbolic_weights)) + - ' weight(s), but the saved weights' + ' have ' + - str(len(weight_values)) + ' element(s).') - # Set values. - for i in range(len(weight_values)): - weight_value_tuples.append((symbolic_weights[i], weight_values[i])) - K.batch_set_value(weight_value_tuples) - - -def shape_type_conversion(fn): - """Decorator that handles tuple/TensorShape conversion. - - Used in `compute_output_shape` and `build`. - - Arguments: - fn: function to wrap. - - Returns: - Wrapped function. - """ - - def wrapper(instance, input_shape): - if input_shape is not None: - if isinstance(input_shape, list): - input_shape = [ - tuple(tensor_shape.TensorShape(x).as_list()) for x in input_shape] - else: - input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list()) - output_shape = fn(instance, input_shape) - if output_shape is not None: - if isinstance(output_shape, list): - return [tensor_shape.TensorShape(x) for x in output_shape] - return tensor_shape.TensorShape(output_shape) - - return wrapper - - def _make_node_key(layer_name, node_index): return layer_name + '_ib-' + str(node_index) def _map_graph_network(inputs, outputs): - """Validate a network's topology and gather its layers and nodes. + """Validates a network's topology and gather its layers and nodes. Arguments: inputs: List of input tensors. diff --git a/tensorflow/python/keras/_impl/keras/engine/saving.py b/tensorflow/python/keras/_impl/keras/engine/saving.py new file mode 100644 index 0000000000000000000000000000000000000000..2ad06ca4fdcd55c12ba3ba192751f2f05aacc7ec --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/saving.py @@ -0,0 +1,844 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +"""Model saving utilities. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os + +import numpy as np +from six.moves import zip # pylint: disable=redefined-builtin + +from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras import optimizers +from tensorflow.python.keras._impl.keras.utils import conv_utils +from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export + +# pylint: disable=g-import-not-at-top +try: + import h5py + HDF5_OBJECT_HEADER_LIMIT = 64512 +except ImportError: + h5py = None + +try: + import yaml +except ImportError: + yaml = None +# pylint: enable=g-import-not-at-top + + +@tf_export('keras.models.save_model') +def save_model(model, filepath, overwrite=True, include_optimizer=True): + """Saves a model to a HDF5 file. + + The saved model contains: + - the model's configuration (topology) + - the model's weights + - the model's optimizer's state (if any) + + Thus the saved model can be reinstantiated in + the exact same state, without any of the code + used for model definition or training. + + Arguments: + model: Keras model instance to be saved. + filepath: String, path where to save the model. + overwrite: Whether we should overwrite any existing + model at the target location, or instead + ask the user with a manual prompt. + include_optimizer: If True, save optimizer's state together. + + Raises: + ImportError: if h5py is not available. + """ + + if h5py is None: + raise ImportError('`save_model` requires h5py.') + + def get_json_type(obj): + """Serializes any object to a JSON-serializable structure. + + Arguments: + obj: the object to serialize + + Returns: + JSON-serializable structure representing `obj`. + + Raises: + TypeError: if `obj` cannot be serialized. + """ + # if obj is a serializable Keras class instance + # e.g. optimizer, layer + if hasattr(obj, 'get_config'): + return {'class_name': obj.__class__.__name__, 'config': obj.get_config()} + + # if obj is any numpy type + if type(obj).__module__ == np.__name__: + if isinstance(obj, np.ndarray): + return {'type': type(obj), 'value': obj.tolist()} + else: + return obj.item() + + # misc functions (e.g. loss function) + if callable(obj): + return obj.__name__ + + # if obj is a python 'type' + if type(obj).__name__ == type.__name__: + return obj.__name__ + + raise TypeError('Not JSON Serializable:', obj) + + from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top + + # If file exists and should not be overwritten. + if not overwrite and os.path.isfile(filepath): + proceed = ask_to_proceed_with_overwrite(filepath) + if not proceed: + return + + with h5py.File(filepath, mode='w') as f: + f.attrs['keras_version'] = str(keras_version).encode('utf8') + f.attrs['backend'] = K.backend().encode('utf8') + f.attrs['model_config'] = json.dumps( + { + 'class_name': model.__class__.__name__, + 'config': model.get_config() + }, + default=get_json_type).encode('utf8') + + model_weights_group = f.create_group('model_weights') + model_layers = model.layers + save_weights_to_hdf5_group(model_weights_group, model_layers) + + if include_optimizer and hasattr(model, 'optimizer'): + if isinstance(model.optimizer, optimizers.TFOptimizer): + logging.warning( + 'TensorFlow optimizers do not ' + 'make it possible to access ' + 'optimizer attributes or optimizer state ' + 'after instantiation. ' + 'As a result, we cannot save the optimizer ' + 'as part of the model save file.' + 'You will have to compile your model again after loading it. ' + 'Prefer using a Keras optimizer instead ' + '(see keras.io/optimizers).') + else: + f.attrs['training_config'] = json.dumps( + { + 'optimizer_config': { + 'class_name': model.optimizer.__class__.__name__, + 'config': model.optimizer.get_config() + }, + 'loss': model.loss, + 'metrics': model.metrics, + 'sample_weight_mode': model.sample_weight_mode, + 'loss_weights': model.loss_weights, + }, + default=get_json_type).encode('utf8') + + # Save optimizer weights. + symbolic_weights = getattr(model.optimizer, 'weights') + if symbolic_weights: + optimizer_weights_group = f.create_group('optimizer_weights') + weight_values = K.batch_get_value(symbolic_weights) + weight_names = [] + for w, val in zip(symbolic_weights, weight_values): + name = str(w.name) + weight_names.append(name.encode('utf8')) + optimizer_weights_group.attrs['weight_names'] = weight_names + for name, val in zip(weight_names, weight_values): + param_dset = optimizer_weights_group.create_dataset( + name, val.shape, dtype=val.dtype) + if not val.shape: + # scalar + param_dset[()] = val + else: + param_dset[:] = val + f.flush() + + +@tf_export('keras.models.load_model') +def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=redefined-builtin + """Loads a model saved via `save_model`. + + Arguments: + filepath: String, path to the saved model. + custom_objects: Optional dictionary mapping names + (strings) to custom classes or functions to be + considered during deserialization. + compile: Boolean, whether to compile the model + after loading. + + Returns: + A Keras model instance. If an optimizer was found + as part of the saved model, the model is already + compiled. Otherwise, the model is uncompiled and + a warning will be displayed. When `compile` is set + to False, the compilation is omitted without any + warning. + + Raises: + ImportError: if h5py is not available. + ValueError: In case of an invalid savefile. + """ + if h5py is None: + raise ImportError('`load_model` requires h5py.') + + if not custom_objects: + custom_objects = {} + + def convert_custom_objects(obj): + """Handles custom object lookup. + + Arguments: + obj: object, dict, or list. + + Returns: + The same structure, where occurrences + of a custom object name have been replaced + with the custom object. + """ + if isinstance(obj, list): + deserialized = [] + for value in obj: + deserialized.append(convert_custom_objects(value)) + return deserialized + if isinstance(obj, dict): + deserialized = {} + for key, value in obj.items(): + deserialized[key] = convert_custom_objects(value) + return deserialized + if obj in custom_objects: + return custom_objects[obj] + return obj + + with h5py.File(filepath, mode='r') as f: + # instantiate model + model_config = f.attrs.get('model_config') + if model_config is None: + raise ValueError('No model found in config file.') + model_config = json.loads(model_config.decode('utf-8')) + model = model_from_config(model_config, custom_objects=custom_objects) + + # set weights + load_weights_from_hdf5_group(f['model_weights'], model.layers) + + # Early return if compilation is not required. + if not compile: + return model + + # instantiate optimizer + training_config = f.attrs.get('training_config') + if training_config is None: + logging.warning('No training configuration found in save file: ' + 'the model was *not* compiled. Compile it manually.') + return model + training_config = json.loads(training_config.decode('utf-8')) + optimizer_config = training_config['optimizer_config'] + optimizer = optimizers.deserialize( + optimizer_config, custom_objects=custom_objects) + + # Recover loss functions and metrics. + loss = convert_custom_objects(training_config['loss']) + metrics = convert_custom_objects(training_config['metrics']) + sample_weight_mode = training_config['sample_weight_mode'] + loss_weights = training_config['loss_weights'] + + # Compile model. + model.compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + loss_weights=loss_weights, + sample_weight_mode=sample_weight_mode) + + # Set optimizer weights. + if 'optimizer_weights' in f: + # Build train function (to get weight updates). + model._make_train_function() + optimizer_weights_group = f['optimizer_weights'] + optimizer_weight_names = [ + n.decode('utf8') + for n in optimizer_weights_group.attrs['weight_names'] + ] + optimizer_weight_values = [ + optimizer_weights_group[n] for n in optimizer_weight_names + ] + try: + model.optimizer.set_weights(optimizer_weight_values) + except ValueError: + logging.warning('Error in loading the saved optimizer ' + 'state. As a result, your model is ' + 'starting with a freshly initialized ' + 'optimizer.') + return model + + +@tf_export('keras.models.model_from_config') +def model_from_config(config, custom_objects=None): + """Instantiates a Keras model from its config. + + Arguments: + config: Configuration dictionary. + custom_objects: Optional dictionary mapping names + (strings) to custom classes or functions to be + considered during deserialization. + + Returns: + A Keras model instance (uncompiled). + + Raises: + TypeError: if `config` is not a dictionary. + """ + if isinstance(config, list): + raise TypeError('`model_from_config` expects a dictionary, not a list. ' + 'Maybe you meant to use ' + '`Sequential.from_config(config)`?') + from tensorflow.python.keras._impl.keras.layers import deserialize # pylint: disable=g-import-not-at-top + return deserialize(config, custom_objects=custom_objects) + + +@tf_export('keras.models.model_from_yaml') +def model_from_yaml(yaml_string, custom_objects=None): + """Parses a yaml model configuration file and returns a model instance. + + Arguments: + yaml_string: YAML string encoding a model configuration. + custom_objects: Optional dictionary mapping names + (strings) to custom classes or functions to be + considered during deserialization. + + Returns: + A Keras model instance (uncompiled). + + Raises: + ImportError: if yaml module is not found. + """ + if yaml is None: + raise ImportError('Requires yaml module installed.') + config = yaml.load(yaml_string) + from tensorflow.python.keras._impl.keras.layers import deserialize # pylint: disable=g-import-not-at-top + return deserialize(config, custom_objects=custom_objects) + + +@tf_export('keras.models.model_from_json') +def model_from_json(json_string, custom_objects=None): + """Parses a JSON model configuration file and returns a model instance. + + Arguments: + json_string: JSON string encoding a model configuration. + custom_objects: Optional dictionary mapping names + (strings) to custom classes or functions to be + considered during deserialization. + + Returns: + A Keras model instance (uncompiled). + """ + config = json.loads(json_string) + from tensorflow.python.keras._impl.keras.layers import deserialize # pylint: disable=g-import-not-at-top + return deserialize(config, custom_objects=custom_objects) + + +def preprocess_weights_for_loading(layer, + weights, + original_keras_version=None, + original_backend=None): + """Converts layers weights from Keras 1 format to Keras 2. + + Arguments: + layer: Layer instance. + weights: List of weights values (Numpy arrays). + original_keras_version: Keras version for the weights, as a string. + original_backend: Keras backend the weights were trained with, + as a string. + + Returns: + A list of weights values (Numpy arrays). + """ + if layer.__class__.__name__ == 'Bidirectional': + num_weights_per_layer = len(weights) // 2 + forward_weights = preprocess_weights_for_loading( + layer.forward_layer, weights[:num_weights_per_layer], + original_keras_version, original_backend) + backward_weights = preprocess_weights_for_loading( + layer.backward_layer, weights[num_weights_per_layer:], + original_keras_version, original_backend) + weights = forward_weights + backward_weights + + if original_keras_version == '1': + if layer.__class__.__name__ == 'TimeDistributed': + weights = preprocess_weights_for_loading( + layer.layer, weights, original_keras_version, original_backend) + + if layer.__class__.__name__ == 'Conv1D': + shape = weights[0].shape + # Handle Keras 1.1 format + if shape[:2] != (layer.kernel_size[0], 1) or shape[3] != layer.filters: + # Legacy shape: + # (filters, input_dim, filter_length, 1) + assert shape[0] == layer.filters and shape[2:] == (layer.kernel_size[0], + 1) + weights[0] = np.transpose(weights[0], (2, 3, 1, 0)) + weights[0] = weights[0][:, 0, :, :] + + if layer.__class__.__name__ == 'Conv2D': + if layer.data_format == 'channels_first': + # old: (filters, stack_size, kernel_rows, kernel_cols) + # new: (kernel_rows, kernel_cols, stack_size, filters) + weights[0] = np.transpose(weights[0], (2, 3, 1, 0)) + + if layer.__class__.__name__ == 'Conv2DTranspose': + if layer.data_format == 'channels_last': + # old: (kernel_rows, kernel_cols, stack_size, filters) + # new: (kernel_rows, kernel_cols, filters, stack_size) + weights[0] = np.transpose(weights[0], (0, 1, 3, 2)) + if layer.data_format == 'channels_first': + # old: (filters, stack_size, kernel_rows, kernel_cols) + # new: (kernel_rows, kernel_cols, filters, stack_size) + weights[0] = np.transpose(weights[0], (2, 3, 0, 1)) + + if layer.__class__.__name__ == 'Conv3D': + if layer.data_format == 'channels_first': + # old: (filters, stack_size, ...) + # new: (..., stack_size, filters) + weights[0] = np.transpose(weights[0], (2, 3, 4, 1, 0)) + + if layer.__class__.__name__ == 'GRU': + if len(weights) == 9: + kernel = np.concatenate([weights[0], weights[3], weights[6]], axis=-1) + recurrent_kernel = np.concatenate( + [weights[1], weights[4], weights[7]], axis=-1) + bias = np.concatenate([weights[2], weights[5], weights[8]], axis=-1) + weights = [kernel, recurrent_kernel, bias] + + if layer.__class__.__name__ == 'LSTM': + if len(weights) == 12: + # old: i, c, f, o + # new: i, f, c, o + kernel = np.concatenate( + [weights[0], weights[6], weights[3], weights[9]], axis=-1) + recurrent_kernel = np.concatenate( + [weights[1], weights[7], weights[4], weights[10]], axis=-1) + bias = np.concatenate( + [weights[2], weights[8], weights[5], weights[11]], axis=-1) + weights = [kernel, recurrent_kernel, bias] + + if layer.__class__.__name__ == 'ConvLSTM2D': + if len(weights) == 12: + kernel = np.concatenate( + [weights[0], weights[6], weights[3], weights[9]], axis=-1) + recurrent_kernel = np.concatenate( + [weights[1], weights[7], weights[4], weights[10]], axis=-1) + bias = np.concatenate( + [weights[2], weights[8], weights[5], weights[11]], axis=-1) + if layer.data_format == 'channels_first': + # old: (filters, stack_size, kernel_rows, kernel_cols) + # new: (kernel_rows, kernel_cols, stack_size, filters) + kernel = np.transpose(kernel, (2, 3, 1, 0)) + recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0)) + weights = [kernel, recurrent_kernel, bias] + + if layer.__class__.__name__ in ['Model', 'Sequential']: + new_weights = [] + # trainable weights + for sublayer in layer.layers: + num_weights = len(sublayer.trainable_weights) + if num_weights > 0: + new_weights.extend( + preprocess_weights_for_loading( + layer=sublayer, + weights=weights[:num_weights], + original_keras_version=original_keras_version, + original_backend=original_backend)) + weights = weights[num_weights:] + + # non-trainable weights + for sublayer in layer.layers: + num_weights = len([ + l for l in sublayer.weights if l not in sublayer.trainable_weights + ]) + if num_weights > 0: + new_weights.extend( + preprocess_weights_for_loading( + layer=sublayer, + weights=weights[:num_weights], + original_keras_version=original_keras_version, + original_backend=original_backend)) + weights = weights[num_weights:] + weights = new_weights + + conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D'] + if layer.__class__.__name__ in conv_layers: + if original_backend == 'theano': + weights[0] = conv_utils.convert_kernel(weights[0]) + if layer.__class__.__name__ == 'ConvLSTM2D': + weights[1] = conv_utils.convert_kernel(weights[1]) + if K.int_shape(layer.weights[0]) != weights[0].shape: + weights[0] = np.transpose(weights[0], (3, 2, 0, 1)) + if layer.__class__.__name__ == 'ConvLSTM2D': + weights[1] = np.transpose(weights[1], (3, 2, 0, 1)) + + # Convert the weights of CuDNNLSTM so that they could be loaded into LSTM + if layer.__class__.__name__ == 'LSTM' and len(weights) == 3: + # Determine if loading a CuDNNLSTM layer from the number of bias weights: + # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4) + # if there's no bias weight in the file, skip this conversion + units = weights[1].shape[0] + bias = weights[2] + if len(bias) == units * 8: + # reshape the kernels + kernels = np.split(weights[0], 4, axis=1) + kernels = [ + kernel.reshape(-1).reshape(kernel.shape, order='F') + for kernel in kernels + ] + weights[0] = np.concatenate(kernels, axis=1) + + # transpose the recurrent kernels + recurrent_kernels = np.split(weights[1], 4, axis=1) + recurrent_kernels = [kernel.T for kernel in recurrent_kernels] + weights[1] = np.concatenate(recurrent_kernels, axis=1) + + # split the bias into half and merge + weights[2] = bias[:units * 4] + bias[units * 4:] + + return convert_rnn_weights(layer, weights) + + +def convert_rnn_weights(layer, weights): + """Converts weights for RNN layers between native and CuDNN format. + + Input kernels for each gate are transposed and converted between Fortran + and C layout, recurrent kernels are transposed. For LSTM biases are summed/ + split in half, for GRU biases are reshaped. + + Weights can be converted in both directions between `LSTM` and`CuDNNSLTM` + and between `CuDNNGRU` and `GRU(reset_after=True)`. Default `GRU` is not + compatible with `CuDNNGRU`. + + For missing biases in `LSTM`/`GRU` (`use_bias=False`) no conversion is made. + + Arguments: + layer: Target layer instance. + weights: List of source weights values (input kernels, recurrent + kernels, [biases]) (Numpy arrays). + + Returns: + A list of converted weights values (Numpy arrays). + + Raises: + ValueError: for incompatible GRU layer/weights or incompatible biases + """ + + def transform_kernels(kernels, func, n_gates): + """Transforms kernel for each gate separately using given function. + + Arguments: + kernels: Stacked array of kernels for individual gates. + func: Function applied to kernel of each gate. + n_gates: Number of gates (4 for LSTM, 3 for GRU). + Returns: + Stacked array of transformed kernels. + """ + return np.hstack([func(k) for k in np.hsplit(kernels, n_gates)]) + + def transpose_input(from_cudnn): + """Makes a function that transforms input kernels from/to CuDNN format. + + It keeps the shape, but changes between the layout (Fortran/C). Eg.: + + ``` + Keras CuDNN + [[0, 1, 2], <---> [[0, 2, 4], + [3, 4, 5]] [1, 3, 5]] + ``` + + It can be passed to `transform_kernels()`. + + Arguments: + from_cudnn: `True` if source weights are in CuDNN format, `False` + if they're in plain Keras format. + Returns: + Function that converts input kernel to the other format. + """ + order = 'F' if from_cudnn else 'C' + + def transform(kernel): + return kernel.T.reshape(kernel.shape, order=order) + + return transform + + target_class = layer.__class__.__name__ + + # convert the weights between CuDNNLSTM and LSTM + if target_class in ['LSTM', 'CuDNNLSTM'] and len(weights) == 3: + # determine if we're loading a CuDNNLSTM layer + # from the number of bias weights: + # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4) + # if there's no bias weight in the file, skip this conversion + units = weights[1].shape[0] + bias_shape = weights[2].shape + n_gates = 4 + + if bias_shape == (2 * units * n_gates,): + source = 'CuDNNLSTM' + elif bias_shape == (units * n_gates,): + source = 'LSTM' + else: + raise ValueError('Invalid bias shape: ' + str(bias_shape)) + + def convert_lstm_weights(weights, from_cudnn=True): + # Transpose (and reshape) input and recurrent kernels. + kernels = transform_kernels(weights[0], transpose_input(from_cudnn), + n_gates) + recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates) + if from_cudnn: # Merge input and recurrent biases into a single set. + biases = np.sum(np.split(weights[2], 2, axis=0), axis=0) + else: + # Split single set of biases evenly to two sets. + biases = np.tile(0.5 * weights[2], 2) + return [kernels, recurrent_kernels, biases] + + if source != target_class: + weights = convert_lstm_weights(weights, from_cudnn=source == 'CuDNNLSTM') + + # TODO(fchollet): add feature after GRU is refactored: + # convert the weights between `CuDNNGRU` and `GRU(reset_after=True)` + return weights + + +def save_weights_to_hdf5_group(f, layers): + from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top + + save_attributes_to_hdf5_group( + f, 'layer_names', [layer.name.encode('utf8') for layer in layers]) + f.attrs['backend'] = K.backend().encode('utf8') + f.attrs['keras_version'] = str(keras_version).encode('utf8') + + for layer in layers: + g = f.create_group(layer.name) + symbolic_weights = layer.weights + weight_values = K.batch_get_value(symbolic_weights) + weight_names = [] + for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)): + if hasattr(w, 'name') and w.name: + name = str(w.name) + else: + name = 'param_' + str(i) + weight_names.append(name.encode('utf8')) + save_attributes_to_hdf5_group(g, 'weight_names', weight_names) + for name, val in zip(weight_names, weight_values): + param_dset = g.create_dataset(name, val.shape, dtype=val.dtype) + if not val.shape: + # scalar + param_dset[()] = val + else: + param_dset[:] = val + + +def load_weights_from_hdf5_group(f, layers): + """Implements topological (order-based) weight loading. + + Arguments: + f: A pointer to a HDF5 group. + layers: a list of target layers. + + Raises: + ValueError: in case of mismatch between provided layers + and weights file. + """ + if 'keras_version' in f.attrs: + original_keras_version = f.attrs['keras_version'].decode('utf8') + else: + original_keras_version = '1' + if 'backend' in f.attrs: + original_backend = f.attrs['backend'].decode('utf8') + else: + original_backend = None + + filtered_layers = [] + for layer in layers: + weights = layer.weights + if weights: + filtered_layers.append(layer) + + layer_names = load_attributes_from_hdf5_group(f, 'layer_names') + filtered_layer_names = [] + for name in layer_names: + g = f[name] + weight_names = load_attributes_from_hdf5_group(g, 'weight_names') + if weight_names: + filtered_layer_names.append(name) + layer_names = filtered_layer_names + if len(layer_names) != len(filtered_layers): + raise ValueError('You are trying to load a weight file ' + 'containing ' + str(len(layer_names)) + + ' layers into a model with ' + str(len(filtered_layers)) + + ' layers.') + + # We batch weight value assignments in a single backend call + # which provides a speedup in TensorFlow. + weight_value_tuples = [] + for k, name in enumerate(layer_names): + g = f[name] + weight_names = load_attributes_from_hdf5_group(g, 'weight_names') + weight_values = [g[weight_name] for weight_name in weight_names] + layer = filtered_layers[k] + symbolic_weights = layer.weights + weight_values = preprocess_weights_for_loading( + layer, weight_values, original_keras_version, original_backend) + if len(weight_values) != len(symbolic_weights): + raise ValueError('Layer #' + str(k) + ' (named "' + layer.name + + '" in the current model) was found to ' + 'correspond to layer ' + name + ' in the save file. ' + 'However the new layer ' + layer.name + ' expects ' + + str(len(symbolic_weights)) + + ' weights, but the saved weights have ' + + str(len(weight_values)) + ' elements.') + weight_value_tuples += zip(symbolic_weights, weight_values) + K.batch_set_value(weight_value_tuples) + + +def load_weights_from_hdf5_group_by_name(f, layers): + """Implements name-based weight loading. + + (instead of topological weight loading). + + Layers that have no matching name are skipped. + + Arguments: + f: A pointer to a HDF5 group. + layers: a list of target layers. + + Raises: + ValueError: in case of mismatch between provided layers + and weights file. + """ + if 'keras_version' in f.attrs: + original_keras_version = f.attrs['keras_version'].decode('utf8') + else: + original_keras_version = '1' + if 'backend' in f.attrs: + original_backend = f.attrs['backend'].decode('utf8') + else: + original_backend = None + + # New file format. + layer_names = load_attributes_from_hdf5_group(f, 'layer_names') + + # Reverse index of layer name to list of layers with name. + index = {} + for layer in layers: + if layer.name: + index.setdefault(layer.name, []).append(layer) + + # We batch weight value assignments in a single backend call + # which provides a speedup in TensorFlow. + weight_value_tuples = [] + for k, name in enumerate(layer_names): + g = f[name] + weight_names = load_attributes_from_hdf5_group(g, 'weight_names') + weight_values = [g[weight_name] for weight_name in weight_names] + + for layer in index.get(name, []): + symbolic_weights = layer.weights + weight_values = preprocess_weights_for_loading( + layer, weight_values, original_keras_version, original_backend) + if len(weight_values) != len(symbolic_weights): + raise ValueError('Layer #' + str(k) + ' (named "' + layer.name + + '") expects ' + str(len(symbolic_weights)) + + ' weight(s), but the saved weights' + ' have ' + + str(len(weight_values)) + ' element(s).') + # Set values. + for i in range(len(weight_values)): + weight_value_tuples.append((symbolic_weights[i], weight_values[i])) + K.batch_set_value(weight_value_tuples) + + +def save_attributes_to_hdf5_group(group, name, data): + """Saves attributes (data) of the specified name into the HDF5 group. + + This method deals with an inherent problem of HDF5 file which is not + able to store data larger than HDF5_OBJECT_HEADER_LIMIT bytes. + + Arguments: + group: A pointer to a HDF5 group. + name: A name of the attributes to save. + data: Attributes data to store. + + Raises: + RuntimeError: If any single attribute is too large to be saved. + """ + # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT` + # because in that case even chunking the array would not make the saving + # possible. + bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT] + + # Expecting this to never be true. + if bad_attributes: + raise RuntimeError('The following attributes cannot be saved to HDF5 ' + 'file because they are larger than %d bytes: %s' % + (HDF5_OBJECT_HEADER_LIMIT, + ', '.join([x for x in bad_attributes]))) + + data_npy = np.asarray(data) + + num_chunks = 1 + chunked_data = np.array_split(data_npy, num_chunks) + + # This will never loop forever thanks to the test above. + while any([x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data]): + num_chunks += 1 + chunked_data = np.array_split(data_npy, num_chunks) + + if num_chunks > 1: + for chunk_id, chunk_data in enumerate(chunked_data): + group.attrs['%s%d' % (name, chunk_id)] = chunk_data + else: + group.attrs[name] = data + + +def load_attributes_from_hdf5_group(group, name): + """Loads attributes of the specified name from the HDF5 group. + + This method deals with an inherent problem + of HDF5 file which is not able to store + data larger than HDF5_OBJECT_HEADER_LIMIT bytes. + + Arguments: + group: A pointer to a HDF5 group. + name: A name of the attributes to load. + + Returns: + data: Attributes data. + """ + if name in group.attrs: + data = [n.decode('utf8') for n in group.attrs[name]] + else: + data = [] + chunk_id = 0 + while '%s%d' % (name, chunk_id) in group.attrs: + data.extend( + [n.decode('utf8') for n in group.attrs['%s%d' % (name, chunk_id)]]) + chunk_id += 1 + return data diff --git a/tensorflow/python/keras/_impl/keras/engine/saving_test.py b/tensorflow/python/keras/_impl/keras/engine/saving_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dde090120456f968267e1c572f22eda1bd6ed7c4 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/saving_test.py @@ -0,0 +1,461 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#,============================================================================ +"""Tests for model saving.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import tempfile + +import numpy as np + +from tensorflow.python.keras._impl import keras +from tensorflow.python.platform import test +from tensorflow.python.training import training as training_module + +try: + import h5py # pylint:disable=g-import-not-at-top +except ImportError: + h5py = None + + +class TestWeightSavingAndLoading(test.TestCase): + + def test_weight_loading(self): + with self.test_session(): + a = keras.layers.Input(shape=(2,)) + x = keras.layers.Dense(3)(a) + b = keras.layers.Dense(1)(x) + model = keras.models.Model(a, b) + + x = np.random.random((3, 2)) + ref_y = model.predict(x) + weights = model.get_weights() + model.set_weights(weights) + y = model.predict(x) + self.assertAllClose(ref_y, y) + + with self.assertRaises(ValueError): + model.set_weights(weights[1:]) + with self.assertRaises(ValueError): + model.set_weights(weights[::-1]) + + if h5py is None: + return # Skip rest of test if H5py isn't available. + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + + h5_path = os.path.join(temp_dir, 'test.h5') + model.save_weights(h5_path) + model.load_weights(h5_path) + y = model.predict(x) + self.assertAllClose(ref_y, y) + + model.load_weights(h5_path, by_name=True) + y = model.predict(x) + self.assertAllClose(ref_y, y) + + def test_weight_preprocessing(self): + input_dim = 3 + output_dim = 3 + size = 2 + cases = [ + [ + (keras.layers.Bidirectional(keras.layers.SimpleRNN(2))), + [np.random.random((2, 1)), np.random.random((2, 1))], + (None, 3, 2), + ], + [ + (keras.layers.TimeDistributed(keras.layers.Dense(1))), + [np.random.random((2, 1)), np.random.random((1,))], + (None, 3, 2), + ], + [ + (keras.layers.Conv1D(output_dim, size, use_bias=False)), + [np.random.random((output_dim, input_dim, size, 1))], + (None, 4, input_dim), + ], + [ + (keras.layers.Conv2D(output_dim, size, + use_bias=False, data_format='channels_first')), + [np.random.random((output_dim, input_dim, size, size))], + (None, input_dim, 4, 4), + ], + [ + (keras.layers.Conv2DTranspose(output_dim, size, + use_bias=False, + data_format='channels_first')), + [np.random.random((output_dim, input_dim, size, size))], + (None, input_dim, 4, 4), + ], + [ + (keras.layers.Conv2DTranspose(output_dim, size, + use_bias=False, + data_format='channels_last')), + [np.random.random((size, size, input_dim, output_dim))], + (None, 4, 4, input_dim), + ], + [ + (keras.layers.Conv3D(output_dim, size, + use_bias=False, data_format='channels_first')), + [np.random.random((output_dim, input_dim, size, size, size))], + (None, input_dim, 4, 4, 4), + ], + [ + (keras.layers.GRU(output_dim)), + [np.random.random((input_dim, output_dim)), + np.random.random((output_dim, output_dim)), + np.random.random((output_dim,)), + np.random.random((input_dim, output_dim)), + np.random.random((output_dim, output_dim)), + np.random.random((output_dim,)), + np.random.random((input_dim, output_dim)), + np.random.random((output_dim, output_dim)), + np.random.random((output_dim,))], + (None, 4, input_dim), + ], + [ + (keras.layers.LSTM(output_dim)), + [np.random.random((input_dim, output_dim)), + np.random.random((output_dim, output_dim)), + np.random.random((output_dim,)), + np.random.random((input_dim, output_dim)), + np.random.random((output_dim, output_dim)), + np.random.random((output_dim,)), + np.random.random((input_dim, output_dim)), + np.random.random((output_dim, output_dim)), + np.random.random((output_dim,)), + np.random.random((input_dim, output_dim)), + np.random.random((output_dim, output_dim)), + np.random.random((output_dim,))], + (None, 4, input_dim), + ], + ] + for layer, weights, input_shape in cases: + layer.build(input_shape) + _ = keras.engine.saving.preprocess_weights_for_loading( + layer, weights, original_keras_version='1') + + model = keras.models.Sequential([keras.layers.Dense(2, input_dim=2)]) + _ = keras.engine.saving.preprocess_weights_for_loading( + model, model.weights, original_keras_version='1') + + x = keras.Input((2,)) + y = keras.layers.Dense(2)(x) + model = keras.models.Model(x, y) + _ = keras.engine.saving.preprocess_weights_for_loading( + model, model.weights, original_keras_version='1') + + def test_sequential_weight_loading(self): + if h5py is None: + return + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + h5_path = os.path.join(temp_dir, 'test.h5') + + num_hidden = 5 + input_dim = 3 + batch_size = 5 + num_classes = 2 + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) + model.add(keras.layers.Dense(num_classes)) + + x = np.random.random((batch_size, input_dim)) + ref_y = model.predict(x) + + model.save_weights(h5_path) + + model = keras.models.Sequential() + model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) + model.add(keras.layers.Dense(num_classes)) + model.load_weights(h5_path) + y = model.predict(x) + + self.assertAllClose(y, ref_y) + + +class TestWholeModelSaving(test.TestCase): + + def test_sequential_model_saving(self): + if h5py is None: + return # Skip test if models cannot be saved. + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + + out = model.predict(x) + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) + + new_model = keras.models.load_model(fname) + os.close(fd) + os.remove(fname) + + out2 = new_model.predict(x) + self.assertAllClose(out, out2, atol=1e-05) + + # test that new updates are the same with both models + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + new_model.train_on_batch(x, y) + out = model.predict(x) + out2 = new_model.predict(x) + self.assertAllClose(out, out2, atol=1e-05) + + def test_sequential_model_saving_2(self): + if h5py is None: + return # Skip test if models cannot be saved. + + with self.test_session(): + # test with custom optimizer, loss + + class CustomOp(keras.optimizers.RMSprop): + pass + + def custom_loss(y_true, y_pred): + return keras.losses.mse(y_true, y_pred) + + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.Dense(3)) + model.compile(loss=custom_loss, optimizer=CustomOp(), metrics=['acc']) + + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + out = model.predict(x) + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) + + model = keras.models.load_model( + fname, + custom_objects={'CustomOp': CustomOp, + 'custom_loss': custom_loss}) + os.close(fd) + os.remove(fname) + + out2 = model.predict(x) + self.assertAllClose(out, out2, atol=1e-05) + + def test_functional_model_saving(self): + if h5py is None: + return # Skip test if models cannot be saved. + + with self.test_session(): + inputs = keras.layers.Input(shape=(3,)) + x = keras.layers.Dense(2)(inputs) + output = keras.layers.Dense(3)(x) + + model = keras.models.Model(inputs, output) + model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + out = model.predict(x) + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) + + model = keras.models.load_model(fname) + os.close(fd) + os.remove(fname) + + out2 = model.predict(x) + self.assertAllClose(out, out2, atol=1e-05) + + def test_saving_without_compilation(self): + if h5py is None: + return # Skip test if models cannot be saved. + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.Dense(3)) + model.compile(loss='mse', optimizer='sgd', metrics=['acc']) + + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) + model = keras.models.load_model(fname) + os.close(fd) + os.remove(fname) + + def test_saving_with_tf_optimizer(self): + if h5py is None: + return # Skip test if models cannot be saved. + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.Dense(3)) + model.compile(loss='mse', + optimizer=training_module.AdadeltaOptimizer(0.1), + metrics=['acc']) + + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) + model = keras.models.load_model(fname) + os.close(fd) + os.remove(fname) + + def test_saving_right_after_compilation(self): + if h5py is None: + return # Skip test if models cannot be saved. + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.Dense(3)) + model.compile(loss='mse', optimizer='sgd', metrics=['acc']) + model._make_train_function() + + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) + model = keras.models.load_model(fname) + os.close(fd) + os.remove(fname) + + def test_saving_lambda_numpy_array_arguments(self): + if h5py is None: + return # Skip test if models cannot be saved. + + mean = np.random.random((4, 2, 3)) + std = np.abs(np.random.random((4, 2, 3))) + 1e-5 + inputs = keras.layers.Input(shape=(4, 2, 3)) + output = keras.layers.Lambda(lambda image, mu, std: (image - mu) / std, + arguments={'mu': mean, 'std': std})(inputs) + model = keras.models.Model(inputs, output) + model.compile(loss='mse', optimizer='sgd', metrics=['acc']) + + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) + + model = keras.models.load_model(fname) + os.close(fd) + os.remove(fname) + + self.assertAllClose(mean, model.layers[1].arguments['mu']) + self.assertAllClose(std, model.layers[1].arguments['std']) + + def test_saving_model_with_long_layer_names(self): + if h5py is None: + return # Skip test if models cannot be saved. + + with self.test_session(): + # This layer name will make the `layers_name` HDF5 attribute blow + # out of proportion. Note that it fits into the internal HDF5 + # attribute memory limit on its own but because h5py converts + # the list of layer names into numpy array, which uses the same + # amout of memory for every item, it increases the memory + # requirements substantially. + x = keras.Input(shape=(2,), name='input_' + ('x' * (2**15))) + f = x + for i in range(4): + f = keras.layers.Dense(2, name='dense_%d' % (i,))(f) + model = keras.Model(inputs=[x], outputs=[f]) + model.compile(loss='mse', optimizer='adam', metrics=['acc']) + + x = np.random.random((1, 2)) + y = np.random.random((1, 2)) + model.train_on_batch(x, y) + out = model.predict(x) + + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) + model = keras.models.load_model(fname) + + # Check that the HDF5 files contains chunked array + # of layer names. + with h5py.File(fname, 'r') as h5file: + num_names_arrays = len([attr for attr in h5file['model_weights'].attrs + if attr.startswith('layer_names')]) + # The chunking of layer names array should have happend. + self.assertGreater(num_names_arrays, 0) + out2 = model.predict(x) + self.assertAllClose(out, out2, atol=1e-05) + + # Cleanup + os.close(fd) + os.remove(fname) + + def test_saving_model_with_long_weights_names(self): + if h5py is None: + return # Skip test if models cannot be saved. + + with self.test_session(): + x = keras.Input(shape=(2,), name='nested_model_input') + f = x + for i in range(4): + f = keras.layers.Dense(2, name='nested_model_dense_%d' % (i,))(f) + # This layer name will make the `weights_name` + # HDF5 attribute blow out of proportion. + f = keras.layers.Dense(2, name='nested_model_output' + ('x' * (2**15)))(f) + nested_model = keras.Model(inputs=[x], outputs=[f], name='nested_model') + + x = keras.Input(shape=(2,), name='outer_model_input') + f = nested_model(x) + f = keras.layers.Dense(2, name='outer_model_output')(f) + + model = keras.Model(inputs=[x], outputs=[f]) + model.compile(loss='mse', optimizer='adam', metrics=['acc']) + + x = np.random.random((1, 2)) + y = np.random.random((1, 2)) + model.train_on_batch(x, y) + out = model.predict(x) + + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) + model = keras.models.load_model(fname) + + # Check that the HDF5 files contains chunked array + # of weight names. + with h5py.File(fname, 'r') as h5file: + num_weight_arrays = len( + [attr for attr in h5file['model_weights']['nested_model'].attrs + if attr.startswith('weight_names')]) + # The chunking of layer names array should have happend. + self.assertGreater(num_weight_arrays, 0) + out2 = model.predict(x) + self.assertAllClose(out, out2, atol=1e-05) + + # Cleanup + os.close(fd) + os.remove(fname) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/_impl/keras/engine/sequential.py b/tensorflow/python/keras/_impl/keras/engine/sequential.py new file mode 100644 index 0000000000000000000000000000000000000000..66cef1f5b9cef302117fe1fa67a0cfdf694403f1 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/sequential.py @@ -0,0 +1,287 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +"""Home of the `Sequential` model. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras import layers as layer_module +from tensorflow.python.keras._impl.keras.engine import base_layer +from tensorflow.python.keras._impl.keras.engine import network +from tensorflow.python.keras._impl.keras.engine.input_layer import Input +from tensorflow.python.keras._impl.keras.engine.input_layer import InputLayer +from tensorflow.python.keras._impl.keras.engine.training import Model +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export + + +@tf_export('keras.models.Sequential', 'keras.Sequential') +class Sequential(Model): + """Linear stack of layers. + + Arguments: + layers: list of layers to add to the model. + + Example: + + ```python + # Optionally, the first layer can receive an `input_shape` argument: + model = Sequential() + model.add(Dense(32, input_shape=(500,))) + # Afterwards, we do automatic shape inference: + model.add(Dense(32)) + + # This is identical to the following: + model = Sequential() + model.add(Dense(32, input_dim=500)) + + # And to the following: + model = Sequential() + model.add(Dense(32, batch_input_shape=(None, 500))) + + # Note that you can also omit the `input_shape` argument: + # In that case the model gets built the first time you call `fit` (or other + # training and evaluation methods). + model = Sequential() + model.add(Dense(32)) + model.add(Dense(32)) + model.compile(optimizer=optimizer, loss=loss) + # This builds the model for the first time: + model.fit(x, y, batch_size=32, epochs=10) + + # Note that when using this delayed-build pattern (no input shape specified), + # the model doesn't have any weights until the first call + # to a training/evaluation method (since it isn't yet built): + model = Sequential() + model.add(Dense(32)) + model.add(Dense(32)) + model.weights # returns [] + + # Whereas if you specify the input shape, the model gets built continuously + # as you are adding layers: + model = Sequential() + model.add(Dense(32, input_shape=(500,))) + model.add(Dense(32)) + model.weights # returns list of length 4 + + When using the delayed-build pattern (no input shape specified), you can + choose to manually build your model by calling `build(batch_input_shape)`: + model = Sequential() + model.add(Dense(32)) + model.add(Dense(32)) + model.build((None, 500)) + model.weights # returns list of length 4 + ``` + """ + + def __init__(self, layers=None, name=None): + super(Sequential, self).__init__(name=name) + + # Add to the model any layers passed to the constructor. + if layers: + for layer in layers: + self.add(layer) + + @property + def layers(self): + # Historically, `sequential.layers` only returns layers that were added + # via `add`, and omits the auto-generated `InputLayer` that comes at the + # bottom of the stack. + if self._layers and isinstance(self._layers[0], InputLayer): + return self._layers[1:] + return self._layers + + def add(self, layer): + """Adds a layer instance on top of the layer stack. + + Arguments: + layer: layer instance. + + Raises: + TypeError: If `layer` is not a layer instance. + ValueError: In case the `layer` argument does not + know its input shape. + ValueError: In case the `layer` argument has + multiple output tensors, or is already connected + somewhere else (forbidden in `Sequential` models). + """ + if not isinstance(layer, (base_layer.Layer, base_layer.TFBaseLayer)): + raise TypeError('The added layer must be ' + 'an instance of class Layer. ' + 'Found: ' + str(layer)) + self.built = False + if not self._layers: + set_inputs = False + # First layer in model: check that it is an input layer. + if not isinstance(layer, InputLayer): + # Create an input tensor and call `layer` on the input tensor. + # First, we need to infer the expected input shape and dtype. + first_layer = layer + if isinstance(layer, (Model, Sequential)): + # We were passed a model as first layer. + # This requires a specific way to figure out the + # input shape and dtype. + if not layer.layers: + raise ValueError('Cannot add an empty model ' + 'to a `Sequential` model.') + # In case of nested models: recover the first layer + # of the deepest model to infer input shape and dtype. + first_layer = layer.layers[0] + while isinstance(first_layer, (Model, Sequential)): + first_layer = first_layer.layers[0] + batch_shape = first_layer._batch_input_shape + dtype = first_layer.dtype + + if hasattr(first_layer, '_batch_input_shape'): + batch_shape = first_layer._batch_input_shape + dtype = first_layer.dtype + # Instantiate the input layer. + x = Input( + batch_shape=batch_shape, + dtype=dtype, + name=layer.name + '_input') + # This will build the current layer + # and create the node connecting the current layer + # to the input layer we just created. + layer(x) + set_inputs = True + else: + # The layer doesn't know about its expected shape. We will have to + # build the model lazily on `fit`/etc. + batch_shape = None + else: + # Corner case where the user passes an InputLayer layer via `add`. + assert len(layer._inbound_nodes[-1].output_tensors) == 1 + set_inputs = True + + if set_inputs: + if len(layer._inbound_nodes[-1].output_tensors) != 1: + raise ValueError('All layers in a Sequential model ' + 'should have a single output tensor. ' + 'For multi-output layers, ' + 'use the functional API.') + + self.outputs = [layer._inbound_nodes[-1].output_tensors[0]] + self.inputs = network.get_source_inputs(self.outputs[0]) + elif self.outputs: + output_tensor = layer(self.outputs[0]) + if isinstance(output_tensor, list): + raise TypeError('All layers in a Sequential model ' + 'should have a single output tensor. ' + 'For multi-output layers, ' + 'use the functional API.') + self.outputs = [output_tensor] + if self.inputs: + self.build() + else: + self._layers.append(layer) + + def pop(self): + """Removes the last layer in the model. + + Raises: + TypeError: if there are no layers in the model. + """ + if not self.layers: + raise TypeError('There are no layers in the model.') + + self._layers.pop() + self.built = False + if not self.layers: + self.outputs = None + self.inputs = None + elif self.outputs: + self.layers[-1]._outbound_nodes = [] + self.outputs = [self.layers[-1].output] + self.build() + + def build(self, input_shape=None): + if input_shape and not self.inputs: + batch_shape = tuple(input_shape) + dtype = K.floatx() + x = Input( + batch_shape=batch_shape, dtype=dtype, name=self.name + '_input') + self.inputs = [x] + for layer in self._layers: + x = layer(x) + self.outputs = [x] + + if self.inputs: + self._init_graph_network(self.inputs, self.outputs, name=self.name) + self.built = True + + def predict_proba(self, x, batch_size=32, verbose=0): + """Generates class probability predictions for the input samples. + + The input samples are processed batch by batch. + + Arguments: + x: input data, as a Numpy array or list of Numpy arrays + (if the model has multiple inputs). + batch_size: integer. + verbose: verbosity mode, 0 or 1. + + Returns: + A Numpy array of probability predictions. + """ + preds = self.predict(x, batch_size, verbose) + if preds.min() < 0. or preds.max() > 1.: + logging.warning('Network returning invalid probability values. ' + 'The last layer might not normalize predictions ' + 'into probabilities ' + '(like softmax or sigmoid would).') + return preds + + def predict_classes(self, x, batch_size=32, verbose=0): + """Generate class predictions for the input samples. + + The input samples are processed batch by batch. + + Arguments: + x: input data, as a Numpy array or list of Numpy arrays + (if the model has multiple inputs). + batch_size: integer. + verbose: verbosity mode, 0 or 1. + + Returns: + A numpy array of class predictions. + """ + proba = self.predict(x, batch_size=batch_size, verbose=verbose) + if proba.shape[-1] > 1: + return proba.argmax(axis=-1) + else: + return (proba > 0.5).astype('int32') + + def get_config(self): + config = [] + for layer in self.layers: + config.append({ + 'class_name': layer.__class__.__name__, + 'config': layer.get_config() + }) + return copy.deepcopy(config) + + @classmethod + def from_config(cls, config, custom_objects=None): + model = cls() + for conf in config: + layer = layer_module.deserialize(conf, custom_objects=custom_objects) + model.add(layer) + return model diff --git a/tensorflow/python/keras/_impl/keras/engine/sequential_test.py b/tensorflow/python/keras/_impl/keras/engine/sequential_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c9a47581df03e0fc1ad38552ba8634862435cd80 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/sequential_test.py @@ -0,0 +1,176 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests specific to `Sequential` model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import test_util as tf_test_util +from tensorflow.python.keras._impl import keras +from tensorflow.python.platform import test +from tensorflow.python.training import rmsprop + + +class TestSequential(test.TestCase): + """Most Sequential model API tests are covered in `training_test.py`. + """ + + @tf_test_util.run_in_graph_and_eager_modes() + def test_basic_methods(self): + model = keras.models.Sequential() + model.add(keras.layers.Dense(1, input_dim=2)) + model.add(keras.layers.Dropout(0.3, name='dp')) + model.add(keras.layers.Dense(2, kernel_regularizer='l2', + kernel_constraint='max_norm')) + self.assertEqual(len(model.layers), 3) + self.assertEqual(len(model.weights), 2 * 2) + self.assertEqual(model.get_layer(name='dp').name, 'dp') + + @tf_test_util.run_in_graph_and_eager_modes() + def test_sequential_pop(self): + num_hidden = 5 + input_dim = 3 + batch_size = 5 + num_classes = 2 + + model = keras.models.Sequential() + model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) + model.add(keras.layers.Dense(num_classes)) + model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3)) + x = np.random.random((batch_size, input_dim)) + y = np.random.random((batch_size, num_classes)) + model.fit(x, y, epochs=1) + model.pop() + self.assertEqual(len(model.layers), 1) + self.assertEqual(model.output_shape, (None, num_hidden)) + model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3)) + y = np.random.random((batch_size, num_hidden)) + model.fit(x, y, epochs=1) + + # Test popping single-layer model + model = keras.models.Sequential() + model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) + model.pop() + self.assertEqual(model.layers, []) + self.assertEqual(model.outputs, None) + + # Invalid use case + model = keras.models.Sequential() + with self.assertRaises(TypeError): + model.pop() + + @tf_test_util.run_in_graph_and_eager_modes() + def test_sequential_deferred_build(self): + num_hidden = 5 + input_dim = 3 + batch_size = 5 + num_classes = 2 + + model = keras.models.Sequential() + # We don't specify the input shape. + model.add(keras.layers.Dense(num_hidden)) + model.add(keras.layers.Dense(num_classes)) + model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3)) + self.assertEqual(len(model.layers), 2) + self.assertEqual(len(model.weights), 0) + self.assertFalse(model.built) + + x = np.random.random((batch_size, input_dim)) + y = np.random.random((batch_size, num_classes)) + model.fit(x, y, epochs=1) + self.assertTrue(model.built) + self.assertEqual(model.inputs[0].get_shape().as_list(), [None, input_dim]) + self.assertEqual(model.outputs[0].get_shape().as_list(), + [None, num_classes]) + self.assertEqual(len(model.weights), 2 * 2) + + @tf_test_util.run_in_graph_and_eager_modes() + def test_invalid_use_cases(self): + # Added objects must be layer instances + with self.assertRaises(TypeError): + model = keras.models.Sequential() + model.add(None) + + # Added layers cannot have multiple outputs + class MyLayer(keras.layers.Layer): + + def call(self, inputs): + return [3 * inputs, 2 * inputs] + + def compute_output_shape(self, input_shape): + return [input_shape, input_shape] + + with self.assertRaises(ValueError): + model = keras.models.Sequential() + model.add(MyLayer(input_shape=(3,))) + with self.assertRaises(TypeError): + model = keras.models.Sequential() + model.add(keras.layers.Dense(1, input_dim=1)) + model.add(MyLayer()) + + @tf_test_util.run_in_graph_and_eager_modes() + def test_nested_sequential_trainability(self): + input_dim = 20 + num_units = 10 + num_classes = 2 + + inner_model = keras.models.Sequential() + inner_model.add(keras.layers.Dense(num_units, input_shape=(input_dim,))) + + model = keras.models.Sequential() + model.add(inner_model) + model.add(keras.layers.Dense(num_classes)) + + self.assertEqual(len(model.layers), 2) + + self.assertEqual(len(model.trainable_weights), 4) + inner_model.trainable = False + self.assertEqual(len(model.trainable_weights), 2) + inner_model.trainable = True + self.assertEqual(len(model.trainable_weights), 4) + + def test_sequential_update_disabling(self): + val_a = np.random.random((10, 4)) + val_out = np.random.random((10, 4)) + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.BatchNormalization(input_shape=(4,))) + + model.trainable = False + assert not model.updates + + model.compile('sgd', 'mse') + assert not model.updates + + x1 = model.predict(val_a) + model.train_on_batch(val_a, val_out) + x2 = model.predict(val_a) + self.assertAllClose(x1, x2, atol=1e-7) + + model.trainable = True + model.compile('sgd', 'mse') + assert model.updates + + model.train_on_batch(val_a, val_out) + x2 = model.predict(val_a) + assert np.abs(np.sum(x1 - x2)) > 1e-5 + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/_impl/keras/engine/topology_test.py b/tensorflow/python/keras/_impl/keras/engine/topology_test.py index 139621db6db1d8fee69bcc93a873eaa7e31dd8dc..b50277c8fff917d77694903c989fd02ea98b1711 100644 --- a/tensorflow/python/keras/_impl/keras/engine/topology_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/topology_test.py @@ -18,9 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import shutil - import numpy as np from tensorflow.python.eager import context @@ -28,7 +25,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.keras._impl import keras -from tensorflow.python.layers import base as base_layers +from tensorflow.python.layers import base as tf_base_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops @@ -39,11 +36,6 @@ try: except ImportError: yaml = None -try: - import h5py # pylint:disable=g-import-not-at-top -except ImportError: - h5py = None - class TopologyConstructionTest(test.TestCase): @@ -84,7 +76,7 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(len(layer.get_updates_for(x2)), 1) self.assertEqual(len(layer.get_updates_for(None)), 1) - network = keras.engine.topology.Network(x2, y2) + network = keras.engine.Network(x2, y2) self.assertEqual(len(network.updates), 2) self.assertEqual(len(network.get_updates_for(x1)), 0) self.assertEqual(len(network.get_updates_for(x2)), 1) @@ -146,7 +138,7 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(len(layer.get_losses_for(x2)), 1) self.assertEqual(len(layer.get_losses_for(None)), 1) - network = keras.engine.topology.Network(x2, y2) + network = keras.engine.Network(x2, y2) self.assertEqual(len(network.losses), 2) self.assertEqual(len(network.get_losses_for(x1)), 0) self.assertEqual(len(network.get_losses_for(x2)), 1) @@ -267,7 +259,7 @@ class TopologyConstructionTest(test.TestCase): x = keras.Input(shape=(32,)) dense = keras.layers.Dense(2) y = dense(x) - network = keras.engine.topology.Network(x, y, name='dense_network') + network = keras.engine.Network(x, y, name='dense_network') # test basic attributes self.assertEqual(network.name, 'dense_network') @@ -502,7 +494,7 @@ class TopologyConstructionTest(test.TestCase): self.assertListEqual([x.shape for x in fn_outputs], [(10, 64), (10, 5)]) # test get_source_inputs - self.assertListEqual(keras.engine.topology.get_source_inputs(c), [a, b]) + self.assertListEqual(keras.engine.network.get_source_inputs(c), [a, b]) # serialization / deserialization json_config = model.to_json() @@ -539,7 +531,9 @@ class TopologyConstructionTest(test.TestCase): e = keras.layers.Input(shape=(32,), name='input_e') f = keras.layers.Input(shape=(32,), name='input_f') + self.assertEqual(len(model.inputs), 2) g, h = model([e, f]) + self.assertEqual(len(model.inputs), 2) self.assertEqual(g.name, 'model/dense_2/BiasAdd:0') self.assertListEqual(g.get_shape().as_list(), c.get_shape().as_list()) @@ -721,7 +715,9 @@ class TopologyConstructionTest(test.TestCase): j = keras.layers.Input(shape=(32,), name='input_j') k = keras.layers.Input(shape=(32,), name='input_k') + self.assertEqual(len(model.inputs), 2) m, n = model([j, k]) + self.assertEqual(len(model.inputs), 2) tf_model = keras.models.Model([j, k], [m, n]) j_tf = array_ops.placeholder(dtype=dtypes.float32, shape=(None, 32)) @@ -759,10 +755,20 @@ class TopologyConstructionTest(test.TestCase): def compute_mask(self, inputs, mask=None): return array_ops.ones_like(inputs) - if context.in_graph_mode(): + if context.executing_eagerly(): + a = constant_op.constant([2] * 32) + mask = constant_op.constant([0, 1] * 16) + a._keras_mask = mask + b = MaskedLayer().apply(a) + self.assertTrue(hasattr(b, '_keras_mask')) + self.assertAllEqual( + self.evaluate(array_ops.ones_like(mask)), + self.evaluate(getattr(b, '_keras_mask'))) + self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b)) + else: x = keras.Input(shape=(32,)) y = MaskedLayer()(x) # pylint: disable=not-callable - network = keras.engine.topology.Network(x, y) + network = keras.engine.Network(x, y) # test callability on Input x_2 = keras.Input(shape=(32,)) @@ -773,15 +779,6 @@ class TopologyConstructionTest(test.TestCase): x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32)) y_2 = network(x_2) self.assertEqual(y_2.get_shape().as_list(), [None, 32]) - else: - a = constant_op.constant([2] * 32) - mask = constant_op.constant([0, 1] * 16) - a._keras_mask = mask - b = MaskedLayer().apply(a) - self.assertTrue(hasattr(b, '_keras_mask')) - self.assertAllEqual(self.evaluate(array_ops.ones_like(mask)), - self.evaluate(getattr(b, '_keras_mask'))) - self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b)) def test_activity_regularization_with_model_composition(self): @@ -875,139 +872,12 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(np.min(preds), 0.) # At least one unit was dropped. -class TestSaving(test.TestCase): - - def test_weight_loading(self): - with self.test_session(): - a = keras.layers.Input(shape=(2,)) - x = keras.layers.Dense(3)(a) - b = keras.layers.Dense(1)(x) - model = keras.models.Model(a, b) - - x = np.random.random((3, 2)) - ref_y = model.predict(x) - weights = model.get_weights() - model.set_weights(weights) - y = model.predict(x) - self.assertAllClose(ref_y, y) - - with self.assertRaises(ValueError): - model.set_weights(weights[1:]) - with self.assertRaises(ValueError): - model.set_weights(weights[::-1]) - - if h5py is None: - return # Skip rest of test if H5py isn't available. - - temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, temp_dir) - - h5_path = os.path.join(temp_dir, 'test.h5') - model.save_weights(h5_path) - model.load_weights(h5_path) - y = model.predict(x) - self.assertAllClose(ref_y, y) - - model.load_weights(h5_path, by_name=True) - y = model.predict(x) - self.assertAllClose(ref_y, y) - - def test_weight_preprocessing(self): - input_dim = 3 - output_dim = 3 - size = 2 - cases = [ - [ - (keras.layers.Bidirectional(keras.layers.SimpleRNN(2))), - [np.random.random((2, 1)), np.random.random((2, 1))], - (None, 3, 2), - ], - [ - (keras.layers.TimeDistributed(keras.layers.Dense(1))), - [np.random.random((2, 1)), np.random.random((1,))], - (None, 3, 2), - ], - [ - (keras.layers.Conv1D(output_dim, size, use_bias=False)), - [np.random.random((output_dim, input_dim, size, 1))], - (None, 4, input_dim), - ], - [ - (keras.layers.Conv2D(output_dim, size, - use_bias=False, data_format='channels_first')), - [np.random.random((output_dim, input_dim, size, size))], - (None, input_dim, 4, 4), - ], - [ - (keras.layers.Conv2DTranspose(output_dim, size, - use_bias=False, - data_format='channels_first')), - [np.random.random((output_dim, input_dim, size, size))], - (None, input_dim, 4, 4), - ], - [ - (keras.layers.Conv2DTranspose(output_dim, size, - use_bias=False, - data_format='channels_last')), - [np.random.random((size, size, input_dim, output_dim))], - (None, 4, 4, input_dim), - ], - [ - (keras.layers.Conv3D(output_dim, size, - use_bias=False, data_format='channels_first')), - [np.random.random((output_dim, input_dim, size, size, size))], - (None, input_dim, 4, 4, 4), - ], - [ - (keras.layers.GRU(output_dim)), - [np.random.random((input_dim, output_dim)), - np.random.random((output_dim, output_dim)), - np.random.random((output_dim,)), - np.random.random((input_dim, output_dim)), - np.random.random((output_dim, output_dim)), - np.random.random((output_dim,)), - np.random.random((input_dim, output_dim)), - np.random.random((output_dim, output_dim)), - np.random.random((output_dim,))], - (None, 4, input_dim), - ], - [ - (keras.layers.LSTM(output_dim)), - [np.random.random((input_dim, output_dim)), - np.random.random((output_dim, output_dim)), - np.random.random((output_dim,)), - np.random.random((input_dim, output_dim)), - np.random.random((output_dim, output_dim)), - np.random.random((output_dim,)), - np.random.random((input_dim, output_dim)), - np.random.random((output_dim, output_dim)), - np.random.random((output_dim,)), - np.random.random((input_dim, output_dim)), - np.random.random((output_dim, output_dim)), - np.random.random((output_dim,))], - (None, 4, input_dim), - ], - ] - for layer, weights, input_shape in cases: - layer.build(input_shape) - _ = keras.engine.topology.preprocess_weights_for_loading( - layer, weights, original_keras_version='1') - - model = keras.models.Sequential([keras.layers.Dense(2, input_dim=2)]) - _ = keras.engine.topology.preprocess_weights_for_loading( - model, model.weights, original_keras_version='1') - - x = keras.Input((2,)) - y = keras.layers.Dense(2)(x) - model = keras.models.Model(x, y) - _ = keras.engine.topology.preprocess_weights_for_loading( - model, model.weights, original_keras_version='1') - - class DeferredModeTest(test.TestCase): def testDeferredTensorAttributes(self): - x = base_layers._DeferredTensor(shape=(None, 2), dtype='float32', name='x') + x = tf_base_layers._DeferredTensor(shape=(None, 2), + dtype='float32', + name='x') self.assertEqual(str(x), 'DeferredTensor(\'x\', shape=(?, 2), dtype=float32)') self.assertEqual(repr(x), @@ -1015,23 +885,23 @@ class DeferredModeTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testSimpleNetworkBuilding(self): - inputs = keras.engine.topology.Input(shape=(32,)) - if context.in_eager_mode(): - self.assertIsInstance(inputs, base_layers._DeferredTensor) + inputs = keras.engine.Input(shape=(32,)) + if context.executing_eagerly(): + self.assertIsInstance(inputs, tf_base_layers._DeferredTensor) self.assertEqual(inputs.dtype.name, 'float32') self.assertEqual(inputs.shape.as_list(), [None, 32]) x = keras.layers.Dense(2)(inputs) - if context.in_eager_mode(): - self.assertIsInstance(x, base_layers._DeferredTensor) + if context.executing_eagerly(): + self.assertIsInstance(x, tf_base_layers._DeferredTensor) self.assertEqual(x.dtype.name, 'float32') self.assertEqual(x.shape.as_list(), [None, 2]) outputs = keras.layers.Dense(4)(x) - network = keras.engine.topology.Network(inputs, outputs) - self.assertIsInstance(network, keras.engine.topology.Network) + network = keras.engine.Network(inputs, outputs) + self.assertIsInstance(network, keras.engine.Network) - if context.in_eager_mode(): + if context.executing_eagerly(): # It should be possible to call such a network on EagerTensors. inputs = constant_op.constant( np.random.random((10, 32)).astype('float32')) @@ -1040,8 +910,8 @@ class DeferredModeTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testMultiIONetworkbuilding(self): - input_a = keras.engine.topology.Input(shape=(32,)) - input_b = keras.engine.topology.Input(shape=(16,)) + input_a = keras.engine.Input(shape=(32,)) + input_b = keras.engine.Input(shape=(16,)) a = keras.layers.Dense(16)(input_a) class AddLayer(keras.layers.Layer): @@ -1055,8 +925,8 @@ class DeferredModeTest(test.TestCase): c = AddLayer()([a, input_b]) # pylint: disable=not-callable c = keras.layers.Dense(2)(c) - network = keras.engine.topology.Network([input_a, input_b], [a, c]) - if context.in_eager_mode(): + network = keras.engine.Network([input_a, input_b], [a, c]) + if context.executing_eagerly(): a_val = constant_op.constant( np.random.random((10, 32)).astype('float32')) b_val = constant_op.constant( diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py index d8ea2fe3db500d3b52d80e46b0cff22a3d1c5915..08288d353efdb233f87c1e3c7c09cd405c1e1688 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training.py +++ b/tensorflow/python/keras/_impl/keras/engine/training.py @@ -18,500 +18,28 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy - import numpy as np from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import callbacks as cbks from tensorflow.python.keras._impl.keras import losses from tensorflow.python.keras._impl.keras import metrics as metrics_module from tensorflow.python.keras._impl.keras import optimizers +from tensorflow.python.keras._impl.keras.engine import training_arrays from tensorflow.python.keras._impl.keras.engine import training_eager -from tensorflow.python.keras._impl.keras.engine.topology import Layer -from tensorflow.python.keras._impl.keras.engine.topology import Network -from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer -from tensorflow.python.keras._impl.keras.utils.data_utils import OrderedEnqueuer -from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence -from tensorflow.python.keras._impl.keras.utils.generic_utils import make_batches -from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar +from tensorflow.python.keras._impl.keras.engine import training_generator +from tensorflow.python.keras._impl.keras.engine import training_utils +from tensorflow.python.keras._impl.keras.engine.base_layer import Layer +from tensorflow.python.keras._impl.keras.engine.network import Network from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays from tensorflow.python.layers.base import _DeferredTensor +from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.util.tf_export import tf_export -try: - from scipy.sparse import issparse # pylint: disable=g-import-not-at-top -except ImportError: - issparse = None - - -def _standardize_input_data(data, - names, - shapes=None, - check_batch_axis=True, - exception_prefix=''): - """Normalizes inputs and targets provided by users. - - Users may pass data as a list of arrays, dictionary of arrays, - or as a single array. We normalize this to an ordered list of - arrays (same order as `names`), while checking that the provided - arrays have shapes that match the network's expectations. - - Arguments: - data: User-provided input data (polymorphic). - names: List of expected array names. - shapes: Optional list of expected array shapes. - check_batch_axis: Boolean; whether to check that - the batch axis of the arrays matches the expected - value found in `shapes`. - exception_prefix: String prefix used for exception formatting. - - Returns: - List of standardized input arrays (one array per model input). - - Raises: - ValueError: in case of improperly formatted user-provided data. - """ - if not names: - if data is not None and hasattr(data, '__len__') and len(data): - raise ValueError('Error when checking model ' + exception_prefix + ': ' - 'expected no data, but got:', data) - return [] - if data is None: - return [None for _ in range(len(names))] - - if isinstance(data, dict): - try: - data = [ - data[x].values - if data[x].__class__.__name__ == 'DataFrame' else data[x] - for x in names - ] - except KeyError as e: - raise ValueError('No data provided for "' + e.args[0] + '". Need data ' - 'for each key in: ' + str(names)) - elif isinstance(data, list): - if isinstance(data[0], list): - data = [np.asarray(d) for d in data] - elif len(names) == 1 and isinstance(data[0], (float, int)): - data = [np.asarray(data)] - else: - data = [ - x.values if x.__class__.__name__ == 'DataFrame' else x for x in data - ] - else: - data = data.values if data.__class__.__name__ == 'DataFrame' else data - data = [data] - data = [ - np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x for x in data - ] - - if len(data) != len(names): - if data and hasattr(data[0], 'shape'): - raise ValueError('Error when checking model ' + exception_prefix + - ': the list of Numpy arrays that you are passing to ' - 'your model is not the size the model expected. ' - 'Expected to see ' + str(len(names)) + ' array(s), ' - 'but instead got the following list of ' + - str(len(data)) + ' arrays: ' + str(data)[:200] + '...') - elif len(names) > 1: - raise ValueError( - 'Error when checking model ' + exception_prefix + - ': you are passing a list as input to your model, ' - 'but the model expects a list of ' + str(len(names)) + - ' Numpy arrays instead. The list you passed was: ' + str(data)[:200]) - elif len(data) == 1 and not hasattr(data[0], 'shape'): - raise TypeError('Error when checking model ' + exception_prefix + - ': data should be a Numpy array, or list/dict of ' - 'Numpy arrays. Found: ' + str(data)[:200] + '...') - elif len(names) == 1: - data = [np.asarray(data)] - - # Check shapes compatibility. - if shapes: - for i in range(len(names)): - if shapes[i] is not None: - data_shape = data[i].shape - shape = shapes[i] - if data[i].ndim != len(shape): - raise ValueError('Error when checking ' + exception_prefix + - ': expected ' + names[i] + ' to have ' + - str(len(shape)) + ' dimensions, but got array ' - 'with shape ' + str(data_shape)) - if not check_batch_axis: - data_shape = data_shape[1:] - shape = shape[1:] - for dim, ref_dim in zip(data_shape, shape): - if ref_dim != dim and ref_dim: - raise ValueError( - 'Error when checking ' + exception_prefix + ': expected ' + - names[i] + ' to have shape ' + str(shape) + - ' but got array with shape ' + str(data_shape)) - return data - - -def _standardize_sample_or_class_weights(x_weight, output_names, weight_type): - """Maps `sample_weight` or `class_weight` to model outputs. - - Arguments: - x_weight: User-provided `sample_weight` or `class_weight` argument. - output_names: List of output names (strings) in the model. - weight_type: A string used purely for exception printing. - - Returns: - A list of `sample_weight` or `class_weight` where there are exactly - one element per model output. - - Raises: - ValueError: In case of invalid user-provided argument. - """ - if x_weight is None or len(x_weight) == 0: # pylint: disable=g-explicit-length-test - return [None for _ in output_names] - if len(output_names) == 1: - if isinstance(x_weight, list) and len(x_weight) == 1: - return x_weight - if isinstance(x_weight, dict) and output_names[0] in x_weight: - return [x_weight[output_names[0]]] - else: - return [x_weight] - if isinstance(x_weight, list): - if len(x_weight) != len(output_names): - raise ValueError('Provided `' + weight_type + '` was a list of ' + - str(len(x_weight)) + ' elements, but the model has ' + - str(len(output_names)) + ' outputs. ' - 'You should provide one `' + weight_type + '`' - 'array per model output.') - return x_weight - if isinstance(x_weight, dict): - x_weights = [] - for name in output_names: - x_weights.append(x_weight.get(name)) - return x_weights - else: - raise TypeError( - 'The model has multiple outputs, so `' + weight_type + '` ' - 'should be either a list or a dict. ' - 'Provided `' + weight_type + '` type not understood: ' + str(x_weight)) - - -def _standardize_class_weights(class_weight, output_names): - return _standardize_sample_or_class_weights(class_weight, output_names, - 'class_weight') - - -def _standardize_sample_weights(sample_weight, output_names): - return _standardize_sample_or_class_weights(sample_weight, output_names, - 'sample_weight') - - -def _check_array_lengths(inputs, targets, weights=None): - """Does user input validation for numpy arrays. - - Arguments: - inputs: list of Numpy arrays of inputs. - targets: list of Numpy arrays of targets. - weights: list of Numpy arrays of sample weights. - - Raises: - ValueError: in case of incorrectly formatted data. - """ - - def set_of_lengths(x): - # return a set with the variation between - # different shapes, with None => 0 - if x is None: - return {0} - else: - return set([0 if y is None else y.shape[0] for y in x]) - - set_x = set_of_lengths(inputs) - set_y = set_of_lengths(targets) - set_w = set_of_lengths(weights) - if len(set_x) > 1: - raise ValueError('All input arrays (x) should have ' - 'the same number of samples. Got array shapes: ' + - str([x.shape for x in inputs])) - if len(set_y) > 1: - raise ValueError('All target arrays (y) should have ' - 'the same number of samples. Got array shapes: ' + - str([y.shape for y in targets])) - if set_x and set_y and list(set_x)[0] != list(set_y)[0]: - raise ValueError('Input arrays should have ' - 'the same number of samples as target arrays. ' - 'Found ' + str(list(set_x)[0]) + ' input samples ' - 'and ' + str(list(set_y)[0]) + ' target samples.') - if len(set_w) > 1: - raise ValueError('All sample_weight arrays should have ' - 'the same number of samples. Got array shapes: ' + - str([w.shape for w in weights])) - if set_y and set_w and list(set_y)[0] != list(set_w)[0]: - raise ValueError('Sample_weight arrays should have ' - 'the same number of samples as target arrays. Got ' + - str(list(set_y)[0]) + ' input samples and ' + - str(list(set_w)[0]) + ' target samples.') - - -def _check_loss_and_target_compatibility(targets, loss_fns, output_shapes): - """Does validation on the compatibility of targets and loss functions. - - This helps prevent users from using loss functions incorrectly. - - Arguments: - targets: list of Numpy arrays of targets. - loss_fns: list of loss functions. - output_shapes: list of shapes of model outputs. - - Raises: - ValueError: if a loss function or target array - is incompatible with an output. - """ - key_losses = { - losses.mean_squared_error, losses.binary_crossentropy, - losses.categorical_crossentropy - } - for y, loss, shape in zip(targets, loss_fns, output_shapes): - if y is None or loss is None: - continue - if loss is losses.categorical_crossentropy: - if y.shape[-1] == 1: - raise ValueError('You are passing a target array of shape ' + str( - y.shape) + ' while using as loss `categorical_crossentropy`. ' - '`categorical_crossentropy` expects ' - 'targets to be binary matrices (1s and 0s) ' - 'of shape (samples, classes). ' - 'If your targets are integer classes, ' - 'you can convert them to the expected format via:\n' - '```\n' - 'from keras.utils import to_categorical\n' - 'y_binary = to_categorical(y_int)\n' - '```\n' - '\n' - 'Alternatively, you can use the loss function ' - '`sparse_categorical_crossentropy` instead, ' - 'which does expect integer targets.') - if loss in key_losses: - for target_dim, out_dim in zip(y.shape[1:], shape[1:]): - if out_dim is not None and target_dim != out_dim: - raise ValueError('A target array with shape ' + str(y.shape) + - ' was passed for an output of shape ' + str(shape) + - ' while using as loss `' + loss.__name__ + '`. ' - 'This loss expects ' - 'targets to have the same shape ' - 'as the output.') - - -def _collect_metrics(metrics, output_names): - """Maps metric functions to model outputs. - - Arguments: - metrics: a list or dict of metric functions. - output_names: a list of the names (strings) of model outputs. - - Returns: - A list (one entry per model output) of lists of metric functions. - For instance, if the model has 2 outputs, and for the first output - we want to compute "binary_accuracy" and "binary_crossentropy", - and just "binary_accuracy" for the second output, - the list would look like: - `[[binary_accuracy, binary_crossentropy], [binary_accuracy]]` - - Raises: - TypeError: if an incorrect type is passed for the `metrics` argument. - """ - if not metrics: - return [[] for _ in output_names] - if isinstance(metrics, list): - # we then apply all metrics to all outputs. - return [copy.copy(metrics) for _ in output_names] - elif isinstance(metrics, dict): - nested_metrics = [] - for name in output_names: - output_metrics = metrics.get(name, []) - if not isinstance(output_metrics, list): - output_metrics = [output_metrics] - nested_metrics.append(output_metrics) - return nested_metrics - else: - raise TypeError('Type of `metrics` argument not understood. ' - 'Expected a list or dictionary, found: ' + str(metrics)) - - -def _batch_shuffle(index_array, batch_size): - """Shuffles an array in a batch-wise fashion. - - Useful for shuffling HDF5 arrays - (where one cannot access arbitrary indices). - - Arguments: - index_array: array of indices to be shuffled. - batch_size: integer. - - Returns: - The `index_array` array, shuffled in a batch-wise fashion. - """ - batch_count = int(len(index_array) / batch_size) - # to reshape we need to be cleanly divisible by batch size - # we stash extra items and reappend them after shuffling - last_batch = index_array[batch_count * batch_size:] - index_array = index_array[:batch_count * batch_size] - index_array = index_array.reshape((batch_count, batch_size)) - np.random.shuffle(index_array) - index_array = index_array.flatten() - return np.append(index_array, last_batch) - - -def _weighted_masked_objective(fn): - """Adds support for masking and sample-weighting to an objective function. - - It transforms an objective function `fn(y_true, y_pred)` - into a sample-weighted, cost-masked objective function - `fn(y_true, y_pred, weights, mask)`. - - Arguments: - fn: The objective function to wrap, - with signature `fn(y_true, y_pred)`. - - Returns: - A function with signature `fn(y_true, y_pred, weights, mask)`. - """ - if fn is None: - return None - - def weighted(y_true, y_pred, weights, mask=None): - """Wrapper function. - - Arguments: - y_true: `y_true` argument of `fn`. - y_pred: `y_pred` argument of `fn`. - weights: Weights tensor. - mask: Mask tensor. - - Returns: - Scalar tensor. - """ - # score_array has ndim >= 2 - score_array = fn(y_true, y_pred) - if mask is not None: - # Cast the mask to floatX to avoid float64 upcasting in theano - mask = K.cast(mask, K.floatx()) - # mask should have the same shape as score_array - score_array *= mask - # the loss per batch should be proportional - # to the number of unmasked samples. - score_array /= K.mean(mask) - - # apply sample weighting - if weights is not None: - # reduce score_array to same ndim as weight array - ndim = K.ndim(score_array) - weight_ndim = K.ndim(weights) - score_array = K.mean(score_array, axis=list(range(weight_ndim, ndim))) - score_array *= weights - score_array /= K.mean(K.cast(K.not_equal(weights, 0), K.floatx())) - return K.mean(score_array) - - return weighted - - -def _standardize_weights(y, - sample_weight=None, - class_weight=None, - sample_weight_mode=None): - """Performs sample weight validation and standardization. - - Everything gets normalized to a single sample-wise (or timestep-wise) - weight array. - - Arguments: - y: Numpy array of model targets to be weighted. - sample_weight: User-provided `sample_weight` argument. - class_weight: User-provided `class_weight` argument. - sample_weight_mode: One of `None` or `"temporal"`. - `"temporal"` indicated that we expect 2D weight data - that will be applied to the last 2 dimensions of - the targets (i.e. we are weighting timesteps, not samples). - - Returns: - A numpy array of target weights, one entry per sample to weight. - - Raises: - ValueError: In case of invalid user-provided arguments. - """ - if sample_weight_mode is not None: - if sample_weight_mode != 'temporal': - raise ValueError('"sample_weight_mode ' - 'should be None or "temporal". ' - 'Found: ' + str(sample_weight_mode)) - if len(y.shape) < 3: - raise ValueError('Found a sample_weight array for ' - 'an input with shape ' + str(y.shape) + '. ' - 'Timestep-wise sample weighting (use of ' - 'sample_weight_mode="temporal") is restricted to ' - 'outputs that are at least 3D, i.e. that have ' - 'a time dimension.') - if sample_weight is not None and len(sample_weight.shape) != 2: - raise ValueError('Found a sample_weight array with shape ' + - str(sample_weight.shape) + '. ' - 'In order to use timestep-wise sample weighting, ' - 'you should pass a 2D sample_weight array.') - else: - if sample_weight is not None and len(sample_weight.shape) != 1: - raise ValueError('Found a sample_weight array with shape ' + - str(sample_weight.shape) + '. ' - 'In order to use timestep-wise sample weights, ' - 'you should specify ' - 'sample_weight_mode="temporal" ' - 'in compile(). If you just mean to use ' - 'sample-wise weights, make sure your ' - 'sample_weight array is 1D.') - - if sample_weight is not None: - if len(sample_weight.shape) > len(y.shape): - raise ValueError( - 'Found a sample_weight with shape' + str(sample_weight.shape) + '.' - 'Expected sample_weight with rank ' - 'less than or equal to ' + str(len(y.shape))) - - if y.shape[:sample_weight.ndim] != sample_weight.shape: - raise ValueError( - 'Found a sample_weight array with shape ' + str(sample_weight.shape) + - ' for an input with shape ' + str(y.shape) + '. ' - 'sample_weight cannot be broadcast.') - return sample_weight - elif isinstance(class_weight, dict): - if len(y.shape) > 2: - raise ValueError('`class_weight` not supported for ' - '3+ dimensional targets.') - if y.shape[1] > 1: - y_classes = np.argmax(y, axis=1) - elif y.shape[1] == 1: - y_classes = np.reshape(y, y.shape[0]) - else: - y_classes = y - - weights = np.asarray( - [class_weight[cls] for cls in y_classes if cls in class_weight]) - - if len(weights) != len(y_classes): - # subtract the sets to pick all missing classes - existing_classes = set(y_classes) - existing_class_weight = set(class_weight.keys()) - raise ValueError('`class_weight` must contain all classes in the data.' - ' The classes %s exist in the data but not in ' - '`class_weight`.' % - (existing_classes - existing_class_weight)) - return weights - else: - if sample_weight_mode is None: - return np.ones((y.shape[0],), dtype=K.floatx()) - else: - return np.ones((y.shape[0], y.shape[1]), dtype=K.floatx()) - @tf_export('keras.models.Model', 'keras.Model') class Model(Network): @@ -634,7 +162,7 @@ class Model(Network): `optimizer`, `loss`, `metrics` or `sample_weight_mode`. """ loss = loss or {} - if context.in_eager_mode() and not isinstance( + if context.executing_eagerly() and not isinstance( optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)): raise ValueError('Only TF native optimizers are supported in Eager mode.') @@ -642,13 +170,13 @@ class Model(Network): self.loss = loss self.metrics = metrics or [] self.loss_weights = loss_weights - if context.in_eager_mode() and sample_weight_mode is not None: + if context.executing_eagerly() and sample_weight_mode is not None: raise ValueError('sample_weight_mode is not supported in Eager mode.') self.sample_weight_mode = sample_weight_mode - if context.in_eager_mode() and weighted_metrics is not None: + if context.executing_eagerly() and weighted_metrics is not None: raise ValueError('weighted_metrics is not supported in Eager mode.') self.weighted_metrics = weighted_metrics - if context.in_eager_mode() and target_tensors is not None: + if context.executing_eagerly() and target_tensors is not None: raise ValueError('target_tensors is not supported in Eager mode.') self.target_tensors = target_tensors @@ -688,7 +216,8 @@ class Model(Network): loss_functions = [loss_function for _ in range(len(self.outputs))] self.loss_functions = loss_functions - weighted_losses = [_weighted_masked_objective(fn) for fn in loss_functions] + weighted_losses = [training_utils.weighted_masked_objective(fn) + for fn in loss_functions] skip_target_indices = [] skip_target_weighing_indices = [] self._feed_outputs = [] @@ -701,7 +230,7 @@ class Model(Network): skip_target_weighing_indices.append(i) # Prepare output masks. - if context.in_graph_mode(): + if not context.executing_eagerly(): masks = self.compute_mask(self.inputs, mask=None) if masks is None: masks = [None for _ in self.outputs] @@ -735,9 +264,9 @@ class Model(Network): self.loss_weights_list = loss_weights_list # initialization for Eager mode execution - if context.in_eager_mode(): + if context.executing_eagerly(): if target_tensors is not None: - raise ValueError('target_tensors are not currently supported in Eager' + raise ValueError('target_tensors are not currently supported in Eager ' 'mode.') self.total_loss = None self.metrics_tensors = [] @@ -745,7 +274,8 @@ class Model(Network): for i in range(len(self.outputs)): if len(self.outputs) > 1: self.metrics_names.append(self.output_names[i] + '_loss') - self.nested_metrics = _collect_metrics(metrics, self.output_names) + self.nested_metrics = training_utils.collect_metrics(metrics, + self.output_names) self._feed_sample_weight_modes = [] for i in range(len(self.outputs)): self._feed_sample_weight_modes.append(None) @@ -862,12 +392,12 @@ class Model(Network): sample_weights.append(None) else: if sample_weight_mode == 'temporal': - sample_weights.append( - K.placeholder(ndim=2, name=name + '_sample_weights')) + sample_weights.append(array_ops.placeholder_with_default( + [[1.]], shape=[None, None], name=name + '_sample_weights')) sample_weight_modes.append('temporal') else: - sample_weights.append( - K.placeholder(ndim=1, name=name + '_sample_weights')) + sample_weights.append(array_ops.placeholder_with_default( + [1.], shape=[None], name=name + '_sample_weights')) sample_weight_modes.append(None) self.sample_weight_modes = sample_weight_modes self._feed_sample_weight_modes = [] @@ -915,9 +445,9 @@ class Model(Network): # List of same size as output_names. # contains tuples (metrics for output, names of metrics). - nested_metrics = _collect_metrics(metrics, self.output_names) - nested_weighted_metrics = _collect_metrics(weighted_metrics, - self.output_names) + nested_metrics = training_utils.collect_metrics(metrics, self.output_names) + nested_weighted_metrics = training_utils.collect_metrics(weighted_metrics, + self.output_names) self.metrics_updates = [] self.stateful_metric_names = [] with K.name_scope('metrics'): @@ -963,11 +493,13 @@ class Model(Network): suffix = 'acc' elif metric in ('crossentropy', 'ce'): suffix = 'ce' - weighted_metric_fn = _weighted_masked_objective(metric_fn) + weighted_metric_fn = training_utils.weighted_masked_objective( + metric_fn) metric_name = metric_name_prefix + suffix else: metric_fn = metrics_module.get(metric) - weighted_metric_fn = _weighted_masked_objective(metric_fn) + weighted_metric_fn = training_utils.weighted_masked_objective( + metric_fn) # Get metric name as string if hasattr(metric_fn, 'name'): metric_name = metric_fn.name @@ -1105,451 +637,6 @@ class Model(Network): name='predict_function', **kwargs) - def _check_num_samples(self, - ins, - batch_size=None, - steps=None, - steps_name='steps'): - """Determine the number of samples provided for training and evaluation. - - The number of samples is not defined when running with `steps`, - in which case the number of samples is set to `None`. - - Arguments: - ins: List of tensors to be fed to the Keras function. - batch_size: Integer batch size or `None` if not defined. - steps: Total number of steps (batches of samples) - before declaring `_predict_loop` finished. - Ignored with the default value of `None`. - steps_name: The public API's parameter name for `steps`. - - Raises: - ValueError: when `steps` is `None` and the attribute `ins.shape` - does not exist. Also raises ValueError when `steps` is not `None` - and `batch_size` is not `None` because they are mutually - exclusive. - - Returns: - When steps is `None`, returns the number of samples to be - processed based on the size of the first dimension of the - first input numpy array. When steps is not `None` and - `batch_size` is `None`, returns `None`. - - Raises: - ValueError: In case of invalid arguments. - """ - if steps is not None: - num_samples = None - if batch_size is not None: - raise ValueError( - 'If ' + steps_name + ' is set, the `batch_size` must be None.') - elif ins and hasattr(ins[0], 'shape'): - num_samples = ins[0].shape[0] - else: - raise ValueError( - 'Either the input data should have ' - 'a defined shape, or ' + steps_name + ' should be specified.') - return num_samples - - def _fit_loop(self, - f, - ins, - out_labels=None, - batch_size=None, - epochs=100, - verbose=1, - callbacks=None, - val_f=None, - val_ins=None, - shuffle=True, - callback_metrics=None, - initial_epoch=0, - steps_per_epoch=None, - validation_steps=None): - """Abstract fit function for `f(ins)`. - - Assume that f returns a list, labeled by out_labels. - - Arguments: - f: Keras function returning a list of tensors - ins: List of tensors to be fed to `f` - out_labels: List of strings, display names of - the outputs of `f` - batch_size: Integer batch size or None if unknown. - epochs: Number of times to iterate over the data - verbose: Verbosity mode, 0, 1 or 2 - callbacks: List of callbacks to be called during training - val_f: Keras function to call for validation - val_ins: List of tensors to be fed to `val_f` - shuffle: Whether to shuffle the data at the beginning of each epoch - callback_metrics: List of strings, the display names of the metrics - passed to the callbacks. They should be the - concatenation of list the display names of the outputs of - `f` and the list of display names of the outputs of `f_val`. - initial_epoch: Epoch at which to start training - (useful for resuming a previous training run) - steps_per_epoch: Total number of steps (batches of samples) - before declaring one epoch finished and starting the - next epoch. Ignored with the default value of `None`. - validation_steps: Number of steps to run validation for - (only if doing validation from data tensors). - Ignored with the default value of `None`. - - Returns: - `History` object. - - Raises: - ValueError: in case of invalid arguments. - """ - do_validation = False - if val_f and val_ins: - do_validation = True - if verbose and ins and hasattr(ins[0], 'shape') and hasattr( - val_ins[0], 'shape'): - print('Train on %d samples, validate on %d samples' % - (ins[0].shape[0], val_ins[0].shape[0])) - if validation_steps: - do_validation = True - if steps_per_epoch is None: - raise ValueError('Can only use `validation_steps` ' - 'when doing step-wise ' - 'training, i.e. `steps_per_epoch` ' - 'must be set.') - - num_train_samples = self._check_num_samples( - ins, batch_size, steps_per_epoch, 'steps_per_epoch') - if num_train_samples is not None: - index_array = np.arange(num_train_samples) - - self.history = cbks.History() - all_callbacks = [cbks.BaseLogger( - stateful_metrics=self.stateful_metric_names)] - if verbose: - if steps_per_epoch is not None: - count_mode = 'steps' - else: - count_mode = 'samples' - all_callbacks.append( - cbks.ProgbarLogger( - count_mode, stateful_metrics=self.stateful_metric_names)) - all_callbacks += (callbacks or []) + [self.history] - callbacks = cbks.CallbackList(all_callbacks) - out_labels = out_labels or [] - - # it's possible to callback a different model than self - # (used by Sequential models) - if hasattr(self, 'callback_model') and self.callback_model: - callback_model = self.callback_model - else: - callback_model = self - - callbacks.set_model(callback_model) - - callbacks.set_params({ - 'batch_size': batch_size, - 'epochs': epochs, - 'steps': steps_per_epoch, - 'samples': num_train_samples, - 'verbose': verbose, - 'do_validation': do_validation, - 'metrics': callback_metrics or [], - }) - callbacks.on_train_begin() - callback_model.stop_training = False - for cbk in callbacks: - cbk.validation_data = val_ins - - # To prevent a slowdown, we find beforehand the arrays that need conversion. - feed = self._feed_inputs + self._feed_targets + self._feed_sample_weights - indices_for_conversion_to_dense = [] - for i in range(len(feed)): - if issparse is not None and issparse(ins[i]) and not K.is_sparse(feed[i]): - indices_for_conversion_to_dense.append(i) - - for epoch in range(initial_epoch, epochs): - # Reset stateful metrics - for m in self.metrics: - if isinstance(m, Layer): - m.reset_states() - # Update callbacks - callbacks.on_epoch_begin(epoch) - epoch_logs = {} - if steps_per_epoch is not None: - for step_index in range(steps_per_epoch): - batch_logs = {} - batch_logs['batch'] = step_index - batch_logs['size'] = 1 - callbacks.on_batch_begin(step_index, batch_logs) - outs = f(ins) - - if not isinstance(outs, list): - outs = [outs] - for l, o in zip(out_labels, outs): - batch_logs[l] = o - - callbacks.on_batch_end(step_index, batch_logs) - if callback_model.stop_training: - break - - if do_validation: - val_outs = self._test_loop( - val_f, - val_ins, - batch_size=batch_size, - steps=validation_steps, - verbose=0) - if not isinstance(val_outs, list): - val_outs = [val_outs] - # Same labels assumed. - for l, o in zip(out_labels, val_outs): - epoch_logs['val_' + l] = o - else: - if shuffle == 'batch': - index_array = _batch_shuffle(index_array, batch_size) - elif shuffle: - np.random.shuffle(index_array) - - batches = make_batches(num_train_samples, batch_size) - - for batch_index, (batch_start, batch_end) in enumerate(batches): - batch_ids = index_array[batch_start:batch_end] - try: - if isinstance(ins[-1], float): - # Do not slice the training phase flag. - ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] - else: - ins_batch = slice_arrays(ins, batch_ids) - except TypeError: - raise TypeError('TypeError while preparing batch. ' - 'If using HDF5 input data, ' - 'pass shuffle="batch".') - batch_logs = {} - batch_logs['batch'] = batch_index - batch_logs['size'] = len(batch_ids) - callbacks.on_batch_begin(batch_index, batch_logs) - for i in indices_for_conversion_to_dense: - ins_batch[i] = ins_batch[i].toarray() - - outs = f(ins_batch) - if not isinstance(outs, list): - outs = [outs] - for l, o in zip(out_labels, outs): - batch_logs[l] = o - - callbacks.on_batch_end(batch_index, batch_logs) - if callback_model.stop_training: - break - - if batch_index == len(batches) - 1: # Last batch. - if do_validation: - val_outs = self._test_loop( - val_f, val_ins, batch_size=batch_size, verbose=0) - if not isinstance(val_outs, list): - val_outs = [val_outs] - # Same labels assumed. - for l, o in zip(out_labels, val_outs): - epoch_logs['val_' + l] = o - callbacks.on_epoch_end(epoch, epoch_logs) - if callback_model.stop_training: - break - callbacks.on_train_end() - return self.history - - def _predict_loop(self, f, ins, batch_size=32, verbose=0, steps=None): - """Abstract method to loop over some data in batches. - - Arguments: - f: Keras function returning a list of tensors. - ins: list of tensors to be fed to `f`. - batch_size: integer batch size. - verbose: verbosity mode. - steps: Total number of steps (batches of samples) - before declaring `_predict_loop` finished. - Ignored with the default value of `None`. - - Returns: - Array of predictions (if the model has a single output) - or list of arrays of predictions - (if the model has multiple outputs). - """ - if hasattr(self, 'metrics'): - for m in self.metrics: - if isinstance(m, Layer): - m.reset_states() - - num_samples = self._check_num_samples(ins, batch_size, steps, 'steps') - if verbose == 1: - if steps is not None: - progbar = Progbar(target=steps, - stateful_metrics=self.stateful_metric_names) - else: - progbar = Progbar(target=num_samples, - stateful_metrics=self.stateful_metric_names) - - indices_for_conversion_to_dense = [] - for i in range(len(self._feed_inputs)): - if (issparse is not None and issparse(ins[i]) and - not K.is_sparse(self._feed_inputs[i])): - indices_for_conversion_to_dense.append(i) - - if steps is not None: - # Step-based predictions. - # Since we do not know how many samples - # we will see, we cannot pre-allocate - # the returned Numpy arrays. - # Instead, we store one array per batch seen - # and concatenate them upon returning. - unconcatenated_outs = [] - for step in range(steps): - batch_outs = f(ins) - if not isinstance(batch_outs, list): - batch_outs = [batch_outs] - if step == 0: - for batch_out in batch_outs: - unconcatenated_outs.append([]) - for i, batch_out in enumerate(batch_outs): - unconcatenated_outs[i].append(batch_out) - if verbose == 1: - progbar.update(step + 1) - if len(unconcatenated_outs) == 1: - return np.concatenate(unconcatenated_outs[0], axis=0) - return [ - np.concatenate(unconcatenated_outs[i], axis=0) - for i in range(len(unconcatenated_outs)) - ] - else: - # Sample-based predictions. - outs = [] - batches = make_batches(num_samples, batch_size) - index_array = np.arange(num_samples) - for batch_index, (batch_start, batch_end) in enumerate(batches): - batch_ids = index_array[batch_start:batch_end] - if ins and isinstance(ins[-1], float): - # Do not slice the training phase flag. - ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] - else: - ins_batch = slice_arrays(ins, batch_ids) - for i in indices_for_conversion_to_dense: - ins_batch[i] = ins_batch[i].toarray() - - batch_outs = f(ins_batch) - if not isinstance(batch_outs, list): - batch_outs = [batch_outs] - if batch_index == 0: - # Pre-allocate the results arrays. - for batch_out in batch_outs: - shape = (num_samples,) + batch_out.shape[1:] - outs.append(np.zeros(shape, dtype=batch_out.dtype)) - for i, batch_out in enumerate(batch_outs): - outs[i][batch_start:batch_end] = batch_out - if verbose == 1: - progbar.update(batch_end) - if len(outs) == 1: - return outs[0] - return outs - - def _test_loop(self, f, ins, batch_size=None, verbose=0, steps=None): - """Abstract method to loop over some data in batches. - - Arguments: - f: Keras function returning a list of tensors. - ins: list of tensors to be fed to `f`. - batch_size: integer batch size or `None`. - verbose: verbosity mode. - steps: Total number of steps (batches of samples) - before declaring predictions finished. - Ignored with the default value of `None`. - - Returns: - Scalar loss (if the model has a single output and no metrics) - or list of scalars (if the model has multiple outputs - and/or metrics). The attribute `model.metrics_names` will give you - the display labels for the scalar outputs. - """ - if hasattr(self, 'metrics'): - for m in self.metrics: - if isinstance(m, Layer): - m.reset_states() - stateful_metric_indices = [ - i for i, name in enumerate(self.metrics_names) - if str(name) in self.stateful_metric_names - ] - else: - stateful_metric_indices = [] - - num_samples = self._check_num_samples(ins, batch_size, steps, 'steps') - outs = [] - if verbose == 1: - if steps is not None: - progbar = Progbar(target=steps) - else: - progbar = Progbar(target=num_samples) - - # To prevent a slowdown, we find beforehand the arrays that need conversion. - feed = self._feed_inputs + self._feed_targets + self._feed_sample_weights - indices_for_conversion_to_dense = [] - for i in range(len(feed)): - if issparse is not None and issparse(ins[i]) and not K.is_sparse(feed[i]): - indices_for_conversion_to_dense.append(i) - - if steps is not None: - for step in range(steps): - batch_outs = f(ins) - if isinstance(batch_outs, list): - if step == 0: - for _ in enumerate(batch_outs): - outs.append(0.) - for i, batch_out in enumerate(batch_outs): - if i in stateful_metric_indices: - outs[i] = batch_out - else: - outs[i] += batch_out - else: - if step == 0: - outs.append(0.) - outs[0] += batch_outs - if verbose == 1: - progbar.update(step + 1) - for i in range(len(outs)): - if i not in stateful_metric_indices: - outs[i] /= steps - else: - batches = make_batches(num_samples, batch_size) - index_array = np.arange(num_samples) - for batch_index, (batch_start, batch_end) in enumerate(batches): - batch_ids = index_array[batch_start:batch_end] - if isinstance(ins[-1], float): - # Do not slice the training phase flag. - ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] - else: - ins_batch = slice_arrays(ins, batch_ids) - for i in indices_for_conversion_to_dense: - ins_batch[i] = ins_batch[i].toarray() - - batch_outs = f(ins_batch) - - if isinstance(batch_outs, list): - if batch_index == 0: - for batch_out in enumerate(batch_outs): - outs.append(0.) - for i, batch_out in enumerate(batch_outs): - if i in stateful_metric_indices: - outs[i] = batch_out - else: - outs[i] += batch_out * len(batch_ids) - else: - if batch_index == 0: - outs.append(0.) - outs[0] += batch_outs * len(batch_ids) - if verbose == 1: - progbar.update(batch_end) - for i in range(len(outs)): - if i not in stateful_metric_indices: - outs[i] /= num_samples - if len(outs) == 1: - return outs[0] - return outs - def _standardize_user_data(self, x, y=None, @@ -1651,13 +738,13 @@ class Model(Network): 'TensorFlow tensors. ' 'You passed: x=' + str(x) + '; y=' + str(y)) - if context.in_graph_mode(): + if context.executing_eagerly(): + target_tensors = None + else: # Handle target tensors if any passed. if not isinstance(y, (list, tuple)): y = [y] target_tensors = [v for v in y if tensor_util.is_tensor(v)] - else: - target_tensors = None self.compile(optimizer=self.optimizer, loss=self.loss, metrics=self.metrics, @@ -1674,7 +761,7 @@ class Model(Network): # What follows is input validation and standardization to list format, # in the case where all inputs are value arrays. - if context.in_eager_mode(): + if context.executing_eagerly(): # In eager mode, do not do shape validation. feed_input_names = self.input_names feed_input_shapes = None @@ -1689,7 +776,7 @@ class Model(Network): feed_input_shapes = self._feed_input_shapes # Standardize the inputs. - x = _standardize_input_data( + x = training_utils.standardize_input_data( x, feed_input_names, feed_input_shapes, @@ -1697,7 +784,7 @@ class Model(Network): exception_prefix='input') if y is not None: - if context.in_eager_mode(): + if context.executing_eagerly(): feed_output_names = self.output_names feed_output_shapes = None # Sample weighting not supported in this case. @@ -1728,7 +815,7 @@ class Model(Network): feed_output_shapes.append(output_shape) # Standardize the outputs. - y = _standardize_input_data( + y = training_utils.standardize_input_data( y, feed_output_names, feed_output_shapes, @@ -1737,21 +824,21 @@ class Model(Network): # Generate sample-wise weight values given the `sample_weight` and # `class_weight` arguments. - sample_weights = _standardize_sample_weights(sample_weight, - feed_output_names) - class_weights = _standardize_class_weights(class_weight, - feed_output_names) + sample_weights = training_utils.standardize_sample_weights( + sample_weight, feed_output_names) + class_weights = training_utils.standardize_class_weights( + class_weight, feed_output_names) sample_weights = [ - _standardize_weights(ref, sw, cw, mode) + training_utils.standardize_weights(ref, sw, cw, mode) for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights, feed_sample_weight_modes) ] # Check that all arrays have the same length. - _check_array_lengths(x, y, sample_weights) - if self._is_graph_network and not context.in_eager_mode(): + training_utils.check_array_lengths(x, y, sample_weights) + if self._is_graph_network and not context.executing_eagerly(): # Additional checks to avoid users mistakenly using improper loss fns. - _check_loss_and_target_compatibility(y, self._feed_loss_fns, - feed_output_shapes) + training_utils.check_loss_and_target_compatibility( + y, self._feed_loss_fns, feed_output_shapes) else: y = [] sample_weights = [] @@ -1787,11 +874,27 @@ class Model(Network): whether to build the model's graph in inference mode (False), training mode (True), or using the Keras learning phase (None). """ - if context.in_eager_mode(): + if self.__class__.__name__ == 'Sequential': + # Note: we can't test whether the model is `Sequential` via `isinstance` + # since `Sequential` depends on `Model`. + if isinstance(inputs, list): + assert len(inputs) == 1 + inputs = inputs[0] + self.build(input_shape=(None,) + inputs.shape[1:]) + elif context.executing_eagerly(): self._eager_set_inputs(inputs) else: self._symbolic_set_inputs(inputs, training=training) + def _set_scope(self, scope=None): + """Modify the Layer scope creation logic to create ResourceVariables.""" + super(Model, self)._set_scope(scope=scope) + # Subclassed Models create ResourceVariables by default. This makes it + # easier to use Models in an eager/graph agnostic way (since eager execution + # always uses ResourceVariables). + if not self._is_graph_network: + self._scope.set_use_resource(True) + def _eager_set_inputs(self, inputs): """Set model's input and output specs based on the input data received. @@ -1807,7 +910,7 @@ class Model(Network): Raises: ValueError: If the model's inputs are already set. """ - assert context.in_eager_mode() + assert context.executing_eagerly() if self.inputs: raise ValueError('Model inputs are already set.') # On-the-fly setting of model inputs/outputs as DeferredTensors, @@ -1836,14 +939,17 @@ class Model(Network): 'output_%d' % (i + 1) for i in range(len(dummy_output_values))] self.built = True - def _symbolic_set_inputs(self, inputs, training=None): - """Set model's inputs based on the input data received from the user. + def _symbolic_set_inputs(self, inputs, outputs=None, training=None): + """Set model's inputs and output specs based. This is to be used for Model subclasses, which do not know at instantiation time what their inputs look like. Args: inputs: Argument `x` (input data) passed by the user upon first model use. + outputs: None, a data tensor, or a list of data tensors. If None, the + outputs will be determined by invoking self.call(), otherwise the + provided value will be used. training: Boolean or None. Only relevant in symbolic mode. Specifies whether to build the model's graph in inference mode (False), training mode (True), or using the Keras learning phase (None). @@ -1851,7 +957,7 @@ class Model(Network): Raises: ValueError: If the model's inputs are already set. """ - assert context.in_graph_mode() + assert not context.executing_eagerly() if self.inputs: raise ValueError('Model inputs are already set.') @@ -1893,17 +999,18 @@ class Model(Network): self._feed_input_names.append(name) self._feed_input_shapes.append(K.int_shape(v)) - # Obtain symbolic outputs by calling the model. - if len(self.inputs) == 1: - if self._expects_training_arg: - outputs = self.call(self.inputs[0], training=training) - else: - outputs = self.call(self.inputs[0]) - else: - if self._expects_training_arg: - outputs = self.call(self.inputs, training=training) + if outputs is None: + # Obtain symbolic outputs by calling the model. + if len(self.inputs) == 1: + if self._expects_training_arg: + outputs = self.call(self.inputs[0], training=training) + else: + outputs = self.call(self.inputs[0]) else: - outputs = self.call(self.inputs) + if self._expects_training_arg: + outputs = self.call(self.inputs, training=training) + else: + outputs = self.call(self.inputs) if isinstance(outputs, (list, tuple)): outputs = list(outputs) else: @@ -2049,10 +1156,7 @@ class Model(Network): class_weight=class_weight, batch_size=batch_size) # Prepare validation data. - do_validation = False - val_ins = [] if validation_data: - do_validation = True if len(validation_data) == 2: val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence val_sample_weight = None @@ -2070,13 +1174,8 @@ class Model(Network): val_y, sample_weight=val_sample_weight, batch_size=batch_size) - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - val_ins = val_x + val_y + val_sample_weights + [0.] - else: - val_ins = val_x + val_y + val_sample_weights elif validation_split and 0. < validation_split < 1.: - do_validation = True if hasattr(x[0], 'shape'): split_at = int(x[0].shape[0] * (1. - validation_split)) else: @@ -2085,77 +1184,44 @@ class Model(Network): y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at)) sample_weights, val_sample_weights = (slice_arrays( sample_weights, 0, split_at), slice_arrays(sample_weights, split_at)) - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - val_ins = val_x + val_y + val_sample_weights + [0.] - else: - val_ins = val_x + val_y + val_sample_weights - elif validation_steps: - do_validation = True - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - val_ins = [0.] - - # Prepare input arrays and training function. - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + y + sample_weights + [1.] + val_x = [] + val_y = [] + val_sample_weights = [] else: - ins = x + y + sample_weights - - # Prepare display labels. - out_labels = self.metrics_names - - if context.in_eager_mode(): - if do_validation: - callback_metrics = copy.copy(out_labels) + [ - 'val_' + n for n in out_labels - ] - else: - callback_metrics = copy.copy(out_labels) + val_x = None + val_y = None + val_sample_weights = None + if context.executing_eagerly(): return training_eager.fit_loop( self, - ins, - out_labels=out_labels, + inputs=x, + targets=y, + sample_weights=sample_weights, batch_size=batch_size, epochs=epochs, verbose=verbose, callbacks=callbacks, - val_ins=val_ins, + val_inputs=val_x, + val_targets=val_y, + val_sample_weights=val_sample_weights, shuffle=shuffle, - callback_metrics=callback_metrics, initial_epoch=initial_epoch, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps) else: - self._make_train_function() - f = self.train_function - - if do_validation: - if context.in_graph_mode(): - self._make_test_function() - val_f = self.test_function - else: - val_f = None - callback_metrics = copy.copy(out_labels) + [ - 'val_' + n for n in out_labels - ] - else: - val_f = None - callback_metrics = copy.copy(out_labels) - - # Delegate logic to `_fit_loop`. - return self._fit_loop( - f, - ins, - out_labels=out_labels, + return training_arrays.fit_loop( + self, x, y, + sample_weights=sample_weights, batch_size=batch_size, epochs=epochs, verbose=verbose, callbacks=callbacks, - val_f=val_f, - val_ins=val_ins, + val_inputs=val_x, + val_targets=val_y, + val_sample_weights=val_sample_weights, shuffle=shuffle, - callback_metrics=callback_metrics, initial_epoch=initial_epoch, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps) @@ -2229,20 +1295,15 @@ class Model(Network): y, sample_weight=sample_weight, batch_size=batch_size) - # Prepare inputs, delegate logic to `_test_loop`. - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + y + sample_weights + [0.] - else: - ins = x + y + sample_weights - if context.in_eager_mode(): + if context.executing_eagerly(): return training_eager.test_loop( - self, ins, batch_size=batch_size, verbose=verbose, steps=steps) + self, inputs=x, targets=y, sample_weights=sample_weights, + batch_size=batch_size, verbose=verbose, steps=steps) else: - self._make_test_function() - f = self.test_function - return self._test_loop( - f, ins, batch_size=batch_size, verbose=verbose, steps=steps) + return training_arrays.test_loop( + self, inputs=x, targets=y, sample_weights=sample_weights, + batch_size=batch_size, verbose=verbose, steps=steps) def predict(self, x, batch_size=None, verbose=0, steps=None): """Generates output predictions for the input samples. @@ -2276,21 +1337,12 @@ class Model(Network): 'argument.') x, _, _ = self._standardize_user_data(x) - # Prepare inputs, delegate logic to `_predict_loop`. - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + [0.] - else: - ins = x - - if context.in_eager_mode(): + if context.executing_eagerly(): return training_eager.predict_loop( - self, ins, batch_size=batch_size, verbose=verbose, steps=steps) + self, x, batch_size=batch_size, verbose=verbose, steps=steps) else: - self._make_predict_function() - f = self.predict_function - - return self._predict_loop( - f, ins, batch_size=batch_size, verbose=verbose, steps=steps) + return training_arrays.predict_loop( + self, x, batch_size=batch_size, verbose=verbose, steps=steps) def train_on_batch(self, x, y, sample_weight=None, class_weight=None): """Runs a single gradient update on a single batch of data. @@ -2327,20 +1379,24 @@ class Model(Network): and/or metrics). The attribute `model.metrics_names` will give you the display labels for the scalar outputs. + Raises: + ValueError: In case of invalid user-provided arguments. """ x, y, sample_weights = self._standardize_user_data( x, y, sample_weight=sample_weight, class_weight=class_weight) - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + y + sample_weights + [1.] - else: - ins = x + y + sample_weights - if context.in_eager_mode(): - outputs = training_eager.train_on_batch(self, ins) + if context.executing_eagerly(): + outputs = training_eager.train_on_batch( + self, x, y, sample_weights=sample_weights) else: + if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = x + y + sample_weights + [1] + else: + ins = x + y + sample_weights + self._make_train_function() outputs = self.train_function(ins) @@ -2377,18 +1433,19 @@ class Model(Network): the display labels for the scalar outputs. Raises: - ValueError: in case of invalid arguments. + ValueError: In case of invalid user-provided arguments. """ x, y, sample_weights = self._standardize_user_data( x, y, sample_weight=sample_weight) - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + y + sample_weights + [0.] - else: - ins = x + y + sample_weights - if context.in_eager_mode(): - outputs = training_eager.test_on_batch(self, ins) + if context.executing_eagerly(): + outputs = training_eager.test_on_batch( + self, x, y, sample_weights=sample_weights) else: + if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = x + y + sample_weights + [0] + else: + ins = x + y + sample_weights self._make_test_function() outputs = self.test_function(ins) @@ -2408,26 +1465,19 @@ class Model(Network): """ x, _, _ = self._standardize_user_data(x) - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + [0.] - else: - ins = x - - if context.in_eager_mode(): - ins_batch_converted = [] - for ib in ins: - ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) + if context.executing_eagerly(): + inputs = [ops.convert_to_tensor(val, dtype=K.floatx()) for val in x] + return self(inputs) # pylint: disable=not-callable - eager_model_inputs = [] - for i in range(len(self.inputs)): - eager_model_inputs.append(ins_batch_converted[i]) - - outs = self(eager_model_inputs) # pylint: disable=not-callable - return outs + if not context.executing_eagerly(): + if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = x + [0] + else: + ins = x - if context.in_graph_mode(): self._make_predict_function() outputs = self.predict_function(ins) + if len(outputs) == 1: return outputs[0] return outputs @@ -2499,20 +1549,19 @@ class Model(Network): max_queue_size: Integer. Maximum size for the generator queue. If unspecified, `max_queue_size` will default to 10. workers: Integer. Maximum number of processes to spin up - when using process based threading. + when using process-based threading. If unspecified, `workers` will default to 1. If 0, will execute the generator on the main thread. - use_multiprocessing: Boolean. If True, use process based threading. - If unspecified, `workers` will default to False. - Note that because - this implementation relies on multiprocessing, - you should not pass - non picklable arguments to the generator - as they can't be passed - easily to children processes. - shuffle: Whether to shuffle the order of the batches at + use_multiprocessing: Boolean. + If `True`, use process-based threading. + If unspecified, `use_multiprocessing` will default to `False`. + Note that because this implementation relies on multiprocessing, + you should not pass non-picklable arguments to the generator + as they can't be passed easily to children processes. + shuffle: Boolean. Whether to shuffle the order of the batches at the beginning of each epoch. Only used with instances - of `Sequence` (keras.utils.Sequence). + of `Sequence` (`keras.utils.Sequence`). + Has no effect when `steps_per_epoch` is not `None`. initial_epoch: Epoch at which to start training (useful for resuming a previous training run) @@ -2539,217 +1588,25 @@ class Model(Network): ValueError: In case the generator yields data in an invalid format. """ - if not self._is_graph_network: + if not self.built and not self._is_graph_network: raise NotImplementedError( - '`fit_generator` is not yet enabled for Model subclasses') - - wait_time = 0.01 # in seconds - epoch = initial_epoch - - do_validation = bool(validation_data) - self._make_train_function() - if do_validation: - self._make_test_function() - - is_sequence = isinstance(generator, Sequence) - if not is_sequence and use_multiprocessing and workers > 1: - logging.warning( - UserWarning('Using a generator with `use_multiprocessing=True`' - ' and multiple workers may duplicate your data.' - ' Please consider using the`keras.utils.Sequence' - ' class.')) - if steps_per_epoch is None: - if is_sequence: - steps_per_epoch = len(generator) - else: - raise ValueError('`steps_per_epoch=None` is only valid for a' - ' generator based on the `keras.utils.Sequence`' - ' class. Please specify `steps_per_epoch` or use' - ' the `keras.utils.Sequence` class.') - - # python 2 has 'next', 3 has '__next__' - # avoid any explicit version checks - val_gen = ( - hasattr(validation_data, 'next') or - hasattr(validation_data, '__next__') or - isinstance(validation_data, Sequence)) - if (val_gen and not isinstance(validation_data, Sequence) and - not validation_steps): - raise ValueError('`validation_steps=None` is only valid for a' - ' generator based on the `keras.utils.Sequence`' - ' class. Please specify `validation_steps` or use' - ' the `keras.utils.Sequence` class.') - - # Prepare display labels. - out_labels = self.metrics_names - callback_metrics = out_labels + ['val_%s' % n for n in out_labels] - - # prepare callbacks - self.history = cbks.History() - callbacks = [cbks.BaseLogger()] + (callbacks or []) + [self.history] - if verbose: - callbacks += [cbks.ProgbarLogger(count_mode='steps')] - callbacks = cbks.CallbackList(callbacks) - - # it's possible to callback a different model than self: - if hasattr(self, 'callback_model') and self.callback_model: - callback_model = self.callback_model - else: - callback_model = self - callbacks.set_model(callback_model) - callbacks.set_params({ - 'epochs': epochs, - 'steps': steps_per_epoch, - 'verbose': verbose, - 'do_validation': do_validation, - 'metrics': callback_metrics, - }) - callbacks.on_train_begin() - - enqueuer = None - val_enqueuer = None - - try: - if do_validation: - if val_gen: - if workers > 0: - if isinstance(validation_data, Sequence): - val_enqueuer = OrderedEnqueuer( - validation_data, use_multiprocessing=use_multiprocessing) - if validation_steps is None: - validation_steps = len(validation_data) - else: - val_enqueuer = GeneratorEnqueuer( - validation_data, - use_multiprocessing=use_multiprocessing, - wait_time=wait_time) - val_enqueuer.start(workers=workers, max_queue_size=max_queue_size) - validation_generator = val_enqueuer.get() - else: - validation_generator = validation_data - else: - if len(validation_data) == 2: - val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence - val_sample_weight = None - elif len(validation_data) == 3: - val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence - else: - raise ValueError( - '`validation_data` should be a tuple ' - '`(val_x, val_y, val_sample_weight)` ' - 'or `(val_x, val_y)`. Found: ' + str(validation_data)) - val_x, val_y, val_sample_weights = self._standardize_user_data( - val_x, val_y, val_sample_weight) - val_data = val_x + val_y + val_sample_weights - if self.uses_learning_phase and not isinstance( - K.learning_phase(), int): - val_data += [0.] - for cbk in callbacks: - cbk.validation_data = val_data - - if workers > 0: - if is_sequence: - enqueuer = OrderedEnqueuer( - generator, - use_multiprocessing=use_multiprocessing, - shuffle=shuffle) - else: - enqueuer = GeneratorEnqueuer( - generator, - use_multiprocessing=use_multiprocessing, - wait_time=wait_time) - enqueuer.start(workers=workers, max_queue_size=max_queue_size) - output_generator = enqueuer.get() - else: - output_generator = generator - - callback_model.stop_training = False - # Construct epoch logs. - epoch_logs = {} - while epoch < epochs: - callbacks.on_epoch_begin(epoch) - steps_done = 0 - batch_index = 0 - while steps_done < steps_per_epoch: - generator_output = next(output_generator) - - if not hasattr(generator_output, '__len__'): - raise ValueError('Output of generator should be ' - 'a tuple `(x, y, sample_weight)` ' - 'or `(x, y)`. Found: ' + str(generator_output)) - - if len(generator_output) == 2: - x, y = generator_output - sample_weight = None - elif len(generator_output) == 3: - x, y, sample_weight = generator_output - else: - raise ValueError('Output of generator should be ' - 'a tuple `(x, y, sample_weight)` ' - 'or `(x, y)`. Found: ' + str(generator_output)) - # build batch logs - batch_logs = {} - if isinstance(x, list): - batch_size = x[0].shape[0] - elif isinstance(x, dict): - batch_size = list(x.values())[0].shape[0] - else: - batch_size = x.shape[0] - batch_logs['batch'] = batch_index - batch_logs['size'] = batch_size - callbacks.on_batch_begin(batch_index, batch_logs) - - outs = self.train_on_batch( - x, y, sample_weight=sample_weight, class_weight=class_weight) - - if not isinstance(outs, list): - outs = [outs] - for l, o in zip(out_labels, outs): - batch_logs[l] = o - - callbacks.on_batch_end(batch_index, batch_logs) - - batch_index += 1 - steps_done += 1 - - # Epoch finished. - if steps_done >= steps_per_epoch and do_validation: - if val_gen: - val_outs = self.evaluate_generator( - validation_generator, validation_steps, workers=0) - else: - # No need for try/except because - # data has already been validated. - val_outs = self.evaluate( - val_x, - val_y, - batch_size=batch_size, - sample_weight=val_sample_weights, - verbose=0) - if not isinstance(val_outs, list): - val_outs = [val_outs] - # Same labels assumed. - for l, o in zip(out_labels, val_outs): - epoch_logs['val_' + l] = o - - if callback_model.stop_training: - break - - callbacks.on_epoch_end(epoch, epoch_logs) - epoch += 1 - if callback_model.stop_training: - break - - finally: - try: - if enqueuer is not None: - enqueuer.stop() - finally: - if val_enqueuer is not None: - val_enqueuer.stop() - - callbacks.on_train_end() - return self.history + '`fit_generator` is not yet enabled for unbuilt Model subclasses') + + return training_generator.fit_generator( + self, + generator, + steps_per_epoch=steps_per_epoch, + epochs=epochs, + verbose=verbose, + callbacks=callbacks, + validation_data=validation_data, + validation_steps=validation_steps, + class_weight=class_weight, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing, + shuffle=shuffle, + initial_epoch=initial_epoch) def evaluate_generator(self, generator, @@ -2774,16 +1631,15 @@ class Model(Network): the `len(generator)` as a number of steps. max_queue_size: maximum size for the generator queue workers: Integer. Maximum number of processes to spin up - when using process based threading. + when using process-based threading. If unspecified, `workers` will default to 1. If 0, will execute the generator on the main thread. - use_multiprocessing: if True, use process based threading. - Note that because - this implementation relies on multiprocessing, - you should not pass - non picklable arguments to the generator - as they can't be passed - easily to children processes. + use_multiprocessing: Boolean. + If `True`, use process-based threading. + If unspecified, `use_multiprocessing` will default to `False`. + Note that because this implementation relies on multiprocessing, + you should not pass non-picklable arguments to the generator + as they can't be passed easily to children processes. Returns: Scalar test loss (if the model has a single output and no metrics) @@ -2798,91 +1654,18 @@ class Model(Network): ValueError: In case the generator yields data in an invalid format. """ - if not self._is_graph_network: + if not self.built and not self._is_graph_network: raise NotImplementedError( - '`evaluate_generator` is not yet enabled for Model subclasses') - - self._make_test_function() - - steps_done = 0 - wait_time = 0.01 - all_outs = [] - batch_sizes = [] - is_sequence = isinstance(generator, Sequence) - if not is_sequence and use_multiprocessing and workers > 1: - logging.warning( - UserWarning('Using a generator with `use_multiprocessing=True`' - ' and multiple workers may duplicate your data.' - ' Please consider using the`keras.utils.Sequence' - ' class.')) - if steps is None: - if is_sequence: - steps = len(generator) - else: - raise ValueError('`steps=None` is only valid for a generator' - ' based on the `keras.utils.Sequence` class.' - ' Please specify `steps` or use the' - ' `keras.utils.Sequence` class.') - enqueuer = None - - try: - if workers > 0: - if is_sequence: - enqueuer = OrderedEnqueuer( - generator, use_multiprocessing=use_multiprocessing) - else: - enqueuer = GeneratorEnqueuer( - generator, - use_multiprocessing=use_multiprocessing, - wait_time=wait_time) - enqueuer.start(workers=workers, max_queue_size=max_queue_size) - output_generator = enqueuer.get() - else: - output_generator = generator - - while steps_done < steps: - generator_output = next(output_generator) - if not hasattr(generator_output, '__len__'): - raise ValueError('Output of generator should be a tuple ' - '(x, y, sample_weight) ' - 'or (x, y). Found: ' + str(generator_output)) - if len(generator_output) == 2: - x, y = generator_output - sample_weight = None - elif len(generator_output) == 3: - x, y, sample_weight = generator_output - else: - raise ValueError('Output of generator should be a tuple ' - '(x, y, sample_weight) ' - 'or (x, y). Found: ' + str(generator_output)) - outs = self.test_on_batch(x, y, sample_weight=sample_weight) - - if isinstance(x, list): - batch_size = x[0].shape[0] - elif isinstance(x, dict): - batch_size = list(x.values())[0].shape[0] - else: - batch_size = x.shape[0] - if batch_size == 0: - raise ValueError('Received an empty batch. ' - 'Batches should at least contain one item.') - all_outs.append(outs) + '`evaluate_generator` is not yet enabled for ' + 'unbuilt Model subclasses') - steps_done += 1 - batch_sizes.append(batch_size) - - finally: - if enqueuer is not None: - enqueuer.stop() - - if not isinstance(outs, list): - return np.average(np.asarray(all_outs), weights=batch_sizes) - else: - averages = [] - for i in range(len(outs)): - averages.append( - np.average([out[i] for out in all_outs], weights=batch_sizes)) - return averages + return training_generator.evaluate_generator( + self, + generator, + steps=steps, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing) def predict_generator(self, generator, @@ -2907,16 +1690,15 @@ class Model(Network): the `len(generator)` as a number of steps. max_queue_size: Maximum size for the generator queue. workers: Integer. Maximum number of processes to spin up - when using process based threading. + when using process-based threading. If unspecified, `workers` will default to 1. If 0, will execute the generator on the main thread. - use_multiprocessing: If `True`, use process based threading. - Note that because - this implementation relies on multiprocessing, - you should not pass - non picklable arguments to the generator - as they can't be passed - easily to children processes. + use_multiprocessing: Boolean. + If `True`, use process-based threading. + If unspecified, `use_multiprocessing` will default to `False`. + Note that because this implementation relies on multiprocessing, + you should not pass non-picklable arguments to the generator + as they can't be passed easily to children processes. verbose: verbosity mode, 0 or 1. Returns: @@ -2926,92 +1708,15 @@ class Model(Network): ValueError: In case the generator yields data in an invalid format. """ - if not self._is_graph_network: + if not self.built and not self._is_graph_network: raise NotImplementedError( - '`predict_generator` is not yet enabled for Model subclasses') - - self._make_predict_function() - - steps_done = 0 - wait_time = 0.01 - all_outs = [] - is_sequence = isinstance(generator, Sequence) - if not is_sequence and use_multiprocessing and workers > 1: - logging.warning( - UserWarning('Using a generator with `use_multiprocessing=True`' - ' and multiple workers may duplicate your data.' - ' Please consider using the`keras.utils.Sequence' - ' class.')) - if steps is None: - if is_sequence: - steps = len(generator) - else: - raise ValueError('`steps=None` is only valid for a generator' - ' based on the `keras.utils.Sequence` class.' - ' Please specify `steps` or use the' - ' `keras.utils.Sequence` class.') - enqueuer = None - - try: - if workers > 0: - if is_sequence: - enqueuer = OrderedEnqueuer( - generator, use_multiprocessing=use_multiprocessing) - else: - enqueuer = GeneratorEnqueuer( - generator, - use_multiprocessing=use_multiprocessing, - wait_time=wait_time) - enqueuer.start(workers=workers, max_queue_size=max_queue_size) - output_generator = enqueuer.get() - else: - output_generator = generator - - if verbose == 1: - progbar = Progbar(target=steps) - - while steps_done < steps: - generator_output = next(output_generator) - if isinstance(generator_output, tuple): - # Compatibility with the generators - # used for training. - if len(generator_output) == 2: - x, _ = generator_output - elif len(generator_output) == 3: - x, _, _ = generator_output - else: - raise ValueError('Output of generator should be ' - 'a tuple `(x, y, sample_weight)` ' - 'or `(x, y)`. Found: ' + str(generator_output)) - else: - # Assumes a generator that only - # yields inputs (not targets and sample weights). - x = generator_output - - outs = self.predict_on_batch(x) - if not isinstance(outs, list): - outs = [outs] - - if not all_outs: - for out in outs: - all_outs.append([]) - - for i, out in enumerate(outs): - all_outs[i].append(out) - steps_done += 1 - if verbose == 1: - progbar.update(steps_done) - - finally: - if enqueuer is not None: - enqueuer.stop() - - if len(all_outs) == 1: - if steps_done == 1: - return all_outs[0][0] - else: - return np.concatenate(all_outs[0]) - if steps_done == 1: - return [out[0] for out in all_outs] - else: - return [np.concatenate(out) for out in all_outs] + '`predict_generator` is not yet enabled for unbuilt Model subclasses') + + return training_generator.predict_generator( + self, + generator, + steps=steps, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing, + verbose=verbose) diff --git a/tensorflow/python/keras/_impl/keras/engine/training_arrays.py b/tensorflow/python/keras/_impl/keras/engine/training_arrays.py new file mode 100644 index 0000000000000000000000000000000000000000..18116e3a14d6b1365f1a9db1a23243cd07763a62 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/training_arrays.py @@ -0,0 +1,488 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Part of the Keras training engine related to plain array data. +""" +# pylint: disable=protected-access +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +import numpy as np + +from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras import callbacks as cbks +from tensorflow.python.keras._impl.keras.engine import training_utils +from tensorflow.python.keras._impl.keras.engine.base_layer import Layer +from tensorflow.python.keras._impl.keras.utils.generic_utils import make_batches +from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar +from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays + +try: + from scipy.sparse import issparse # pylint: disable=g-import-not-at-top +except ImportError: + issparse = None + + +def fit_loop(model, + inputs, + targets, + sample_weights=None, + batch_size=None, + epochs=100, + verbose=1, + callbacks=None, + val_inputs=None, + val_targets=None, + val_sample_weights=None, + shuffle=True, + callback_metrics=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None): + """Abstract fit function for arrays of data. + + Arguments: + model: Keras Model instance. + inputs: List of input arrays. + targets: List of target arrays. + sample_weights: Optional list of sample weight arrays. + batch_size: Integer batch size or None if unknown. + epochs: Number of times to iterate over the data + verbose: Verbosity mode, 0, 1 or 2 + callbacks: List of callbacks to be called during training + val_inputs: List of input arrays. + val_targets: List of target arrays. + val_sample_weights: Optional list of sample weight arrays. + shuffle: Whether to shuffle the data at the beginning of each epoch + callback_metrics: List of strings, the display names of the metrics + passed to the callbacks. They should be the + concatenation of list the display names of the outputs of + `f` and the list of display names of the outputs of `f_val`. + initial_epoch: Epoch at which to start training + (useful for resuming a previous training run) + steps_per_epoch: Total number of steps (batches of samples) + before declaring one epoch finished and starting the + next epoch. Ignored with the default value of `None`. + validation_steps: Number of steps to run validation for + (only if doing validation from data tensors). + Ignored with the default value of `None`. + + Returns: + `History` object. + + Raises: + ValueError: in case of invalid arguments. + """ + model._make_train_function() + f = model.train_function + + sample_weights = sample_weights or [] + val_sample_weights = val_sample_weights or [] + if model.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = inputs + targets + sample_weights + [1] + if val_inputs: + val_ins = val_inputs + val_targets + val_sample_weights + [1] + else: + ins = inputs + targets + sample_weights + if val_inputs: + val_ins = val_inputs + val_targets + val_sample_weights + if not val_inputs: + val_ins = [] + + do_validation = False + if val_inputs: + do_validation = True + if verbose and inputs and hasattr(inputs[0], 'shape') and hasattr( + val_inputs[0], 'shape'): + print('Train on %d samples, validate on %d samples' % + (inputs[0].shape[0], val_inputs[0].shape[0])) + if validation_steps: + do_validation = True + if steps_per_epoch is None: + raise ValueError('Can only use `validation_steps` ' + 'when doing step-wise ' + 'training, i.e. `steps_per_epoch` ' + 'must be set.') + + out_labels = model.metrics_names + if do_validation: + callback_metrics = copy.copy(out_labels) + [ + 'val_' + n for n in out_labels + ] + else: + callback_metrics = copy.copy(out_labels) + + num_train_samples = training_utils.check_num_samples( + ins, batch_size, steps_per_epoch, 'steps_per_epoch') + if num_train_samples is not None: + index_array = np.arange(num_train_samples) + + model.history = cbks.History() + all_callbacks = [cbks.BaseLogger( + stateful_metrics=model.stateful_metric_names)] + if verbose: + if steps_per_epoch is not None: + count_mode = 'steps' + else: + count_mode = 'samples' + all_callbacks.append( + cbks.ProgbarLogger( + count_mode, stateful_metrics=model.stateful_metric_names)) + all_callbacks += (callbacks or []) + [model.history] + callbacks = cbks.CallbackList(all_callbacks) + out_labels = out_labels or [] + + # it's possible to callback a different model than self + # (used by Sequential models) + if hasattr(model, 'callback_model') and model.callback_model: + callback_model = model.callback_model + else: + callback_model = model + + callbacks.set_model(callback_model) + + callbacks.set_params({ + 'batch_size': batch_size, + 'epochs': epochs, + 'steps': steps_per_epoch, + 'samples': num_train_samples, + 'verbose': verbose, + 'do_validation': do_validation, + 'metrics': callback_metrics or [], + }) + callbacks.on_train_begin() + callback_model.stop_training = False + for cbk in callbacks: + cbk.validation_data = val_ins + + # To prevent a slowdown, we find beforehand the arrays that need conversion. + feed = model._feed_inputs + model._feed_targets + model._feed_sample_weights + indices_for_conversion_to_dense = [] + for i in range(len(feed)): + if issparse is not None and issparse(ins[i]) and not K.is_sparse(feed[i]): + indices_for_conversion_to_dense.append(i) + + for epoch in range(initial_epoch, epochs): + # Reset stateful metrics + for m in model.metrics: + if isinstance(m, Layer): + m.reset_states() + # Update callbacks + callbacks.on_epoch_begin(epoch) + epoch_logs = {} + if steps_per_epoch is not None: + for step_index in range(steps_per_epoch): + batch_logs = {} + batch_logs['batch'] = step_index + batch_logs['size'] = 1 + callbacks.on_batch_begin(step_index, batch_logs) + outs = f(ins) + + if not isinstance(outs, list): + outs = [outs] + for l, o in zip(out_labels, outs): + batch_logs[l] = o + + callbacks.on_batch_end(step_index, batch_logs) + if callback_model.stop_training: + break + + if do_validation: + val_outs = test_loop( + model, + val_inputs, + val_targets, + sample_weights=val_sample_weights, + batch_size=batch_size, + steps=validation_steps, + verbose=0) + if not isinstance(val_outs, list): + val_outs = [val_outs] + # Same labels assumed. + for l, o in zip(out_labels, val_outs): + epoch_logs['val_' + l] = o + else: + if shuffle == 'batch': + index_array = training_utils.batch_shuffle(index_array, batch_size) + elif shuffle: + np.random.shuffle(index_array) + + batches = make_batches(num_train_samples, batch_size) + + for batch_index, (batch_start, batch_end) in enumerate(batches): + batch_ids = index_array[batch_start:batch_end] + try: + if isinstance(ins[-1], int): + # Do not slice the training phase flag. + ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + else: + ins_batch = slice_arrays(ins, batch_ids) + except TypeError: + raise TypeError('TypeError while preparing batch. ' + 'If using HDF5 input data, ' + 'pass shuffle="batch".') + batch_logs = {} + batch_logs['batch'] = batch_index + batch_logs['size'] = len(batch_ids) + callbacks.on_batch_begin(batch_index, batch_logs) + for i in indices_for_conversion_to_dense: + ins_batch[i] = ins_batch[i].toarray() + + outs = f(ins_batch) + if not isinstance(outs, list): + outs = [outs] + for l, o in zip(out_labels, outs): + batch_logs[l] = o + + callbacks.on_batch_end(batch_index, batch_logs) + if callback_model.stop_training: + break + + if batch_index == len(batches) - 1: # Last batch. + if do_validation: + val_outs = test_loop( + model, + val_inputs, + val_targets, + sample_weights=val_sample_weights, + batch_size=batch_size, + verbose=0) + if not isinstance(val_outs, list): + val_outs = [val_outs] + # Same labels assumed. + for l, o in zip(out_labels, val_outs): + epoch_logs['val_' + l] = o + callbacks.on_epoch_end(epoch, epoch_logs) + if callback_model.stop_training: + break + callbacks.on_train_end() + return model.history + + +def predict_loop(model, inputs, batch_size=32, verbose=0, steps=None): + """Abstract method to loop over some data in batches. + + Arguments: + model: Keras Model instance. + inputs: list of tensors to be fed to `f`. + batch_size: integer batch size. + verbose: verbosity mode. + steps: Total number of steps (batches of samples) + before declaring `_predict_loop` finished. + Ignored with the default value of `None`. + + Returns: + Array of predictions (if the model has a single output) + or list of arrays of predictions + (if the model has multiple outputs). + """ + model._make_predict_function() + f = model.predict_function + + if model.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = inputs + [0] + else: + ins = inputs + + num_samples = training_utils.check_num_samples( + inputs, batch_size, steps, 'steps') + if verbose == 1: + if steps is not None: + progbar = Progbar(target=steps) + else: + progbar = Progbar(target=num_samples) + + indices_for_conversion_to_dense = [] + for i in range(len(model._feed_inputs)): + if (issparse is not None and issparse(inputs[i]) and + not K.is_sparse(model._feed_inputs[i])): + indices_for_conversion_to_dense.append(i) + + if steps is not None: + # Step-based predictions. + # Since we do not know how many samples + # we will see, we cannot pre-allocate + # the returned Numpy arrays. + # Instead, we store one array per batch seen + # and concatenate them upon returning. + unconcatenated_outs = [] + for step in range(steps): + batch_outs = f(ins) + if not isinstance(batch_outs, list): + batch_outs = [batch_outs] + if step == 0: + for batch_out in batch_outs: + unconcatenated_outs.append([]) + for i, batch_out in enumerate(batch_outs): + unconcatenated_outs[i].append(batch_out) + if verbose == 1: + progbar.update(step + 1) + if len(unconcatenated_outs) == 1: + return np.concatenate(unconcatenated_outs[0], axis=0) + return [ + np.concatenate(unconcatenated_outs[i], axis=0) + for i in range(len(unconcatenated_outs)) + ] + else: + # Sample-based predictions. + outs = [] + batches = make_batches(num_samples, batch_size) + index_array = np.arange(num_samples) + for batch_index, (batch_start, batch_end) in enumerate(batches): + batch_ids = index_array[batch_start:batch_end] + if ins and isinstance(ins[-1], int): + # Do not slice the training phase flag. + ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + else: + ins_batch = slice_arrays(ins, batch_ids) + for i in indices_for_conversion_to_dense: + ins_batch[i] = ins_batch[i].toarray() + + batch_outs = f(ins_batch) + if not isinstance(batch_outs, list): + batch_outs = [batch_outs] + if batch_index == 0: + # Pre-allocate the results arrays. + for batch_out in batch_outs: + shape = (num_samples,) + batch_out.shape[1:] + outs.append(np.zeros(shape, dtype=batch_out.dtype)) + for i, batch_out in enumerate(batch_outs): + outs[i][batch_start:batch_end] = batch_out + if verbose == 1: + progbar.update(batch_end) + if len(outs) == 1: + return outs[0] + return outs + + +def test_loop(model, inputs, targets, + sample_weights=None, + batch_size=None, + verbose=0, + steps=None): + """Abstract method to loop over some data in batches. + + Arguments: + model: Keras Model instance. + inputs: List of input arrays. + targets: List of target arrays. + sample_weights: Optional list of sample weight arrays. + batch_size: integer batch size or `None`. + verbose: verbosity mode. + steps: Total number of steps (batches of samples) + before declaring predictions finished. + Ignored with the default value of `None`. + + Returns: + Scalar loss (if the model has a single output and no metrics) + or list of scalars (if the model has multiple outputs + and/or metrics). The attribute `model.metrics_names` will give you + the display labels for the scalar outputs. + """ + model._make_test_function() + f = model.test_function + + sample_weights = sample_weights or [] + if model.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = inputs + targets + sample_weights + [0] + else: + ins = inputs + targets + sample_weights + + if hasattr(model, 'metrics'): + for m in model.metrics: + if isinstance(m, Layer): + m.reset_states() + stateful_metric_indices = [ + i for i, name in enumerate(model.metrics_names) + if str(name) in model.stateful_metric_names + ] + else: + stateful_metric_indices = [] + + num_samples = training_utils.check_num_samples( + ins, batch_size, steps, 'steps') + outs = [] + if verbose == 1: + if steps is not None: + progbar = Progbar(target=steps) + else: + progbar = Progbar(target=num_samples) + + # To prevent a slowdown, we find beforehand the arrays that need conversion. + feed = model._feed_inputs + model._feed_targets + model._feed_sample_weights + indices_for_conversion_to_dense = [] + for i in range(len(feed)): + if issparse is not None and issparse(ins[i]) and not K.is_sparse(feed[i]): + indices_for_conversion_to_dense.append(i) + + if steps is not None: + for step in range(steps): + batch_outs = f(ins) + if isinstance(batch_outs, list): + if step == 0: + for _ in enumerate(batch_outs): + outs.append(0.) + for i, batch_out in enumerate(batch_outs): + if i in stateful_metric_indices: + outs[i] = batch_out + else: + outs[i] += batch_out + else: + if step == 0: + outs.append(0.) + outs[0] += batch_outs + if verbose == 1: + progbar.update(step + 1) + for i in range(len(outs)): + if i not in stateful_metric_indices: + outs[i] /= steps + else: + batches = make_batches(num_samples, batch_size) + index_array = np.arange(num_samples) + for batch_index, (batch_start, batch_end) in enumerate(batches): + batch_ids = index_array[batch_start:batch_end] + if isinstance(ins[-1], int): + # Do not slice the training phase flag. + ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + else: + ins_batch = slice_arrays(ins, batch_ids) + for i in indices_for_conversion_to_dense: + ins_batch[i] = ins_batch[i].toarray() + + batch_outs = f(ins_batch) + + if isinstance(batch_outs, list): + if batch_index == 0: + for batch_out in enumerate(batch_outs): + outs.append(0.) + for i, batch_out in enumerate(batch_outs): + if i in stateful_metric_indices: + outs[i] = batch_out + else: + outs[i] += batch_out * len(batch_ids) + else: + if batch_index == 0: + outs.append(0.) + outs[0] += batch_outs * len(batch_ids) + if verbose == 1: + progbar.update(batch_end) + for i in range(len(outs)): + if i not in stateful_metric_indices: + outs[i] /= num_samples + if len(outs) == 1: + return outs[0] + return outs diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager.py b/tensorflow/python/keras/_impl/keras/engine/training_eager.py index 282dd0dc0dbd440455cfade6952eda669aeaf2df..67858a578c5c95b3099e1e6713f3287748fc861f 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_eager.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_eager.py @@ -12,20 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Keras training and evaluation routines. +"""Keras training and evaluation routines for eager execution. """ # pylint: disable=protected-access from __future__ import absolute_import from __future__ import division from __future__ import print_function + +import copy + import numpy as np + from tensorflow.python.eager.backprop import GradientTape from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util -from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras import backend from tensorflow.python.keras._impl.keras import callbacks as cbks from tensorflow.python.keras._impl.keras import losses from tensorflow.python.keras._impl.keras import metrics as metrics_module +from tensorflow.python.keras._impl.keras.engine import training_utils from tensorflow.python.keras._impl.keras.utils.generic_utils import make_batches from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays @@ -55,7 +60,7 @@ def _get_metrics_info(metric, internal_output_shapes=None, loss_func=None): def _eager_loss_fn(outputs, targets, loss_fn, output_name): - with K.name_scope(output_name + '_loss'): + with backend.name_scope(output_name + '_loss'): loss = loss_fn(targets, outputs) return loss @@ -83,7 +88,7 @@ def _eager_metrics_fn(model, outputs, targets): output_metrics = model.nested_metrics[i] for nested_output_metric in output_metrics: metric_name, metric_fn = _get_metrics_info( - nested_output_metric, K.int_shape(model.outputs[i]), + nested_output_metric, backend.int_shape(model.outputs[i]), model.loss_functions[i]) if len(model.output_names) > 1: @@ -91,23 +96,23 @@ def _eager_metrics_fn(model, outputs, targets): if metric_name not in model.metrics_names: model.metrics_names.append(metric_name) - with K.name_scope(metric_name): + with backend.name_scope(metric_name): metric_result = metric_fn(outputs[i], targets[i]) metric_names.append(metric_name) - metric_results.append(K.mean(metric_result)) + metric_results.append(backend.mean(metric_result)) return metric_names, metric_results -def _model_loss(model, inputs, targets, training=False): +def _model_loss(model, inputs, targets, sample_weights=None, training=False): """Calculates the loss for a given model. Arguments: - model: The model on which metrics are being calculated. - inputs: The inputs of the given model. This is typically the mini batch of - data that is fed to the model. - targets: The predictions or targets of the given model. - training: Whether the model should be run in inference or training mode. + model: The model on which metrics are being calculated. + inputs: List of input arrays. + targets: List of target arrays. + sample_weights: Optional list of sample weight arrays. + training: Whether the model should be run in inference or training mode. Returns: Returns the model output, total loss and loss value calculated using the @@ -132,33 +137,22 @@ def _model_loss(model, inputs, targets, training=False): targets = [targets] loss_metrics = [] - with K.name_scope('loss'): + with backend.name_scope('loss'): for i, loss_fn in enumerate(model.loss_functions): - # compute the loss - output_loss = _eager_loss_fn(outs[i], targets[i], loss_fn, - model.output_names[i]) - loss_metrics.append(K.mean(output_loss)) + if sample_weights: + weights = sample_weights[i] + else: + weights = None + # TODO(fchollet): support masking; in practice `_keras_mask` is never + # set in this context currently. mask = outs[i]._keras_mask - # adapted from weighted_loss_fn - if mask is not None: - # mask should have the same shape as output_loss - output_loss *= mask - # the loss per batch should be proportional - # to the number of unmasked samples. - output_loss /= K.mean(mask) - - # adapted from weighted_loss_fn - # apply sample weighting - if model.sample_weights: - # reduce score_array to same ndim as weight array - ndim = K.ndim(output_loss) - weight_ndim = K.ndim(model.sample_weights) - output_loss = K.mean(output_loss, axis=list(range(weight_ndim, ndim))) - output_loss *= model.sample_weights - output_loss /= K.mean(K.cast(K.not_equal(model.sample_weights, 0), - K.floatx())) - output_loss = K.mean(output_loss) + + weighted_masked_fn = training_utils.weighted_masked_objective(loss_fn) + with backend.name_scope(model.output_names[i] + '_loss'): + output_loss = weighted_masked_fn( + outs[i], targets[i], weights, mask=mask) + loss_metrics.append(backend.mean(output_loss)) loss_weight = model.loss_weights_list[i] if total_loss is None: @@ -166,7 +160,7 @@ def _model_loss(model, inputs, targets, training=False): else: total_loss += loss_weight * output_loss - total_loss = K.mean(total_loss) + total_loss = backend.mean(total_loss) # Add regularization losses custom_losses = [] for layer in model.layers: @@ -179,16 +173,20 @@ def _model_loss(model, inputs, targets, training=False): return outs, total_loss, loss_metrics -def _process_single_batch(eager_model_inputs, eager_model_outputs, model, +def _process_single_batch(model, + inputs, + targets, + sample_weights=None, training=False): """Calculate the loss and gradient for one input batch. The model weights are updated if training is set to True. Arguments: - eager_model_inputs: Input batch data. - eager_model_outputs: Output batch data. model: Model whose loss has to be calculated. + inputs: List of input arrays. + targets: List of target arrays. + sample_weights: Optional list of sample weight arrays. training: The boolean represents if the weights of the model are updated. 'fit' methods will set this to True while 'evaluate' methods will set this to False. @@ -199,81 +197,81 @@ def _process_single_batch(eager_model_inputs, eager_model_outputs, model, Raises: ValueError: If the model has no loss to optimize. """ - K.set_learning_phase(training) - with GradientTape() as tape: - outs, loss, loss_metrics = _model_loss(model, eager_model_inputs, - eager_model_outputs, - training=training) - if loss is None: - raise ValueError('The model cannot be run ' - 'because it has no loss to optimize.') - if training: - if not model._collected_trainable_weights: - logging.warning('The list of trainable weights is empty. Make sure that ' - 'you are not setting model.trainable to False before ' - 'compiling the model.') - else: - grads = tape.gradient(loss, model._collected_trainable_weights) - model.optimizer.apply_gradients(zip(grads, - model._collected_trainable_weights)) - return outs, loss, loss_metrics + with backend.learning_phase_scope(1 if training else 0): + with GradientTape() as tape: + outs, loss, loss_metrics = _model_loss(model, inputs, targets, + sample_weights=sample_weights, + training=training) + if loss is None: + raise ValueError('The model cannot be run ' + 'because it has no loss to optimize.') + if training: + if not model._collected_trainable_weights: + logging.warning('The list of trainable weights is empty. Make sure that' + ' you are not setting model.trainable to False before ' + 'compiling the model.') + else: + grads = tape.gradient(loss, model._collected_trainable_weights) + model.optimizer.apply_gradients(zip(grads, + model._collected_trainable_weights)) + return outs, loss, loss_metrics -def train_on_batch(model, ins): +def train_on_batch(model, inputs, targets, sample_weights=None): """Calculates the loss and gradient updates for one input batch. Arguments: - model: Given model on which loss and gradients are calculated. - ins: Input and output batch numpy arrays. + model: Model whose loss has to be calculated. + inputs: Input batch data. + targets: Target batch data. + sample_weights: Sample weight batch data. Returns: total loss and the loss associated with each output. """ - ins_batch_converted = [] - for ib in ins: - ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) - eager_model_inputs = [] - eager_model_outputs = [] - for i in range(len(model.inputs)): - eager_model_inputs.append(ins_batch_converted[i]) - for i in range(len(model.inputs), len(ins_batch_converted)): - eager_model_outputs.append(ins_batch_converted[i]) + inputs = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs] + targets = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) for val in targets] + sample_weights = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + if val is not None else None for val in sample_weights] outs, loss, _ = _process_single_batch( - eager_model_inputs, eager_model_outputs, model, training=True) + model, inputs, targets, sample_weights=sample_weights, training=True) if not isinstance(outs, list): outs = [outs] _, metrics_results = _eager_metrics_fn( - model, outs, eager_model_outputs) + model, outs, targets) if not isinstance(loss, list): loss = [loss] return loss + metrics_results -def test_on_batch(model, ins): +def test_on_batch(model, inputs, targets, sample_weights=None): """Calculates the loss for one input batch. Arguments: - model: Given model on which loss is calculated. - ins: Input and output batch numpy arrays. + model: Model whose loss has to be calculated. + inputs: Input batch data. + targets: Target batch data. + sample_weights: Sample weight batch data. Returns: total loss, loss and metrics associated with each output. """ - ins_batch_converted = [] - for ib in ins: - ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) - eager_model_inputs = [] - eager_model_outputs = [] - for i in range(len(model.inputs)): - eager_model_inputs.append(ins_batch_converted[i]) - for i in range(len(model.inputs), len(ins_batch_converted)): - eager_model_outputs.append(ins_batch_converted[i]) + inputs = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs] + targets = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) for val in targets] + sample_weights = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + if val is not None else None for val in sample_weights] outs, loss, loss_metrics = _process_single_batch( - eager_model_inputs, eager_model_outputs, model, training=False) + model, inputs, targets, sample_weights=sample_weights, training=False) if not isinstance(outs, list): outs = [outs] metric_names, metrics_results = _eager_metrics_fn( - model, outs, eager_model_outputs) + model, outs, targets) model.metrics_names.append(metric_names) if not isinstance(loss, list): loss = [loss] @@ -282,32 +280,35 @@ def test_on_batch(model, ins): def fit_loop( model, - ins, - out_labels=None, + inputs, + targets, + sample_weights=None, + val_inputs=None, + val_targets=None, + val_sample_weights=None, batch_size=None, epochs=100, verbose=1, callbacks=None, - val_ins=None, shuffle=True, callback_metrics=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None): - """Abstract fit function for `f(ins)`. - - Assume that f returns a list, labeled by out_labels. + """Abstract fit function for eager execution. Arguments: model: Instance of the model that is being executed in Eager mode. - ins: List of tensors to be fed to `f` - out_labels: List of strings, display names of - the outputs of `f` + inputs: List of input arrays. + targets: List of target arrays. + sample_weights: Optional list of sample weight arrays. + val_inputs: Input data for validation. + val_targets: Target data for validation. + val_sample_weights: Sample weight data for validation. batch_size: Integer batch size or None if unknown. epochs: Number of times to iterate over the data verbose: Verbosity mode, 0, 1 or 2 callbacks: List of callbacks to be called during training - val_ins: List of tensors to be fed to `val_f` shuffle: Whether to shuffle the data at the beginning of each epoch callback_metrics: List of strings, the display names of the metrics passed to the callbacks. They should be the @@ -328,165 +329,196 @@ def fit_loop( ValueError: In case of invalid argument values. """ # Required for Eager mode - K.set_learning_phase(True) - - do_validation = False - if val_ins: - do_validation = True - if (verbose and ins and hasattr(ins[0], 'shape') and - hasattr(val_ins[0], 'shape')): - print('Train on %d samples, validate on %d samples' % - (ins[0].shape[0], val_ins[0].shape[0])) - if validation_steps: - if steps_per_epoch is None: - raise ValueError('Can only use `validation_steps` when doing step-wise ' - 'training, i.e. `steps_per_epoch` must be set.') - do_validation = True - - num_train_samples = model._check_num_samples( - ins, batch_size, steps_per_epoch, 'steps_per_epoch') - - if num_train_samples is not None: - index_array = np.arange(num_train_samples) - - model.history = cbks.History() - callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history] - if verbose: - if steps_per_epoch is not None: - count_mode = 'steps' + with backend.learning_phase_scope(1): + do_validation = False + if val_inputs: + do_validation = True + if (verbose and inputs and hasattr(inputs[0], 'shape') and + hasattr(val_inputs[0], 'shape')): + print('Train on %d samples, validate on %d samples' % + (inputs[0].shape[0], val_inputs[0].shape[0])) + if validation_steps: + if steps_per_epoch is None: + raise ValueError('Can only use `validation_steps` when doing step-wise ' + 'training, i.e. `steps_per_epoch` must be set.') + do_validation = True + + out_labels = model.metrics_names + if do_validation: + callback_metrics = copy.copy(out_labels) + [ + 'val_' + n for n in out_labels + ] else: - count_mode = 'samples' - callbacks += [cbks.ProgbarLogger(count_mode)] - callbacks = cbks.CallbackList(callbacks) - out_labels = out_labels or [] - - # it's possible to callback a different model than self - # (used by Sequential models) - if hasattr(model, 'callback_model') and model.callback_model: - callback_model = model.callback_model - else: - callback_model = model - - callbacks.set_model(callback_model) - - callbacks.set_params({ - 'batch_size': batch_size, - 'epochs': epochs, - 'steps': steps_per_epoch, - 'samples': num_train_samples, - 'verbose': verbose, - 'do_validation': do_validation, - 'metrics': callback_metrics or [], - }) - callbacks.on_train_begin() - callback_model.stop_training = False - for cbk in callbacks: - cbk.validation_data = val_ins - - for epoch in range(initial_epoch, epochs): - callbacks.on_epoch_begin(epoch) - epoch_logs = {} - if shuffle == 'batch': - index_array = model._batch_shuffle(index_array, batch_size) - elif shuffle: - np.random.shuffle(index_array) - - batches = make_batches(num_train_samples, batch_size) + callback_metrics = copy.copy(out_labels) - for batch_index, (batch_start, batch_end) in enumerate(batches): - batch_ids = index_array[batch_start:batch_end] - try: - if isinstance(ins[-1], float): - # Do not slice the training phase flag. - ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] - else: - ins_batch = slice_arrays(ins, batch_ids) - except TypeError: - raise TypeError('TypeError while preparing batch. ' - 'If using HDF5 input data, ' - 'pass shuffle="batch".') - batch_logs = {} - batch_logs['batch'] = batch_index - batch_logs['size'] = len(batch_ids) - - callbacks.on_batch_begin(batch_index, batch_logs) - - ins_batch_converted = [] - for ib in ins_batch: - ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) - eager_model_inputs = [] - eager_model_outputs = [] - for i in range(len(model.inputs)): - eager_model_inputs.append(ins_batch_converted[i]) - - for i in range(len(model.inputs), len(ins_batch_converted)): - eager_model_outputs.append(ins_batch_converted[i]) - - outs, loss, loss_metrics = _process_single_batch(eager_model_inputs, - eager_model_outputs, - model, - training=True) - - if not isinstance(outs, list): - outs = [outs] - - for l, o in zip(out_labels, outs): - batch_logs[l] = o - # Required for Eager mode - metrics_names, metrics_results = _eager_metrics_fn(model, outs, - eager_model_outputs) - batch_logs['loss'] = tensor_util.constant_value(K.mean(loss)) - - # TODO(anjalisridhar): Move this to compile to avoid duplicate code. - # In graph mode we set the metric names in compile. However in - # Eager mode we calculate the metrics for each batch in fit_loop. - # We could calculate the metric names and functions in compile. - # This would avoid setting the callback parameters separately. - # We need to do this for the first iteration alone - for m in metrics_names: - if m not in callback_metrics: - callback_metrics.append(m) - - callbacks.set_params({ - 'batch_size': batch_size, - 'epochs': epochs, - 'steps': steps_per_epoch, - 'samples': num_train_samples, - 'verbose': verbose, - 'do_validation': do_validation, - 'metrics': callback_metrics or [], - }) - - for k, v in zip(model.metrics_names, - [K.mean(loss)] + loss_metrics + metrics_results): - batch_logs[k] = tensor_util.constant_value(v) - - callbacks.on_batch_end(batch_index, batch_logs) + if sample_weights: + feed_data = inputs + targets + sample_weights + else: + feed_data = inputs + targets + num_train_samples = training_utils.check_num_samples( + feed_data, + batch_size=batch_size, + steps=steps_per_epoch, + steps_name='steps_per_epoch') + + if num_train_samples is not None: + index_array = np.arange(num_train_samples) + + model.history = cbks.History() + callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history] + if verbose: + if steps_per_epoch is not None: + count_mode = 'steps' + else: + count_mode = 'samples' + callbacks += [cbks.ProgbarLogger(count_mode)] + callbacks = cbks.CallbackList(callbacks) + + # it's possible to callback a different model than self + # (used by Sequential models) + if hasattr(model, 'callback_model') and model.callback_model: + callback_model = model.callback_model + else: + callback_model = model + + callbacks.set_model(callback_model) + + callbacks.set_params({ + 'batch_size': batch_size, + 'epochs': epochs, + 'steps': steps_per_epoch, + 'samples': num_train_samples, + 'verbose': verbose, + 'do_validation': do_validation, + 'metrics': callback_metrics or [], + }) + callbacks.on_train_begin() + callback_model.stop_training = False + for cbk in callbacks: + if not val_inputs: + cbk.validation_data = [] + elif val_sample_weights: + cbk.validation_data = val_inputs + val_targets + val_sample_weights + else: + cbk.validation_data = val_inputs + val_targets + + for epoch in range(initial_epoch, epochs): + callbacks.on_epoch_begin(epoch) + epoch_logs = {} + if shuffle == 'batch': + index_array = model._batch_shuffle(index_array, batch_size) + elif shuffle: + np.random.shuffle(index_array) + + batches = make_batches(num_train_samples, batch_size) + + for batch_index, (batch_start, batch_end) in enumerate(batches): + batch_ids = index_array[batch_start:batch_end] + try: + inputs_batch = slice_arrays(inputs, batch_ids) + targets_batch = slice_arrays(targets, batch_ids) + if sample_weights: + sample_weights_batch = slice_arrays(sample_weights, batch_ids) + else: + sample_weights_batch = None + except TypeError: + raise TypeError('TypeError while preparing batch. ' + 'If using HDF5 input data, ' + 'pass shuffle="batch".') + batch_logs = {} + batch_logs['batch'] = batch_index + batch_logs['size'] = len(batch_ids) + + callbacks.on_batch_begin(batch_index, batch_logs) + + inputs_batch = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + for val in inputs_batch] + targets_batch = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + for val in targets_batch] + if sample_weights: + sample_weights_batch = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + if val is not None else None + for val in sample_weights_batch] + + outs, loss, loss_metrics = _process_single_batch( + model, + inputs_batch, + targets_batch, + sample_weights=sample_weights_batch, + training=True) + + if not isinstance(outs, list): + outs = [outs] + + for l, o in zip(out_labels, outs): + batch_logs[l] = o + # Required for Eager mode + metrics_names, metrics_results = _eager_metrics_fn( + model, outs, targets_batch) + batch_logs['loss'] = tensor_util.constant_value(backend.mean(loss)) + + # TODO(anjalisridhar): Move this to compile to avoid duplicate code. + # In graph mode we set the metric names in compile. However in + # Eager mode we calculate the metrics for each batch in fit_loop. + # We could calculate the metric names and functions in compile. + # This would avoid setting the callback parameters separately. + # We need to do this for the first iteration alone + for m in metrics_names: + if m not in callback_metrics: + callback_metrics.append(m) + + callbacks.set_params({ + 'batch_size': batch_size, + 'epochs': epochs, + 'steps': steps_per_epoch, + 'samples': num_train_samples, + 'verbose': verbose, + 'do_validation': do_validation, + 'metrics': callback_metrics or [], + }) + + for k, v in zip(model.metrics_names, + [backend.mean(loss)] + loss_metrics + metrics_results): + batch_logs[k] = tensor_util.constant_value(v) + + callbacks.on_batch_end(batch_index, batch_logs) + if callback_model.stop_training: + break + + if batch_index == len(batches) - 1: # Last batch. + if do_validation: + val_outs = test_loop( + model, val_inputs, val_targets, + sample_weights=val_sample_weights, + batch_size=batch_size, + verbose=0) + if not isinstance(val_outs, list): + val_outs = [val_outs] + # Same labels assumed. + for l, o in zip(out_labels, val_outs): + epoch_logs['val_' + l] = o + callbacks.on_epoch_end(epoch, epoch_logs) if callback_model.stop_training: break + callbacks.on_train_end() + return model.history + - if batch_index == len(batches) - 1: # Last batch. - if do_validation: - val_outs = test_loop( - model, val_ins, batch_size=batch_size, verbose=0) - if not isinstance(val_outs, list): - val_outs = [val_outs] - # Same labels assumed. - for l, o in zip(out_labels, val_outs): - epoch_logs['val_' + l] = o - callbacks.on_epoch_end(epoch, epoch_logs) - if callback_model.stop_training: - break - callbacks.on_train_end() - return model.history - - -def test_loop(model, ins, batch_size=None, verbose=0, steps=None): +def test_loop(model, inputs, targets, + sample_weights=None, + batch_size=None, + verbose=0, + steps=None): """Abstract method to loop over some data in batches. Arguments: model: Model instance that is being evaluated in Eager mode. - ins: list of tensors to be fed to `f`. + inputs: List of input arrays. + targets: List of target arrays. + sample_weights: Optional list of sample weight arrays. batch_size: integer batch size or `None`. verbose: verbosity mode. steps: Total number of steps (batches of samples) @@ -499,69 +531,79 @@ def test_loop(model, ins, batch_size=None, verbose=0, steps=None): and/or metrics). The attribute `model.metrics_names` will give you the display labels for the scalar outputs. """ - K.set_learning_phase(False) - num_samples = model._check_num_samples(ins, batch_size, steps, 'steps') - outs = [] - if verbose == 1: - progbar = Progbar(target=num_samples) - batches = make_batches(num_samples, batch_size) - index_array = np.arange(num_samples) - for batch_index, (batch_start, batch_end) in enumerate(batches): - batch_ids = index_array[batch_start:batch_end] - if isinstance(ins[-1], float): - # Do not slice the training phase flag. - ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] - else: - ins_batch = slice_arrays(ins, batch_ids) - - ins_batch_converted = [] - for ib in ins_batch: - ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) - - eager_model_inputs = [] - eager_model_outputs = [] - for i in range(len(model.inputs)): - eager_model_inputs.append(ins_batch_converted[i]) - - for i in range(len(model.inputs), len(ins_batch_converted)): - eager_model_outputs.append(ins_batch_converted[i]) - - loss_outs, loss, loss_metrics = _model_loss(model, eager_model_inputs, - eager_model_outputs, - training=False) - _, metrics_results = _eager_metrics_fn(model, loss_outs, - eager_model_outputs) - batch_outs = [] - for _, v in zip(model.metrics_names, - [K.mean(loss)] + loss_metrics + metrics_results): - batch_outs.append(tensor_util.constant_value(v)) - - if isinstance(batch_outs, list): - if batch_index == 0: - for batch_out in enumerate(batch_outs): + with backend.learning_phase_scope(0): + feed_data = inputs + targets + if sample_weights: + feed_data += sample_weights + num_samples = training_utils.check_num_samples( + feed_data, batch_size=batch_size, steps=steps, steps_name='steps') + outs = [] + if verbose == 1: + progbar = Progbar(target=num_samples) + batches = make_batches(num_samples, batch_size) + index_array = np.arange(num_samples) + for batch_index, (batch_start, batch_end) in enumerate(batches): + batch_ids = index_array[batch_start:batch_end] + inputs_batch = slice_arrays(inputs, batch_ids) + targets_batch = slice_arrays(targets, batch_ids) + if sample_weights: + sample_weights_batch = slice_arrays(sample_weights, batch_ids) + else: + sample_weights_batch = None + + inputs_batch = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + for val in inputs_batch] + targets_batch = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + for val in targets_batch] + if sample_weights: + sample_weights_batch = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + if val is not None else None + for val in sample_weights_batch] + + loss_outs, loss, loss_metrics = _model_loss( + model, + inputs_batch, + targets_batch, + sample_weights=sample_weights_batch, + training=False) + _, metrics_results = _eager_metrics_fn(model, loss_outs, targets_batch) + batch_outs = [] + for _, v in zip(model.metrics_names, + [backend.mean(loss)] + loss_metrics + metrics_results): + batch_outs.append(tensor_util.constant_value(v)) + + if isinstance(batch_outs, list): + if batch_index == 0: + for batch_out in enumerate(batch_outs): + outs.append(0.) + for i, batch_out in enumerate(batch_outs): + outs[i] += batch_out * len(batch_ids) + else: + if batch_index == 0: outs.append(0.) - for i, batch_out in enumerate(batch_outs): - outs[i] += batch_out * len(batch_ids) - else: - if batch_index == 0: - outs.append(0.) - outs[0] += batch_outs * len(batch_ids) + outs[0] += batch_outs * len(batch_ids) - if verbose == 1: - progbar.update(batch_end) - for i in range(len(outs)): - outs[i] /= num_samples - if len(outs) == 1: - return outs[0] - return outs + if verbose == 1: + progbar.update(batch_end) + for i in range(len(outs)): + outs[i] /= num_samples + if len(outs) == 1: + return outs[0] + return outs -def predict_loop(model, ins, batch_size=32, verbose=0, steps=None): +def predict_loop(model, inputs, + batch_size=32, + verbose=0, + steps=None): """Abstract method to loop over some data in batches. Arguments: model: - ins: list of tensors to be fed to `f`. + inputs: List of input arrays. batch_size: integer batch size. verbose: verbosity mode. steps: Total number of steps (batches of samples) @@ -573,57 +615,50 @@ def predict_loop(model, ins, batch_size=32, verbose=0, steps=None): or list of arrays of predictions (if the model has multiple outputs). """ - K.set_learning_phase(False) - num_samples = model._check_num_samples(ins, batch_size, steps, 'steps') - if verbose == 1: - if steps is not None: - progbar = Progbar(target=steps) - else: - progbar = Progbar(target=num_samples) - - outs = [] - batches = make_batches(num_samples, batch_size) - index_array = np.arange(num_samples) - for batch_index, (batch_start, batch_end) in enumerate(batches): - batch_ids = index_array[batch_start:batch_end] - if ins and isinstance(ins[-1], float): - # Do not slice the training phase flag. - ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] - else: - ins_batch = slice_arrays(ins, batch_ids) + with backend.learning_phase_scope(0): + num_samples = training_utils.check_num_samples( + inputs, batch_size, steps, 'steps') + if verbose == 1: + if steps is not None: + progbar = Progbar(target=steps) + else: + progbar = Progbar(target=num_samples) - ins_batch_converted = [] - for ib in ins_batch: - ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) + outs = [] + batches = make_batches(num_samples, batch_size) + index_array = np.arange(num_samples) + for batch_index, (batch_start, batch_end) in enumerate(batches): + batch_ids = index_array[batch_start:batch_end] + inputs_batch = slice_arrays(inputs, batch_ids) - eager_model_inputs = [] - for i in range(len(model.inputs)): - eager_model_inputs.append(ins_batch_converted[i]) + inputs_batch = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + for val in inputs_batch] - if len(eager_model_inputs) == 1: - if model._expects_training_arg: - batch_outs = model.call(eager_model_inputs[0], training=False) - else: - batch_outs = model.call(eager_model_inputs[0]) - else: - if model._expects_training_arg: - batch_outs = model.call(eager_model_inputs, training=False) + if len(inputs_batch) == 1: + if model._expects_training_arg: + batch_outs = model.call(inputs_batch[0], training=False) + else: + batch_outs = model.call(inputs_batch[0]) else: - batch_outs = model.call(eager_model_inputs) - - if not isinstance(batch_outs, list): - batch_outs = [batch_outs] - if batch_index == 0: - # Pre-allocate the results arrays. - for batch_out in batch_outs: - dims = batch_out.shape[1:].dims - dims_list = [d.value for d in dims] - shape = (num_samples,) + tuple(dims_list) - outs.append(np.zeros(shape, dtype=batch_out.dtype.as_numpy_dtype)) - for i, batch_out in enumerate(batch_outs): - outs[i][batch_start:batch_end] = batch_out - if verbose == 1: - progbar.update(batch_end) - if len(outs) == 1: - return outs[0] - return outs + if model._expects_training_arg: + batch_outs = model.call(inputs_batch, training=False) + else: + batch_outs = model.call(inputs_batch) + + if not isinstance(batch_outs, list): + batch_outs = [batch_outs] + if batch_index == 0: + # Pre-allocate the results arrays. + for batch_out in batch_outs: + dims = batch_out.shape[1:].dims + dims_list = [d.value for d in dims] + shape = (num_samples,) + tuple(dims_list) + outs.append(np.zeros(shape, dtype=batch_out.dtype.as_numpy_dtype)) + for i, batch_out in enumerate(batch_outs): + outs[i][batch_start:batch_end] = batch_out + if verbose == 1: + progbar.update(batch_end) + if len(outs) == 1: + return outs[0] + return outs diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py index 3d94b7537ff354351d3f4adde6e0819ce66ea377..8848b393d5e602e564cb357c32a937eaabd68203 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py @@ -24,9 +24,7 @@ import numpy as np from tensorflow.python.framework import ops from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils -from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.rmsprop import RMSPropOptimizer @@ -316,10 +314,9 @@ class LossWeightingTest(test.TestCase): def test_class_weights(self): num_classes = 5 batch_size = 5 - epochs = 5 weighted_class = 3 - train_samples = 3000 - test_samples = 3000 + train_samples = 300 + test_samples = 300 input_dim = 5 model = keras.models.Sequential() @@ -344,16 +341,16 @@ class LossWeightingTest(test.TestCase): test_ids = np.where(int_y_test == np.array(weighted_class))[0] class_weight = dict([(i, 1.) for i in range(num_classes)]) - class_weight[weighted_class] = 2. + class_weight[weighted_class] = 4. sample_weight = np.ones((y_train.shape[0])) - sample_weight[int_y_train == weighted_class] = 2. + sample_weight[int_y_train == weighted_class] = 4. model.fit( x_train, y_train, batch_size=batch_size, - epochs=epochs // 3, + epochs=2, verbose=0, class_weight=class_weight, validation_data=(x_train, y_train, sample_weight)) @@ -361,14 +358,14 @@ class LossWeightingTest(test.TestCase): x_train, y_train, batch_size=batch_size, - epochs=epochs // 2, + epochs=2, verbose=0, class_weight=class_weight) model.fit( x_train, y_train, batch_size=batch_size, - epochs=epochs // 2, + epochs=2, verbose=0, class_weight=class_weight, validation_split=0.1) @@ -383,10 +380,9 @@ class LossWeightingTest(test.TestCase): def test_sample_weights(self): num_classes = 5 batch_size = 5 - epochs = 5 weighted_class = 3 - train_samples = 3000 - test_samples = 3000 + train_samples = 300 + test_samples = 300 input_dim = 5 model = keras.models.Sequential() @@ -407,23 +403,23 @@ class LossWeightingTest(test.TestCase): y_train = keras.utils.to_categorical(y_train, num_classes) class_weight = dict([(i, 1.) for i in range(num_classes)]) - class_weight[weighted_class] = 2. + class_weight[weighted_class] = 4. sample_weight = np.ones((y_train.shape[0])) - sample_weight[int_y_train == weighted_class] = 2. + sample_weight[int_y_train == weighted_class] = 4. model.fit( x_train, y_train, batch_size=batch_size, - epochs=epochs // 3, + epochs=2, verbose=0, sample_weight=sample_weight) model.fit( x_train, y_train, batch_size=batch_size, - epochs=epochs // 3, + epochs=2, verbose=0, sample_weight=sample_weight, validation_split=0.1) @@ -536,215 +532,6 @@ class LossWeightingTest(test.TestCase): model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np}) -class TestDynamicTrainability(test.TestCase): - - def test_trainable_warning(self): - x = np.random.random((5, 3)) - y = np.random.random((5, 2)) - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_dim=3)) - model.trainable = False - model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse') - model.trainable = True - with test.mock.patch.object(logging, 'warning') as mock_log: - model.train_on_batch(x, y) - self.assertRegexpMatches(str(mock_log.call_args), - 'trainable weights is empty') - - def test_trainable_argument(self): - x = np.random.random((5, 3)) - y = np.random.random((5, 2)) - - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_dim=3, trainable=False)) - model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse') - out = model.predict(x) - with test.mock.patch.object(logging, 'warning') as mock_log: - model.train_on_batch(x, y) - self.assertRegexpMatches(str(mock_log.call_args), - 'trainable weights is empty') - out_2 = model.predict(x) - self.assertAllClose(out, out_2) - - # test with nesting - inputs = keras.layers.Input(shape=(3,)) - output = model(inputs) - model = keras.models.Model(inputs, output) - model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse') - out = model.predict(x) - with test.mock.patch.object(logging, 'warning') as mock_log: - model.train_on_batch(x, y) - self.assertRegexpMatches(str(mock_log.call_args), - 'trainable weights is empty') - out_2 = model.predict(x) - self.assertAllClose(out, out_2) - - def test_layer_trainability_switch(self): - # with constructor argument, in Sequential - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, trainable=False, input_dim=1)) - self.assertListEqual(model.trainable_weights, []) - - # by setting the `trainable` argument, in Sequential - model = keras.models.Sequential() - layer = keras.layers.Dense(2, input_dim=1) - model.add(layer) - self.assertListEqual(model.trainable_weights, layer.trainable_weights) - layer.trainable = False - self.assertListEqual(model.trainable_weights, []) - - # with constructor argument, in Model - x = keras.layers.Input(shape=(1,)) - y = keras.layers.Dense(2, trainable=False)(x) - model = keras.models.Model(x, y) - self.assertListEqual(model.trainable_weights, []) - - # by setting the `trainable` argument, in Model - x = keras.layers.Input(shape=(1,)) - layer = keras.layers.Dense(2) - y = layer(x) - model = keras.models.Model(x, y) - self.assertListEqual(model.trainable_weights, layer.trainable_weights) - layer.trainable = False - self.assertListEqual(model.trainable_weights, []) - - def test_model_trainability_switch(self): - # a non-trainable model has no trainable weights - x = keras.layers.Input(shape=(1,)) - y = keras.layers.Dense(2)(x) - model = keras.models.Model(x, y) - model.trainable = False - self.assertListEqual(model.trainable_weights, []) - - # same for Sequential - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_dim=1)) - model.trainable = False - self.assertListEqual(model.trainable_weights, []) - - def test_nested_model_trainability(self): - - # a Sequential inside a Model - inner_model = keras.models.Sequential() - inner_model.add(keras.layers.Dense(2, input_dim=1)) - - x = keras.layers.Input(shape=(1,)) - y = inner_model(x) - outer_model = keras.models.Model(x, y) - self.assertListEqual(outer_model.trainable_weights, - inner_model.trainable_weights) - inner_model.trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - inner_model.trainable = True - inner_model.layers[-1].trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - - # a Sequential inside a Sequential - inner_model = keras.models.Sequential() - inner_model.add(keras.layers.Dense(2, input_dim=1)) - outer_model = keras.models.Sequential() - outer_model.add(inner_model) - self.assertListEqual(outer_model.trainable_weights, - inner_model.trainable_weights) - inner_model.trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - inner_model.trainable = True - inner_model.layers[-1].trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - - # a Model inside a Model - x = keras.layers.Input(shape=(1,)) - y = keras.layers.Dense(2)(x) - inner_model = keras.models.Model(x, y) - x = keras.layers.Input(shape=(1,)) - y = inner_model(x) - outer_model = keras.models.Model(x, y) - self.assertListEqual(outer_model.trainable_weights, - inner_model.trainable_weights) - inner_model.trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - inner_model.trainable = True - inner_model.layers[-1].trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - - # a Model inside a Sequential - x = keras.layers.Input(shape=(1,)) - y = keras.layers.Dense(2)(x) - inner_model = keras.models.Model(x, y) - outer_model = keras.models.Sequential() - outer_model.add(inner_model) - self.assertListEqual(outer_model.trainable_weights, - inner_model.trainable_weights) - inner_model.trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - inner_model.trainable = True - inner_model.layers[-1].trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - - -class TestTrainingUtils(test.TestCase): - - def test_check_array_lengths(self): - keras.engine.training._check_array_lengths(None, None, None) - a_np = np.random.random((4, 3, 3)) - keras.engine.training._check_array_lengths(a_np, a_np, a_np) - keras.engine.training._check_array_lengths( - [a_np, a_np], [a_np, a_np], [a_np, a_np]) - keras.engine.training._check_array_lengths([None], [None], [None]) - - b_np = np.random.random((3, 4)) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths(a_np, None, None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths(a_np, a_np, None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths([a_np], [None], None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths([a_np], [b_np], None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths([a_np], None, [b_np]) - - def test_slice_arrays(self): - input_a = np.random.random((10, 3)) - slice_arrays(None) - slice_arrays(input_a, 0) - slice_arrays(input_a, 0, 1) - slice_arrays(input_a, stop=2) - input_a = [None, [1, 1], None, [1, 1]] - slice_arrays(input_a, 0) - slice_arrays(input_a, 0, 1) - slice_arrays(input_a, stop=2) - input_a = [None] - slice_arrays(input_a, 0) - slice_arrays(input_a, 0, 1) - slice_arrays(input_a, stop=2) - input_a = None - slice_arrays(input_a, 0) - slice_arrays(input_a, 0, 1) - slice_arrays(input_a, stop=2) - - def test_fit_with_BatchNorm(self): - model = keras.models.Sequential() - model.add(keras.layers.Dense(10, input_dim=4)) - model.add(keras.layers.BatchNormalization()) - model.add(keras.layers.Activation('tanh')) - model.add(keras.layers.Dropout(0.2)) - - input_a_np = np.random.random((10, 4)) - output_b_np = np.random.random((10, 10)) - - model.compile(loss='binary_crossentropy', optimizer=RMSPropOptimizer(0.001)) - model.fit(input_a_np, output_b_np, epochs=1, batch_size=5, verbose=0) - - def test_fit_with_regularization(self): - model = keras.models.Sequential() - with self.assertRaises(ValueError): - model.add( - keras.layers.Dense(4, input_dim=3, - kernel_regularizer=keras.regularizers.l2(0.01), - activity_regularizer=keras.regularizers.l1(0.01))) - - if __name__ == '__main__': # Bazel sets these environment variables to very long paths. # Tempfile uses them to create long paths, and in turn multiprocessing diff --git a/tensorflow/python/keras/_impl/keras/engine/training_generator.py b/tensorflow/python/keras/_impl/keras/engine/training_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..58b5bc39c10ea06f680eb030e14ecd19a3888588 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/training_generator.py @@ -0,0 +1,436 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Part of the Keras training engine related to Python generators of array data. +""" +# pylint: disable=protected-access +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras import callbacks as cbks +from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer +from tensorflow.python.keras._impl.keras.utils.data_utils import OrderedEnqueuer +from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence +from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar +from tensorflow.python.platform import tf_logging as logging + + +def fit_generator(model, + generator, + steps_per_epoch=None, + epochs=1, + verbose=1, + callbacks=None, + validation_data=None, + validation_steps=None, + class_weight=None, + max_queue_size=10, + workers=1, + use_multiprocessing=False, + shuffle=True, + initial_epoch=0): + """See docstring for `Model.fit_generator`.""" + wait_time = 0.01 # in seconds + epoch = initial_epoch + + do_validation = bool(validation_data) + model._make_train_function() + if do_validation: + model._make_test_function() + + is_sequence = isinstance(generator, Sequence) + if not is_sequence and use_multiprocessing and workers > 1: + logging.warning( + UserWarning('Using a generator with `use_multiprocessing=True`' + ' and multiple workers may duplicate your data.' + ' Please consider using the`keras.utils.Sequence' + ' class.')) + if steps_per_epoch is None: + if is_sequence: + steps_per_epoch = len(generator) + else: + raise ValueError('`steps_per_epoch=None` is only valid for a' + ' generator based on the `keras.utils.Sequence`' + ' class. Please specify `steps_per_epoch` or use' + ' the `keras.utils.Sequence` class.') + + # python 2 has 'next', 3 has '__next__' + # avoid any explicit version checks + val_gen = ( + hasattr(validation_data, 'next') or + hasattr(validation_data, '__next__') or + isinstance(validation_data, Sequence)) + if (val_gen and not isinstance(validation_data, Sequence) and + not validation_steps): + raise ValueError('`validation_steps=None` is only valid for a' + ' generator based on the `keras.utils.Sequence`' + ' class. Please specify `validation_steps` or use' + ' the `keras.utils.Sequence` class.') + + # Prepare display labels. + out_labels = model.metrics_names + callback_metrics = out_labels + ['val_%s' % n for n in out_labels] + + # prepare callbacks + model.history = cbks.History() + callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history] + if verbose: + callbacks += [cbks.ProgbarLogger(count_mode='steps')] + callbacks = cbks.CallbackList(callbacks) + + # it's possible to callback a different model than self: + if hasattr(model, 'callback_model') and model.callback_model: + callback_model = model.callback_model + else: + callback_model = model + callbacks.set_model(callback_model) + callbacks.set_params({ + 'epochs': epochs, + 'steps': steps_per_epoch, + 'verbose': verbose, + 'do_validation': do_validation, + 'metrics': callback_metrics, + }) + callbacks.on_train_begin() + + enqueuer = None + val_enqueuer = None + + try: + if do_validation and not val_gen: + # Prepare data for validation + if len(validation_data) == 2: + val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence + val_sample_weight = None + elif len(validation_data) == 3: + val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence + else: + raise ValueError( + '`validation_data` should be a tuple ' + '`(val_x, val_y, val_sample_weight)` ' + 'or `(val_x, val_y)`. Found: ' + str(validation_data)) + val_x, val_y, val_sample_weights = model._standardize_user_data( + val_x, val_y, val_sample_weight) + val_data = val_x + val_y + val_sample_weights + if model.uses_learning_phase and not isinstance(K.learning_phase(), int): + val_data += [0.] + for cbk in callbacks: + cbk.validation_data = val_data + + if workers > 0: + if is_sequence: + enqueuer = OrderedEnqueuer( + generator, + use_multiprocessing=use_multiprocessing, + shuffle=shuffle) + else: + enqueuer = GeneratorEnqueuer( + generator, + use_multiprocessing=use_multiprocessing, + wait_time=wait_time) + enqueuer.start(workers=workers, max_queue_size=max_queue_size) + output_generator = enqueuer.get() + else: + if is_sequence: + output_generator = iter(generator) + else: + output_generator = generator + + callback_model.stop_training = False + # Construct epoch logs. + epoch_logs = {} + while epoch < epochs: + callbacks.on_epoch_begin(epoch) + steps_done = 0 + batch_index = 0 + while steps_done < steps_per_epoch: + generator_output = next(output_generator) + + if not hasattr(generator_output, '__len__'): + raise ValueError('Output of generator should be ' + 'a tuple `(x, y, sample_weight)` ' + 'or `(x, y)`. Found: ' + str(generator_output)) + + if len(generator_output) == 2: + x, y = generator_output + sample_weight = None + elif len(generator_output) == 3: + x, y, sample_weight = generator_output + else: + raise ValueError('Output of generator should be ' + 'a tuple `(x, y, sample_weight)` ' + 'or `(x, y)`. Found: ' + str(generator_output)) + # build batch logs + batch_logs = {} + if isinstance(x, list): + batch_size = x[0].shape[0] + elif isinstance(x, dict): + batch_size = list(x.values())[0].shape[0] + else: + batch_size = x.shape[0] + batch_logs['batch'] = batch_index + batch_logs['size'] = batch_size + callbacks.on_batch_begin(batch_index, batch_logs) + + outs = model.train_on_batch( + x, y, sample_weight=sample_weight, class_weight=class_weight) + + if not isinstance(outs, list): + outs = [outs] + for l, o in zip(out_labels, outs): + batch_logs[l] = o + + callbacks.on_batch_end(batch_index, batch_logs) + + batch_index += 1 + steps_done += 1 + + # Epoch finished. + if steps_done >= steps_per_epoch and do_validation: + if val_gen: + val_outs = evaluate_generator( + model, + validation_data, + validation_steps, + workers=workers, + use_multiprocessing=use_multiprocessing, + max_queue_size=max_queue_size) + else: + # No need for try/except because + # data has already been validated. + val_outs = model.evaluate( + val_x, + val_y, + batch_size=batch_size, + sample_weight=val_sample_weights, + verbose=0) + if not isinstance(val_outs, list): + val_outs = [val_outs] + # Same labels assumed. + for l, o in zip(out_labels, val_outs): + epoch_logs['val_' + l] = o + + if callback_model.stop_training: + break + + callbacks.on_epoch_end(epoch, epoch_logs) + epoch += 1 + if callback_model.stop_training: + break + + finally: + try: + if enqueuer is not None: + enqueuer.stop() + finally: + if val_enqueuer is not None: + val_enqueuer.stop() + + callbacks.on_train_end() + return model.history + + +def evaluate_generator(model, + generator, + steps=None, + max_queue_size=10, + workers=1, + use_multiprocessing=False): + """See docstring for `Model.evaluate_generator`.""" + model._make_test_function() + + steps_done = 0 + wait_time = 0.01 + all_outs = [] + batch_sizes = [] + is_sequence = isinstance(generator, Sequence) + if not is_sequence and use_multiprocessing and workers > 1: + logging.warning( + UserWarning('Using a generator with `use_multiprocessing=True`' + ' and multiple workers may duplicate your data.' + ' Please consider using the`keras.utils.Sequence' + ' class.')) + if steps is None: + if is_sequence: + steps = len(generator) + else: + raise ValueError('`steps=None` is only valid for a generator' + ' based on the `keras.utils.Sequence` class.' + ' Please specify `steps` or use the' + ' `keras.utils.Sequence` class.') + enqueuer = None + + try: + if workers > 0: + if is_sequence: + enqueuer = OrderedEnqueuer( + generator, use_multiprocessing=use_multiprocessing) + else: + enqueuer = GeneratorEnqueuer( + generator, + use_multiprocessing=use_multiprocessing, + wait_time=wait_time) + enqueuer.start(workers=workers, max_queue_size=max_queue_size) + output_generator = enqueuer.get() + else: + if is_sequence: + output_generator = iter(generator) + else: + output_generator = generator + + while steps_done < steps: + generator_output = next(output_generator) + if not hasattr(generator_output, '__len__'): + raise ValueError('Output of generator should be a tuple ' + '(x, y, sample_weight) ' + 'or (x, y). Found: ' + str(generator_output)) + if len(generator_output) == 2: + x, y = generator_output + sample_weight = None + elif len(generator_output) == 3: + x, y, sample_weight = generator_output + else: + raise ValueError('Output of generator should be a tuple ' + '(x, y, sample_weight) ' + 'or (x, y). Found: ' + str(generator_output)) + outs = model.test_on_batch(x, y, sample_weight=sample_weight) + + if isinstance(x, list): + batch_size = x[0].shape[0] + elif isinstance(x, dict): + batch_size = list(x.values())[0].shape[0] + else: + batch_size = x.shape[0] + if batch_size == 0: + raise ValueError('Received an empty batch. ' + 'Batches should at least contain one item.') + all_outs.append(outs) + + steps_done += 1 + batch_sizes.append(batch_size) + + finally: + if enqueuer is not None: + enqueuer.stop() + + if not isinstance(outs, list): + return np.average(np.asarray(all_outs), weights=batch_sizes) + else: + averages = [] + for i in range(len(outs)): + averages.append( + np.average([out[i] for out in all_outs], weights=batch_sizes)) + return averages + + +def predict_generator(model, + generator, + steps=None, + max_queue_size=10, + workers=1, + use_multiprocessing=False, + verbose=0): + """See docstring for `Model.predict_generator`.""" + model._make_predict_function() + + steps_done = 0 + wait_time = 0.01 + all_outs = [] + is_sequence = isinstance(generator, Sequence) + if not is_sequence and use_multiprocessing and workers > 1: + logging.warning( + UserWarning('Using a generator with `use_multiprocessing=True`' + ' and multiple workers may duplicate your data.' + ' Please consider using the`keras.utils.Sequence' + ' class.')) + if steps is None: + if is_sequence: + steps = len(generator) + else: + raise ValueError('`steps=None` is only valid for a generator' + ' based on the `keras.utils.Sequence` class.' + ' Please specify `steps` or use the' + ' `keras.utils.Sequence` class.') + enqueuer = None + + try: + if workers > 0: + if is_sequence: + enqueuer = OrderedEnqueuer( + generator, use_multiprocessing=use_multiprocessing) + else: + enqueuer = GeneratorEnqueuer( + generator, + use_multiprocessing=use_multiprocessing, + wait_time=wait_time) + enqueuer.start(workers=workers, max_queue_size=max_queue_size) + output_generator = enqueuer.get() + else: + if is_sequence: + output_generator = iter(generator) + else: + output_generator = generator + + if verbose == 1: + progbar = Progbar(target=steps) + + while steps_done < steps: + generator_output = next(output_generator) + if isinstance(generator_output, tuple): + # Compatibility with the generators + # used for training. + if len(generator_output) == 2: + x, _ = generator_output + elif len(generator_output) == 3: + x, _, _ = generator_output + else: + raise ValueError('Output of generator should be ' + 'a tuple `(x, y, sample_weight)` ' + 'or `(x, y)`. Found: ' + str(generator_output)) + else: + # Assumes a generator that only + # yields inputs (not targets and sample weights). + x = generator_output + + outs = model.predict_on_batch(x) + if not isinstance(outs, list): + outs = [outs] + + if not all_outs: + for out in outs: + all_outs.append([]) + + for i, out in enumerate(outs): + all_outs[i].append(out) + steps_done += 1 + if verbose == 1: + progbar.update(steps_done) + + finally: + if enqueuer is not None: + enqueuer.stop() + + if len(all_outs) == 1: + if steps_done == 1: + return all_outs[0][0] + else: + return np.concatenate(all_outs[0]) + if steps_done == 1: + return [out[0] for out in all_outs] + else: + return [np.concatenate(out) for out in all_outs] diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py index 9651eb9f14f1275dc79c8d3b1fb54690772086a1..fd91dbba52ff7d152335514085ef3b057ae5eec4 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py @@ -25,7 +25,7 @@ import numpy as np from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils -from tensorflow.python.keras._impl.keras.engine.training import _weighted_masked_objective +from tensorflow.python.keras._impl.keras.engine.training_utils import weighted_masked_objective from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays from tensorflow.python.platform import test @@ -340,20 +340,21 @@ class TrainingTest(test.TestCase): if scipy_sparse is None: return - test_inputs = [ - scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)] - test_outputs = [ - scipy_sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5)] - in1 = keras.layers.Input(shape=(3,)) - in2 = keras.layers.Input(shape=(3,)) - out1 = keras.layers.Dropout(0.5, name='dropout')(in1) - out2 = keras.layers.Dense(4, name='dense_1')(in2) - model = keras.Model([in1, in2], [out1, out2]) - model.predict(test_inputs, batch_size=2) - model.compile('rmsprop', 'mse') - model.fit(test_inputs, test_outputs, - epochs=1, batch_size=2, validation_split=0.5) - model.evaluate(test_inputs, test_outputs, batch_size=2) + with self.test_session(): + test_inputs = [ + scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)] + test_outputs = [ + scipy_sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5)] + in1 = keras.layers.Input(shape=(3,)) + in2 = keras.layers.Input(shape=(3,)) + out1 = keras.layers.Dropout(0.5, name='dropout')(in1) + out2 = keras.layers.Dense(4, name='dense_1')(in2) + model = keras.Model([in1, in2], [out1, out2]) + model.predict(test_inputs, batch_size=2) + model.compile('rmsprop', 'mse') + model.fit(test_inputs, test_outputs, + epochs=1, batch_size=2, validation_split=0.5) + model.evaluate(test_inputs, test_outputs, batch_size=2) def test_that_trainable_disables_updates(self): val_a = np.random.random((10, 4)) @@ -705,7 +706,7 @@ class LossMaskingTest(test.TestCase): def test_loss_masking(self): with self.test_session(): - weighted_loss = _weighted_masked_objective(keras.losses.get('mae')) + weighted_loss = weighted_masked_objective(keras.losses.get('mae')) shape = (3, 4, 2) x = np.arange(24).reshape(shape) y = 2 * x @@ -876,9 +877,9 @@ class TestGeneratorMethods(test.TestCase): def custom_generator(): batch_size = 10 - n_samples = 50 + num_samples = 50 while True: - batch_index = np.random.randint(0, n_samples - batch_size) + batch_index = np.random.randint(0, num_samples - batch_size) start = batch_index end = start + batch_size x = arr_data[start: end] @@ -957,9 +958,9 @@ class TestGeneratorMethods(test.TestCase): def custom_generator(): batch_size = 10 - n_samples = 50 + num_samples = 50 while True: - batch_index = np.random.randint(0, n_samples - batch_size) + batch_index = np.random.randint(0, num_samples - batch_size) start = batch_index end = start + batch_size x = arr_data[start: end] @@ -1033,28 +1034,66 @@ class TestGeneratorMethods(test.TestCase): max_queue_size=10, use_multiprocessing=False) + def test_training_with_sequences(self): + + class DummySequence(keras.utils.Sequence): + + def __getitem__(self, idx): + return np.zeros([10, 2]), np.ones([10]) + + def __len__(self): + return 10 + + arr_data = np.random.random((50, 2)) + arr_labels = np.random.random((50,)) + arr_sample_weights = np.random.random((50,)) + + def custom_generator(): + batch_size = 10 + num_samples = 50 + while True: + batch_index = np.random.randint(0, num_samples - batch_size) + start = batch_index + end = start + batch_size + x = arr_data[start: end] + y = arr_labels[start: end] + w = arr_sample_weights[start: end] + yield x, y, w + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(1, input_shape=(2,))) + model.compile(loss='mse', optimizer='sgd') + + model.fit_generator(DummySequence(), + steps_per_epoch=10, + validation_data=custom_generator(), + validation_steps=1, + max_queue_size=10, + workers=0, + use_multiprocessing=True) + model.fit_generator(DummySequence(), + steps_per_epoch=10, + validation_data=custom_generator(), + validation_steps=1, + max_queue_size=10, + workers=0, + use_multiprocessing=False) + class TestTrainingUtils(test.TestCase): def test_check_array_lengths(self): - keras.engine.training._check_array_lengths(None, None, None) + keras.engine.training_utils.check_array_lengths(None, None, None) a_np = np.random.random((4, 3, 3)) - keras.engine.training._check_array_lengths(a_np, a_np, a_np) - keras.engine.training._check_array_lengths( + keras.engine.training_utils.check_array_lengths(a_np, a_np, a_np) + keras.engine.training_utils.check_array_lengths( [a_np, a_np], [a_np, a_np], [a_np, a_np]) - keras.engine.training._check_array_lengths([None], [None], [None]) + keras.engine.training_utils.check_array_lengths([None], [None], [None]) b_np = np.random.random((3, 4)) with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths(a_np, None, None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths(a_np, a_np, None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths([a_np], [None], None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths([a_np], [b_np], None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths([a_np], None, [b_np]) + keras.engine.training_utils.check_array_lengths([a_np], [b_np], None) def test_slice_arrays(self): input_a = np.random.random((10, 3)) diff --git a/tensorflow/python/keras/_impl/keras/engine/training_utils.py b/tensorflow/python/keras/_impl/keras/engine/training_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..105638ce1087e8668b49b6653a847667e8f9157d --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/training_utils.py @@ -0,0 +1,534 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Training-related utilities. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +import numpy as np + +from tensorflow.python.framework import tensor_util +from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras import losses + + +def check_num_samples(ins, + batch_size=None, + steps=None, + steps_name='steps'): + """Determine the number of samples provided for training and evaluation. + + The number of samples is not defined when running with `steps`, + in which case the number of samples is set to `None`. + + Arguments: + ins: List of tensors to be fed to the Keras function. + batch_size: Integer batch size or `None` if not defined. + steps: Total number of steps (batches of samples) + before declaring `_predict_loop` finished. + Ignored with the default value of `None`. + steps_name: The public API's parameter name for `steps`. + + Raises: + ValueError: when `steps` is `None` and the attribute `ins.shape` + does not exist. Also raises ValueError when `steps` is not `None` + and `batch_size` is not `None` because they are mutually + exclusive. + + Returns: + When steps is `None`, returns the number of samples to be + processed based on the size of the first dimension of the + first input numpy array. When steps is not `None` and + `batch_size` is `None`, returns `None`. + + Raises: + ValueError: In case of invalid arguments. + """ + if steps is not None: + num_samples = None + if batch_size is not None: + raise ValueError( + 'If ' + steps_name + ' is set, the `batch_size` must be None.') + elif ins and hasattr(ins[0], 'shape'): + num_samples = ins[0].shape[0] + else: + raise ValueError( + 'Either the input data should have ' + 'a defined shape, or ' + steps_name + ' should be specified.') + return num_samples + + +def standardize_input_data(data, + names, + shapes=None, + check_batch_axis=True, + exception_prefix=''): + """Normalizes inputs and targets provided by users. + + Users may pass data as a list of arrays, dictionary of arrays, + or as a single array. We normalize this to an ordered list of + arrays (same order as `names`), while checking that the provided + arrays have shapes that match the network's expectations. + + Arguments: + data: User-provided input data (polymorphic). + names: List of expected array names. + shapes: Optional list of expected array shapes. + check_batch_axis: Boolean; whether to check that + the batch axis of the arrays matches the expected + value found in `shapes`. + exception_prefix: String prefix used for exception formatting. + + Returns: + List of standardized input arrays (one array per model input). + + Raises: + ValueError: in case of improperly formatted user-provided data. + """ + if not names: + if data is not None and hasattr(data, '__len__') and len(data): + raise ValueError('Error when checking model ' + exception_prefix + ': ' + 'expected no data, but got:', data) + return [] + if data is None: + return [None for _ in range(len(names))] + + if isinstance(data, dict): + try: + data = [ + data[x].values + if data[x].__class__.__name__ == 'DataFrame' else data[x] + for x in names + ] + except KeyError as e: + raise ValueError('No data provided for "' + e.args[0] + '". Need data ' + 'for each key in: ' + str(names)) + elif isinstance(data, list): + if isinstance(data[0], list): + data = [np.asarray(d) for d in data] + elif len(names) == 1 and isinstance(data[0], (float, int)): + data = [np.asarray(data)] + else: + data = [ + x.values if x.__class__.__name__ == 'DataFrame' else x for x in data + ] + else: + data = data.values if data.__class__.__name__ == 'DataFrame' else data + data = [data] + data = [ + np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x for x in data + ] + + if len(data) != len(names): + if data and hasattr(data[0], 'shape'): + raise ValueError('Error when checking model ' + exception_prefix + + ': the list of Numpy arrays that you are passing to ' + 'your model is not the size the model expected. ' + 'Expected to see ' + str(len(names)) + ' array(s), ' + 'but instead got the following list of ' + + str(len(data)) + ' arrays: ' + str(data)[:200] + '...') + elif len(names) > 1: + raise ValueError( + 'Error when checking model ' + exception_prefix + + ': you are passing a list as input to your model, ' + 'but the model expects a list of ' + str(len(names)) + + ' Numpy arrays instead. The list you passed was: ' + str(data)[:200]) + elif len(data) == 1 and not hasattr(data[0], 'shape'): + raise TypeError('Error when checking model ' + exception_prefix + + ': data should be a Numpy array, or list/dict of ' + 'Numpy arrays. Found: ' + str(data)[:200] + '...') + elif len(names) == 1: + data = [np.asarray(data)] + + # Check shapes compatibility. + if shapes: + for i in range(len(names)): + if shapes[i] is not None: + data_shape = data[i].shape + shape = shapes[i] + if data[i].ndim != len(shape): + raise ValueError('Error when checking ' + exception_prefix + + ': expected ' + names[i] + ' to have ' + + str(len(shape)) + ' dimensions, but got array ' + 'with shape ' + str(data_shape)) + if not check_batch_axis: + data_shape = data_shape[1:] + shape = shape[1:] + for dim, ref_dim in zip(data_shape, shape): + if ref_dim != dim and ref_dim: + raise ValueError( + 'Error when checking ' + exception_prefix + ': expected ' + + names[i] + ' to have shape ' + str(shape) + + ' but got array with shape ' + str(data_shape)) + return data + + +def standardize_sample_or_class_weights(x_weight, output_names, weight_type): + """Maps `sample_weight` or `class_weight` to model outputs. + + Arguments: + x_weight: User-provided `sample_weight` or `class_weight` argument. + output_names: List of output names (strings) in the model. + weight_type: A string used purely for exception printing. + + Returns: + A list of `sample_weight` or `class_weight` where there are exactly + one element per model output. + + Raises: + ValueError: In case of invalid user-provided argument. + """ + if x_weight is None or len(x_weight) == 0: # pylint: disable=g-explicit-length-test + return [None for _ in output_names] + if len(output_names) == 1: + if isinstance(x_weight, list) and len(x_weight) == 1: + return x_weight + if isinstance(x_weight, dict) and output_names[0] in x_weight: + return [x_weight[output_names[0]]] + else: + return [x_weight] + if isinstance(x_weight, list): + if len(x_weight) != len(output_names): + raise ValueError('Provided `' + weight_type + '` was a list of ' + + str(len(x_weight)) + ' elements, but the model has ' + + str(len(output_names)) + ' outputs. ' + 'You should provide one `' + weight_type + '`' + 'array per model output.') + return x_weight + if isinstance(x_weight, dict): + x_weights = [] + for name in output_names: + x_weights.append(x_weight.get(name)) + return x_weights + else: + raise TypeError( + 'The model has multiple outputs, so `' + weight_type + '` ' + 'should be either a list or a dict. ' + 'Provided `' + weight_type + '` type not understood: ' + str(x_weight)) + + +def standardize_class_weights(class_weight, output_names): + return standardize_sample_or_class_weights(class_weight, output_names, + 'class_weight') + + +def standardize_sample_weights(sample_weight, output_names): + return standardize_sample_or_class_weights(sample_weight, output_names, + 'sample_weight') + + +def check_array_lengths(inputs, targets, weights=None): + """Does user input validation for numpy arrays. + + Arguments: + inputs: list of Numpy arrays of inputs. + targets: list of Numpy arrays of targets. + weights: list of Numpy arrays of sample weights. + + Raises: + ValueError: in case of incorrectly formatted data. + """ + + def set_of_lengths(x): + # return a set with the variation between + # different shapes, with None => 0 + if x is None: + return {} + else: + return set([y.shape[0] for y in x if y is not None]) + + set_x = set_of_lengths(inputs) + set_y = set_of_lengths(targets) + set_w = set_of_lengths(weights) + if len(set_x) > 1: + raise ValueError('All input arrays (x) should have ' + 'the same number of samples. Got array shapes: ' + + str([x.shape for x in inputs])) + if len(set_y) > 1: + raise ValueError('All target arrays (y) should have ' + 'the same number of samples. Got array shapes: ' + + str([y.shape for y in targets])) + if set_x and set_y and list(set_x)[0] != list(set_y)[0]: + raise ValueError('Input arrays should have ' + 'the same number of samples as target arrays. ' + 'Found ' + str(list(set_x)[0]) + ' input samples ' + 'and ' + str(list(set_y)[0]) + ' target samples.') + if len(set_w) > 1: + raise ValueError('All sample_weight arrays should have ' + 'the same number of samples. Got array shapes: ' + + str([w.shape for w in weights])) + if set_y and set_w and list(set_y)[0] != list(set_w)[0]: + raise ValueError('Sample_weight arrays should have ' + 'the same number of samples as target arrays. Got ' + + str(list(set_y)[0]) + ' input samples and ' + + str(list(set_w)[0]) + ' target samples.') + + +def check_loss_and_target_compatibility(targets, loss_fns, output_shapes): + """Does validation on the compatibility of targets and loss functions. + + This helps prevent users from using loss functions incorrectly. This check + is purely for UX purposes. + + Arguments: + targets: list of Numpy arrays of targets. + loss_fns: list of loss functions. + output_shapes: list of shapes of model outputs. + + Raises: + ValueError: if a loss function or target array + is incompatible with an output. + """ + key_losses = { + losses.mean_squared_error, losses.binary_crossentropy, + losses.categorical_crossentropy + } + for y, loss, shape in zip(targets, loss_fns, output_shapes): + if y is None or loss is None or tensor_util.is_tensor(y): + continue + if loss is losses.categorical_crossentropy: + if y.shape[-1] == 1: + raise ValueError('You are passing a target array of shape ' + str( + y.shape) + ' while using as loss `categorical_crossentropy`. ' + '`categorical_crossentropy` expects ' + 'targets to be binary matrices (1s and 0s) ' + 'of shape (samples, classes). ' + 'If your targets are integer classes, ' + 'you can convert them to the expected format via:\n' + '```\n' + 'from keras.utils import to_categorical\n' + 'y_binary = to_categorical(y_int)\n' + '```\n' + '\n' + 'Alternatively, you can use the loss function ' + '`sparse_categorical_crossentropy` instead, ' + 'which does expect integer targets.') + if loss in key_losses: + for target_dim, out_dim in zip(y.shape[1:], shape[1:]): + if out_dim is not None and target_dim != out_dim: + raise ValueError('A target array with shape ' + str(y.shape) + + ' was passed for an output of shape ' + str(shape) + + ' while using as loss `' + loss.__name__ + '`. ' + 'This loss expects ' + 'targets to have the same shape ' + 'as the output.') + + +def collect_metrics(metrics, output_names): + """Maps metric functions to model outputs. + + Arguments: + metrics: a list or dict of metric functions. + output_names: a list of the names (strings) of model outputs. + + Returns: + A list (one entry per model output) of lists of metric functions. + For instance, if the model has 2 outputs, and for the first output + we want to compute "binary_accuracy" and "binary_crossentropy", + and just "binary_accuracy" for the second output, + the list would look like: + `[[binary_accuracy, binary_crossentropy], [binary_accuracy]]` + + Raises: + TypeError: if an incorrect type is passed for the `metrics` argument. + """ + if not metrics: + return [[] for _ in output_names] + if isinstance(metrics, list): + # we then apply all metrics to all outputs. + return [copy.copy(metrics) for _ in output_names] + elif isinstance(metrics, dict): + nested_metrics = [] + for name in output_names: + output_metrics = metrics.get(name, []) + if not isinstance(output_metrics, list): + output_metrics = [output_metrics] + nested_metrics.append(output_metrics) + return nested_metrics + else: + raise TypeError('Type of `metrics` argument not understood. ' + 'Expected a list or dictionary, found: ' + str(metrics)) + + +def batch_shuffle(index_array, batch_size): + """Shuffles an array in a batch-wise fashion. + + Useful for shuffling HDF5 arrays + (where one cannot access arbitrary indices). + + Arguments: + index_array: array of indices to be shuffled. + batch_size: integer. + + Returns: + The `index_array` array, shuffled in a batch-wise fashion. + """ + batch_count = int(len(index_array) / batch_size) + # to reshape we need to be cleanly divisible by batch size + # we stash extra items and reappend them after shuffling + last_batch = index_array[batch_count * batch_size:] + index_array = index_array[:batch_count * batch_size] + index_array = index_array.reshape((batch_count, batch_size)) + np.random.shuffle(index_array) + index_array = index_array.flatten() + return np.append(index_array, last_batch) + + +def weighted_masked_objective(fn): + """Adds support for masking and sample-weighting to an objective function. + + It transforms an objective function `fn(y_true, y_pred)` + into a sample-weighted, cost-masked objective function + `fn(y_true, y_pred, weights, mask)`. + + Arguments: + fn: The objective function to wrap, + with signature `fn(y_true, y_pred)`. + + Returns: + A function with signature `fn(y_true, y_pred, weights, mask)`. + """ + if fn is None: + return None + + def weighted(y_true, y_pred, weights, mask=None): + """Wrapper function. + + Arguments: + y_true: `y_true` argument of `fn`. + y_pred: `y_pred` argument of `fn`. + weights: Weights tensor. + mask: Mask tensor. + + Returns: + Scalar tensor. + """ + # score_array has ndim >= 2 + score_array = fn(y_true, y_pred) + if mask is not None: + # Cast the mask to floatX to avoid float64 upcasting in theano + mask = K.cast(mask, K.floatx()) + # mask should have the same shape as score_array + score_array *= mask + # the loss per batch should be proportional + # to the number of unmasked samples. + score_array /= K.mean(mask) + + # apply sample weighting + if weights is not None: + # reduce score_array to same ndim as weight array + ndim = K.ndim(score_array) + weight_ndim = K.ndim(weights) + score_array = K.mean(score_array, axis=list(range(weight_ndim, ndim))) + score_array *= weights + score_array /= K.mean(K.cast(K.not_equal(weights, 0), K.floatx())) + return K.mean(score_array) + + return weighted + + +def standardize_weights(y, + sample_weight=None, + class_weight=None, + sample_weight_mode=None): + """Performs sample weight validation and standardization. + + Everything gets normalized to a single sample-wise (or timestep-wise) + weight array. + + Arguments: + y: Numpy array of model targets to be weighted. + sample_weight: User-provided `sample_weight` argument. + class_weight: User-provided `class_weight` argument. + sample_weight_mode: One of `None` or `"temporal"`. + `"temporal"` indicated that we expect 2D weight data + that will be applied to the last 2 dimensions of + the targets (i.e. we are weighting timesteps, not samples). + + Returns: + A numpy array of target weights, one entry per sample to weight. + + Raises: + ValueError: In case of invalid user-provided arguments. + """ + if sample_weight_mode is not None: + if sample_weight_mode != 'temporal': + raise ValueError('"sample_weight_mode ' + 'should be None or "temporal". ' + 'Found: ' + str(sample_weight_mode)) + if len(y.shape) < 3: + raise ValueError('Found a sample_weight array for ' + 'an input with shape ' + str(y.shape) + '. ' + 'Timestep-wise sample weighting (use of ' + 'sample_weight_mode="temporal") is restricted to ' + 'outputs that are at least 3D, i.e. that have ' + 'a time dimension.') + if sample_weight is not None and len(sample_weight.shape) != 2: + raise ValueError('Found a sample_weight array with shape ' + + str(sample_weight.shape) + '. ' + 'In order to use timestep-wise sample weighting, ' + 'you should pass a 2D sample_weight array.') + else: + if sample_weight is not None and len(sample_weight.shape) != 1: + raise ValueError('Found a sample_weight array with shape ' + + str(sample_weight.shape) + '. ' + 'In order to use timestep-wise sample weights, ' + 'you should specify ' + 'sample_weight_mode="temporal" ' + 'in compile(). If you just mean to use ' + 'sample-wise weights, make sure your ' + 'sample_weight array is 1D.') + + if sample_weight is not None: + if len(sample_weight.shape) > len(y.shape): + raise ValueError( + 'Found a sample_weight with shape' + str(sample_weight.shape) + '.' + 'Expected sample_weight with rank ' + 'less than or equal to ' + str(len(y.shape))) + + if y.shape[:sample_weight.ndim] != sample_weight.shape: + raise ValueError( + 'Found a sample_weight array with shape ' + str(sample_weight.shape) + + ' for an input with shape ' + str(y.shape) + '. ' + 'sample_weight cannot be broadcast.') + return sample_weight + elif isinstance(class_weight, dict): + if len(y.shape) > 2: + raise ValueError('`class_weight` not supported for ' + '3+ dimensional targets.') + if y.shape[1] > 1: + y_classes = np.argmax(y, axis=1) + elif y.shape[1] == 1: + y_classes = np.reshape(y, y.shape[0]) + else: + y_classes = y + + weights = np.asarray( + [class_weight[cls] for cls in y_classes if cls in class_weight]) + + if len(weights) != len(y_classes): + # subtract the sets to pick all missing classes + existing_classes = set(y_classes) + existing_class_weight = set(class_weight.keys()) + raise ValueError('`class_weight` must contain all classes in the data.' + ' The classes %s exist in the data but not in ' + '`class_weight`.' % + (existing_classes - existing_class_weight)) + return weights + else: + return None diff --git a/tensorflow/python/keras/_impl/keras/estimator.py b/tensorflow/python/keras/_impl/keras/estimator.py index 0bf5bd41dc915fbecbce4c3a6191e925612dbebb..8426d84df964092435b10c9e28e1843df7e423f4 100644 --- a/tensorflow/python/keras/_impl/keras/estimator.py +++ b/tensorflow/python/keras/_impl/keras/estimator.py @@ -25,11 +25,15 @@ from tensorflow.python.client import session from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import export as export_lib from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras import models +from tensorflow.python.keras._impl.keras import optimizers +from tensorflow.python.keras._impl.keras.engine.base_layer import Layer +from tensorflow.python.keras._impl.keras.engine.network import Network from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_module @@ -50,36 +54,174 @@ def _cast_tensor_to_floatx(x): return math_ops.cast(x, K.floatx()) -def _create_ordered_io(keras_model, estimator_io_dict, is_input=True): +def _create_ordered_io(keras_model, estimator_io, is_input=True): """Create a list of tensors from IO dictionary based on Keras IO order. Args: - keras_model: an instance of compiled keras model. - estimator_io_dict: features or labels dictionary from model_fn. + keras_model: An instance of compiled keras model. + estimator_io: The features or labels (dict or plain array) from model_fn. is_input: True if dictionary is for inputs. Returns: - a list of tensors based on Keras IO order. + A list of tensors based on Keras IO order. Raises: ValueError: if dictionary keys cannot be found in Keras model input_names or output_names. """ - if is_input: - keras_io_names = keras_model.input_names + if isinstance(estimator_io, (list, tuple)): + # Case currently not supported by most built-in input_fn, + # but it's good to have for sanity + return [_cast_tensor_to_floatx(x) for x in estimator_io] + elif isinstance(estimator_io, dict): + if is_input: + if keras_model._is_graph_network: + keras_io_names = keras_model.input_names + else: + keras_io_names = [ + 'input_%d' % i for i in range(1, len(estimator_io) + 1)] + else: + if keras_model._is_graph_network: + keras_io_names = keras_model.output_names + else: + keras_io_names = [ + 'output_%d' % i for i in range(1, len(estimator_io) + 1)] + + for key in estimator_io: + if key not in keras_io_names: + raise ValueError( + 'Cannot find %s with name "%s" in Keras Model. ' + 'It needs to match one ' + 'of the following: %s' % ('input' if is_input else 'output', key, + ', '.join(keras_io_names))) + tensors = [_cast_tensor_to_floatx(estimator_io[io_name]) + for io_name in keras_io_names] + return tensors else: - keras_io_names = keras_model.output_names + # Plain array. + return _cast_tensor_to_floatx(estimator_io) - for key in estimator_io_dict: - if key not in keras_io_names: - raise ValueError( - 'Cannot find %s with name "%s" in Keras Model. It needs to match ' - 'one of the following: %s' % ('input' if is_input else 'output', key, - ', '.join(keras_io_names))) - tensors = [] - for io_name in keras_io_names: - tensors.append(_cast_tensor_to_floatx(estimator_io_dict[io_name])) - return tensors + +def _in_place_subclassed_model_reset(model): + """Substitute for model cloning that works for subclassed models. + + Subclassed models cannot be cloned because their topology is not serializable. + To "instantiate" an identical model in a new TF graph, we reuse the original + model object, but we clear its state. + + After calling this function on a model intance, you can use the model instance + as if it were a model clone (in particular you can use it in a new graph). + + This method clears the state of the input model. It is thus destructive. + However the original state can be restored fully by calling + `_in_place_subclassed_model_state_restoration`. + + Args: + model: Instance of a Keras model created via subclassing. + + Raises: + ValueError: In case the model uses a subclassed model as inner layer. + """ + assert not model._is_graph_network # Only makes sense for subclassed networks + # Retrieve all layers tracked by the model as well as their attribute names + attributes_cache = {} + for name in dir(model): + try: + value = getattr(model, name) + except (AttributeError, ValueError, TypeError): + continue + if isinstance(value, Layer): + attributes_cache[name] = value + assert value in model._layers + elif isinstance(value, (list, tuple)) and name not in ('layers', '_layers'): + # Handle case: list/tuple of layers (also tracked by the Network API). + if value and all(isinstance(val, Layer) for val in value): + raise ValueError('We do not support the use of list-of-layers ' + 'attributes in subclassed models used with ' + '`model_to_estimator` at this time. Found list ' + 'model: %s' % name) + + # Replace layers on the model with fresh layers + layers_to_names = {value: key for key, value in attributes_cache.items()} + original_layers = model._layers[:] + model._layers = [] + for layer in original_layers: # We preserve layer order. + config = layer.get_config() + # This will not work for nested subclassed models used as layers. + # This would be theoretically possible to support, but would add complexity. + # Only do it if users complain. + if isinstance(layer, Network) and not layer._is_graph_network: + raise ValueError('We do not support the use of nested subclassed models ' + 'in `model_to_estimator` at this time. Found nested ' + 'model: %s' % layer) + fresh_layer = layer.__class__.from_config(config) + name = layers_to_names[layer] + setattr(model, name, fresh_layer) + + # Cache original model build attributes (in addition to layers) + if (not hasattr(model, '_original_attributes_cache') or + model._original_attributes_cache is None): + if model.built: + attributes_to_cache = [ + 'inputs', + 'outputs', + '_feed_outputs', + '_feed_output_names', + '_feed_output_shapes', + '_feed_loss_fns', + 'loss_weights_list', + 'targets', + '_feed_targets', + 'sample_weight_modes', + 'weighted_metrics', + 'metrics_names', + 'metrics_tensors', + 'metrics_updates', + 'stateful_metric_names', + 'total_loss', + 'sample_weights', + '_feed_sample_weights', + 'train_function', + 'test_function', + 'predict_function', + '_collected_trainable_weights', + '_feed_inputs', + '_feed_input_names', + '_feed_input_shapes', + 'optimizer', + ] + for name in attributes_to_cache: + attributes_cache[name] = getattr(model, name) + model._original_attributes_cache = attributes_cache + + # Reset built state + model.built = False + model.inputs = None + model.outputs = None + + +def _in_place_subclassed_model_state_restoration(model): + """Restores the original state of a model after it was "reset". + + This undoes this action of `_in_place_subclassed_model_reset`. + + Args: + model: Instance of a Keras model created via subclassing, on which + `_in_place_subclassed_model_reset` was previously called. + """ + assert not model._is_graph_network + # Restore layers and build attributes + if (hasattr(model, '_original_attributes_cache') and + model._original_attributes_cache is not None): + model._layers = [] + for name, value in model._original_attributes_cache.items(): + setattr(model, name, value) + model._original_attributes_cache = None + else: + # Restore to the state of a never-called model. + model.built = False + model.inputs = None + model.outputs = None def _clone_and_build_model(mode, @@ -93,8 +235,8 @@ def _clone_and_build_model(mode, mode: training mode. keras_model: an instance of compiled keras model. custom_objects: Dictionary for custom objects. - features: - labels: + features: Dict of tensors. + labels: Dict of tensors, or single tensor instance. Returns: The newly built model. @@ -102,33 +244,49 @@ def _clone_and_build_model(mode, # Set to True during training, False for inference. K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN) - # Clone keras model. - input_tensors = None if features is None else _create_ordered_io( - keras_model, features) - if custom_objects: - with CustomObjectScope(custom_objects): + # Get list of inputs. + if features is None: + input_tensors = None + else: + input_tensors = _create_ordered_io(keras_model, + estimator_io=features, + is_input=True) + # Get list of outputs. + if labels is None: + target_tensors = None + elif isinstance(labels, dict): + target_tensors = _create_ordered_io(keras_model, + estimator_io=labels, + is_input=False) + else: + target_tensors = [ + _cast_tensor_to_floatx( + sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(labels)) + ] + + if keras_model._is_graph_network: + if custom_objects: + with CustomObjectScope(custom_objects): + model = models.clone_model(keras_model, input_tensors=input_tensors) + else: model = models.clone_model(keras_model, input_tensors=input_tensors) else: - model = models.clone_model(keras_model, input_tensors=input_tensors) + model = keras_model + _in_place_subclassed_model_reset(model) + if input_tensors is not None: + model._set_inputs(input_tensors) # Compile/Build model - if mode is model_fn_lib.ModeKeys.PREDICT and not model.built: - model.build() + if mode is model_fn_lib.ModeKeys.PREDICT: + if isinstance(model, models.Sequential): + model.build() else: - optimizer_config = keras_model.optimizer.get_config() - optimizer = keras_model.optimizer.__class__.from_config(optimizer_config) - optimizer.iterations = training_util.get_or_create_global_step() - - # Get list of outputs. - if labels is None: - target_tensors = None - elif isinstance(labels, dict): - target_tensors = _create_ordered_io(keras_model, labels, is_input=False) + if isinstance(keras_model.optimizer, optimizers.TFOptimizer): + optimizer = keras_model.optimizer else: - target_tensors = [ - _cast_tensor_to_floatx( - sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(labels)) - ] + optimizer_config = keras_model.optimizer.get_config() + optimizer = keras_model.optimizer.__class__.from_config(optimizer_config) + optimizer.iterations = training_util.get_or_create_global_step() model.compile( optimizer, @@ -138,9 +296,6 @@ def _clone_and_build_model(mode, sample_weight_mode=keras_model.sample_weight_mode, weighted_metrics=keras_model.weighted_metrics, target_tensors=target_tensors) - - if isinstance(model, models.Sequential): - model = model.model return model @@ -168,10 +323,14 @@ def _create_keras_model_fn(keras_model, custom_objects=None): # Set loss and metric only during train and evaluate. if mode is not model_fn_lib.ModeKeys.PREDICT: - model._make_train_function() # pylint: disable=protected-access + if mode is model_fn_lib.ModeKeys.TRAIN: + model._make_train_function() # pylint: disable=protected-access + else: + model._make_test_function() # pylint: disable=protected-access loss = model.total_loss if model.metrics: + # TODO(fchollet): support stateful metrics eval_metric_ops = {} # When each metric maps to an output if isinstance(model.metrics, dict): @@ -195,6 +354,10 @@ def _create_keras_model_fn(keras_model, custom_objects=None): if mode is model_fn_lib.ModeKeys.TRAIN: train_op = model.train_function.updates_op + if not model._is_graph_network: + # Reset model state to original state, + # to avoid `model_fn` being destructive for the initial model argument. + _in_place_subclassed_model_state_restoration(keras_model) return model_fn_lib.EstimatorSpec( mode=mode, predictions=predictions, @@ -230,8 +393,6 @@ def _save_first_checkpoint(keras_model, estimator, custom_objects, training_util.create_global_step() model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model, custom_objects) - if isinstance(model, models.Sequential): - model = model.model # save to checkpoint with session.Session(config=estimator._session_config) as sess: model.set_weights(keras_weights) @@ -274,10 +435,11 @@ def model_to_estimator(keras_model=None, """ if (not keras_model) and (not keras_model_path): raise ValueError( - 'Either keras_model or keras_model_path needs to be provided.') + 'Either `keras_model` or `keras_model_path` needs to be provided.') if keras_model and keras_model_path: raise ValueError( - 'Please specity either keras_model or keras_model_path but not both.') + 'Please specity either `keras_model` or `keras_model_path`, ' + 'but not both.') if not keras_model: if keras_model_path.startswith( @@ -288,18 +450,42 @@ def model_to_estimator(keras_model=None, logging.info('Loading models from %s', keras_model_path) keras_model = models.load_model(keras_model_path) else: - logging.info('Using the Keras model from memory.') + logging.info('Using the Keras model provided.') keras_model = keras_model - if not hasattr(keras_model, 'optimizer'): + if not hasattr(keras_model, 'optimizer') or not keras_model.optimizer: raise ValueError( - 'Given keras model has not been compiled yet. Please compile first ' - 'before creating the estimator.') + 'The given keras model has not been compiled yet. Please compile first ' + 'before calling `model_to_estimator`.') + + if isinstance(config, dict): + config = run_config_lib.RunConfig(**config) - keras_weights = keras_model.get_weights() keras_model_fn = _create_keras_model_fn(keras_model, custom_objects) - est = estimator_lib.Estimator( + estimator = estimator_lib.Estimator( keras_model_fn, model_dir=model_dir, config=config) - # TODO(yifeif): move checkpoint initialization to scaffold.init_fn - _save_first_checkpoint(keras_model, est, custom_objects, keras_weights) - return est + + # Pass the config into keras backend's default session. + with session.Session(config=estimator._session_config) as sess: + K.set_session(sess) + + keras_weights = keras_model.get_weights() + if keras_model._is_graph_network: + # TODO(yifeif): move checkpoint initialization to scaffold.init_fn + _save_first_checkpoint(keras_model, + estimator, + custom_objects, + keras_weights) + elif keras_model.built: + logging.warning('You are creating an Estimator from a Keras model ' + 'manually subclassed from `Model`, that was ' + 'already called on some inputs (and thus already had ' + 'weights). We are currently unable to preserve ' + 'the model\'s state (its weights) ' + 'as part of the estimator ' + 'in this case. Be warned that the estimator ' + 'has been created using ' + 'a freshly initialized version of your model.\n' + 'Note that this doesn\'t affect the state of the ' + 'model instance you passed as `keras_model` argument.') + return estimator diff --git a/tensorflow/python/keras/_impl/keras/estimator_test.py b/tensorflow/python/keras/_impl/keras/estimator_test.py index 88dd14b856a4ee9dfbee61d6fd1bdb96af24b50c..e076dc25b16900636313f0ddd85a61b8d917fc91 100644 --- a/tensorflow/python/keras/_impl/keras/estimator_test.py +++ b/tensorflow/python/keras/_impl/keras/estimator_test.py @@ -24,6 +24,7 @@ import tempfile import numpy as np +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import test_util @@ -33,6 +34,7 @@ from tensorflow.python.keras._impl.keras.applications import mobilenet from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import rmsprop try: @@ -63,12 +65,42 @@ def simple_functional_model(): return model -def get_resource_for_simple_model(is_sequential=True, is_evaluate=False): - model = simple_sequential_model( - ) if is_sequential else simple_functional_model() - if is_sequential: +def simple_subclassed_model(): + + class SimpleModel(keras.Model): + + def __init__(self): + super(SimpleModel, self).__init__() + self.dense1 = keras.layers.Dense(16, activation='relu') + self.dp = keras.layers.Dropout(0.1) + self.dense2 = keras.layers.Dense(_NUM_CLASS, activation='softmax') + + def call(self, inputs): + x = self.dense1(inputs) + x = self.dp(x) + return self.dense2(x) + + return SimpleModel() + + +def get_resource_for_simple_model(model_type='sequential', + is_evaluate=False,): + if model_type == 'sequential': + model = simple_sequential_model() model.build() - input_name = model.input_names[0] + elif model_type == 'subclass': + model = simple_subclassed_model() + else: + assert model_type == 'functional' + model = simple_functional_model() + + if model_type == 'subclass': + input_name = 'input_1' + output_name = 'output_1' + else: + input_name = model.input_names[0] + output_name = model.output_names[0] + np.random.seed(_RANDOM_SEED) (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( train_samples=_TRAIN_SIZE, @@ -79,17 +111,19 @@ def get_resource_for_simple_model(is_sequential=True, is_evaluate=False): y_test = keras.utils.to_categorical(y_test) train_input_fn = numpy_io.numpy_input_fn( - x={input_name: x_train}, - y=y_train, + x=randomize_io_type(x_train, input_name), + y=randomize_io_type(y_train, output_name), shuffle=False, num_epochs=None, batch_size=16) evaluate_input_fn = numpy_io.numpy_input_fn( - x={input_name: x_test}, y=y_test, num_epochs=1, shuffle=False) + x=randomize_io_type(x_test, input_name), + y=randomize_io_type(y_test, output_name), + num_epochs=1, shuffle=False) predict_input_fn = numpy_io.numpy_input_fn( - x={input_name: x_test}, num_epochs=1, shuffle=False) + x=randomize_io_type(x_test, input_name), num_epochs=1, shuffle=False) inference_input_fn = evaluate_input_fn if is_evaluate else predict_input_fn @@ -97,6 +131,14 @@ def get_resource_for_simple_model(is_sequential=True, is_evaluate=False): y_test), train_input_fn, inference_input_fn +def randomize_io_type(array, name): + switch = np.random.random() + if switch > 0.5: + return array + else: + return {name: array} + + def multi_inputs_multi_outputs_model(): # test multi-input layer a = keras.layers.Input(shape=(16,), name='input_a') @@ -133,10 +175,10 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): gfile.DeleteRecursively(self._base_dir) def test_train(self): - for is_sequential in [True, False]: + for model_type in ['sequential', 'functional']: keras_model, (_, _), ( _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model( - is_sequential=is_sequential, is_evaluate=True) + model_type=model_type, is_evaluate=True) keras_model.compile( loss='categorical_crossentropy', optimizer='rmsprop', @@ -154,10 +196,87 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): writer_cache.FileWriterCache.clear() gfile.DeleteRecursively(self._config.model_dir) + def test_train_with_tf_optimizer(self): + for model_type in ['sequential', 'functional']: + keras_model, (_, _), ( + _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model( + model_type=model_type, is_evaluate=True) + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(1e-3), + metrics=['mse', keras.metrics.categorical_accuracy]) + + with self.test_session(): + est_keras = keras.estimator.model_to_estimator( + keras_model=keras_model, + # Also use dict config argument to get test coverage for that line. + config={ + 'tf_random_seed': _RANDOM_SEED, + 'model_dir': self._base_dir, + }) + before_eval_results = est_keras.evaluate( + input_fn=eval_input_fn, steps=1) + est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16) + after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + + def test_train_with_subclassed_model(self): + keras_model, (_, _), ( + _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model( + model_type='subclass', is_evaluate=True) + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(1e-3), + metrics=['mse', keras.metrics.categorical_accuracy]) + + with self.test_session(): + est_keras = keras.estimator.model_to_estimator( + keras_model=keras_model, config=self._config) + est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16) + before_eval_results = est_keras.evaluate( + input_fn=eval_input_fn, steps=1) + est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16) + after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + + def test_train_with_subclassed_model_with_existing_state(self): + keras_model, (_, _), ( + _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model( + model_type='subclass', is_evaluate=True) + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(1e-3), + metrics=['mse', keras.metrics.categorical_accuracy]) + + with self.test_session(): + # Create state + keras_model.train_on_batch(np.random.random((10,) + _INPUT_SIZE), + np.random.random((10, _NUM_CLASS))) + original_preds = keras_model.predict(np.ones((10,) + _INPUT_SIZE)) + + est_keras = keras.estimator.model_to_estimator( + keras_model=keras_model, config=self._config) + est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16) + before_eval_results = est_keras.evaluate( + input_fn=eval_input_fn, steps=1) + est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16) + after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + + # Check that original model state was not altered + preds = keras_model.predict(np.ones((10,) + _INPUT_SIZE)) + self.assertAllClose(original_preds, preds, atol=1e-5) + # Check that the original model compilation did not break + keras_model.train_on_batch(np.random.random((10,) + _INPUT_SIZE), + np.random.random((10, _NUM_CLASS))) + def test_evaluate(self): keras_model, (x_train, y_train), ( x_test, y_test), _, eval_input_fn = get_resource_for_simple_model( - is_sequential=False, is_evaluate=True) + model_type='functional', is_evaluate=True) with self.test_session(): metrics = [ @@ -199,7 +318,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): # Check that predict on a pretrained model yield the same result. keras_model, (x_train, y_train), ( x_test, _), _, pred_input_fn = get_resource_for_simple_model( - is_sequential=True, is_evaluate=False) + model_type='sequential', is_evaluate=False) with self.test_session(): keras_model.compile( @@ -261,7 +380,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model, (x_train, y_train), ( x_test, _), _, pred_input_fn = get_resource_for_simple_model( - is_sequential=False, is_evaluate=False) + model_type='functional', is_evaluate=False) with self.test_session(): keras_model.compile( @@ -377,6 +496,22 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model=keras_model, model_dir=tempfile.mkdtemp(dir=self._base_dir)) + def test_gpu_config(self): + keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer='rmsprop', + metrics=['mse', keras.metrics.categorical_accuracy]) + + gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3) + sess_config = config_pb2.ConfigProto(gpu_options=gpu_options) + self._config._session_config = sess_config + keras.estimator.model_to_estimator( + keras_model=keras_model, config=self._config) + self.assertEqual(keras.backend.get_session() + ._config.gpu_options.per_process_gpu_memory_fraction, + gpu_options.per_process_gpu_memory_fraction) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py index 7cac17c51a9adcf8fc62154b6633de60bab18387..c40ee109aaea7dacea72e095b1d8cea3ed2e9bf8 100644 --- a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py +++ b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py @@ -25,7 +25,7 @@ from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py index d2792b9636214d21e9658018f853fb6c0808abb4..d95a0942452afa82e277c358be5c3b2ba061ac98 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py +++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py @@ -26,7 +26,7 @@ from tensorflow.python.keras._impl.keras import constraints from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.keras._impl.keras.layers.recurrent import Recurrent from tensorflow.python.keras._impl.keras.utils import conv_utils from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py index c612e97a9d67f7398c78a7da1107f8e067bf9371..f4a134b96cec0385cb24a208f3403db944b68edc 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py @@ -553,7 +553,7 @@ class ZeroPaddingTest(test.TestCase): layer = keras.layers.ZeroPadding1D(padding=2) layer.build(shape) output = layer(keras.backend.variable(inputs)) - if context.in_eager_mode(): + if context.executing_eagerly(): np_output = output.numpy() else: np_output = keras.backend.eval(output) @@ -564,7 +564,7 @@ class ZeroPaddingTest(test.TestCase): layer = keras.layers.ZeroPadding1D(padding=(1, 2)) layer.build(shape) output = layer(keras.backend.variable(inputs)) - if context.in_eager_mode(): + if context.executing_eagerly(): np_output = output.numpy() else: np_output = keras.backend.eval(output) @@ -610,7 +610,7 @@ class ZeroPaddingTest(test.TestCase): padding=(2, 2), data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - if context.in_eager_mode(): + if context.executing_eagerly(): np_output = output.numpy() else: np_output = keras.backend.eval(output) @@ -629,7 +629,7 @@ class ZeroPaddingTest(test.TestCase): padding=((1, 2), (3, 4)), data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - if context.in_eager_mode(): + if context.executing_eagerly(): np_output = output.numpy() else: np_output = keras.backend.eval(output) @@ -683,7 +683,7 @@ class ZeroPaddingTest(test.TestCase): layer = keras.layers.ZeroPadding3D(padding=(2, 2, 2)) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - if context.in_eager_mode(): + if context.executing_eagerly(): np_output = output.numpy() else: np_output = keras.backend.eval(output) @@ -737,7 +737,7 @@ class UpSamplingTest(test.TestCase): size=(length_row, length_col), data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - if context.in_eager_mode(): + if context.executing_eagerly(): np_output = output.numpy() else: np_output = keras.backend.eval(output) @@ -790,7 +790,7 @@ class UpSamplingTest(test.TestCase): data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - if context.in_eager_mode(): + if context.executing_eagerly(): np_output = output.numpy() else: np_output = keras.backend.eval(output) @@ -865,7 +865,7 @@ class CroppingTest(test.TestCase): cropping=cropping, data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - if context.in_eager_mode(): + if context.executing_eagerly(): np_output = output.numpy() else: np_output = keras.backend.eval(output) @@ -892,7 +892,7 @@ class CroppingTest(test.TestCase): cropping=cropping, data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - if context.in_eager_mode(): + if context.executing_eagerly(): np_output = output.numpy() else: np_output = keras.backend.eval(output) @@ -937,7 +937,7 @@ class CroppingTest(test.TestCase): cropping=cropping, data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - if context.in_eager_mode(): + if context.executing_eagerly(): np_output = output.numpy() else: np_output = keras.backend.eval(output) @@ -954,7 +954,7 @@ class CroppingTest(test.TestCase): cropping[2][0]:-cropping[2][1], :] np.testing.assert_allclose(np_output, expected_out) - # test incorrect use + # test incorrect use with self.assertRaises(ValueError): keras.layers.Cropping3D(cropping=(1, 1)) with self.assertRaises(ValueError): diff --git a/tensorflow/python/keras/_impl/keras/layers/core.py b/tensorflow/python/keras/_impl/keras/layers/core.py index 50a197c80c3d97f47a071a24297301dddf78a27e..73e4f15f7e259211c892fdc663e14dcb14aec58d 100644 --- a/tensorflow/python/keras/_impl/keras/layers/core.py +++ b/tensorflow/python/keras/_impl/keras/layers/core.py @@ -124,7 +124,7 @@ class Dropout(tf_core_layers.Dropout, Layer): training = K.learning_phase() output = super(Dropout, self).call(inputs, training=training) # EagerTensor object has no attribute _uses_learning_phase - if not context.in_eager_mode() and training is K.learning_phase(): + if not context.executing_eagerly() and training is K.learning_phase(): output._uses_learning_phase = True # pylint: disable=protected-access return output diff --git a/tensorflow/python/keras/_impl/keras/layers/embeddings.py b/tensorflow/python/keras/_impl/keras/layers/embeddings.py index ca92899a455cd28a756e9efff63655d7c43c9f45..006ecd3135be25d43133daed1603734ecd1be955 100644 --- a/tensorflow/python/keras/_impl/keras/layers/embeddings.py +++ b/tensorflow/python/keras/_impl/keras/layers/embeddings.py @@ -23,7 +23,7 @@ from tensorflow.python.keras._impl.keras import constraints from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/_impl/keras/layers/local.py b/tensorflow/python/keras/_impl/keras/layers/local.py index df0efe6b8b7eaa0259eb6f4e246269551b3e0c15..13d96e939220c11a4090cf535e3efa4365fe8b62 100644 --- a/tensorflow/python/keras/_impl/keras/layers/local.py +++ b/tensorflow/python/keras/_impl/keras/layers/local.py @@ -25,7 +25,7 @@ from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.keras._impl.keras.utils import conv_utils from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/_impl/keras/layers/merge.py b/tensorflow/python/keras/_impl/keras/layers/merge.py index cdf2878e83e32147d30d6b29742b7e9013a1facb..c660cbd449b11a139f64cfa8b3a35310a597491c 100644 --- a/tensorflow/python/keras/_impl/keras/layers/merge.py +++ b/tensorflow/python/keras/_impl/keras/layers/merge.py @@ -21,8 +21,8 @@ from __future__ import division from __future__ import print_function from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.engine.topology import Layer -from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.base_layer import Layer +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/_impl/keras/layers/noise.py b/tensorflow/python/keras/_impl/keras/layers/noise.py index 9010f4961585af58b7eae43dcd224e0c39606239..e309d160e5a9be97ff5f5356dad9dfaf85430233 100644 --- a/tensorflow/python/keras/_impl/keras/layers/noise.py +++ b/tensorflow/python/keras/_impl/keras/layers/noise.py @@ -22,7 +22,7 @@ import numpy as np from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/_impl/keras/layers/normalization.py b/tensorflow/python/keras/_impl/keras/layers/normalization.py index 0dedd5e8daa2974038c90ae2e8c68ca6516ba725..3b44b20bf822429351002c0f81fe8f9596d595d3 100644 --- a/tensorflow/python/keras/_impl/keras/layers/normalization.py +++ b/tensorflow/python/keras/_impl/keras/layers/normalization.py @@ -111,7 +111,7 @@ class BatchNormalization(tf_normalization_layers.BatchNormalization, Layer): if training is None: training = K.learning_phase() output = super(BatchNormalization, self).call(inputs, training=training) - if context.in_graph_mode() and training is K.learning_phase(): + if not context.executing_eagerly() and training is K.learning_phase(): output._uses_learning_phase = True # pylint: disable=protected-access return output diff --git a/tensorflow/python/keras/_impl/keras/layers/pooling_test.py b/tensorflow/python/keras/_impl/keras/layers/pooling_test.py index 70049f0976b7170005183bb4b028079b39a23afb..bb003c1dddf80e2a745c1268a3a7d045f4e8b036 100644 --- a/tensorflow/python/keras/_impl/keras/layers/pooling_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/pooling_test.py @@ -105,7 +105,7 @@ class Pooling2DTest(test.TestCase): # This part of the test can only run on GPU but doesn't appear # to be properly assigned to a GPU when running in eager mode. - if not context.in_eager_mode(): + if not context.executing_eagerly(): # Only runs on GPU with CUDA, channels_first is not supported on CPU. # TODO(b/62340061): Support channels_first on CPU. if test.is_gpu_available(cuda_only=True): diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py index a81971d9eef0070245c31704b527a7f8edbbe9e9..791f9b311300ed05591083d551c040eb25ac8e22 100644 --- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py +++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py @@ -31,7 +31,7 @@ from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -546,8 +546,8 @@ class RNN(Layer): raise ValueError('The initial state or constants of an RNN' ' layer cannot be specified with a mix of' ' Keras tensors and non-Keras tensors' - '(a "Keras tensor" is a tensor that was' - 'returned by a Keras layer, or by `Input`)') + ' (a "Keras tensor" is a tensor that was' + ' returned by a Keras layer, or by `Input`)') if is_keras_tensor: # Compute the full input spec, including state and constants @@ -936,7 +936,7 @@ class SimpleRNNCell(Layer): # Properly set learning phase on output tensor. if 0 < self.dropout + self.recurrent_dropout: - if training is None and not context.in_eager_mode(): + if training is None and not context.executing_eagerly(): # This would be harmless to set in eager mode, but eager tensors # disallow setting arbitrary attributes. output._uses_learning_phase = True @@ -1384,7 +1384,7 @@ class GRUCell(Layer): hh = self.activation(x_h + recurrent_h) h = z * h_tm1 + (1 - z) * hh if 0 < self.dropout + self.recurrent_dropout: - if training is None and not context.in_eager_mode(): + if training is None and not context.executing_eagerly(): # This would be harmless to set in eager mode, but eager tensors # disallow setting arbitrary attributes. h._uses_learning_phase = True @@ -1877,7 +1877,7 @@ class LSTMCell(Layer): h = o * self.activation(c) if 0 < self.dropout + self.recurrent_dropout: - if training is None and not context.in_eager_mode(): + if training is None and not context.executing_eagerly(): # This would be harmless to set in eager mode, but eager tensors # disallow setting arbitrary attributes. h._uses_learning_phase = True diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers.py b/tensorflow/python/keras/_impl/keras/layers/wrappers.py index 61f1a758e4701e6925af88b7fed9c48cf42ca735..76ddd9299dd669da35d89a6fe8fc521ce4c26337 100644 --- a/tensorflow/python/keras/_impl/keras/layers/wrappers.py +++ b/tensorflow/python/keras/_impl/keras/layers/wrappers.py @@ -25,7 +25,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg from tensorflow.python.layers import utils as tf_layers_util from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py index 3d71a620fcb34d21c41f920eed99b1fe22668899..58b144365be6cd8ea5b2ea82e275eacdee6b6c84 100644 --- a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py +++ b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py @@ -174,19 +174,18 @@ class ModelSubclassingTest(test.TestCase): num_samples = 100 input_dim = 50 - with self.test_session(): - model = SimpleTestModel(num_classes=num_classes, - use_dp=True, - use_bn=True) - model.compile(loss='mse', - optimizer=RMSPropOptimizer(learning_rate=0.001), - metrics=['acc']) + model = SimpleTestModel(num_classes=num_classes, + use_dp=True, + use_bn=True) + model.compile(loss='mse', + optimizer=RMSPropOptimizer(learning_rate=0.001), + metrics=['acc']) - x = np.ones((num_samples, input_dim)) - y = np.zeros((num_samples, num_classes)) + x = np.ones((num_samples, input_dim)) + y = np.zeros((num_samples, num_classes)) - model.fit(x, y, epochs=2, batch_size=32, verbose=0) - _ = model.evaluate(x, y, verbose=0) + model.fit(x, y, epochs=2, batch_size=32, verbose=0) + _ = model.evaluate(x, y, verbose=0) @test_util.run_in_graph_and_eager_modes() def test_multi_io_workflow_with_np_arrays(self): @@ -194,21 +193,20 @@ class ModelSubclassingTest(test.TestCase): num_samples = 1000 input_dim = 50 - with self.test_session(): - model = MultiIOTestModel(num_classes=num_classes, - use_dp=True, - use_bn=True) - model.compile(loss='mse', - optimizer=RMSPropOptimizer(learning_rate=0.001), - metrics=['acc']) + model = MultiIOTestModel(num_classes=num_classes, + use_dp=True, + use_bn=True) + model.compile(loss='mse', + optimizer=RMSPropOptimizer(learning_rate=0.001), + metrics=['acc']) - x1 = np.ones((num_samples, input_dim)) - x2 = np.ones((num_samples, input_dim)) - y1 = np.zeros((num_samples, num_classes[0])) - y2 = np.zeros((num_samples, num_classes[1])) + x1 = np.ones((num_samples, input_dim)) + x2 = np.ones((num_samples, input_dim)) + y1 = np.zeros((num_samples, num_classes[0])) + y2 = np.zeros((num_samples, num_classes[1])) - model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0) - _ = model.evaluate([x1, x2], [y1, y2], verbose=0) + model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0) + _ = model.evaluate([x1, x2], [y1, y2], verbose=0) def test_single_io_workflow_with_tensors(self): @@ -321,14 +319,13 @@ class ModelSubclassingTest(test.TestCase): x = np.ones((num_samples, input_dim)) y = np.ones((num_samples, input_dim)) - with self.test_session(): - model = BNNet() - model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) - y_ref = model.predict(x) + model = BNNet() + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + y_ref = model.predict(x) - model.train_on_batch(x, y) - y_new = model.predict(x) - self.assertGreater(np.sum(np.abs(y_ref - y_new)), 0.1) + model.train_on_batch(x, y) + y_new = model.predict(x) + self.assertGreater(np.sum(np.abs(y_ref - y_new)), 0.1) @test_util.run_in_graph_and_eager_modes() def test_training_and_inference_behavior(self): @@ -350,14 +347,13 @@ class ModelSubclassingTest(test.TestCase): x = self.dp(inputs) return self.dense(x) - with self.test_session(): - model = DPNet() - x = np.ones((num_samples, input_dim)) - y = model.predict(x) - self.assertEqual(np.sum(y), np.sum(x)) - model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) - loss = model.train_on_batch(x, y) - self.assertGreater(loss, 0.1) + model = DPNet() + x = np.ones((num_samples, input_dim)) + y = model.predict(x) + self.assertEqual(np.sum(y), np.sum(x)) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + loss = model.train_on_batch(x, y) + self.assertGreater(loss, 0.1) @test_util.run_in_graph_and_eager_modes() def test_training_methods(self): @@ -373,21 +369,20 @@ class ModelSubclassingTest(test.TestCase): y1 = np.zeros((num_samples, num_classes[0])) y2 = np.zeros((num_samples, num_classes[1])) - with self.test_session(): - model = MultiIOTestModel(num_classes=num_classes, use_bn=True) - model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) - model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0) - model.fit({'input_1': x1, 'input_2': x2}, - {'output_1': y1, 'output_2': y2}, - epochs=2, batch_size=32) - model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0, - validation_data=([x1, x2], [y1, y2])) + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0) + model.fit({'input_1': x1, 'input_2': x2}, + {'output_1': y1, 'output_2': y2}, + epochs=2, batch_size=32) + model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0, + validation_data=([x1, x2], [y1, y2])) - model = MultiIOTestModel(num_classes=num_classes, use_bn=True) - model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) - model.train_on_batch([x1, x2], [y1, y2]) - model.train_on_batch({'input_1': x1, 'input_2': x2}, - {'output_1': y1, 'output_2': y2}) + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + model.train_on_batch([x1, x2], [y1, y2]) + model.train_on_batch({'input_1': x1, 'input_2': x2}, + {'output_1': y1, 'output_2': y2}) @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def test_inference_methods(self): @@ -402,17 +397,16 @@ class ModelSubclassingTest(test.TestCase): y1 = np.zeros((num_samples, num_classes[0])) y2 = np.zeros((num_samples, num_classes[1])) - with self.test_session(): - model = MultiIOTestModel(num_classes=num_classes, use_bn=True) - model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) - model.evaluate([x1, x2], [y1, y2]) - model.test_on_batch([x1, x2], [y1, y2]) + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + model.evaluate([x1, x2], [y1, y2]) + model.test_on_batch([x1, x2], [y1, y2]) - model = MultiIOTestModel(num_classes=num_classes, use_bn=True) - model.predict([x1, x2]) + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + model.predict([x1, x2]) - model = MultiIOTestModel(num_classes=num_classes, use_bn=True) - model.predict_on_batch([x1, x2]) + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + model.predict_on_batch([x1, x2]) @test_util.run_in_graph_and_eager_modes() def test_trainable_mutation(self): @@ -435,26 +429,25 @@ class ModelSubclassingTest(test.TestCase): y1 = np.zeros((num_samples, num_classes[0])) y2 = np.zeros((num_samples, num_classes[1])) - with self.test_session(): - model = MultiIOTestModel(num_classes=num_classes, use_bn=True) - model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) - model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0) - y_ref_1, y_ref_2 = model.predict([x1, x2]) + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0) + y_ref_1, y_ref_2 = model.predict([x1, x2]) - fd, fname = tempfile.mkstemp('.h5') - model.save_weights(fname) + fd, fname = tempfile.mkstemp('.h5') + model.save_weights(fname) - model = MultiIOTestModel(num_classes=num_classes, use_bn=True) - # need to build the model before loading weights - # (otherwise no weights to load) - model._set_inputs([x1, x2]) - model.load_weights(fname) + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + # need to build the model before loading weights + # (otherwise no weights to load) + model._set_inputs([x1, x2]) + model.load_weights(fname) - y1, y2 = model.predict([x1, x2]) - self.assertAllClose(y_ref_1, y1, atol=1e-5) - self.assertAllClose(y_ref_2, y2, atol=1e-5) - os.close(fd) - os.remove(fname) + y1, y2 = model.predict([x1, x2]) + self.assertAllClose(y_ref_1, y1, atol=1e-5) + self.assertAllClose(y_ref_2, y2, atol=1e-5) + os.close(fd) + os.remove(fname) @test_util.run_in_graph_and_eager_modes() def test_summary(self): @@ -488,23 +481,22 @@ class ModelSubclassingTest(test.TestCase): num_samples = 100 input_dim = 50 - with self.test_session(): - model = NestedTestModel1(num_classes=num_classes) - model.compile(loss='mse', - optimizer=RMSPropOptimizer(learning_rate=0.001), - metrics=['acc']) + model = NestedTestModel1(num_classes=num_classes) + model.compile(loss='mse', + optimizer=RMSPropOptimizer(learning_rate=0.001), + metrics=['acc']) - x = np.ones((num_samples, input_dim)) - y = np.zeros((num_samples, num_classes)) + x = np.ones((num_samples, input_dim)) + y = np.zeros((num_samples, num_classes)) - model.fit(x, y, epochs=2, batch_size=32, verbose=0) - _ = model.evaluate(x, y, verbose=0) + model.fit(x, y, epochs=2, batch_size=32, verbose=0) + _ = model.evaluate(x, y, verbose=0) - self.assertEqual(len(model.weights), 8 + len(model.test_net.weights)) - self.assertEqual(len(model.non_trainable_weights), - 2 + len(model.test_net.non_trainable_weights)) - self.assertEqual(len(model.trainable_weights), - 6 + len(model.test_net.trainable_weights)) + self.assertEqual(len(model.weights), 8 + len(model.test_net.weights)) + self.assertEqual(len(model.non_trainable_weights), + 2 + len(model.test_net.non_trainable_weights)) + self.assertEqual(len(model.trainable_weights), + 6 + len(model.test_net.trainable_weights)) @test_util.run_in_graph_and_eager_modes() def test_graph_nested_in_subclass(self): @@ -512,23 +504,22 @@ class ModelSubclassingTest(test.TestCase): num_samples = 100 input_dim = 50 - with self.test_session(): - model = NestedTestModel2(num_classes=num_classes) - model.compile(loss='mse', - optimizer=RMSPropOptimizer(learning_rate=0.001), - metrics=['acc']) + model = NestedTestModel2(num_classes=num_classes) + model.compile(loss='mse', + optimizer=RMSPropOptimizer(learning_rate=0.001), + metrics=['acc']) - x = np.ones((num_samples, input_dim)) - y = np.zeros((num_samples, num_classes)) + x = np.ones((num_samples, input_dim)) + y = np.zeros((num_samples, num_classes)) - model.fit(x, y, epochs=2, batch_size=32, verbose=0) - _ = model.evaluate(x, y, verbose=0) + model.fit(x, y, epochs=2, batch_size=32, verbose=0) + _ = model.evaluate(x, y, verbose=0) - self.assertEqual(len(model.weights), 8 + len(model.test_net.weights)) - self.assertEqual(len(model.non_trainable_weights), - 2 + len(model.test_net.non_trainable_weights)) - self.assertEqual(len(model.trainable_weights), - 6 + len(model.test_net.trainable_weights)) + self.assertEqual(len(model.weights), 8 + len(model.test_net.weights)) + self.assertEqual(len(model.non_trainable_weights), + 2 + len(model.test_net.non_trainable_weights)) + self.assertEqual(len(model.trainable_weights), + 6 + len(model.test_net.trainable_weights)) @test_util.run_in_graph_and_eager_modes() def test_subclass_nested_in_graph(self): @@ -536,22 +527,21 @@ class ModelSubclassingTest(test.TestCase): num_samples = 100 input_dim = 50 - with self.test_session(): - model = get_nested_model_3(input_dim=input_dim, num_classes=num_classes) - model.compile(loss='mse', - optimizer=RMSPropOptimizer(learning_rate=0.001), - metrics=['acc']) + model = get_nested_model_3(input_dim=input_dim, num_classes=num_classes) + model.compile(loss='mse', + optimizer=RMSPropOptimizer(learning_rate=0.001), + metrics=['acc']) - x = np.ones((num_samples, input_dim)) - y = np.zeros((num_samples, num_classes)) + x = np.ones((num_samples, input_dim)) + y = np.zeros((num_samples, num_classes)) - model.fit(x, y, epochs=2, batch_size=32, verbose=0) - _ = model.evaluate(x, y, verbose=0) + model.fit(x, y, epochs=2, batch_size=32, verbose=0) + _ = model.evaluate(x, y, verbose=0) - self.assertEqual(len(model.weights), 16) - self.assertEqual( - len(model.non_trainable_weights), 4) - self.assertEqual(len(model.trainable_weights), 12) + self.assertEqual(len(model.weights), 16) + self.assertEqual( + len(model.non_trainable_weights), 4) + self.assertEqual(len(model.trainable_weights), 12) @test_util.run_in_graph_and_eager_modes() def test_support_for_manual_training_arg(self): @@ -575,14 +565,13 @@ class ModelSubclassingTest(test.TestCase): x = self.dp(inputs, training=training) return self.dense(x) - with self.test_session(): - model = DPNet() - x = np.ones((10, 10)) - y = model.predict(x) - self.assertEqual(np.sum(y), np.sum(x)) - model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) - loss = model.train_on_batch(x, y) - self.assertGreater(loss, 0.1) + model = DPNet() + x = np.ones((10, 10)) + y = model.predict(x) + self.assertEqual(np.sum(y), np.sum(x)) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + loss = model.train_on_batch(x, y) + self.assertGreater(loss, 0.1) if __name__ == '__main__': diff --git a/tensorflow/python/keras/_impl/keras/models.py b/tensorflow/python/keras/_impl/keras/models.py index 8000eaabab48dd225b7132af3ffeb11798c3096a..9602e7ba39b290f33c7ca9d0d1b5b35838667531 100644 --- a/tensorflow/python/keras/_impl/keras/models.py +++ b/tensorflow/python/keras/_impl/keras/models.py @@ -13,1305 +13,30 @@ # limitations under the License. # ============================================================================== # pylint: disable=protected-access -"""Home of the Sequential model, and the `save_model`/`load_model` functions. +"""Code for model cloning, plus model-related API entries. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy -import json -import os - -import numpy as np - -from tensorflow.python.framework import ops from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import layers as layer_module -from tensorflow.python.keras._impl.keras import optimizers -from tensorflow.python.keras._impl.keras.engine import topology -from tensorflow.python.keras._impl.keras.engine.topology import Input -from tensorflow.python.keras._impl.keras.engine.topology import InputLayer -from tensorflow.python.keras._impl.keras.engine.topology import Layer -from tensorflow.python.keras._impl.keras.engine.topology import TFBaseLayer -from tensorflow.python.keras._impl.keras.engine.training import Model +from tensorflow.python.keras._impl.keras.engine import saving +from tensorflow.python.keras._impl.keras.engine import sequential +from tensorflow.python.keras._impl.keras.engine import training +from tensorflow.python.keras._impl.keras.engine.input_layer import Input +from tensorflow.python.keras._impl.keras.engine.input_layer import InputLayer +from tensorflow.python.keras._impl.keras.utils import generic_utils from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg -from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.util.tf_export import tf_export - - -# pylint: disable=g-import-not-at-top -try: - import h5py -except ImportError: - h5py = None - -try: - import yaml -except ImportError: - yaml = None -# pylint: enable=g-import-not-at-top - - -@tf_export('keras.models.save_model') -def save_model(model, filepath, overwrite=True, include_optimizer=True): - """Save a model to a HDF5 file. - - The saved model contains: - - the model's configuration (topology) - - the model's weights - - the model's optimizer's state (if any) - - Thus the saved model can be reinstantiated in - the exact same state, without any of the code - used for model definition or training. - - Arguments: - model: Keras model instance to be saved. - filepath: String, path where to save the model. - overwrite: Whether we should overwrite any existing - model at the target location, or instead - ask the user with a manual prompt. - include_optimizer: If True, save optimizer's state together. - - Raises: - ImportError: if h5py is not available. - """ - - if h5py is None: - raise ImportError('`save_model` requires h5py.') - - def get_json_type(obj): - """Serialize any object to a JSON-serializable structure. - - Arguments: - obj: the object to serialize - - Returns: - JSON-serializable structure representing `obj`. - - Raises: - TypeError: if `obj` cannot be serialized. - """ - # if obj is a serializable Keras class instance - # e.g. optimizer, layer - if hasattr(obj, 'get_config'): - return {'class_name': obj.__class__.__name__, 'config': obj.get_config()} - - # if obj is any numpy type - if type(obj).__module__ == np.__name__: - if isinstance(obj, np.ndarray): - return {'type': type(obj), 'value': obj.tolist()} - else: - return obj.item() - - # misc functions (e.g. loss function) - if callable(obj): - return obj.__name__ - - # if obj is a python 'type' - if type(obj).__name__ == type.__name__: - return obj.__name__ - - raise TypeError('Not JSON Serializable:', obj) - - from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top - - # If file exists and should not be overwritten. - if not overwrite and os.path.isfile(filepath): - proceed = ask_to_proceed_with_overwrite(filepath) - if not proceed: - return - - with h5py.File(filepath, mode='w') as f: - f.attrs['keras_version'] = str(keras_version).encode('utf8') - f.attrs['backend'] = K.backend().encode('utf8') - f.attrs['model_config'] = json.dumps( - { - 'class_name': model.__class__.__name__, - 'config': model.get_config() - }, - default=get_json_type).encode('utf8') - - model_weights_group = f.create_group('model_weights') - model_layers = model.layers - topology.save_weights_to_hdf5_group(model_weights_group, model_layers) - - if include_optimizer and hasattr(model, 'optimizer'): - if isinstance(model.optimizer, optimizers.TFOptimizer): - logging.warning( - 'TensorFlow optimizers do not ' - 'make it possible to access ' - 'optimizer attributes or optimizer state ' - 'after instantiation. ' - 'As a result, we cannot save the optimizer ' - 'as part of the model save file.' - 'You will have to compile your model again after loading it. ' - 'Prefer using a Keras optimizer instead ' - '(see keras.io/optimizers).') - else: - f.attrs['training_config'] = json.dumps( - { - 'optimizer_config': { - 'class_name': model.optimizer.__class__.__name__, - 'config': model.optimizer.get_config() - }, - 'loss': model.loss, - 'metrics': model.metrics, - 'sample_weight_mode': model.sample_weight_mode, - 'loss_weights': model.loss_weights, - }, - default=get_json_type).encode('utf8') - - # Save optimizer weights. - symbolic_weights = getattr(model.optimizer, 'weights') - if symbolic_weights: - optimizer_weights_group = f.create_group('optimizer_weights') - weight_values = K.batch_get_value(symbolic_weights) - weight_names = [] - for w, val in zip(symbolic_weights, weight_values): - name = str(w.name) - weight_names.append(name.encode('utf8')) - optimizer_weights_group.attrs['weight_names'] = weight_names - for name, val in zip(weight_names, weight_values): - param_dset = optimizer_weights_group.create_dataset( - name, val.shape, dtype=val.dtype) - if not val.shape: - # scalar - param_dset[()] = val - else: - param_dset[:] = val - f.flush() - - -@tf_export('keras.models.load_model') -def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=redefined-builtin - """Loads a model saved via `save_model`. - - Arguments: - filepath: String, path to the saved model. - custom_objects: Optional dictionary mapping names - (strings) to custom classes or functions to be - considered during deserialization. - compile: Boolean, whether to compile the model - after loading. - - Returns: - A Keras model instance. If an optimizer was found - as part of the saved model, the model is already - compiled. Otherwise, the model is uncompiled and - a warning will be displayed. When `compile` is set - to False, the compilation is omitted without any - warning. - - Raises: - ImportError: if h5py is not available. - ValueError: In case of an invalid savefile. - """ - if h5py is None: - raise ImportError('`load_model` requires h5py.') - - if not custom_objects: - custom_objects = {} - - def convert_custom_objects(obj): - """Handles custom object lookup. - - Arguments: - obj: object, dict, or list. - - Returns: - The same structure, where occurrences - of a custom object name have been replaced - with the custom object. - """ - if isinstance(obj, list): - deserialized = [] - for value in obj: - deserialized.append(convert_custom_objects(value)) - return deserialized - if isinstance(obj, dict): - deserialized = {} - for key, value in obj.items(): - deserialized[key] = convert_custom_objects(value) - return deserialized - if obj in custom_objects: - return custom_objects[obj] - return obj - - with h5py.File(filepath, mode='r') as f: - # instantiate model - model_config = f.attrs.get('model_config') - if model_config is None: - raise ValueError('No model found in config file.') - model_config = json.loads(model_config.decode('utf-8')) - model = model_from_config(model_config, custom_objects=custom_objects) - - # set weights - topology.load_weights_from_hdf5_group(f['model_weights'], model.layers) - - # Early return if compilation is not required. - if not compile: - return model - - # instantiate optimizer - training_config = f.attrs.get('training_config') - if training_config is None: - logging.warning('No training configuration found in save file: ' - 'the model was *not* compiled. Compile it manually.') - return model - training_config = json.loads(training_config.decode('utf-8')) - optimizer_config = training_config['optimizer_config'] - optimizer = optimizers.deserialize( - optimizer_config, custom_objects=custom_objects) - - # Recover loss functions and metrics. - loss = convert_custom_objects(training_config['loss']) - metrics = convert_custom_objects(training_config['metrics']) - sample_weight_mode = training_config['sample_weight_mode'] - loss_weights = training_config['loss_weights'] - - # Compile model. - model.compile( - optimizer=optimizer, - loss=loss, - metrics=metrics, - loss_weights=loss_weights, - sample_weight_mode=sample_weight_mode) - - # Set optimizer weights. - if 'optimizer_weights' in f: - # Build train function (to get weight updates). - if isinstance(model, Sequential): - model.model._make_train_function() - else: - model._make_train_function() - optimizer_weights_group = f['optimizer_weights'] - optimizer_weight_names = [ - n.decode('utf8') - for n in optimizer_weights_group.attrs['weight_names'] - ] - optimizer_weight_values = [ - optimizer_weights_group[n] for n in optimizer_weight_names - ] - try: - model.optimizer.set_weights(optimizer_weight_values) - except ValueError: - logging.warning('Error in loading the saved optimizer ' - 'state. As a result, your model is ' - 'starting with a freshly initialized ' - 'optimizer.') - return model - - -@tf_export('keras.models.model_from_config') -def model_from_config(config, custom_objects=None): - """Instantiates a Keras model from its config. - - Arguments: - config: Configuration dictionary. - custom_objects: Optional dictionary mapping names - (strings) to custom classes or functions to be - considered during deserialization. - - Returns: - A Keras model instance (uncompiled). - - Raises: - TypeError: if `config` is not a dictionary. - """ - if isinstance(config, list): - raise TypeError('`model_from_config` expects a dictionary, not a list. ' - 'Maybe you meant to use ' - '`Sequential.from_config(config)`?') - return layer_module.deserialize(config, custom_objects=custom_objects) - - -@tf_export('keras.models.model_from_yaml') -def model_from_yaml(yaml_string, custom_objects=None): - """Parses a yaml model configuration file and returns a model instance. - - Arguments: - yaml_string: YAML string encoding a model configuration. - custom_objects: Optional dictionary mapping names - (strings) to custom classes or functions to be - considered during deserialization. - - Returns: - A Keras model instance (uncompiled). - - Raises: - ImportError: if yaml module is not found. - """ - if yaml is None: - raise ImportError('Requires yaml module installed.') - config = yaml.load(yaml_string) - return layer_module.deserialize(config, custom_objects=custom_objects) - - -@tf_export('keras.models.model_from_json') -def model_from_json(json_string, custom_objects=None): - """Parses a JSON model configuration file and returns a model instance. - - Arguments: - json_string: JSON string encoding a model configuration. - custom_objects: Optional dictionary mapping names - (strings) to custom classes or functions to be - considered during deserialization. - - Returns: - A Keras model instance (uncompiled). - """ - config = json.loads(json_string) - return layer_module.deserialize(config, custom_objects=custom_objects) - - -@tf_export('keras.models.Sequential', 'keras.Sequential') -class Sequential(Model): - """Linear stack of layers. - - Arguments: - layers: list of layers to add to the model. - - # Note - The first layer passed to a Sequential model - should have a defined input shape. What that - means is that it should have received an `input_shape` - or `batch_input_shape` argument, - or for some type of layers (recurrent, Dense...) - an `input_dim` argument. - - Example: - - ```python - model = Sequential() - # first layer must have a defined input shape - model.add(Dense(32, input_dim=500)) - # afterwards, Keras does automatic shape inference - model.add(Dense(32)) - - # also possible (equivalent to the above): - model = Sequential() - model.add(Dense(32, input_shape=(500,))) - model.add(Dense(32)) - - # also possible (equivalent to the above): - model = Sequential() - # here the batch dimension is None, - # which means any batch size will be accepted by the model. - model.add(Dense(32, batch_input_shape=(None, 500))) - model.add(Dense(32)) - ``` - """ - - def __init__(self, layers=None, name=None): - self._is_graph_network = True - self._is_compiled = False - self._layers = [] # Stack of layers. - self.model = None # Internal Model instance. - self.inputs = [] # List of input tensors - self.outputs = [] # List of length 1: the output tensor (unique). - self._trainable = True - self._initial_weights = None - self._input_layers = [] - - # Model attributes. - self._inbound_nodes = [] - self._outbound_nodes = [] - self.built = False - - # Set model name. - if not name: - prefix = 'sequential_' - name = prefix + str(K.get_uid(prefix)) - self._name = name - - # Used by Layer base class. - self._dtype = None - self._activity_regularizer = None - - # The following properties are not actually used by Keras; - # they exist for compatibility with TF's variable scoping mechanism. - self._updates = [] - self._losses = [] - self._scope = None - self._reuse = None - self._base_name = name - self._graph = ops.get_default_graph() - - # Add to the model any layers passed to the constructor. - if layers: - for layer in layers: - self.add(layer) - - def add(self, layer): - """Adds a layer instance on top of the layer stack. - - Arguments: - layer: layer instance. - - Raises: - TypeError: If `layer` is not a layer instance. - ValueError: In case the `layer` argument does not - know its input shape. - ValueError: In case the `layer` argument has - multiple output tensors, or is already connected - somewhere else (forbidden in `Sequential` models). - """ - if not isinstance(layer, (Layer, TFBaseLayer)): - raise TypeError('The added layer must be ' - 'an instance of class Layer. ' - 'Found: ' + str(layer)) - if not self.outputs: - # First layer in model: check that it is an input layer. - if not isinstance(layer, InputLayer): - # Create an input layer. - # First, we need to infer its expected input shape and dtype. - if isinstance(layer, (Model, Sequential)): - # We were passed a model as first layer. - # This requires a specific way to figure out the - # input shape and dtype. - if not layer.layers: - raise ValueError('Cannot add an empty model ' - 'to a `Sequential` model.') - # In case of nested models: recover the first layer - # of the deepest model to infer input shape and dtype. - first_layer = layer.layers[0] - while isinstance(first_layer, (Model, Sequential)): - first_layer = first_layer.layers[0] - batch_shape = first_layer._batch_input_shape - dtype = first_layer.dtype - else: - # We were passed a regular layer, and it should - # know about its input shape. Otherwise, that's an error. - if not hasattr(layer, '_batch_input_shape'): - raise ValueError('The first layer in a ' - 'Sequential model must ' - 'get an `input_shape` argument.') - batch_shape = layer._batch_input_shape - dtype = layer.dtype - # Instantiate the input layer. - x = Input( - batch_shape=batch_shape, dtype=dtype, name=layer.name + '_input') - # This will build the current layer - # and create the node connecting the current layer - # to the input layer we just created. - layer(x) - - if len(layer._inbound_nodes[-1].output_tensors) != 1: - raise ValueError('All layers in a Sequential model ' - 'should have a single output tensor. ' - 'For multi-output layers, ' - 'use the functional API.') - - self.outputs = [layer._inbound_nodes[-1].output_tensors[0]] - self.inputs = topology.get_source_inputs(self.outputs[0]) - - # We create an input node, which we will keep updated - # as we add more layers - topology.Node( - outbound_layer=self, - inbound_layers=[], - node_indices=[], - tensor_indices=[], - input_tensors=self.inputs, - output_tensors=self.outputs) - else: - output_tensor = layer(self.outputs[0]) - if isinstance(output_tensor, list): - raise TypeError('All layers in a Sequential model ' - 'should have a single output tensor. ' - 'For multi-output layers, ' - 'use the functional API.') - self.outputs = [output_tensor] - # update self._inbound_nodes - self._inbound_nodes[0].output_tensors = self.outputs - self._inbound_nodes[0].output_shapes = [K.int_shape(self.outputs[0])] - - self._layers.append(layer) - self.built = False - - def pop(self): - """Removes the last layer in the model. - - Raises: - TypeError: if there are no layers in the model. - """ - if not self.layers: - raise TypeError('There are no layers in the model.') - - self.layers.pop() - if not self.layers: - self.outputs = [] - self._inbound_nodes = [] - self._outbound_nodes = [] - else: - self.layers[-1]._outbound_nodes = [] - self.outputs = [self.layers[-1].output] - # update self._inbound_nodes - self._inbound_nodes[0].output_tensors = self.outputs - self._inbound_nodes[0].output_shapes = [K.int_shape(self.outputs[0])] - self.built = False - - def get_layer(self, name=None, index=None): - """Retrieve a layer that is part of the model. - - Returns a layer based on either its name (unique) - or its index in the graph. Indices are based on - order of horizontal graph traversal (bottom-up). - - Arguments: - name: string, name of layer. - index: integer, index of layer. - - Returns: - A layer instance. - """ - if not self.built: - self.build() - return self.model.get_layer(name, index) - - def call(self, inputs, **kwargs): - if not self.built: - self.build() - return self.model.call(inputs, **kwargs) - - def build(self, input_shape=None): - if not self.inputs or not self.outputs: - raise TypeError('Sequential model cannot be built: model is empty.' - ' Add some layers first.') - # actually create the model - self.model = Model(self.inputs, self.outputs[0], name=self.name + '_model') - self.model.trainable = self.trainable - - # mirror model attributes - self.supports_masking = self.model.supports_masking - self._output_mask_cache = self.model._output_mask_cache - self._output_tensor_cache = self.model._output_tensor_cache - self._output_shape_cache = self.model._output_shape_cache - self._input_layers = self.model._input_layers - self._output_layers = self.model._output_layers - self._input_coordinates = self.model._input_coordinates - self._output_coordinates = self.model._output_coordinates - self._nodes_by_depth = self.model._nodes_by_depth - self._network_nodes = self.model._network_nodes - self.output_names = self.model.output_names - self.input_names = self.model.input_names - self._feed_input_names = self.model._feed_input_names - self._feed_inputs = self.model._feed_inputs - - # Make sure child model callbacks - # will call the parent Sequential model. - self.model.callback_model = self - - self.built = True - - @property - def uses_learning_phase(self): - if not self.built: - self.build() - return self.model.uses_learning_phase - - def _gather_list_attr(self, attr): - all_attrs = [] - for layer in self.layers: - all_attrs += getattr(layer, attr, []) - return all_attrs - - @property - def trainable(self): - return self._trainable - - @trainable.setter - def trainable(self, value): - if self.model: - self.model.trainable = value - self._trainable = value - - @property - def trainable_weights(self): - if not self.trainable: - return [] - return self._gather_list_attr('trainable_weights') - - @property - def non_trainable_weights(self): - weights = self._gather_list_attr('non_trainable_weights') - if not self.trainable: - trainable_weights = self._gather_list_attr('trainable_weights') - return trainable_weights + weights - return weights - - @property - def regularizers(self): - if not self.built: - self.build() - return self.model.regularizers - - def get_weights(self): - """Retrieves the weights of the model. - - Returns: - A flat list of Numpy arrays - (one array per model weight). - """ - if not self.built: - self.build() - return self.model.get_weights() - - def set_weights(self, weights): - """Sets the weights of the model. - - Arguments: - weights: Should be a list - of Numpy arrays with shapes and types matching - the output of `model.get_weights()`. - """ - if not self.built: - self.build() - self.model.set_weights(weights) - - def load_weights(self, filepath, by_name=False): - if h5py is None: - raise ImportError('`load_weights` requires h5py.') - f = h5py.File(filepath, mode='r') - if 'layer_names' not in f.attrs and 'model_weights' in f: - f = f['model_weights'] - layers = self.layers - if by_name: - topology.load_weights_from_hdf5_group_by_name(f, layers) - else: - topology.load_weights_from_hdf5_group(f, layers) - if hasattr(f, 'close'): - f.close() - - def save_weights(self, filepath, overwrite=True): - if h5py is None: - raise ImportError('`save_weights` requires h5py.') - # If file exists and should not be overwritten: - if not overwrite and os.path.isfile(filepath): - proceed = ask_to_proceed_with_overwrite(filepath) - if not proceed: - return - layers = self.layers - f = h5py.File(filepath, 'w') - topology.save_weights_to_hdf5_group(f, layers) - f.flush() - f.close() - - def compile(self, - optimizer, - loss, - metrics=None, - sample_weight_mode=None, - weighted_metrics=None, - target_tensors=None, - **kwargs): - """Configures the model for training. - - Arguments: - optimizer: String (name of optimizer) or optimizer object. - See [optimizers](/optimizers). - loss: String (name of objective function) or objective function. - See [losses](/losses). - If the model has multiple outputs, you can use a different loss - on each output by passing a dictionary or a list of losses. - The loss value that will be minimized by the model - will then be the sum of all individual losses. - metrics: List of metrics to be evaluated by the model - during training and testing. - Typically you will use `metrics=['accuracy']`. - To specify different metrics for different outputs of a - multi-output model, you could also pass a dictionary, - such as `metrics={'output_a': 'accuracy'}`. - sample_weight_mode: If you need to do timestep-wise - sample weighting (2D weights), set this to `"temporal"`. - `None` defaults to sample-wise weights (1D). - If the model has multiple outputs, you can use a different - `sample_weight_mode` on each output by passing a - dictionary or a list of modes. - weighted_metrics: list of metrics to be evaluated and weighted - by `sample_weight` or `class_weight` during training and testing. - target_tensors: By default, Keras will create a placeholder for the - model's target, which will be fed with the target data during - training. If instead you would like to use your own - target tensor (in turn, Keras will not expect external - Numpy data for these targets at training time), you - can specify them via the `target_tensors` argument. - It should be a single tensor - (for a single-output `Sequential` model). - **kwargs: These arguments are passed into `tf.Session.run`. - - Example: - ```python - model = Sequential() - model.add(Dense(32, input_shape=(500,))) - model.add(Dense(10, activation='softmax')) - model.compile(optimizer='rmsprop', - loss='categorical_crossentropy', - metrics=['accuracy']) - ``` - """ - # create the underlying model - self.build() - # call compile method of Model class - self.model.compile( - optimizer, - loss, - metrics=metrics, - sample_weight_mode=sample_weight_mode, - weighted_metrics=weighted_metrics, - target_tensors=target_tensors, - **kwargs) - self.optimizer = self.model.optimizer - self.loss = self.model.loss - self.metrics = self.model.metrics - self.loss_weights = self.model.loss_weights - self.sample_weight_mode = self.model.sample_weight_mode - self.weighted_metrics = self.model.weighted_metrics - self.targets = self.model.targets - self.metrics_tensors = self.model.metrics_tensors - self.metrics_names = self.model.metrics_names - self.sample_weights = self.model.sample_weights - self.total_loss = self.model.total_loss - - def fit(self, - x=None, - y=None, - batch_size=None, - epochs=1, - verbose=1, - callbacks=None, - validation_split=0., - validation_data=None, - shuffle=True, - class_weight=None, - sample_weight=None, - initial_epoch=0, - steps_per_epoch=None, - validation_steps=None, - **kwargs): - """Trains the model for a fixed number of epochs. - - Arguments: - x: Numpy array of training data. - If the input layer in the model is named, you can also pass a - dictionary mapping the input name to a Numpy array. - `x` can be `None` (default) if feeding from - TensorFlow data tensors. - y: Numpy array of target (label) data. - If the output layer in the model is named, you can also pass a - dictionary mapping the output name to a Numpy array. - `y` can be `None` (default) if feeding from - TensorFlow data tensors. - batch_size: Integer or `None`. - Number of samples per gradient update. - If unspecified, it will default to 32. - epochs: Integer. Number of epochs to train the model. - An epoch is an iteration over the entire `x` and `y` - data provided. - Note that in conjunction with `initial_epoch`, - `epochs` is to be understood as "final epoch". - The model is not trained for a number of iterations - given by `epochs`, but merely until the epoch - of index `epochs` is reached. - verbose: 0, 1, or 2. Verbosity mode. - 0 = silent, 1 = progress bar, 2 = one line per epoch. - callbacks: List of `keras.callbacks.Callback` instances. - List of callbacks to apply during training. - See [callbacks](/callbacks). - validation_split: Float between 0 and 1: - Fraction of the training data to be used as validation data. - The model will set apart this fraction of the training data, - will not train on it, and will evaluate - the loss and any model metrics - on this data at the end of each epoch. - The validation data is selected from the last samples - in the `x` and `y` data provided, before shuffling. - validation_data: tuple `(x_val, y_val)` or tuple - `(x_val, y_val, val_sample_weights)` on which to evaluate - the loss and any model metrics at the end of each epoch. - The model will not be trained on this data. - This will override `validation_split`. - shuffle: Boolean (whether to shuffle the training data - before each epoch) or str (for 'batch'). - 'batch' is a special option for dealing with the - limitations of HDF5 data; it shuffles in batch-sized chunks. - Has no effect when `steps_per_epoch` is not `None`. - class_weight: Optional dictionary mapping class indices (integers) - to a weight (float) value, used for weighting the loss function - (during training only). - This can be useful to tell the model to - "pay more attention" to samples from - an under-represented class. - sample_weight: Optional Numpy array of weights for - the training samples, used for weighting the loss function - (during training only). You can either pass a flat (1D) - Numpy array with the same length as the input samples - (1:1 mapping between weights and samples), - or in the case of temporal data, - you can pass a 2D array with shape - `(samples, sequence_length)`, - to apply a different weight to every timestep of every sample. - In this case you should make sure to specify - `sample_weight_mode="temporal"` in `compile()`. - initial_epoch: Epoch at which to start training - (useful for resuming a previous training run). - steps_per_epoch: Total number of steps (batches of samples) - before declaring one epoch finished and starting the - next epoch. When training with input tensors such as - TensorFlow data tensors, the default `None` is equal to - the number of unique samples in your dataset divided by - the batch size, or 1 if that cannot be determined. - validation_steps: Only relevant if `steps_per_epoch` - is specified. Total number of steps (batches of samples) - to validate before stopping. - **kwargs: Used for backwards compatibility support. - - Returns: - A `History` object. Its `History.history` attribute is - a record of training loss values and metrics values - at successive epochs, as well as validation loss values - and validation metrics values (if applicable). - - Raises: - RuntimeError: If the model was never compiled. - ValueError: In case of mismatch between the provided input data - and what the model expects. - """ - if not self.built: - raise RuntimeError('The model needs to be compiled before being used.') - return self.model.fit( - x, - y, - batch_size=batch_size, - epochs=epochs, - verbose=verbose, - callbacks=callbacks, - validation_split=validation_split, - validation_data=validation_data, - shuffle=shuffle, - class_weight=class_weight, - sample_weight=sample_weight, - initial_epoch=initial_epoch, - steps_per_epoch=steps_per_epoch, - validation_steps=validation_steps) - - def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None): - """Computes the loss on some input data, batch by batch. - - Arguments: - x: input data, as a Numpy array or list of Numpy arrays - (if the model has multiple inputs). - y: labels, as a Numpy array. - batch_size: integer. Number of samples per gradient update. - verbose: verbosity mode, 0 or 1. - sample_weight: sample weights, as a Numpy array. - - Returns: - Scalar test loss (if the model has no metrics) - or list of scalars (if the model computes other metrics). - The attribute `model.metrics_names` will give you - the display labels for the scalar outputs. - - Raises: - RuntimeError: if the model was never compiled. - """ - if not self.built: - raise RuntimeError('The model needs to be compiled before being used.') - return self.model.evaluate( - x, - y, - batch_size=batch_size, - verbose=verbose, - sample_weight=sample_weight) - - def predict(self, x, batch_size=32, verbose=0): - """Generates output predictions for the input samples. - - The input samples are processed batch by batch. - - Arguments: - x: the input data, as a Numpy array. - batch_size: integer. - verbose: verbosity mode, 0 or 1. - - Returns: - A Numpy array of predictions. - """ - if not self.built: - self.build() - return self.model.predict(x, batch_size=batch_size, verbose=verbose) - - def predict_on_batch(self, x): - """Returns predictions for a single batch of samples. - - Arguments: - x: input data, as a Numpy array or list of Numpy arrays - (if the model has multiple inputs). - - Returns: - A Numpy array of predictions. - """ - if not self.built: - self.build() - return self.model.predict_on_batch(x) - - def train_on_batch(self, x, y, class_weight=None, sample_weight=None): - """Single gradient update over one batch of samples. - - Arguments: - x: input data, as a Numpy array or list of Numpy arrays - (if the model has multiple inputs). - y: labels, as a Numpy array. - class_weight: dictionary mapping classes to a weight value, - used for scaling the loss function (during training only). - sample_weight: sample weights, as a Numpy array. - - Returns: - Scalar training loss (if the model has no metrics) - or list of scalars (if the model computes other metrics). - The attribute `model.metrics_names` will give you - the display labels for the scalar outputs. - - Raises: - RuntimeError: if the model was never compiled. - """ - if not self.built: - raise RuntimeError('The model needs to be compiled before being used.') - return self.model.train_on_batch( - x, y, sample_weight=sample_weight, class_weight=class_weight) - - def test_on_batch(self, x, y, sample_weight=None): - """Evaluates the model over a single batch of samples. - - Arguments: - x: input data, as a Numpy array or list of Numpy arrays - (if the model has multiple inputs). - y: labels, as a Numpy array. - sample_weight: sample weights, as a Numpy array. - - Returns: - Scalar test loss (if the model has no metrics) - or list of scalars (if the model computes other metrics). - The attribute `model.metrics_names` will give you - the display labels for the scalar outputs. - - Raises: - RuntimeError: if the model was never compiled. - """ - if not self.built: - raise RuntimeError('The model needs to be compiled before being used.') - return self.model.test_on_batch(x, y, sample_weight=sample_weight) - - def predict_proba(self, x, batch_size=32, verbose=0): - """Generates class probability predictions for the input samples. - - The input samples are processed batch by batch. - - Arguments: - x: input data, as a Numpy array or list of Numpy arrays - (if the model has multiple inputs). - batch_size: integer. - verbose: verbosity mode, 0 or 1. - - Returns: - A Numpy array of probability predictions. - """ - preds = self.predict(x, batch_size, verbose) - if preds.min() < 0. or preds.max() > 1.: - logging.warning('Network returning invalid probability values. ' - 'The last layer might not normalize predictions ' - 'into probabilities ' - '(like softmax or sigmoid would).') - return preds - - def predict_classes(self, x, batch_size=32, verbose=0): - """Generate class predictions for the input samples. - - The input samples are processed batch by batch. - - Arguments: - x: input data, as a Numpy array or list of Numpy arrays - (if the model has multiple inputs). - batch_size: integer. - verbose: verbosity mode, 0 or 1. - - Returns: - A numpy array of class predictions. - """ - proba = self.predict(x, batch_size=batch_size, verbose=verbose) - if proba.shape[-1] > 1: - return proba.argmax(axis=-1) - else: - return (proba > 0.5).astype('int32') - - def fit_generator(self, - generator, - steps_per_epoch=None, - epochs=1, - verbose=1, - callbacks=None, - validation_data=None, - validation_steps=None, - class_weight=None, - max_queue_size=10, - workers=1, - use_multiprocessing=False, - shuffle=True, - initial_epoch=0, - **kwargs): - """Fits the model on data generated batch-by-batch by a Python generator. - - The generator is run in parallel to the model, for efficiency. - For instance, this allows you to do real-time data augmentation - on images on CPU in parallel to training your model on GPU. - - Arguments: - generator: A generator. - The output of the generator must be either - - a tuple (inputs, targets) - - a tuple (inputs, targets, sample_weights). - All arrays should contain the same number of samples. - The generator is expected to loop over its data - indefinitely. An epoch finishes when `steps_per_epoch` - batches have been seen by the model. - steps_per_epoch: Total number of steps (batches of samples) - to yield from `generator` before declaring one epoch - finished and starting the next epoch. It should typically - be equal to the number of samples of your dataset - divided by the batch size. - Optional for `Sequence`: if unspecified, will use - the `len(generator)` as a number of steps. - epochs: Integer, total number of iterations on the data. - Note that in conjunction with initial_epoch, the parameter - epochs is to be understood as "final epoch". The model is - not trained for n steps given by epochs, but until the - epoch epochs is reached. - verbose: Verbosity mode, 0, 1, or 2. - callbacks: List of callbacks to be called during training. - validation_data: This can be either - - A generator for the validation data - - A tuple (inputs, targets) - - A tuple (inputs, targets, sample_weights). - validation_steps: Only relevant if `validation_data` - is a generator. - Number of steps to yield from validation generator - at the end of every epoch. It should typically - be equal to the number of samples of your - validation dataset divided by the batch size. - Optional for `Sequence`: if unspecified, will use - the `len(validation_data)` as a number of steps. - class_weight: Dictionary mapping class indices to a weight - for the class. - max_queue_size: Maximum size for the generator queue - workers: Maximum number of processes to spin up - use_multiprocessing: If True, use process based threading. - Note that because - this implementation relies on multiprocessing, - you should not pass - non picklable arguments to the generator - as they can't be passed - easily to children processes. - shuffle: Whether to shuffle the order of the batches at - the beginning of each epoch. Only used with instances - of `Sequence` (keras.utils.Sequence). - initial_epoch: Epoch at which to start training - (useful for resuming a previous training run) - **kwargs: support for legacy arguments. - - Returns: - A `History` object. - - Raises: - RuntimeError: if the model was never compiled. - ValueError: In case the generator yields - data in an invalid format. - - Example: - - ```python - def generate_arrays_from_file(path): - while 1: - f = open(path) - for line in f: - # create Numpy arrays of input data - # and labels, from each line in the file - x, y = process_line(line) - yield (x, y) - f.close() - - model.fit_generator(generate_arrays_from_file('/my_file.txt'), - steps_per_epoch=1000, epochs=10) - ``` - """ - # Legacy support - if 'max_q_size' in kwargs: - max_queue_size = kwargs.pop('max_q_size') - logging.warning('The argument `max_q_size` has been renamed ' - '`max_queue_size`. Update your method calls accordingly.') - if 'pickle_safe' in kwargs: - use_multiprocessing = kwargs.pop('pickle_safe') - logging.warning('The argument `pickle_safe` has been renamed ' - '`use_multiprocessing`. ' - 'Update your method calls accordingly.') - if kwargs: - raise ValueError('Unrecognized keyword arguments: ' + str(kwargs)) - - if not self.built: - raise RuntimeError('The model needs to be compiled before being used.') - return self.model.fit_generator( - generator, - steps_per_epoch, - epochs, - verbose=verbose, - callbacks=callbacks, - validation_data=validation_data, - validation_steps=validation_steps, - class_weight=class_weight, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing, - shuffle=shuffle, - initial_epoch=initial_epoch) - - def evaluate_generator(self, - generator, - steps=None, - max_queue_size=10, - workers=1, - use_multiprocessing=False, - **kwargs): - """Evaluates the model on a data generator. - - The generator should return the same kind of data - as accepted by `test_on_batch`. - - Arguments: - generator: Generator yielding tuples (inputs, targets) - or (inputs, targets, sample_weights) - steps: Total number of steps (batches of samples) - to yield from `generator` before stopping. - Optional for `Sequence`: if unspecified, will use - the `len(generator)` as a number of steps. - max_queue_size: maximum size for the generator queue - workers: maximum number of processes to spin up - use_multiprocessing: if True, use process based threading. - Note that because this implementation - relies on multiprocessing, you should not pass - non picklable arguments to the generator - as they can't be passed easily to children processes. - **kwargs: support for legacy arguments. - - Returns: - Scalar test loss (if the model has no metrics) - or list of scalars (if the model computes other metrics). - The attribute `model.metrics_names` will give you - the display labels for the scalar outputs. - - Raises: - RuntimeError: if the model was never compiled. - ValueError: In case the generator yields - data in an invalid format. - """ - # Legacy support - if 'max_q_size' in kwargs: - max_queue_size = kwargs.pop('max_q_size') - logging.warning('The argument `max_q_size` has been renamed ' - '`max_queue_size`. Update your method calls accordingly.') - if 'pickle_safe' in kwargs: - use_multiprocessing = kwargs.pop('pickle_safe') - logging.warning('The argument `pickle_safe` has been renamed ' - '`use_multiprocessing`. ' - 'Update your method calls accordingly.') - if kwargs: - raise ValueError('Unrecognized keyword arguments: ' + str(kwargs)) - - if not self.built: - raise RuntimeError('The model needs to be compiled before being used.') - return self.model.evaluate_generator( - generator, - steps, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing) - - def predict_generator(self, - generator, - steps=None, - max_queue_size=10, - workers=1, - use_multiprocessing=False, - verbose=0, - **kwargs): - """Generates predictions for the input samples from a data generator. - - The generator should return the same kind of data as accepted by - `predict_on_batch`. - - Arguments: - generator: generator yielding batches of input samples. - steps: Total number of steps (batches of samples) - to yield from `generator` before stopping. - Optional for `Sequence`: if unspecified, will use - the `len(generator)` as a number of steps. - max_queue_size: maximum size for the generator queue - workers: maximum number of processes to spin up - use_multiprocessing: if True, use process based threading. - Note that because this implementation - relies on multiprocessing, you should not pass - non picklable arguments to the generator - as they can't be passed easily to children processes. - verbose: verbosity mode, 0 or 1. - **kwargs: support for legacy arguments. - - Returns: - A Numpy array of predictions. - - Raises: - ValueError: In case the generator yields - data in an invalid format. - """ - # Legacy support - if 'max_q_size' in kwargs: - max_queue_size = kwargs.pop('max_q_size') - logging.warning('The argument `max_q_size` has been renamed ' - '`max_queue_size`. Update your method calls accordingly.') - if 'pickle_safe' in kwargs: - use_multiprocessing = kwargs.pop('pickle_safe') - logging.warning('The argument `pickle_safe` has been renamed ' - '`use_multiprocessing`. ' - 'Update your method calls accordingly.') - if kwargs: - raise ValueError('Unrecognized keyword arguments: ' + str(kwargs)) - - if not self.built: - self.build() - return self.model.predict_generator( - generator, - steps, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing, - verbose=verbose) - def get_config(self): - config = [] - for layer in self.layers: - config.append({ - 'class_name': layer.__class__.__name__, - 'config': layer.get_config() - }) - return copy.deepcopy(config) - @classmethod - def from_config(cls, config, custom_objects=None): - model = cls() - for conf in config: - layer = layer_module.deserialize(conf, custom_objects=custom_objects) - model.add(layer) - return model +# API entries importable from `keras.models`: +Model = training.Model # pylint: disable=invalid-name +Sequential = sequential.Sequential # pylint: disable=invalid-name +save_model = saving.save_model +load_model = saving.load_model +model_from_config = saving.model_from_config +model_from_yaml = saving.model_from_yaml +model_from_json = saving.model_from_json def _clone_functional_model(model, input_tensors=None): @@ -1365,7 +90,7 @@ def _clone_functional_model(model, input_tensors=None): else: # Make sure that all input tensors come from a Keras layer. # If tensor comes from an input layer: cache the input layer. - input_tensors = topology.to_list(input_tensors) + input_tensors = generic_utils.to_list(input_tensors) input_tensors_ = [] for i, x in enumerate(input_tensors): if not K.is_keras_tensor(x): @@ -1402,7 +127,7 @@ def _clone_functional_model(model, input_tensors=None): # Reuse previously cloned layer. layer = layer_map[layer] # Don't call InputLayer multiple times. - if isinstance(layer, topology.InputLayer): + if isinstance(layer, InputLayer): continue # Gather inputs to call the new layer. @@ -1427,8 +152,9 @@ def _clone_functional_model(model, input_tensors=None): if has_arg(layer.call, 'mask'): if 'mask' not in kwargs: kwargs['mask'] = computed_mask - output_tensors = topology.to_list(layer(computed_tensor, **kwargs)) - output_masks = topology.to_list( + output_tensors = generic_utils.to_list(layer(computed_tensor, + **kwargs)) + output_masks = generic_utils.to_list( layer.compute_mask(computed_tensor, computed_mask)) computed_tensors = [computed_tensor] computed_masks = [computed_mask] @@ -1438,8 +164,9 @@ def _clone_functional_model(model, input_tensors=None): if has_arg(layer.call, 'mask'): if 'mask' not in kwargs: kwargs['mask'] = computed_masks - output_tensors = topology.to_list(layer(computed_tensors, **kwargs)) - output_masks = topology.to_list( + output_tensors = generic_utils.to_list(layer(computed_tensors, + **kwargs)) + output_masks = generic_utils.to_list( layer.compute_mask(computed_tensors, computed_masks)) # Update tensor_map. for x, y, mask in zip(reference_output_tensors, output_tensors, @@ -1489,14 +216,14 @@ def _clone_sequential_model(model, input_tensors=None): if input_tensors is None: return Sequential(layers=layers, name=model.name) else: - if len(topology.to_list(input_tensors)) != 1: + if len(generic_utils.to_list(input_tensors)) != 1: raise ValueError('To clone a `Sequential` model, we expect ' ' at most one tensor ' 'as part of `input_tensors`.') - x = topology.to_list(input_tensors)[0] + x = generic_utils.to_list(input_tensors)[0] if K.is_keras_tensor(x): origin_layer = x._keras_history[0] - if isinstance(origin_layer, topology.InputLayer): + if isinstance(origin_layer, InputLayer): return Sequential(layers=[origin_layer] + layers, name=model.name) else: raise ValueError('Cannot clone a `Sequential` model on top ' diff --git a/tensorflow/python/keras/_impl/keras/models_test.py b/tensorflow/python/keras/_impl/keras/models_test.py index 04017e4b28b27e52f88a7746fc44510c29edffce..5978ddd987c63b9d87a31be6837172f08512ef73 100644 --- a/tensorflow/python/keras/_impl/keras/models_test.py +++ b/tensorflow/python/keras/_impl/keras/models_test.py @@ -12,362 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for training routines.""" +"""Tests for `models.py` (model cloning, mainly).""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import shutil -import tempfile - import numpy as np from tensorflow.python.keras._impl import keras from tensorflow.python.platform import test -from tensorflow.python.training import training as training_module - -try: - import h5py # pylint:disable=g-import-not-at-top -except ImportError: - h5py = None - - -class TestModelSaving(test.TestCase): - - def test_sequential_model_saving(self): - if h5py is None: - return # Skip test if models cannot be saved. - - with self.test_session(): - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_shape=(3,))) - model.add(keras.layers.RepeatVector(3)) - model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) - model.compile(loss=keras.losses.MSE, - optimizer=keras.optimizers.RMSprop(lr=0.0001), - metrics=[keras.metrics.categorical_accuracy], - sample_weight_mode='temporal') - x = np.random.random((1, 3)) - y = np.random.random((1, 3, 3)) - model.train_on_batch(x, y) - - out = model.predict(x) - fd, fname = tempfile.mkstemp('.h5') - keras.models.save_model(model, fname) - - new_model = keras.models.load_model(fname) - os.close(fd) - os.remove(fname) - - out2 = new_model.predict(x) - self.assertAllClose(out, out2, atol=1e-05) - - # test that new updates are the same with both models - x = np.random.random((1, 3)) - y = np.random.random((1, 3, 3)) - model.train_on_batch(x, y) - new_model.train_on_batch(x, y) - out = model.predict(x) - out2 = new_model.predict(x) - self.assertAllClose(out, out2, atol=1e-05) - - def test_sequential_model_saving_2(self): - if h5py is None: - return # Skip test if models cannot be saved. - - with self.test_session(): - # test with custom optimizer, loss - - class CustomOp(keras.optimizers.RMSprop): - pass - - def custom_loss(y_true, y_pred): - return keras.losses.mse(y_true, y_pred) - - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_shape=(3,))) - model.add(keras.layers.Dense(3)) - model.compile(loss=custom_loss, optimizer=CustomOp(), metrics=['acc']) - - x = np.random.random((1, 3)) - y = np.random.random((1, 3)) - model.train_on_batch(x, y) - - out = model.predict(x) - fd, fname = tempfile.mkstemp('.h5') - keras.models.save_model(model, fname) - - model = keras.models.load_model( - fname, - custom_objects={'CustomOp': CustomOp, - 'custom_loss': custom_loss}) - os.close(fd) - os.remove(fname) - - out2 = model.predict(x) - self.assertAllClose(out, out2, atol=1e-05) - - def test_functional_model_saving(self): - if h5py is None: - return # Skip test if models cannot be saved. - - with self.test_session(): - inputs = keras.layers.Input(shape=(3,)) - x = keras.layers.Dense(2)(inputs) - output = keras.layers.Dense(3)(x) - - model = keras.models.Model(inputs, output) - model.compile(loss=keras.losses.MSE, - optimizer=keras.optimizers.RMSprop(lr=0.0001), - metrics=[keras.metrics.categorical_accuracy]) - x = np.random.random((1, 3)) - y = np.random.random((1, 3)) - model.train_on_batch(x, y) - - out = model.predict(x) - fd, fname = tempfile.mkstemp('.h5') - keras.models.save_model(model, fname) - - model = keras.models.load_model(fname) - os.close(fd) - os.remove(fname) - - out2 = model.predict(x) - self.assertAllClose(out, out2, atol=1e-05) - - def test_saving_without_compilation(self): - if h5py is None: - return # Skip test if models cannot be saved. - - with self.test_session(): - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_shape=(3,))) - model.add(keras.layers.Dense(3)) - model.compile(loss='mse', optimizer='sgd', metrics=['acc']) - - fd, fname = tempfile.mkstemp('.h5') - keras.models.save_model(model, fname) - model = keras.models.load_model(fname) - os.close(fd) - os.remove(fname) - - def test_saving_with_tf_optimizer(self): - if h5py is None: - return # Skip test if models cannot be saved. - - with self.test_session(): - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_shape=(3,))) - model.add(keras.layers.Dense(3)) - model.compile(loss='mse', - optimizer=training_module.AdadeltaOptimizer(0.1), - metrics=['acc']) - - fd, fname = tempfile.mkstemp('.h5') - keras.models.save_model(model, fname) - model = keras.models.load_model(fname) - os.close(fd) - os.remove(fname) - - def test_saving_right_after_compilation(self): - if h5py is None: - return # Skip test if models cannot be saved. - - with self.test_session(): - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_shape=(3,))) - model.add(keras.layers.Dense(3)) - model.compile(loss='mse', optimizer='sgd', metrics=['acc']) - model.model._make_train_function() - - fd, fname = tempfile.mkstemp('.h5') - keras.models.save_model(model, fname) - model = keras.models.load_model(fname) - os.close(fd) - os.remove(fname) - - def test_saving_lambda_numpy_array_arguments(self): - if h5py is None: - return # Skip test if models cannot be saved. - - mean = np.random.random((4, 2, 3)) - std = np.abs(np.random.random((4, 2, 3))) + 1e-5 - inputs = keras.layers.Input(shape=(4, 2, 3)) - output = keras.layers.Lambda(lambda image, mu, std: (image - mu) / std, - arguments={'mu': mean, 'std': std})(inputs) - model = keras.models.Model(inputs, output) - model.compile(loss='mse', optimizer='sgd', metrics=['acc']) - - fd, fname = tempfile.mkstemp('.h5') - keras.models.save_model(model, fname) - - model = keras.models.load_model(fname) - os.close(fd) - os.remove(fname) - - self.assertAllClose(mean, model.layers[1].arguments['mu']) - self.assertAllClose(std, model.layers[1].arguments['std']) - - -class TestSequential(test.TestCase): - """Most Sequential model API tests are covered in `training_test.py`. - """ - - def test_basic_methods(self): - model = keras.models.Sequential() - model.add(keras.layers.Dense(1, input_dim=2)) - model.add(keras.layers.Dropout(0.3, name='dp')) - model.add(keras.layers.Dense(2, kernel_regularizer='l2', - kernel_constraint='max_norm')) - model.build() - self.assertEqual(model.state_updates, model.model.state_updates) - self.assertEqual(model.get_layer(name='dp').name, 'dp') - - def test_sequential_pop(self): - num_hidden = 5 - input_dim = 3 - batch_size = 5 - num_classes = 2 - with self.test_session(): - model = keras.models.Sequential() - model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) - model.add(keras.layers.Dense(num_classes)) - model.compile(loss='mse', optimizer='sgd') - x = np.random.random((batch_size, input_dim)) - y = np.random.random((batch_size, num_classes)) - model.fit(x, y, epochs=1) - model.pop() - self.assertEqual(len(model.layers), 1) - self.assertEqual(model.output_shape, (None, num_hidden)) - model.compile(loss='mse', optimizer='sgd') - y = np.random.random((batch_size, num_hidden)) - model.fit(x, y, epochs=1) - - # Test popping single-layer model - model = keras.models.Sequential() - model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) - model.pop() - self.assertEqual(len(model.layers), 0) - self.assertEqual(len(model.outputs), 0) - - # Invalid use case - model = keras.models.Sequential() - with self.assertRaises(TypeError): - model.pop() - - def test_sequential_weight_loading(self): - if h5py is None: - return - - temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, temp_dir) - h5_path = os.path.join(temp_dir, 'test.h5') - - num_hidden = 5 - input_dim = 3 - batch_size = 5 - num_classes = 2 - - with self.test_session(): - model = keras.models.Sequential() - model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) - model.add(keras.layers.Dense(num_classes)) - - x = np.random.random((batch_size, input_dim)) - ref_y = model.predict(x) - - model.save_weights(h5_path) - - model = keras.models.Sequential() - model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) - model.add(keras.layers.Dense(num_classes)) - model.load_weights(h5_path) - y = model.predict(x) - - self.assertAllClose(y, ref_y) - - def test_invalid_use_cases(self): - with self.test_session(): - # Added objects must be layer instances - with self.assertRaises(TypeError): - model = keras.models.Sequential() - model.add(None) - - # Added layers must have an inputs shape - with self.assertRaises(ValueError): - model = keras.models.Sequential() - model.add(keras.layers.Dense(1)) - - # Added layers cannot have multiple outputs - class MyLayer(keras.layers.Layer): - - def call(self, inputs): - return [3 * inputs, 2 * inputs] - - def compute_output_shape(self, input_shape): - return [input_shape, input_shape] - - with self.assertRaises(ValueError): - model = keras.models.Sequential() - model.add(MyLayer(input_shape=(3,))) - with self.assertRaises(TypeError): - model = keras.models.Sequential() - model.add(keras.layers.Dense(1, input_dim=1)) - model.add(MyLayer()) - - # Building empty model - model = keras.models.Sequential() - with self.assertRaises(TypeError): - model.build() - - def test_nested_sequential_trainability(self): - input_dim = 20 - num_units = 10 - num_classes = 2 - - inner_model = keras.models.Sequential() - inner_model.add(keras.layers.Dense(num_units, input_shape=(input_dim,))) - - model = keras.models.Sequential() - model.add(inner_model) - model.add(keras.layers.Dense(num_classes)) - - self.assertEqual(len(model.trainable_weights), 4) - inner_model.trainable = False - self.assertEqual(len(model.trainable_weights), 2) - inner_model.trainable = True - self.assertEqual(len(model.trainable_weights), 4) - - def test_sequential_update_disabling(self): - val_a = np.random.random((10, 4)) - val_out = np.random.random((10, 4)) - - with self.test_session(): - model = keras.models.Sequential() - model.add(keras.layers.BatchNormalization(input_shape=(4,))) - - model.trainable = False - assert not model.updates - - model.compile('sgd', 'mse') - assert not model.updates - assert not model.model.updates - - x1 = model.predict(val_a) - model.train_on_batch(val_a, val_out) - x2 = model.predict(val_a) - self.assertAllClose(x1, x2, atol=1e-7) - - model.trainable = True - model.compile('sgd', 'mse') - assert model.updates - assert model.model.updates - - model.train_on_batch(val_a, val_out) - x2 = model.predict(val_a) - assert np.abs(np.sum(x1 - x2)) > 1e-5 class TestModelCloning(test.TestCase): diff --git a/tensorflow/python/keras/_impl/keras/optimizers.py b/tensorflow/python/keras/_impl/keras/optimizers.py index 6520128c5b65451aef20ec9626245fba5ef29927..b715d722b98b9db3bdf0985da0954356a2facdfe 100644 --- a/tensorflow/python/keras/_impl/keras/optimizers.py +++ b/tensorflow/python/keras/_impl/keras/optimizers.py @@ -95,7 +95,26 @@ class Optimizer(object): raise NotImplementedError def get_gradients(self, loss, params): + """Returns gradients of `loss` with respect to `params`. + + Arguments: + loss: Loss tensor. + params: List of variables. + + Returns: + List of gradient tensors. + + Raises: + ValueError: In case any gradient cannot be computed (e.g. if gradient + function not implemented). + """ grads = K.gradients(loss, params) + if None in grads: + raise ValueError('An operation has `None` for gradient. ' + 'Please make sure that all of your ops have a ' + 'gradient defined (i.e. are differentiable). ' + 'Common ops without gradient: ' + 'K.argmax, K.round, K.eval.') if hasattr(self, 'clipnorm') and self.clipnorm > 0: norm = K.sqrt(sum([K.sum(K.square(g)) for g in grads])) grads = [clip_norm(g, self.clipnorm, norm) for g in grads] @@ -120,6 +139,11 @@ class Optimizer(object): ValueError: in case of incompatible weight shapes. """ params = self.weights + if len(params) != len(weights): + raise ValueError( + 'Length of the specified weight list (' + str(len(weights)) + + ') does not match the number of weights ' + 'of the optimizer (' + str(len(params)) + ')') weight_value_tuples = [] param_values = K.batch_get_value(params) for pv, p, w in zip(param_values, params, weights): diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/image.py b/tensorflow/python/keras/_impl/keras/preprocessing/image.py index d12f10863921ee7d635930f34e8bc701c89864e8..6299445c34b99f20d7ae5090fc979d0ac2611109 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/image.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/image.py @@ -43,6 +43,7 @@ except ImportError: try: + from PIL import ImageEnhance from PIL import Image as pil_image except ImportError: pil_image = None @@ -227,6 +228,32 @@ def random_channel_shift(x, intensity, channel_axis=0): return x +@tf_export('keras.preprocessing.image.random_brightness') +def random_brightness(x, brightness_range): + """Performs a random adjustment of brightness of a Numpy image tensor. + + Arguments: + x: Input tensor. Must be 3D. + brightness_range: Tuple of floats; range to pick a brightness value from. + + Returns: + Brightness adjusted Numpy image tensor. + + Raises: + ValueError: if `brightness_range` isn't a tuple. + """ + if len(brightness_range) != 2: + raise ValueError('`brightness_range should be tuple or list of two floats. ' + 'Received arg: ', brightness_range) + + x = array_to_img(x) + x = ImageEnhance.Brightness(x) + u = np.random.uniform(brightness_range[0], brightness_range[1]) + x = x.enhance(u) + x = img_to_array(x) + return x + + def transform_matrix_offset_center(matrix, x, y): o_x = float(x) / 2 + 0.5 o_y = float(y) / 2 + 0.5 @@ -265,7 +292,7 @@ def apply_transform(x, x_channel, final_affine_matrix, final_offset, - order=0, + order=1, mode=fill_mode, cval=cval) for x_channel in x ] @@ -436,6 +463,7 @@ class ImageDataGenerator(object): rotation_range: degrees (0 to 180). width_shift_range: fraction of total width, if < 1, or pixels if >= 1. height_shift_range: fraction of total height, if < 1, or pixels if >= 1. + brightness_range: the range of brightness to apply shear_range: shear intensity (shear angle in degrees). zoom_range: amount of zoom. if scalar z, zoom will be randomly picked in the range [1-z, 1+z]. A sequence of two can be passed instead @@ -469,6 +497,8 @@ class ImageDataGenerator(object): It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". + validation_split: fraction of images reserved for validation (strictly + between 0 and 1). """ def __init__(self, @@ -481,6 +511,7 @@ class ImageDataGenerator(object): rotation_range=0., width_shift_range=0., height_shift_range=0., + brightness_range=None, shear_range=0., zoom_range=0., channel_shift_range=0., @@ -490,7 +521,8 @@ class ImageDataGenerator(object): vertical_flip=False, rescale=None, preprocessing_function=None, - data_format=None): + data_format=None, + validation_split=0.0): if data_format is None: data_format = K.image_data_format() self.featurewise_center = featurewise_center @@ -502,6 +534,7 @@ class ImageDataGenerator(object): self.rotation_range = rotation_range self.width_shift_range = width_shift_range self.height_shift_range = height_shift_range + self.brightness_range = brightness_range self.shear_range = shear_range self.zoom_range = zoom_range self.channel_shift_range = channel_shift_range @@ -526,6 +559,10 @@ class ImageDataGenerator(object): self.channel_axis = 3 self.row_axis = 1 self.col_axis = 2 + if validation_split and not 0 < validation_split < 1: + raise ValueError('`validation_split` must be strictly between 0 and 1. ' + 'Received arg: ', validation_split) + self.validation_split = validation_split self.mean = None self.std = None @@ -574,7 +611,8 @@ class ImageDataGenerator(object): seed=None, save_to_dir=None, save_prefix='', - save_format='png'): + save_format='png', + subset=None): return NumpyArrayIterator( x, y, @@ -585,7 +623,8 @@ class ImageDataGenerator(object): data_format=self.data_format, save_to_dir=save_to_dir, save_prefix=save_prefix, - save_format=save_format) + save_format=save_format, + subset=subset) def flow_from_directory(self, directory, @@ -600,6 +639,7 @@ class ImageDataGenerator(object): save_prefix='', save_format='png', follow_links=False, + subset=None, interpolation='nearest'): return DirectoryIterator( directory, @@ -616,6 +656,7 @@ class ImageDataGenerator(object): save_prefix=save_prefix, save_format=save_format, follow_links=follow_links, + subset=subset, interpolation=interpolation) def standardize(self, x): @@ -628,7 +669,7 @@ class ImageDataGenerator(object): The inputs, normalized. """ if self.preprocessing_function: - x = self.preprocessing_function(x) + x = self.image_data_generator.preprocessing_function(x) if self.rescale: x *= self.rescale if self.samplewise_center: @@ -762,6 +803,9 @@ class ImageDataGenerator(object): if np.random.random() < 0.5: x = flip_axis(x, img_row_axis) + if self.brightness_range is not None: + x = random_brightness(x, self.brightness_range) + return x def fit(self, x, augment=False, rounds=1, seed=None): @@ -828,12 +872,10 @@ class ImageDataGenerator(object): raise ImportError('Scipy is required for zca_whitening.') flat_x = np.reshape(x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])) - num_examples = flat_x.shape[0] - _, s, vt = linalg.svd(flat_x / np.sqrt(num_examples)) - s_expand = np.hstack( - (s, np.zeros(vt.shape[0] - num_examples, dtype=flat_x.dtype))) - self.principal_components = ( - vt.T / np.sqrt(s_expand**2 + self.zca_epsilon)).dot(vt) + sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0] + u, s, _ = linalg.svd(sigma) + s_inv = 1. / np.sqrt(s[np.newaxis] + self.zca_epsilon) + self.principal_components = (u * s_inv).dot(u.T) @tf_export('keras.preprocessing.image.Iterator') @@ -947,6 +989,8 @@ class NumpyArrayIterator(Iterator): images (if `save_to_dir` is set). save_format: Format to use for saving sample images (if `save_to_dir` is set). + subset: Subset of data (`"training"` or `"validation"`) if + validation_split is set in ImageDataGenerator. """ def __init__(self, @@ -959,17 +1003,29 @@ class NumpyArrayIterator(Iterator): data_format=None, save_to_dir=None, save_prefix='', - save_format='png'): + save_format='png', + subset=None): if y is not None and len(x) != len(y): - raise ValueError('X (images tensor) and y (labels) ' + raise ValueError('`x` (images tensor) and `y` (labels) ' 'should have the same length. ' - 'Found: X.shape = %s, y.shape = %s' % + 'Found: x.shape = %s, y.shape = %s' % (np.asarray(x).shape, np.asarray(y).shape)) - + if subset is not None: + if subset not in {'training', 'validation'}: + raise ValueError('Invalid subset name:', subset, + '; expected "training" or "validation".') + split_idx = int(len(x) * image_data_generator.validation_split) + if subset == 'validation': + x = x[:split_idx] + if y is not None: + y = y[:split_idx] + else: + x = x[split_idx:] + if y is not None: + y = y[split_idx:] if data_format is None: data_format = K.image_data_format() self.x = np.asarray(x, dtype=K.floatx()) - if self.x.ndim != 4: raise ValueError('Input data in `NumpyArrayIterator` ' 'should have rank 4. You passed an array ' @@ -1032,8 +1088,7 @@ class NumpyArrayIterator(Iterator): return self._get_batches_of_transformed_samples(index_array) -def _count_valid_files_in_directory(directory, white_list_formats, - follow_links): +def _iter_valid_files(directory, white_list_formats, follow_links): """Count files with extension in `white_list_formats` contained in directory. Arguments: @@ -1043,29 +1098,54 @@ def _count_valid_files_in_directory(directory, white_list_formats, the files to be counted. follow_links: boolean. - Returns: - the count of files with extension in `white_list_formats` contained in - the directory. + Yields: + tuple of (root, filename) with extension in `white_list_formats`. """ def _recursive_list(subpath): return sorted( - os.walk(subpath, followlinks=follow_links), key=lambda tpl: tpl[0]) + os.walk(subpath, followlinks=follow_links), key=lambda x: x[0]) - samples = 0 - for _, _, files in _recursive_list(directory): - for fname in files: - is_valid = False + for root, _, files in _recursive_list(directory): + for fname in sorted(files): for extension in white_list_formats: + if fname.lower().endswith('.tiff'): + logging.warning( + 'Using \'.tiff\' files with multiple bands will cause ' + 'distortion. Please verify your output.') if fname.lower().endswith('.' + extension): - is_valid = True - break - if is_valid: - samples += 1 - return samples + yield root, fname -def _list_valid_filenames_in_directory(directory, white_list_formats, +def _count_valid_files_in_directory(directory, white_list_formats, split, + follow_links): + """Count files with extension in `white_list_formats` contained in directory. + + Arguments: + directory: absolute path to the directory + containing files to be counted + white_list_formats: set of strings containing allowed extensions for + the files to be counted. + split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into + account a certain fraction of files in each directory. + E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent + of images in each directory. + follow_links: boolean. + + Returns: + the count of files with extension in `white_list_formats` contained in + the directory. + """ + num_files = len( + list(_iter_valid_files(directory, white_list_formats, follow_links))) + if split: + start, stop = int(split[0] * num_files), int(split[1] * num_files) + else: + start, stop = 0, num_files + return stop - start + + +def _list_valid_filenames_in_directory(directory, white_list_formats, split, class_indices, follow_links): """List paths of files in `subdir` with extensions in `white_list_formats`. @@ -1075,6 +1155,10 @@ def _list_valid_filenames_in_directory(directory, white_list_formats, `class_indices`. white_list_formats: set of strings containing allowed extensions for the files to be counted. + split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into + account a certain fraction of files in each directory. + E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent + of images in each directory. class_indices: dictionary mapping a class name to its index. follow_links: boolean. @@ -1084,27 +1168,26 @@ def _list_valid_filenames_in_directory(directory, white_list_formats, `directory`'s parent (e.g., if `directory` is "dataset/class1", the filenames will be ["class1/file1.jpg", "class1/file2.jpg", ...]). """ - - def _recursive_list(subpath): - return sorted( - os.walk(subpath, followlinks=follow_links), key=lambda tpl: tpl[0]) + dirname = os.path.basename(directory) + if split: + num_files = len( + list(_iter_valid_files(directory, white_list_formats, follow_links))) + start, stop = int(split[0] * num_files), int(split[1] * num_files) + valid_files = list( + _iter_valid_files(directory, white_list_formats, + follow_links))[start:stop] + else: + valid_files = _iter_valid_files(directory, white_list_formats, follow_links) classes = [] filenames = [] - subdir = os.path.basename(directory) - basedir = os.path.dirname(directory) - for root, _, files in _recursive_list(directory): - for fname in sorted(files): - is_valid = False - for extension in white_list_formats: - if fname.lower().endswith('.' + extension): - is_valid = True - break - if is_valid: - classes.append(class_indices[subdir]) - # add filename relative to directory - absolute_path = os.path.join(root, fname) - filenames.append(os.path.relpath(absolute_path, basedir)) + for root, fname in valid_files: + classes.append(class_indices[dirname]) + absolute_path = os.path.join(root, fname) + relative_path = os.path.join(dirname, + os.path.relpath(absolute_path, directory)) + filenames.append(relative_path) + return classes, filenames @@ -1144,6 +1227,8 @@ class DirectoryIterator(Iterator): images (if `save_to_dir` is set). save_format: Format to use for saving sample images (if `save_to_dir` is set). + subset: Subset of data (`"training"` or `"validation"`) if + validation_split is set in ImageDataGenerator. interpolation: Interpolation method used to resample the image if the target size is different from that of the loaded image. Supported methods are "nearest", "bilinear", and "bicubic". @@ -1167,6 +1252,7 @@ class DirectoryIterator(Iterator): save_prefix='', save_format='png', follow_links=False, + subset=None, interpolation='nearest'): if data_format is None: data_format = K.image_data_format() @@ -1200,7 +1286,20 @@ class DirectoryIterator(Iterator): self.save_format = save_format self.interpolation = interpolation - white_list_formats = {'png', 'jpg', 'jpeg', 'bmp', 'ppm'} + if subset is not None: + validation_split = self.image_data_generator.validation_split + if subset == 'validation': + split = (0, validation_split) + elif subset == 'training': + split = (validation_split, 1) + else: + raise ValueError('Invalid subset name: ', subset, + '; expected "training" or "validation"') + else: + split = None + self.subset = subset + + white_list_formats = {'png', 'jpg', 'jpeg', 'bmp', 'ppm', 'tif', 'tiff'} # first, count the number of samples and classes self.samples = 0 @@ -1217,7 +1316,8 @@ class DirectoryIterator(Iterator): function_partial = partial( _count_valid_files_in_directory, white_list_formats=white_list_formats, - follow_links=follow_links) + follow_links=follow_links, + split=split) self.samples = sum( pool.map(function_partial, (os.path.join(directory, subdir) for subdir in classes))) @@ -1233,14 +1333,15 @@ class DirectoryIterator(Iterator): i = 0 for dirpath in (os.path.join(directory, subdir) for subdir in classes): results.append( - pool.apply_async( - _list_valid_filenames_in_directory, - (dirpath, white_list_formats, self.class_indices, follow_links))) + pool.apply_async(_list_valid_filenames_in_directory, + (dirpath, white_list_formats, split, + self.class_indices, follow_links))) for res in results: classes, filenames = res.get() self.classes[i:i + len(classes)] = classes self.filenames += filenames i += len(classes) + pool.close() pool.join() super(DirectoryIterator, self).__init__(self.samples, batch_size, shuffle, diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py b/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py index c0790b5a5140193b18907d9375530f4f06e137da..001fee91f9ed609c0b3cd88d4079e75c0e585b02 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import os import shutil +import tempfile import numpy as np @@ -74,6 +75,7 @@ class TestImage(test.TestCase): shear_range=0.5, zoom_range=0.2, channel_shift_range=0., + brightness_range=(1, 5), fill_mode='nearest', cval=0.5, horizontal_flip=True, @@ -92,6 +94,47 @@ class TestImage(test.TestCase): self.assertEqual(x.shape[1:], images.shape[1:]) break + def test_image_data_generator_with_validation_split(self): + if PIL is None: + return # Skip test if PIL is not available. + + for test_images in _generate_test_images(): + img_list = [] + for im in test_images: + img_list.append(keras.preprocessing.image.img_to_array(im)[None, ...]) + + images = np.vstack(img_list) + generator = keras.preprocessing.image.ImageDataGenerator( + validation_split=0.5) + seq = generator.flow( + images, + np.arange(images.shape[0]), + shuffle=False, + batch_size=3, + subset='validation') + _, y = seq[0] + self.assertEqual(list(y), [0, 1, 2]) + seq = generator.flow( + images, + np.arange(images.shape[0]), + shuffle=False, + batch_size=3, + subset='training') + _, y2 = seq[0] + self.assertEqual(list(y2), [4, 5, 6]) + + with self.assertRaises(ValueError): + generator.flow( + images, + np.arange(images.shape[0]), + shuffle=False, + batch_size=3, + subset='foo') + + def test_image_data_generator_with_split_value_error(self): + with self.assertRaises(ValueError): + keras.preprocessing.image.ImageDataGenerator(validation_split=5) + def test_image_data_generator_invalid_data(self): generator = keras.preprocessing.image.ImageDataGenerator( featurewise_center=True, @@ -202,9 +245,80 @@ class TestImage(test.TestCase): # check number of classes and images self.assertEqual(len(dir_iterator.class_indices), num_classes) self.assertEqual(len(dir_iterator.classes), count) - self.assertEqual(sorted(dir_iterator.filenames), sorted(filenames)) + self.assertEqual(set(dir_iterator.filenames), set(filenames)) _ = dir_iterator.next() + def directory_iterator_with_validation_split_test_helper( + self, validation_split): + if PIL is None: + return # Skip test if PIL is not available. + + num_classes = 2 + tmp_folder = tempfile.mkdtemp(prefix='test_images') + + # create folders and subfolders + paths = [] + for cl in range(num_classes): + class_directory = 'class-{}'.format(cl) + classpaths = [ + class_directory, + os.path.join(class_directory, 'subfolder-1'), + os.path.join(class_directory, 'subfolder-2'), + os.path.join(class_directory, 'subfolder-1', 'sub-subfolder') + ] + for path in classpaths: + os.mkdir(os.path.join(tmp_folder, path)) + paths.append(classpaths) + + # save the images in the paths + count = 0 + filenames = [] + for test_images in _generate_test_images(): + for im in test_images: + # rotate image class + im_class = count % num_classes + # rotate subfolders + classpaths = paths[im_class] + filename = os.path.join(classpaths[count % len(classpaths)], + 'image-{}.jpg'.format(count)) + filenames.append(filename) + im.save(os.path.join(tmp_folder, filename)) + count += 1 + + # create iterator + generator = keras.preprocessing.image.ImageDataGenerator( + validation_split=validation_split) + + with self.assertRaises(ValueError): + generator.flow_from_directory(tmp_folder, subset='foo') + + num_validation = int(count * validation_split) + num_training = count - num_validation + train_iterator = generator.flow_from_directory( + tmp_folder, subset='training') + self.assertEqual(train_iterator.samples, num_training) + + valid_iterator = generator.flow_from_directory( + tmp_folder, subset='validation') + self.assertEqual(valid_iterator.samples, num_validation) + + # check number of classes and images + self.assertEqual(len(train_iterator.class_indices), num_classes) + self.assertEqual(len(train_iterator.classes), num_training) + self.assertEqual( + len(set(train_iterator.filenames) & set(filenames)), num_training) + + shutil.rmtree(tmp_folder) + + def test_directory_iterator_with_validation_split_25_percent(self): + self.directory_iterator_with_validation_split_test_helper(0.25) + + def test_directory_iterator_with_validation_split_40_percent(self): + self.directory_iterator_with_validation_split_test_helper(0.40) + + def test_directory_iterator_with_validation_split_50_percent(self): + self.directory_iterator_with_validation_split_test_helper(0.50) + def test_img_utils(self): if PIL is None: return # Skip test if PIL is not available. @@ -241,6 +355,41 @@ class TestImage(test.TestCase): x = keras.preprocessing.image.img_to_array(img, data_format='channels_last') self.assertEqual(x.shape, (height, width, 1)) + def test_batch_standardize(self): + if PIL is None: + return # Skip test if PIL is not available. + + # ImageDataGenerator.standardize should work on batches + for test_images in _generate_test_images(): + img_list = [] + for im in test_images: + img_list.append(keras.preprocessing.image.img_to_array(im)[None, ...]) + + images = np.vstack(img_list) + generator = keras.preprocessing.image.ImageDataGenerator( + featurewise_center=True, + samplewise_center=True, + featurewise_std_normalization=True, + samplewise_std_normalization=True, + zca_whitening=True, + rotation_range=90., + width_shift_range=0.1, + height_shift_range=0.1, + shear_range=0.5, + zoom_range=0.2, + channel_shift_range=0., + brightness_range=(1, 5), + fill_mode='nearest', + cval=0.5, + horizontal_flip=True, + vertical_flip=True) + generator.fit(images, augment=True) + + transformed = np.copy(images) + for i, im in enumerate(transformed): + transformed[i] = generator.random_transform(im) + transformed = generator.standardize(transformed) + def test_img_transforms(self): x = np.random.random((3, 200, 200)) _ = keras.preprocessing.image.random_rotation(x, 20) diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py b/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py index a423d96d3d8578df347b7ee36fb53dfd335e0d65..e68c171d9c7e33d7e932f5d5b7f15859faa2348b 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py @@ -22,6 +22,8 @@ import random import numpy as np from six.moves import range # pylint: disable=redefined-builtin + +from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence from tensorflow.python.util.tf_export import tf_export @@ -32,29 +34,40 @@ def pad_sequences(sequences, padding='pre', truncating='pre', value=0.): - """Pads each sequence to the same length (length of the longest sequence). + """Pads sequences to the same length. + + This function transforms a list of + `num_samples` sequences (lists of integers) + into a 2D Numpy array of shape `(num_samples, num_timesteps)`. + `num_timesteps` is either the `maxlen` argument if provided, + or the length of the longest sequence otherwise. + + Sequences that are shorter than `num_timesteps` + are padded with `value` at the end. - If maxlen is provided, any sequence longer - than maxlen is truncated to maxlen. - Truncation happens off either the beginning (default) or - the end of the sequence. + Sequences longer than `num_timesteps` are truncated + so that they fit the desired length. + The position where padding or truncation happens is determined by + the arguments `padding` and `truncating`, respectively. - Supports post-padding and pre-padding (default). + Pre-padding is the default. Arguments: - sequences: list of lists where each element is a sequence - maxlen: int, maximum length - dtype: type to cast the resulting sequence. - padding: 'pre' or 'post', pad either before or after each sequence. - truncating: 'pre' or 'post', remove values from sequences larger than - maxlen either in the beginning or in the end of the sequence - value: float, value to pad the sequences to the desired value. + sequences: List of lists, where each element is a sequence. + maxlen: Int, maximum length of all sequences. + dtype: Type of the output sequences. + padding: String, 'pre' or 'post': + pad either before or after each sequence. + truncating: String, 'pre' or 'post': + remove values from sequences larger than + `maxlen`, either at the beginning or at the end of the sequences. + value: Float, padding value. Returns: - x: numpy array with dimensions (number_of_sequences, maxlen) + x: Numpy array with shape `(len(sequences), maxlen)` Raises: - ValueError: in case of invalid values for `truncating` or `padding`, + ValueError: In case of invalid values for `truncating` or `padding`, or in case of invalid shape for a `sequences` entry. """ if not hasattr(sequences, '__len__'): @@ -92,10 +105,9 @@ def pad_sequences(sequences, # check `trunc` has expected shape trunc = np.asarray(trunc, dtype=dtype) if trunc.shape[1:] != sample_shape: - raise ValueError( - 'Shape of sample %s of sequence at position %s is different from ' - 'expected shape %s' - % (trunc.shape[1:], idx, sample_shape)) + raise ValueError('Shape of sample %s of sequence at position %s ' + 'is different from expected shape %s' % + (trunc.shape[1:], idx, sample_shape)) if padding == 'post': x[idx, :len(trunc)] = trunc @@ -110,22 +122,26 @@ def pad_sequences(sequences, def make_sampling_table(size, sampling_factor=1e-5): """Generates a word rank-based probabilistic sampling table. - This generates an array where the ith element - is the probability that a word of rank i would be sampled, - according to the sampling distribution used in word2vec. + Used for generating the `sampling_table` argument for `skipgrams`. + `sampling_table[i]` is the probability of sampling + the word i-th most common word in a dataset + (more common words should be sampled less frequently, for balance). - The word2vec formula is: - p(word) = min(1, sqrt(word.frequency/sampling_factor) / - (word.frequency/sampling_factor)) + The sampling probabilities are generated according + to the sampling distribution used in word2vec: + + `p(word) = min(1, sqrt(word_frequency / sampling_factor) / (word_frequency / + sampling_factor))` We assume that the word frequencies follow Zipf's law (s=1) to derive a numerical approximation of frequency(rank): - frequency(rank) ~ 1/(rank * (log(rank) + gamma) + 1/2 - 1/(12*rank)) - where gamma is the Euler-Mascheroni constant. + + `frequency(rank) ~ 1/(rank * (log(rank) + gamma) + 1/2 - 1/(12*rank))` + where `gamma` is the Euler-Mascheroni constant. Arguments: - size: int, number of possible words to sample. - sampling_factor: the sampling factor in the word2vec formula. + size: Int, number of possible words to sample. + sampling_factor: The sampling factor in the word2vec formula. Returns: A 1D Numpy array of length `size` where the ith entry @@ -151,30 +167,37 @@ def skipgrams(sequence, seed=None): """Generates skipgram word pairs. - Takes a sequence (list of indexes of words), - returns couples of [word_index, other_word index] and labels (1s or 0s), - where label = 1 if 'other_word' belongs to the context of 'word', - and label=0 if 'other_word' is randomly sampled + This function transforms a sequence of word indexes (list of integers) + into tuples of words of the form: + + - (word, word in the same window), with label 1 (positive samples). + - (word, random word from the vocabulary), with label 0 (negative samples). + + Read more about Skipgram in this gnomic paper by Mikolov et al.: + [Efficient Estimation of Word Representations in + Vector Space](http://arxiv.org/pdf/1301.3781v3.pdf) Arguments: - sequence: a word sequence (sentence), encoded as a list + sequence: A word sequence (sentence), encoded as a list of word indices (integers). If using a `sampling_table`, word indices are expected to match the rank of the words in a reference dataset (e.g. 10 would encode the 10-th most frequently occurring token). Note that index 0 is expected to be a non-word and will be skipped. - vocabulary_size: int. maximum possible word index + 1 - window_size: int. actually half-window. - The window of a word wi will be [i-window_size, i+window_size+1] - negative_samples: float >= 0. 0 for no negative (=random) samples. - 1 for same number as positive samples. etc. - shuffle: whether to shuffle the word couples before returning them. + vocabulary_size: Int, maximum possible word index + 1 + window_size: Int, size of sampling windows (technically half-window). + The window of a word `w_i` will be + `[i - window_size, i + window_size+1]`. + negative_samples: Float >= 0. 0 for no negative (i.e. random) samples. + 1 for same number as positive samples. + shuffle: Whether to shuffle the word couples before returning them. categorical: bool. if False, labels will be - integers (eg. [0, 1, 1 .. ]), - if True labels will be categorical eg. [[1,0],[0,1],[0,1] .. ] + integers (eg. `[0, 1, 1 .. ]`), + if `True`, labels will be categorical, e.g. + `[[1,0],[0,1],[0,1] .. ]`. sampling_table: 1D array of size `vocabulary_size` where the entry i encodes the probability to sample a word of rank i. - seed: random seed. + seed: Random seed. Returns: couples, labels: where `couples` are int pairs and @@ -234,9 +257,9 @@ def _remove_long_seq(maxlen, seq, label): """Removes sequences that exceed the maximum length. Arguments: - maxlen: int, maximum length - seq: list of lists where each sublist is a sequence - label: list where each element is an integer + maxlen: Int, maximum length of the output sequences. + seq: List of lists, where each sublist is a sequence. + label: List where each element is an integer. Returns: new_seq, new_label: shortened lists for `seq` and `label`. @@ -247,3 +270,120 @@ def _remove_long_seq(maxlen, seq, label): new_seq.append(x) new_label.append(y) return new_seq, new_label + + +@tf_export('keras.preprocessing.sequence.TimeseriesGenerator') +class TimeseriesGenerator(Sequence): + """Utility class for generating batches of temporal data. + + This class takes in a sequence of data-points gathered at + equal intervals, along with time series parameters such as + stride, length of history, etc., to produce batches for + training/validation. + + Arguments: + data: Indexable generator (such as list or Numpy array) + containing consecutive data points (timesteps). + The data should be at 2D, and axis 0 is expected + to be the time dimension. + targets: Targets corresponding to timesteps in `data`. + It should have same length as `data`. + length: Length of the output sequences (in number of timesteps). + sampling_rate: Period between successive individual timesteps + within sequences. For rate `r`, timesteps + `data[i]`, `data[i-r]`, ... `data[i - length]` + are used for create a sample sequence. + stride: Period between successive output sequences. + For stride `s`, consecutive output samples would + be centered around `data[i]`, `data[i+s]`, `data[i+2*s]`, etc. + start_index, end_index: Data points earlier than `start_index` + or later than `end_index` will not be used in the output sequences. + This is useful to reserve part of the data for test or validation. + shuffle: Whether to shuffle output samples, + or instead draw them in chronological order. + reverse: Boolean: if `true`, timesteps in each output sample will be + in reverse chronological order. + batch_size: Number of timeseries samples in each batch + (except maybe the last one). + + Returns: + A [Sequence](/utils/#sequence) instance. + + Examples: + + ```python + from keras.preprocessing.sequence import TimeseriesGenerator + import numpy as np + + data = np.array([[i] for i in range(50)]) + targets = np.array([[i] for i in range(50)]) + + data_gen = TimeseriesGenerator(data, targets, + length=10, sampling_rate=2, + batch_size=2) + assert len(data_gen) == 20 + + batch_0 = data_gen[0] + x, y = batch_0 + assert np.array_equal(x, + np.array([[[0], [2], [4], [6], [8]], + [[1], [3], [5], [7], [9]]])) + assert np.array_equal(y, + np.array([[10], [11]])) + ``` + """ + + def __init__(self, + data, + targets, + length, + sampling_rate=1, + stride=1, + start_index=0, + end_index=None, + shuffle=False, + reverse=False, + batch_size=128): + self.data = data + self.targets = targets + self.length = length + self.sampling_rate = sampling_rate + self.stride = stride + self.start_index = start_index + length + if end_index is None: + end_index = len(data) - 1 + self.end_index = end_index + self.shuffle = shuffle + self.reverse = reverse + self.batch_size = batch_size + + def __len__(self): + length = int( + np.ceil((self.end_index - self.start_index) / + (self.batch_size * self.stride))) + return length if length >= 0 else 0 + + def _empty_batch(self, num_rows): + samples_shape = [num_rows, self.length // self.sampling_rate] + samples_shape.extend(self.data.shape[1:]) + targets_shape = [num_rows] + targets_shape.extend(self.targets.shape[1:]) + return np.empty(samples_shape), np.empty(targets_shape) + + def __getitem__(self, index): + if self.shuffle: + rows = np.random.randint( + self.start_index, self.end_index, size=self.batch_size) + else: + i = self.start_index + self.batch_size * self.stride * index + rows = np.arange(i, min(i + self.batch_size * self.stride, + self.end_index), self.stride) + + samples, targets = self._empty_batch(len(rows)) + for j in range(len(rows)): + indices = range(rows[j] - self.length, rows[j], self.sampling_rate) + samples[j] = self.data[indices] + targets[j] = self.targets[rows[j]] + if self.reverse: + return samples[:, ::-1, ...], targets + return samples, targets diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py b/tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py index 4529e6e94fc42661fb0474c1a827863ddb654776..b9bfdd000484665e8771f4bccef59738e5c26120 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py @@ -84,15 +84,91 @@ class TestSequence(test.TestCase): couples, labels = keras.preprocessing.sequence.skipgrams( np.arange(3), vocabulary_size=3) for couple in couples: - assert couple[0] in [0, 1, 2] and couple[1] in [0, 1, 2] + self.assertIn(couple[0], [0, 1, 2]) + self.assertIn(couple[1], [0, 1, 2]) # test window size and categorical labels couples, labels = keras.preprocessing.sequence.skipgrams( np.arange(5), vocabulary_size=5, window_size=1, categorical=True) for couple in couples: - assert couple[0] - couple[1] <= 3 + self.assertLessEqual(couple[0] - couple[1], 3) for l in labels: - assert len(l) == 2 + self.assertEqual(len(l), 2) + + def test_TimeseriesGenerator(self): + data = np.array([[i] for i in range(50)]) + targets = np.array([[i] for i in range(50)]) + + data_gen = keras.preprocessing.sequence.TimeseriesGenerator( + data, targets, length=10, sampling_rate=2, batch_size=2) + self.assertEqual(len(data_gen), 20) + self.assertAllClose(data_gen[0][0], + np.array([[[0], [2], [4], [6], [8]], [[1], [3], [5], + [7], [9]]])) + self.assertAllClose(data_gen[0][1], np.array([[10], [11]])) + self.assertAllClose(data_gen[1][0], + np.array([[[2], [4], [6], [8], [10]], [[3], [5], [7], + [9], [11]]])) + self.assertAllClose(data_gen[1][1], np.array([[12], [13]])) + + data_gen = keras.preprocessing.sequence.TimeseriesGenerator( + data, targets, length=10, sampling_rate=2, reverse=True, batch_size=2) + self.assertEqual(len(data_gen), 20) + self.assertAllClose(data_gen[0][0], + np.array([[[8], [6], [4], [2], [0]], [[9], [7], [5], + [3], [1]]])) + self.assertAllClose(data_gen[0][1], np.array([[10], [11]])) + + data_gen = keras.preprocessing.sequence.TimeseriesGenerator( + data, targets, length=10, sampling_rate=2, shuffle=True, batch_size=1) + batch = data_gen[0] + r = batch[1][0][0] + self.assertAllClose(batch[0], + np.array([[[r - 10], [r - 8], [r - 6], [r - 4], + [r - 2]]])) + self.assertAllClose(batch[1], np.array([ + [r], + ])) + + data_gen = keras.preprocessing.sequence.TimeseriesGenerator( + data, targets, length=10, sampling_rate=2, stride=2, batch_size=2) + self.assertEqual(len(data_gen), 10) + self.assertAllClose(data_gen[1][0], + np.array([[[4], [6], [8], [10], [12]], [[6], [8], [10], + [12], [14]]])) + self.assertAllClose(data_gen[1][1], np.array([[14], [16]])) + + data_gen = keras.preprocessing.sequence.TimeseriesGenerator( + data, + targets, + length=10, + sampling_rate=2, + start_index=10, + end_index=30, + batch_size=2) + self.assertEqual(len(data_gen), 5) + self.assertAllClose(data_gen[0][0], + np.array([[[10], [12], [14], [16], [18]], + [[11], [13], [15], [17], [19]]])) + self.assertAllClose(data_gen[0][1], np.array([[20], [21]])) + + data = np.array([np.random.random_sample((1, 2, 3, 4)) for i in range(50)]) + targets = np.array([np.random.random_sample((3, 2, 1)) for i in range(50)]) + data_gen = keras.preprocessing.sequence.TimeseriesGenerator( + data, + targets, + length=10, + sampling_rate=2, + start_index=10, + end_index=30, + batch_size=2) + + self.assertEqual(len(data_gen), 5) + self.assertAllClose(data_gen[0][0], + np.array( + [np.array(data[10:19:2]), + np.array(data[11:20:2])])) + self.assertAllClose(data_gen[0][1], np.array([targets[20], targets[21]])) if __name__ == '__main__': diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/text.py b/tensorflow/python/keras/_impl/keras/preprocessing/text.py index 1e3828ccf1e3bf9c443691e1c1da5697bedb4653..f652f318f3d6dae20b1113a50cd02930abb851af 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/text.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/text.py @@ -91,6 +91,7 @@ def one_hot(text, text, n, hash_function=hash, filters=filters, lower=lower, split=split) +@tf_export('keras.preprocessing.text.hashing_trick') def hashing_trick(text, n, hash_function=None, @@ -187,21 +188,27 @@ class Tokenizer(object): self.document_count = 0 self.char_level = char_level self.oov_token = oov_token + self.index_docs = {} def fit_on_texts(self, texts): """Updates internal vocabulary based on a list of texts. + In the case where texts contains lists, we assume each entry of the lists + to be a token. + Required before using `texts_to_sequences` or `texts_to_matrix`. Arguments: texts: can be a list of strings, - or a generator of strings (for memory-efficiency) + a generator of strings (for memory-efficiency), + or a list of list of strings. """ - self.document_count = 0 for text in texts: self.document_count += 1 - seq = text if self.char_level else text_to_word_sequence( - text, self.filters, self.lower, self.split) + if self.char_level or isinstance(text, list): + seq = text + else: + seq = text_to_word_sequence(text, self.filters, self.lower, self.split) for w in seq: if w in self.word_counts: self.word_counts[w] += 1 @@ -226,7 +233,6 @@ class Tokenizer(object): if i is None: self.word_index[self.oov_token] = len(self.word_index) + 1 - self.index_docs = {} for w, c in list(self.word_docs.items()): self.index_docs[self.word_index[w]] = c @@ -240,8 +246,7 @@ class Tokenizer(object): sequences: A list of sequence. A "sequence" is a list of integer word indices. """ - self.document_count = len(sequences) - self.index_docs = {} + self.document_count += len(sequences) for seq in sequences: seq = set(seq) for i in seq: @@ -268,7 +273,11 @@ class Tokenizer(object): return res def texts_to_sequences_generator(self, texts): - """Transforms each text in texts in a sequence of integers. + """Transforms each text in `texts` in a sequence of integers. + + Each item in texts can also be a list, in which case we assume each item of + that list + to be a token. Only top "num_words" most frequent words will be taken into account. Only words known by the tokenizer will be taken into account. @@ -281,8 +290,10 @@ class Tokenizer(object): """ num_words = self.num_words for text in texts: - seq = text if self.char_level else text_to_word_sequence( - text, self.filters, self.lower, self.split) + if self.char_level or isinstance(text, list): + seq = text + else: + seq = text_to_word_sequence(text, self.filters, self.lower, self.split) vect = [] for w in seq: i = self.word_index.get(w) diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py b/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py index a934e331c4a14d9bd170258b6b6183e6a15bb561..c6a267e57e4e2dc04156483d1cf85a42a78eb395 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -80,17 +81,52 @@ class TestText(test.TestCase): x_train = ['This text has only known words'] x_test = ['This text has some unknown words'] # 2 OOVs: some, unknown - # Defalut, without OOV flag + # Default, without OOV flag tokenizer = keras.preprocessing.text.Tokenizer() tokenizer.fit_on_texts(x_train) x_test_seq = tokenizer.texts_to_sequences(x_test) - assert len(x_test_seq[0]) == 4 # discards 2 OOVs + self.assertEqual(len(x_test_seq[0]), 4) # discards 2 OOVs # With OOV feature tokenizer = keras.preprocessing.text.Tokenizer(oov_token='') tokenizer.fit_on_texts(x_train) x_test_seq = tokenizer.texts_to_sequences(x_test) - assert len(x_test_seq[0]) == 6 # OOVs marked in place + self.assertEqual(len(x_test_seq[0]), 6) # OOVs marked in place + + def test_sequential_fit(self): + texts = [ + 'The cat sat on the mat.', 'The dog sat on the log.', + 'Dogs and cats living together.' + ] + word_sequences = [['The', 'cat', 'is', 'sitting'], + ['The', 'dog', 'is', 'standing']] + tokenizer = keras.preprocessing.text.Tokenizer() + tokenizer.fit_on_texts(texts) + tokenizer.fit_on_texts(word_sequences) + + self.assertEqual(tokenizer.document_count, 5) + + tokenizer.texts_to_matrix(texts) + tokenizer.texts_to_matrix(word_sequences) + + def test_text_to_word_sequence(self): + text = 'hello! ? world!' + seq = keras.preprocessing.text.text_to_word_sequence(text) + self.assertEqual(seq, ['hello', 'world']) + + def test_text_to_word_sequence_unicode(self): + text = u'ali! veli? kırk dokuz elli' + seq = keras.preprocessing.text.text_to_word_sequence(text) + self.assertEqual(seq, [u'ali', u'veli', u'kırk', u'dokuz', u'elli']) + + def test_tokenizer_unicode(self): + texts = [ + u'ali veli kırk dokuz elli', u'ali veli kırk dokuz elli veli kırk dokuz' + ] + tokenizer = keras.preprocessing.text.Tokenizer(num_words=5) + tokenizer.fit_on_texts(texts) + + self.assertEqual(len(tokenizer.word_counts), 5) if __name__ == '__main__': diff --git a/tensorflow/python/keras/_impl/keras/utils/__init__.py b/tensorflow/python/keras/_impl/keras/utils/__init__.py index 370ae0dd0f0d00059f1b0cc79459abe75c8ca494..0c9f19a0c8dcf3bf929e102b31679a03b27728f7 100644 --- a/tensorflow/python/keras/_impl/keras/utils/__init__.py +++ b/tensorflow/python/keras/_impl/keras/utils/__init__.py @@ -31,8 +31,8 @@ from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_ke from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model from tensorflow.python.keras._impl.keras.utils.layer_utils import print_summary +from tensorflow.python.keras._impl.keras.utils.multi_gpu_utils import multi_gpu_model from tensorflow.python.keras._impl.keras.utils.np_utils import normalize from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical -from tensorflow.python.keras._impl.keras.utils.training_utils import multi_gpu_model from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model diff --git a/tensorflow/python/keras/_impl/keras/utils/data_utils.py b/tensorflow/python/keras/_impl/keras/utils/data_utils.py index e87c8f48ef0967d561db1ab841a669d783f9b1ec..4c49544c6a63c4e5a0b79d31b074ad352c512bfa 100644 --- a/tensorflow/python/keras/_impl/keras/utils/data_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/data_utils.py @@ -393,6 +393,16 @@ class Sequence(object): """ pass + def __iter__(self): + """Creates an infinite generator that iterate over the Sequence. + + Yields: + Sequence items. + """ + while True: + for item in (self[i] for i in range(len(self))): + yield item + # Global variables to be shared across processes _SHARED_SEQUENCES = {} @@ -400,6 +410,11 @@ _SHARED_SEQUENCES = {} _SEQUENCE_COUNTER = None +def init_pool(seqs): + global _SHARED_SEQUENCES + _SHARED_SEQUENCES = seqs + + def get_index(uid, i): """Get the value from the Sequence `uid` at index `i`. @@ -532,9 +547,11 @@ class OrderedEnqueuer(SequenceEnqueuer): (when full, workers could block on `put()`) """ if self.use_multiprocessing: - self.executor_fn = lambda: multiprocessing.Pool(workers) + self.executor_fn = lambda seqs: multiprocessing.Pool( # pylint: disable=g-long-lambda + workers, initializer=init_pool, initargs=(seqs,)) else: - self.executor_fn = lambda: ThreadPool(workers) + # We do not need the init since it's threads. + self.executor_fn = lambda _: ThreadPool(workers) self.workers = workers self.queue = queue.Queue(max_queue_size) self.stop_signal = threading.Event() @@ -557,7 +574,7 @@ class OrderedEnqueuer(SequenceEnqueuer): if self.shuffle: random.shuffle(sequence) - with closing(self.executor_fn()) as executor: + with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: for i in sequence: if self.stop_signal.is_set(): return diff --git a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py index 462d600bf827768b0f2e6265aebdaad48e70fcd9..3bbe87f92d8f7eac27033344550ca65397eab986 100644 --- a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py @@ -490,8 +490,8 @@ def slice_arrays(arrays, start=None, stop=None): if arrays is None: return [None] if isinstance(start, list) and stop is not None: - raise ValueError('The stop argument has to be None if the value of start is' - 'a list.') + raise ValueError('The stop argument has to be None if the value of start ' + 'is a list.') elif isinstance(arrays, list): if hasattr(start, '__len__'): # hdf5 datasets only support list objects as indices @@ -509,3 +509,20 @@ def slice_arrays(arrays, start=None, stop=None): return arrays[start:stop] else: return [None] + + +def to_list(x): + """Normalizes a list/tensor into a list. + + If a tensor is passed, we return + a list of size 1 containing the tensor. + + Arguments: + x: target object to be normalized. + + Returns: + A list. + """ + if isinstance(x, list): + return x + return [x] diff --git a/tensorflow/python/keras/_impl/keras/utils/training_utils.py b/tensorflow/python/keras/_impl/keras/utils/multi_gpu_utils.py similarity index 98% rename from tensorflow/python/keras/_impl/keras/utils/training_utils.py rename to tensorflow/python/keras/_impl/keras/utils/multi_gpu_utils.py index ce7402e9d279278eaaf5aab58a3973eec6de8e99..231ace2a0b4a4f25cebf06a5216cf3d30aadc49b 100644 --- a/tensorflow/python/keras/_impl/keras/utils/training_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/multi_gpu_utils.py @@ -125,7 +125,7 @@ def multi_gpu_model(model, gpus): if gpus <= 1: raise ValueError('For multi-gpu usage to be effective, ' 'call `multi_gpu_model` with `gpus >= 2`. ' - 'Received: `gpus=%d`' % gpus) + 'Received: `gpus=%s`' % gpus) num_gpus = gpus target_gpu_ids = range(num_gpus) @@ -136,7 +136,7 @@ def multi_gpu_model(model, gpus): ] for device in target_devices: if device not in available_devices: - raise ValueError('To call `multi_gpu_model` with `gpus=%d`, ' + raise ValueError('To call `multi_gpu_model` with `gpus=%s`, ' 'we expect the following devices to be available: %s. ' 'However this machine only has: %s. ' 'Try reducing `gpus`.' % (gpus, target_devices, diff --git a/tensorflow/python/keras/_impl/keras/utils/training_utils_test.py b/tensorflow/python/keras/_impl/keras/utils/multi_gpu_utils_test.py similarity index 64% rename from tensorflow/python/keras/_impl/keras/utils/training_utils_test.py rename to tensorflow/python/keras/_impl/keras/utils/multi_gpu_utils_test.py index 12354c49ca72cddc0f395bcfcfabab18c1189227..0a38d6b5228fe791ce14adc7e37e0b7a6926fadf 100644 --- a/tensorflow/python/keras/_impl/keras/utils/training_utils_test.py +++ b/tensorflow/python/keras/_impl/keras/utils/multi_gpu_utils_test.py @@ -19,21 +19,34 @@ from __future__ import print_function import numpy as np - +from tensorflow.python import data from tensorflow.python.keras._impl import keras from tensorflow.python.platform import test +def check_if_compatible_devices(gpus=2): + available_devices = [ + keras.utils.multi_gpu_utils._normalize_device_name(name) + for name in keras.utils.multi_gpu_utils._get_available_devices() + ] + if '/gpu:%d' % (gpus - 1) not in available_devices: + return False + return True + + class TestMultiGPUModel(test.TestCase): - def multi_gpu_test_simple_model(self): + def test_multi_gpu_test_simple_model(self): gpus = 2 num_samples = 1000 input_dim = 10 output_dim = 1 hidden_dim = 10 epochs = 2 - target_gpu_id = [0, 2, 4] + target_gpu_id = [0, 1] + + if not check_if_compatible_devices(gpus=gpus): + return with self.test_session(): model = keras.models.Sequential() @@ -47,12 +60,11 @@ class TestMultiGPUModel(test.TestCase): parallel_model = keras.utils.multi_gpu_model(model, gpus=gpus) parallel_model.compile(loss='mse', optimizer='rmsprop') parallel_model.fit(x, y, epochs=epochs) - parallel_model = keras.utils.multi_gpu_model(model, gpus=target_gpu_id) parallel_model.compile(loss='mse', optimizer='rmsprop') parallel_model.fit(x, y, epochs=epochs) - def multi_gpu_test_multi_io_model(self): + def test_multi_gpu_test_multi_io_model(self): gpus = 2 num_samples = 1000 input_dim_a = 10 @@ -61,7 +73,10 @@ class TestMultiGPUModel(test.TestCase): output_dim_b = 2 hidden_dim = 10 epochs = 2 - target_gpu_id = [0, 2, 4] + target_gpu_id = [0, 1] + + if not check_if_compatible_devices(gpus=gpus): + return with self.test_session(): input_a = keras.Input((input_dim_a,)) @@ -86,7 +101,10 @@ class TestMultiGPUModel(test.TestCase): parallel_model.compile(loss='mse', optimizer='rmsprop') parallel_model.fit([a_x, b_x], [a_y, b_y], epochs=epochs) - def multi_gpu_test_invalid_devices(self): + def test_multi_gpu_test_invalid_devices(self): + if not check_if_compatible_devices(gpus=2): + return + with self.test_session(): input_shape = (1000, 10) model = keras.models.Sequential() @@ -115,3 +133,53 @@ class TestMultiGPUModel(test.TestCase): with self.assertRaises(ValueError): parallel_model = keras.utils.multi_gpu_model(model, gpus=[0]) parallel_model.fit(x, y, epochs=2) + + def test_nested_model_with_tensor_input(self): + gpus = 2 + input_dim = 10 + shape = (input_dim,) + num_samples = 16 + num_classes = 10 + + if not check_if_compatible_devices(gpus=gpus): + return + + with self.test_session(): + input_shape = (num_samples,) + shape + x_train = np.random.randint(0, 255, input_shape) + y_train = np.random.randint(0, num_classes, (input_shape[0],)) + keras.backend.set_learning_phase(True) + + y_train = keras.utils.to_categorical(y_train, num_classes) + + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + + dataset = data.Dataset.from_tensor_slices((x_train, y_train)) + dataset = dataset.repeat() + dataset = dataset.batch(4) + iterator = dataset.make_one_shot_iterator() + + inputs, targets = iterator.get_next() + + input_tensor = keras.layers.Input(tensor=inputs) + + model = keras.models.Sequential() + model.add(keras.layers.Dense(3, + input_shape=(input_dim,))) + model.add(keras.layers.Dense(num_classes)) + + output = model(input_tensor) + outer_model = keras.Model(input_tensor, output) + parallel_model = keras.utils.multi_gpu_model(outer_model, gpus=gpus) + + parallel_model.compile( + loss='categorical_crossentropy', + optimizer=keras.optimizers.RMSprop(lr=0.0001, decay=1e-6), + metrics=['accuracy'], + target_tensors=[targets]) + parallel_model.fit(epochs=1, steps_per_epoch=3) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/_impl/keras/utils/vis_utils.py b/tensorflow/python/keras/_impl/keras/utils/vis_utils.py index 45c1b92075c50956fee004409e98898411e83d27..4761cece82c727e4962d0374f8efb80dfaeac3c6 100644 --- a/tensorflow/python/keras/_impl/keras/utils/vis_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/vis_utils.py @@ -120,7 +120,7 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'): layer_id = str(id(layer)) for i, node in enumerate(layer._inbound_nodes): node_key = layer.name + '_ib-' + str(i) - if node_key in model._container_nodes: + if node_key in model._network_nodes: # pylint: disable=protected-access for inbound_layer in node.inbound_layers: inbound_layer_id = str(id(inbound_layer)) layer_id = str(id(layer)) diff --git a/tensorflow/python/keras/datasets/fashion_mnist/__init__.py b/tensorflow/python/keras/datasets/fashion_mnist/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..7f5ddecc4707334d52ebf4966f2ec6141cce0d46 100644 --- a/tensorflow/python/keras/datasets/fashion_mnist/__init__.py +++ b/tensorflow/python/keras/datasets/fashion_mnist/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fashion-MNIST dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.keras._impl.keras.datasets.fashion_mnist import load_data + +del absolute_import +del division +del print_function diff --git a/tensorflow/python/keras/preprocessing/image/__init__.py b/tensorflow/python/keras/preprocessing/image/__init__.py index b96e7675527041d3952b049f5f431d3df36eea4c..6aba5fc8252e1acf604a89a4e66c2a7db080aa73 100644 --- a/tensorflow/python/keras/preprocessing/image/__init__.py +++ b/tensorflow/python/keras/preprocessing/image/__init__.py @@ -27,6 +27,7 @@ from tensorflow.python.keras._impl.keras.preprocessing.image import img_to_array from tensorflow.python.keras._impl.keras.preprocessing.image import Iterator from tensorflow.python.keras._impl.keras.preprocessing.image import load_img from tensorflow.python.keras._impl.keras.preprocessing.image import NumpyArrayIterator +from tensorflow.python.keras._impl.keras.preprocessing.image import random_brightness from tensorflow.python.keras._impl.keras.preprocessing.image import random_channel_shift from tensorflow.python.keras._impl.keras.preprocessing.image import random_rotation from tensorflow.python.keras._impl.keras.preprocessing.image import random_shear diff --git a/tensorflow/python/keras/preprocessing/sequence/__init__.py b/tensorflow/python/keras/preprocessing/sequence/__init__.py index 112f6af5e588bcb2e85fdbecea86f402742d44e7..b7a7149cc40654c878e3c0db1fc78d8912abf498 100644 --- a/tensorflow/python/keras/preprocessing/sequence/__init__.py +++ b/tensorflow/python/keras/preprocessing/sequence/__init__.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.keras._impl.keras.preprocessing.sequence import make_sampling_table from tensorflow.python.keras._impl.keras.preprocessing.sequence import pad_sequences from tensorflow.python.keras._impl.keras.preprocessing.sequence import skipgrams +from tensorflow.python.keras._impl.keras.preprocessing.sequence import TimeseriesGenerator del absolute_import del division diff --git a/tensorflow/python/keras/preprocessing/text/__init__.py b/tensorflow/python/keras/preprocessing/text/__init__.py index 5bf1a2fb21dc27f7aa10cd08b1496e3991c61d2f..000ad68a0c01e9067f8852836ba5d502deb3fcd4 100644 --- a/tensorflow/python/keras/preprocessing/text/__init__.py +++ b/tensorflow/python/keras/preprocessing/text/__init__.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.keras._impl.keras.preprocessing.text import hashing_trick from tensorflow.python.keras._impl.keras.preprocessing.text import one_hot from tensorflow.python.keras._impl.keras.preprocessing.text import text_to_word_sequence from tensorflow.python.keras._impl.keras.preprocessing.text import Tokenizer diff --git a/tensorflow/python/keras/utils/__init__.py b/tensorflow/python/keras/utils/__init__.py index 91cc8607274a80a14dd27a64274da7f8f0aafab1..2f74cf031d0520c8d874b7269c52e3b9e1b9931b 100644 --- a/tensorflow/python/keras/utils/__init__.py +++ b/tensorflow/python/keras/utils/__init__.py @@ -30,9 +30,9 @@ from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model +from tensorflow.python.keras._impl.keras.utils.multi_gpu_utils import multi_gpu_model from tensorflow.python.keras._impl.keras.utils.np_utils import normalize from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical -from tensorflow.python.keras._impl.keras.utils.training_utils import multi_gpu_model from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model del absolute_import diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index d4ceb2e489c8a20d26eaf9d89b12992d2b8673d7..228d1c245248c972d7d504df10251e5e45076a2e 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -393,6 +393,7 @@ tf_py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", ], + shard_count = 5, ) tf_py_test( @@ -408,6 +409,7 @@ tf_py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", ], + shard_count = 5, ) tf_py_test( @@ -712,6 +714,18 @@ cuda_py_test( ], ) +tf_py_test( + name = "regex_replace_op_test", + size = "small", + srcs = ["regex_replace_op_test.py"], + additional_deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:string_ops", + ], +) + tf_py_test( name = "save_restore_ops_test", size = "small", @@ -1075,6 +1089,8 @@ cuda_py_test( tags = [ "no_windows", "noasan", + "noguitar", + "notap", ], ) @@ -1559,12 +1575,15 @@ cuda_py_test( "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:layers", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:init_ops", + "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:partitioned_variables", + "//tensorflow/python:random_ops", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], @@ -1892,7 +1911,7 @@ cuda_py_test( cuda_py_test( name = "softmax_op_test", - size = "small", + size = "medium", srcs = ["softmax_op_test.py"], additional_deps = [ "//third_party/py/numpy", @@ -2705,6 +2724,7 @@ cuda_py_test( "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", ], + data = ["//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files"], shard_count = 20, ) @@ -2892,6 +2912,40 @@ tf_py_test( ], ) +tf_py_test( + name = "accumulate_n_test", + size = "small", + srcs = ["accumulate_n_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + +tf_py_test( + name = "accumulate_n_eager_test", + size = "small", + srcs = ["accumulate_n_eager_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:tape", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py b/tensorflow/python/kernel_tests/accumulate_n_eager_test.py similarity index 72% rename from tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py rename to tensorflow/python/kernel_tests/accumulate_n_eager_test.py index 35974b9e21d2d7423777a95a99f51c9cb4b453b2..dc11b7deceb9040584aca1f629f4d003aef39428 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py +++ b/tensorflow/python/kernel_tests/accumulate_n_eager_test.py @@ -12,48 +12,41 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for new version of accumulate_n op that will eventually go into -`ops.math_ops`. - -These test cases spefically exercise the `eager` APIs. They need to be in a -separate file from the remaining tests because eager mode is currently something -you can turn on but can't turn off for the lifetime of the current process.""" +"""Tests for new version of accumulate_n op.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.framework.python.ops import accumulate_n_v2 as av2 - from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test - class AccumulateNV2EagerTest(test_util.TensorFlowTestCase): - """Tests of the new, differentiable version of accumulate_n""" + """Tests of the new, differentiable version of accumulate_n.""" def testMinimalEagerMode(self): forty = constant_op.constant(40) two = constant_op.constant(2) - answer = av2.accumulate_n_v2([forty, two]) + answer = math_ops.accumulate_n([forty, two]) self.assertEqual(42, answer.numpy()) - def testFloat(self): np.random.seed(12345) x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)] tf_x = ops.convert_n_to_tensor(x) with self.test_session(use_gpu=True): - self.assertAllClose(sum(x), av2.accumulate_n_v2(tf_x).numpy()) - self.assertAllClose(x[0] * 5, av2.accumulate_n_v2([tf_x[0]] * 5).numpy()) + self.assertAllClose(sum(x), math_ops.accumulate_n(tf_x).numpy()) + self.assertAllClose(x[0] * 5, + math_ops.accumulate_n([tf_x[0]] * 5).numpy()) def testGrad(self): np.random.seed(42) @@ -65,16 +58,14 @@ class AccumulateNV2EagerTest(test_util.TensorFlowTestCase): ] def fn(first, second, third): - return av2.accumulate_n_v2([first, second, third]) + return math_ops.accumulate_n([first, second, third]) grad_fn = backprop.gradients_function(fn) grad = grad_fn(input_vars[0], input_vars[1], input_vars[2]) - self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 + self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 [elem.numpy() for elem in grad]) - if __name__ == "__main__": ops.enable_eager_execution() test.main() - diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py b/tensorflow/python/kernel_tests/accumulate_n_test.py similarity index 75% rename from tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py rename to tensorflow/python/kernel_tests/accumulate_n_test.py index 45962098e93acfac414396ddbeaa847701ff2b4b..b793906fac2cd12a5c0c663dd169000ad6067759 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py +++ b/tensorflow/python/kernel_tests/accumulate_n_test.py @@ -12,42 +12,49 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for new version of accumulate_n op that will eventually go into -`ops.math_ops`.""" +"""Tests for new version of accumulate_n op.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.framework.python.ops import accumulate_n_v2 as av2 - from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest class AccumulateNV2Test(test_util.TensorFlowTestCase): - """Tests of the new, differentiable version of accumulate_n""" + """Tests of the new, differentiable version of accumulate_n.""" def testFloat(self): np.random.seed(12345) x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)] tf_x = ops.convert_n_to_tensor(x) with self.test_session(use_gpu=True): - self.assertAllClose(sum(x), av2.accumulate_n_v2(tf_x).eval()) - self.assertAllClose(x[0] * 5, av2.accumulate_n_v2([tf_x[0]] * 5).eval()) + self.assertAllClose(sum(x), math_ops.accumulate_n(tf_x).eval()) + self.assertAllClose(x[0] * 5, + math_ops.accumulate_n([tf_x[0]] * 5).eval()) def testInt(self): np.random.seed(54321) x = [np.random.randint(-128, 128, (5, 4, 3, 2, 1)) for _ in range(6)] tf_x = ops.convert_n_to_tensor(x) with self.test_session(use_gpu=True): - self.assertAllEqual(sum(x), av2.accumulate_n_v2(tf_x).eval()) - self.assertAllEqual(x[0] * 6, av2.accumulate_n_v2([tf_x[0]] * 6).eval()) + self.assertAllEqual(sum(x), math_ops.accumulate_n(tf_x).eval()) + self.assertAllEqual(x[0] * 6, + math_ops.accumulate_n([tf_x[0]] * 6).eval()) + + def testUnknownShape(self): + with self.test_session(use_gpu=True): + x0 = array_ops.placeholder(dtype=dtypes_lib.int32, shape=[None]) + acc = math_ops.accumulate_n([x0, x0], shape=[None]) + self.assertAllEqual([2, 4], acc.eval(feed_dict={x0: [1, 2]})) def testGrad(self): np.random.seed(42) @@ -55,9 +62,9 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): with self.test_session(use_gpu=True) as sess: input_vars = [ variables.Variable(10.0 * np.random.random()) - for i in range(0, num_inputs) + for _ in range(0, num_inputs) ] - accum_n = av2.accumulate_n_v2(input_vars) + accum_n = math_ops.accumulate_n(input_vars) sess.run(variables.global_variables_initializer()) accum_n_grad = gradients.gradients(accum_n, input_vars) self.assertAllEqual( @@ -77,7 +84,7 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): ops.convert_to_tensor(x, dtype=dtypes_lib.float32) for x in random_arrays ] - tf_val = av2.accumulate_n_v2(random_tensors) + tf_val = math_ops.accumulate_n(random_tensors) np_val = random_arrays[0] for random_array in random_arrays[1:]: np_val += random_array @@ -86,7 +93,7 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): def testZeroArgs(self): with self.test_session(): with self.assertRaises(ValueError): - tf_val = av2.accumulate_n_v2([]) + tf_val = math_ops.accumulate_n([]) tf_val.eval() def testWrongShape(self): @@ -94,28 +101,28 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): a = variables.Variable(0.2) b = variables.Variable(0.1) - tf_val = av2.accumulate_n_v2([a, b], shape=[2, 2]) # Should be shape=[] + math_ops.accumulate_n([a, b], shape=[2, 2]) # Should be shape=[] def testIncompatibleShapes(self): with self.test_session(): with self.assertRaises(ValueError): a = variables.Variable(np.array([0.1, 0.2])) b = variables.Variable(np.array([[0.3], [0.4]])) - tf_val = av2.accumulate_n_v2([a, b]) + math_ops.accumulate_n([a, b]) def testWrongType(self): with self.test_session(): with self.assertRaises(TypeError): a = variables.Variable(0.2, dtype=np.float32) b = variables.Variable(0.1, dtype=np.float32) - tf_val = av2.accumulate_n_v2([a, b], tensor_dtype=np.int32) + math_ops.accumulate_n([a, b], tensor_dtype=np.int32) def testWrongTypeOneInput(self): # Scenario that used to trigger a bug, even when testWrongType() worked with self.test_session(): with self.assertRaises(TypeError): a = variables.Variable(0.2, dtype=np.float32) - tf_val = av2.accumulate_n_v2([a], tensor_dtype=np.int32) + math_ops.accumulate_n([a], tensor_dtype=np.int32) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 365cf72108de5a1e5e1eb47891a6ad64151add22..64c1760d5e72c8dd2b0b8adb09cc3612f85228b0 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -315,21 +315,39 @@ class ReverseV2Test(test_util.TensorFlowTestCase): self.assertAllEqual(x_tf_4, np.asarray(x_np)[:, ::-1]) self.assertAllEqual(x_tf_5, np.asarray(x_np)[::-1, ::-1]) + # This test covers the axis validation in the shape function + # (no eval()) + def testInvalidAxis(self): + x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + with self.assertRaisesRegexp(ValueError, + "is out of valid range"): + array_ops.reverse_v2(x_np, [-30]) + with self.assertRaisesRegexp(ValueError, + "is out of valid range"): + array_ops.reverse_v2(x_np, [2]) + with self.assertRaisesRegexp(ValueError, + "axis 0 specified more than once"): + array_ops.reverse_v2(x_np, [0, -2]) + # This is the version of reverse that uses axis indices rather than # bool tensors # TODO(b/32254538): Change this test to use array_ops.reverse + # + # Note: this test passes placeholder as constant axis is validated + # in shape function (see testInvalidAxis) def testInvalid(self): x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + axis = array_ops.placeholder(dtypes.int32) with self.test_session(): with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "is out of valid range"): - array_ops.reverse_v2(x_np, [-30]).eval() + array_ops.reverse_v2(x_np, axis).eval(feed_dict={axis: [-30]}) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "is out of valid range"): - array_ops.reverse_v2(x_np, [2]).eval() + array_ops.reverse_v2(x_np, axis).eval(feed_dict={axis: [2]}) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "axis 0 specified more than once"): - array_ops.reverse_v2(x_np, [0, -2]).eval() + array_ops.reverse_v2(x_np, axis).eval(feed_dict={axis: [0, -2]}) def testReverse1DimAuto(self): for dtype in [ @@ -890,7 +908,7 @@ class StridedSliceAssignChecker(object): var = resource_variable_ops.ResourceVariable(self.x) else: var = variables.Variable(self.x) - sess.run(variables.initialize_variables([var])) + sess.run(variables.variables_initializer([var])) val = sess.run(var[index].assign(value)) # val_copy is used to check that tf.assign works equivalently to the # assign method above. @@ -1024,6 +1042,7 @@ class SequenceMaskTest(test_util.TensorFlowTestCase): [[True, False, False, False, False], [True, True, True, False, False], [True, True, False, False, False]]) + @test_util.enable_c_shapes def testOneDimensionalDtypeWithoutMaxlen(self): with self.test_session(): # test dtype and default maxlen: @@ -1037,6 +1056,7 @@ class SequenceMaskTest(test_util.TensorFlowTestCase): res.eval(), [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]]) + @test_util.enable_c_shapes def testOneDimensionalWithoutMaxlen(self): with self.test_session(): res = array_ops.sequence_mask( @@ -1051,6 +1071,7 @@ class SequenceMaskTest(test_util.TensorFlowTestCase): [True, False, False, False], [True, True, True, True]]) + @test_util.enable_c_shapes def testTwoDimensional(self): with self.test_session(): res = array_ops.sequence_mask(constant_op.constant([[1, 3, 2]]), 5) @@ -1223,7 +1244,7 @@ class SnapshotOpTest(test_util.TensorFlowTestCase): for dtype in [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]: with self.test_session(use_gpu=True): x = constant_op.constant([0, 1, 2, 3], dtype=dtype) - y = gen_array_ops._snapshot(x) + y = gen_array_ops.snapshot(x) self.assertAllEqual(y.eval(), [0, 1, 2, 3]) diff --git a/tensorflow/python/kernel_tests/atrous_convolution_test.py b/tensorflow/python/kernel_tests/atrous_convolution_test.py index 2d1b3d9b7e836591646a2d0e59742bf6139446d1..0ef08581c9f931b991ef0c1218dc503345e248c2 100644 --- a/tensorflow/python/kernel_tests/atrous_convolution_test.py +++ b/tensorflow/python/kernel_tests/atrous_convolution_test.py @@ -83,14 +83,14 @@ class AtrousConvolutionTest(test.TestCase): checks = [] def add_check(check, *args, **kwargs): - if context.in_eager_mode(): + if context.executing_eagerly(): args_val, kwargs_val = self.evaluate([args, kwargs]) check(*args_val, **kwargs_val) else: checks.append((check, args, kwargs)) yield add_check - if context.in_graph_mode(): + if not context.executing_eagerly(): all_values = self.evaluate([[args, kwargs] for _, args, kwargs in checks]) for (check, _, _), (args, kwargs) in zip(checks, all_values): check(*args, **kwargs) diff --git a/tensorflow/python/kernel_tests/basic_gpu_test.py b/tensorflow/python/kernel_tests/basic_gpu_test.py index 405651e8ae97fbc5eefd4aba0a95a99ff8fd8c26..987a6ffcd4b18eb5857ff9e82206de7f6ebe8a27 100644 --- a/tensorflow/python/kernel_tests/basic_gpu_test.py +++ b/tensorflow/python/kernel_tests/basic_gpu_test.py @@ -33,7 +33,7 @@ from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables -from tensorflow.python.ops.gen_array_ops import _broadcast_gradient_args +from tensorflow.python.ops.gen_array_ops import broadcast_gradient_args from tensorflow.python.platform import test @@ -157,7 +157,7 @@ class BroadcastSimpleTest(test.TestCase): def _GetGradientArgs(self, xs, ys): with self.test_session(use_gpu=True) as sess: - return sess.run(_broadcast_gradient_args(xs, ys)) + return sess.run(broadcast_gradient_args(xs, ys)) def testBroadcast(self): r0, r1 = self._GetGradientArgs([2, 3, 5], [1]) diff --git a/tensorflow/python/kernel_tests/batchtospace_op_test.py b/tensorflow/python/kernel_tests/batchtospace_op_test.py index 0c802476a0e788aff3de84ab736fa8f1de5daab4..6143cd3baa6317fc512d80f94b494710037d4082 100644 --- a/tensorflow/python/kernel_tests/batchtospace_op_test.py +++ b/tensorflow/python/kernel_tests/batchtospace_op_test.py @@ -44,7 +44,7 @@ class CppOpImpl(object): @staticmethod def batch_to_space(*args, **kwargs): - return gen_array_ops._batch_to_space(*args, **kwargs) + return gen_array_ops.batch_to_space(*args, **kwargs) class BatchToSpaceDepthToSpace(test.TestCase, PythonOpImpl): diff --git a/tensorflow/python/kernel_tests/bcast_ops_test.py b/tensorflow/python/kernel_tests/bcast_ops_test.py index 9e512346053a4c3af089170f47313606c4a307c2..3305e55c05bd03d31c46fd333db09dbab9a5d09c 100644 --- a/tensorflow/python/kernel_tests/bcast_ops_test.py +++ b/tensorflow/python/kernel_tests/bcast_ops_test.py @@ -20,8 +20,8 @@ from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.ops.gen_array_ops import _broadcast_args -from tensorflow.python.ops.gen_array_ops import _broadcast_gradient_args +from tensorflow.python.ops.gen_array_ops import broadcast_args +from tensorflow.python.ops.gen_array_ops import broadcast_gradient_args from tensorflow.python.platform import test @@ -29,11 +29,11 @@ class BcastOpsTest(test.TestCase): def _GetBroadcastShape(self, xs, ys): with self.test_session() as sess: - return sess.run(_broadcast_args(xs, ys)) + return sess.run(broadcast_args(xs, ys)) def _GetGradientArgs(self, xs, ys): with self.test_session() as sess: - return sess.run(_broadcast_gradient_args(xs, ys)) + return sess.run(broadcast_gradient_args(xs, ys)) def testBasic(self): r = self._GetBroadcastShape([2, 3, 5], [1]) diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index 2e94603a3f3d4ca9074320cfb4e9bf06b6640e82..5a83ec8d302b4c26aef7abfa7465eb9fd0cca019 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -102,17 +102,15 @@ class AssertEqualTest(test.TestCase): with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"): check_ops.assert_equal(static_big, static_small, message="fail") - # Dynamic check - if context.in_graph_mode(): - with self.test_session(): - small = array_ops.placeholder(dtypes.int32, name="small") - big = array_ops.placeholder(dtypes.int32, name="big") - with ops.control_dependencies( - [check_ops.assert_equal( - big, small, message="fail")]): - out = array_ops.identity(small) - with self.assertRaisesOpError("fail.*big.*small"): - out.eval(feed_dict={small: [1, 2], big: [3, 4]}) + def test_raises_when_greater_dynamic(self): + with self.test_session(): + small = array_ops.placeholder(dtypes.int32, name="small") + big = array_ops.placeholder(dtypes.int32, name="big") + with ops.control_dependencies( + [check_ops.assert_equal(big, small, message="fail")]): + out = array_ops.identity(small) + with self.assertRaisesOpError("fail.*big.*small"): + out.eval(feed_dict={small: [1, 2], big: [3, 4]}) def test_error_message_eager(self): expected_error_msg_full = r"""big does not equal small @@ -182,15 +180,14 @@ First 2 elements of y: with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"): check_ops.assert_equal(static_big, static_small, message="fail") - # Dynamic check - if context.in_graph_mode(): - with self.test_session(): - small = array_ops.placeholder(dtypes.int32, name="small") - big = array_ops.placeholder(dtypes.int32, name="big") - with ops.control_dependencies([check_ops.assert_equal(small, big)]): - out = array_ops.identity(small) - with self.assertRaisesOpError("small.*big"): - out.eval(feed_dict={small: [3, 1], big: [4, 2]}) + def test_raises_when_less_dynamic(self): + with self.test_session(): + small = array_ops.placeholder(dtypes.int32, name="small") + big = array_ops.placeholder(dtypes.int32, name="big") + with ops.control_dependencies([check_ops.assert_equal(small, big)]): + out = array_ops.identity(small) + with self.assertRaisesOpError("small.*big"): + out.eval(feed_dict={small: [3, 1], big: [4, 2]}) @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_equal_and_broadcastable_shapes(self): @@ -215,6 +212,12 @@ First 2 elements of y: out = array_ops.identity(small) self.evaluate(out) + @test_util.run_in_graph_and_eager_modes() + def test_raises_when_not_equal_and_broadcastable_shapes(self): + cond = constant_op.constant([True, False], name="small") + with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"): + check_ops.assert_equal(cond, False, message="fail") + @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) diff --git a/tensorflow/python/kernel_tests/checkpoint_ops_test.py b/tensorflow/python/kernel_tests/checkpoint_ops_test.py index a786d0a47e569f71812086fb93c21dc12660a2a5..7f147ba53a71539962f424158731e359724f664f 100644 --- a/tensorflow/python/kernel_tests/checkpoint_ops_test.py +++ b/tensorflow/python/kernel_tests/checkpoint_ops_test.py @@ -50,7 +50,7 @@ class GenerateVocabRemappingTest(test.TestCase): def test_generate_remapping_with_no_vocab_changes(self): """Tests where vocab does not change at all.""" - remapping, num_present = gen_checkpoint_ops._generate_vocab_remapping( + remapping, num_present = gen_checkpoint_ops.generate_vocab_remapping( new_vocab_file=self.old_vocab_file, old_vocab_file=self.old_vocab_file, num_new_vocab=3, @@ -63,7 +63,7 @@ class GenerateVocabRemappingTest(test.TestCase): def test_generate_remapping_with_shifted_vocab(self): """Tests where vocab is the same, but shifted / ordered differently.""" - remapping, num_present = gen_checkpoint_ops._generate_vocab_remapping( + remapping, num_present = gen_checkpoint_ops.generate_vocab_remapping( new_vocab_file=self.new_vocab_file, old_vocab_file=self.old_vocab_file, num_new_vocab=3, @@ -76,7 +76,7 @@ class GenerateVocabRemappingTest(test.TestCase): def test_generate_remapping_with_offset(self): """Tests offset and num_new_vocab logic.""" - remapping, num_present = gen_checkpoint_ops._generate_vocab_remapping( + remapping, num_present = gen_checkpoint_ops.generate_vocab_remapping( new_vocab_file=self.new_vocab_file, old_vocab_file=self.old_vocab_file, num_new_vocab=1, @@ -89,7 +89,7 @@ class GenerateVocabRemappingTest(test.TestCase): def test_generate_remapping_with_old_vocab_size(self): """Tests where old_vocab_size is specified.""" - remapping, num_present = gen_checkpoint_ops._generate_vocab_remapping( + remapping, num_present = gen_checkpoint_ops.generate_vocab_remapping( new_vocab_file=self.new_vocab_file, old_vocab_file=self.old_vocab_file, num_new_vocab=3, @@ -132,7 +132,7 @@ class LoadAndRemapMatrixTest(test.TestCase): # No column remapping, new weight matrix has second row, then first row. row_remapping = [1, 0] - remapped_matrix = gen_checkpoint_ops._load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=row_remapping, @@ -147,7 +147,7 @@ class LoadAndRemapMatrixTest(test.TestCase): # No row remapping, new weight matrix has third col, then first col. row_remapping = list(range(self.old_num_rows)) col_remapping = [2, 0] - remapped_matrix = gen_checkpoint_ops._load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=row_remapping, @@ -162,7 +162,7 @@ class LoadAndRemapMatrixTest(test.TestCase): # Both row and column remappings. row_remapping = [1, 0, 4] col_remapping = [1, 15] - remapped_matrix = gen_checkpoint_ops._load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=row_remapping, @@ -177,7 +177,7 @@ class LoadAndRemapMatrixTest(test.TestCase): def test_load_and_remap_with_init(self): """Tests the op's load and remap where there are missing entries.""" init_val = 42 - remapped_matrix = gen_checkpoint_ops._load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=[2, -1, 0], @@ -196,7 +196,7 @@ class LoadAndRemapMatrixTest(test.TestCase): """Tests when all the rows are missing and need to be initialized.""" num_rows = 7 initializing_values = [42] * num_rows * self.old_num_cols - remapped_matrix = gen_checkpoint_ops._load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=[-1] * num_rows, @@ -214,7 +214,7 @@ class LoadAndRemapMatrixTest(test.TestCase): num_rows = 7 num_cols = 4 initializing_values = [42] * num_rows * num_cols - remapped_matrix = gen_checkpoint_ops._load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=[-1] * num_rows, @@ -235,7 +235,7 @@ class LoadAndRemapMatrixTest(test.TestCase): invalid_remapping = [1, 0, 0, 0, 1, 2] # Invalid row remapping. - remapped_matrix = gen_checkpoint_ops._load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=invalid_remapping, @@ -247,7 +247,7 @@ class LoadAndRemapMatrixTest(test.TestCase): remapped_matrix.eval() # Invalid column remapping. - remapped_matrix = gen_checkpoint_ops._load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=list(range(self.old_num_rows)), @@ -260,7 +260,7 @@ class LoadAndRemapMatrixTest(test.TestCase): def test_load_and_remap_incorrect_initializing_values(self): """Tests that errors are raised with incorrect number of init values.""" - remapped_matrix = gen_checkpoint_ops._load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=[2, -1, 0], @@ -275,7 +275,7 @@ class LoadAndRemapMatrixTest(test.TestCase): with self.test_session(), self.assertRaises(errors.InvalidArgumentError): remapped_matrix.eval() - remapped_matrix = gen_checkpoint_ops._load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=[2, -1, 0], @@ -314,7 +314,7 @@ class LoadAndRemapMatrixWithMaxRowsTest(test.TestCase): num_rows, num_cols = np_value.shape # Tests loading the entire tensor (except reversed). - remapped_matrix = gen_checkpoint_ops._load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=ckpt_path, old_tensor_name=old_tensor_name, # Simply reverses the rows of the matrix. @@ -332,7 +332,7 @@ class LoadAndRemapMatrixWithMaxRowsTest(test.TestCase): self.assertGreater(num_rows, 2) prefix_rows = 2 suffix_rows = 3 - remapped_matrix = gen_checkpoint_ops._load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=ckpt_path, old_tensor_name=old_tensor_name, # Reverses the rows of the matrix, then prepends and appends @@ -353,7 +353,7 @@ class LoadAndRemapMatrixWithMaxRowsTest(test.TestCase): # Tests when everything is taken from initializing_values. new_rows = 7 initializing_values = [42] * new_rows * num_cols - remapped_matrix = gen_checkpoint_ops._load_and_remap_matrix( + remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=ckpt_path, old_tensor_name=old_tensor_name, # Nothing is loaded from the old tensor. diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py index 127bc6bb20ae6b415da94672de68cc4b8ceaa287..c22934ce47543ab11b6a5b9acde2e2ec3aec9da7 100644 --- a/tensorflow/python/kernel_tests/concat_op_test.py +++ b/tensorflow/python/kernel_tests/concat_op_test.py @@ -526,7 +526,7 @@ class ConcatOpTest(test.TestCase): with self.test_session(use_gpu=True): t1 = [] t2 = [] - output = gen_array_ops._concat_v2([t1, t2], 0).eval() + output = gen_array_ops.concat_v2([t1, t2], 0).eval() self.assertFalse(output) # Checks that output is empty def testConcatInvalidAxis(self): @@ -534,20 +534,20 @@ class ConcatOpTest(test.TestCase): with self.test_session(use_gpu=True): t1 = [1] t2 = [2] - gen_array_ops._concat_v2([t1, t2], 1).eval() + gen_array_ops.concat_v2([t1, t2], 1).eval() def testConcatNegativeAxis(self): with self.test_session(use_gpu=True): t1 = [[1, 2, 3], [4, 5, 6]] t2 = [[7, 8, 9], [10, 11, 12]] - c = gen_array_ops._concat_v2([t1, t2], -2) + c = gen_array_ops.concat_v2([t1, t2], -2) self.assertEqual([4, 3], c.get_shape().as_list()) output = c.eval() self.assertAllEqual([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], output) - c = gen_array_ops._concat_v2([t1, t2], -1) + c = gen_array_ops.concat_v2([t1, t2], -1) self.assertEqual([2, 6], c.get_shape().as_list()) output = c.eval() self.assertAllEqual([[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], output) @@ -606,6 +606,17 @@ class ConcatOpTest(test.TestCase): inp_tensors_placeholders, -2, output_shape=[2, 3], gather_indexes=[2, 0], feed_dict=feed_dict) + def testConcatAxisType(self): + for dtype in [dtypes.int32, dtypes.int64]: + with self.test_session(use_gpu=True): + t1 = [[1, 2, 3], [4, 5, 6]] + t2 = [[7, 8, 9], [10, 11, 12]] + + c = gen_array_ops.concat_v2([t1, t2], + constant_op.constant(1, dtype=dtype)) + self.assertEqual([2, 6], c.get_shape().as_list()) + output = c.eval() + self.assertAllEqual([[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], output) class ConcatOffsetTest(test.TestCase): @@ -615,7 +626,7 @@ class ConcatOffsetTest(test.TestCase): s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32) s2 = constant_op.constant([2, 20, 5], dtypes.int32) - off = gen_array_ops._concat_offset(cdim, [s0, s1, s2]) + off = gen_array_ops.concat_offset(cdim, [s0, s1, s2]) ans = sess.run(off) self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]]) @@ -624,7 +635,7 @@ class ConcatOffsetTest(test.TestCase): cdim = constant_op.constant(1, dtypes.int32) s0 = constant_op.constant([[2, 3, 5]], dtypes.int32) s1 = constant_op.constant([[2, 7, 5]], dtypes.int32) - off = gen_array_ops._concat_offset(cdim, [s0, s1]) + off = gen_array_ops.concat_offset(cdim, [s0, s1]) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, r"should be a vector"): sess.run(off) @@ -634,7 +645,7 @@ class ConcatOffsetTest(test.TestCase): cdim = constant_op.constant(4, dtypes.int32) s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32) - off = gen_array_ops._concat_offset(cdim, [s0, s1]) + off = gen_array_ops.concat_offset(cdim, [s0, s1]) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, r"Concat dim is out of range: 4 vs. 3"): sess.run(off) @@ -644,7 +655,7 @@ class ConcatOffsetTest(test.TestCase): cdim = constant_op.constant(1, dtypes.int32) s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5, 10], dtypes.int32) - off = gen_array_ops._concat_offset(cdim, [s0, s1]) + off = gen_array_ops.concat_offset(cdim, [s0, s1]) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, r"should contain 3 elem"): sess.run(off) @@ -654,7 +665,7 @@ class ConcatOffsetTest(test.TestCase): cdim = constant_op.constant(1, dtypes.int32) s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 10], dtypes.int32) - off = gen_array_ops._concat_offset(cdim, [s0, s1]) + off = gen_array_ops.concat_offset(cdim, [s0, s1]) with self.assertRaisesRegexp( errors_impl.InvalidArgumentError, r"All dimensions except 1 must match. Input 1 has shape \[2 7 10\] " @@ -667,7 +678,7 @@ class ConcatOffsetTest(test.TestCase): s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32) s2 = constant_op.constant([2, 20, 5], dtypes.int32) - off = gen_array_ops._concat_offset(cdim, [s0, s1, s2]) + off = gen_array_ops.concat_offset(cdim, [s0, s1, s2]) ans = sess.run(off) self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]]) @@ -675,7 +686,7 @@ class ConcatOffsetTest(test.TestCase): s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([1, 3, 5], dtypes.int32) s2 = constant_op.constant([3, 3, 5], dtypes.int32) - off = gen_array_ops._concat_offset(cdim, [s0, s1, s2]) + off = gen_array_ops.concat_offset(cdim, [s0, s1, s2]) ans = sess.run(off) self.assertAllEqual(ans, [[0, 0, 0], [2, 0, 0], [3, 0, 0]]) diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py index 16e56349c45dd56a335f6f881826d975e24bd110..18796f709566f022258806ce46cc706e8fe34354 100644 --- a/tensorflow/python/kernel_tests/constant_op_test.py +++ b/tensorflow/python/kernel_tests/constant_op_test.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import logging_ops @@ -180,6 +181,11 @@ class ConstantTest(test.TestCase): shape=[2, 3, 5]) self.assertEqual(c.get_shape(), [2, 3, 5]) + @test_util.assert_no_new_pyobjects_executing_eagerly + def testEagerMemory(self): + """Tests PyObject refs are managed correctly when executing eagerly.""" + constant_op.constant([[1.]]) + def testImplicitShapeNumPy(self): with ops.Graph().as_default(): c = constant_op.constant( @@ -875,7 +881,7 @@ versions { class PlaceholderWithDefaultTest(test.TestCase): def testFullShape(self): - with self.test_session(): + with self.test_session(force_gpu=test_util.is_gpu_available()): p = array_ops.placeholder_with_default([[2, 2], [2, 2]], shape=[2, 2]) a = array_ops.identity(p) self.assertAllEqual([[2, 2], [2, 2]], a.eval()) @@ -886,7 +892,7 @@ class PlaceholderWithDefaultTest(test.TestCase): a.eval(feed_dict={p: [[6, 6, 6], [6, 6, 6]]}) def testPartialShape(self): - with self.test_session(): + with self.test_session(force_gpu=test_util.is_gpu_available()): p = array_ops.placeholder_with_default([1, 2, 3], shape=[None]) a = array_ops.identity(p) self.assertAllEqual([1, 2, 3], a.eval()) @@ -896,7 +902,7 @@ class PlaceholderWithDefaultTest(test.TestCase): a.eval(feed_dict={p: [[2, 2], [2, 2]]}) def testNoShape(self): - with self.test_session(): + with self.test_session(force_gpu=test_util.is_gpu_available()): p = array_ops.placeholder_with_default([17], shape=None) a = array_ops.identity(p) self.assertAllEqual([17], a.eval()) @@ -905,11 +911,12 @@ class PlaceholderWithDefaultTest(test.TestCase): [[3, 3], [3, 3]], a.eval(feed_dict={p: [[3, 3], [3, 3]]})) def testGradient(self): - with self.test_session(): + with self.test_session(force_gpu=test_util.is_gpu_available()): x = array_ops.placeholder(dtypes_lib.float32, [5, 7]) y = array_ops.placeholder_with_default(x, None) err = gradient_checker.compute_gradient_error(x, [5, 7], y, [5, 7]) self.assertLess(err, 1e-3) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 58f38650eb526e98edf35b2425e0e9e1296ab353..75f8644f694c4cebb7dbdac4599244dda427bc05 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -144,7 +144,7 @@ class ControlFlowTest(test.TestCase): enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True) nine = constant_op.constant(9) - enter_nine = gen_control_flow_ops._enter(nine, "foo_1") + enter_nine = gen_control_flow_ops.enter(nine, "foo_1") op = state_ops.assign(enter_v, enter_nine) v2 = control_flow_ops.with_dependencies([op], enter_v) v3 = control_flow_ops.exit(v2) @@ -164,9 +164,9 @@ class ControlFlowTest(test.TestCase): def testEnterMulExit(self): with self.test_session(): data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") - enter_data = gen_control_flow_ops._enter(data, "foo_1", False) + enter_data = gen_control_flow_ops.enter(data, "foo_1", False) five = constant_op.constant(5) - enter_five = gen_control_flow_ops._enter(five, "foo_1", False) + enter_five = gen_control_flow_ops.enter(five, "foo_1", False) mul_op = math_ops.multiply(enter_data, enter_five) exit_op = control_flow_ops.exit(mul_op) @@ -178,12 +178,12 @@ class ControlFlowTest(test.TestCase): v = variables.Variable([0.0, 0.0], dtype=dtypes.float32) # If is_constant=True, the shape information should be propagated. - enter_v_constant = gen_control_flow_ops._enter( + enter_v_constant = gen_control_flow_ops.enter( v, "frame1", is_constant=True) self.assertEqual(enter_v_constant.shape, [2]) # Otherwise, the shape should be unknown. - enter_v_non_constant = gen_control_flow_ops._enter( + enter_v_non_constant = gen_control_flow_ops.enter( v, "frame2", is_constant=False) self.assertEqual(enter_v_non_constant.shape, None) @@ -257,8 +257,8 @@ class ControlFlowTest(test.TestCase): false = ops.convert_to_tensor(False) n = constant_op.constant(10) - enter_false = gen_control_flow_ops._enter(false, "foo_1", False) - enter_n = gen_control_flow_ops._enter(n, "foo_1", False) + enter_false = gen_control_flow_ops.enter(false, "foo_1", False) + enter_n = gen_control_flow_ops.enter(n, "foo_1", False) merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0] switch_n = control_flow_ops.switch(merge_n, enter_false) @@ -275,9 +275,9 @@ class ControlFlowTest(test.TestCase): one = constant_op.constant(1) n = constant_op.constant(10) - enter_i = gen_control_flow_ops._enter(zero, "foo", False) - enter_one = gen_control_flow_ops._enter(one, "foo", True) - enter_n = gen_control_flow_ops._enter(n, "foo", True) + enter_i = gen_control_flow_ops.enter(zero, "foo", False) + enter_one = gen_control_flow_ops.enter(one, "foo", True) + enter_n = gen_control_flow_ops.enter(n, "foo", True) with ops.device(test.gpu_device_name()): merge_i = control_flow_ops.merge([enter_i, enter_i])[0] @@ -301,9 +301,9 @@ class ControlFlowTest(test.TestCase): one = constant_op.constant(1) n = constant_op.constant(10) - enter_i = gen_control_flow_ops._enter(zero, "foo", False) - enter_one = gen_control_flow_ops._enter(one, "foo", True) - enter_n = gen_control_flow_ops._enter(n, "foo", True) + enter_i = gen_control_flow_ops.enter(zero, "foo", False) + enter_one = gen_control_flow_ops.enter(one, "foo", True) + enter_n = gen_control_flow_ops.enter(n, "foo", True) merge_i = control_flow_ops.merge([enter_i, enter_i])[0] @@ -324,8 +324,8 @@ class ControlFlowTest(test.TestCase): def testDifferentFrame(self): with self.test_session(): data = array_ops.placeholder(dtypes.float32, shape=[]) - enter_1 = gen_control_flow_ops._enter(data, "foo_1", False) - enter_2 = gen_control_flow_ops._enter(data, "foo_2", False) + enter_1 = gen_control_flow_ops.enter(data, "foo_1", False) + enter_2 = gen_control_flow_ops.enter(data, "foo_2", False) res = math_ops.add(enter_1, enter_2) with self.assertRaisesOpError("has inputs from different frames"): res.eval(feed_dict={data: 1.0}) @@ -552,7 +552,7 @@ class ControlFlowTest(test.TestCase): def testCondRef(self): with self.test_session(): - x = gen_state_ops._variable( + x = gen_state_ops.variable( shape=[1], dtype=dtypes.float32, name="x", @@ -580,7 +580,7 @@ class ControlFlowTest(test.TestCase): def testUninitializedRefIdentity(self): with self.test_session() as sess: - v = gen_state_ops._variable( + v = gen_state_ops.variable( shape=[1], dtype=dtypes.float32, name="v", @@ -591,10 +591,10 @@ class ControlFlowTest(test.TestCase): # Both v_f and v_t are uninitialized references. However, an actual use # of the reference in the 'true' branch in the 'tf.identity' op will # not 'fire' when v is uninitialized, so this is a valid construction. - # This test tests that _ref_identity allows uninitialized ref as input + # This test tests that ref_identity allows uninitialized ref as input # so that this construction is allowed. - v_f_op = gen_array_ops._ref_identity(v_f) - v_t_op = gen_array_ops._ref_identity(v_t) + v_f_op = gen_array_ops.ref_identity(v_f) + v_t_op = gen_array_ops.ref_identity(v_t) with ops.control_dependencies([v_f_op]): assign_v = state_ops.assign(v, [1.0]) with ops.control_dependencies([v_t_op]): @@ -633,7 +633,8 @@ class ControlFlowTest(test.TestCase): sess.run(r) def testCondGrad_1(self): - with self.test_session(): + graph = ops.Graph() + with graph.as_default(): x = constant_op.constant(10.0, name="x") pred = math_ops.less(1, 2) fn1 = lambda: array_ops.identity(x) @@ -641,8 +642,14 @@ class ControlFlowTest(test.TestCase): r = control_flow_ops.cond(pred, fn1, fn2) grad = gradients_impl.gradients(r, [x])[0] - result = grad.eval() - self.assertAllEqual(1.0, result) + with self.test_session(): + self.assertAllEqual(1.0, grad.eval()) + # The gradients computation creates a tensor with zeros by broadcasting a + # zeros constant to the required shape. Verify that the zero constant + # feeding into the fill is dominated by a Switch. + zero = graph.get_operation_by_name("gradients/zeros/Const") + self.assertEqual(len(zero.control_inputs), 1) + self.assertEqual(zero.control_inputs[0].type, "Switch") def testCondGrad_2(self): with self.test_session(): @@ -744,7 +751,7 @@ class ControlFlowTest(test.TestCase): def b(i, x): self.assertEqual(x.dtype, dtypes.int32_ref) - return (i + 1, gen_array_ops._ref_identity(x)) + return (i + 1, gen_array_ops.ref_identity(x)) r = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=5) @@ -1620,7 +1627,7 @@ class ControlFlowTest(test.TestCase): def testWhileStack_1(self): with self.test_session(): - s = gen_data_flow_ops._stack_v2(-1, dtypes.int32, stack_name="foo") + s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo") i = constant_op.constant(0) def c(i): @@ -1629,7 +1636,7 @@ class ControlFlowTest(test.TestCase): def b(i): ni = math_ops.add(i, 1) ni = control_flow_ops.with_dependencies( - [gen_data_flow_ops._stack_push_v2(s, i)], ni) + [gen_data_flow_ops.stack_push_v2(s, i)], ni) return ni r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1) @@ -1641,7 +1648,7 @@ class ControlFlowTest(test.TestCase): def b1(i, x): ni = math_ops.subtract(i, 1) - nx = x + gen_data_flow_ops._stack_pop_v2(s, dtypes.int32) + nx = x + gen_data_flow_ops.stack_pop_v2(s, dtypes.int32) return [ni, nx] _, rx = control_flow_ops.while_loop( @@ -2205,12 +2212,9 @@ class ControlFlowTest(test.TestCase): self.assertEqual(x.dtype, dtypes.int32_ref) - # pylint: disable=protected-access def body(i, x): self.assertEqual(x.dtype, dtypes.int32_ref) - return [i + 1, gen_array_ops._ref_identity(x)] - - # pylint: enable=protected-access + return [i + 1, gen_array_ops.ref_identity(x)] r = control_flow_ops.while_loop(c, body, [i, x], parallel_iterations=5) diff --git a/tensorflow/python/kernel_tests/control_flow_util_test.py b/tensorflow/python/kernel_tests/control_flow_util_test.py index 23185eaeece0d56fd83ecdf9e02c778712420465..39e96f74b0461da0cf499e303b30a4a41aae4899 100644 --- a/tensorflow/python/kernel_tests/control_flow_util_test.py +++ b/tensorflow/python/kernel_tests/control_flow_util_test.py @@ -41,17 +41,17 @@ class ControlFlowUtilTest(test.TestCase): self.assertFalse(control_flow_util.IsSwitch(test_ops.int_output().op)) def testIsLoopEnter(self): - enter = gen_control_flow_ops._enter(1, frame_name="name").op + enter = gen_control_flow_ops.enter(1, frame_name="name").op self.assertTrue(control_flow_util.IsLoopEnter(enter)) self.assertFalse(control_flow_util.IsLoopConstantEnter(enter)) - ref_enter = gen_control_flow_ops._ref_enter(test_ops.ref_output(), - frame_name="name").op + ref_enter = gen_control_flow_ops.ref_enter(test_ops.ref_output(), + frame_name="name").op self.assertTrue(control_flow_util.IsLoopEnter(ref_enter)) self.assertFalse(control_flow_util.IsLoopConstantEnter(ref_enter)) - const_enter = gen_control_flow_ops._enter(1, frame_name="name", - is_constant=True).op + const_enter = gen_control_flow_ops.enter(1, frame_name="name", + is_constant=True).op self.assertTrue(control_flow_util.IsLoopEnter(const_enter)) self.assertTrue(control_flow_util.IsLoopConstantEnter(const_enter)) diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index f4fe01f868da25660171c614bbf84390aead3ade..a291bef0ad6f16184ff29f665457a53b77447d54 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -159,11 +159,11 @@ class Conv2DTest(test.TestCase): def _DtypesToTest(self, use_gpu): if use_gpu and not test_util.CudaSupportsHalfMatMulAndConv(): - return [dtypes.float32] + return [dtypes.float32, dtypes.float64] else: # It is important that float32 comes before float16 here, # as we will be using its gradients as reference for fp16 gradients. - return [dtypes.float32, dtypes.float16] + return [dtypes.float32, dtypes.float16, dtypes.float64] def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, dilations, strides, padding, data_format, dtype, use_gpu): @@ -970,7 +970,7 @@ class Conv2DTest(test.TestCase): self.assertArrayNear(value_2.flatten(), value.flatten(), err) def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1(self): - if test.is_gpu_available(cuda_only=True): + if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropFilterDilation( input_sizes=[1, 3, 6, 1], @@ -984,7 +984,7 @@ class Conv2DTest(test.TestCase): err=1e-5) def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self): - if test.is_gpu_available(cuda_only=True): + if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropFilterDilation( input_sizes=[1, 2, 3, 1], @@ -998,7 +998,7 @@ class Conv2DTest(test.TestCase): err=1e-5) def testConv2DEmptyBackpropFilterDilation1x2(self): - if test.is_gpu_available(cuda_only=True): + if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropFilterDilation( input_sizes=[1, 2, 3, 1], @@ -1012,7 +1012,7 @@ class Conv2DTest(test.TestCase): err=1e-5) def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self): - if test.is_gpu_available(cuda_only=True): + if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropFilterDilation( input_sizes=[1, 3, 4, 3], @@ -1026,7 +1026,7 @@ class Conv2DTest(test.TestCase): err=1e-5) def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2(self): - if test.is_gpu_available(cuda_only=True): + if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropFilterDilation( input_sizes=[1, 3, 3, 1], @@ -1040,7 +1040,7 @@ class Conv2DTest(test.TestCase): err=1e-5) def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1(self): - if test.is_gpu_available(cuda_only=True): + if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropInputDilation( input_sizes=[1, 3, 6, 1], @@ -1054,7 +1054,7 @@ class Conv2DTest(test.TestCase): err=1e-5) def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self): - if test.is_gpu_available(cuda_only=True): + if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropInputDilation( input_sizes=[1, 2, 3, 1], @@ -1068,7 +1068,7 @@ class Conv2DTest(test.TestCase): err=1e-5) def testConv2DEmptyBackpropInputDilation1x2(self): - if test.is_gpu_available(cuda_only=True): + if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropInputDilation( input_sizes=[0, 2, 3, 1], @@ -1082,7 +1082,7 @@ class Conv2DTest(test.TestCase): err=1e-5) def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self): - if test.is_gpu_available(cuda_only=True): + if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): # The GPU version of this test is not very stable. So adjusting the # error threshold to 1e-4. @@ -1098,7 +1098,7 @@ class Conv2DTest(test.TestCase): err=1e-4) def testConv2DKernelSizeMatchesInputSizeBackpropInputDilation2x2(self): - if test.is_gpu_available(cuda_only=True): + if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropInputDilation( input_sizes=[1, 3, 3, 1], diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 0d9b46c30dbbed20dd940e0427fbf6f6d5415106..8db0bb6f0dc495e7be2cd717787acf87156f42af 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -495,11 +495,11 @@ class UnaryOpTest(test.TestCase): dtype_tols = [(np.float32, 5e-4), (np.float64, 1e-6), (np.complex64, 5e-4), (np.complex128, 1e-6)] op_range = [ - (gen_math_ops._reciprocal_grad, [-2, 2]), - (gen_math_ops._rsqrt_grad, [0.1, 3]), - (gen_math_ops._sigmoid_grad, [-2, 2]), - (gen_math_ops._sqrt_grad, [0.1, 3]), - (gen_math_ops._tanh_grad, [-2, 2]), + (gen_math_ops.reciprocal_grad, [-2, 2]), + (gen_math_ops.rsqrt_grad, [0.1, 3]), + (gen_math_ops.sigmoid_grad, [-2, 2]), + (gen_math_ops.sqrt_grad, [0.1, 3]), + (gen_math_ops.tanh_grad, [-2, 2]), ] def rand(dtype): diff --git a/tensorflow/python/kernel_tests/depthtospace_op_test.py b/tensorflow/python/kernel_tests/depthtospace_op_test.py index 7df2366954f3a6f3f37aef447479ba67c263025f..f0beabb4e20e4ec0a2fc7a487bf2541d19568927 100644 --- a/tensorflow/python/kernel_tests/depthtospace_op_test.py +++ b/tensorflow/python/kernel_tests/depthtospace_op_test.py @@ -35,8 +35,8 @@ from tensorflow.python.platform import tf_logging class DepthToSpaceTest(test.TestCase): - def _testOne(self, inputs, block_size, outputs): - input_nhwc = math_ops.to_float(inputs) + def _testOne(self, inputs, block_size, outputs, dtype=dtypes.float32): + input_nhwc = math_ops.cast(inputs, dtype) with self.test_session(use_gpu=False): # test NHWC (default) on CPU x_tf = array_ops.depth_to_space(input_nhwc, block_size) @@ -59,6 +59,12 @@ class DepthToSpaceTest(test.TestCase): x_out = [[[[1], [2]], [[3], [4]]]] self._testOne(x_np, block_size, x_out) + def testBasicFloat16(self): + x_np = [[[[1, 2, 3, 4]]]] + block_size = 2 + x_out = [[[[1], [2]], [[3], [4]]]] + self._testOne(x_np, block_size, x_out, dtype=dtypes.float16) + # Tests for larger input dimensions. To make sure elements are # correctly ordered spatially. def testBlockSize2(self): @@ -90,6 +96,24 @@ class DepthToSpaceTest(test.TestCase): x_out = [batch_output_elt(i) for i in range(batch_size)] self._testOne(x_np, block_size, x_out) + def testBatchSize0(self): + block_size = 2 + batch_size = 0 + input_nhwc = array_ops.ones([batch_size, 2, 3, 12]) + x_out = array_ops.ones([batch_size, 4, 6, 3]) + + with self.test_session(use_gpu=False): + # test NHWC (default) on CPU + x_tf = array_ops.depth_to_space(input_nhwc, block_size) + self.assertAllEqual(x_tf.shape, x_out.shape) + x_tf.eval() + if test.is_gpu_available(): + with self.test_session(use_gpu=True): + # test NHWC (default) on GPU + x_tf = array_ops.depth_to_space(input_nhwc, block_size) + self.assertAllEqual(x_tf.shape, x_out.shape) + x_tf.eval() + # Tests for different width and height. def testNonSquare(self): x_np = [[[[1, 10, 2, 20, 3, 30, 4, 40]], diff --git a/tensorflow/python/kernel_tests/determinant_op_test.py b/tensorflow/python/kernel_tests/determinant_op_test.py index 222038b22ef3c766efd14fd9b1c9044a0b6e9125..a52b2c0dc32c26ecd5ef08aa3f8678f0006cd4fe 100644 --- a/tensorflow/python/kernel_tests/determinant_op_test.py +++ b/tensorflow/python/kernel_tests/determinant_op_test.py @@ -65,7 +65,7 @@ class DeterminantOpTest(test.TestCase): self._compareDeterminantBase(matrix_x, linalg_ops.matrix_determinant(matrix_x)) self._compareLogDeterminantBase( - matrix_x, gen_linalg_ops._log_matrix_determinant(matrix_x)) + matrix_x, gen_linalg_ops.log_matrix_determinant(matrix_x)) def testBasic(self): # 2x2 matrices diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py index fedbf9e696923a34968e7a907e4099c520d1447b..5e8937ad2c36afb2b1ddb58ffb238a45e09e4b30 100644 --- a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py +++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py @@ -326,6 +326,18 @@ class DynamicPartitionTest(test.TestCase): with self.assertRaises(ValueError): data_flow_ops.dynamic_partition(data, indices, num_partitions=4) + # see https://github.com/tensorflow/tensorflow/issues/17106 + def testCUBBug(self): + x = constant_op.constant(np.random.randn(3072)) + inds = [0]*189 + [1]*184 + [2]*184 + [3]*191 + [4]*192 + [5]*195 + [6]*195 + inds += [7]*195 + [8]*188 + [9]*195 + [10]*188 + [11]*202 + [12]*194 + inds += [13]*194 + [14]*194 + [15]*192 + self.assertEqual(len(inds), x.shape[0]) + partitioned = data_flow_ops.dynamic_partition(x, inds, 16) + with self.test_session() as sess: + res = sess.run(partitioned) + self.assertEqual(res[-1].shape[0], 192) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py b/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py index feec9934e459590bb1dd0bc5c7cf40013d3d8b88..faac7d8365dfaa1b6b32f8fe66a76c3694aa0d5b 100644 --- a/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py +++ b/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py @@ -347,7 +347,7 @@ class FractionalAvgPoolGradTest(test.TestCase): Two types of tests for FractionalAvgPoolGrad. 1) Test fractional_avg_pool_grad() directly. - This type of test relies on gen_nn_ops._avg_pool_grad() returns the + This type of test relies on gen_nn_ops.avg_pool_grad() returns the correct result. For example: * input_tensor_shape = (1, 10, 10, 1) * window_size = (1, 2, 2, 1) @@ -404,13 +404,13 @@ class FractionalAvgPoolGradTest(test.TestCase): num_elements *= dim_size output_backprop = (self._PRNG.rand(num_elements) * 1000).reshape(output_data.shape) - input_backprop_tensor = gen_nn_ops._avg_pool_grad( + input_backprop_tensor = gen_nn_ops.avg_pool_grad( input_tensor.get_shape(), output_backprop, window_size, stride_size, padding) input_backprop = input_backprop_tensor.eval() row_seq = list(range(0, num_rows + 1, row_window_size)) col_seq = list(range(0, num_cols + 1, col_window_size)) - fap_input_backprop_tensor = gen_nn_ops._fractional_avg_pool_grad( + fap_input_backprop_tensor = gen_nn_ops.fractional_avg_pool_grad( input_tensor.get_shape(), output_backprop, row_seq, @@ -443,7 +443,7 @@ class FractionalAvgPoolGradTest(test.TestCase): num_elements *= dim_size output_backprop = (self._PRNG.rand(num_elements) * 1000).reshape(output_data.shape) - input_backprop_tensor = gen_nn_ops._avg_pool_grad( + input_backprop_tensor = gen_nn_ops.avg_pool_grad( input_tensor.get_shape(), output_backprop, window_size, stride_size, padding) input_backprop = input_backprop_tensor.eval() @@ -451,7 +451,7 @@ class FractionalAvgPoolGradTest(test.TestCase): col_seq = list(range(0, num_cols, col_window_size - 1)) row_seq[-1] += 1 col_seq[-1] += 1 - fap_input_backprop_tensor = gen_nn_ops._fractional_avg_pool_grad( + fap_input_backprop_tensor = gen_nn_ops.fractional_avg_pool_grad( input_tensor.get_shape(), output_backprop, row_seq, diff --git a/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py b/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py index 5983ae7759dbf3eb2db9867def829ce8dbeb4b73..6477c9ebc4c35fcc5963b27a0f5c50624a73fa09 100644 --- a/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py +++ b/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py @@ -318,7 +318,7 @@ class FractionalMaxPoolGradTest(test.TestCase): Two types of tests for FractionalMaxPoolGrad. 1) Test fractional_max_pool_grad() directly. - This type of test relies on gen_nn_ops._max_pool_grad() returns the correct + This type of test relies on gen_nn_ops.max_pool_grad() returns the correct result. For example: * input_tensor_shape = (1, 10, 10, 1) * window_size = (1, 2, 2, 1) @@ -384,16 +384,13 @@ class FractionalMaxPoolGradTest(test.TestCase): stride_size, padding) output_data = output_tensor.eval() output_backprop = self._PRNG.randint(100, size=output_data.shape) - input_backprop_tensor = gen_nn_ops._max_pool_grad(input_tensor, - output_tensor, - output_backprop, - window_size, - stride_size, - padding) + input_backprop_tensor = gen_nn_ops.max_pool_grad( + input_tensor, output_tensor, output_backprop, window_size, + stride_size, padding) input_backprop = input_backprop_tensor.eval() row_seq = list(range(0, num_rows + 1, row_window_size)) col_seq = list(range(0, num_cols + 1, col_window_size)) - fmp_input_backprop_tensor = gen_nn_ops._fractional_max_pool_grad( + fmp_input_backprop_tensor = gen_nn_ops.fractional_max_pool_grad( input_tensor, output_tensor, output_backprop, @@ -422,18 +419,15 @@ class FractionalMaxPoolGradTest(test.TestCase): stride_size, padding) output_data = output_tensor.eval() output_backprop = self._PRNG.randint(100, size=output_data.shape) - input_backprop_tensor = gen_nn_ops._max_pool_grad(input_tensor, - output_tensor, - output_backprop, - window_size, - stride_size, - padding) + input_backprop_tensor = gen_nn_ops.max_pool_grad( + input_tensor, output_tensor, output_backprop, window_size, + stride_size, padding) input_backprop = input_backprop_tensor.eval() row_seq = list(range(0, num_rows, row_window_size - 1)) col_seq = list(range(0, num_cols, col_window_size - 1)) row_seq[-1] += 1 col_seq[-1] += 1 - fmp_input_backprop_tensor = gen_nn_ops._fractional_max_pool_grad( + fmp_input_backprop_tensor = gen_nn_ops.fractional_max_pool_grad( input_tensor, output_tensor, output_backprop, @@ -591,7 +585,7 @@ class FractionalMaxPoolGradTest(test.TestCase): output_tensor = constant_op.constant( output_data_not_overlapping, shape=output_size) grad = constant_op.constant(output_backprop, shape=output_size) - r = gen_nn_ops._fractional_max_pool_grad( + r = gen_nn_ops.fractional_max_pool_grad( input_tensor, output_tensor, grad, @@ -606,7 +600,7 @@ class FractionalMaxPoolGradTest(test.TestCase): # Test when overlapping is True output_tensor = constant_op.constant( output_data_overlapping, shape=output_size) - r = gen_nn_ops._fractional_max_pool_grad( + r = gen_nn_ops.fractional_max_pool_grad( input_tensor, output_tensor, grad, row_seq, col_seq, overlapping=True) input_backprop_overlapping = r.eval() self.assertShapeEqual( diff --git a/tensorflow/python/kernel_tests/identity_op_py_test.py b/tensorflow/python/kernel_tests/identity_op_py_test.py index 2cfe420bd49ec44815d1386bd873b234d8710e9d..49fb76d5b41de18ed3ba2187e85cb288e7344c38 100644 --- a/tensorflow/python/kernel_tests/identity_op_py_test.py +++ b/tensorflow/python/kernel_tests/identity_op_py_test.py @@ -65,7 +65,7 @@ class IdentityOpTest(test.TestCase): constant_op.constant( [[1, 2, 3], [6, 5, 4]], dtype=dtypes.int32)) self.assertEquals(shape, tensor.get_shape()) - self.assertEquals(shape, gen_array_ops._ref_identity(tensor).get_shape()) + self.assertEquals(shape, gen_array_ops.ref_identity(tensor).get_shape()) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py index 19a7d2f9d51fff46ee817ad03ef62383f6727791..c1755985ee85c62005c8d3d5fb916859193aa5f3 100644 --- a/tensorflow/python/kernel_tests/init_ops_test.py +++ b/tensorflow/python/kernel_tests/init_ops_test.py @@ -25,10 +25,13 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -571,6 +574,82 @@ class OrthogonalInitializerTest(test.TestCase): np.dot(t, t.T), np.eye(t.shape[0]), rtol=tol, atol=tol) +class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase): + + def testInitializerIdentical(self): + for dtype in [dtypes.float32, dtypes.float64]: + init1 = init_ops.convolutional_delta_orthogonal(seed=1, dtype=dtype) + init2 = init_ops.convolutional_delta_orthogonal(seed=1, dtype=dtype) + self.assertTrue(identicaltest(self, init1, init2, (3, 3, 10, 10))) + + def testInitializerDifferent(self): + for dtype in [dtypes.float32, dtypes.float64]: + init1 = init_ops.convolutional_delta_orthogonal(seed=1, dtype=dtype) + init2 = init_ops.convolutional_delta_orthogonal(seed=2, dtype=dtype) + self.assertFalse(identicaltest(self, init1, init2, (3, 3, 10, 10))) + + def testDuplicatedInitializer(self): + init = init_ops.convolutional_delta_orthogonal() + self.assertFalse(duplicated_initializer(self, init, 1, (3, 3, 10, 10))) + + def testInvalidDataType(self): + self.assertRaises( + ValueError, init_ops.convolutional_delta_orthogonal, + dtype=dtypes.string) + + def testInvalidShape(self): + init1 = init_ops.convolutional_delta_orthogonal() + with self.test_session(graph=ops.Graph(), use_gpu=True): + self.assertRaises(ValueError, init1, shape=[3, 3, 6, 5]) + + def testGain(self): + shape = (3, 3, 10, 10) + for dtype in [dtypes.float32, dtypes.float64]: + init1 = init_ops.convolutional_delta_orthogonal(seed=1, dtype=dtype) + init2 = init_ops.convolutional_delta_orthogonal(gain=3.14, + seed=1, dtype=dtype) + with self.test_session(graph=ops.Graph(), use_gpu=True): + t1 = init1(shape).eval() + with self.test_session(graph=ops.Graph(), use_gpu=True): + t2 = init2(shape).eval() + return np.allclose(t1, t2 / 3.14, rtol=1e-15, atol=1e-15) + + def testShapesValues(self): + for dtype in [dtypes.float32]: + for kernel_size in [[3], [8], [3, 5], [2, 4], [3, 3, 3], [2, 2, 2]]: + tol = 1e-2 + # Check orthogonality by computing the 2-norms of the inputs and ouputs. + if len(kernel_size) == 1: + shape = [4, 32, 64] + convolution = convolutional.conv1d + elif len(kernel_size) == 2: + convolution = convolutional.conv2d + shape = [4, 32, 32, 64] + else: + shape = [4, 16, 16, 16, 64] + convolution = convolutional.conv3d + inputs = random_ops.random_normal(shape, dtype=dtype) + inputs_2norm = linalg_ops.norm(inputs) + outputs = convolution( + inputs, padding="same", filters=128, + kernel_size=kernel_size, use_bias=False, + kernel_initializer=init_ops.convolutional_delta_orthogonal( + gain=3.14)) + outputs_shape = shape[0:-1] + [128] + outputs_2norm = linalg_ops.norm(outputs) + my_ops = variables.global_variables_initializer() + with self.test_session(use_gpu=True) as sess: + sess.run(my_ops) + # Check the shape of the outputs + t = outputs.eval() + self.assertAllEqual(t.shape, outputs_shape) + # Check isometry of the delta-orthogonal kernel. + self.assertAllClose( + sess.run(inputs_2norm)/np.sqrt(np.prod(shape)), + sess.run(outputs_2norm)/(np.sqrt(np.prod(shape))*np.sqrt(3.14)), + rtol=tol, atol=tol) + + class IdentityInitializerTest(test.TestCase): def testInvalidDataType(self): diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py index 4d79365dbefc74fe8412b65ec089fb2af4255aea..f96b9ccdaacae7d8e0552ed3d74ce53808fed963 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py @@ -44,9 +44,9 @@ class SquareLinearOperatorCompositionTest( self._rtol[dtypes.float32] = 1e-4 self._rtol[dtypes.complex64] = 1e-4 - def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder): + def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): sess = ops.get_default_session() - shape = list(shape) + shape = list(build_info.shape) # Either 1 or 2 matrices, depending. num_operators = rng.randint(low=1, high=3) @@ -148,9 +148,9 @@ class NonSquareLinearOperatorCompositionTest( self._rtol[dtypes.float32] = 1e-4 self._rtol[dtypes.complex64] = 1e-4 - def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder): + def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): sess = ops.get_default_session() - shape = list(shape) + shape = list(build_info.shape) # Test only the case of 2 matrices. # The Square test uses either 1 or 2, so we have tested the case of 1 matrix diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py index 8cb9f9e6213cda8daae7b629fc31d4721fd48fa7..0a0e31c716ecfa10ed93cff92fa908a240f8495e 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py @@ -34,7 +34,8 @@ class LinearOperatorDiagTest( linear_operator_test_util.SquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" - def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder): + def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + shape = list(build_info.shape) diag = linear_operator_test_util.random_sign_uniform( shape[:-1], minval=1., maxval=2., dtype=dtype) if use_placeholder: diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py index 50d6f524e9ad75715d7a57348638fdfeee667f40..b3da623b5e8d8c99c6777e75e2d49f24dab1c96b 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py @@ -36,11 +36,11 @@ class SquareLinearOperatorFullMatrixTest( linear_operator_test_util.SquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" - def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder): - shape = list(shape) + def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + shape = list(build_info.shape) - matrix = linear_operator_test_util.random_positive_definite_matrix(shape, - dtype) + matrix = linear_operator_test_util.random_positive_definite_matrix( + shape, dtype) if use_placeholder: matrix_ph = array_ops.placeholder(dtype=dtype) @@ -136,8 +136,8 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest( def _dtypes_to_test(self): return [dtypes.float32, dtypes.float64] - def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder): - shape = list(shape) + def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + shape = list(build_info.shape) matrix = linear_operator_test_util.random_positive_definite_matrix( shape, dtype, force_well_conditioned=True) @@ -210,7 +210,8 @@ class NonSquareLinearOperatorFullMatrixTest( linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" - def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder): + def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + shape = list(build_info.shape) matrix = linear_operator_test_util.random_normal(shape, dtype=dtype) if use_placeholder: matrix_ph = array_ops.placeholder(dtype=dtype) diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py index 6d635707683f4500919073a4f43c320a44b65018..59f63f949e96991193412d3574603e58a75cb6e5 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py @@ -43,8 +43,8 @@ class LinearOperatorIdentityTest( # 16bit. return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] - def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder): - shape = list(shape) + def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + shape = list(build_info.shape) assert shape[-1] == shape[-2] batch_shape = shape[:-2] @@ -261,8 +261,8 @@ class LinearOperatorScaledIdentityTest( # 16bit. return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] - def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder): - shape = list(shape) + def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + shape = list(build_info.shape) assert shape[-1] == shape[-2] batch_shape = shape[:-2] diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py index d3a47da946b12277c4c390a4a320d7c91ed81b32..8095f6419ef0d9543339cf1f4ee9cd4783f852b9 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py @@ -55,16 +55,22 @@ class BaseLinearOperatorLowRankUpdatetest(object): return [dtypes.float32, dtypes.float64] @property - def _shapes_to_test(self): + def _operator_build_infos(self): + build_info = linear_operator_test_util.OperatorBuildInfo # Previously we had a (2, 10, 10) shape at the end. We did this to test the # inversion and determinant lemmas on not-tiny matrices, since these are # known to have stability issues. This resulted in test timeouts, so this # shape has been removed, but rest assured, the tests did pass. - return [(0, 0), (1, 1), (1, 3, 3), (3, 4, 4), (2, 1, 4, 4)] - - def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder): + return [ + build_info((0, 0)), + build_info((1, 1)), + build_info((1, 3, 3)), + build_info((3, 4, 4)), + build_info((2, 1, 4, 4))] + + def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): # Recall A = L + UDV^H - shape = list(shape) + shape = list(build_info.shape) diag_shape = shape[:-1] k = shape[-2] // 2 + 1 u_perturbation_shape = shape[:-1] + [k] diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py index db3918f9983c5b7d05fa4ba3bc85b26a485f2f00..a57d2f085e089fb913f09fdd9b07cf13aa7f3c35 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py @@ -38,7 +38,8 @@ class LinearOperatorLowerTriangularTest( # matrix_triangular_solve. return [dtypes.float32, dtypes.float64] - def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder): + def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + shape = list(build_info.shape) # Upper triangle will be nonzero, but ignored. # Use a diagonal that ensures this matrix is well conditioned. tril = linear_operator_test_util.random_tril_matrix( diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index 1577b7bc8021a326eb720bdf059b8d1c568c0cc1..dbbed39c727f01ed1fae271375575c690958c7d8 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -30,7 +30,9 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import list_ops +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test from tensorflow.python.training import server_lib @@ -123,6 +125,78 @@ class ListOpsTest(test_util.TensorFlowTestCase): l_cpu, element_dtype=dtypes.float32)[1], 2.0) + def testGraphStack(self): + with context.graph_mode(), self.test_session(): + tl = list_ops.empty_tensor_list( + element_shape=constant_op.constant([1], dtype=dtypes.int32), + element_dtype=dtypes.int32) + tl = list_ops.tensor_list_push_back(tl, [1]) + self.assertAllEqual( + list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32).eval(), + [[1]]) + + def testGraphStackInLoop(self): + with context.graph_mode(), self.test_session(): + t1 = list_ops.empty_tensor_list( + element_shape=constant_op.constant([], dtype=dtypes.int32), + element_dtype=dtypes.int32) + i = constant_op.constant(0, dtype=dtypes.int32) + + def body(i, t1): + t1 = list_ops.tensor_list_push_back(t1, i) + i += 1 + return i, t1 + + i, t1 = control_flow_ops.while_loop(lambda i, t1: math_ops.less(i, 4), + body, [i, t1]) + s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32).eval() + self.assertAllEqual(s1, [0, 1, 2, 3]) + + def testGraphStackSwitchDtype(self): + with context.graph_mode(), self.test_session(): + list_ = list_ops.empty_tensor_list( + element_shape=constant_op.constant([], dtype=dtypes.int32), + element_dtype=dtypes.int32) + m = constant_op.constant([1, 2, 3], dtype=dtypes.float32) + + def body(list_, m): + list_ = control_flow_ops.cond( + math_ops.equal(list_ops.tensor_list_length(list_), 0), + lambda: list_ops.empty_tensor_list(m.shape, m.dtype), lambda: list_) + list_ = list_ops.tensor_list_push_back(list_, m) + return list_, m + + for _ in range(2): + list_, m = body(list_, m) + + s1 = list_ops.tensor_list_stack( + list_, element_dtype=dtypes.float32).eval() + np_s1 = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) + self.assertAllEqual(s1, np_s1) + + def testGraphStackInLoopSwitchDtype(self): + with context.graph_mode(), self.test_session(): + t1 = list_ops.empty_tensor_list( + element_shape=constant_op.constant([], dtype=dtypes.int32), + element_dtype=dtypes.int32) + i = constant_op.constant(0, dtype=dtypes.float32) + m = constant_op.constant([1, 2, 3], dtype=dtypes.float32) + + def body(i, m, t1): + t1 = control_flow_ops.cond( + math_ops.equal(list_ops.tensor_list_length(t1), 0), + lambda: list_ops.empty_tensor_list(m.shape, m.dtype), lambda: t1) + + t1 = list_ops.tensor_list_push_back(t1, m * i) + i += 1.0 + return i, m, t1 + + i, m, t1 = control_flow_ops.while_loop( + lambda i, m, t1: math_ops.less(i, 4), body, [i, m, t1]) + s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.float32).eval() + np_s1 = np.vstack([np.arange(1, 4) * i for i in range(4)]) + self.assertAllEqual(s1, np_s1) + def testSerialize(self): # pylint: disable=g-import-not-at-top try: diff --git a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py index 6203a412d7faec4fe9f6179141301579b5900291..a0c66c77d8850d3144678870983730537a253556 100644 --- a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py @@ -48,7 +48,7 @@ class ExponentialOpTest(test.TestCase): def _verifyExponential(self, x, np_type): inp = x.astype(np_type) with self.test_session(use_gpu=True): - tf_ans = gen_linalg_ops._matrix_exponential(inp) + tf_ans = gen_linalg_ops.matrix_exponential(inp) if x.size == 0: np_ans = np.empty(x.shape, dtype=np_type) else: @@ -116,13 +116,13 @@ class ExponentialOpTest(test.TestCase): # When the exponential of a non-square matrix is attempted we should return # an error with self.assertRaises(ValueError): - gen_linalg_ops._matrix_exponential(np.array([[1., 2., 3.], [3., 4., 5.]])) + gen_linalg_ops.matrix_exponential(np.array([[1., 2., 3.], [3., 4., 5.]])) def testWrongDimensions(self): # The input to the exponential should be at least a 2-dimensional tensor. tensor3 = constant_op.constant([1., 2.]) with self.assertRaises(ValueError): - gen_linalg_ops._matrix_exponential(tensor3) + gen_linalg_ops.matrix_exponential(tensor3) def testEmpty(self): self._verifyExponentialReal(np.empty([0, 2, 2])) @@ -143,8 +143,8 @@ class ExponentialOpTest(test.TestCase): with self.test_session(use_gpu=True) as sess: matrix1 = random_ops.random_normal([5, 5], seed=42) matrix2 = random_ops.random_normal([5, 5], seed=42) - expm1 = gen_linalg_ops._matrix_exponential(matrix1) - expm2 = gen_linalg_ops._matrix_exponential(matrix2) + expm1 = gen_linalg_ops.matrix_exponential(matrix1) + expm2 = gen_linalg_ops.matrix_exponential(matrix2) expm = sess.run([expm1, expm2]) self.assertAllEqual(expm[0], expm[1]) @@ -180,7 +180,7 @@ class MatrixExponentialBenchmark(test.Benchmark): session.Session() as sess, \ ops.device("/cpu:0"): matrix = self._GenerateMatrix(shape) - expm = gen_linalg_ops._matrix_exponential(matrix) + expm = gen_linalg_ops.matrix_exponential(matrix) variables.global_variables_initializer().run() self.run_op_benchmark( sess, diff --git a/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py b/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py index 18ed59828c15f5ad21fe054cd6e40991c02bb356..24edc4f59fe6dd84da6732036eb53e2ad367bd06 100644 --- a/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py @@ -39,8 +39,8 @@ class LogarithmOpTest(test.TestCase): inp = x.astype(np_type) with self.test_session(use_gpu=True): # Verify that expm(logm(A)) == A. - tf_ans = gen_linalg_ops._matrix_exponential( - gen_linalg_ops._matrix_logarithm(inp)) + tf_ans = gen_linalg_ops.matrix_exponential( + gen_linalg_ops.matrix_logarithm(inp)) out = tf_ans.eval() self.assertAllClose(inp, out, rtol=1e-4, atol=1e-3) @@ -85,14 +85,14 @@ class LogarithmOpTest(test.TestCase): # When the logarithm of a non-square matrix is attempted we should return # an error with self.assertRaises(ValueError): - gen_linalg_ops._matrix_logarithm( + gen_linalg_ops.matrix_logarithm( np.array([[1., 2., 3.], [3., 4., 5.]], dtype=np.complex64)) def testWrongDimensions(self): # The input to the logarithm should be at least a 2-dimensional tensor. tensor3 = constant_op.constant([1., 2.], dtype=dtypes.complex64) with self.assertRaises(ValueError): - gen_linalg_ops._matrix_logarithm(tensor3) + gen_linalg_ops.matrix_logarithm(tensor3) def testEmpty(self): self._verifyLogarithmComplex(np.empty([0, 2, 2], dtype=np.complex64)) @@ -115,8 +115,8 @@ class LogarithmOpTest(test.TestCase): random_ops.random_normal([5, 5], seed=42), dtypes.complex64) matrix2 = math_ops.cast( random_ops.random_normal([5, 5], seed=42), dtypes.complex64) - logm1 = gen_linalg_ops._matrix_logarithm(matrix1) - logm2 = gen_linalg_ops._matrix_logarithm(matrix2) + logm1 = gen_linalg_ops.matrix_logarithm(matrix1) + logm2 = gen_linalg_ops.matrix_logarithm(matrix2) logm = sess.run([logm1, logm2]) self.assertAllEqual(logm[0], logm[1]) @@ -152,7 +152,7 @@ class MatrixLogarithmBenchmark(test.Benchmark): session.Session() as sess, \ ops.device("/cpu:0"): matrix = self._GenerateMatrix(shape) - logm = gen_linalg_ops._matrix_logarithm(matrix) + logm = gen_linalg_ops.matrix_logarithm(matrix) variables.global_variables_initializer().run() self.run_op_benchmark( sess, diff --git a/tensorflow/python/kernel_tests/metrics_test.py b/tensorflow/python/kernel_tests/metrics_test.py index 59e7afa2dcb1e02ed9c66e5cf75753f96552b4e0..ad802f7e1f72f6cbc3dda1ca98e46e6da4e5110a 100644 --- a/tensorflow/python/kernel_tests/metrics_test.py +++ b/tensorflow/python/kernel_tests/metrics_test.py @@ -1132,9 +1132,9 @@ class AUCTest(test.TestCase): auc, update_op = metrics.auc(labels, predictions, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.54166, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(0.79166, sess.run(update_op), delta=1e-3) - self.assertAlmostEqual(0.54166, auc.eval(), delta=1e-3) + self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3) def testAnotherAUCPRSpecialCase(self): with self.test_session() as sess: @@ -1146,9 +1146,9 @@ class AUCTest(test.TestCase): auc, update_op = metrics.auc(labels, predictions, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.44365042, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(0.610317, sess.run(update_op), delta=1e-3) - self.assertAlmostEqual(0.44365042, auc.eval(), delta=1e-3) + self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3) def testThirdAUCPRSpecialCase(self): with self.test_session() as sess: @@ -1160,26 +1160,9 @@ class AUCTest(test.TestCase): auc, update_op = metrics.auc(labels, predictions, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.73611039, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(0.90277, sess.run(update_op), delta=1e-3) - self.assertAlmostEqual(0.73611039, auc.eval(), delta=1e-3) - - def testFourthAUCPRSpecialCase(self): - # Create the labels and data. - labels = np.array([ - 0, 0, 0, 0, 0, 0, 0, 1, 0, 1]) - predictions = np.array([ - 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35]) - - with self.test_session() as sess: - auc, _ = metrics.auc( - labels, predictions, curve='PR', num_thresholds=11) - - sess.run(variables.local_variables_initializer()) - # Since this is only approximate, we can't expect a 6 digits match. - # Although with higher number of samples/thresholds we should see the - # accuracy improving - self.assertAlmostEqual(0.0, auc.eval(), delta=0.001) + self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-3) def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) @@ -1205,16 +1188,16 @@ class AUCTest(test.TestCase): self.assertAlmostEqual(1, auc.eval(), 6) - def testRecallOneAndPrecisionOne(self): + def testRecallOneAndPrecisionOneGivesOnePRAUC(self): with self.test_session() as sess: predictions = array_ops.ones([4], dtype=dtypes_lib.float32) labels = array_ops.ones([4]) auc, update_op = metrics.auc(labels, predictions, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.5, sess.run(update_op), 6) + self.assertAlmostEqual(1, sess.run(update_op), 6) - self.assertAlmostEqual(0.5, auc.eval(), 6) + self.assertAlmostEqual(1, auc.eval(), 6) def np_auc(self, predictions, labels, weights): """Computes the AUC explicitly using Numpy. diff --git a/tensorflow/python/kernel_tests/pad_op_test.py b/tensorflow/python/kernel_tests/pad_op_test.py index 2c766e364073fc8c92156f19d08753367982e7fc..361853448ce2c8477af6920257c58c1eba0fa952 100644 --- a/tensorflow/python/kernel_tests/pad_op_test.py +++ b/tensorflow/python/kernel_tests/pad_op_test.py @@ -215,13 +215,13 @@ class PadOpTest(test.TestCase): def testIntTypes(self): # TODO(touts): Figure out why the padding tests do not work on GPU # for int types and rank > 2. - for t in [np.int32, np.int64]: + for t in [np.int8, np.int32, np.int64]: self._testAll( np.random.randint(-100, 100, (4, 4, 3)).astype(t), [[1, 0], [2, 3], [0, 2]], 0) self._testAll( np.random.randint(-100, 100, (4, 2, 1, 3)).astype(t), - [[0, 0], [0, 0], [0, 0], [0, 0]], -1234) + [[0, 0], [0, 0], [0, 0], [0, 0]], -123) def testFloatTypes(self): for t in [np.float32, np.float64]: @@ -238,6 +238,29 @@ class PadOpTest(test.TestCase): x = np.random.rand(3, 2, 1, 1).astype(t) self._testAll(x + 1j * x, [[0, 0], [0, 0], [0, 0], [0, 0]], 0 + 0j) + def testString(self): + # Numpy does not support padding strings so we compare padding manually. + x = ops.convert_to_tensor([["Hello", "World"], + ["Goodnight", "Moon"]]) + + constant = array_ops.pad(x, [[1, 0], [0, 1]], mode="CONSTANT", + constant_values="PAD") + reflect = array_ops.pad(x, [[1, 0], [0, 1]], mode="REFLECT", + constant_values="PAD") + symmetric = array_ops.pad(x, [[1, 0], [0, 1]], mode="SYMMETRIC", + constant_values="PAD") + with self.test_session(use_gpu=True): + self.assertAllEqual([[b"PAD", b"PAD", b"PAD"], + [b"Hello", b"World", b"PAD"], + [b"Goodnight", b"Moon", b"PAD"]], constant.eval()) + self.assertAllEqual([[b"Goodnight", b"Moon", b"Goodnight"], + [b"Hello", b"World", b"Hello"], + [b"Goodnight", b"Moon", b"Goodnight"]], + reflect.eval()) + self.assertAllEqual([[b"Hello", b"World", b"World"], + [b"Hello", b"World", b"World"], + [b"Goodnight", b"Moon", b"Moon"]], symmetric.eval()) + def testShapeFunctionEdgeCases(self): # Unknown paddings shape. inp = constant_op.constant(0.0, shape=[4, 4, 4, 4]) @@ -313,5 +336,32 @@ class PadOpTest(test.TestCase): self.assertAllEqual(inp, out) self.assertShapeEqual(inp, tf_val) + def testCollapseAdjacentNonPaddedDimensions(self): + # pyformat: disable + paddings_values = [[[0, 0], [0, 0], [0, 0], [0, 1]], + [[0, 0], [2, 3], [0, 0], [0, 0]], + [[0, 0], [0, 0], [0, 0], [0, 0]]] + # pyformat: enable + for paddings_value in paddings_values: + for dtype in [dtypes.float32, dtypes.int32]: + inp = constant_op.constant(1, shape=[8, 28, 28, 3], dtype=dtype) + paddings = constant_op.constant(paddings_value, dtype=dtypes.int32) + padded = array_ops.pad(inp, paddings) + middle = array_ops.slice(padded, [row[0] for row in paddings_value], + [dim.value for dim in inp.shape.dims]) + left = array_ops.slice(padded, [0, 0, 0, 0], + [row[0] for row in paddings_value]) + right = array_ops.slice( + padded, + [paddings_value[i][0] + inp.shape.dims[i].value for i in range(4)], + [-1, -1, -1, -1]) + with self.test_session(use_gpu=True): + self.assertAllEqual(inp.eval(), middle.eval()) + self.assertAllEqual( + np.zeros([row[0] for row in paddings_value]), left.eval()) + self.assertAllEqual( + np.zeros([row[1] for row in paddings_value]), right.eval()) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py index 4466beeec96509b3761e34d885276e1510c62d10..ed44a1a4d16a94d3aa75a50bf059e33326757c4d 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_test.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import variables import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging @@ -122,8 +123,9 @@ class PoolingTest(test.TestCase): if input_sizes[-1] % 4 != 0: tf_logging.info("Skipping test for depth %d", input_sizes[-1]) return - tf_logging.info("Running %s test. %r %r %d %r %r %r", data_format, v2, - input_sizes, total_size, pool_func, ksize, strides) + tf_logging.info("Running %s test. %r %r %d %r %r %r %s", data_format, v2, + input_sizes, total_size, pool_func, ksize, strides, + data_type) # Initializes the input tensor with array containing incrementing # numbers from 1, wrapping round to -127 after 127 to support int8. x = [((f + 128) % 255) - 127 for f in range(total_size)] @@ -192,6 +194,8 @@ class PoolingTest(test.TestCase): self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding, data_format, dtypes.float32, expected, use_gpu, v2) + self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding, + data_format, dtypes.float64, expected, use_gpu, v2) if not use_gpu or test_util.CudaSupportsHalfMatMulAndConv(): self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding, @@ -405,7 +409,7 @@ class PoolingTest(test.TestCase): for v2 in [True, False]: self._VerifyValues( - gen_nn_ops._max_pool_v2, + gen_nn_ops.max_pool_v2, input_sizes=[1, 3, 3, 3], ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], @@ -427,7 +431,7 @@ class PoolingTest(test.TestCase): for v2 in [True, False]: self._VerifyValues( - gen_nn_ops._max_pool_v2, + gen_nn_ops.max_pool_v2, input_sizes=[1, 2, 3, 3], ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], @@ -456,7 +460,7 @@ class PoolingTest(test.TestCase): for v2 in [True, False]: self._VerifyValues( - gen_nn_ops._max_pool_v2, + gen_nn_ops.max_pool_v2, input_sizes=[1, 2, 2, 1], ksize=[1, 1, 2, 1], strides=[1, 1, 1, 1], @@ -485,7 +489,7 @@ class PoolingTest(test.TestCase): for v2 in [True, False]: self._VerifyValues( - gen_nn_ops._max_pool_v2, + gen_nn_ops.max_pool_v2, input_sizes=[1, 4, 4, 1], ksize=[1, 2, 2, 1], strides=[1, 1, 2, 1], @@ -494,7 +498,7 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu, v2=v2) self._VerifyValues( - gen_nn_ops._max_pool_v2, + gen_nn_ops.max_pool_v2, input_sizes=[1, 4, 4, 1], ksize=[1, 2, 2, 1], strides=[1, 2, 1, 1], @@ -519,7 +523,7 @@ class PoolingTest(test.TestCase): for v2 in [True, False]: self._VerifyValues( - gen_nn_ops._max_pool_v2, + gen_nn_ops.max_pool_v2, input_sizes=[1, 4, 4, 4], ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], @@ -554,7 +558,7 @@ class PoolingTest(test.TestCase): for v2 in [True, False]: self._VerifyValues( - gen_nn_ops._max_pool_v2, + gen_nn_ops.max_pool_v2, input_sizes=[1, 8, 8, 8], ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], @@ -565,7 +569,7 @@ class PoolingTest(test.TestCase): def _testMaxPoolEmptyInput(self, use_gpu): self._VerifyValues( - gen_nn_ops._max_pool_v2, + gen_nn_ops.max_pool_v2, input_sizes=[0, 8, 8, 8], ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], @@ -600,7 +604,7 @@ class PoolingTest(test.TestCase): for v2 in [True, False]: self._VerifyValues( - gen_nn_ops._max_pool_v2, + gen_nn_ops.max_pool_v2, input_sizes=[1, 1, 1, 10], ksize=[1, 1, 1, 2], strides=[1, 1, 1, 2], @@ -626,7 +630,7 @@ class PoolingTest(test.TestCase): for v2 in [True, False]: self._VerifyValues( - gen_nn_ops._max_pool_v2, + gen_nn_ops.max_pool_v2, input_sizes=[1, 2, 2, 6], ksize=[1, 1, 1, 3], strides=[1, 1, 1, 3], @@ -648,7 +652,7 @@ class PoolingTest(test.TestCase): for v2 in [True, False]: self._VerifyValues( - gen_nn_ops._max_pool_v2, + gen_nn_ops.max_pool_v2, input_sizes=[1, 7, 7, 1], ksize=[1, 2, 2, 1], strides=[1, 3, 3, 1], @@ -689,7 +693,7 @@ class PoolingTest(test.TestCase): for v2 in [True, False]: self._VerifyValues( - gen_nn_ops._max_pool_v2, + gen_nn_ops.max_pool_v2, input_sizes=[1, 3, 3, 1], ksize=[1, 1, 1, 1], strides=[1, 2, 2, 1], @@ -699,7 +703,7 @@ class PoolingTest(test.TestCase): v2=v2) self._VerifyValues( - gen_nn_ops._max_pool_v2, + gen_nn_ops.max_pool_v2, input_sizes=[1, 4, 4, 1], ksize=[1, 1, 1, 1], strides=[1, 2, 2, 1], @@ -731,7 +735,8 @@ class PoolingTest(test.TestCase): [1, 1, 1, 3], "evenly divide") if test.is_gpu_available(): with self.test_session(use_gpu=True): - t = constant_op.constant(1.0, shape=[1, 2, 2, 4]) + t = variables.Variable(np.ones([1, 2, 2, 4])) + variables.global_variables_initializer().run() with self.assertRaisesOpError("for CPU devices"): nn_ops.max_pool( t, ksize=[1, 1, 1, 2], strides=[1, 1, 1, 2], @@ -764,8 +769,8 @@ class PoolingTest(test.TestCase): _, argmax_op = nn_ops.max_pool_with_argmax(t, ksize, strides, padding) argmax = argmax_op.eval() grad_in = constant_op.constant(tensor_output, shape=output_shape) - out_op = gen_nn_ops._max_pool_grad_with_argmax(t, grad_in, argmax, - ksize, strides, padding) + out_op = gen_nn_ops.max_pool_grad_with_argmax(t, grad_in, argmax, ksize, + strides, padding) gpu_val = out_op.eval() self.assertShapeEqual(gpu_val, out_op) with self.test_session(use_gpu=False): @@ -773,8 +778,8 @@ class PoolingTest(test.TestCase): out_op = nn_ops.max_pool(t, ksize, strides, padding) orig_out = out_op.eval() grad_in = constant_op.constant(tensor_output, shape=output_shape) - out_op = gen_nn_ops._max_pool_grad(t, orig_out, grad_in, ksize, strides, - padding) + out_op = gen_nn_ops.max_pool_grad(t, orig_out, grad_in, ksize, strides, + padding) cpu_val = out_op.eval() self.assertShapeEqual(cpu_val, out_op) # The CPU version accumulates its gradient on fp16, so it's less @@ -793,7 +798,7 @@ class PoolingTest(test.TestCase): _, argmax_op = nn_ops.max_pool_with_argmax(t, ksize, strides, padding) argmax = argmax_op.eval() grad_in = constant_op.constant(tensor_input, shape=input_shape) - out_op = gen_nn_ops._max_pool_grad_grad_with_argmax( + out_op = gen_nn_ops.max_pool_grad_grad_with_argmax( t, grad_in, argmax, ksize, strides, padding) gpu_val = out_op.eval() self.assertShapeEqual(gpu_val, out_op) @@ -802,8 +807,8 @@ class PoolingTest(test.TestCase): out_op = nn_ops.max_pool(t, ksize, strides, padding) orig_out = out_op.eval() grad_in = constant_op.constant(tensor_input, shape=input_shape) - out_op = gen_nn_ops._max_pool_grad_grad(t, orig_out, grad_in, ksize, - strides, padding) + out_op = gen_nn_ops.max_pool_grad_grad(t, orig_out, grad_in, ksize, + strides, padding) cpu_val = out_op.eval() self.assertShapeEqual(cpu_val, out_op) # The CPU version accumulates its gradient on fp16, so it's less @@ -842,7 +847,7 @@ class PoolingTest(test.TestCase): t = constant_op.constant(tensor_input, shape=[1, 2, 2, 1]) argmax = constant_op.constant( tensor_argmax, shape=[1, 2, 2, 1], dtype=dtypes.int64) - out_op = gen_nn_ops._max_pool_grad_with_argmax( + out_op = gen_nn_ops.max_pool_grad_with_argmax( orig_in, t, argmax, @@ -865,7 +870,7 @@ class PoolingTest(test.TestCase): t = constant_op.constant(tensor_input, shape=[1, 3, 3, 1]) argmax = constant_op.constant( tensor_argmax, shape=[1, 2, 2, 1], dtype=dtypes.int64) - out_op = gen_nn_ops._max_pool_grad_grad_with_argmax( + out_op = gen_nn_ops.max_pool_grad_grad_with_argmax( orig_in, t, argmax, @@ -1029,7 +1034,7 @@ class PoolingTest(test.TestCase): self.assertLess(err, err_tolerance) def _testMaxPoolGradValidPadding1_1(self, data_format, use_gpu): - for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestGradient( pool_func, input_sizes=[1, 3, 3, 1], @@ -1043,7 +1048,7 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu) def _testMaxPoolGradValidPadding2_1_6(self, data_format, use_gpu): - for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestGradient( pool_func, input_sizes=[2, 6, 6, 3], @@ -1057,7 +1062,7 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu) def _testMaxPoolGradValidPadding2_1_7(self, data_format, use_gpu): - for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestGradient( pool_func, input_sizes=[2, 7, 7, 3], @@ -1071,7 +1076,7 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu) def _testMaxPoolGradValidPadding1_2(self, data_format, use_gpu): - for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestGradient( pool_func, input_sizes=[1, 3, 3, 1], @@ -1085,7 +1090,7 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu) def _testMaxPoolGradValidPadding2_2(self, data_format, use_gpu): - for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestGradient( pool_func, input_sizes=[2, 2, 2, 3], @@ -1099,7 +1104,7 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu) def _testMaxPoolGradSamePadding1_1(self, data_format, use_gpu): - for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestGradient( pool_func, input_sizes=[2, 2, 4, 3], @@ -1113,7 +1118,7 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu) def _testMaxPoolGradSamePadding1_2(self, data_format, use_gpu): - for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestGradient( pool_func, input_sizes=[2, 2, 4, 3], @@ -1127,7 +1132,7 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu) def _testMaxPoolGradSamePadding2_1(self, data_format, use_gpu): - for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestGradient( pool_func, input_sizes=[2, 2, 4, 3], @@ -1141,7 +1146,7 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu) def _testMaxPoolGradSamePadding2_2(self, data_format, use_gpu): - for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestGradient( pool_func, input_sizes=[2, 2, 4, 3], @@ -1155,7 +1160,7 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu) def _testMaxPoolGradSamePadding3_1(self, data_format, use_gpu): - for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestGradient( pool_func, input_sizes=[1, 7, 7, 1], @@ -1199,7 +1204,7 @@ class PoolingTest(test.TestCase): Returns: A Tensor. """ - pool_func = gen_nn_ops.max_pool_grad_v2 if v2 else gen_nn_ops._max_pool_grad + pool_func = gen_nn_ops.max_pool_grad_v2 if v2 else gen_nn_ops.max_pool_grad return pool_func(orig_input, orig_output, grad, [1, window_rows, window_cols, 1], [1, row_stride, col_stride, 1], padding) @@ -1208,9 +1213,11 @@ class PoolingTest(test.TestCase): expected_input_backprop, input_sizes, output_sizes, window_rows, window_cols, row_stride, col_stride, padding, use_gpu, v2): - pool_func = gen_nn_ops._max_pool_v2 if v2 else nn_ops.max_pool + pool_func = gen_nn_ops.max_pool_v2 if v2 else nn_ops.max_pool with self.test_session(use_gpu=use_gpu): - input_tensor = constant_op.constant(input_data, shape=input_sizes) + input_tensor = variables.Variable( + np.array(input_data, dtype=np.float32).reshape(input_sizes)) + variables.global_variables_initializer().run() output_tensor = pool_func(input_tensor, [1, window_rows, window_cols, 1], [1, row_stride, col_stride, 1], padding) output_backprop_tensor = constant_op.constant( @@ -1504,7 +1511,7 @@ class PoolingTest(test.TestCase): self._testMaxPoolGradDirectWithNans2_2() def _testMaxPoolGradGradValidPadding1_1(self, data_format, use_gpu): - for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestSecondGradient( pool_func, input_sizes=[1, 3, 3, 1], @@ -1518,7 +1525,7 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu) def _testMaxPoolGradGradValidPadding2_1_6(self, data_format, use_gpu): - for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestSecondGradient( pool_func, input_sizes=[2, 6, 6, 3], @@ -1532,7 +1539,7 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu) def _testMaxPoolGradGradValidPadding2_1_7(self, data_format, use_gpu): - for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestSecondGradient( pool_func, input_sizes=[2, 7, 7, 3], @@ -1546,7 +1553,7 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu) def _testMaxPoolGradGradValidPadding2_2(self, data_format, use_gpu): - for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestSecondGradient( pool_func, input_sizes=[2, 2, 2, 3], @@ -1560,7 +1567,7 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu) def _testMaxPoolGradGradSamePadding1_1(self, data_format, use_gpu): - for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestSecondGradient( pool_func, input_sizes=[2, 2, 4, 3], @@ -1574,7 +1581,7 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu) def _testMaxPoolGradGradSamePadding2_1(self, data_format, use_gpu): - for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestSecondGradient( pool_func, input_sizes=[2, 2, 4, 3], @@ -1588,7 +1595,7 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu) def _testMaxPoolGradGradSamePadding2_2(self, data_format, use_gpu): - for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestSecondGradient( pool_func, input_sizes=[2, 2, 4, 3], @@ -1602,7 +1609,7 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu) def _testMaxPoolGradGradSamePadding3_1(self, data_format, use_gpu): - for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: + for pool_func in [gen_nn_ops.max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestSecondGradient( pool_func, input_sizes=[1, 7, 7, 1], @@ -1644,7 +1651,7 @@ class PoolingTest(test.TestCase): Returns: A Tensor. """ - return gen_nn_ops._max_pool_grad_grad( + return gen_nn_ops.max_pool_grad_grad( orig_input, orig_output, grad, [1, window_rows, window_cols, 1], [1, row_stride, col_stride, 1], padding) diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index 61fb3f12e45ea5ae3bc4f0a26c2116b54c003624..5b508b7c0e72180194fa1a4c95bc4282d4694605 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -19,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import re + import numpy as np from six.moves import queue from six.moves import xrange # pylint: disable=redefined-builtin @@ -33,6 +35,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import script_ops from tensorflow.python.platform import test @@ -356,12 +359,22 @@ class PyFuncTest(test.TestCase): def _testExceptionHandling(self, py_exp, tf_exp, eager=False): - def raise_exception(): + def inner_exception(): raise py_exp("blah") # pylint: disable=not-callable + def raise_exception(): + inner_exception() + + expected_regexp = r": blah.*" # Error at the top + expected_regexp += r"in raise_exception.*" # Stacktrace outer + expected_regexp += r"in inner_exception.*" # Stacktrace inner + expected_regexp += r": blah" # Stacktrace of raise + def expected_error_check(exception): + return re.search(expected_regexp, str(exception), re.DOTALL) + if eager: - if context.in_eager_mode(): - with self.assertRaisesRegexp(tf_exp, "blah"): + if context.executing_eagerly(): + with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check): f = script_ops.eager_py_func(raise_exception, [], []) return else: @@ -370,7 +383,7 @@ class PyFuncTest(test.TestCase): f = script_ops.py_func(raise_exception, [], []) with self.test_session(): - with self.assertRaisesRegexp(tf_exp, "blah"): + with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check): self.evaluate(f) def testExceptionHandling(self): @@ -432,7 +445,7 @@ class PyFuncTest(test.TestCase): output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[]) ret = self.evaluate(output) - if context.in_eager_mode(): + if context.executing_eagerly(): self.assertEquals(len(ret), 0) else: self.assertIsNone(ret) @@ -468,6 +481,18 @@ class PyFuncTest(test.TestCase): self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True) + @test_util.run_in_graph_and_eager_modes() + def testEagerReturningVariableRaisesError(self): + def return_variable(): + variable = resource_variable_ops.ResourceVariable(0.0) + return variable + + with self.assertRaisesRegexp(errors.UnknownError, + "Attempting to return a variable"): + output = script_ops.eager_py_func( + return_variable, inp=[], Tout=dtypes.float32) + self.evaluate(output) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py index d306d1b8d64f292dc299deee2e3c36981b933d1e..589ea54973c10902c461f552d5c54b6fad6ecf67 100644 --- a/tensorflow/python/kernel_tests/reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/reduction_ops_test.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test # The maximum input rank to test. @@ -212,7 +213,8 @@ class SumReductionTest(BaseReductionTest): arr = np.ones([68000], dtype=np.float16) with self.test_session(graph=ops.Graph(), use_gpu=True) as sess: - tf_arr = array_ops.constant(arr) + tf_arr = variables.Variable(arr) + variables.global_variables_initializer().run() tf_mean = math_ops.reduce_mean(tf_arr, 0, False) tf_out_mean = sess.run(tf_mean) self.assertAllClose(tf_out_mean, 1.) diff --git a/tensorflow/python/kernel_tests/regex_replace_op_test.py b/tensorflow/python/kernel_tests/regex_replace_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6739ac32245668e98d37673fe9e9fe9d55cc0c5f --- /dev/null +++ b/tensorflow/python/kernel_tests/regex_replace_op_test.py @@ -0,0 +1,71 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for RegexReplace op from string_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import string_ops +from tensorflow.python.platform import test + + +class RegexReplaceOpTest(test.TestCase): + + def testRemovePrefix(self): + values = ["a:foo", "a:bar", "a:foo", "b:baz", "b:qux", "ca:b"] + with self.test_session(): + input_vector = constant_op.constant(values, dtypes.string) + stripped = string_ops.regex_replace( + input_vector, "^(a:|b:)", "", replace_global=False).eval() + self.assertAllEqual([b"foo", b"bar", b"foo", b"baz", b"qux", b"ca:b"], + stripped) + + def testRegexReplace(self): + values = ["aba\naba", "abcdabcde"] + with self.test_session(): + input_vector = constant_op.constant(values, dtypes.string) + stripped = string_ops.regex_replace(input_vector, "a.*a", "(\\0)").eval() + self.assertAllEqual([b"(aba)\n(aba)", b"(abcda)bcde"], stripped) + + def testEmptyMatch(self): + values = ["abc", "1"] + with self.test_session(): + input_vector = constant_op.constant(values, dtypes.string) + stripped = string_ops.regex_replace(input_vector, "", "x").eval() + self.assertAllEqual([b"xaxbxcx", b"x1x"], stripped) + + def testInvalidPattern(self): + values = ["abc", "1"] + with self.test_session(): + input_vector = constant_op.constant(values, dtypes.string) + invalid_pattern = "A[" + replace = string_ops.regex_replace(input_vector, invalid_pattern, "x") + with self.assertRaisesOpError("Invalid pattern"): + replace.eval() + + def testGlobal(self): + values = ["ababababab", "abcabcabc", ""] + with self.test_session(): + input_vector = constant_op.constant(values, dtypes.string) + stripped = string_ops.regex_replace(input_vector, "ab", "abc", + True).eval() + self.assertAllEqual([b"abcabcabcabcabc", b"abccabccabcc", b""], stripped) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index 6b4091ae5d3c6e469a9cd5237b978eae4c75485f..25e947f09e137b37ea129ba6015a060aa01f02e4 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -19,12 +19,14 @@ from __future__ import division from __future__ import print_function import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables @@ -87,6 +89,35 @@ class ReluTest(test.TestCase): print("relu (float32) gradient err = ", err) self.assertLess(err, 1e-4) + # The gradient for fp16 is inaccurate due to the low-precision. + # Instead of relying on compute_gradient_error, we compare the fp16 analytical + # gradient against their fp32 counterpart. + def testGradientFloat16(self): + with self.test_session(use_gpu=True) as sess: + # Randomly construct a 1D shape from [1, 40) + shape = random_ops.random_uniform( + [1], minval=1, maxval=40, dtype=dtypes.int32) + + # Construct the fp32 graph and its gradient. + x = random_ops.random_uniform(shape, minval=-1, maxval=1, name="x") + y1 = nn_ops.relu(x, name="relu_fp32") + l1 = nn_ops.l2_loss(y1) + dx_f32 = gradients_impl.gradients(l1, x) + + # Construct the fp16 graph and its gradient. + # It starts with the same x, in fp32. But before it reaches Relu, it is + # cast into fp16. So during backprop, the gradient computation is in fp16. + x2 = math_ops.cast(x, dtype=dtypes.float16, name="cast") + y2 = nn_ops.relu(x2, name="relu_fp16") + l2 = nn_ops.l2_loss(y2) + dx_f16 = gradients_impl.gradients(l2, x) + + # Repeat the experiment for 100 times. All tensor shapes and its tensor + # values are randomly generated for each run. + for _ in xrange(100): + dx_f32_v, dx_f16_v = sess.run([dx_f32, dx_f16]) + self.assertAllClose(dx_f32_v, dx_f16_v, atol=3e-4) + def testGradientFloat64(self): with self.test_session(): x = constant_op.constant( diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 8503f3e0310125bb714942b32bbbf46596f9bddb..742564f9bf671bc0da87c8b6d8e3ee6ed0ef2549 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -103,6 +103,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): v = resource_variable_ops.ResourceVariable(False, name="bool_test") self.assertAllEqual(bool(v), False) + def testFetchHandle(self): + with self.test_session(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1], name="foo") + self.assertGreater(len(handle.eval()), 0) + def testAssignVariableDtypeMismatchEager(self): with context.eager_mode(): handle = resource_variable_ops.var_handle_op( @@ -179,6 +185,204 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]]) + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterSub(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[1]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_sub(handle, [0], + constant_op.constant( + [[2]], + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[-1]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterMul(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[1]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_mul(handle, [0], + constant_op.constant( + [[5]], + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[5]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterDiv(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[6]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_div(handle, [0], + constant_op.constant( + [[3]], + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[2]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterMin(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[6]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_min(handle, [0], + constant_op.constant( + [[3]], + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[3]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterMax(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[6]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_max(handle, [0], + constant_op.constant( + [[3]], + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[6]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterAddScalar(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[1]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_add(handle, [0], + constant_op.constant( + 2, + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[3]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterSubScalar(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[1]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_sub(handle, [0], + constant_op.constant( + 2, + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[-1]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterMulScalar(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[1]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_mul(handle, [0], + constant_op.constant( + 5, + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[5]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterDivScalar(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[6]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_div(handle, [0], + constant_op.constant( + 3, + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[2]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterMinScalar(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[6]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_min(handle, [0], + constant_op.constant( + 3, + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[3]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterMaxScalar(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[6]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_max(handle, [0], + constant_op.constant( + 3, + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[6]]) + def testScatterUpdateString(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.string, shape=[1, 1]) @@ -190,6 +394,23 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]), compat.as_bytes("b")) + def testScatterUpdateStringScalar(self): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.string, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [["a"]], + dtype=dtypes.string))) + self.evaluate( + resource_variable_ops.resource_scatter_update(handle, [0], + constant_op.constant( + "b", + dtype=dtypes.string))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string) + self.assertEqual( + compat.as_bytes(self.evaluate(read)[0][0]), compat.as_bytes("b")) + # TODO(alive): get this to work in Eager mode. def testGPU(self): with self.test_session(use_gpu=True): @@ -277,6 +498,17 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.evaluate(v.assign(2.0)) self.assertEqual(2.0, self.evaluate(v.value())) + # Tests for the 'read_value' argument: + assign_with_read = v.assign(3.0, read_value=True) + self.assertEqual(3.0, self.evaluate(assign_with_read)) + assign_without_read = v.assign(4.0, read_value=False) + if context.executing_eagerly(): + self.assertIsNone(assign_without_read) + else: + self.assertIsInstance(assign_without_read, ops.Operation) + self.evaluate(assign_without_read) + self.assertEqual(4.0, self.evaluate(v.value())) + @test_util.run_in_graph_and_eager_modes() def testLoad(self): v = resource_variable_ops.ResourceVariable(1.0, name="var0") @@ -329,6 +561,9 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): w = resource_variable_ops.ResourceVariable.from_proto(v.to_proto()) self.assertEquals(2, math_ops.add(w, 1).eval()) + self.assertEquals(v._handle, w._handle) + self.assertEquals(v._graph_element, w._graph_element) + @test_util.run_in_graph_and_eager_modes() def testAssignAddMethod(self): v = resource_variable_ops.ResourceVariable(1.0, name="var0") @@ -336,6 +571,17 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.evaluate(v.assign_add(1.0)) self.assertEqual(2.0, self.evaluate(v.value())) + # Tests for the 'read_value' argument: + assign_with_read = v.assign_add(1.0, read_value=True) + self.assertEqual(3.0, self.evaluate(assign_with_read)) + assign_without_read = v.assign_add(1.0, read_value=False) + if context.executing_eagerly(): + self.assertIsNone(assign_without_read) + else: + self.assertIsInstance(assign_without_read, ops.Operation) + self.evaluate(assign_without_read) + self.assertEqual(4.0, self.evaluate(v.value())) + @test_util.run_in_graph_and_eager_modes() def testAssignSubMethod(self): v = resource_variable_ops.ResourceVariable(3.0, name="var0") @@ -343,6 +589,17 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.evaluate(v.assign_sub(1.0)) self.assertEqual(2.0, self.evaluate(v.value())) + # Tests for the 'read_value' argument: + assign_with_read = v.assign_sub(1.0, read_value=True) + self.assertEqual(1.0, self.evaluate(assign_with_read)) + assign_without_read = v.assign_sub(1.0, read_value=False) + if context.executing_eagerly(): + self.assertIsNone(assign_without_read) + else: + self.assertIsInstance(assign_without_read, ops.Operation) + self.evaluate(assign_without_read) + self.assertEqual(0.0, self.evaluate(v.value())) + @test_util.run_in_graph_and_eager_modes() def testDestroyResource(self): v = resource_variable_ops.ResourceVariable(3.0, name="var0") @@ -440,7 +697,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertEqual("(10, 20, 35)", str(v.get_shape())) self.assertEqual("(10, 20, 35)", str(v.value().shape)) self.assertEqual("(3, 20, 35)", str(v.sparse_read([0, 1, 2]).shape)) - if context.in_graph_mode(): + if not context.executing_eagerly(): self.assertEqual( "", str(v.sparse_read(array_ops.placeholder(dtypes.int32)).shape)) @@ -481,7 +738,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertEqual(dtypes.int32, v.dtype) self.assertEqual("foo/var7:0", v.name) self.assertAllEqual([10, 20, 35], v.shape.as_list()) - self.assertEqual(context.get_default_context().device_name, v.device) self.assertTrue(isinstance(v.handle, ops.EagerTensor)) self.assertEqual(constraint, v.constraint) self.assertAllEqual(init.numpy(), v.read_value().numpy()) @@ -551,6 +807,15 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): state_ops.scatter_update(v, [1], [3]) self.assertAllEqual([1.0, 3.0], v.numpy()) + @test_util.run_in_graph_and_eager_modes() + def testScatterUpdateInvalidArgs(self): + v = resource_variable_ops.ResourceVariable([0, 1, 2, 3], name="update") + # The exact error and message differ between graph construction (where the + # error is realized during shape inference at graph construction time) and + # eager execution (where the error is realized during kernel execution). + with self.assertRaisesRegexp(Exception, r"shape.*2.*3"): + state_ops.scatter_update(v, [0, 1], [0, 1, 2]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index daa42938e6af205425d7e423ce162294b9002be4..9a0409c796ab60da3d47cf7d46ef6fbd5bd82394 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -111,10 +111,10 @@ class RNNTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testInvalidSequenceLengthShape(self): cell = Plus1RNNCell() - if context.in_graph_mode(): - inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))] - else: + if context.executing_eagerly(): inputs = [constant_op.constant(np.ones((3, 4)))] + else: + inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))] with self.assertRaisesRegexp(ValueError, "must be a vector"): rnn.dynamic_rnn( cell, @@ -125,38 +125,30 @@ class RNNTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testBatchSizeFromInput(self): cell = Plus1RNNCell() - in_graph_mode = context.in_graph_mode() + in_eager_mode = context.executing_eagerly() # With static batch size - if in_graph_mode: - inputs = array_ops.placeholder(dtypes.float32, shape=(3, 4, 5)) - initial_state = array_ops.placeholder(dtypes.float32, shape=(3, 5)) - else: + if in_eager_mode: inputs = np.zeros((3, 4, 5), dtype=np.float32) initial_state = np.zeros((3, 5), dtype=np.float32) + else: + inputs = array_ops.placeholder(dtypes.float32, shape=(3, 4, 5)) + initial_state = array_ops.placeholder(dtypes.float32, shape=(3, 5)) # - Without initial_state outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) - if in_graph_mode: - self.assertEqual(3, outputs.shape[0].value) - self.assertEqual(3, state.shape[0].value) - else: - self.assertEqual(3, outputs.shape[0]) - self.assertEqual(3, state.shape[0]) + self.assertEqual(3, outputs.shape[0]) + self.assertEqual(3, state.shape[0]) # - With initial_state outputs, state = rnn.dynamic_rnn( cell, inputs, initial_state=initial_state) - if in_graph_mode: - self.assertEqual(3, outputs.shape[0].value) - self.assertEqual(3, state.shape[0].value) - else: - self.assertEqual(3, outputs.shape[0]) - self.assertEqual(3, state.shape[0]) + self.assertEqual(3, outputs.shape[0]) + self.assertEqual(3, state.shape[0]) # Without static batch size - # Tensor shapes are fully determined in Eager mode, so only run this - # test in graph mode. - if in_graph_mode: + # Tensor shapes are fully determined with eager execution enabled, + # so only run this test for graph construction. + if not in_eager_mode: inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 5)) # - Without initial_state outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) @@ -173,56 +165,46 @@ class RNNTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testScalarStateIsAccepted(self): cell = ScalarStateRNNCell() - in_graph_mode = context.in_graph_mode() + in_eager_mode = context.executing_eagerly() - if in_graph_mode: - inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) - else: + if in_eager_mode: inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32) + else: + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) with self.test_session() as sess: outputs, state = rnn.dynamic_rnn( cell, inputs, dtype=dtypes.float32, sequence_length=[4]) - if in_graph_mode: + if not in_eager_mode: outputs, state = sess.run( [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]}) - if in_graph_mode: - self.assertAllEqual(outputs, np.array([[[1], [2], [3], [4]]])) - self.assertEqual(state, 4) - else: - self.assertAllEqual(outputs.numpy(), np.array([[[1], [2], [3], [4]]])) - self.assertEqual(state.numpy(), 4) + self.assertAllEqual([[[1], [2], [3], [4]]], outputs) + self.assertAllEqual(4, state) @test_util.run_in_graph_and_eager_modes() def testTensorArrayStateIsAccepted(self): cell = TensorArrayStateRNNCell() - in_graph_mode = context.in_graph_mode() + in_eager_mode = context.executing_eagerly() - if in_graph_mode: - inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) - else: + if in_eager_mode: inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32) + else: + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) with self.test_session() as sess: outputs, state = rnn.dynamic_rnn( cell, inputs, dtype=dtypes.float32, sequence_length=[4]) state = (state[0], state[1].stack()) - if in_graph_mode: + if not in_eager_mode: outputs, state = sess.run( [outputs, state], feed_dict={ inputs: [[[1], [2], [3], [4]]] }) - if in_graph_mode: - self.assertAllEqual(outputs, np.array([[[1], [2], [3], [4]]])) - self.assertEqual(state[0], 4) - self.assertAllEqual(state[1], np.array([[[1]], [[2]], [[3]], [[4]]])) - else: - self.assertAllEqual(outputs.numpy(), np.array([[[1], [2], [3], [4]]])) - self.assertEqual(state[0].numpy(), 4) - self.assertAllEqual(state[1].numpy(), - np.array([[[1]], [[2]], [[3]], [[4]]])) + self.assertAllEqual([[[1], [2], [3], [4]]], outputs) + self.assertAllEqual(4, state[0]) + self.assertAllEqual([[[1]], [[2]], [[3]], [[4]]], state[1]) ######### Benchmarking RNN code diff --git a/tensorflow/python/kernel_tests/save_restore_ops_test.py b/tensorflow/python/kernel_tests/save_restore_ops_test.py index 1bdfa9ebd8e1a4495e67004f59adfb51bf3a6602..cb9aa1e34d6eb82efa94e60e7b56c26b181cef04 100644 --- a/tensorflow/python/kernel_tests/save_restore_ops_test.py +++ b/tensorflow/python/kernel_tests/save_restore_ops_test.py @@ -31,11 +31,10 @@ class ShardedFileOpsTest(test.TestCase): with session.Session( target="", config=config_pb2.ConfigProto(device_count={"CPU": 2})): self.assertEqual( - gen_io_ops._sharded_filename("foo", 4, 100).eval(), + gen_io_ops.sharded_filename("foo", 4, 100).eval(), b"foo-00004-of-00100") self.assertEqual( - gen_io_ops._sharded_filespec("foo", 100).eval(), - b"foo-?????-of-00100") + gen_io_ops.sharded_filespec("foo", 100).eval(), b"foo-?????-of-00100") class ShapeInferenceTest(test.TestCase): @@ -53,7 +52,7 @@ class ShapeInferenceTest(test.TestCase): [dtypes.float32, dtypes.float32]) def testRestoreSlice(self): - op = gen_io_ops._restore_slice("model", "var", "3 4 0,1:-", dtypes.float32) + op = gen_io_ops.restore_slice("model", "var", "3 4 0,1:-", dtypes.float32) self.assertEqual([1, 4], op.get_shape()) diff --git a/tensorflow/python/kernel_tests/scalar_test.py b/tensorflow/python/kernel_tests/scalar_test.py index e65241981eac2d42207c1de7a261f7936e588f2a..0d8fd232946883ac1d95c4c2d9744af69175ab90 100644 --- a/tensorflow/python/kernel_tests/scalar_test.py +++ b/tensorflow/python/kernel_tests/scalar_test.py @@ -92,11 +92,11 @@ class ScalarTest(test.TestCase): self.check(array_ops.reshape, (7, 1), 'sizes input must be 1-D', [7]) def testShardedFilename(self): - self.check(gen_io_ops._sharded_filename, ('foo', 4, [100]), + self.check(gen_io_ops.sharded_filename, ('foo', 4, [100]), 'must be a scalar', b'foo-00004-of-00100') def testShardedFilespec(self): - self.check(gen_io_ops._sharded_filespec, ('foo', [100]), 'must be a scalar', + self.check(gen_io_ops.sharded_filespec, ('foo', [100]), 'must be a scalar', b'foo-?????-of-00100') def testUnsortedSegmentSum(self): diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py index 7cdf11d88468cabaf32387b0a4bdda760b4af31e..c70a4ffce7be71effe3ea10faa9754ab2b3842ce 100644 --- a/tensorflow/python/kernel_tests/scatter_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_ops_test.py @@ -38,38 +38,100 @@ def _NumpyAdd(ref, indices, updates): ref[indx] += updates[i] +def _NumpyAddScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] += update + + def _NumpySub(ref, indices, updates): for i, indx in np.ndenumerate(indices): ref[indx] -= updates[i] +def _NumpySubScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] -= update + + def _NumpyMul(ref, indices, updates): for i, indx in np.ndenumerate(indices): ref[indx] *= updates[i] +def _NumpyMulScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] *= update + + def _NumpyDiv(ref, indices, updates): for i, indx in np.ndenumerate(indices): ref[indx] /= updates[i] +def _NumpyDivScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] /= update + + +def _NumpyMin(ref, indices, updates): + for i, indx in np.ndenumerate(indices): + ref[indx] = np.minimum(ref[indx], updates[i]) + + +def _NumpyMinScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] = np.minimum(ref[indx], update) + + +def _NumpyMax(ref, indices, updates): + for i, indx in np.ndenumerate(indices): + ref[indx] = np.maximum(ref[indx], updates[i]) + + +def _NumpyMaxScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] = np.maximum(ref[indx], update) + + def _NumpyUpdate(ref, indices, updates): for i, indx in np.ndenumerate(indices): ref[indx] = updates[i] +def _NumpyUpdateScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] = update + + _TF_OPS_TO_NUMPY = { state_ops.scatter_update: _NumpyUpdate, state_ops.scatter_add: _NumpyAdd, state_ops.scatter_sub: _NumpySub, state_ops.scatter_mul: _NumpyMul, state_ops.scatter_div: _NumpyDiv, + state_ops.scatter_min: _NumpyMin, + state_ops.scatter_max: _NumpyMax, +} + +_TF_OPS_TO_NUMPY_SCALAR = { + state_ops.scatter_update: _NumpyUpdateScalar, + state_ops.scatter_add: _NumpyAddScalar, + state_ops.scatter_sub: _NumpySubScalar, + state_ops.scatter_mul: _NumpyMulScalar, + state_ops.scatter_div: _NumpyDivScalar, + state_ops.scatter_min: _NumpyMinScalar, + state_ops.scatter_max: _NumpyMaxScalar, } class ScatterTest(test.TestCase): - def _VariableRankTest(self, tf_scatter, vtype, itype, repeat_indices=False): + def _VariableRankTest(self, + tf_scatter, + vtype, + itype, + repeat_indices=False, + updates_are_scalar=False): np.random.seed(8) with self.test_session(use_gpu=True): for indices_shape in (), (2,), (3, 7), (3, 4, 7): @@ -89,8 +151,11 @@ class ScatterTest(test.TestCase): indices[np.random.randint(size // 2)]) np.random.shuffle(indices) indices = indices.reshape(indices_shape) - updates = _AsType( - np.random.randn(*(indices_shape + extra_shape)), vtype) + if updates_are_scalar: + updates = _AsType(np.random.randn(), vtype) + else: + updates = _AsType( + np.random.randn(*(indices_shape + extra_shape)), vtype) # Clips small values to avoid division by zero. def clip_small_values(x): @@ -101,7 +166,10 @@ class ScatterTest(test.TestCase): # Scatter via numpy new = old.copy() - np_scatter = _TF_OPS_TO_NUMPY[tf_scatter] + if updates_are_scalar: + np_scatter = _TF_OPS_TO_NUMPY_SCALAR[tf_scatter] + else: + np_scatter = _TF_OPS_TO_NUMPY[tf_scatter] np_scatter(new, indices, updates) # Scatter via tensorflow ref = variables.Variable(old) @@ -109,25 +177,35 @@ class ScatterTest(test.TestCase): tf_scatter(ref, indices, updates).eval() self.assertAllClose(ref.eval(), new) - def _VariableRankTests(self, tf_scatter, repeat_indices=False): + def _VariableRankTests(self, + tf_scatter, + repeat_indices=False, + updates_are_scalar=False): for vtype in (np.float32, np.float64): for itype in (np.int32, np.int64): - self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices) + self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices, + updates_are_scalar) def testVariableRankUpdate(self): - self._VariableRankTests(state_ops.scatter_update) + self._VariableRankTests(state_ops.scatter_update, False) def testVariableRankAdd(self): - self._VariableRankTests(state_ops.scatter_add) + self._VariableRankTests(state_ops.scatter_add, False) def testVariableRankSub(self): - self._VariableRankTests(state_ops.scatter_sub) + self._VariableRankTests(state_ops.scatter_sub, False) def testVariableRankMul(self): - self._VariableRankTests(state_ops.scatter_mul) + self._VariableRankTests(state_ops.scatter_mul, False) def testVariableRankDiv(self): - self._VariableRankTests(state_ops.scatter_div) + self._VariableRankTests(state_ops.scatter_div, False) + + def testVariableRankMin(self): + self._VariableRankTests(state_ops.scatter_min, False) + + def testVariableRankMax(self): + self._VariableRankTests(state_ops.scatter_max, False) def testRepeatIndicesAdd(self): self._VariableRankTests(state_ops.scatter_add, True) @@ -141,6 +219,51 @@ class ScatterTest(test.TestCase): def testRepeatIndicesDiv(self): self._VariableRankTests(state_ops.scatter_div, True) + def testRepeatIndicesMin(self): + self._VariableRankTests(state_ops.scatter_min, True) + + def testRepeatIndicesMax(self): + self._VariableRankTests(state_ops.scatter_max, True) + + def testVariableRankUpdateScalar(self): + self._VariableRankTests(state_ops.scatter_update, False, True) + + def testVariableRankAddScalar(self): + self._VariableRankTests(state_ops.scatter_add, False, True) + + def testVariableRankSubScalar(self): + self._VariableRankTests(state_ops.scatter_sub, False, True) + + def testVariableRankMulScalar(self): + self._VariableRankTests(state_ops.scatter_mul, False, True) + + def testVariableRankDivScalar(self): + self._VariableRankTests(state_ops.scatter_div, False, True) + + def testVariableRankMinScalar(self): + self._VariableRankTests(state_ops.scatter_min, False, True) + + def testVariableRankMaxScalar(self): + self._VariableRankTests(state_ops.scatter_max, False, True) + + def testRepeatIndicesAddScalar(self): + self._VariableRankTests(state_ops.scatter_add, True, True) + + def testRepeatIndicesSubScalar(self): + self._VariableRankTests(state_ops.scatter_sub, True, True) + + def testRepeatIndicesMulScalar(self): + self._VariableRankTests(state_ops.scatter_mul, True, True) + + def testRepeatIndicesDivScalar(self): + self._VariableRankTests(state_ops.scatter_div, True, True) + + def testRepeatIndicesMinScalar(self): + self._VariableRankTests(state_ops.scatter_min, True, True) + + def testRepeatIndicesMaxScalar(self): + self._VariableRankTests(state_ops.scatter_max, True, True) + def testBooleanScatterUpdate(self): if not test.is_gpu_available(): with self.test_session(use_gpu=False) as session: diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py index bbce6b7d47325b8209815230426672ec6894147f..3bca5fadc42693f514911c7ffa8f078de8ef9bcd 100644 --- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py @@ -542,6 +542,25 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper): tf_ans = s.eval() self.assertAllClose(np_ans, tf_ans) + def testWithEmptySegments(self): + tf_x = constant_op.constant([], shape=[0, 4], dtype=dtypes_lib.float32) + ops_list = [ + math_ops.sparse_segment_sum_with_num_segments, + math_ops.sparse_segment_mean_with_num_segments + ] + segment_indices = [] + tf_indices = [] + num_segments = 5 + with self.test_session(use_gpu=False): + for tf_op in ops_list: + s = tf_op( + data=tf_x, + indices=tf_indices, + segment_ids=segment_indices, + num_segments=num_segments) + tf_ans = s.eval() + self.assertAllClose(np.zeros([5, 4]), tf_ans) + def testSegmentIdsGreaterThanZero(self): tf_x, np_x = self._input([10, 4], dtype=dtypes_lib.float32) ops_list = [(np.add, None, math_ops.sparse_segment_sum), ( diff --git a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py index 4de5f4e4dbd38043557c54ede90fa47e43a1e26d..d2647088c5c2afda032482fb5cfd983cedb49a8f 100644 --- a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py +++ b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py @@ -71,6 +71,23 @@ class SelfAdjointEigTest(test.TestCase): self.assertAllEqual(val[4], val[5]) self.assertAllEqual(val[1], val[3]) + def testMatrixThatFailsWhenFlushingDenormsToZero(self): + # Test a 32x32 matrix which is known to fail if denorm floats are flushed to + # zero. + matrix = np.genfromtxt( + test.test_src_dir_path( + "python/kernel_tests/testdata/" + "self_adjoint_eig_fail_if_denorms_flushed.txt")).astype(np.float32) + self.assertEqual(matrix.shape, (32, 32)) + matrix_tensor = constant_op.constant(matrix) + with self.test_session(use_gpu=True) as sess: + (e, v) = sess.run(linalg_ops.self_adjoint_eig(matrix_tensor)) + self.assertEqual(e.size, 32) + self.assertAllClose( + np.matmul(v, v.transpose()), np.eye(32, dtype=np.float32), atol=2e-3) + self.assertAllClose(matrix, + np.matmul(np.matmul(v, np.diag(e)), v.transpose())) + def SortEigenDecomposition(e, v): if v.ndim < 2: diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py index 051a25080b826de05ee3e24a82fbcd1f47995544..5fc9bef21816e3a12f0d274bab1fc82a83546422 100644 --- a/tensorflow/python/kernel_tests/slice_op_test.py +++ b/tensorflow/python/kernel_tests/slice_op_test.py @@ -283,7 +283,7 @@ class SliceTest(test.TestCase): # unintended behavior is prevented. c = constant_op.constant(5.0) with self.assertRaisesWithPredicateMatch( - TypeError, lambda e: "`Tensor` objects are not iterable" in str(e)): + TypeError, lambda e: "Tensor objects are not iterable" in str(e)): for _ in c: pass diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py index 4d89831aae9a5e95210a8defb180e09c9d38f4d6..981f96b74d3058aa79a1ea10e1254e572d0e8b85 100644 --- a/tensorflow/python/kernel_tests/softmax_op_test.py +++ b/tensorflow/python/kernel_tests/softmax_op_test.py @@ -18,15 +18,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import unittest import numpy as np -from tensorflow.python.framework import constant_op + from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging @test_util.with_c_api @@ -42,9 +44,10 @@ class SoftmaxTest(test.TestCase): features, axis=dim), one_only_on_dim)) softmax = e / np.reshape(np.sum(e, axis=dim), one_only_on_dim) if log: - return np.log(softmax) + res = np.log(softmax) else: - return softmax + res = softmax + return res def _testSoftmax(self, np_features, dim=-1, log=False, use_gpu=False): # A previous version of the code checked the op name rather than the op type @@ -54,9 +57,9 @@ class SoftmaxTest(test.TestCase): np_softmax = self._npSoftmax(np_features, dim=dim, log=log) with self.test_session(use_gpu=use_gpu): if log: - tf_softmax = nn_ops.log_softmax(np_features, dim=dim, name=name) + tf_softmax = nn_ops.log_softmax(np_features, axis=dim, name=name) else: - tf_softmax = nn_ops.softmax(np_features, dim=dim, name=name) + tf_softmax = nn_ops.softmax(np_features, axis=dim, name=name) out = tf_softmax.eval() self.assertAllCloseAccordingToType(np_softmax, out) self.assertShapeEqual(np_softmax, tf_softmax) @@ -118,10 +121,32 @@ class SoftmaxTest(test.TestCase): self._testAll( np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float32)) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testFloatGPU(self): + if test.is_gpu_available(cuda_only=True): + rows = [2**x + np.random.randint(0, 1024) for x in range(1, 10)] + cols = [2**x + np.random.randint(0, 1024) for x in range(1, 10)] + for row, col in zip(rows, cols): + logging.info("Testing softmax float dtype in shape [%d, %d]", row, col) + data = np.random.rand(row, col) + self._testAll(data.astype(np.float32)) + def testHalf(self): self._testAll( np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float16)) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testHalfGPU(self): + if test.is_gpu_available(cuda_only=True): + rows = [2**x + np.random.randint(0, 1024) for x in range(1, 8)] + cols = [2**x + np.random.randint(0, 1024) for x in range(1, 8)] + for row, col in zip(rows, cols): + logging.info("Testing softmax half dtype in shape [%d, %d]", row, col) + data = np.random.rand(row, col) + self._testAll(data.astype(np.float16)) + def testDouble(self): self._testSoftmax( np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64)) @@ -166,11 +191,11 @@ class SoftmaxTest(test.TestCase): def testEmptyInput(self): with self.test_session(): - x = constant_op.constant([[]], shape=[0, 3]) + x = array_ops.placeholder(dtypes.float32, shape=[0, 3]) self.assertEqual(0, array_ops.size(x).eval()) # reshape would raise if logits is empty with self.assertRaises(errors_impl.InvalidArgumentError): - nn_ops.softmax(x, dim=0).eval() + nn_ops.softmax(x, axis=0).eval() def testDimTooLarge(self): with self.test_session(): @@ -178,7 +203,7 @@ class SoftmaxTest(test.TestCase): # inference error. dim = array_ops.placeholder_with_default(100, shape=[]) with self.assertRaises(errors_impl.InvalidArgumentError): - nn_ops.softmax([1., 2., 3., 4.], dim=dim).eval() + nn_ops.softmax([1., 2., 3., 4.], axis=dim).eval() def testLargeDims(self): # Make sure that we properly handle large inputs. See diff --git a/tensorflow/python/kernel_tests/spacetobatch_op_test.py b/tensorflow/python/kernel_tests/spacetobatch_op_test.py index b943dfa4e5f2a06eddcb3af03764e5e046b715f4..2a9232b6aecb66328f10a62f2251246c4fcec6e6 100644 --- a/tensorflow/python/kernel_tests/spacetobatch_op_test.py +++ b/tensorflow/python/kernel_tests/spacetobatch_op_test.py @@ -86,11 +86,11 @@ class CppOpImpl(object): @staticmethod def space_to_batch(*args, **kwargs): - return gen_array_ops._space_to_batch(*args, **kwargs) + return gen_array_ops.space_to_batch(*args, **kwargs) @staticmethod def batch_to_space(*args, **kwargs): - return gen_array_ops._batch_to_space(*args, **kwargs) + return gen_array_ops.batch_to_space(*args, **kwargs) class SpaceToBatchTest(test.TestCase, PythonOpImpl): diff --git a/tensorflow/python/kernel_tests/spacetodepth_op_test.py b/tensorflow/python/kernel_tests/spacetodepth_op_test.py index 3c98a685e07a1f2d55c3c1035a99ffaa593d35b3..cd90d16aacb4325ed426b0466266d9616b574401 100644 --- a/tensorflow/python/kernel_tests/spacetodepth_op_test.py +++ b/tensorflow/python/kernel_tests/spacetodepth_op_test.py @@ -34,8 +34,8 @@ from tensorflow.python.platform import tf_logging class SpaceToDepthTest(test.TestCase): - def _testOne(self, inputs, block_size, outputs): - input_nhwc = math_ops.to_float(inputs) + def _testOne(self, inputs, block_size, outputs, dtype=dtypes.float32): + input_nhwc = math_ops.cast(inputs, dtype) with self.test_session(use_gpu=False): # test NHWC (default) on CPU x_tf = array_ops.space_to_depth(input_nhwc, block_size) @@ -58,6 +58,12 @@ class SpaceToDepthTest(test.TestCase): x_out = [[[[1, 2, 3, 4]]]] self._testOne(x_np, block_size, x_out) + def testBasicFloat16(self): + x_np = [[[[1], [2]], [[3], [4]]]] + block_size = 2 + x_out = [[[[1, 2, 3, 4]]]] + self._testOne(x_np, block_size, x_out, dtype=dtypes.float16) + # Tests for larger input dimensions. To make sure elements are # correctly ordered spatially. def testLargerInput2x2(self): @@ -126,6 +132,24 @@ class SpaceToDepthTest(test.TestCase): x_out = [batch_output_elt(i) for i in range(batch_size)] self._testOne(x_np, block_size, x_out) + def testBatchSize0(self): + block_size = 2 + batch_size = 0 + input_nhwc = array_ops.ones([batch_size, 4, 6, 3]) + x_out = array_ops.ones([batch_size, 2, 3, 12]) + + with self.test_session(use_gpu=False): + # test NHWC (default) on CPU + x_tf = array_ops.space_to_depth(input_nhwc, block_size) + self.assertAllEqual(x_tf.shape, x_out.shape) + x_tf.eval() + if test.is_gpu_available(): + with self.test_session(use_gpu=True): + # test NHWC (default) on GPU + x_tf = array_ops.space_to_depth(input_nhwc, block_size) + self.assertAllEqual(x_tf.shape, x_out.shape) + x_tf.eval() + # Tests for different width and height. def testNonSquare(self): x_np = [[[[1, 10], [2, 20]], [[3, 30], [4, 40]], [[5, 50], [6, 60]], diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py index cd5b711a0ed18aabff543aa4b6ecb1a885618caf..a841fe83a7f585a69ef33c437570359797484a4a 100644 --- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py @@ -64,7 +64,7 @@ class SparseXentTest(test.TestCase): def _testXent(self, np_features, np_labels): np_loss, np_backprop = self._npXent(np_features, np_labels) with self.test_session(use_gpu=True) as sess: - loss, backprop = gen_nn_ops._sparse_softmax_cross_entropy_with_logits( + loss, backprop = gen_nn_ops.sparse_softmax_cross_entropy_with_logits( np_features, np_labels) tf_loss, tf_backprop = sess.run([loss, backprop]) self.assertAllCloseAccordingToType(np_loss, tf_loss) @@ -73,7 +73,7 @@ class SparseXentTest(test.TestCase): def testSingleClass(self): for label_dtype in np.int32, np.int64: with self.test_session(use_gpu=True) as sess: - loss, backprop = gen_nn_ops._sparse_softmax_cross_entropy_with_logits( + loss, backprop = gen_nn_ops.sparse_softmax_cross_entropy_with_logits( np.array([[1.], [-1.], [0.]]).astype(np.float32), np.array([0, 0, 0]).astype(label_dtype)) tf_loss, tf_backprop = sess.run([loss, backprop]) @@ -87,8 +87,9 @@ class SparseXentTest(test.TestCase): if test.is_built_with_cuda() and test.is_gpu_available(): with self.test_session(use_gpu=True) as sess: - loss, backprop = (gen_nn_ops._sparse_softmax_cross_entropy_with_logits( - features, labels)) + loss, backprop = ( + gen_nn_ops.sparse_softmax_cross_entropy_with_logits( + features, labels)) tf_loss, tf_backprop = sess.run([loss, backprop]) self.assertAllClose( [[np.nan] * 4, [0.25, 0.25, 0.25, -0.75], @@ -100,8 +101,8 @@ class SparseXentTest(test.TestCase): [np.nan, 1.3862, 3.4420, np.nan], tf_loss, rtol=1e-3, atol=1e-3) with self.test_session(use_gpu=False) as sess: - loss, backprop = (gen_nn_ops._sparse_softmax_cross_entropy_with_logits( - features, labels)) + loss, backprop = ( + gen_nn_ops.sparse_softmax_cross_entropy_with_logits(features, labels)) with self.assertRaisesOpError("Received a label value of"): sess.run([loss, backprop]) diff --git a/tensorflow/python/kernel_tests/split_op_test.py b/tensorflow/python/kernel_tests/split_op_test.py index 6171793b148f8d8f195b9548a13df89d29c5e96e..8cfee3eb933afcea7a58d5632948b87b0c4c10df 100644 --- a/tensorflow/python/kernel_tests/split_op_test.py +++ b/tensorflow/python/kernel_tests/split_op_test.py @@ -336,6 +336,20 @@ class SplitOpTest(test.TestCase): for s in splits: self.assertEqual(None, s.get_shape().ndims) + def testNonexistentDimTensor(self): + x = array_ops.placeholder(dtypes.int32) + values = np.zeros([5, 30]) + splits = array_ops.placeholder(dtypes.int32) + with self.assertRaisesRegexp(ValueError, "Cannot infer"): + y = array_ops.split(values, splits, axis=x) + + splits = array_ops.placeholder(dtypes.int32, [3]) + y = array_ops.split(values, splits, axis=x) + with self.test_session(use_gpu=True) as sess: + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "must have exactly one element"): + sess.run(y, {x: np.array([], dtype=np.int32), splits: [4, 11, 15]}) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/stack_ops_test.py b/tensorflow/python/kernel_tests/stack_ops_test.py index aa409336f5c50178e4d0ca946190119fb0e4188e..afd2eaffab992bca4b3ae7b4f65e0370f325b548 100644 --- a/tensorflow/python/kernel_tests/stack_ops_test.py +++ b/tensorflow/python/kernel_tests/stack_ops_test.py @@ -34,11 +34,11 @@ class StackOpTest(test.TestCase): def _testStackPushPop(self, use_gpu): with self.test_session(use_gpu=use_gpu): - h = gen_data_flow_ops._stack_v2( + h = gen_data_flow_ops.stack_v2( -1, elem_type=dtypes.float32, stack_name="foo") - c = gen_data_flow_ops._stack_push_v2(h, [[4.0, 5.0]]) + c = gen_data_flow_ops.stack_push_v2(h, [[4.0, 5.0]]) with ops.control_dependencies([c]): - c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32) + c1 = gen_data_flow_ops.stack_pop_v2(h, dtypes.float32) self.assertAllClose([[4.0, 5.0]], c1.eval()) def testStackPushPop(self): @@ -49,11 +49,11 @@ class StackOpTest(test.TestCase): with self.test_session(use_gpu=use_gpu): a = np.arange(2000) x = constant_op.constant(a, dtype=dtypes.float32) - h = gen_data_flow_ops._stack_v2( + h = gen_data_flow_ops.stack_v2( -1, elem_type=dtypes.float32, stack_name="foo") - c = gen_data_flow_ops._stack_push_v2(h, x, swap_memory=True) + c = gen_data_flow_ops.stack_push_v2(h, x, swap_memory=True) with ops.control_dependencies([c]): - c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32) + c1 = gen_data_flow_ops.stack_pop_v2(h, dtypes.float32) self.assertAllClose(a, c1.eval()) def testStackPushPopSwap(self): @@ -63,7 +63,7 @@ class StackOpTest(test.TestCase): def _testStackWhileSwap(self, use_gpu): with self.test_session(use_gpu=use_gpu): n = constant_op.constant(0) - h = gen_data_flow_ops._stack_v2( + h = gen_data_flow_ops.stack_v2( -1, elem_type=dtypes.float32, stack_name="foo") def c(x): @@ -72,7 +72,7 @@ class StackOpTest(test.TestCase): def b(x): with ops.control_dependencies([x]): a = constant_op.constant(np.ones(2000), dtype=dtypes.float32) - v = gen_data_flow_ops._stack_push_v2(h, a, swap_memory=True) + v = gen_data_flow_ops.stack_push_v2(h, a, swap_memory=True) with ops.control_dependencies([v]): return math_ops.add(x, 1) @@ -86,7 +86,7 @@ class StackOpTest(test.TestCase): def b1(x, y): nx = math_ops.subtract(x, 1) - ny = y + gen_data_flow_ops._stack_pop_v2(h, dtypes.float32) + ny = y + gen_data_flow_ops.stack_pop_v2(h, dtypes.float32) return [nx, ny] _, ry = control_flow_ops.while_loop( @@ -99,16 +99,16 @@ class StackOpTest(test.TestCase): def _testMultiStack(self, use_gpu): with self.test_session(use_gpu=use_gpu): - h1 = gen_data_flow_ops._stack_v2( + h1 = gen_data_flow_ops.stack_v2( -1, elem_type=dtypes.float32, stack_name="foo") - c1 = gen_data_flow_ops._stack_push_v2(h1, 4.0) + c1 = gen_data_flow_ops.stack_push_v2(h1, 4.0) with ops.control_dependencies([c1]): - c1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32) - h2 = gen_data_flow_ops._stack_v2( + c1 = gen_data_flow_ops.stack_pop_v2(h1, dtypes.float32) + h2 = gen_data_flow_ops.stack_v2( -1, elem_type=dtypes.float32, stack_name="bar") - c2 = gen_data_flow_ops._stack_push_v2(h2, 5.0) + c2 = gen_data_flow_ops.stack_push_v2(h2, 5.0) with ops.control_dependencies([c2]): - c2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32) + c2 = gen_data_flow_ops.stack_pop_v2(h2, dtypes.float32) r = c1 + c2 self.assertAllClose(9.0, r.eval()) @@ -119,17 +119,17 @@ class StackOpTest(test.TestCase): def _testSameNameStacks(self, use_gpu): """Different stacks with the same name do not interfere.""" with self.test_session(use_gpu=use_gpu) as sess: - h1 = gen_data_flow_ops._stack_v2( + h1 = gen_data_flow_ops.stack_v2( -1, elem_type=dtypes.float32, stack_name="foo") - h2 = gen_data_flow_ops._stack_v2( + h2 = gen_data_flow_ops.stack_v2( -1, elem_type=dtypes.float32, stack_name="foo") - c1 = gen_data_flow_ops._stack_push_v2(h1, 4.0) + c1 = gen_data_flow_ops.stack_push_v2(h1, 4.0) with ops.control_dependencies([c1]): - c2 = gen_data_flow_ops._stack_push_v2(h2, 5.0) + c2 = gen_data_flow_ops.stack_push_v2(h2, 5.0) with ops.control_dependencies([c2]): - pop1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32) - pop2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32) + pop1 = gen_data_flow_ops.stack_pop_v2(h1, dtypes.float32) + pop2 = gen_data_flow_ops.stack_pop_v2(h2, dtypes.float32) out1, out2 = sess.run([pop1, pop2]) self.assertAllClose(out1, 4.0) @@ -141,9 +141,9 @@ class StackOpTest(test.TestCase): def _testCloseStack(self, use_gpu): with self.test_session(use_gpu=use_gpu) as sess: - h = gen_data_flow_ops._stack_v2( + h = gen_data_flow_ops.stack_v2( -1, elem_type=dtypes.float32, stack_name="foo") - c1 = gen_data_flow_ops._stack_close_v2(h) + c1 = gen_data_flow_ops.stack_close_v2(h) sess.run(c1) def testCloseStack(self): @@ -152,11 +152,11 @@ class StackOpTest(test.TestCase): def _testPushCloseStack(self, use_gpu): with self.test_session(use_gpu=use_gpu) as sess: - h = gen_data_flow_ops._stack_v2( + h = gen_data_flow_ops.stack_v2( -1, elem_type=dtypes.float32, stack_name="foo") - c = gen_data_flow_ops._stack_push_v2(h, [[4.0, 5.0]]) + c = gen_data_flow_ops.stack_push_v2(h, [[4.0, 5.0]]) with ops.control_dependencies([c]): - c1 = gen_data_flow_ops._stack_close_v2(h) + c1 = gen_data_flow_ops.stack_close_v2(h) sess.run(c1) def testPushCloseStack(self): @@ -170,9 +170,9 @@ class StackOpRefTest(test.TestCase): def _testStackPushPop(self, use_gpu): with self.test_session(use_gpu=use_gpu): h = gen_data_flow_ops._stack(dtypes.float32, stack_name="foo") - c = gen_data_flow_ops._stack_push(h, [[4.0, 5.0]]) + c = gen_data_flow_ops.stack_push(h, [[4.0, 5.0]]) with ops.control_dependencies([c]): - c1 = gen_data_flow_ops._stack_pop(h, dtypes.float32) + c1 = gen_data_flow_ops.stack_pop(h, dtypes.float32) self.assertAllClose([[4.0, 5.0]], c1.eval()) def testStackPushPop(self): @@ -184,9 +184,9 @@ class StackOpRefTest(test.TestCase): a = np.arange(2000) x = constant_op.constant(a, dtype=dtypes.float32) h = gen_data_flow_ops._stack(dtypes.float32, stack_name="foo") - c = gen_data_flow_ops._stack_push(h, x, swap_memory=True) + c = gen_data_flow_ops.stack_push(h, x, swap_memory=True) with ops.control_dependencies([c]): - c1 = gen_data_flow_ops._stack_pop(h, dtypes.float32) + c1 = gen_data_flow_ops.stack_pop(h, dtypes.float32) self.assertAllClose(a, c1.eval()) def testStackPushPopSwap(self): @@ -196,13 +196,13 @@ class StackOpRefTest(test.TestCase): def _testMultiStack(self, use_gpu): with self.test_session(use_gpu=use_gpu): h1 = gen_data_flow_ops._stack(dtypes.float32, stack_name="foo") - c1 = gen_data_flow_ops._stack_push(h1, 4.0) + c1 = gen_data_flow_ops.stack_push(h1, 4.0) with ops.control_dependencies([c1]): - c1 = gen_data_flow_ops._stack_pop(h1, dtypes.float32) + c1 = gen_data_flow_ops.stack_pop(h1, dtypes.float32) h2 = gen_data_flow_ops._stack(dtypes.float32, stack_name="bar") - c2 = gen_data_flow_ops._stack_push(h2, 5.0) + c2 = gen_data_flow_ops.stack_push(h2, 5.0) with ops.control_dependencies([c2]): - c2 = gen_data_flow_ops._stack_pop(h2, dtypes.float32) + c2 = gen_data_flow_ops.stack_pop(h2, dtypes.float32) r = c1 + c2 self.assertAllClose(9.0, r.eval()) @@ -217,7 +217,7 @@ class StackOpRefTest(test.TestCase): def b(x): with ops.control_dependencies([x]): a = constant_op.constant(np.ones(2000), dtype=dtypes.float32) - v = gen_data_flow_ops._stack_push(h, a, swap_memory=True) + v = gen_data_flow_ops.stack_push(h, a, swap_memory=True) with ops.control_dependencies([v]): return math_ops.add(x, 1) @@ -231,7 +231,7 @@ class StackOpRefTest(test.TestCase): def b1(x, y): nx = math_ops.subtract(x, 1) - ny = y + gen_data_flow_ops._stack_pop(h, dtypes.float32) + ny = y + gen_data_flow_ops.stack_pop(h, dtypes.float32) return [nx, ny] _, ry = control_flow_ops.while_loop( @@ -249,9 +249,9 @@ class StackOpRefTest(test.TestCase): def _testSameNameStacks(self, use_gpu): with self.test_session(use_gpu=use_gpu): h1 = gen_data_flow_ops._stack(dtypes.float32, stack_name="foo") - c1 = gen_data_flow_ops._stack_push(h1, 4.0) + c1 = gen_data_flow_ops.stack_push(h1, 4.0) h2 = gen_data_flow_ops._stack(dtypes.float32, stack_name="foo") - c2 = gen_data_flow_ops._stack_push(h2, 5.0) + c2 = gen_data_flow_ops.stack_push(h2, 5.0) _ = c1 + c2 self.assertNotEqual(h1.eval()[1], h2.eval()[1]) @@ -262,7 +262,7 @@ class StackOpRefTest(test.TestCase): def _testCloseStack(self, use_gpu): with self.test_session(use_gpu=use_gpu) as sess: h = gen_data_flow_ops._stack(dtypes.float32, stack_name="foo") - c1 = gen_data_flow_ops._stack_close(h) + c1 = gen_data_flow_ops.stack_close(h) sess.run(c1) def testCloseStack(self): @@ -272,9 +272,9 @@ class StackOpRefTest(test.TestCase): def _testPushCloseStack(self, use_gpu): with self.test_session(use_gpu=use_gpu) as sess: h = gen_data_flow_ops._stack(dtypes.float32, stack_name="foo") - c = gen_data_flow_ops._stack_push(h, [[4.0, 5.0]]) + c = gen_data_flow_ops.stack_push(h, [[4.0, 5.0]]) with ops.control_dependencies([c]): - c1 = gen_data_flow_ops._stack_close(h) + c1 = gen_data_flow_ops.stack_close(h) sess.run(c1) def testPushCloseStack(self): diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py index a519b69b22cf51ab4f4173b215c21a71d83e9f99..1b935d5286729e9e802c56e90e2ae7ab72a6e080 100644 --- a/tensorflow/python/kernel_tests/template_test.py +++ b/tensorflow/python/kernel_tests/template_test.py @@ -356,6 +356,10 @@ class TemplateTest(test.TestCase): self.assertEqual("s1_1/nested/dummy:0", v5.name) self.assertEqual("s1_1/nested_1/dummy:0", v6.name) + self.assertEqual(2, len(tmpl1._checkpoint_dependencies)) + self.assertEqual("nested", tmpl1._checkpoint_dependencies[0].name) + self.assertEqual("nested_1", tmpl1._checkpoint_dependencies[1].name) + @test_util.run_in_graph_and_eager_modes() def test_nested_templates_with_defun(self): @@ -558,7 +562,7 @@ class TemplateTest(test.TestCase): outputs_b, _ = linear1(inputs) self.assertEquals("foo", linear1.variable_scope.name) self.assertEquals("foo/w:0", w1.name) - if context.in_graph_mode(): + if not context.executing_eagerly(): self.assertEquals("foo/add:0", outputs_a.name, "First application of template should get " "same name scope as variables.") @@ -573,7 +577,7 @@ class TemplateTest(test.TestCase): "New template gets a freshly uniquified variable scope " "because 'foo' is already taken.") self.assertEquals("foo_1/w:0", w2.name) - if context.in_graph_mode(): + if not context.executing_eagerly(): self.assertEquals("foo_1_1/add:0", outputs_c.name, "First application of template would get " "same name scope as variables, but 'foo_1' is already " @@ -588,7 +592,7 @@ class TemplateTest(test.TestCase): with variable_scope.variable_scope("foo"): # Create two templates with the same name, ensure scopes are made unique. ta = template.make_template("bar", variable_scoped_function, True) - if context.in_eager_mode(): + if context.executing_eagerly(): tb = template.make_template("s", function_with_side_create, trainable=False) else: diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py index aad2443eea7ad87faf481973e91ca3df32ccfb44..a834675828b67aed4057d1857c546a586cee69c9 100644 --- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py +++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py @@ -399,28 +399,14 @@ class TensorArrayTest(test.TestCase): def testTensorArrayWriteWrongIndexOrDataTypeFails(self): with self.test_session(use_gpu=True): ta = _make_ta(3, "foo", dtype=dtypes.float32) - in_graph_mode = context.in_graph_mode() # Test writing the wrong datatype - if in_graph_mode: - with self.assertRaisesOpError( - "TensorArray dtype is float but Op is trying to write " - "dtype string"): - self.evaluate(ta.write(0, "wrong_type_scalar").flow) - else: - with self.assertRaisesOpError( - "TensorArray dtype is float32 but Op is trying to write " - "dtype string"): - self.evaluate(ta.write(0, "wrong_type_scalar").flow) + with self.assertRaisesOpError( + "TensorArray dtype is (float|float32) but Op is trying to write " + "dtype string"): + self.evaluate(ta.write(0, "wrong_type_scalar").flow) - if context.in_graph_mode(): - with self.assertRaisesOpError( - "Tried to write to index -1 but array is not " - "resizeable and size is: 3"): - self.evaluate(ta.write(-1, 3.0).flow) - else: - with self.assertRaisesOpError( - r"Writing to negative indices \(index -1\) is not allowed."): - self.evaluate(ta.write(-1, 3.0).flow) + with self.assertRaisesOpError("index -1"): + self.evaluate(ta.write(-1, 3.0).flow) # Test reading from too large an index with self.assertRaisesOpError( @@ -435,23 +421,17 @@ class TensorArrayTest(test.TestCase): w0 = ta.write(0, [[4.0, 5.0]]) - # Test reading wrong datatype, which is only possible in graph mode - if context.in_graph_mode(): - r0_bad = gen_data_flow_ops._tensor_array_read_v3( + # Test reading wrong datatype (only possible when constructing graphs). + if not context.executing_eagerly(): + r0_bad = gen_data_flow_ops.tensor_array_read_v3( handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow) with self.assertRaisesOpError( "TensorArray dtype is float but Op requested dtype double."): r0_bad.eval() # Test reading from a negative index, which is not allowed - if context.in_graph_mode(): - with self.assertRaisesOpError( - r"Tried to read from index -1 but array size is: 3"): - self.evaluate(ta.read(-1)) - else: - with self.assertRaisesOpError( - r"Reading from negative indices \(index -1\) is not allowed."): - self.evaluate(ta.read(-1)) + with self.assertRaisesOpError("index -1"): + self.evaluate(ta.read(-1)) # Test reading from too large an index with self.assertRaisesOpError( @@ -467,10 +447,7 @@ class TensorArrayTest(test.TestCase): with self.assertRaisesOpError( "Could not write to TensorArray index 2 because " "it has already been written to."): - if context.in_graph_mode(): - self.evaluate(ta.write(2, 3.0).write(2, 3.0).flow) - else: - self.evaluate(ta.write(2, 3.0).write(2, 3.0)) + self.evaluate(ta.write(2, 3.0).write(2, 3.0).flow) @test_util.run_in_graph_and_eager_modes() def testTensorArrayConcatIncompatibleShapesFails(self): @@ -499,58 +476,40 @@ class TensorArrayTest(test.TestCase): w2 = w1.write(1, [4.0]) w3 = w2.write(2, [[3.0]]) - # The eager-mode implementation just passes up array_op.concat's error - # message. - if context.in_graph_mode(): - with self.assertRaisesOpError( - r"TensorArray has inconsistent shapes. Index 0 has " - r"\(excepting dimension 0\) shape: \[\] but index 2 has " - r"\(excepting dimension 0\) shape: \[1\]"): - self.evaluate(w3.concat()) - else: - with self.assertRaisesOpError( - r".*Ranks of all input tensors should match: shape\[0\] " - r"= \[1\] vs\. shape\[2\] = \[1,1\].*"): - self.evaluate(w3.concat()) + # The exact error messages differ between eager execution and graph + # construction as the former bubbles up the error from array_op.concat. + with self.assertRaisesOpError("shape"): + self.evaluate(w3.concat()) @test_util.run_in_graph_and_eager_modes() def testTensorArraySplitIncompatibleShapesFails(self): with self.test_session(use_gpu=True): - in_graph_mode = context.in_graph_mode() + in_eager_mode = context.executing_eagerly() ta = _make_ta(3, "foo") with self.assertRaisesOpError( r"Expected lengths to be a vector, received shape: \[\]"): - if in_graph_mode: + if in_eager_mode: + self.evaluate(ta.split([1.0, 2.0, 3.0], 1)) + else: lengths = array_ops.placeholder(dtypes.int64) ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1}) - else: - self.evaluate(ta.split([1.0, 2.0, 3.0], 1)) with self.assertRaisesOpError( r"Expected sum of lengths to be equal to values.shape\[0\], " r"but sum of lengths is 1 and value's shape is: \[3\]"): - if in_graph_mode: - self.evaluate(ta.split([1.0, 2.0, 3.0], [1]).flow) - else: - self.evaluate(ta.split([1.0, 2.0, 3.0], [1])) + self.evaluate(ta.split([1.0, 2.0, 3.0], [1]).flow) ta = _make_ta(1, "baz") with self.assertRaisesOpError( r"Expected value to be at least a vector, but received shape: \[\]"): - if in_graph_mode: - self.evaluate(ta.split(1.0, [1]).flow) - else: - self.evaluate(ta.split(1.0, [1])) + self.evaluate(ta.split(1.0, [1]).flow) ta = _make_ta(2, "buz") with self.assertRaisesOpError( r"TensorArray's size is not equal to the size of lengths " r"\(2 vs. 1\), and the TensorArray is not marked as " r"dynamically resizeable"): - if in_graph_mode: - self.evaluate(ta.split([1.0], [1]).flow) - else: - self.evaluate(ta.split([1.0], [1])) + self.evaluate(ta.split([1.0], [1]).flow) def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype): with self.test_session(use_gpu=True): @@ -868,14 +827,14 @@ class TensorArrayTest(test.TestCase): vout = func(v0, state0, var) grad_val = -np.arange(3 * 5, dtype=np_dtype).reshape(3, 5) - if context.in_graph_mode(): + if context.executing_eagerly(): + grad_fn = backprop.gradients_function(func) + v0_grad, state0_grad, var_grad = grad_fn(v0, state0, var, dy=grad_val) + else: v0_grad = gradients_impl.gradients([vout], [v0], [grad_val])[0] state0_grad = gradients_impl.gradients([vout], [state0], [grad_val])[0] var_grad = gradients_impl.gradients([vout], [var], [grad_val])[0] variables.global_variables_initializer().run() - else: - grad_fn = backprop.gradients_function(func) - v0_grad, state0_grad, var_grad = grad_fn(v0, state0, var, dy=grad_val) state0_t, var_t, v0_t, vout_t, v0_grad_t, var_grad_t, state0_grad_t = ( self.evaluate( @@ -959,10 +918,10 @@ class TensorArrayTest(test.TestCase): return r x = constant_op.constant(2.0, name="x") - if context.in_graph_mode(): - grad = gradients_impl.gradients(loop(x), [x])[0] - else: + if context.executing_eagerly(): grad = backprop.gradients_function(loop)(x)[0] + else: + grad = gradients_impl.gradients(loop(x), [x])[0] self.assertAllClose(31.0, self.evaluate(grad)) def testSumOfTwoReadVariablesWithoutRepeatGrad(self): @@ -1158,14 +1117,14 @@ class TensorArrayTest(test.TestCase): infer_shape=True) w0 = ta1.split(value, [1, 2]) r0 = w0.read(0) - if context.in_graph_mode(): + if context.executing_eagerly(): + self.assertEqual((1, 2), r0.get_shape()) + self.assertEqual((2, 2), w0.read(1).get_shape()) + else: self.assertEqual(r0.get_shape().ndims, None) self.assertEqual( tensor_shape.TensorShape( ta1.handle.op.get_attr("element_shape")).ndims, None) - else: - self.assertEqual((1, 2), r0.get_shape()) - self.assertEqual((2, 2), w0.read(1).get_shape()) def testWriteUnknownShape(self): with self.test_session(use_gpu=True): @@ -1297,13 +1256,13 @@ class TensorArrayTest(test.TestCase): g = func(values) grad_ys = [[[2.0, 3.0], [4.0, 5.0]]] # Test combined gradients + aggregation of read(0) - if context.in_graph_mode(): - grad = gradients_impl.gradients(ys=[g], xs=[values], grad_ys=grad_ys) - g_vals, grad_vals = session.run([[g], grad]) - else: + if context.executing_eagerly(): g_vals = [g] grad_vals = backprop.gradients_function(func)( values, dy=constant_op.constant(grad_ys[0], dtype=dtypes.float32)) + else: + grad = gradients_impl.gradients(ys=[g], xs=[values], grad_ys=grad_ys) + g_vals, grad_vals = session.run([[g], grad]) # Gradients for 8 of the 10 unread components are zero. expected_grad = np.zeros((10, 2)) @@ -1453,13 +1412,13 @@ class TensorArrayTest(test.TestCase): # Tests correct properties on new TensorArrays. self.assertEqual(dtypes.float32, ta0.dtype) self.assertEqual(dtypes.int32, ta1.dtype) - if context.in_graph_mode(): - self.assertEqual(tensor_shape.unknown_shape(), read0.get_shape()) + if context.executing_eagerly(): + self.assertEqual(tensor_shape.scalar(), read0.get_shape()) else: - self.assertEqual(tensor_shape.scalar(), read1.get_shape()) + self.assertEqual(tensor_shape.unknown_shape(), read0.get_shape()) self.assertEqual(tensor_shape.scalar(), read1.get_shape()) - if context.in_graph_mode(): + if not context.executing_eagerly(): variables.global_variables_initializer().run() read0_v, read1_v, size0_v, size1_v = self.evaluate((read0, read1, size0, diff --git a/tensorflow/python/kernel_tests/testdata/BUILD b/tensorflow/python/kernel_tests/testdata/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..45264c773ac0089bbfed44bd115e73e848a8cc62 --- /dev/null +++ b/tensorflow/python/kernel_tests/testdata/BUILD @@ -0,0 +1,24 @@ +# Data files for kernel tests. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "self_adjoint_eig_op_test_files", + srcs = ["self_adjoint_eig_fail_if_denorms_flushed.txt"], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/python/kernel_tests/testdata/self_adjoint_eig_fail_if_denorms_flushed.txt b/tensorflow/python/kernel_tests/testdata/self_adjoint_eig_fail_if_denorms_flushed.txt new file mode 100644 index 0000000000000000000000000000000000000000..d56a690a7928fafe39debc478db3e90ab953430b --- /dev/null +++ b/tensorflow/python/kernel_tests/testdata/self_adjoint_eig_fail_if_denorms_flushed.txt @@ -0,0 +1,32 @@ +2.60986303e-17 -9.66826148e-21 -1.68610775e-24 -9.16104778e-17 -1.1039539e-18 -1.66460338e-25 -2.12362492e-23 1.90946688e-21 -3.34190535e-22 1.2000634e-18 -7.31782583e-20 2.57851762e-20 -2.55509e-20 -9.54284927e-20 -1.04248315e-17 -5.32450516e-22 -1.81712853e-17 6.0044594e-18 3.96602716e-11 2.89077487e-25 -2.47461475e-25 1.77941757e-24 -7.30388687e-21 -3.84350041e-16 -3.88532388e-21 -4.29928618e-21 4.13551131e-16 -2.63408791e-25 -2.84830375e-21 -1.6450072e-16 -2.8585296e-21 -3.65413296e-21 +-9.66826148e-21 5.03939189e-22 9.17361108e-26 5.17304053e-20 1.99338895e-20 1.25259775e-28 -8.70441942e-26 9.91474109e-25 -5.80960164e-24 -1.19022314e-21 3.90467165e-22 -1.38179098e-22 1.79253406e-22 2.23977705e-22 1.1864143e-19 7.16291934e-24 4.10159639e-20 -2.16798529e-20 -4.95460504e-14 -2.6881406e-27 5.32861213e-27 -4.54567085e-28 1.99794328e-23 1.26854541e-17 -1.92916739e-23 8.60632417e-24 -1.04721097e-18 -7.00607669e-28 6.86771954e-23 8.65173173e-19 1.24469175e-22 6.03883081e-24 +-1.68610775e-24 9.17361108e-26 1.34889529e-26 2.65059e-22 2.39713735e-23 -2.00915344e-30 -1.135692e-27 -6.46049964e-26 -1.03607712e-26 -1.57623654e-23 -1.63805162e-24 -5.95741642e-25 3.24984759e-25 6.49561204e-24 2.28504969e-21 2.8319611e-25 3.96494845e-22 -2.1988623e-22 6.26027228e-16 1.2418479e-30 2.1016041e-30 6.22813846e-30 -1.0708067e-25 6.90778045e-21 1.86361622e-25 7.08789674e-26 -9.23628499e-21 1.65335067e-30 -1.12173032e-26 8.2257321e-22 -4.72686764e-27 -2.58501275e-26 +-9.16104778e-17 5.17304053e-20 2.65059e-22 2.69965968e-14 7.06005733e-17 1.69851446e-22 -2.75994304e-21 -6.61589523e-20 3.8682048e-20 -1.69253147e-17 -2.68580354e-18 -7.74994098e-19 -9.75466696e-19 2.13537585e-18 2.13185342e-16 6.89417478e-21 1.35805044e-16 -3.48309239e-16 1.0448622e-09 -2.17287918e-23 7.41749185e-24 -7.36683057e-23 -1.31083094e-20 1.574e-14 5.72646592e-19 -9.85673749e-21 -1.0654985e-14 2.70679318e-23 4.0943479e-20 -3.42938568e-15 8.57373804e-20 -2.18094505e-20 +-1.1039539e-18 1.99338895e-20 2.39713735e-23 7.06005733e-17 1.83801666e-17 1.09735975e-24 -5.73058223e-24 7.2227645e-22 -8.94843118e-22 -2.30558605e-19 -7.84892038e-20 -1.88692532e-20 -1.02217713e-20 2.95458834e-20 2.42873413e-17 8.89161401e-22 1.21669872e-17 -6.85317731e-18 -7.345906e-12 -3.1158751e-25 1.36359449e-24 -1.57981417e-24 3.89633371e-21 9.94580899e-16 1.45732115e-20 6.92065325e-22 -1.86114433e-16 6.00601346e-26 3.26844e-21 4.38573742e-17 1.06803444e-20 4.60203933e-22 +-1.66460338e-25 1.25259775e-28 -2.00915344e-30 1.69851446e-22 1.09735975e-24 5.75549306e-30 4.74050864e-29 -5.99239043e-28 -1.5784658e-27 -1.74631273e-25 -1.22702975e-25 -1.03371979e-26 -1.96967552e-26 -1.56446725e-26 -3.06462576e-25 -6.33857393e-28 -6.08829397e-24 -7.07478859e-24 -4.82614847e-18 -2.7324345e-31 1.23830207e-31 -7.96172e-31 -1.9034503e-27 -3.82709848e-22 -2.69257733e-26 -3.84934809e-27 -1.48572725e-22 4.14585761e-31 2.5611404e-28 -2.77402858e-24 3.10373361e-28 -5.09669241e-28 +-2.12362492e-23 -8.70441942e-26 -1.135692e-27 -2.75994304e-21 -5.73058223e-24 4.74050864e-29 6.28162e-26 -3.30076462e-25 -3.30065418e-25 -1.1370873e-23 -8.97722764e-24 -1.03190629e-24 -9.52908672e-25 -3.27285413e-24 1.36216664e-22 -8.0549564e-26 -1.94826821e-22 -3.64999226e-22 -2.92500975e-15 -3.00986528e-29 2.39712646e-29 -1.02470704e-28 -4.99034099e-25 -1.32277916e-19 -5.05595e-24 -3.04012473e-25 -1.44724215e-20 5.04614184e-30 -4.12370105e-26 4.20735765e-21 -1.02818953e-25 3.41267575e-26 +1.90946688e-21 9.91474109e-25 -6.46049964e-26 -6.61589523e-20 7.2227645e-22 -5.99239043e-28 -3.30076462e-25 1.8948059e-22 1.83367373e-23 1.06616038e-21 -2.81616502e-22 1.18347412e-22 8.3458038e-23 9.67703245e-24 -1.37445558e-20 2.11412652e-24 2.64820742e-21 8.02510339e-20 4.39926334e-13 9.58727772e-27 2.9838033e-28 1.29183353e-26 1.78626483e-22 3.03531056e-19 9.62612316e-23 1.33722715e-23 2.92905627e-18 -9.42286262e-28 3.23170971e-24 4.10885529e-19 -8.38673724e-25 -8.63732285e-25 +-3.34190535e-22 -5.80960164e-24 -1.03607712e-26 3.8682048e-20 -8.94843118e-22 -1.5784658e-27 -3.30065418e-25 1.83367373e-23 9.30693173e-23 1.48929558e-21 1.83278606e-21 1.08468362e-22 2.61703785e-22 4.42441537e-23 1.23906316e-20 2.55235433e-24 8.36323349e-20 1.2152038e-19 9.83332204e-14 5.14523933e-27 -3.28220159e-28 8.22099066e-27 3.34939233e-23 4.3309476e-19 5.82711129e-22 1.14299394e-22 3.25240717e-18 5.84184241e-28 -1.76991199e-24 5.5568966e-20 -2.80294941e-24 4.59071175e-24 +1.2000634e-18 -1.19022314e-21 -1.57623654e-23 -1.69253147e-17 -2.30558605e-19 -1.74631273e-25 -1.1370873e-23 1.06616038e-21 1.48929558e-21 2.05547703e-18 2.01471341e-20 2.65473229e-20 1.36331708e-20 -2.19777252e-20 -3.09825792e-18 -1.93365673e-22 -2.25608735e-18 7.98997246e-18 1.45582661e-11 6.29004356e-25 -1.14866332e-25 -5.51419319e-26 2.97082139e-21 -2.39052259e-16 1.48920411e-20 1.28589326e-21 4.27717466e-16 -4.44694851e-26 -1.80270052e-22 3.29932795e-18 -5.11645591e-22 5.53091711e-23 +-7.31782583e-20 3.90467165e-22 -1.63805162e-24 -2.68580354e-18 -7.84892038e-20 -1.22702975e-25 -8.97722764e-24 -2.81616502e-22 1.83278606e-21 2.01471341e-20 4.38037939e-19 -4.46678177e-21 3.48516266e-20 7.32592348e-21 1.11928135e-18 8.58541052e-23 8.80645183e-18 4.80109643e-21 -1.7163557e-11 1.92262335e-26 -2.78003951e-26 5.48322572e-25 8.95330117e-23 -1.11570766e-17 3.13666242e-20 4.47195205e-21 -1.09014604e-17 7.69340111e-26 1.64649306e-22 1.71054085e-17 1.33471053e-23 6.40747815e-22 +2.57851762e-20 -1.38179098e-22 -5.95741642e-25 -7.74994098e-19 -1.88692532e-20 -1.03371979e-26 -1.03190629e-24 1.18347412e-22 1.08468362e-22 2.65473229e-20 -4.46678177e-21 5.22731861e-21 1.06412616e-21 -8.0508039e-22 -1.68829721e-19 -2.7699538e-23 -2.15173717e-19 7.46895651e-19 1.71858101e-12 5.41956e-26 -6.15013064e-27 1.54884457e-26 2.54028029e-22 -1.50009535e-18 1.11920465e-21 1.05890428e-22 3.6487132e-17 -2.06798384e-27 -5.5143889e-23 -1.71529414e-18 -7.38099094e-23 -6.5250472e-24 +-2.55509e-20 1.79253406e-22 3.24984759e-25 -9.75466696e-19 -1.02217713e-20 -1.96967552e-26 -9.52908672e-25 8.3458038e-23 2.61703785e-22 1.36331708e-20 3.48516266e-20 1.06412616e-21 4.08927657e-20 -2.76503659e-21 -6.81059804e-20 5.13487959e-23 1.80612902e-18 5.32462054e-19 -3.89327199e-12 3.60012729e-26 -2.5575456e-26 3.14316426e-25 4.56614351e-22 -1.24545392e-17 9.14707146e-21 7.97421952e-22 2.84371096e-17 2.98359736e-26 1.33439467e-23 1.00242743e-17 -4.94476664e-23 3.28816461e-22 +-9.54284927e-20 2.23977705e-22 6.49561204e-24 2.13537585e-18 2.95458834e-20 -1.56446725e-26 -3.27285413e-24 9.67703245e-24 4.42441537e-23 -2.19777252e-20 7.32592348e-21 -8.0508039e-22 -2.76503659e-21 5.02409342e-20 1.57549297e-18 2.63027228e-22 6.11241908e-19 -2.71906856e-19 1.41003203e-12 2.66730019e-26 2.25679315e-26 1.00596535e-25 3.02875382e-22 3.85539387e-17 6.79708607e-22 1.60452617e-22 -2.08440846e-17 -5.40071056e-28 4.56236979e-23 -1.00868521e-17 1.22265047e-22 -1.81997389e-23 +-1.04248315e-17 1.1864143e-19 2.28504969e-21 2.13185342e-16 2.42873413e-17 -3.06462576e-25 1.36216664e-22 -1.37445558e-20 1.23906316e-20 -3.09825792e-18 1.11928135e-18 -1.68829721e-19 -6.81059804e-20 1.57549297e-18 2.5311263e-15 9.97996576e-20 2.26115975e-16 -3.86907114e-17 3.68487445e-12 8.23669787e-24 1.00324064e-23 3.38722042e-24 8.64234911e-21 2.46521189e-15 1.72823337e-19 9.24995431e-20 -3.16903295e-15 5.94130048e-25 1.73965082e-20 1.17371651e-15 2.26718703e-20 4.16709318e-21 +-5.32450516e-22 7.16291934e-24 2.8319611e-25 6.89417478e-21 8.89161401e-22 -6.33857393e-28 -8.0549564e-26 2.11412652e-24 2.55235433e-24 -1.93365673e-22 8.58541052e-23 -2.7699538e-23 5.13487959e-23 2.63027228e-22 9.97996576e-20 2.88326168e-23 1.35358898e-20 5.43364968e-21 4.24011412e-14 1.88486064e-27 8.93106076e-29 4.5748278e-27 2.48573168e-24 5.81165621e-19 1.96505062e-23 5.84813631e-24 -2.46866108e-20 1.912471e-29 2.0243857e-24 -2.88983463e-20 1.35761502e-24 1.40424791e-27 +-1.81712853e-17 4.10159639e-20 3.96494845e-22 1.35805044e-16 1.21669872e-17 -6.08829397e-24 -1.94826821e-22 2.64820742e-21 8.36323349e-20 -2.25608735e-18 8.80645183e-18 -2.15173717e-19 1.80612902e-18 6.11241908e-19 2.26115975e-16 1.35358898e-20 3.66013906e-15 1.35652384e-17 -1.97764849e-09 4.16586597e-24 1.28936031e-24 6.96597122e-23 2.43147439e-21 -1.25627342e-15 1.52711738e-18 2.61025243e-19 -2.00782109e-15 9.75835691e-24 4.0203e-21 1.40790259e-15 -7.8869e-21 8.51983e-20 +6.0044594e-18 -2.16798529e-20 -2.1988623e-22 -3.48309239e-16 -6.85317731e-18 -7.07478859e-24 -3.64999226e-22 8.02510339e-20 1.2152038e-19 7.98997246e-18 4.80109643e-21 7.46895651e-19 5.32462054e-19 -2.71906856e-19 -3.86907114e-17 5.43364968e-21 1.35652384e-17 1.19795414e-15 1.18472676e-09 2.74214961e-23 -7.6305178e-26 1.25969175e-23 1.68466447e-19 1.33873166e-15 1.0739288e-18 1.02533716e-19 2.73480291e-14 -1.87024011e-24 -9.73944425e-21 2.74769918e-16 -1.48632788e-20 1.69142815e-21 +3.96602716e-11 -4.95460504e-14 6.26027228e-16 1.0448622e-09 -7.345906e-12 -4.82614847e-18 -2.92500975e-15 4.39926334e-13 9.83332204e-14 1.45582661e-11 -1.7163557e-11 1.71858101e-12 -3.89327199e-12 1.41003203e-12 3.68487445e-12 4.24011412e-14 -1.97764849e-09 1.18472676e-09 0.0257282555 5.64106473e-17 5.83845666e-18 -1.72409096e-16 1.02886027e-12 1.42563525e-08 -1.57067415e-12 -4.61972799e-13 3.30651737e-08 -5.20615037e-17 -1.71347193e-14 2.87764201e-10 5.03749196e-14 -1.97989316e-13 +2.89077487e-25 -2.6881406e-27 1.2418479e-30 -2.17287918e-23 -3.1158751e-25 -2.7324345e-31 -3.00986528e-29 9.58727772e-27 5.14523933e-27 6.29004356e-25 1.92262335e-26 5.41956e-26 3.60012729e-26 2.66730019e-26 8.23669787e-24 1.88486064e-27 4.16586597e-24 2.74214961e-23 5.64106473e-17 1.2555855e-29 -1.30304595e-31 8.42884087e-31 1.75222077e-26 -2.89058862e-23 3.0225144e-26 6.67962117e-27 8.54181718e-22 -1.2385176e-32 -5.78078369e-28 3.34704626e-23 -2.00599605e-27 2.05674681e-28 +-2.47461475e-25 5.32861213e-27 2.1016041e-30 7.41749185e-24 1.36359449e-24 1.23830207e-31 2.39712646e-29 2.9838033e-28 -3.28220159e-28 -1.14866332e-25 -2.78003951e-26 -6.15013064e-27 -2.5575456e-26 2.25679315e-26 1.00324064e-23 8.93106076e-29 1.28936031e-24 -7.6305178e-26 5.83845666e-18 -1.30304595e-31 2.26490979e-30 -4.25637053e-31 1.40697e-27 5.91197152e-22 -2.08475892e-26 -5.64982671e-28 -3.97199197e-23 -5.06794406e-32 1.11993943e-27 -2.94280711e-23 2.65858181e-27 -2.23093754e-28 +1.77941757e-24 -4.54567085e-28 6.22813846e-30 -7.36683057e-23 -1.57981417e-24 -7.96172e-31 -1.02470704e-28 1.29183353e-26 8.22099066e-27 -5.51419319e-26 5.48322572e-25 1.54884457e-26 3.14316426e-25 1.00596535e-25 3.38722042e-24 4.5748278e-27 6.96597122e-23 1.25969175e-23 -1.72409096e-16 8.42884087e-31 -4.25637053e-31 1.40764294e-28 1.38735442e-26 -1.93810515e-22 1.93660175e-25 1.97417449e-26 1.62145272e-22 2.52533191e-31 -3.42833345e-28 6.34130774e-22 -2.01859e-27 6.1781768e-27 +-7.30388687e-21 1.99794328e-23 -1.0708067e-25 -1.31083094e-20 3.89633371e-21 -1.9034503e-27 -4.99034099e-25 1.78626483e-22 3.34939233e-23 2.97082139e-21 8.95330117e-23 2.54028029e-22 4.56614351e-22 3.02875382e-22 8.64234911e-21 2.48573168e-24 2.43147439e-21 1.68466447e-19 1.02886027e-12 1.75222077e-26 1.40697e-27 1.38735442e-26 1.18400807e-21 1.40670976e-18 2.40320429e-22 3.69528133e-23 4.81603371e-18 -1.49322683e-27 -2.70670724e-25 1.59463723e-19 6.40406749e-24 1.17170599e-23 +-3.84350041e-16 1.26854541e-17 6.90778045e-21 1.574e-14 9.94580899e-16 -3.82709848e-22 -1.32277916e-19 3.03531056e-19 4.3309476e-19 -2.39052259e-16 -1.11570766e-17 -1.50009535e-18 -1.24545392e-17 3.85539387e-17 2.46521189e-15 5.81165621e-19 -1.25627342e-15 1.33873166e-15 1.42563525e-08 -2.89058862e-23 5.91197152e-22 -1.93810515e-22 1.40670976e-18 4.40677789e-12 7.86017934e-19 7.73466606e-19 1.96690791e-15 -1.65941347e-22 2.63659933e-18 -3.0624544e-14 5.87194631e-18 -3.46291098e-19 +-3.88532388e-21 -1.92916739e-23 1.86361622e-25 5.72646592e-19 1.45732115e-20 -2.69257733e-26 -5.05595e-24 9.62612316e-23 5.82711129e-22 1.48920411e-20 3.13666242e-20 1.11920465e-21 9.14707146e-21 6.79708607e-22 1.72823337e-19 1.96505062e-23 1.52711738e-18 1.0739288e-18 -1.57067415e-12 3.0225144e-26 -2.08475892e-26 1.93660175e-25 2.40320429e-22 7.86017934e-19 1.80741048e-20 9.85491491e-22 5.08456938e-17 1.08072265e-26 -1.75036654e-23 4.36436952e-18 -1.77728563e-23 1.01268548e-22 +-4.29928618e-21 8.60632417e-24 7.08789674e-26 -9.85673749e-21 6.92065325e-22 -3.84934809e-27 -3.04012473e-25 1.33722715e-23 1.14299394e-22 1.28589326e-21 4.47195205e-21 1.05890428e-22 7.97421952e-22 1.60452617e-22 9.24995431e-20 5.84813631e-24 2.61025243e-19 1.02533716e-19 -4.61972799e-13 6.67962117e-27 -5.64982671e-28 1.97417449e-26 3.69528133e-23 7.73466606e-19 9.85491491e-22 3.68332283e-22 1.76753773e-18 2.6167718e-27 3.55918682e-25 1.95786374e-19 -2.60077304e-24 1.84790635e-23 +4.13551131e-16 -1.04721097e-18 -9.23628499e-21 -1.0654985e-14 -1.86114433e-16 -1.48572725e-22 -1.44724215e-20 2.92905627e-18 3.25240717e-18 4.27717466e-16 -1.09014604e-17 3.6487132e-17 2.84371096e-17 -2.08440846e-17 -3.16903295e-15 -2.46866108e-20 -2.00782109e-15 2.73480291e-14 3.30651737e-08 8.54181718e-22 -3.97199197e-23 1.62145272e-22 4.81603371e-18 1.96690791e-15 5.08456938e-17 1.76753773e-18 1.57092991e-12 -4.31425852e-23 -3.78241e-19 -1.15899865e-14 -7.61890782e-19 -1.15344546e-19 +-2.63408791e-25 -7.00607669e-28 1.65335067e-30 2.70679318e-23 6.00601346e-26 4.14585761e-31 5.04614184e-30 -9.42286262e-28 5.84184241e-28 -4.44694851e-26 7.69340111e-26 -2.06798384e-27 2.98359736e-26 -5.40071056e-28 5.94130048e-25 1.912471e-29 9.75835691e-24 -1.87024011e-24 -5.20615037e-17 -1.2385176e-32 -5.06794406e-32 2.52533191e-31 -1.49322683e-27 -1.65941347e-22 1.08072265e-26 2.6167718e-27 -4.31425852e-23 1.5576233e-30 -6.14697676e-29 -5.39097603e-24 -8.01112167e-29 1.81063126e-27 +-2.84830375e-21 6.86771954e-23 -1.12173032e-26 4.0943479e-20 3.26844e-21 2.5611404e-28 -4.12370105e-26 3.23170971e-24 -1.76991199e-24 -1.80270052e-22 1.64649306e-22 -5.5143889e-23 1.33439467e-23 4.56236979e-23 1.73965082e-20 2.0243857e-24 4.0203e-21 -9.73944425e-21 -1.71347193e-14 -5.78078369e-28 1.11993943e-27 -3.42833345e-28 -2.70670724e-25 2.63659933e-18 -1.75036654e-23 3.55918682e-25 -3.78241e-19 -6.14697676e-29 2.71732416e-23 2.4136621e-19 2.38938648e-23 1.21468477e-24 +-1.6450072e-16 8.65173173e-19 8.2257321e-22 -3.42938568e-15 4.38573742e-17 -2.77402858e-24 4.20735765e-21 4.10885529e-19 5.5568966e-20 3.29932795e-18 1.71054085e-17 -1.71529414e-18 1.00242743e-17 -1.00868521e-17 1.17371651e-15 -2.88983463e-20 1.40790259e-15 2.74769918e-16 2.87764201e-10 3.34704626e-23 -2.94280711e-23 6.34130774e-22 1.59463723e-19 -3.0624544e-14 4.36436952e-18 1.95786374e-19 -1.15899865e-14 -5.39097603e-24 2.4136621e-19 2.10373291e-13 4.84257897e-20 2.71571227e-19 +-2.8585296e-21 1.24469175e-22 -4.72686764e-27 8.57373804e-20 1.06803444e-20 3.10373361e-28 -1.02818953e-25 -8.38673724e-25 -2.80294941e-24 -5.11645591e-22 1.33471053e-23 -7.38099094e-23 -4.94476664e-23 1.22265047e-22 2.26718703e-20 1.35761502e-24 -7.8869e-21 -1.48632788e-20 5.03749196e-14 -2.00599605e-27 2.65858181e-27 -2.01859e-27 6.40406749e-24 5.87194631e-18 -1.77728563e-23 -2.60077304e-24 -7.61890782e-19 -8.01112167e-29 2.38938648e-23 4.84257897e-20 7.77486414e-23 -7.38542574e-25 +-3.65413296e-21 6.03883081e-24 -2.58501275e-26 -2.18094505e-20 4.60203933e-22 -5.09669241e-28 3.41267575e-26 -8.63732285e-25 4.59071175e-24 5.53091711e-23 6.40747815e-22 -6.5250472e-24 3.28816461e-22 -1.81997389e-23 4.16709318e-21 1.40424791e-27 8.51983e-20 1.69142815e-21 -1.97989316e-13 2.05674681e-28 -2.23093754e-28 6.1781768e-27 1.17170599e-23 -3.46291098e-19 1.01268548e-22 1.84790635e-23 -1.15344546e-19 1.81063126e-27 1.21468477e-24 2.71571227e-19 -7.38542574e-25 3.49516247e-23 diff --git a/tensorflow/python/kernel_tests/unique_op_test.py b/tensorflow/python/kernel_tests/unique_op_test.py index 4498fd9fe9986c134b92aed192a6de6f06109bd9..bbc040dc13fc151b970f130eeb76fa1639245416 100644 --- a/tensorflow/python/kernel_tests/unique_op_test.py +++ b/tensorflow/python/kernel_tests/unique_op_test.py @@ -66,9 +66,9 @@ class UniqueTest(test.TestCase): for dtype in [np.int32, np.int64]: x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]]) with self.test_session() as sess: - y0, idx0 = gen_array_ops._unique_v2(x, axis=np.array([0], dtype)) + y0, idx0 = gen_array_ops.unique_v2(x, axis=np.array([0], dtype)) tf_y0, tf_idx0 = sess.run([y0, idx0]) - y1, idx1 = gen_array_ops._unique_v2(x, axis=np.array([1], dtype)) + y1, idx1 = gen_array_ops.unique_v2(x, axis=np.array([1], dtype)) tf_y1, tf_idx1 = sess.run([y1, idx1]) self.assertAllEqual(tf_y0, np.array([[1, 0, 0], [2, 0, 0]])) self.assertAllEqual(tf_idx0, np.array([0, 0, 1])) @@ -80,7 +80,7 @@ class UniqueTest(test.TestCase): # by default, the axis will be wrapped to allow `axis=None`. x = np.random.randint(2, high=10, size=7000) with self.test_session() as sess: - y, idx = gen_array_ops._unique_v2(x, axis=np.array([], np.int32)) + y, idx = gen_array_ops.unique_v2(x, axis=np.array([], np.int32)) tf_y, tf_idx = sess.run([y, idx]) self.assertEqual(len(x), len(tf_idx)) @@ -137,10 +137,10 @@ class UniqueWithCountsTest(test.TestCase): for dtype in [np.int32, np.int64]: x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]]) with self.test_session() as sess: - y0, idx0, count0 = gen_array_ops._unique_with_counts_v2( + y0, idx0, count0 = gen_array_ops.unique_with_counts_v2( x, axis=np.array([0], dtype)) tf_y0, tf_idx0, tf_count0 = sess.run([y0, idx0, count0]) - y1, idx1, count1 = gen_array_ops._unique_with_counts_v2( + y1, idx1, count1 = gen_array_ops.unique_with_counts_v2( x, axis=np.array([1], dtype)) tf_y1, tf_idx1, tf_count1 = sess.run([y1, idx1, count1]) self.assertAllEqual(tf_y0, np.array([[1, 0, 0], [2, 0, 0]])) @@ -155,7 +155,7 @@ class UniqueWithCountsTest(test.TestCase): # by default, the axis will be wrapped to allow `axis=None`. x = np.random.randint(2, high=10, size=7000) with self.test_session() as sess: - y, idx, count = gen_array_ops._unique_with_counts_v2( + y, idx, count = gen_array_ops.unique_with_counts_v2( x, axis=np.array([], np.int32)) tf_y, tf_idx, tf_count = sess.run([y, idx, count]) diff --git a/tensorflow/python/kernel_tests/variable_ops_test.py b/tensorflow/python/kernel_tests/variable_ops_test.py index 79071029fd42374964d12f513e9c510bdc7400eb..cf369c071813120fef685b7220292d50b966cf11 100644 --- a/tensorflow/python/kernel_tests/variable_ops_test.py +++ b/tensorflow/python/kernel_tests/variable_ops_test.py @@ -165,26 +165,26 @@ class VariableOpTest(test.TestCase): def testTemporaryVariable(self): with self.test_session(use_gpu=True): - var = gen_state_ops._temporary_variable( + var = gen_state_ops.temporary_variable( [1, 2], dtypes.float32, var_name="foo") var = state_ops.assign(var, [[4.0, 5.0]]) var = state_ops.assign_add(var, [[6.0, 7.0]]) - final = gen_state_ops._destroy_temporary_variable(var, var_name="foo") + final = gen_state_ops.destroy_temporary_variable(var, var_name="foo") self.assertAllClose([[10.0, 12.0]], final.eval()) def testDestroyNonexistentTemporaryVariable(self): with self.test_session(use_gpu=True): - var = gen_state_ops._temporary_variable([1, 2], dtypes.float32) - final = gen_state_ops._destroy_temporary_variable(var, var_name="bad") + var = gen_state_ops.temporary_variable([1, 2], dtypes.float32) + final = gen_state_ops.destroy_temporary_variable(var, var_name="bad") with self.assertRaises(errors.NotFoundError): final.eval() def testDuplicateTemporaryVariable(self): with self.test_session(use_gpu=True): - var1 = gen_state_ops._temporary_variable( + var1 = gen_state_ops.temporary_variable( [1, 2], dtypes.float32, var_name="dup") var1 = state_ops.assign(var1, [[1.0, 2.0]]) - var2 = gen_state_ops._temporary_variable( + var2 = gen_state_ops.temporary_variable( [1, 2], dtypes.float32, var_name="dup") var2 = state_ops.assign(var2, [[3.0, 4.0]]) final = var1 + var2 @@ -193,25 +193,25 @@ class VariableOpTest(test.TestCase): def testDestroyTemporaryVariableTwice(self): with self.test_session(use_gpu=True): - var = gen_state_ops._temporary_variable([1, 2], dtypes.float32) - val1 = gen_state_ops._destroy_temporary_variable(var, var_name="dup") - val2 = gen_state_ops._destroy_temporary_variable(var, var_name="dup") + var = gen_state_ops.temporary_variable([1, 2], dtypes.float32) + val1 = gen_state_ops.destroy_temporary_variable(var, var_name="dup") + val2 = gen_state_ops.destroy_temporary_variable(var, var_name="dup") final = val1 + val2 with self.assertRaises(errors.NotFoundError): final.eval() def testTemporaryVariableNoLeak(self): with self.test_session(use_gpu=True): - var = gen_state_ops._temporary_variable( + var = gen_state_ops.temporary_variable( [1, 2], dtypes.float32, var_name="bar") final = array_ops.identity(var) final.eval() def testTwoTemporaryVariablesNoLeaks(self): with self.test_session(use_gpu=True): - var1 = gen_state_ops._temporary_variable( + var1 = gen_state_ops.temporary_variable( [1, 2], dtypes.float32, var_name="var1") - var2 = gen_state_ops._temporary_variable( + var2 = gen_state_ops.temporary_variable( [1, 2], dtypes.float32, var_name="var2") final = var1 + var2 final.eval() diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index 8527f116f9541942e52ba2ab635ca1212ea38583..86ab9fbb70b5efcf06cc064617df14deb18c1f98 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import gc +import threading import numpy @@ -166,12 +167,10 @@ class VariableScopeTest(test.TestCase): self.evaluate(variables_lib.variables_initializer([w])) self.assertAllClose(self.evaluate(w.value()), [1, 2, 3]) - if context.in_graph_mode(): - with self.assertRaises(TypeError): - variable_scope.get_variable("x4", initializer={}) - else: - with self.assertRaises(ValueError): - variable_scope.get_variable("x4", initializer={}) + # A quirk to be revisited? + error = ValueError if context.executing_eagerly() else TypeError + with self.assertRaises(error): + variable_scope.get_variable("x4", initializer={}) @test_util.run_in_graph_and_eager_modes() def testInitFromNonInitializer(self): @@ -267,7 +266,7 @@ class VariableScopeTest(test.TestCase): self.assertAllClose(self.evaluate(losses[2]), 0.5) with variable_scope.variable_scope("foo", reuse=True): # reuse=True is for now only supported when eager execution is disabled. - if context.in_graph_mode(): + if not context.executing_eagerly(): v = variable_scope.get_variable("v", []) # "v" is alredy there, reused losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) @@ -374,7 +373,7 @@ class VariableScopeTest(test.TestCase): v = variable_scope.get_variable("v", []) self.evaluate(variables_lib.variables_initializer([v])) self.assertAllClose(self.evaluate(v.value()), 0.3) - if context.in_graph_mode(): + if not context.executing_eagerly(): # Check that we can set reuse. variable_scope.get_variable_scope().reuse_variables() with self.assertRaises(ValueError): # Fail, w does not exist yet. @@ -408,7 +407,7 @@ class VariableScopeTest(test.TestCase): with variable_scope.variable_scope("tower") as tower: with ops.name_scope("scope2") as sc2: self.assertEqual(sc2, "testVarScopeNameScope1/tower/scope2/") - if context.in_graph_mode(): + if not context.executing_eagerly(): with variable_scope.variable_scope( tower): # Re-entering acts like another "tower". with ops.name_scope("scope2") as sc2: @@ -422,7 +421,7 @@ class VariableScopeTest(test.TestCase): with variable_scope.variable_scope("tower"): with ops.name_scope("scope2") as sc2: self.assertEqual(sc2, "testVarScopeNameScope2/tower/scope2/") - if context.in_graph_mode(): + if not context.executing_eagerly(): with variable_scope.variable_scope(tower): with ops.name_scope("scope2") as sc2: self.assertEqual(sc2, "testVarScopeNameScope2/tower_1/scope2/") @@ -903,17 +902,15 @@ class VariableScopeTest(test.TestCase): "w", [], collections=["foo"]) self.assertEqual(local_var.name, "outer/w:0") - # Since variable is local, it should be in the local variable collection - # but not the trainable collection. - if context.in_graph_mode(): + if not context.executing_eagerly(): + # Since variable is local, it should be in the local variable collection + # but not the trainable collection. self.assertIn(local_var, ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) self.assertIn(local_var, ops.get_collection("foo")) self.assertNotIn(local_var, ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) - - # Check that local variable respects `reuse`. - if context.in_graph_mode(): + # Check that local variable respects `reuse`. with variable_scope.variable_scope(outer, "default", reuse=True): self.assertEqual( variable_scope.get_local_variable("w", []).name, "outer/w:0") @@ -1353,5 +1350,91 @@ class PartitionInfoTest(test.TestCase): self.assertEqual(0, partition_info.single_slice_dim([2, 3])) +class VariableScopeMultithreadedTest(test.TestCase): + + def testTwoThreadsDisjointScopeEntry(self): + + def thread_fn(i, graph): + with graph.as_default(): + with variable_scope.variable_scope("foo"): + if i == 0: + v = variable_scope.get_variable("v", []) + self.assertEquals("foo/v:0", v.name) + else: + # Any thread after the first one should fail to create variable + # with the same name. + with self.assertRaises(ValueError): + variable_scope.get_variable("v", []) + + graph = ops.get_default_graph() + threads = [ + threading.Thread(target=thread_fn, args=(i, graph,)) for i in range(2)] + + threads[0].start() + # Allow thread 0 to finish before starting thread 1. + threads[0].join() + threads[1].start() + threads[1].join() + + def testTwoThreadsNestedScopeEntry(self): + + def thread_fn(i, graph, run_event, pause_event): + with graph.as_default(): + with variable_scope.variable_scope("foo"): + if i == 0: + v = variable_scope.get_variable("v", []) + self.assertEquals("foo/v:0", v.name) + else: + # Any thread after the first one should fail to create variable + # with the same name. + with self.assertRaises(ValueError): + variable_scope.get_variable("v", []) + pause_event.set() + run_event.wait() + + graph = ops.get_default_graph() + run_events = [threading.Event() for _ in range(2)] + pause_events = [threading.Event() for _ in range(2)] + threads = [ + threading.Thread( + target=thread_fn, args=(i, graph, run_events[i], pause_events[i])) + for i in range(2) + ] + + # Start first thread. + threads[0].start() + pause_events[0].wait() + # Start next thread once the first thread has paused. + threads[1].start() + pause_events[1].wait() + # Resume both threads. + run_events[0].set() + run_events[1].set() + threads[0].join() + threads[1].join() + + def testReenterMainScope(self): + + def thread_fn(graph, main_thread_scope): + with graph.as_default(): + # Variable created with main scope will have prefix "main". + with variable_scope.variable_scope(main_thread_scope): + with variable_scope.variable_scope("foo"): + v = variable_scope.get_variable("v", []) + self.assertEquals("main/foo/v:0", v.name) + + # Variable created outside main scope will not have prefix "main". + with variable_scope.variable_scope("bar"): + v = variable_scope.get_variable("v", []) + self.assertEquals("bar/v:0", v.name) + + graph = ops.get_default_graph() + with variable_scope.variable_scope("main") as main_thread_scope: + thread = threading.Thread( + target=thread_fn, args=(graph, main_thread_scope)) + thread.start() + thread.join() + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index b16c8c002c98a0351d1fc55fce061695327a18c9..27599868b74be323189b872c2147c6a33f84d170 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -687,7 +687,7 @@ class VariableContainerTest(test.TestCase): v1 = variables.Variable([1]) with ops.container("l2"): v2 = variables.Variable([2]) - special_v = gen_state_ops._variable( + special_v = gen_state_ops.variable( shape=[1], dtype=dtypes.float32, name="VariableInL3", diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py index e152f02d8e983364603053dc5c8d14b5dfaf3605..60c726d54ceeb65ddf52af9b6aad685501214c24 100644 --- a/tensorflow/python/kernel_tests/xent_op_test.py +++ b/tensorflow/python/kernel_tests/xent_op_test.py @@ -18,10 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import itertools +import sys + import numpy as np +from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl @@ -48,7 +54,7 @@ class XentTest(test.TestCase): def _testXent(self, np_features, np_labels, use_gpu=False): np_loss, np_backprop = self._npXent(np_features, np_labels) with self.test_session(use_gpu=use_gpu) as sess: - loss, backprop = gen_nn_ops._softmax_cross_entropy_with_logits( + loss, backprop = gen_nn_ops.softmax_cross_entropy_with_logits( np_features, np_labels) tf_loss, tf_backprop = sess.run([loss, backprop]) self.assertAllCloseAccordingToType(np_loss, tf_loss) @@ -71,7 +77,7 @@ class XentTest(test.TestCase): def _testSingleClass(self, use_gpu=False): for dtype in np.float16, np.float32: with self.test_session(use_gpu=use_gpu) as sess: - loss, backprop = gen_nn_ops._softmax_cross_entropy_with_logits( + loss, backprop = gen_nn_ops.softmax_cross_entropy_with_logits( np.array([[1.], [-1.], [0.]]).astype(dtype), np.array([[-1.], [0.], [1.]]).astype(dtype)) tf_loss, tf_backprop = sess.run([loss, backprop]) @@ -88,8 +94,8 @@ class XentTest(test.TestCase): 4.]]]).astype(dtype) np_labels = np.array([[[0., 0., 0., 1.]], [[0., .5, .5, 0.]]]).astype(dtype) - self.assertRaisesRegexp(ValueError, "must be rank 2", - gen_nn_ops._softmax_cross_entropy_with_logits, + self.assertRaisesRegexp(ValueError, "rank 2, but is rank 3", + gen_nn_ops.softmax_cross_entropy_with_logits, np_features, np_labels) def testNpXent(self): @@ -128,17 +134,35 @@ class XentTest(test.TestCase): self.assertAllClose( np.array([1.3862, 1.9401]), np_loss, rtol=1.e-3, atol=1.e-3) + def testShapeBroadcast(self): + np_f = np.array([[1., 2., 3., 4.], + [1., 2., 3., 4.]]).astype(np.float32) + np_l = np.array([[0., 0., 0., 1.], + [0., .5, .5, 0.]]).astype(np.float32) + np_loss, np_backprop = self._npXent(np_f, np_l) + tf_f = constant_op.constant( + np.array([[1., 2., 3., 4.]]).astype(np.float32)) + tf_l = constant_op.constant( + np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float32)) + for use_gpu in [False, True]: + with self.test_session(use_gpu=use_gpu) as sess: + loss, backprop = gen_nn_ops.softmax_cross_entropy_with_logits( + tf_f, tf_l) + tf_loss, tf_backprop = sess.run([loss, backprop]) + self.assertAllCloseAccordingToType(np_loss, tf_loss) + self.assertAllCloseAccordingToType(np_backprop, tf_backprop) + def testShapeMismatch(self): with self.test_session(): with self.assertRaises(ValueError): - gen_nn_ops._softmax_cross_entropy_with_logits( + gen_nn_ops.softmax_cross_entropy_with_logits( [[0., 1.], [2., 3.]], [[0., 1., 0.], [1., 0., 0.]]) def testNotMatrix(self): with self.test_session(): with self.assertRaises(ValueError): - gen_nn_ops._softmax_cross_entropy_with_logits([0., 1., 2., 3.], - [0., 1., 0., 1.]) + gen_nn_ops.softmax_cross_entropy_with_logits([0., 1., 2., 3.], + [0., 1., 0., 1.]) def testHalf(self): self._testAll( @@ -260,5 +284,60 @@ class XentTest(test.TestCase): self.assertAllEqual(np_loss, tf_loss) +class XentBenchmark(test.Benchmark): + + def benchmarkZeroDimension(self): + for (m, n, p, use_gpu) in itertools.product( + [128], + [10, 100, 1000, 10000, 100000], + [0.001, 0.01, 0.5, 0.99, 1.0], + [False]): + k = int(p * n) + if k == 0: + continue + name = "zero_dimension_m_%d_n_%d_k_%g_use_gpu_%s" % (m, n, k, use_gpu) + device = "/%s:0" % ("gpu" if use_gpu else "cpu") + with ops.Graph().as_default(): + with ops.device(device): + labels = array_ops.zeros([0, 2, 4], dtype=dtypes.float32) + logits = array_ops.zeros([0, 2, 4], dtype=dtypes.float32) + op = nn_ops.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + with session.Session() as sess: + r = self.run_op_benchmark(sess, op, min_iters=100, name=name) + gb_processed_input = m * n / 1.0e9 + throughput = gb_processed_input / r["wall_time"] + print("Benchmark: %s \t wall_time: %0.03g s \t " + "Throughput: %0.03g GB/s" % (name, r["wall_time"], throughput)) + sys.stdout.flush() + + def benchmarkSingleClass(self): + for (m, n, p, use_gpu) in itertools.product( + [128], + [10, 100, 1000, 10000, 100000], + [0.001, 0.01, 0.5, 0.99, 1.0], + [False]): + k = int(p * n) + if k == 0: + continue + name = "single_class_m_%d_n_%d_k_%g_use_gpu_%s" % (m, n, k, use_gpu) + device = "/%s:0" % ("gpu" if use_gpu else "cpu") + with ops.Graph().as_default(): + with ops.device(device): + labels = constant_op.constant([[1.], [-1.], [0.]], + dtype=dtypes.float32) + logits = constant_op.constant([[-1.], [0.], [1.]], + dtype=dtypes.float32) + op = nn_ops.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + with session.Session() as sess: + r = self.run_op_benchmark(sess, op, min_iters=100, name=name) + gb_processed_input = m * n / 1.0e9 + throughput = gb_processed_input / r["wall_time"] + print("Benchmark: %s \t wall_time: %0.03g s \t " + "Throughput: %0.03g GB/s" % (name, r["wall_time"], throughput)) + sys.stdout.flush() + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 8314c4aa87a5b54effc44c371703267517ffa07d..1e5f26a77f4c923871f780ca31dac1763ddd144c 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -36,12 +36,13 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpointable from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @tf_export('layers.Layer') -class Layer(object): +class Layer(checkpointable.CheckpointableBase): """Base layer class. This is the class from which all layers inherit, implementing common @@ -114,7 +115,7 @@ class Layer(object): # Provides information about which inputs are compatible with the layer. self.input_spec = None - if activity_regularizer and context.in_eager_mode(): + if activity_regularizer and context.executing_eagerly(): raise ValueError( ('Activity regularization is not supported when executing eagerly. ' 'Got activity_regularizer=%s') % (activity_regularizer,)) @@ -126,12 +127,12 @@ class Layer(object): # return tensors. When using graph execution, _losses is a list of ops. self._losses = [] self._reuse = kwargs.get('_reuse') - self._graph = ops.get_default_graph() + self._graph = None # Will be set at build time. self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name - call_fn_args = estimator_util.fn_args(self.call) - self._compute_previous_mask = ('mask' in call_fn_args or + self._call_fn_args = estimator_util.fn_args(self.call) + self._compute_previous_mask = ('mask' in self._call_fn_args or hasattr(self, 'compute_mask')) - self._call_has_scope_arg = 'scope' in call_fn_args + self._call_has_scope_arg = 'scope' in self._call_fn_args # These lists will be filled via successive calls # to self._add_inbound_node(). @@ -227,7 +228,7 @@ class Layer(object): @property def updates(self): - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('Layer.updates not supported in Eager mode.') if not self.trainable and not self.stateful: return [] @@ -259,7 +260,7 @@ class Layer(object): have is available at runtime. A step counter might fall into this category. """ - if context.in_eager_mode(): + if context.executing_eagerly(): return # Updates already applied when in eager mode. updates = _to_list(updates) @@ -285,7 +286,7 @@ class Layer(object): Raises: RuntimeError: If called in Eager mode. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('`get_updates_for()` not supported in Eager mode.') # Updates disabled if layer is not trainable and not explicitly stateful. @@ -316,7 +317,7 @@ class Layer(object): Returns: A list of tensors. """ - if context.in_eager_mode(): + if context.executing_eagerly(): # _losses may only contain variable regularization losses when executing # eagerly, and they have been saved as lambdas to be executed when # requested. @@ -354,7 +355,7 @@ class Layer(object): Raises: RuntimeError: If called in Eager mode. """ - if context.in_eager_mode(): + if context.executing_eagerly(): # TODO(fchollet): it should be possible (and highly desirable) to support # `add_loss` in eager mode. This allows great convenience and flexibility # in defining custom losses on the fly (e.g. in VAEs). @@ -388,7 +389,7 @@ class Layer(object): Raises: RuntimeError: If called in Eager mode. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('Layer.get_losses_for not supported in Eager mode.') if inputs is None: @@ -508,7 +509,7 @@ class Layer(object): # will occur; it should be None if and only if initialization will take # place in the eager context. init_graph = None - if context.in_graph_mode(): + if not context.executing_eagerly(): default_graph = ops.get_default_graph() if default_graph.building_function: with ops.init_scope(): @@ -516,7 +517,7 @@ class Layer(object): # will be lifted; if initialization ops will be lifted into # the eager context, then there is nothing to retrieve, since variable # collections are not supported when eager execution is enabled. - if context.in_graph_mode(): + if not context.executing_eagerly(): init_graph = ops.get_default_graph() existing_variables = set(tf_variables.global_variables()) else: @@ -532,13 +533,17 @@ class Layer(object): with vs.variable_scope( self._scope, reuse=reuse, auxiliary_name_scope=False) as scope: with ops.name_scope(self._name_scope_name(scope)): - variable = vs.get_variable(name, - shape=shape, - initializer=initializer, - dtype=dtypes.as_dtype(dtype), - constraint=constraint, - trainable=trainable and self.trainable, - partitioner=partitioner) + variable = self._add_variable_with_custom_getter( + name=name, + shape=shape, + getter=vs.get_variable, + # Manage errors in Layer rather than Checkpointable. + overwrite=True, + initializer=initializer, + dtype=dtypes.as_dtype(dtype), + constraint=constraint, + trainable=trainable and self.trainable, + partitioner=partitioner) if init_graph is not None: # pylint: disable=protected-access # The variable was created and initialized in a graph. @@ -573,7 +578,7 @@ class Layer(object): if isinstance(variable, tf_variables.PartitionedVariable): raise RuntimeError( 'Partitioned variable regularization is not yet ' - 'supported when executing eagerly. File a feature request' + 'supported when executing eagerly. File a feature request ' 'if this is important to you.') # Save a zero-argument lambda which runs the regularizer on the # variable, to be executed when `Layer.losses` is requested. @@ -619,16 +624,17 @@ class Layer(object): self._set_scope(kwargs.pop('scope', None)) input_list = nest.flatten(inputs) - in_graph_mode = context.in_graph_mode() + build_graph = not context.executing_eagerly() in_deferred_mode = isinstance(input_list[0], _DeferredTensor) # Ensure the Layer, if being reused, is working with inputs from # the same graph as where it was created. - if in_graph_mode: + if build_graph: try: - ops._get_graph_from_inputs(input_list, graph=self.graph) # pylint: disable=protected-access + # Set layer's "graph" at build time + self._graph = ops._get_graph_from_inputs(input_list, graph=self._graph) # pylint: disable=protected-access except ValueError as e: raise ValueError('Input graph and Layer graph are not the same: %s' % e) - if in_graph_mode or in_deferred_mode: + if build_graph or in_deferred_mode: user_kwargs = copy.copy(kwargs) # Handle Keras mask propagation from previous layer to current layer. @@ -636,8 +642,9 @@ class Layer(object): if (not hasattr(self, '_compute_previous_mask') or self._compute_previous_mask): previous_mask = _collect_previous_mask(inputs) - if ('mask' in estimator_util.fn_args(self.call) and - 'mask' not in kwargs and + if not hasattr(self, '_call_fn_args'): + self._call_fn_args = estimator_util.fn_args(self.call) + if ('mask' in self._call_fn_args and 'mask' not in kwargs and not _is_all_none(previous_mask)): # The previous layer generated a mask, and mask was not explicitly pass # to __call__, hence we set previous_mask as the default value. @@ -662,13 +669,14 @@ class Layer(object): with scope_context_manager as scope: with ops.name_scope(self._name_scope_name(scope)): if not self.built: - if not in_graph_mode: + if not build_graph: # Activity regularization is currently unsupported in Eager mode. if self._activity_regularizer: - raise ValueError('activity_regularizer currently unsupported in ' - 'Eager mode. Found an activity_regularizer in ' - '%s(%s).' % (self.__class__.__name__, self)) - if not in_graph_mode and not in_deferred_mode: + raise ValueError( + 'activity_regularizer currently unsupported with ' + 'eager execution enabled. Found an activity_regularizer in ' + '%s(%s).' % (self.__class__.__name__, self)) + if not build_graph and not in_deferred_mode: # TODO(agarwal): support _keras_history in Eager mode. for x in input_list: if hasattr(x, '_keras_history'): @@ -693,11 +701,13 @@ class Layer(object): # TODO(agarwal): Fix the sub-classes and avoid this complexity. call_has_scope_arg = self._call_has_scope_arg except AttributeError: - call_has_scope_arg = 'scope' in estimator_util.fn_args(self.call) + self._call_fn_args = estimator_util.fn_args(self.call) + self._call_has_scope_arg = 'scope' in self._call_fn_args + call_has_scope_arg = self._call_has_scope_arg if call_has_scope_arg: kwargs['scope'] = scope # Check input assumptions set after layer building, e.g. input shape. - if in_graph_mode or in_deferred_mode: + if build_graph or in_deferred_mode: self._assert_input_compatibility(inputs) if not in_deferred_mode: @@ -721,7 +731,7 @@ class Layer(object): if len(outputs) == 1: outputs = outputs[0] - if in_graph_mode: + if build_graph: # Apply activity regularization. # Note that it should be applied every time the layer creates a new # output, since it is output-specific. @@ -743,7 +753,7 @@ class Layer(object): else: outputs._keras_mask = output_mask # pylint: disable=protected-access - if in_graph_mode: + if build_graph: # If all input tensors have history metadata, # we update the output tensors # with corresponding history metadata, thus eventually allowing to use @@ -766,7 +776,7 @@ class Layer(object): # Update global default collections. _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) - if in_deferred_mode or in_graph_mode: + if in_deferred_mode or build_graph: if _have_all_keras_metadata(inputs): # Add an inbound node to the layer, so it can keep track of this call. # This updates the layer history of the output tensor(s). @@ -778,7 +788,7 @@ class Layer(object): @property def graph(self): - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('Layer.graph not supported in Eager mode.') return self._graph @@ -882,7 +892,6 @@ class Layer(object): mode. ValueError: If the index provided does not match any node. """ - assert context.in_graph_mode() if not self._inbound_nodes: raise RuntimeError('The layer has never been called ' 'and thus has no defined ' + attr_name + '.') @@ -912,9 +921,6 @@ class Layer(object): Raises: RuntimeError: If called in Eager mode. """ - if context.in_eager_mode(): - raise RuntimeError( - 'Layer.get_input_shape_at not supported in Eager mode.') return self._get_node_attribute_at_index(node_index, 'input_shapes', 'input shape') @@ -934,7 +940,7 @@ class Layer(object): Raises: RuntimeError: If called in Eager mode. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError( 'Layer.get_output_shape_at not supported in Eager mode.') return self._get_node_attribute_at_index(node_index, 'output_shapes', @@ -955,7 +961,7 @@ class Layer(object): Raises: RuntimeError: If called in Eager mode. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('Layer.get_input_at not supported in Eager mode.') return self._get_node_attribute_at_index(node_index, 'input_tensors', 'input') @@ -975,8 +981,6 @@ class Layer(object): Raises: RuntimeError: If called in Eager mode. """ - if context.in_eager_mode(): - raise RuntimeError('Layer.get_output_at not supported in Eager mode.') return self._get_node_attribute_at_index(node_index, 'output_tensors', 'output') @@ -998,8 +1002,6 @@ class Layer(object): RuntimeError: If called in Eager mode. AttributeError: If no inbound nodes are found. """ - if context.in_eager_mode(): - raise RuntimeError('Layer.input not supported in Eager mode.') if not self._inbound_nodes: raise AttributeError('Layer ' + self.name + ' is not connected, no input to return.') @@ -1020,8 +1022,6 @@ class Layer(object): layers. RuntimeError: if called in Eager mode. """ - if context.in_eager_mode(): - raise RuntimeError('Layer.output not supported in Eager mode.') if not self._inbound_nodes: raise AttributeError('Layer ' + self.name + ' has no inbound nodes.') return self._get_node_attribute_at_index(0, 'output_tensors', 'output') @@ -1042,8 +1042,6 @@ class Layer(object): AttributeError: if the layer has no defined input_shape. RuntimeError: if called in Eager mode. """ - if context.in_eager_mode(): - raise RuntimeError('Layer.input_shape not supported in Eager mode.') if not self._inbound_nodes: raise AttributeError('The layer has never been called ' 'and thus has no defined input shape.') @@ -1103,8 +1101,6 @@ class Layer(object): AttributeError: if the layer has no defined output shape. RuntimeError: if called in Eager mode. """ - if context.in_eager_mode(): - raise RuntimeError('Layer.output_shape not supported in Eager mode.') if not self._inbound_nodes: raise AttributeError('The layer has never been called ' 'and thus has no defined output shape.') @@ -1461,7 +1457,7 @@ def _to_list(x): def _add_elements_to_collection(elements, collection_list): - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('Using collections from Layers not supported in Eager ' 'mode. Tried to add %s to %s' % (elements, collection_list)) diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index 91b8988d31c1f04be8134733e5e919c738ccb74f..9ed4afeaba931c47d2a1e65f08489773f0b9eb1b 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -44,7 +44,7 @@ class BaseLayerTest(test.TestCase): self.assertEqual(layer.variables, []) self.assertEqual(layer.trainable_variables, []) self.assertEqual(layer.non_trainable_variables, []) - if context.in_graph_mode(): + if not context.executing_eagerly(): # updates, losses only supported in GRAPH mode self.assertEqual(layer.updates, []) self.assertEqual(layer.losses, []) @@ -63,7 +63,7 @@ class BaseLayerTest(test.TestCase): self.assertEqual(layer.variables, [variable]) self.assertEqual(layer.trainable_variables, [variable]) self.assertEqual(layer.non_trainable_variables, []) - if context.in_graph_mode(): + if not context.executing_eagerly(): self.assertEqual( layer.variables, ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) @@ -77,7 +77,7 @@ class BaseLayerTest(test.TestCase): self.assertEqual(layer.variables, [variable, variable_2]) self.assertEqual(layer.trainable_variables, [variable]) self.assertEqual(layer.non_trainable_variables, [variable_2]) - if context.in_graph_mode(): + if not context.executing_eagerly(): self.assertEqual( len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1) @@ -161,7 +161,7 @@ class BaseLayerTest(test.TestCase): inputs = random_ops.random_uniform((5,), seed=1) outputs = layer.apply(inputs) self.assertEqual(layer.built, True) - if context.in_graph_mode(): + if not context.executing_eagerly(): # op is only supported in GRAPH mode self.assertEqual(outputs.op.name, 'my_layer/Square') @@ -210,7 +210,7 @@ class BaseLayerTest(test.TestCase): inputs = random_ops.random_uniform((5,), seed=1) outputs = layer.apply(inputs) self.assertEqual(layer.built, True) - if context.in_graph_mode(): + if not context.executing_eagerly(): # op only supported in GRAPH mode. self.assertEqual(outputs.op.name, 'my_layer/Square') @@ -280,7 +280,7 @@ class BaseLayerTest(test.TestCase): def call(self, inputs): return inputs - if context.in_graph_mode(): + if not context.executing_eagerly(): layer = CustomerLayer() with self.assertRaisesRegexp(ValueError, r'requires a defined rank'): layer.apply(array_ops.placeholder('int32')) @@ -307,7 +307,7 @@ class BaseLayerTest(test.TestCase): def call(self, inputs): return inputs - if context.in_graph_mode(): + if not context.executing_eagerly(): layer = CustomerLayer() with self.assertRaisesRegexp(ValueError, r'requires a defined rank'): layer.apply(array_ops.placeholder('int32')) @@ -335,7 +335,7 @@ class BaseLayerTest(test.TestCase): def call(self, inputs): return inputs - if context.in_graph_mode(): + if not context.executing_eagerly(): layer = CustomerLayer() with self.assertRaisesRegexp(ValueError, r'requires a defined rank'): layer.apply(array_ops.placeholder('int32')) @@ -430,7 +430,7 @@ class BaseLayerTest(test.TestCase): layer.apply(constant_op.constant(1)) # Works - if context.in_graph_mode(): + if not context.executing_eagerly(): layer.apply(array_ops.placeholder('int32')) layer.apply(array_ops.placeholder('int32', shape=(2, 3))) @@ -453,13 +453,7 @@ class BaseLayerTest(test.TestCase): return {'l' + key: inputs[key] for key in inputs} layer = DictLayer() - if context.in_graph_mode(): - i1 = array_ops.placeholder('int32') - i2 = array_ops.placeholder('float32') - result = layer.apply({'abel': i1, 'ogits': i2}) - self.assertTrue(isinstance(result, dict)) - self.assertEqual(set(['label', 'logits']), set(result.keys())) - else: + if context.executing_eagerly(): i1 = constant_op.constant(3) i2 = constant_op.constant(4.0) result = layer.apply({'abel': i1, 'ogits': i2}) @@ -467,6 +461,12 @@ class BaseLayerTest(test.TestCase): self.assertEqual(set(['label', 'logits']), set(result.keys())) self.assertEqual(3, result['label'].numpy()) self.assertEqual(4.0, result['logits'].numpy()) + else: + i1 = array_ops.placeholder('int32') + i2 = array_ops.placeholder('float32') + result = layer.apply({'abel': i1, 'ogits': i2}) + self.assertTrue(isinstance(result, dict)) + self.assertEqual(set(['label', 'logits']), set(result.keys())) def testActivityRegularizer(self): regularizer = math_ops.reduce_sum @@ -643,6 +643,16 @@ class BaseLayerTest(test.TestCase): self.assertEqual(len(layer.get_losses_for([intermediate_inputs])), 1) self.assertEqual(len(layer.get_losses_for([outputs])), 0) + def testLayerGraphSetInFirstApply(self): + with ops.Graph().as_default(): + layer = core_layers.Dense(1) # Graph at construction time is ignored + with ops.Graph().as_default(): + layer.apply(constant_op.constant([[1]])) + # layer is now bound to second Graph + with ops.Graph().as_default(), self.assertRaisesRegexp( + ValueError, 'Input graph and Layer graph are not the same'): + layer.apply(constant_op.constant([[1]])) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index bb10fe5e8bfd26e4877fb6aef73980a30f62bb5d..2d99b1688f1b2736c0660ba2ac914018b21bf9ed 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -180,6 +180,8 @@ class _Conv(base.Layer): # bias_add when computing gradients. To use bias_add, we collapse Z # and Y into a single dimension to obtain a 4D input tensor. outputs_shape = outputs.shape.as_list() + if outputs_shape[0] is None: + outputs_shape[0] = -1 outputs_4d = array_ops.reshape(outputs, [outputs_shape[0], outputs_shape[1], outputs_shape[2] * outputs_shape[3], @@ -1664,7 +1666,7 @@ class Conv2DTranspose(Conv2D): padding=self.padding.upper(), data_format=utils.convert_data_format(self.data_format, ndim=4)) - if context.in_graph_mode(): + if not context.executing_eagerly(): # Infer the static output shape: out_shape = inputs.get_shape().as_list() out_shape[c_axis] = self.filters @@ -1969,7 +1971,7 @@ class Conv3DTranspose(Conv3D): data_format=utils.convert_data_format(self.data_format, ndim=5), padding=self.padding.upper()) - if context.in_graph_mode(): + if not context.executing_eagerly(): # Infer the static output shape: out_shape = inputs.get_shape().as_list() out_shape[c_axis] = self.filters diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/layers/convolutional_test.py index 160e732b6798697d05815e13a7b1c399070f0783..cdb42f5bd18292cad9d8536e88ea1c58c1d7d777 100644 --- a/tensorflow/python/layers/convolutional_test.py +++ b/tensorflow/python/layers/convolutional_test.py @@ -325,6 +325,12 @@ class ConvTest(test.TestCase): self.assertEqual(conv3d.kernel_constraint, k_constraint) self.assertEqual(conv3d.bias_constraint, b_constraint) + def testConv3DChannelsFirst(self): + # Test case for GitHub issue 15655 + images = array_ops.placeholder( + dtype=dtypes.float32, shape=[None, 1, 32, 32, 32]) + conv_layers.conv3d(images, 32, 9, data_format='channels_first') + @test_util.with_c_api class SeparableConv1DTest(test.TestCase): diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index 6970bf9234f5a31ee8093069ac1c933bcdb6f103..e598d9f83ab21f2dd5fabb3dd37fa0bfb5f003a4 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -35,6 +35,7 @@ from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops from tensorflow.python.ops import standard_ops @@ -155,11 +156,11 @@ class Dense(base.Layer): outputs = standard_ops.tensordot(inputs, self.kernel, [[len(shape) - 1], [0]]) # Reshape the output back to the original ndim of the input. - if context.in_graph_mode(): + if not context.executing_eagerly(): output_shape = shape[:-1] + [self.units] outputs.set_shape(output_shape) else: - outputs = standard_ops.matmul(inputs, self.kernel) + outputs = gen_math_ops.mat_mul(inputs, self.kernel) if self.use_bias: outputs = nn.bias_add(outputs, self.bias) if self.activation is not None: @@ -373,7 +374,7 @@ class Flatten(base.Layer): def call(self, inputs): outputs = array_ops.reshape(inputs, (array_ops.shape(inputs)[0], -1)) - if context.in_graph_mode(): + if not context.executing_eagerly(): outputs.set_shape(self.compute_output_shape(inputs.get_shape())) return outputs diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py index 15ce6cba21fcc78126f7db58ab18934db69c15fd..cf45b07637108422f1c612390bb01efdad6d5bcf 100644 --- a/tensorflow/python/layers/core_test.py +++ b/tensorflow/python/layers/core_test.py @@ -67,7 +67,7 @@ class DenseTest(test.TestCase): variables.global_variables_initializer().run() self.assertAllEqual(x.eval(), [[0.0]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testCall(self): dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense') inputs = random_ops.random_uniform((5, 4), seed=1) @@ -77,12 +77,20 @@ class DenseTest(test.TestCase): self.assertListEqual(dense.trainable_variables, [dense.kernel, dense.bias]) self.assertListEqual(dense.non_trainable_variables, []) - if context.in_graph_mode(): + if not context.executing_eagerly(): self.assertEqual( len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2) self.assertEqual(dense.kernel.name, 'my_dense/kernel:0') self.assertEqual(dense.bias.name, 'my_dense/bias:0') + @test_util.assert_no_new_pyobjects_executing_eagerly + def testNoEagerLeak(self): + # Tests that repeatedly constructing and building a Layer does not leak + # Python objects. + inputs = random_ops.random_uniform((5, 4), seed=1) + core_layers.Dense(5)(inputs) + core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')(inputs) + @test_util.run_in_graph_and_eager_modes() def testCallTensorDot(self): dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense') @@ -98,7 +106,7 @@ class DenseTest(test.TestCase): self.assertListEqual(dense.variables, [dense.kernel]) self.assertListEqual(dense.trainable_variables, [dense.kernel]) self.assertListEqual(dense.non_trainable_variables, []) - if context.in_graph_mode(): + if not context.executing_eagerly(): self.assertEqual( len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1) self.assertEqual(dense.kernel.name, 'my_dense/kernel:0') @@ -113,7 +121,7 @@ class DenseTest(test.TestCase): self.assertListEqual(dense.non_trainable_variables, [dense.kernel, dense.bias]) self.assertListEqual(dense.trainable_variables, []) - if context.in_graph_mode(): + if not context.executing_eagerly(): self.assertEqual( len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 0) @@ -162,13 +170,13 @@ class DenseTest(test.TestCase): dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1') inputs = random_ops.random_uniform((5, 3), seed=1) outputs = dense(inputs) - if context.in_graph_mode(): + if not context.executing_eagerly(): self.assertEqual(outputs.op.name, 'dense1/Relu') dense = core_layers.Dense(2, name='dense2') inputs = random_ops.random_uniform((5, 3), seed=1) outputs = dense(inputs) - if context.in_graph_mode(): + if not context.executing_eagerly(): self.assertEqual(outputs.op.name, 'dense2/BiasAdd') def testActivityRegularizer(self): @@ -374,7 +382,7 @@ class DropoutTest(test.TestCase): dp = core_layers.Dropout(0.5) inputs = array_ops.ones((5, 3)) dropped = dp.apply(inputs, training=True) - if context.in_graph_mode(): + if not context.executing_eagerly(): self.evaluate(variables.global_variables_initializer()) np_output = self.evaluate(dropped) self.assertAlmostEqual(0., np_output.min()) diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index d83292b80963d942023b5d086a089af53008efe0..29fb92ccb59aef83448cff8fd1bd759c4fda5abf 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -319,7 +319,6 @@ class BatchNormalization(base.Layer): initializer=self.moving_variance_initializer, trainable=False) - self._one_minus_decay = 1.0 - self.momentum if self.renorm: # Create variables to maintain the moving mean and standard deviation. # These are used in training and thus are different from the moving @@ -338,8 +337,9 @@ class BatchNormalization(base.Layer): return var with ops.device(None): - device = ((lambda _: self.moving_mean.device) - if context.in_graph_mode() else self.moving_mean.device) + device = ( + self.moving_mean.device if context.executing_eagerly() else + (lambda _: self.moving_mean.device)) with ops.device(device): self.renorm_mean = _renorm_variable('renorm_mean', param_shape) self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ()) @@ -347,8 +347,9 @@ class BatchNormalization(base.Layer): # renorm_stddev_weight. This allows us to (1) mix the average # stddev with the minibatch stddev early in training, and (2) compute # the unbiased average stddev by dividing renorm_stddev by the weight. - device = ((lambda _: self.moving_variance.device) - if context.in_graph_mode() else self.moving_variance.device) + device = ( + self.moving_variance.device if context.executing_eagerly() else + (lambda _: self.moving_variance.device)) with ops.device(device): self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape) self.renorm_stddev_weight = _renorm_variable( @@ -358,20 +359,15 @@ class BatchNormalization(base.Layer): self._scope.set_partitioner(partitioner) self.built = True - def _assign_moving_average(self, variable, value, one_minus_decay): + def _assign_moving_average(self, variable, value, momentum): with ops.name_scope(None, 'AssignMovingAvg', - [variable, value, one_minus_decay]) as scope: + [variable, value, momentum]) as scope: with ops.colocate_with(variable): - update_delta = math_ops.multiply( - math_ops.subtract(variable.read_value(), value), - one_minus_decay) - if isinstance(variable, resource_variable_ops.ResourceVariable): - # state_ops.assign_sub does an extra read_variable_op after the - # assign. We avoid that here. - return gen_resource_variable_ops.assign_sub_variable_op( - variable.handle, update_delta, name=scope) - else: - return state_ops.assign_sub(variable, update_delta, name=scope) + decay = ops.convert_to_tensor(1.0 - momentum, name='decay') + if decay.dtype != variable.dtype.base_dtype: + decay = math_ops.cast(decay, variable.dtype.base_dtype) + update_delta = (variable - value) * decay + return state_ops.assign_sub(variable, update_delta, name=scope) def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" @@ -410,22 +406,16 @@ class BatchNormalization(base.Layer): training_value = utils.constant_value(training) if training_value is None: - one_minus_decay = utils.smart_cond(training, - lambda: self._one_minus_decay, - lambda: 0.) + momentum = utils.smart_cond(training, lambda: self.momentum, lambda: 1.0) else: - one_minus_decay = ops.convert_to_tensor(self._one_minus_decay) + momentum = ops.convert_to_tensor(self.momentum) if training_value or training_value is None: mean_update = self._assign_moving_average(self.moving_mean, mean, - one_minus_decay) + momentum) variance_update = self._assign_moving_average(self.moving_variance, - variance, one_minus_decay) - if context.in_graph_mode(): - # Note that in Eager mode, the updates are already executed when running - # assign_moving_averages. So we do not need to put them into - # collections. - self.add_update(mean_update, inputs=inputs) - self.add_update(variance_update, inputs=inputs) + variance, momentum) + self.add_update(mean_update, inputs=inputs) + self.add_update(variance_update, inputs=inputs) return output @@ -462,6 +452,7 @@ class BatchNormalization(base.Layer): """Updates a moving average and weight, returns the unbiased value.""" value = array_ops.identity(value) def _do_update(): + """Updates the var and weight, returns their updated ratio.""" # Update the variables without zero debiasing. The debiasing will be # accomplished by dividing the exponential moving average by the weight. # For example, after a single update, the moving average would be @@ -470,11 +461,14 @@ class BatchNormalization(base.Layer): # Make sure the weight is not updated until before r and d computation. with ops.control_dependencies([value]): weight_value = array_ops.constant(1., dtype=weight.dtype) - new_var = moving_averages.assign_moving_average( - var, value, self.renorm_momentum, zero_debias=False) - new_weight = moving_averages.assign_moving_average( - weight, weight_value, self.renorm_momentum, zero_debias=False) + new_var = self._assign_moving_average(var, value, self.renorm_momentum) + new_weight = self._assign_moving_average(weight, weight_value, + self.renorm_momentum) + # TODO(yuefengz): the updates to var and weighted can not be batched + # together if we fetch their updated values here. Consider calculating + # new values and delaying the updates. return new_var / new_weight + def _fake_update(): return array_ops.identity(var) return utils.smart_cond(training, _do_update, _fake_update) @@ -493,7 +487,7 @@ class BatchNormalization(base.Layer): return (r, d, new_mean, new_variance) def call(self, inputs, training=False): - in_eager_mode = context.in_eager_mode() + in_eager_mode = context.executing_eagerly() if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation @@ -599,8 +593,7 @@ class BatchNormalization(base.Layer): if in_eager_mode and not self.trainable: return - return moving_averages.assign_moving_average( - var, value, self.momentum, zero_debias=False) + return self._assign_moving_average(var, value, self.momentum) mean_update = utils.smart_cond( training, @@ -610,7 +603,7 @@ class BatchNormalization(base.Layer): training, lambda: _do_update(self.moving_variance, new_variance), lambda: self.moving_variance) - if context.in_graph_mode(): + if not context.executing_eagerly(): self.add_update(mean_update, inputs=inputs) self.add_update(variance_update, inputs=inputs) @@ -671,9 +664,16 @@ def batch_normalization(inputs, Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they - need to be added as a dependency to the `train_op`. For example: + need to be added as a dependency to the `train_op`. Also, be sure to add + any batch_normalization ops before getting the update_ops collection. + Otherwise, update_ops will be empty, and training/inference will not work + properly. For example: ```python + x_norm = tf.layers.batch_normalization(x, training=training) + + # ... + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss) diff --git a/tensorflow/python/layers/utils.py b/tensorflow/python/layers/utils.py index 484c6fc466558dc274740955594cc279a175d638..3b156c36a2ff35fb9e05af1406d7b3f6cf883394 100644 --- a/tensorflow/python/layers/utils.py +++ b/tensorflow/python/layers/utils.py @@ -24,6 +24,7 @@ from tensorflow.python.eager import context from tensorflow.python.ops import variables from tensorflow.python.ops import control_flow_ops from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond as smart_module from tensorflow.python.framework import tensor_util from tensorflow.python.util import nest @@ -201,7 +202,7 @@ def smart_cond(pred, true_fn=None, false_fn=None, name=None): if isinstance(pred, variables.Variable): return control_flow_ops.cond( pred, true_fn=true_fn, false_fn=false_fn, name=name) - return control_flow_ops.smart_cond( + return smart_module.smart_cond( pred, true_fn=true_fn, false_fn=false_fn, name=name) @@ -228,7 +229,7 @@ def constant_value(pred): if isinstance(pred, variables.Variable): return None - return control_flow_ops.smart_constant_value(pred) + return smart_module.smart_constant_value(pred) def object_list_uid(object_list): diff --git a/tensorflow/python/lib/core/ndarray_tensor.cc b/tensorflow/python/lib/core/ndarray_tensor.cc index 994af69386b278f6b88c051f898cd6a9dc607f3f..a07e305ffbe8b4c4736c3231f6d1d7872d91e04e 100644 --- a/tensorflow/python/lib/core/ndarray_tensor.cc +++ b/tensorflow/python/lib/core/ndarray_tensor.cc @@ -267,7 +267,9 @@ gtl::InlinedVector GetPyArrayDimensionsForTensor( const int ndims = TF_NumDims(tensor); gtl::InlinedVector dims(ndims); if (TF_TensorType(tensor) == TF_RESOURCE) { - dims[0] = TF_TensorByteSize(tensor); + CHECK_EQ(ndims, 0) + << "Fetching of non-scalar resource tensors is not supported."; + dims.push_back(TF_TensorByteSize(tensor)); *nelems = dims[0]; } else { *nelems = 1; diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index e0422ef80add42307268be2743e668eb8c8acb68..22317a348c9d5472486ad118d865341ffb6ad829 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -79,10 +79,11 @@ Status MakeArgTuple(const PyCall* call, PyObject** tuple) { const Tensor& t = call->ins[i]; if (call->eager) { if (call->gpu) { - arg = EagerTensorFromHandle(new TFE_TensorHandle(t, call->device)); + arg = EagerTensorFromHandle( + new TFE_TensorHandle(t, call->device, call->device)); } else { // TFE_TensorHandle assumes that CPU is identified by `nullptr`. - arg = EagerTensorFromHandle(new TFE_TensorHandle(t, nullptr)); + arg = EagerTensorFromHandle(new TFE_TensorHandle(t, nullptr, nullptr)); } if (arg == nullptr) { return errors::Internal("Unable to procure EagerTensor from Tensor."); @@ -163,9 +164,9 @@ bool IsSingleNone(PyObject* obj) { } // Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`. -void ExtractTensorFromEagerTensor(const PyObject* eager_tensor, - Tensor* output_tensor) { - *output_tensor = EagerTensor_Handle(eager_tensor)->t; +tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor, + const Tensor** output_tensor) { + return EagerTensor_Handle(eager_tensor)->handle->Tensor(output_tensor); } // Calls the registered py function through the trampoline. @@ -219,7 +220,9 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { if (call->eager) { const PyObject* item = PyList_GetItem(result, i); if (EagerTensor_CheckExact(item)) { - ExtractTensorFromEagerTensor(item, &t); + const Tensor* tensor = nullptr; + s = ExtractTensorFromEagerTensor(item, &tensor); + if (s.ok()) t = *tensor; } else { s = errors::FailedPrecondition( "Expected EagerTensor, found PyObject of type: ", @@ -237,10 +240,10 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { } else if (EagerTensor_CheckExact(result) || result == Py_None) { // result is an `EagerTensor` or `None`. DCHECK(call->eager); - Tensor t; if (result != Py_None) { - ExtractTensorFromEagerTensor(result, &t); - call->out.push_back(t); + const Tensor* t = nullptr; + s = ExtractTensorFromEagerTensor(result, &t); + if (s.ok()) call->out.push_back(*t); } } else if (PyArray_Check(result)) { // `result` is a NumPy array. diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 317bdc2e14747583f372808f48a5928273f5570a..8247d354db62532c10c5acc9875cc08289cd31bf 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -84,6 +84,7 @@ bool IsPyDimension(PyObject* obj) { } Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) { + std::vector refs_to_clean; while (true) { // We test strings first, in case a string is considered a sequence. if (IsPyString(obj)) { @@ -93,6 +94,7 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) { if (length > 0) { shape->AddDim(length); obj = PySequence_GetItem(obj, 0); + refs_to_clean.push_back(make_safe(obj)); continue; } else if (length == 0) { shape->AddDim(length); @@ -167,14 +169,15 @@ const char ErrorFoundFloat[] = if (shape.dims() > 1) { \ /* Iterate over outer dim, and recursively convert each element. */ \ const int64 s = shape.dim_size(0); \ - if (TF_PREDICT_FALSE(s != PySequence_Length(obj))) { \ + Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, "")); \ + if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) { \ return ErrorRectangular; \ } \ TensorShape rest = shape; \ rest.RemoveDim(0); \ for (int64 i = 0; i < s; ++i) { \ - const char* error = \ - FUNCTION##Helper(PySequence_GetItem(obj, i), rest, buf); \ + const char* error = FUNCTION##Helper( \ + PySequence_Fast_GET_ITEM(seq.get(), i), rest, buf); \ if (TF_PREDICT_FALSE(error != nullptr)) return error; \ } \ } else { \ diff --git a/tensorflow/python/lib/core/py_util.cc b/tensorflow/python/lib/core/py_util.cc index 2635694e23c07dd8e75d4bb0cfb9e83a2042d921..00cbf0c532cf80d3bb27afe168ecde963ba3591d 100644 --- a/tensorflow/python/lib/core/py_util.cc +++ b/tensorflow/python/lib/core/py_util.cc @@ -41,6 +41,55 @@ const char* ClassName(PyObject* py) { } // end namespace +// Returns a PyObject containing a string, or null +void TryAppendTraceback(PyObject* ptype, PyObject* pvalue, PyObject* ptraceback, + string* out) { + // The "traceback" module is assumed to be imported already by script_ops.py. + PyObject* tb_module = PyImport_AddModule("traceback"); + + if (!tb_module) { + return; + } + + PyObject* format_exception = + PyObject_GetAttrString(tb_module, "format_exception"); + + if (!format_exception) { + return; + } + + if (!PyCallable_Check(format_exception)) { + Py_DECREF(format_exception); + return; + } + + PyObject* ret_val = PyObject_CallFunctionObjArgs(format_exception, ptype, + pvalue, ptraceback, nullptr); + Py_DECREF(format_exception); + + if (!ret_val) { + return; + } + + if (!PyList_Check(ret_val)) { + Py_DECREF(ret_val); + return; + } + + Py_ssize_t n = PyList_GET_SIZE(ret_val); + for (Py_ssize_t i = 0; i < n; ++i) { + PyObject* v = PyList_GET_ITEM(ret_val, i); +#if PY_MAJOR_VERSION < 3 + strings::StrAppend(out, PyString_AS_STRING(v), "\n"); +#else + strings::StrAppend(out, PyUnicode_AsUTF8(v), "\n"); +#endif + } + + // Iterate through ret_val. + Py_DECREF(ret_val); +} + string PyExceptionFetch() { CHECK(PyErr_Occurred()) << "Must only call PyExceptionFetch after an exception."; @@ -52,14 +101,20 @@ string PyExceptionFetch() { string err = ClassName(ptype); if (pvalue) { PyObject* str = PyObject_Str(pvalue); + if (str) { #if PY_MAJOR_VERSION < 3 - strings::StrAppend(&err, ": ", PyString_AS_STRING(str)); + strings::StrAppend(&err, ": ", PyString_AS_STRING(str), "\n"); #else - strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str)); + strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str), "\n"); #endif Py_DECREF(str); + } else { + strings::StrAppend(&err, "(unknown error message)\n"); } + + TryAppendTraceback(ptype, pvalue, ptraceback, &err); + Py_DECREF(pvalue); } Py_DECREF(ptype); diff --git a/tensorflow/python/lib/io/file_io_test.py b/tensorflow/python/lib/io/file_io_test.py index a751607aaa1f47ca7c08674eca2b27ee0cafa3d2..223858edfa84eaa1c7879a9774dcc836de4f4672 100644 --- a/tensorflow/python/lib/io/file_io_test.py +++ b/tensorflow/python/lib/io/file_io_test.py @@ -485,6 +485,11 @@ class FileIoTest(test.TestCase): f.flush() self.assertEqual(content, f.read(len(content) + 1)) + def testUTF8StringPathExists(self): + file_path = os.path.join(self._base_dir, "UTF8测试_file_exist") + file_io.write_string_to_file(file_path, "testing") + v = file_io.file_exists(file_path) + self.assertEqual(v, True) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py index 48ea107a146c2714f7b59f53abbcd8b60dbf2fd4..6fcf9c91d831e3a89552b522040e8e8647114a2f 100644 --- a/tensorflow/python/lib/io/tf_record.py +++ b/tensorflow/python/lib/io/tf_record.py @@ -75,14 +75,16 @@ def tf_record_iterator(path, options=None): if reader is None: raise IOError("Could not open %s." % path) - while True: - try: - with errors.raise_exception_on_not_ok_status() as status: - reader.GetNext(status) - except errors.OutOfRangeError: - break - yield reader.record() - reader.Close() + try: + while True: + try: + with errors.raise_exception_on_not_ok_status() as status: + reader.GetNext(status) + except errors.OutOfRangeError: + break + yield reader.record() + finally: + reader.Close() @tf_export("python_io.TFRecordWriter") diff --git a/tensorflow/python/ops/accumulate_n_benchmark.py b/tensorflow/python/ops/accumulate_n_benchmark.py index c58d36f39705ecf0f24214ce4ba4574e70a93e77..a709066cae4da2811b3e98d2e93bf44ec12dcee6 100644 --- a/tensorflow/python/ops/accumulate_n_benchmark.py +++ b/tensorflow/python/ops/accumulate_n_benchmark.py @@ -39,7 +39,7 @@ from tensorflow.python.platform import test class AccumulateNBenchmark(test.Benchmark): def _AccumulateNTemplate(self, inputs, init, shape, validate_shape): - var = gen_state_ops._temporary_variable( + var = gen_state_ops.temporary_variable( shape=shape, dtype=inputs[0].dtype.base_dtype) ref = state_ops.assign(var, init, validate_shape=validate_shape) update_ops = [ @@ -47,8 +47,7 @@ class AccumulateNBenchmark(test.Benchmark): ref, tensor, use_locking=True).op for tensor in inputs ] with ops.control_dependencies(update_ops): - return gen_state_ops._destroy_temporary_variable( - ref, var_name=var.op.name) + return gen_state_ops.destroy_temporary_variable(ref, var_name=var.op.name) def _AccumulateNInitializedWithFirst(self, inputs): return self._AccumulateNTemplate( @@ -60,7 +59,7 @@ class AccumulateNBenchmark(test.Benchmark): def _AccumulateNInitializedWithMerge(self, inputs): return self._AccumulateNTemplate( inputs, - init=array_ops.zeros_like(gen_control_flow_ops._merge(inputs)[0]), + init=array_ops.zeros_like(gen_control_flow_ops.merge(inputs)[0]), shape=tensor_shape.vector(0), validate_shape=False) diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 9745d38dc23dba806a2d0dd2ef588a5a950aa05c..3c6a5c9e562ff9765c2ef47555871c94cd6feb1e 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -80,7 +80,7 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): def _ExtractInputShapes(inputs): """Extract the shapes of a set of input tensors.""" - if not context.in_graph_mode(): + if context.executing_eagerly(): return array_ops.shape_n(inputs) sizes = [] fully_known = True @@ -106,7 +106,7 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): out_grads = [] if isinstance(grad, ops.Tensor): - if context.in_eager_mode(): + if context.executing_eagerly(): # Using mod here for convenience since concat_dim is already verified # in concat implementation to be within the allowed [-rank, rank) range. non_neg_concat_dim = ( @@ -139,7 +139,6 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): # on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of # cases when switching implementations at N=16, but it is possible that # there will be a small number of performance regressions. - # pylint: disable=protected-access if len(sizes) > 16: # extract the size of each input along the concat dimension sizes = array_ops.squeeze( @@ -148,10 +147,9 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): [1, -1])) out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) else: - offset = gen_array_ops._concat_offset(non_neg_concat_dim, sizes) + offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes) for (begin, size) in zip(offset, sizes): out_grads.append(array_ops.slice(grad, begin, size)) - # pylint: enable=protected-access elif isinstance(grad, ops.IndexedSlices): # Using mod here for convenience since concat_dim is already verified # in concat implementation to be within the allowed [-rank, rank) range. @@ -430,7 +428,7 @@ def _GatherV2Grad(op, grad): # For axis 0 gathers, build an appropriately shaped IndexedSlices. if axis_static == 0: - if context.in_eager_mode(): + if context.executing_eagerly(): params_tail_shape = params_shape.cpu()[1:] else: params_tail_shape = params_shape[1:] @@ -580,7 +578,7 @@ def _TileGrad(op, grad): axes = math_ops.range(0, array_ops.size(split_shape), 2) input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes) # Fix shape inference - if context.in_graph_mode(): + if not context.executing_eagerly(): input_grad.set_shape(op.inputs[0].get_shape()) return [input_grad, None] @@ -627,9 +625,7 @@ def _ReverseSequenceGrad(op, grad): @ops.RegisterGradient("Reverse") def _ReverseGrad(op, grad): reverse_dims = op.inputs[1] - # pylint: disable=protected-access - return gen_array_ops._reverse(grad, reverse_dims), None - # pylint: enable=protected-access + return gen_array_ops.reverse(grad, reverse_dims), None @ops.RegisterGradient("ReverseV2") @@ -700,17 +696,13 @@ ops.NotDifferentiable("OneHot") @ops.RegisterGradient("MirrorPad") def _MirrorPadGrad(op, grad): mode = op.get_attr("mode") - # pylint: disable=protected-access - return [gen_array_ops._mirror_pad_grad(grad, op.inputs[1], mode=mode), None] - # pylint: enable=protected-access + return [gen_array_ops.mirror_pad_grad(grad, op.inputs[1], mode=mode), None] @ops.RegisterGradient("MirrorPadGrad") def _MirrorPadGradGrad(op, grad): mode = op.get_attr("mode") - # pylint: disable=protected-access - return [gen_array_ops._mirror_pad(grad, op.inputs[1], mode=mode), None] - # pylint: enable=protected-access + return [gen_array_ops.mirror_pad(grad, op.inputs[1], mode=mode), None] @ops.RegisterGradient("QuantizeAndDequantize") diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 14824962eabaecdc2cca348b27ce7f9f8868af82..9106461c6001e3a843bb694e389693236fbd442f 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -128,15 +128,18 @@ def identity(input, name=None): # pylint: disable=redefined-builtin Returns: A `Tensor`. Has the same type as `input`. """ - if context.in_graph_mode(): - return gen_array_ops.identity(input, name=name) - else: + if context.executing_eagerly(): input = ops.convert_to_tensor(input) in_device = input.device # TODO(ashankar): Does 'identity' need to invoke execution callbacks? - if context.context().device_name != in_device: + context_device = context.context().device_name + if not context_device: + context_device = "/job:localhost/replica:0/task:0/device:CPU:0" + if context_device != in_device: return input._copy() # pylint: disable=protected-access return input + else: + return gen_array_ops.identity(input, name=name) # pylint: disable=redefined-builtin,protected-access @@ -195,7 +198,7 @@ def expand_dims(input, axis=None, name=None, dim=None): if axis is not None: raise ValueError("can't specify both 'dim' and 'axis'") axis = dim - return gen_array_ops._expand_dims(input, axis, name) + return gen_array_ops.expand_dims(input, axis, name) # pylint: enable=redefined-builtin,protected-access @@ -208,28 +211,25 @@ def expand_dims(input, axis=None, name=None, dim=None): "This op will be removed after the deprecation date. " "Please switch to tf.setdiff1d().") def listdiff(x, y, out_idx=None, name=None): - return gen_array_ops._list_diff(x, y, out_idx, name) + return gen_array_ops.list_diff(x, y, out_idx, name) -listdiff.__doc__ = gen_array_ops._list_diff.__doc__ + "\n" + listdiff.__doc__ +listdiff.__doc__ = gen_array_ops.list_diff.__doc__ + "\n" + listdiff.__doc__ # pylint: enable=protected-access -# pylint: disable=undefined-variable,protected-access +# pylint: disable=undefined-variable @tf_export("setdiff1d") def setdiff1d(x, y, index_dtype=dtypes.int32, name=None): - return gen_array_ops._list_diff(x, y, index_dtype, name) - + return gen_array_ops.list_diff(x, y, index_dtype, name) -setdiff1d.__doc__ = gen_array_ops._list_diff.__doc__ -# pylint: enable=protected-access +setdiff1d.__doc__ = gen_array_ops.list_diff.__doc__ @tf_export("broadcast_dynamic_shape") def broadcast_dynamic_shape(shape_x, shape_y): - # pylint: disable=protected-access """Returns the broadcasted dynamic shape between `shape_x` and `shape_y`. Args: @@ -239,8 +239,7 @@ def broadcast_dynamic_shape(shape_x, shape_y): Returns: A rank 1 integer `Tensor` representing the broadcasted shape. """ - return gen_array_ops._broadcast_args(shape_x, shape_y) - # pylint: enable=protected-access + return gen_array_ops.broadcast_args(shape_x, shape_y) @tf_export("broadcast_static_shape") @@ -306,7 +305,7 @@ def shape_internal(input, name=None, optimize=True, out_type=dtypes.int32): sparse_tensor.SparseTensorValue)): return gen_math_ops.cast(input.dense_shape, out_type) else: - if context.in_graph_mode(): + if not context.executing_eagerly(): input_tensor = ops.convert_to_tensor(input) input_shape = input_tensor.get_shape() if optimize and input_shape.is_fully_defined(): @@ -331,7 +330,7 @@ def shape_n(input, out_type=dtypes.int32, name=None): """ output = gen_array_ops.shape_n(input, out_type=out_type, name=name) - if context.in_graph_mode(): + if not context.executing_eagerly(): for i, input_tensor in enumerate(input): input_tensor = ops.convert_to_tensor(input_tensor) input_shape = input_tensor.get_shape() @@ -386,23 +385,22 @@ def size_internal(input, name=None, optimize=True, out_type=dtypes.int32): Returns: A `Tensor` of type `out_type`. Defaults to `tf.int32`. """ - if context.in_eager_mode() and not isinstance( - input, (sparse_tensor.SparseTensor, - sparse_tensor.SparseTensorValue)): - size_ = 1 - for dim in ops.convert_to_tensor(input)._shape_tuple(): # pylint: disable=protected-access - size_ *= dim - return size_ + if context.executing_eagerly() and not isinstance( + input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): + return np.prod(ops.convert_to_tensor(input)._shape_tuple()) # pylint: disable=protected-access with ops.name_scope(name, "Size", [input]) as name: if isinstance(input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): - return gen_math_ops._prod( + return gen_math_ops.prod( gen_math_ops.cast(input.dense_shape, out_type), 0, name=name) else: input_tensor = ops.convert_to_tensor(input) input_shape = input_tensor.get_shape() - if optimize and input_shape.is_fully_defined(): - return constant(input_shape.num_elements(), out_type, name=name) + if optimize: + if input_shape.is_fully_defined(): + return constant(input_shape.num_elements(), out_type, name=name) + if input_shape.dims and any(dim == 0 for dim in input_shape.dims): + return constant(0, out_type, name=name) return gen_array_ops.size(input, name=name, out_type=out_type) @@ -784,7 +782,7 @@ def strided_slice(input_, new_axis_mask=new_axis_mask, shrink_axis_mask=shrink_axis_mask) - if context.in_graph_mode(): + if not context.executing_eagerly(): # TODO(apassos) In eager mode assignment will be done by overriding # __setitem__ instead. op.assign = assign @@ -795,8 +793,8 @@ def _SliceHelperVar(var, slice_spec): """Creates a slice helper object given a variable. This allows creating a sub-tensor from part of the current contents - of a variable. See ${tf.Tensor$`Tensor.__getitem__`} - for detailed examples of slicing. + of a variable. See @{tf.Tensor.__getitem__} for detailed examples + of slicing. This function in addition also allows assignment to a sliced range. This is similar to `__setitem__` functionality in Python. However, @@ -886,7 +884,7 @@ def parallel_stack(values, name="parallel_stack"): output_shape = tensor_shape.TensorShape([len(values)]) output_shape = output_shape.concatenate(value_shape) # expand_dims converts concat to stack. - return gen_array_ops._parallel_concat( + return gen_array_ops.parallel_concat( [expand_dims(value, 0) for value in values], shape=output_shape) @@ -944,7 +942,7 @@ def stack(values, axis=0, name="stack"): raise ValueError("axis = %d not in [%d, %d)" % (axis, -expanded_num_dims, expanded_num_dims)) - return gen_array_ops._pack(values, axis=axis, name=name) + return gen_array_ops.pack(values, axis=axis, name=name) # pylint: disable=invalid-name @@ -988,7 +986,7 @@ def _autopacking_helper(list_or_tuple, dtype, name): # convertible-to-tensor types, such as numpy arrays. elems_as_tensors.append( constant_op.constant(elem, dtype=dtype, name=str(i))) - return gen_array_ops._pack(elems_as_tensors, name=scope) + return gen_array_ops.pack(elems_as_tensors, name=scope) else: return converted_elems @@ -1083,7 +1081,7 @@ def unstack(value, num=None, axis=0, name="unstack"): num = value_shape[axis].value if num is None: raise ValueError("Cannot infer num from shape %s" % value_shape) - return gen_array_ops._unpack(value, num=num, axis=axis, name=name) + return gen_array_ops.unpack(value, num=num, axis=axis, name=name) @tf_export("concat") @@ -1180,7 +1178,7 @@ def concat(values, axis, name="concat"): dtype=dtypes.int32).get_shape().assert_is_compatible_with( tensor_shape.scalar()) return identity(values[0], name=scope) - return gen_array_ops._concat_v2(values=values, axis=axis, name=name) + return gen_array_ops.concat_v2(values=values, axis=axis, name=name) @tf_export("boolean_mask") @@ -1248,8 +1246,7 @@ def boolean_mask(tensor, mask, name="boolean_mask", axis=None): axis = 0 if axis is None else axis shape_tensor[axis:axis + ndims_mask].assert_is_compatible_with(shape_mask) - leading_size = gen_math_ops._prod( - shape(tensor)[axis:axis + ndims_mask], [0]) + leading_size = gen_math_ops.prod(shape(tensor)[axis:axis + ndims_mask], [0]) tensor = reshape(tensor, concat([ shape(tensor)[:axis], [leading_size], @@ -1313,10 +1310,10 @@ def unique(x, out_idx=dtypes.int32, name=None): # period (3 weeks) pass. # TODO(yongtang): The documentation should also # be updated when switch to v2. - return gen_array_ops._unique(x, out_idx, name) + return gen_array_ops.unique(x, out_idx, name) -unique.__doc__ = gen_array_ops._unique.__doc__ +unique.__doc__ = gen_array_ops.unique.__doc__ @tf_export("unique_with_counts") @@ -1325,10 +1322,10 @@ def unique_with_counts(x, out_idx=dtypes.int32, name=None): # period (3 weeks) pass. # TODO(yongtang): The documentation should also # be updated when switch to v2. - return gen_array_ops._unique_with_counts(x, out_idx, name) + return gen_array_ops.unique_with_counts(x, out_idx, name) -unique_with_counts.__doc__ = gen_array_ops._unique_with_counts.__doc__ +unique_with_counts.__doc__ = gen_array_ops.unique_with_counts.__doc__ @tf_export("split") @@ -1382,20 +1379,18 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"): """ size_splits = ops.convert_to_tensor(num_or_size_splits) if size_splits._rank() == 0 and size_splits.dtype.is_integer: - return gen_array_ops._split( + return gen_array_ops.split( axis=axis, num_split=num_or_size_splits, value=value, name=name) if num is None: - num = size_splits._shape_tuple()[0] + size_splits_shape = size_splits._shape_tuple() + if size_splits_shape: + num = size_splits_shape[0] if num is None: raise ValueError("Cannot infer num from shape %s" % num_or_size_splits) - return gen_array_ops._split_v( - value=value, - size_splits=size_splits, - axis=axis, - num_split=num, - name=name) + return gen_array_ops.split_v( + value=value, size_splits=size_splits, axis=axis, num_split=num, name=name) @tf_export("transpose") @@ -1465,7 +1460,7 @@ def transpose(a, perm=None, name="transpose", conjugate=False): """ with ops.name_scope(name, "transpose", [a]) as name: transpose_fn = ( - gen_array_ops._conjugate_transpose + gen_array_ops.conjugate_transpose if (conjugate and a.dtype.is_complex) else gen_array_ops.transpose) if perm is None: rank = gen_array_ops.rank(a) @@ -1473,7 +1468,7 @@ def transpose(a, perm=None, name="transpose", conjugate=False): ret = transpose_fn(a, perm, name=name) # NOTE(mrry): Setting the shape explicitly because # reverse is not handled by the shape function. - if context.in_graph_mode(): + if not context.executing_eagerly(): input_shape = ret.op.inputs[0].get_shape().dims if input_shape is not None: ret.set_shape(input_shape[::-1]) @@ -1638,12 +1633,12 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True): with ops.name_scope(name, "zeros_like", [tensor]) as name: tensor = ops.convert_to_tensor(tensor, name="tensor") - if context.in_eager_mode(): + if context.executing_eagerly(): if dtype is not None and dtype != tensor.dtype: return zeros( shape_internal(tensor, optimize=optimize), dtype=dtype, name=name) with ops.device(tensor.device): - return gen_array_ops._zeros_like(tensor, name=name) + return gen_array_ops.zeros_like(tensor, name=name) # For now, variant types must be created via zeros_like; as we need to # pass the input variant object to the proper zeros callback. @@ -1658,7 +1653,7 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True): return zeros( shape_internal(tensor, optimize=optimize), dtype=dtype, name=name) else: - return gen_array_ops._zeros_like(tensor, name=name) + return gen_array_ops.zeros_like(tensor, name=name) @tf_export("ones_like") @@ -1694,7 +1689,7 @@ def ones_like(tensor, dtype=None, name=None, optimize=True): if dtype is None: dtype = tensor.dtype ret = ones(ones_shape, dtype=dtype, name=name) - if context.in_graph_mode(): + if not context.executing_eagerly(): ret.set_shape(tensor.get_shape()) return ret @@ -1775,11 +1770,11 @@ def placeholder(dtype, shape=None, name=None): Raises: RuntimeError: if eager execution is enabled """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("tf.placeholder() is not compatible with " "eager execution.") - return gen_array_ops._placeholder(dtype=dtype, shape=shape, name=name) + return gen_array_ops.placeholder(dtype=dtype, shape=shape, name=name) # pylint: disable=redefined-outer-name @@ -1838,7 +1833,7 @@ def sparse_placeholder(dtype, shape=None, name=None): Raises: RuntimeError: if eager execution is enabled """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("tf.placeholder() is not compatible with " "eager execution.") @@ -1923,21 +1918,21 @@ def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pyl # TODO(rjryan): Once the forward compatibility period (3 weeks) have passed # remove the "Pad" fallback here. if constant_values != 0: - result = gen_array_ops._pad_v2( + result = gen_array_ops.pad_v2( tensor, paddings, constant_values, name=name) else: - result = gen_array_ops._pad(tensor, paddings, name=name) + result = gen_array_ops.pad(tensor, paddings, name=name) elif mode == "REFLECT": - result = gen_array_ops._mirror_pad( + result = gen_array_ops.mirror_pad( tensor, paddings, mode="REFLECT", name=name) elif mode == "SYMMETRIC": - result = gen_array_ops._mirror_pad( + result = gen_array_ops.mirror_pad( tensor, paddings, mode="SYMMETRIC", name=name) else: raise ValueError("Unknown padding mode: %s" % mode) # Restore shape information where possible. - if context.in_graph_mode(): + if not context.executing_eagerly(): paddings_constant = tensor_util.constant_value( result.op.inputs[1], partial=True) input_shape = result.op.inputs[0].shape @@ -2161,7 +2156,7 @@ def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"): sparse_tensor.SparseTensorValue)): raise TypeError("Truth must be a SparseTensor.") - return gen_array_ops._edit_distance( + return gen_array_ops.edit_distance( hypothesis.indices, hypothesis.values, hypothesis.dense_shape, @@ -2298,7 +2293,7 @@ def space_to_batch(input, paddings, block_size, name=None): # pylint: disable=r return result -space_to_batch.__doc__ = gen_array_ops._space_to_batch.__doc__ +space_to_batch.__doc__ = gen_array_ops.space_to_batch.__doc__ @tf_export("space_to_depth") @@ -2328,7 +2323,7 @@ def batch_to_space(input, crops, block_size, name=None): # pylint: disable=rede return result -batch_to_space.__doc__ = gen_array_ops._batch_to_space.__doc__ +batch_to_space.__doc__ = gen_array_ops.batch_to_space.__doc__ @tf_export("one_hot") @@ -2472,8 +2467,8 @@ def one_hot(indices, raise TypeError("dtype {0} of on_value does not match " "dtype {1} of off_value".format(on_dtype, off_dtype)) - return gen_array_ops._one_hot(indices, depth, on_value, off_value, axis, - name) + return gen_array_ops.one_hot(indices, depth, on_value, off_value, axis, + name) def _all_dimensions(x): @@ -2601,7 +2596,7 @@ def squeeze(input, axis=None, name=None, squeeze_dims=None): axis = squeeze_dims if np.isscalar(axis): axis = [axis] - return gen_array_ops._squeeze(input, axis, name) + return gen_array_ops.squeeze(input, axis, name) @tf_export("where") @@ -2652,7 +2647,7 @@ def where(condition, x=None, y=None, name=None): condition, preferred_dtype=dtypes.bool, name="condition") return gen_array_ops.where(condition=condition, name=name) elif x is not None and y is not None: - return gen_math_ops._select(condition=condition, x=x, y=y, name=name) + return gen_math_ops.select(condition=condition, x=x, y=y, name=name) else: raise ValueError("x and y must both be non-None or both be None.") @@ -2696,12 +2691,17 @@ reverse_sequence.__doc__ = deprecation.rewrite_argument_docstring( @tf_export("gather") def gather(params, indices, validate_indices=None, name=None, axis=0): - # TODO(rjryan): Remove "Gather" creation in favor of GatherV2 once the forward - # compatibility 3 week period has passed. - if axis == 0: - return gen_array_ops.gather( - params, indices, validate_indices=validate_indices, name=name) - else: + del validate_indices + if axis != 0: + # Note that we do a sparse_read here to avoid snapshotting the entire + # resource variable and doing a gather, which can be inefficient and lead to + # subtle race conditions. TODO(apassos) implement axis != 0 on sparse_read + return gen_array_ops.gather_v2(params, indices, axis, name=name) + try: + # TODO(apassos) find a less bad way of detecting resource variables without + # introducing a circular dependency. + return params.sparse_read(indices, name=name) + except AttributeError: return gen_array_ops.gather_v2(params, indices, axis, name=name) diff --git a/tensorflow/python/ops/batch_norm_benchmark.py b/tensorflow/python/ops/batch_norm_benchmark.py index c2ee2b383231333239c6e2d4e874a0ad1cdf493e..5d68b47aeaef3a90973387ecd5b265eef1e96a5f 100644 --- a/tensorflow/python/ops/batch_norm_benchmark.py +++ b/tensorflow/python/ops/batch_norm_benchmark.py @@ -41,9 +41,8 @@ def batch_norm_op(tensor, mean, variance, beta, gamma, scale): # _batch_norm_with_global_normalization is deprecated in v9 ops.get_default_graph().graph_def_versions.producer = 8 # pylint: disable=protected-access - return gen_nn_ops._batch_norm_with_global_normalization(tensor, mean, - variance, beta, gamma, - 0.001, scale) + return gen_nn_ops._batch_norm_with_global_normalization( + tensor, mean, variance, beta, gamma, 0.001, scale) # pylint: enable=protected-access diff --git a/tensorflow/python/ops/candidate_sampling_ops.py b/tensorflow/python/ops/candidate_sampling_ops.py index 220ef1754d2e1a2d54a8962148b47806df48e98f..9ea1ea9c92c9b016a3f9126c89ee4dc1e73c9f27 100644 --- a/tensorflow/python/ops/candidate_sampling_ops.py +++ b/tensorflow/python/ops/candidate_sampling_ops.py @@ -77,7 +77,7 @@ def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, of each of `sampled_candidates`. """ seed1, seed2 = random_seed.get_seed(seed) - return gen_candidate_sampling_ops._uniform_candidate_sampler( + return gen_candidate_sampling_ops.uniform_candidate_sampler( true_classes, num_true, num_sampled, unique, range_max, seed=seed1, seed2=seed2, name=name) @@ -136,7 +136,7 @@ def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, of each of `sampled_candidates`. """ seed1, seed2 = random_seed.get_seed(seed) - return gen_candidate_sampling_ops._log_uniform_candidate_sampler( + return gen_candidate_sampling_ops.log_uniform_candidate_sampler( true_classes, num_true, num_sampled, unique, range_max, seed=seed1, seed2=seed2, name=name) @@ -193,7 +193,7 @@ def learned_unigram_candidate_sampler(true_classes, num_true, num_sampled, """ seed1, seed2 = random_seed.get_seed(seed) - return gen_candidate_sampling_ops._learned_unigram_candidate_sampler( + return gen_candidate_sampling_ops.learned_unigram_candidate_sampler( true_classes, num_true, num_sampled, unique, range_max, seed=seed1, seed2=seed2, name=name) @@ -283,7 +283,7 @@ def fixed_unigram_candidate_sampler(true_classes, """ seed1, seed2 = random_seed.get_seed(seed) - return gen_candidate_sampling_ops._fixed_unigram_candidate_sampler( + return gen_candidate_sampling_ops.fixed_unigram_candidate_sampler( true_classes, num_true, num_sampled, unique, range_max, vocab_file=vocab_file, distortion=distortion, num_reserved_ids=num_reserved_ids, num_shards=num_shards, shard=shard, @@ -321,7 +321,7 @@ def all_candidate_sampler(true_classes, num_true, num_sampled, unique, of each of `sampled_candidates`. All returned values are 1.0. """ seed1, seed2 = random_seed.get_seed(seed) - return gen_candidate_sampling_ops._all_candidate_sampler( + return gen_candidate_sampling_ops.all_candidate_sampler( true_classes, num_true, num_sampled, unique, seed=seed1, seed2=seed2, name=name) @@ -370,6 +370,6 @@ def compute_accidental_hits(true_classes, sampled_candidates, num_true, """ seed1, seed2 = random_seed.get_seed(seed) - return gen_candidate_sampling_ops._compute_accidental_hits( + return gen_candidate_sampling_ops.compute_accidental_hits( true_classes, sampled_candidates, num_true, seed=seed1, seed2=seed2, name=name) diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 64567ac54ae43acf6f8b674c46525db7a6c4fab7..9cea3e91f7760034d2ab7649709e62dbf1987701 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -169,7 +169,7 @@ def assert_negative(x, data=None, summarize=None, message=None, name=None): with ops.name_scope(name, 'assert_negative', [x, data]): x = ops.convert_to_tensor(x, name='x') if data is None: - if context.in_eager_mode(): + if context.executing_eagerly(): name = _shape_and_dtype_str(x) else: name = x.name @@ -210,7 +210,7 @@ def assert_positive(x, data=None, summarize=None, message=None, name=None): with ops.name_scope(name, 'assert_positive', [x, data]): x = ops.convert_to_tensor(x, name='x') if data is None: - if context.in_eager_mode(): + if context.executing_eagerly(): name = _shape_and_dtype_str(x) else: name = x.name @@ -251,7 +251,7 @@ def assert_non_negative(x, data=None, summarize=None, message=None, name=None): with ops.name_scope(name, 'assert_non_negative', [x, data]): x = ops.convert_to_tensor(x, name='x') if data is None: - if context.in_eager_mode(): + if context.executing_eagerly(): name = _shape_and_dtype_str(x) else: name = x.name @@ -293,7 +293,7 @@ def assert_non_positive(x, data=None, summarize=None, message=None, name=None): with ops.name_scope(name, 'assert_non_positive', [x, data]): x = ops.convert_to_tensor(x, name='x') if data is None: - if context.in_eager_mode(): + if context.executing_eagerly(): name = _shape_and_dtype_str(x) else: name = x.name @@ -343,7 +343,7 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): x = ops.convert_to_tensor(x, name='x') y = ops.convert_to_tensor(y, name='y') - if context.in_eager_mode(): + if context.executing_eagerly(): eq = math_ops.equal(x, y) condition = math_ops.reduce_all(eq) if not condition: @@ -363,27 +363,30 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): (x_sum, x_np[:x_sum], y_sum, y_np[:y_sum])) - # Get the values that actually differed and their indices. - mask = math_ops.logical_not(eq) - indices = array_ops.where(mask) - indices_np = indices.numpy() - x_vals = array_ops.boolean_mask(x, mask) - y_vals = array_ops.boolean_mask(y, mask) - summarize = min(summarize, indices_np.shape[0]) + index_and_values_str = '' + if x.shape == y.shape: + # If the shapes of x and y are the same, + # Get the values that actually differed and their indices. + # If shapes are different this information is more confusing + # than useful. + mask = math_ops.logical_not(eq) + indices = array_ops.where(mask) + indices_np = indices.numpy() + x_vals = array_ops.boolean_mask(x, mask) + y_vals = array_ops.boolean_mask(y, mask) + summarize = min(summarize, indices_np.shape[0]) + index_and_values_str = ( + 'Indices of first %s different values:\n%s\n' + 'Corresponding x values:\n%s\n' + 'Corresponding y values:\n%s\n' % + (summarize, indices_np[:summarize], + x_vals.numpy().reshape((-1,))[:summarize], + y_vals.numpy().reshape((-1,))[:summarize])) raise errors.InvalidArgumentError( node_def=None, op=None, - message=('%s\nCondition x == y did not hold.\n' - 'Indices of first %s different values:\n%s\n' - 'Corresponding x values:\n%s\n' - 'Corresponding y values:\n%s\n' - '%s' - % - (message or '', - summarize, indices_np[:summarize], - x_vals.numpy().reshape((-1,))[:summarize], - y_vals.numpy().reshape((-1,))[:summarize], - summary_msg))) + message=('%s\nCondition x == y did not hold.\n%s%s' % + (message or '', index_and_values_str, summary_msg))) return if data is None: @@ -435,7 +438,7 @@ def assert_none_equal( with ops.name_scope(name, 'assert_none_equal', [x, y, data]): x = ops.convert_to_tensor(x, name='x') y = ops.convert_to_tensor(y, name='y') - if context.in_eager_mode(): + if context.executing_eagerly(): x_name = _shape_and_dtype_str(x) y_name = _shape_and_dtype_str(y) else: @@ -512,7 +515,7 @@ def assert_near( rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=x.dtype) atol = ops.convert_to_tensor(atol, name='atol', dtype=x.dtype) - if context.in_eager_mode(): + if context.executing_eagerly(): x_name = _shape_and_dtype_str(x) y_name = _shape_and_dtype_str(y) else: @@ -562,7 +565,7 @@ def assert_less(x, y, data=None, summarize=None, message=None, name=None): with ops.name_scope(name, 'assert_less', [x, y, data]): x = ops.convert_to_tensor(x, name='x') y = ops.convert_to_tensor(y, name='y') - if context.in_eager_mode(): + if context.executing_eagerly(): x_name = _shape_and_dtype_str(x) y_name = _shape_and_dtype_str(y) else: @@ -610,7 +613,7 @@ def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None): with ops.name_scope(name, 'assert_less_equal', [x, y, data]): x = ops.convert_to_tensor(x, name='x') y = ops.convert_to_tensor(y, name='y') - if context.in_eager_mode(): + if context.executing_eagerly(): x_name = _shape_and_dtype_str(x) y_name = _shape_and_dtype_str(y) else: @@ -658,7 +661,7 @@ def assert_greater(x, y, data=None, summarize=None, message=None, name=None): with ops.name_scope(name, 'assert_greater', [x, y, data]): x = ops.convert_to_tensor(x, name='x') y = ops.convert_to_tensor(y, name='y') - if context.in_eager_mode(): + if context.executing_eagerly(): x_name = _shape_and_dtype_str(x) y_name = _shape_and_dtype_str(y) else: @@ -708,7 +711,7 @@ def assert_greater_equal(x, y, data=None, summarize=None, message=None, with ops.name_scope(name, 'assert_greater_equal', [x, y, data]): x = ops.convert_to_tensor(x, name='x') y = ops.convert_to_tensor(y, name='y') - if context.in_eager_mode(): + if context.executing_eagerly(): x_name = _shape_and_dtype_str(x) y_name = _shape_and_dtype_str(y) else: @@ -808,7 +811,7 @@ def assert_rank(x, rank, data=None, summarize=None, message=None, name=None): static_condition = lambda actual_rank, given_rank: actual_rank == given_rank dynamic_condition = math_ops.equal - if context.in_eager_mode(): + if context.executing_eagerly(): name = '' else: name = x.name @@ -873,7 +876,7 @@ def assert_rank_at_least( static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank dynamic_condition = math_ops.greater_equal - if context.in_eager_mode(): + if context.executing_eagerly(): name = '' else: name = x.name @@ -1001,7 +1004,7 @@ def assert_rank_in( ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks]) message = message or '' - if context.in_eager_mode(): + if context.executing_eagerly(): name = '' else: name = x.name @@ -1054,7 +1057,7 @@ def assert_integer(x, message=None, name=None): with ops.name_scope(name, 'assert_integer', [x]): x = ops.convert_to_tensor(x, name='x') if not x.dtype.is_integer: - if context.in_eager_mode(): + if context.executing_eagerly(): name = 'tensor' else: name = x.name @@ -1087,12 +1090,11 @@ def assert_type(tensor, tf_type, message=None, name=None): with ops.name_scope(name, 'assert_type', [tensor]): tensor = ops.convert_to_tensor(tensor, name='tensor') if tensor.dtype != tf_type: - if context.in_graph_mode(): - raise TypeError( - '%s %s must be of type %s' % (message, tensor.name, tf_type)) + if context.executing_eagerly(): + raise TypeError('%s tensor must be of type %s' % (message, tf_type)) else: - raise TypeError( - '%s tensor must be of type %s' % (message, tf_type)) + raise TypeError('%s %s must be of type %s' % (message, tensor.name, + tf_type)) return control_flow_ops.no_op('statically_determined_correct_type') @@ -1240,7 +1242,7 @@ def assert_scalar(tensor, name=None): tensor = ops.convert_to_tensor(tensor, name=name_scope) shape = tensor.get_shape() if shape.ndims != 0: - if context.in_eager_mode(): + if context.executing_eagerly(): raise ValueError('Expected scalar shape, saw shape: %s.' % (shape,)) else: diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py index 97b57177b29986a006df992f4c0c2b79e11467aa..45955554cab130597e106660ff1fb4cdf7e9aeb1 100644 --- a/tensorflow/python/ops/control_flow_grad.py +++ b/tensorflow/python/ops/control_flow_grad.py @@ -28,7 +28,6 @@ from tensorflow.python.ops import math_ops # go/tf-wildcard-import # pylint: disable=wildcard-import,undefined-variable from tensorflow.python.ops.control_flow_ops import * -from tensorflow.python.ops.gen_control_flow_ops import * # pylint: enable=wildcard-import @@ -143,6 +142,7 @@ def _ExitGrad(op, grad): """Gradients for an exit op are calculated using an Enter op.""" graph = ops.get_default_graph() # pylint: disable=protected-access + op_ctxt = op._get_control_flow_context() grad_ctxt = graph._get_control_flow_context() # pylint: enable=protected-access if not grad_ctxt.back_prop: @@ -151,10 +151,8 @@ def _ExitGrad(op, grad): # no gradient computation. return None - # pylint: disable=protected-access - if op._get_control_flow_context().grad_state: + if op_ctxt.grad_state: raise TypeError("Second-order gradient for while loops not supported.") - # pylint: enable=protected-access if isinstance(grad, ops.Tensor): grad_ctxt.AddName(grad.name) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 8218e60b53450a500df71719e533fb0c1cbeb5b5..1278768d8bdc9f039f19cf032f8ee09442ea34a9 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -23,7 +23,6 @@ See the @{$python/control_flow_ops} guide. @@no_op @@count_up_to @@cond -@@smart_cond @@case @@while_loop @@logical_and @@ -153,7 +152,7 @@ def Assert(condition, data, summarize=None, name=None): @compatibility{eager} `tf.errors.InvalidArgumentError` if `condition` is not true """ - if context.in_eager_mode(): + if context.executing_eagerly(): if not condition: xs = ops.convert_n_to_tensor(data) data_str = [_summarize_eager(x, summarize) for x in xs] @@ -179,6 +178,8 @@ def Assert(condition, data, summarize=None, name=None): condition, data, summarize, name="Assert") guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard") + if context.executing_eagerly(): + return return guarded_assert.op @@ -195,7 +196,7 @@ def _Identity(data, name=None): data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True) if isinstance(data, ops.Tensor): if data.dtype._is_ref_dtype: # pylint: disable=protected-access - return gen_array_ops._ref_identity(data, name=name) + return gen_array_ops.ref_identity(data, name=name) else: return array_ops.identity(data, name=name) else: @@ -263,10 +264,10 @@ def _Enter(data, data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True) if isinstance(data, ops.Tensor): if data.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access - result = gen_control_flow_ops._ref_enter( + result = gen_control_flow_ops.ref_enter( data, frame_name, is_constant, parallel_iterations, name=name) else: - result = gen_control_flow_ops._enter( + result = gen_control_flow_ops.enter( data, frame_name, is_constant, parallel_iterations, name=name) if use_input_shape: result.set_shape(data.get_shape()) @@ -281,7 +282,7 @@ def _Enter(data, parallel_iterations=parallel_iterations, use_input_shape=use_input_shape, name=name) - indices = gen_control_flow_ops._enter( + indices = gen_control_flow_ops.enter( data.indices, frame_name, is_constant, @@ -292,7 +293,7 @@ def _Enter(data, if isinstance(data, ops.IndexedSlices): dense_shape = data.dense_shape if dense_shape is not None: - dense_shape = gen_control_flow_ops._enter( + dense_shape = gen_control_flow_ops.enter( dense_shape, frame_name, is_constant, @@ -302,7 +303,7 @@ def _Enter(data, dense_shape.set_shape(data.dense_shape.get_shape()) return ops.IndexedSlices(values, indices, dense_shape) else: - dense_shape = gen_control_flow_ops._enter( + dense_shape = gen_control_flow_ops.enter( data.dense_shape, frame_name, is_constant, @@ -328,7 +329,7 @@ def exit(data, name=None): # pylint: disable=redefined-builtin data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True) if isinstance(data, ops.Tensor): if data.dtype._is_ref_dtype: # pylint: disable=protected-access - return gen_control_flow_ops._ref_exit(data, name) + return gen_control_flow_ops.ref_exit(data, name) else: return gen_control_flow_ops._exit(data, name) else: @@ -370,17 +371,17 @@ def switch(data, pred, dtype=None, name=None): data, dtype=dtype, name="data", as_ref=True) pred = ops.convert_to_tensor(pred, name="pred") if isinstance(data, ops.Tensor): - return gen_control_flow_ops._switch(data, pred, name=name) + return gen_control_flow_ops.switch(data, pred, name=name) else: if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)): raise TypeError("Type %s not supported" % type(data)) val, ind = data.values, data.indices - val_f, val_t = gen_control_flow_ops._switch(val, pred, name=name) - ind_f, ind_t = gen_control_flow_ops._switch(ind, pred, name="indices") + val_f, val_t = gen_control_flow_ops.switch(val, pred, name=name) + ind_f, ind_t = gen_control_flow_ops.switch(ind, pred, name="indices") if isinstance(data, ops.IndexedSlices): dense_shape = data.dense_shape if dense_shape is not None: - dense_shape_f, dense_shape_t = gen_control_flow_ops._switch( + dense_shape_f, dense_shape_t = gen_control_flow_ops.switch( dense_shape, pred, name="dense_shape") else: dense_shape_f, dense_shape_t = None, None @@ -388,7 +389,7 @@ def switch(data, pred, dtype=None, name=None): ops.IndexedSlices(val_t, ind_t, dense_shape_t)) else: dense_shape = data.dense_shape - dense_shape_f, dense_shape_t = gen_control_flow_ops._switch( + dense_shape_f, dense_shape_t = gen_control_flow_ops.switch( data.dense_shape, pred, name="dense_shape") return (sparse_tensor.SparseTensor(ind_f, val_f, dense_shape_f), sparse_tensor.SparseTensor(ind_t, val_t, dense_shape_t)) @@ -472,15 +473,15 @@ def merge(inputs, name=None): ] if all([isinstance(v, ops.Tensor) for v in inputs]): if all([v.dtype._is_ref_dtype for v in inputs]): # pylint: disable=protected-access - return gen_control_flow_ops._ref_merge(inputs, name) + return gen_control_flow_ops.ref_merge(inputs, name) else: - return gen_control_flow_ops._merge(inputs, name) + return gen_control_flow_ops.merge(inputs, name) elif all([isinstance(v, sparse_tensor.SparseTensor) for v in inputs]): # Only handle the case when all inputs are SparseTensor. values, _ = merge([inp.values for inp in inputs], name=name) - indices, chosen_index = gen_control_flow_ops._merge( + indices, chosen_index = gen_control_flow_ops.merge( [inp.indices for inp in inputs], name="indices") - dense_shape, _ = gen_control_flow_ops._merge( + dense_shape, _ = gen_control_flow_ops.merge( [inp.dense_shape for inp in inputs], name="dense_shape") return (sparse_tensor.SparseTensor(indices, values, dense_shape), chosen_index) @@ -488,13 +489,13 @@ def merge(inputs, name=None): # For now convert all the inputs as IndexedSlices. inputs = math_ops._as_indexed_slices_list(inputs, optimize=False) values, _ = merge([inp.values for inp in inputs], name=name) - indices, chosen_index = gen_control_flow_ops._merge( + indices, chosen_index = gen_control_flow_ops.merge( [inp.indices for inp in inputs], name="indices") if any(inp.dense_shape is not None for inp in inputs): if any(inp.dense_shape is None for inp in inputs): raise ValueError("Either all merged IndexedSlices must have a " "dense_shape, or none must have a dense_shape.") - dense_shape, _ = gen_control_flow_ops._merge( + dense_shape, _ = gen_control_flow_ops.merge( [inp.dense_shape for inp in inputs], name="dense_shape") else: dense_shape = None @@ -1014,10 +1015,8 @@ class GradLoopState(object): else: max_size = GetMaxSizeFromNestedMaximumIterations( value, self.forward_context) - # pylint: disable=protected-access - acc = gen_data_flow_ops._stack_v2( + acc = gen_data_flow_ops.stack_v2( max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") - # pylint: enable=protected-access if curr_ctxt: curr_ctxt.Exit() @@ -1030,10 +1029,8 @@ class GradLoopState(object): if value_ctxt == self.forward_context: # value is not nested in the forward context. self.forward_context.Enter() - # pylint: disable=protected-access - push = gen_data_flow_ops._stack_push_v2( + push = gen_data_flow_ops.stack_push_v2( enter_acc, value, swap_memory=swap_enabled) - # pylint: enable=protected-access self.forward_context.Exit() # Protect stack push and order it before forward_index. self.forward_index.op._add_control_input(push.op) @@ -1045,18 +1042,14 @@ class GradLoopState(object): # The special case for creating a zero tensor for a dead # branch of a switch. See ControlFlowState.ZerosLike(). value_ctxt.outer_context.Enter() - # pylint: disable=protected-access - push = gen_data_flow_ops._stack_push_v2( + push = gen_data_flow_ops.stack_push_v2( enter_acc, value, swap_memory=swap_enabled) - # pylint: enable=protected-access value_ctxt.outer_context.Exit() push.op._set_control_flow_context(value_ctxt) else: value_ctxt.Enter() - # pylint: disable=protected-access - push = gen_data_flow_ops._stack_push_v2( + push = gen_data_flow_ops.stack_push_v2( enter_acc, value, swap_memory=swap_enabled) - # pylint: enable=protected-access value_ctxt.Exit() # Protect stack push and order it before forward_sync. self.forward_sync._add_control_input(push.op) @@ -1103,10 +1096,8 @@ class GradLoopState(object): pred = cond_ctxt.pred branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch history_value = _SwitchRefOrTensor(history_value, pred)[branch] - # pylint: disable=protected-access - pop = gen_data_flow_ops._stack_pop_v2(history_value, - value.dtype.base_dtype) - # pylint: enable=protected-access + pop = gen_data_flow_ops.stack_pop_v2(history_value, + value.dtype.base_dtype) pop.set_shape(value.get_shape()) self.grad_context.Exit() parallel_iterations = self.grad_context.parallel_iterations @@ -1476,7 +1467,10 @@ def ZerosLikeOutsideLoop(op, index): branch = op_ctxt.branch switch_val = switch(op.inputs[0], pred)[1 - branch] zeros_shape = array_ops.shape_internal(switch_val, optimize=False) - return array_ops.zeros(zeros_shape, dtype=val.dtype) + # Ensure ops created within array_ops.zeros are dominated by switch in + # cond context. + with ops.control_dependencies([switch_val]): + return array_ops.zeros(zeros_shape, dtype=val.dtype) else: return array_ops.zeros_like(val, optimize=False) @@ -1508,9 +1502,11 @@ class ControlFlowContext(object): if values_def: self._init_values_from_proto(values_def, import_scope=import_scope) else: - # Values that have been already seen in this context. + # The names of tensors that have been already seen in this context. self._values = set() - # Values referenced by but external to this context. + # The keys are the names of tensors referenced by but external to this + # context. Each value is the Tensor that should be used by this context to + # access the key value (e.g. a switch output guarding a cond input value). self._external_values = {} def _init_values_from_proto(self, values_def, import_scope=None): @@ -1697,9 +1693,12 @@ class CondContext(ControlFlowContext): self._pivot = pivot # The predicate tensor in this branch self._branch = branch # 0 or 1 representing this branch - # Values considered to have been already seen in this context. + # Values considered to have been already seen in this context. They are + # not included in this context. self._values.add(pred.name) + self._external_values[pred.name] = pred self._values.add(pivot.name) + self._external_values[pivot.name] = pivot def _init_from_proto(self, context_def, import_scope=None): """Creates a new `CondContext` from protocol buffer. @@ -1717,8 +1716,8 @@ class CondContext(ControlFlowContext): self._pivot = g.as_graph_element( ops.prepend_name_scope(context_def.pivot_name, import_scope)) self._branch = context_def.branch - super(CondContext, self).__init__( - values_def=context_def.values_def, import_scope=import_scope) + super(CondContext, self).__init__(values_def=context_def.values_def, + import_scope=import_scope) @property def pred(self): @@ -1766,13 +1765,9 @@ class CondContext(ControlFlowContext): context_def.branch = self._branch context_def.values_def.MergeFrom(super(CondContext, self)._to_values_def( export_scope)) - # TODO(b/72868227): enable this once the corresponding control_flow.proto - # changes have been checked in (they aren't checked in and this is - # disabled for now to ensure forwards compatibility). - if False: # pylint: disable=using-constant-test - for nested in self._nested_contexts: - nested_def = context_def.nested_contexts.add() - nested.to_control_flow_context_def(nested_def) + for nested in self._nested_contexts: + nested_def = context_def.nested_contexts.add() + nested.to_control_flow_context_def(nested_def) return context_def else: @@ -1784,14 +1779,10 @@ class CondContext(ControlFlowContext): ret = CondContext(context_def=context_def, import_scope=import_scope) - # TODO(b/72868227): remove "if hasattr(...)" once the corresponding - # control_flow.proto changes have been checked in (they aren't checked in - # and this is here for now to ensure forwards compatibility). - if hasattr(context_def, "nested_contexts"): - ret.Enter() - for nested_def in context_def.nested_contexts: - from_control_flow_context_def(nested_def) - ret.Exit() + ret.Enter() + for nested_def in context_def.nested_contexts: + from_control_flow_context_def(nested_def, import_scope=import_scope) + ret.Exit() return ret def to_control_flow_context_def(self, context_def, export_scope=None): @@ -1810,6 +1801,7 @@ class CondContext(ControlFlowContext): if self._outer_context: result = self._outer_context.AddValue(val) self._values.add(result.name) + self._external_values[result.name] = result with ops.control_dependencies(None): result = _SwitchRefOrTensor(result, self._pred)[self._branch] if self._outer_context: @@ -1874,6 +1866,7 @@ class CondContext(ControlFlowContext): if self._outer_context: real_val = self._outer_context.AddValue(val) self._values.add(real_val.name) + self._external_values[real_val.name] = real_val real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch] self._external_values[val.name] = real_val else: @@ -2035,7 +2028,7 @@ def cond(pred, raise TypeError("false_fn must be callable.") with ops.name_scope(name, "cond", [pred]): - if context.in_eager_mode(): + if context.executing_eagerly(): if pred: return _UnpackIfSingleton(true_fn()) return _UnpackIfSingleton(false_fn()) @@ -2109,10 +2102,7 @@ def cond(pred, # Only add non-nested conds to the collection. Any nested control flow will # be encapsulated in the root context. assert context_t.outer_context == context_f.outer_context - # TODO(b/72868227): remove "if True..." once the corresponding - # control_flow.proto changes have been checked in (they aren't checked in - # and this is disabled for now to ensure forwards compatibility). - if True or context_t.outer_context is None: + if context_t.outer_context is None: ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t) ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f) @@ -2128,61 +2118,6 @@ def cond(pred, # pylint: enable=redefined-outer-name -def smart_cond(pred, true_fn=None, false_fn=None, name=None): - """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. - - If `pred` is a bool or has a constant value, we return either `true_fn()` - or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. - - Arguments: - pred: A scalar determining whether to return the result of `true_fn` or - `false_fn`. - true_fn: The callable to be performed if pred is true. - false_fn: The callable to be performed if pred is false. - name: Optional name prefix when using `tf.cond`. - - Returns: - Tensors returned by the call to either `true_fn` or `false_fn`. - - Raises: - TypeError: If `true_fn` or `false_fn` is not callable. - """ - if not callable(true_fn): - raise TypeError("`true_fn` must be callable.") - if not callable(false_fn): - raise TypeError("`false_fn` must be callable.") - - pred_value = smart_constant_value(pred) - if pred_value is not None: - if pred_value: - return true_fn() - else: - return false_fn() - else: - return cond(pred, true_fn=true_fn, false_fn=false_fn, name=name) - - -def smart_constant_value(pred): - """Return the bool value for `pred`, or None if `pred` had a dynamic value. - - Arguments: - pred: A scalar, either a Python bool or tensor. - - Returns: - True or False if `pred` has a constant boolean value, None otherwise. - - Raises: - TypeError: If `pred` is not a Tensor or bool. - """ - if isinstance(pred, bool): - pred_value = pred - elif isinstance(pred, ops.Tensor): - pred_value = tensor_util.constant_value(pred) - else: - raise TypeError("`pred` must be a Tensor or a Python bool.") - return pred_value - - def _resource_safe_shape(t): """Returns the shape of t or the variable it points to.""" if t.dtype == dtypes.resource: @@ -2390,13 +2325,9 @@ class WhileContext(ControlFlowContext): context_def.values_def.MergeFrom( super(WhileContext, self)._to_values_def( export_scope=export_scope)) - # TODO(b/72868227): remove "if True..." once the corresponding - # control_flow.proto changes have been checked in (they aren't checked in - # and this is disabled for now to ensure forwards compatibility). - if False: # pylint: disable=using-constant-test - for nested in self._nested_contexts: - nested_def = context_def.nested_contexts.add() - nested.to_control_flow_context_def(nested_def) + for nested in self._nested_contexts: + nested_def = context_def.nested_contexts.add() + nested.to_control_flow_context_def(nested_def) return context_def else: @@ -2418,14 +2349,10 @@ class WhileContext(ControlFlowContext): """ ret = WhileContext(context_def=context_def, import_scope=import_scope) - # TODO(b/72868227): remove "if hasattr(...)" once the corresponding - # control_flow.proto changes have been checked in (they aren't checked in - # and this is disabled for now to ensure forwards compatibility). - if hasattr(context_def, "nested_contexts"): - ret.Enter() - for nested_def in context_def.nested_contexts: - from_control_flow_context_def(nested_def, import_scope=import_scope) - ret.Exit() + ret.Enter() + for nested_def in context_def.nested_contexts: + from_control_flow_context_def(nested_def, import_scope=import_scope) + ret.Exit() return ret def GetWhileContext(self): @@ -3009,8 +2936,11 @@ class WhileContext(ControlFlowContext): loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars) try: self.Enter() - original_body_result, exit_vars = self._BuildLoop( - pred, body, original_loop_vars, loop_vars, shape_invariants) + # _BuildLoop calls _update_input in several places. _lock ensures a + # Session.run call cannot occur between creating and mutating new ops. + with ops.get_default_graph()._lock: # pylint: disable=protected-access + original_body_result, exit_vars = self._BuildLoop( + pred, body, original_loop_vars, loop_vars, shape_invariants) finally: self.Exit() @@ -3250,7 +3180,7 @@ def while_loop(cond, math_ops.logical_and(i < maximum_iterations, orig_cond(*lv))) body = lambda i, lv: (i + 1, orig_body(*lv)) - if context.in_eager_mode(): + if context.executing_eagerly(): while cond(*loop_vars): loop_vars = body(*loop_vars) if maximum_iterations is not None: @@ -3270,10 +3200,7 @@ def while_loop(cond, swap_memory=swap_memory) # Only add non-nested loops to the collection. Any nested control flow will # be encapsulated in the root context. - # TODO(b/72868227): enable condition once the corresponding - # control_flow.proto changes have been checked in (they aren't checked in - # and this is disabled for now to ensure forwards compatibility). - if True or loop_context.outer_context is None: + if loop_context.outer_context is None: ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context) result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants) if maximum_iterations is not None: @@ -3347,7 +3274,7 @@ def with_dependencies(dependencies, output_tensor, name=None): Raises: TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`. """ - if context.in_eager_mode(): + if context.executing_eagerly(): return output_tensor with ops.name_scope(name, "control_dependency", list(dependencies) + [output_tensor]) as name: @@ -3392,7 +3319,7 @@ def group(*inputs, **kwargs): Raises: ValueError: If an unknown keyword argument is provided. """ - if context.in_eager_mode(): + if context.executing_eagerly(): return None name = kwargs.pop("name", None) if kwargs: @@ -3472,7 +3399,7 @@ def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined objects. """ - if context.in_eager_mode(): + if context.executing_eagerly(): return tensors with ops.name_scope(name, "tuple", tensors) as name: tensors = [t if (isinstance(t, ops.Operation) @@ -3560,15 +3487,17 @@ def _case_create_default_action(predicates, actions): return default_action, other_predicates, other_actions -def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name): +def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name, + allow_python_preds): """Verifies input arguments for the case function. Args: - pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a - callable which returns a list of tensors. + pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, + and a callable which returns a list of tensors. exclusive: True iff at most one predicate is allowed to evaluate to `True`. name: A name for the case operation. - + allow_python_preds: if true, pred_fn_pairs may contain Python bools in + addition to boolean Tensors Raises: TypeError: If `pred_fn_pairs` is not a list/dictionary. TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. @@ -3593,14 +3522,69 @@ def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name): if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2: raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple") pred, fn = pred_fn_pair - if pred.dtype != dtypes.bool: - raise TypeError("pred must be of type bool: %s", pred.name) + + if isinstance(pred, ops.Tensor): + if pred.dtype != dtypes.bool: + raise TypeError("pred must be Tensor of type bool: %s" % pred.name) + elif not allow_python_preds: + raise TypeError("pred must be a Tensor, got: %s" % pred) + elif not isinstance(pred, bool): + raise TypeError("pred must be a Tensor or bool, got: %s" % pred) + if not callable(fn): raise TypeError("fn for pred %s must be callable." % pred.name) + predicates, actions = zip(*pred_fn_pairs) return predicates, actions +def _case_helper(cond_fn, pred_fn_pairs, default, + exclusive, name, allow_python_preds=False, **cond_kwargs): + """Implementation of case that allows for different cond functions. + + Args: + cond_fn: method that has signature and semantics of `cond` above. + pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a + callable which returns a list of tensors. + default: Optional callable that returns a list of tensors. + exclusive: True iff at most one predicate is allowed to evaluate to `True`. + name: A name for this operation (optional). + allow_python_preds: if true, pred_fn_pairs may contain Python bools in + addition to boolean Tensors + **cond_kwargs: keyword arguments that will be passed to `cond_fn`. + + Returns: + The tensors returned by the first pair whose predicate evaluated to True, or + those returned by `default` if none does. + + Raises: + TypeError: If `pred_fn_pairs` is not a list/dictionary. + TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. + TypeError: If `fns[i]` is not callable for any i, or `default` is not + callable. + """ + predicates, actions = _case_verify_and_canonicalize_args( + pred_fn_pairs, exclusive, name, allow_python_preds) + with ops.name_scope(name, "case", [predicates]): + if default is None: + default, predicates, actions = _case_create_default_action( + predicates, actions) + fn = default + # To eval conditions in direct order we create nested conditions in reverse: + # cond_fn(c[0], true_fn=.., false_fn=cond_fn(c[1], ...)) + for predicate, action in reversed(list(zip(predicates, actions))): + fn = functools.partial( + cond_fn, predicate, true_fn=action, false_fn=fn, **cond_kwargs) + if exclusive: + with ops.control_dependencies([ + _assert_at_most_n_true( + predicates, n=1, msg="Input error: exclusive=True") + ]): + return fn() + else: + return fn() + + @tf_export("case") def case(pred_fn_pairs, default=None, @@ -3691,26 +3675,8 @@ def case(pred_fn_pairs, TypeError: If `fns[i]` is not callable for any i, or `default` is not callable. """ - predicates, actions = _case_verify_and_canonicalize_args( - pred_fn_pairs, exclusive, name) - with ops.name_scope(name, "case", [predicates]): - if default is None: - default, predicates, actions = _case_create_default_action( - predicates, actions) - fn = default - # To eval conditions in direct order we create nested conditions in reverse: - # cond(c[0], true_fn=.., false_fn=cond(c[1], ...)) - for predicate, action in reversed(list(zip(predicates, actions))): - fn = functools.partial( - cond, predicate, true_fn=action, false_fn=fn, strict=strict) - if exclusive: - with ops.control_dependencies([ - _assert_at_most_n_true( - predicates, n=1, msg="Input error: exclusive=True") - ]): - return fn() - else: - return fn() + return _case_helper(cond, pred_fn_pairs, default, exclusive, name, + allow_python_preds=False, strict=strict) class XLAControlFlowContext(ControlFlowContext): diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index adc8c51e11191c4dbf29ac8294f3390bff37bc6c..f22f3059d139d1bb7c7db57a2939184f1089f397 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -349,42 +349,6 @@ class SwitchTestCase(test_util.TensorFlowTestCase): self.assertEquals(grad_x_false.eval(), 0.) -@test_util.with_c_api -class SmartCondTest(test_util.TensorFlowTestCase): - - def testSmartCondTrue(self): - with ops.Graph().as_default(): - with session.Session(): - x = constant_op.constant(2) - y = constant_op.constant(5) - z = control_flow_ops.smart_cond(True, lambda: math_ops.multiply(x, 16), - lambda: math_ops.multiply(y, 5)) - self.assertEqual(z.eval(), 32) - - def testSmartCondFalse(self): - with ops.Graph().as_default(): - with session.Session(): - x = constant_op.constant(4) - y = constant_op.constant(3) - z = control_flow_ops.smart_cond(False, lambda: math_ops.multiply(x, 16), - lambda: math_ops.multiply(y, 3)) - self.assertEqual(z.eval(), 9) - - def testSmartCondMissingArg1(self): - with ops.Graph().as_default(): - with session.Session(): - x = constant_op.constant(1) - with self.assertRaises(TypeError): - control_flow_ops.smart_cond(True, false_fn=lambda: x) - - def testSmartCondMissingArg2(self): - with ops.Graph().as_default(): - with session.Session(): - x = constant_op.constant(1) - with self.assertRaises(TypeError): - control_flow_ops.smart_cond(True, lambda: x) - - @test_util.with_c_api class CondTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py index 83da6739db673644f59fda3044769b18b2138fbc..4b57e2de790af13499bc73cfcfa98e999eab1603 100644 --- a/tensorflow/python/ops/ctc_ops.py +++ b/tensorflow/python/ops/ctc_ops.py @@ -148,7 +148,7 @@ def ctc_loss(labels, inputs, sequence_length, if not time_major: inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,N) => (T,B,N) - loss, _ = gen_ctc_ops._ctc_loss( + loss, _ = gen_ctc_ops.ctc_loss( inputs, labels.indices, labels.values, @@ -224,7 +224,7 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True): sequence found, the negative of the sum of the greatest logit at each timeframe. """ - outputs = gen_ctc_ops._ctc_greedy_decoder( + outputs = gen_ctc_ops.ctc_greedy_decoder( inputs, sequence_length, merge_repeated=merge_repeated) (decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs return ([sparse_tensor.SparseTensor(decoded_ix, decoded_val, decoded_shape)], @@ -272,7 +272,7 @@ def ctc_beam_search_decoder(inputs, sequence_length, beam_width=100, """ decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = ( - gen_ctc_ops._ctc_beam_search_decoder( + gen_ctc_ops.ctc_beam_search_decoder( inputs, sequence_length, beam_width=beam_width, top_paths=top_paths, merge_repeated=merge_repeated)) diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py new file mode 100644 index 0000000000000000000000000000000000000000..9eacac1b3704c43cbeb5ecd0cbe827cac3a7cc8b --- /dev/null +++ b/tensorflow/python/ops/custom_gradient.py @@ -0,0 +1,134 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Decorator to overrides the gradient for a function.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.eager import tape +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.util import nest +from tensorflow.python.util import tf_decorator +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("custom_gradient") +def custom_gradient(f): + """Decorator to define a function with a custom gradient. + + This decorator allows fine grained control over the gradients of a sequence + for operations. This may be useful for multiple reasons, including providing + a more efficient or numerically stable gradient for a sequence of operations. + + For example, consider the following function that commonly occurs in the + computation of cross entropy and log likelihoods: + + ```python + def log1pexp(x): + return tf.log(1 + tf.exp(x)) + ``` + + Due to numerical instability, the gradient this function evaluated at x=100 is + NaN. For example: + + ```python + x = tf.constant(100.) + y = log1pexp(x) + dy = tf.gradients(y, x) # Will be NaN when evaluated. + ``` + + The gradient expression can be analytically simplified to provide numerical + stability: + + ```python + @tf.custom_gradient + def log1pexp(x): + e = tf.exp(x) + def grad(dy): + return dy * (1 - 1 / (1 + e)) + return tf.log(1 + e), grad + ``` + + With this definition, the gradient at x=100 will be correctly evaluated as + 1.0. + + See also @{tf.RegisterGradient} which registers a gradient function for a + primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows + for fine grained control over the gradient computation of a sequence of + operations. + + Args: + f: function `f(x)` that returns a tuple `(y, grad_fn)` where: + - `x` is a `Tensor` or sequence of `Tensor` inputs to the function. + - `y` is a `Tensor` or sequence of `Tensor` outputs of applying + TensorFlow + operations in `f` to `x`. + - `grad_fn` is a function with the signature `g(grad_ys)` which returns + a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect + to the `Tensor`s in `x. `grad_ys` is a `Tensor` or sequence of + `Tensor`s the same size as `y` holding the initial value gradients for + each `Tensor` in `y`. + + Returns: + A function `h(x)` which returns the same value as `f(x)[0]` and whose + gradient (as calculated by @{tf.gradients}) is determined by `f(x)[1]`. + """ + + def decorated(*args, **kwargs): + """Decorated function with custom gradient.""" + if not context.executing_eagerly(): + if kwargs: + raise ValueError( + "The custom_gradient decorator currently suports keywords " + "arguments only when eager execution is enabled.") + name = "CustomGradient-%s" % ops.uid() + args = [ops.convert_to_tensor(x) for x in args] + result, grad_fn = f(*args) + flat_result = nest.flatten(result) + all_tensors = flat_result + args + + @ops.RegisterGradient(name) + def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable + gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)])) + # Need to return one value per input to the IdentityN, so pad the + # gradients of the inputs of the custom_gradient function with the + # gradients of the outputs as well. + return ([None] * len(flat_result)) + gradients + + with ops.get_default_graph().gradient_override_map({"IdentityN": name}): + all_tensors = array_ops.identity_n(all_tensors) + return nest.pack_sequence_as( + structure=result, flat_sequence=all_tensors[:len(flat_result)]) + + input_tensors = [ops.convert_to_tensor(x) for x in args] + + result, grad_fn = f(*args, **kwargs) + flat_result = nest.flatten(result) + # TODO(apassos) consider removing the identity below. + flat_result = [gen_array_ops.identity(x) for x in flat_result] + + def actual_grad_fn(*outputs): + return nest.flatten(grad_fn(*outputs)) + + tape.record_operation(f.__name__, flat_result, input_tensors, + actual_grad_fn) + flat_result = list(flat_result) + return nest.pack_sequence_as(result, flat_result) + + return tf_decorator.make_decorator(f, decorated) diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 03ed537cfcf27151a0200d7a17f63b1a2bc7ba1a..d2cc87555f6321432261b32f08431c23ce707eff 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -159,7 +159,7 @@ class QueueBase(object): ValueError: If one of the arguments is invalid. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError( "Queues are not supported when eager execution is enabled. " "Instead, please use tf.data to get data into your model.") @@ -177,10 +177,10 @@ class QueueBase(object): else: self._names = None self._queue_ref = queue_ref - if context.in_graph_mode(): - self._name = self._queue_ref.op.name.split("/")[-1] - else: + if context.executing_eagerly(): self._name = context.context().scope_name + else: + self._name = self._queue_ref.op.name.split("/")[-1] @staticmethod def from_list(index, queues): @@ -231,9 +231,9 @@ class QueueBase(object): @property def name(self): """The name of the underlying queue.""" - if context.in_graph_mode(): - return self._queue_ref.op.name - return self._name + if context.executing_eagerly(): + return self._name + return self._queue_ref.op.name @property def dtypes(self): @@ -342,10 +342,10 @@ class QueueBase(object): val.get_shape().assert_is_compatible_with(shape) if self._queue_ref.dtype == _dtypes.resource: - return gen_data_flow_ops._queue_enqueue_v2( + return gen_data_flow_ops.queue_enqueue_v2( self._queue_ref, vals, name=scope) else: - return gen_data_flow_ops._queue_enqueue( + return gen_data_flow_ops.queue_enqueue( self._queue_ref, vals, name=scope) def enqueue_many(self, vals, name=None): @@ -387,7 +387,7 @@ class QueueBase(object): val.get_shape().with_rank_at_least(1)[0]) val.get_shape()[1:].assert_is_compatible_with(shape) - return gen_data_flow_ops._queue_enqueue_many_v2( + return gen_data_flow_ops.queue_enqueue_many_v2( self._queue_ref, vals, name=scope) def _dequeue_return_value(self, tensors): @@ -436,15 +436,15 @@ class QueueBase(object): if name is None: name = "%s_Dequeue" % self._name if self._queue_ref.dtype == _dtypes.resource: - ret = gen_data_flow_ops._queue_dequeue_v2( + ret = gen_data_flow_ops.queue_dequeue_v2( self._queue_ref, self._dtypes, name=name) else: - ret = gen_data_flow_ops._queue_dequeue( + ret = gen_data_flow_ops.queue_dequeue( self._queue_ref, self._dtypes, name=name) # NOTE(mrry): Not using a shape function because we need access to # the `QueueBase` object. - if context.in_graph_mode(): + if not context.executing_eagerly(): op = ret[0].op for output, shape in zip(op.values(), self._shapes): output.set_shape(shape) @@ -479,12 +479,12 @@ class QueueBase(object): if name is None: name = "%s_DequeueMany" % self._name - ret = gen_data_flow_ops._queue_dequeue_many_v2( + ret = gen_data_flow_ops.queue_dequeue_many_v2( self._queue_ref, n=n, component_types=self._dtypes, name=name) # NOTE(mrry): Not using a shape function because we need access to # the Queue object. - if context.in_graph_mode(): + if not context.executing_eagerly(): op = ret[0].op batch_dim = tensor_shape.Dimension( tensor_util.constant_value(op.inputs[1])) @@ -523,12 +523,12 @@ class QueueBase(object): if name is None: name = "%s_DequeueUpTo" % self._name - ret = gen_data_flow_ops._queue_dequeue_up_to_v2( + ret = gen_data_flow_ops.queue_dequeue_up_to_v2( self._queue_ref, n=n, component_types=self._dtypes, name=name) # NOTE(mrry): Not using a shape function because we need access to # the Queue object. - if context.in_graph_mode(): + if not context.executing_eagerly(): op = ret[0].op for output, shape in zip(op.values(), self._shapes): output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape)) @@ -560,12 +560,12 @@ class QueueBase(object): if name is None: name = "%s_Close" % self._name if self._queue_ref.dtype == _dtypes.resource: - return gen_data_flow_ops._queue_close_v2( + return gen_data_flow_ops.queue_close_v2( self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues, name=name) else: - return gen_data_flow_ops._queue_close( + return gen_data_flow_ops.queue_close( self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues, name=name) @@ -601,9 +601,9 @@ class QueueBase(object): if name is None: name = "%s_Size" % self._name if self._queue_ref.dtype == _dtypes.resource: - return gen_data_flow_ops._queue_size_v2(self._queue_ref, name=name) + return gen_data_flow_ops.queue_size_v2(self._queue_ref, name=name) else: - return gen_data_flow_ops._queue_size(self._queue_ref, name=name) + return gen_data_flow_ops.queue_size(self._queue_ref, name=name) @tf_export("RandomShuffleQueue") @@ -683,7 +683,7 @@ class RandomShuffleQueue(QueueBase): # the id of the last op created.) string = (str(seed1) + shared_name).encode("utf-8") seed2 = int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF - queue_ref = gen_data_flow_ops._random_shuffle_queue_v2( + queue_ref = gen_data_flow_ops.random_shuffle_queue_v2( component_types=dtypes, shapes=shapes, capacity=capacity, @@ -748,7 +748,7 @@ class FIFOQueue(QueueBase): dtypes = _as_type_list(dtypes) shapes = _as_shape_list(shapes, dtypes) names = _as_name_list(names, dtypes) - queue_ref = gen_data_flow_ops._fifo_queue_v2( + queue_ref = gen_data_flow_ops.fifo_queue_v2( component_types=dtypes, shapes=shapes, capacity=capacity, @@ -827,7 +827,7 @@ class PaddingFIFOQueue(QueueBase): "but received %d dtypes and %d shapes." % (len(dtypes), len(shapes))) - queue_ref = gen_data_flow_ops._padding_fifo_queue_v2( + queue_ref = gen_data_flow_ops.padding_fifo_queue_v2( component_types=dtypes, shapes=shapes, capacity=capacity, @@ -895,7 +895,7 @@ class PriorityQueue(QueueBase): types = _as_type_list(types) shapes = _as_shape_list(shapes, types) - queue_ref = gen_data_flow_ops._priority_queue_v2( + queue_ref = gen_data_flow_ops.priority_queue_v2( component_types=types, shapes=shapes, capacity=capacity, @@ -985,15 +985,15 @@ class Barrier(object): else: self._shapes = [tensor_shape.unknown_shape() for _ in self._types] - self._barrier_ref = gen_data_flow_ops._barrier( + self._barrier_ref = gen_data_flow_ops.barrier( component_types=self._types, shapes=self._shapes, shared_name=shared_name, name=name) - if context.in_graph_mode(): - self._name = self._barrier_ref.op.name.split("/")[-1] - else: + if context.executing_eagerly(): self._name = context.context().scope_name + else: + self._name = self._barrier_ref.op.name.split("/")[-1] @property def barrier_ref(self): @@ -1003,9 +1003,9 @@ class Barrier(object): @property def name(self): """The name of the underlying barrier.""" - if context.in_graph_mode(): - return self._barrier_ref.op.name - return self._name + if context.executing_eagerly(): + return self._name + return self._barrier_ref.op.name def insert_many(self, component_index, keys, values, name=None): """For each key, assigns the respective value to the specified component. @@ -1026,7 +1026,7 @@ class Barrier(object): """ if name is None: name = "%s_BarrierInsertMany" % self._name - return gen_data_flow_ops._barrier_insert_many( + return gen_data_flow_ops.barrier_insert_many( self._barrier_ref, keys, values, component_index, name=name) def take_many(self, @@ -1073,7 +1073,7 @@ class Barrier(object): """ if name is None: name = "%s_BarrierTakeMany" % self._name - ret = gen_data_flow_ops._barrier_take_many( + ret = gen_data_flow_ops.barrier_take_many( self._barrier_ref, num_elements, self._types, @@ -1083,7 +1083,7 @@ class Barrier(object): # NOTE(mrry): Not using a shape function because we need access to # the Barrier object. - if context.in_graph_mode(): + if not context.executing_eagerly(): op = ret[0].op if allow_small_batch: batch_dim = None @@ -1122,7 +1122,7 @@ class Barrier(object): """ if name is None: name = "%s_BarrierClose" % self._name - return gen_data_flow_ops._barrier_close( + return gen_data_flow_ops.barrier_close( self._barrier_ref, cancel_pending_enqueues=cancel_pending_enqueues, name=name) @@ -1139,7 +1139,7 @@ class Barrier(object): """ if name is None: name = "%s_BarrierReadySize" % self._name - return gen_data_flow_ops._barrier_ready_size(self._barrier_ref, name=name) + return gen_data_flow_ops.barrier_ready_size(self._barrier_ref, name=name) def incomplete_size(self, name=None): """Compute the number of incomplete elements in the given barrier. @@ -1153,7 +1153,7 @@ class Barrier(object): """ if name is None: name = "%s_BarrierIncompleteSize" % self._name - return gen_data_flow_ops._barrier_incomplete_size( + return gen_data_flow_ops.barrier_incomplete_size( self._barrier_ref, name=name) @@ -1183,10 +1183,10 @@ class ConditionalAccumulatorBase(object): else: self._shape = tensor_shape.unknown_shape() self._accumulator_ref = accumulator_ref - if context.in_graph_mode(): - self._name = self._accumulator_ref.op.name.split("/")[-1] - else: + if context.executing_eagerly(): self._name = context.context().scope_name + else: + self._name = self._accumulator_ref.op.name.split("/")[-1] @property def accumulator_ref(self): diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py index 4071e50e815b01d30f3e24ba4677cc37b325f24d..7c43bf54fc783815127f03cc287ab0fc4349beb5 100644 --- a/tensorflow/python/ops/distributions/distribution.py +++ b/tensorflow/python/ops/distributions/distribution.py @@ -338,6 +338,27 @@ class Distribution(_BaseDistribution): cum_prob_invalid = u.cdf([4.0, 5.0, 6.0]) ``` + #### Shapes + + There are three important concepts associated with TensorFlow Distributions + shapes: + - Event shape describes the shape of a single draw from the distribution; + it may be dependent across dimensions. For scalar distributions, the event + shape is `[]`. For a 5-dimensional MultivariateNormal, the event shape is + `[5]`. + - Batch shape describes independent, not identically distributed draws, aka a + "collection" or "bunch" of distributions. + - Sample shape describes independent, identically distributed draws of batches + from the distribution family. + + The event shape and the batch shape are properties of a Distribution object, + whereas the sample shape is associated with a specific call to `sample` or + `log_prob`. + + For detailed usage examples of TensorFlow Distributions shapes, see + [this tutorial]( + https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Understanding%20TensorFlow%20Distributions%20Shapes.ipynb) + #### Parameter values leading to undefined statistics or distributions. Some distributions do not have well-defined statistics for all initialization @@ -593,7 +614,7 @@ class Distribution(_BaseDistribution): Returns: batch_shape: `TensorShape`, possibly unknown. """ - return self._batch_shape() + return tensor_shape.as_shape(self._batch_shape()) def _event_shape_tensor(self): raise NotImplementedError("event_shape_tensor is not implemented") @@ -626,7 +647,7 @@ class Distribution(_BaseDistribution): Returns: event_shape: `TensorShape`, possibly unknown. """ - return self._event_shape() + return tensor_shape.as_shape(self._event_shape()) def is_scalar_event(self, name="is_scalar_event"): """Indicates that `event_shape == []`. @@ -1105,6 +1126,34 @@ class Distribution(_BaseDistribution): with self._name_scope(name): return self._kl_divergence(other) + def __str__(self): + return ("tf.distributions.{type_name}(" + "\"{self_name}\"" + "{maybe_batch_shape}" + "{maybe_event_shape}" + ", dtype={dtype})".format( + type_name=type(self).__name__, + self_name=self.name, + maybe_batch_shape=(", batch_shape={}".format(self.batch_shape) + if self.batch_shape.ndims is not None + else ""), + maybe_event_shape=(", event_shape={}".format(self.event_shape) + if self.event_shape.ndims is not None + else ""), + dtype=self.dtype.name)) + + def __repr__(self): + return ("".format( + type_name=type(self).__name__, + self_name=self.name, + batch_shape=self.batch_shape, + event_shape=self.event_shape, + dtype=self.dtype.name)) + @contextlib.contextmanager def _name_scope(self, name=None, values=None): """Helper function to standardize op scope.""" diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py index 8fb218be3ac7e17e18d85b8e1c100ccd58aa1034..adb1f4f9a879e44cf8cb4cafd22b92554f487712 100644 --- a/tensorflow/python/ops/distributions/gamma.py +++ b/tensorflow/python/ops/distributions/gamma.py @@ -193,12 +193,6 @@ class Gamma(distribution.Distribution): def _log_prob(self, x): return self._log_unnormalized_prob(x) - self._log_normalization() - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - - def _log_cdf(self, x): - return math_ops.log(self._cdf(x)) - def _cdf(self, x): x = self._maybe_assert_valid_sample(x) # Note that igamma returns the regularized incomplete gamma function, diff --git a/tensorflow/python/ops/distributions/normal.py b/tensorflow/python/ops/distributions/normal.py index e7f120ea2da525e20a1ae42e6418cf2ac83686af..32e8a49c81bc4b23d8897639998dd33942b41a80 100644 --- a/tensorflow/python/ops/distributions/normal.py +++ b/tensorflow/python/ops/distributions/normal.py @@ -188,9 +188,6 @@ class Normal(distribution.Distribution): def _log_prob(self, x): return self._log_unnormalized_prob(x) - self._log_normalization() - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - def _log_cdf(self, x): return special_math.log_ndtr(self._z(x)) diff --git a/tensorflow/python/ops/distributions/student_t.py b/tensorflow/python/ops/distributions/student_t.py index 778fefb8c2991153b7e7a1f20df61680153dab2a..9d9e65b4e8d6d2e40bf9c263339f899439c842c3 100644 --- a/tensorflow/python/ops/distributions/student_t.py +++ b/tensorflow/python/ops/distributions/student_t.py @@ -248,9 +248,6 @@ class StudentT(distribution.Distribution): math_ops.lgamma(0.5 * self.df) - math_ops.lgamma(0.5 * (self.df + 1.))) - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - def _cdf(self, x): # Take Abs(scale) to make subsequent where work correctly. y = (x - self.loc) / math_ops.abs(self.scale) diff --git a/tensorflow/python/ops/distributions/uniform.py b/tensorflow/python/ops/distributions/uniform.py index 3580af18f241d777c81340f1c565074914838029..ec623b55eb0067e16599c18c9c504635da863907 100644 --- a/tensorflow/python/ops/distributions/uniform.py +++ b/tensorflow/python/ops/distributions/uniform.py @@ -45,11 +45,12 @@ class Uniform(distribution.Distribution): Z = b - a ``` - where: - * `low = a`, - * `high = b`, - * `Z` is the normalizing constant, and, - * `I[predicate]` is the [indicator function]( + where + + - `low = a`, + - `high = b`, + - `Z` is the normalizing constant, and + - `I[predicate]` is the [indicator function]( https://en.wikipedia.org/wiki/Indicator_function) for `predicate`. The parameters `low` and `high` must be shaped in a way that supports @@ -164,9 +165,6 @@ class Uniform(distribution.Distribution): seed=seed) return self.low + self.range() * samples - def _log_prob(self, x): - return math_ops.log(self._prob(x)) - def _prob(self, x): broadcasted_x = x * array_ops.ones(self.batch_shape_tensor()) return array_ops.where( @@ -178,9 +176,6 @@ class Uniform(distribution.Distribution): array_ops.zeros_like(broadcasted_x), array_ops.ones_like(broadcasted_x) / self.range())) - def _log_cdf(self, x): - return math_ops.log(self.cdf(x)) - def _cdf(self, x): broadcast_shape = array_ops.broadcast_dynamic_shape( array_ops.shape(x), self.batch_shape_tensor()) diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index 3826585f59c31133b12c365816729e090c9ab561..f0120f2957db12caf6a513fde9aa8c756aff8bad 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -35,34 +35,14 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export -def _gather(params, ids, name=None): - """Helper function for _embedding_lookup_and_transform. - - This function gathers embeddings from a single tensor. The gather deals with - resource variables specially. - - Args: - params: A `Tensor` of embeddings. - ids: A `Tensor` indexing the embeddings to be retrieved from `params`. - name: A name for the operation (optional). - - Returns: - A `Tensor` with the same type as `params`. - """ - if isinstance(params, resource_variable_ops.ResourceVariable): - return params.sparse_read(ids, name=name) - else: - return array_ops.gather(params, ids, name=name) - - def _clip(params, ids, max_norm): """Helper function for _embedding_lookup_and_transform. This function optionally clips embeddings to an l2-norm of max_norm. Args: - params: A `Tensor` of embeddings retrieved by `_gather`. - ids: The `ids` argument that was passed to `_gather`. + params: A `Tensor` of embeddings retrieved by `gather`. + ids: The `ids` argument that was passed to `gather`. max_norm: If provided, the embeddings are l2-normalized to the value of max_norm. @@ -148,7 +128,8 @@ def _embedding_lookup_and_transform(params, ids = ops.convert_to_tensor(ids, name="ids") if np == 1 and (not transform_fn or ids.get_shape().ndims == 1): with ops.colocate_with(params[0]): - result = _clip(_gather(params[0], ids, name=name), ids, max_norm) + result = _clip(array_ops.gather(params[0], ids, name=name), + ids, max_norm) if transform_fn: result = transform_fn(result) return result @@ -212,7 +193,7 @@ def _embedding_lookup_and_transform(params, for p in xrange(np): pids = gather_ids[p] with ops.colocate_with(params[p]): - result = _gather(params[p], pids) + result = array_ops.gather(params[p], pids) if transform_fn: # If transform_fn is provided, the clip_by_norm precedes # the transform and hence must be co-located. See below @@ -396,8 +377,8 @@ def embedding_lookup_sparse(params, with `combiner`="mean", then the output will be a 3x20 matrix where output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) - output[1, :] = params[0, :] * 1.0 - output[2, :] = params[1, :] * 3.0 + output[1, :] = (params[0, :] * 1.0) / 1.0 + output[2, :] = (params[1, :] * 3.0) / 3.0 Raises: TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index ac03d30fcd2e65f032937d9259bc8fff18626619..a840b1eddfc6922dc310490e8166efd73480c437 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -41,7 +41,7 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops.gen_functional_ops import * # pylint: enable=wildcard-import # pylint: disable=unused-import -from tensorflow.python.ops.gen_functional_ops import _symbolic_gradient +from tensorflow.python.ops.gen_functional_ops import symbolic_gradient # pylint: enable=unused-import from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -90,7 +90,7 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, if not callable(fn): raise TypeError("fn must be callable.") - in_graph_mode = context.in_graph_mode() + in_graph_mode = not context.executing_eagerly() with ops.name_scope(name, "foldl", [elems]): # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager @@ -178,7 +178,7 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, if not callable(fn): raise TypeError("fn must be callable.") - in_graph_mode = context.in_graph_mode() + in_graph_mode = not context.executing_eagerly() with ops.name_scope(name, "foldr", [elems]): # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager @@ -343,7 +343,7 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, elems_flat = input_flatten(elems) - in_graph_mode = context.in_graph_mode() + in_graph_mode = not context.executing_eagerly() with ops.name_scope(name, "map", elems_flat): # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager @@ -364,8 +364,8 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, dtype = dtype or input_pack([elem.dtype for elem in elems_flat]) dtype_flat = output_flatten(dtype) - # Convert elems to tensor array. - n = array_ops.shape(elems_flat[0])[0] + # Convert elems to tensor array. n may be known statically. + n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0] # TensorArrays are always flat elems_ta = [ @@ -536,7 +536,7 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, elems_flat = input_flatten(elems) - in_graph_mode = context.in_graph_mode() + in_graph_mode = not context.executing_eagerly() with ops.name_scope(name, "scan", elems_flat): # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager @@ -555,7 +555,8 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, elems_flat = [ ops.convert_to_tensor(elem, name="elem") for elem in elems_flat] - n = array_ops.shape(elems_flat[0])[0] + # Convert elems to tensor array. n may be known statically. + n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0] # TensorArrays are always flat elems_ta = [ @@ -615,7 +616,8 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, _, _, r_a = control_flow_ops.while_loop( lambda i, _1, _2: i < n, compute, (i, a_flat, accs_ta), parallel_iterations=parallel_iterations, - back_prop=back_prop, swap_memory=swap_memory) + back_prop=back_prop, swap_memory=swap_memory, + maximum_iterations=n) results_flat = [r.stack() for r in r_a] diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py index 921fd50aa9fdd1a1e493708f4bc8c66996e26e2c..2668e8f60cd2864fd59ffa3fb539380d34a34004 100644 --- a/tensorflow/python/ops/gradients.py +++ b/tensorflow/python/ops/gradients.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import +from tensorflow.python.eager.backprop import GradientTape +from tensorflow.python.ops.custom_gradient import custom_gradient from tensorflow.python.ops.gradients_impl import AggregationMethod from tensorflow.python.ops.gradients_impl import gradients from tensorflow.python.ops.gradients_impl import hessians @@ -28,6 +30,8 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ # TODO(drpng): find a good place to reference this. "AggregationMethod", + "GradientTape", + "custom_gradient", "gradients", # tf.gradients.gradients. "hessians", # tf.gradients.hessians ] diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 1418c0b10fb60601e7c3024891b89aadb53e6873..44473ec69c8ac6cf565f635621eebff7bc403225 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -86,17 +86,19 @@ def _IndexedSlicesToTensor(value, dtype=None, name=None, as_ref=False): % str(value)) # TODO(mrry): Consider adding static shape information to # IndexedSlices, to avoid using numpy here. - dense_shape_value = tensor_util.constant_value(value.dense_shape) - if dense_shape_value is not None: - num_elements = np.prod(dense_shape_value) - if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS: + if not context.executing_eagerly(): + dense_shape_value = tensor_util.constant_value(value.dense_shape) + if dense_shape_value is not None: + num_elements = np.prod(dense_shape_value) + if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS: + warnings.warn( + "Converting sparse IndexedSlices to a dense Tensor with %d " + "elements. This may consume a large amount of memory." % + num_elements) + else: warnings.warn( - "Converting sparse IndexedSlices to a dense Tensor with %d elements. " - "This may consume a large amount of memory." % num_elements) - else: - warnings.warn( - "Converting sparse IndexedSlices to a dense Tensor of unknown shape. " - "This may consume a large amount of memory.") + "Converting sparse IndexedSlices to a dense Tensor of unknown shape. " + "This may consume a large amount of memory.") return math_ops.unsorted_segment_sum( value.values, value.indices, value.dense_shape[0], name=name) @@ -354,7 +356,7 @@ def _SymGrad(op, out_grads): for k in op.node_def.attr: f.attr[k].CopyFrom(op.node_def.attr[k]) # pylint: disable=protected-access - in_grads = functional_ops._symbolic_gradient(input=f_in, Tout=f_types, f=f) + in_grads = functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f) # pylint: enable=protected-access return in_grads @@ -478,9 +480,21 @@ def gradients(ys, RuntimeError: if called in Eager mode. """ - if context.in_eager_mode(): - raise RuntimeError("tf.gradients not supported in EAGER mode. Use " - "functions in tf.contrib.eager.backprop instead.") + # Creating the gradient graph for control flow mutates Operations. _lock + # ensures a Session.run call cannot occur between creating and mutating new + # ops. + with ops.get_default_graph()._lock: # pylint: disable=protected-access + return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, + gate_gradients, aggregation_method, stop_gradients) + + +def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, + gate_gradients, aggregation_method, stop_gradients): + """Implementation of gradients().""" + if context.executing_eagerly(): + raise RuntimeError("tf.gradients not supported when eager execution " + "is enabled. Use tf.contrib.eager.GradientTape " + "instead.") ys = _AsList(ys) xs = _AsList(xs) stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients) diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index d39b934819177e3c15af95a0777ba96869c5e9cf..c94f1396b28e2124c6e5123cf711ac86abf174ab 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_grad # pylint: disable=unused-import from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import from tensorflow.python.ops import functional_ops # pylint: disable=unused-import @@ -661,6 +662,7 @@ class HessianTest(test_util.TensorFlowTestCase): self.assertAllEqual((m, n, m, n), hess_actual.shape) self.assertAllClose(hess_value, hess_actual.reshape((m * n, m * n))) + @test_util.with_c_api class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase): @@ -741,6 +743,59 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase): "of unknown shape. This may consume a large amount of memory." in str(w[0].message)) + def testCustomGradientTrivial(self): + + @custom_gradient.custom_gradient + def MyIdentity(x): + + def Grad(dy): + return [3 * dy] + + return x, Grad + + with ops.Graph().as_default(): + x = constant(3.) + y = MyIdentity(MyIdentity(x)) + dy = gradients.gradients(y, x)[0] + with session.Session(): + self.assertEqual(9., dy.eval()) + + def testCustomGradient(self): + + @custom_gradient.custom_gradient + def MyMultiply(x1, x2): + result = x1 * x2 + + def Grad(dy): + # Switched the ordering here. + return [dy * x1, dy * x2] + + return result, Grad + + with ops.Graph().as_default(): + x1 = constant(3.) + x2 = constant(5.) + y = MyMultiply(x1, x2) + dy = gradients.gradients(y, [x1, x2]) + with session.Session() as sess: + self.assertAllEqual([3., 5.], sess.run(dy)) + + def testCustomGradientErrors(self): + + @custom_gradient.custom_gradient + def F(x): + + def Grad(_): + raise RuntimeError("x") + + return x, Grad + + with ops.Graph().as_default(): + x = constant(1.0) + y = F(x) + with self.assertRaises(RuntimeError): + gradients.gradients(y, x) + @test_util.with_c_api class OnlyRealGradientsTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt deleted file mode 100644 index 9b8172bf2639cca0efb663ff4075b36d6f4f2245..0000000000000000000000000000000000000000 --- a/tensorflow/python/ops/hidden_ops.txt +++ /dev/null @@ -1,394 +0,0 @@ -# array_ops -BatchToSpace -BroadcastArgs -BroadcastGradientArgs -ConcatOffset -Concat -ConcatV2 -ConjugateTranspose -Const -DebugGradientIdentity -DebugGradientRefIdentity -EditDistance -ExpandDims -ListDiff -MirrorPad -MirrorPadGrad -OneHot -Pack -Pad -PadV2 -ParallelConcat -Placeholder -RefIdentity -Reverse -Snapshot -SpaceToBatch -Split -SplitV -Squeeze -Slice -TileGrad # Exported through array_grad instead of array_ops. -ZerosLike # TODO(josh11b): Use this instead of the Python version. -Unique -UniqueV2 -UniqueWithCounts -UniqueWithCountsV2 -Unpack - -# candidate_sampling_ops -AllCandidateSampler -ComputeAccidentalHits -FixedUnigramCandidateSampler -LearnedUnigramCandidateSampler -LogUniformCandidateSampler -ThreadUnsafeUnigramCandidateSampler -UniformCandidateSampler - -# checkpoint_ops -GenerateVocabRemapping -LoadAndRemapMatrix - - -# control_flow_ops -Switch -Merge -RefMerge -Exit -RefExit - -# ctc_ops -CTCLoss -CTCGreedyDecoder -CTCBeamSearchDecoder - -# data_flow_ops -Barrier -BarrierClose -BarrierIncompleteSize -BarrierInsertMany -BarrierReadySize -BarrierTakeMany -DeleteSessionTensor -FakeQueue -FIFOQueue -FIFOQueueV2 -GetSessionHandle -GetSessionHandleV2 -GetSessionTensor -HashTable -HashTableV2 -InitializeTable -InitializeTableV2 -InitializeTableFromTextFile -InitializeTableFromTextFileV2 -LookupTableExport -LookupTableExportV2 -LookupTableFind -LookupTableFindV2 -LookupTableImport -LookupTableImportV2 -LookupTableInsert -LookupTableInsertV2 -LookupTableSize -LookupTableSizeV2 -MutableDenseHashTable -MutableDenseHashTableV2 -MutableHashTable -MutableHashTableV2 -MutableHashTableOfTensors -MutableHashTableOfTensorsV2 -Mutex -MutexAcquire -MutexRelease -PaddingFIFOQueue -PaddingFIFOQueueV2 -PriorityQueue -PriorityQueueV2 -QueueClose -QueueCloseV2 -QueueDequeue -QueueDequeueV2 -QueueDequeueMany -QueueDequeueManyV2 -QueueDequeueUpTo -QueueDequeueUpToV2 -QueueEnqueue -QueueEnqueueV2 -QueueEnqueueMany -QueueEnqueueManyV2 -QueueSize -QueueSizeV2 -RandomShuffleQueue -RandomShuffleQueueV2 -Stack -StackClose -StackPop -StackPush -StackV2 -StackCloseV2 -StackPopV2 -StackPushV2 -TensorArray -TensorArrayClose -TensorArrayCloseV2 -TensorArrayConcat -TensorArrayConcatV2 -TensorArrayGather -TensorArrayGatherV2 -TensorArrayGrad -TensorArrayGradV2 -TensorArrayPack -TensorArrayPackV2 -TensorArrayRead -TensorArrayReadV2 -TensorArrayScatter -TensorArrayScatterV2 -TensorArraySize -TensorArraySizeV2 -TensorArraySplit -TensorArraySplitV2 -TensorArrayUnpack -TensorArrayUnpackV2 -TensorArrayV2 -TensorArrayWrite -TensorArrayWriteV2 -TensorArrayV3 -TensorArrayCloseV3 -TensorArrayConcatV3 -TensorArrayGatherV3 -TensorArrayGradV3 -TensorArrayReadV3 -TensorArrayPackV3 -TensorArrayScatterV3 -TensorArraySizeV3 -TensorArraySplitV3 -TensorArrayUnpackV3 -TensorArrayWriteV3 - -# functional_ops -SymbolicGradient - -# image_ops -AdjustContrastv2 -NonMaxSuppression -NonMaxSuppressionV2 -RandomCrop -ResizeBilinearGrad -ResizeBicubicGrad -ResizeNearestNeighborGrad -SampleDistortedBoundingBox -SampleDistortedBoundingBoxV2 -ScaleImageGrad - -# io_ops -FixedLengthRecordReader -IdentityReader -ReaderNumRecordsProduced -ReaderNumWorkUnitsCompleted -ReaderRead -ReaderReadUpTo -ReaderReset -ReaderRestoreState -ReaderSerializeState -ReaderWorkQueueLength -FixedLengthRecordReaderV2 -IdentityReaderV2 -ReaderNumRecordsProducedV2 -ReaderNumWorkUnitsCompletedV2 -ReaderReadV2 -ReaderReadUpToV2 -ReaderResetV2 -ReaderRestoreStateV2 -ReaderSerializeStateV2 -ReaderWorkQueueLengthV2 -Restore -RestoreSlice -Save -SaveSlices -ShardedFilename -ShardedFilespec -TextLineReader -TFRecordReader -WholeFileReader -TextLineReaderV2 -TFRecordReaderV2 -WholeFileReaderV2 -LMDBReader -DecodeCSV - -# linalg_ops -BatchCholesky -BatchCholeskyGrad -BatchMatrixDeterminant -BatchMatrixInverse -BatchMatrixSolve -BatchMatrixSolveLs -BatchMatrixTriangularSolve -BatchSelfAdjointEig -BatchSelfAdjointEigV2 -BatchSvd -LogMatrixDeterminant -MatrixExponential -MatrixLogarithm -MatrixSolveLs -SelfAdjointEig -SelfAdjointEigV2 -Svd - -# logging_ops -Assert -AudioSummary -AudioSummaryV2 -HistogramSummary -ImageSummary -MergeSummary -Print -ScalarSummary -TensorSummary -TensorSummaryV2 - -# math_ops -Abs -AccumulateNV2 -AddN -AddV2 -All -Any -BatchMatMul -BatchFFT -BatchFFT2D -BatchFFT3D -BatchIFFT -BatchIFFT2D -BatchIFFT3D -Bucketize -Complex -ComplexAbs -Conj -FloorDiv -FloorMod -HistogramFixedWidth -Max -Mean -Min -Mul -Neg -Pow -Prod -Range -RealDiv -Select -SparseMatMul -Sub -Sum -MatMul -Sigmoid -Tanh -SigmoidGrad -TanhGrad -InvGrad -ReciprocalGrad -SqrtGrad -RsqrtGrad -TruncateDiv -TruncateMod - -# nn_ops -AvgPoolGrad # "*Grad" accessible through nn_grad instead of nn_ops. -AvgPool3DGrad -BatchNormWithGlobalNormalization -BatchNormWithGlobalNormalizationGrad -FusedBatchNorm -FusedBatchNormV2 -SoftmaxCrossEntropyWithLogits -SparseSoftmaxCrossEntropyWithLogits -LRNGrad -MaxPoolGrad -MaxPoolGradWithArgmax -MaxPoolGradGrad -MaxPoolGradGradWithArgmax -MaxPool3DGrad -MaxPool3DGradGrad -ReluGrad -Relu6Grad -EluGrad -SeluGrad -SoftplusGrad -SoftsignGrad -TopK -TopKV2 -BiasAdd -BiasAddV1 -Relu6 -AvgPool -MaxPool -MaxPoolV2 -Softmax -LogSoftmax -FractionalAvgPoolGrad -FractionalMaxPoolGrad -InTopK -InTopKV2 - -# parsing_ops -ParseExample -ParseSingleSequenceExample - -# random_ops -RandomGamma -RandomPoisson -RandomUniform -RandomUniformInt -RandomShuffle -RandomStandardNormal -ParameterizedTruncatedNormal -TruncatedNormal - -# script_ops -PyFunc -PyFuncStateless -EagerPyFunc - -# sdca_ops - -# state_ops -Variable -VariableV2 -TemporaryVariable -DestroyTemporaryVariable - -# sparse_ops -AddSparseToTensorsMap -AddManySparseToTensorsMap -TakeManySparseFromTensorsMap -DeserializeManySparse -DeserializeSparse -SerializeManySparse -SerializeSparse -SparseAdd -SparseAddGrad -SparseConcat -SparseCross -SparseFillEmptyRows -SparseFillEmptyRowsGrad -SparseSplit -SparseSelectLastK -SparseReorder -SparseReshape -SparseToDense -SparseTensorDenseAdd -SparseTensorDenseMatMul - -# string_ops -StringSplit - -# user_ops -Fact - -# training_ops -# (None) - -# word2vec deprecated ops -NegTrain -Skipgram diff --git a/tensorflow/python/ops/histogram_ops.py b/tensorflow/python/ops/histogram_ops.py index 6a975160b0698270dfc9ce9140e8b3ff633cdb9e..4a1ef54fb50013881aa832f83674ac66ecccd9bc 100644 --- a/tensorflow/python/ops/histogram_ops.py +++ b/tensorflow/python/ops/histogram_ops.py @@ -141,5 +141,7 @@ def histogram_fixed_width(values, """ with ops.name_scope(name, 'histogram_fixed_width', [values, value_range, nbins]) as name: - return gen_math_ops._histogram_fixed_width( # pylint: disable=protected-access + # pylint: disable=protected-access + return gen_math_ops._histogram_fixed_width( values, value_range, nbins, dtype=dtype, name=name) + # pylint: enable=protected-access diff --git a/tensorflow/python/ops/image_grad.py b/tensorflow/python/ops/image_grad.py index 093843cd5bc0b7c2281a0c9ddf52d93ea3faede3..9f43e3f1466d900ae6d39f3b9ef48043421cb777 100644 --- a/tensorflow/python/ops/image_grad.py +++ b/tensorflow/python/ops/image_grad.py @@ -41,12 +41,10 @@ def _ResizeNearestNeighborGrad(op, grad): else: image_shape = array_ops.shape(image)[1:3] - # pylint: disable=protected-access - grads = gen_image_ops._resize_nearest_neighbor_grad( + grads = gen_image_ops.resize_nearest_neighbor_grad( grad, image_shape, align_corners=op.get_attr("align_corners")) - # pylint: enable=protected-access return [grads, None] @@ -61,10 +59,8 @@ def _ResizeBilinearGrad(op, grad): Returns: The gradients w.r.t. the input. """ - # pylint: disable=protected-access - grad0 = gen_image_ops._resize_bilinear_grad( + grad0 = gen_image_ops.resize_bilinear_grad( grad, op.inputs[0], align_corners=op.get_attr("align_corners")) - # pylint: enable=protected-access return [grad0, None] @@ -82,10 +78,8 @@ def _ResizeBicubicGrad(op, grad): allowed_types = [dtypes.float32, dtypes.float64] grad0 = None if op.inputs[0].dtype in allowed_types: - # pylint: disable=protected-access - grad0 = gen_image_ops._resize_bicubic_grad( + grad0 = gen_image_ops.resize_bicubic_grad( grad, op.inputs[0], align_corners=op.get_attr("align_corners")) - # pylint: enable=protected-access return [grad0, None] diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index ae52d32fea1c872e588c4122f5e73198e4dfe9ad..68be9ccdd642823e7a9c2294f209accd16f45be5 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -69,6 +69,11 @@ See the @{$python/image} guide. @@non_max_suppression @@sample_distorted_bounding_box @@total_variation +@@psnr +@@ssim +@@ssim_multiscale +@@image_gradients +@@sobel_edges """ from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 58c18c6696d64ccca4ebfaa07242d3c7789116e4..3369fe3c9b37ca05311c5548dbfa3228ba04ee80 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -29,6 +31,8 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables @@ -1113,10 +1117,8 @@ def adjust_contrast(images, contrast_factor): orig_dtype = images.dtype flt_images = convert_image_dtype(images, dtypes.float32) - # pylint: disable=protected-access - adjusted = gen_image_ops._adjust_contrastv2( + adjusted = gen_image_ops.adjust_contrastv2( flt_images, contrast_factor=contrast_factor, name=name) - # pylint: enable=protected-access return convert_image_dtype(adjusted, orig_dtype, saturate=True) @@ -1730,7 +1732,7 @@ def sample_distorted_bounding_box(image_size, Provide as input to `tf.image.draw_bounding_boxes`. """ with ops.name_scope(name, 'sample_distorted_bounding_box'): - return gen_image_ops._sample_distorted_bounding_box_v2( # pylint: disable=protected-access + return gen_image_ops.sample_distorted_bounding_box_v2( image_size, bounding_boxes, seed=seed, @@ -1784,10 +1786,8 @@ def non_max_suppression(boxes, """ with ops.name_scope(name, 'non_max_suppression'): iou_threshold = ops.convert_to_tensor(iou_threshold, name='iou_threshold') - # pylint: disable=protected-access - return gen_image_ops._non_max_suppression_v2(boxes, scores, max_output_size, - iou_threshold) - # pylint: enable=protected-access + return gen_image_ops.non_max_suppression_v2(boxes, scores, max_output_size, + iou_threshold) _rgb_to_yiq_kernel = [[0.299, 0.59590059, @@ -1795,6 +1795,7 @@ _rgb_to_yiq_kernel = [[0.299, 0.59590059, [0.114, -0.32134392, 0.31119955]] +@tf_export('image.rgb_to_yiq') def rgb_to_yiq(images): """Converts one or more images from RGB to YIQ. @@ -1820,6 +1821,7 @@ _yiq_to_rgb_kernel = [[1, 1, 1], [0.95598634, -0.27201283, -1.10674021], [0.6208248, -0.64720424, 1.70423049]] +@tf_export('image.yiq_to_rgb') def yiq_to_rgb(images): """Converts one or more images from YIQ to RGB. @@ -1847,6 +1849,7 @@ _rgb_to_yuv_kernel = [[0.299, -0.14714119, [0.114, 0.43601035, -0.10001026]] +@tf_export('image.rgb_to_yuv') def rgb_to_yuv(images): """Converts one or more images from RGB to YUV. @@ -1872,6 +1875,7 @@ _yuv_to_rgb_kernel = [[1, 1, 1], [0, -0.394642334, 2.03206185], [1.13988303, -0.58062185, 0]] +@tf_export('image.yuv_to_rgb') def yuv_to_rgb(images): """Converts one or more images from YUV to RGB. @@ -1892,3 +1896,489 @@ def yuv_to_rgb(images): _yuv_to_rgb_kernel, dtype=images.dtype, name='kernel') ndims = images.get_shape().ndims return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]]) + + +def _verify_compatible_image_shapes(img1, img2): + """Checks if two image tensors are compatible for applying SSIM or PSNR. + + This function checks if two sets of images have ranks at least 3, and if the + last three dimensions match. + + Args: + img1: Tensor containing the first image batch. + img2: Tensor containing the second image batch. + + Returns: + A tuple containing: the first tensor shape, the second tensor shape, and a + list of control_flow_ops.Assert() ops implementing the checks. + + Raises: + ValueError: When static shape check fails. + """ + shape1 = img1.get_shape().with_rank_at_least(3) + shape2 = img2.get_shape().with_rank_at_least(3) + shape1[-3:].assert_is_compatible_with(shape2[-3:]) + + if shape1.ndims is not None and shape2.ndims is not None: + for dim1, dim2 in zip(reversed(shape1[:-3]), reversed(shape2[:-3])): + if not (dim1 == 1 or dim2 == 1 or dim1.is_compatible_with(dim2)): + raise ValueError( + 'Two images are not compatible: %s and %s' % (shape1, shape2)) + + # Now assign shape tensors. + shape1, shape2 = array_ops.shape_n([img1, img2]) + + # TODO(sjhwang): Check if shape1[:-3] and shape2[:-3] are broadcastable. + checks = [] + checks.append(control_flow_ops.Assert( + math_ops.greater_equal(array_ops.size(shape1), 3), + [shape1, shape2], summarize=10)) + checks.append(control_flow_ops.Assert( + math_ops.reduce_all(math_ops.equal(shape1[-3:], shape2[-3:])), + [shape1, shape2], summarize=10)) + return shape1, shape2, checks + + +@tf_export('image.psnr') +def psnr(a, b, max_val, name=None): + """Returns the Peak Signal-to-Noise Ratio between a and b. + + This is intended to be used on signals (or images). Produces a PSNR value for + each image in batch. + + The last three dimensions of input are expected to be [height, width, depth]. + + Example: + + ```python + # Read images from file. + im1 = tf.decode_png('path/to/im1.png') + im2 = tf.decode_png('path/to/im2.png') + # Compute PSNR over tf.uint8 Tensors. + psnr1 = tf.image.psnr(im1, im2, max_val=255) + + # Compute PSNR over tf.float32 Tensors. + im1 = tf.image.convert_image_dtype(im1, tf.float32) + im2 = tf.image.convert_image_dtype(im2, tf.float32) + psnr2 = tf.image.psnr(im1, im2, max_val=1.0) + # psnr1 and psnr2 both have type tf.float32 and are almost equal. + ``` + + Arguments: + a: First set of images. + b: Second set of images. + max_val: The dynamic range of the images (i.e., the difference between the + maximum the and minimum allowed values). + name: Namespace to embed the computation in. + + Returns: + The scalar PSNR between a and b. The returned tensor has type `tf.float32` + and shape [batch_size, 1]. + """ + with ops.name_scope(name, 'PSNR', [a, b]): + # Need to convert the images to float32. Scale max_val accordingly so that + # PSNR is computed correctly. + max_val = math_ops.cast(max_val, a.dtype) + max_val = convert_image_dtype(max_val, dtypes.float32) + a = convert_image_dtype(a, dtypes.float32) + b = convert_image_dtype(b, dtypes.float32) + mse = math_ops.reduce_mean(math_ops.squared_difference(a, b), [-3, -2, -1]) + psnr_val = math_ops.subtract( + 20 * math_ops.log(max_val) / math_ops.log(10.0), + np.float32(10 / np.log(10)) * math_ops.log(mse), + name='psnr') + + _, _, checks = _verify_compatible_image_shapes(a, b) + with ops.control_dependencies(checks): + return array_ops.identity(psnr_val) + +_SSIM_K1 = 0.01 +_SSIM_K2 = 0.03 + + +def _ssim_helper(x, y, reducer, max_val, compensation=1.0): + r"""Helper function for computing SSIM. + + SSIM estimates covariances with weighted sums. The default parameters + use a biased estimate of the covariance: + Suppose `reducer` is a weighted sum, then the mean estimators are + \mu_x = \sum_i w_i x_i, + \mu_y = \sum_i w_i y_i, + where w_i's are the weighted-sum weights, and covariance estimator is + cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y) + with assumption \sum_i w_i = 1. This covariance estimator is biased, since + E[cov_{xy}] = (1 - \sum_i w_i ^ 2) Cov(X, Y). + For SSIM measure with unbiased covariance estimators, pass as `compensation` + argument (1 - \sum_i w_i ^ 2). + + Arguments: + x: First set of images. + y: Second set of images. + reducer: Function that computes 'local' averages from set of images. + For non-covolutional version, this is usually tf.reduce_mean(x, [1, 2]), + and for convolutional version, this is usually tf.nn.avg_pool or + tf.nn.conv2d with weighted-sum kernel. + max_val: The dynamic range (i.e., the difference between the maximum + possible allowed value and the minimum allowed value). + compensation: Compensation factor. See above. + + Returns: + A pair containing the luminance measure, and the contrast-structure measure. + """ + c1 = (_SSIM_K1 * max_val) ** 2 + c2 = (_SSIM_K2 * max_val) ** 2 + + # SSIM luminance measure is + # (2 * mu_x * mu_y + c1) / (mu_x ** 2 + mu_y ** 2 + c1). + mean0 = reducer(x) + mean1 = reducer(y) + num0 = mean0 * mean1 * 2.0 + den0 = math_ops.square(mean0) + math_ops.square(mean1) + luminance = (num0 + c1) / (den0 + c1) + + # SSIM contrast-structure measure is + # (2 * cov_{xy} + c2) / (cov_{xx} + cov_{yy} + c2). + # Note that `reducer` is a weighted sum with weight w_k, \sum_i w_i = 1, then + # cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y) + # = \sum_i w_i x_i y_i - (\sum_i w_i x_i) (\sum_j w_j y_j). + num1 = reducer(x * y) * 2.0 + den1 = reducer(math_ops.square(x) + math_ops.square(y)) + c2 *= compensation + cs = (num1 - num0 + c2) / (den1 - den0 + c2) + + # SSIM score is the product of the luminance and contrast-structure measures. + return luminance, cs + + +def _fspecial_gauss(size, sigma): + """Function to mimic the 'fspecial' gaussian MATLAB function.""" + size = ops.convert_to_tensor(size, dtypes.int32) + sigma = ops.convert_to_tensor(sigma) + + coords = math_ops.cast(math_ops.range(size), sigma.dtype) + coords -= math_ops.cast(size - 1, sigma.dtype) / 2.0 + + g = math_ops.square(coords) + g *= -0.5 / math_ops.square(sigma) + + g = array_ops.reshape(g, shape=[1, -1]) + array_ops.reshape(g, shape=[-1, 1]) + g = array_ops.reshape(g, shape=[1, -1]) # For tf.nn.softmax(). + g = nn_ops.softmax(g) + return array_ops.reshape(g, shape=[size, size, 1, 1]) + + +def _ssim_per_channel(img1, img2, max_val=1.0): + """Computes SSIM index between img1 and img2 per color channel. + + This function matches the standard SSIM implementation from: + Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image + quality assessment: from error visibility to structural similarity. IEEE + transactions on image processing. + + Details: + - 11x11 Gaussian filter of width 1.5 is used. + - k1 = 0.01, k2 = 0.03 as in the original paper. + + Args: + img1: First image batch. + img2: Second image batch. + max_val: The dynamic range of the images (i.e., the difference between the + maximum the and minimum allowed values). + + Returns: + A pair of tensors containing and channel-wise SSIM and contrast-structure + values. The shape is [..., channels]. + """ + filter_size = constant_op.constant(11, dtype=dtypes.int32) + filter_sigma = constant_op.constant(1.5, dtype=img1.dtype) + + shape1, shape2 = array_ops.shape_n([img1, img2]) + checks = [ + control_flow_ops.Assert(math_ops.reduce_all(math_ops.greater_equal( + shape1[-3:-1], filter_size)), [shape1, filter_size], summarize=8), + control_flow_ops.Assert(math_ops.reduce_all(math_ops.greater_equal( + shape2[-3:-1], filter_size)), [shape2, filter_size], summarize=8)] + + # Enforce the check to run before computation. + with ops.control_dependencies(checks): + img1 = array_ops.identity(img1) + + # TODO(sjhwang): Try to cache kernels and compensation factor. + kernel = _fspecial_gauss(filter_size, filter_sigma) + kernel = array_ops.tile(kernel, multiples=[1, 1, shape1[-1], 1]) + + # The correct compensation factor is `1.0 - tf.reduce_sum(tf.square(kernel))`, + # but to match MATLAB implementation of MS-SSIM, we use 1.0 instead. + compensation = 1.0 + + # TODO(sjhwang): Try FFT. + # TODO(sjhwang): Gaussian kernel is separable in space. Consider applying + # 1-by-n and n-by-1 Gaussain filters instead of an n-by-n filter. + def reducer(x): + shape = array_ops.shape(x) + x = array_ops.reshape(x, shape=array_ops.concat([[-1], shape[-3:]], 0)) + y = nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID') + return array_ops.reshape(y, array_ops.concat([shape[:-3], + array_ops.shape(y)[1:]], 0)) + + luminance, cs = _ssim_helper(img1, img2, reducer, max_val, compensation) + + # Average over the second and the third from the last: height, width. + axes = constant_op.constant([-3, -2], dtype=dtypes.int32) + ssim_val = math_ops.reduce_mean(luminance * cs, axes) + cs = math_ops.reduce_mean(cs, axes) + return ssim_val, cs + + +@tf_export('image.ssim') +def ssim(img1, img2, max_val): + """Computes SSIM index between img1 and img2. + + This function is based on the standard SSIM implementation from: + Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image + quality assessment: from error visibility to structural similarity. IEEE + transactions on image processing. + + Note: The true SSIM is only defined on grayscale. This function does not + perform any colorspace transform. (If input is already YUV, then it will + compute YUV SSIM average.) + + Details: + - 11x11 Gaussian filter of width 1.5 is used. + - k1 = 0.01, k2 = 0.03 as in the original paper. + + The image sizes must be at least 11x11 because of the filter size. + + Example: + + ```python + # Read images from file. + im1 = tf.decode_png('path/to/im1.png') + im2 = tf.decode_png('path/to/im2.png') + # Compute SSIM over tf.uint8 Tensors. + ssim1 = tf.image.ssim(im1, im2, max_val=255) + + # Compute SSIM over tf.float32 Tensors. + im1 = tf.image.convert_image_dtype(im1, tf.float32) + im2 = tf.image.convert_image_dtype(im2, tf.float32) + ssim2 = tf.image.ssim(im1, im2, max_val=1.0) + # ssim1 and ssim2 both have type tf.float32 and are almost equal. + ``` + + Args: + img1: First image batch. + img2: Second image batch. + max_val: The dynamic range of the images (i.e., the difference between the + maximum the and minimum allowed values). + + Returns: + A tensor containing an SSIM value for each image in batch. Returned SSIM + values are in range (-1, 1], when pixel values are non-negative. Returns + a tensor with shape: broadcast(img1.shape[:-3], img2.shape[:-3]). + """ + _, _, checks = _verify_compatible_image_shapes(img1, img2) + with ops.control_dependencies(checks): + img1 = array_ops.identity(img1) + + # Need to convert the images to float32. Scale max_val accordingly so that + # SSIM is computed correctly. + max_val = math_ops.cast(max_val, img1.dtype) + max_val = convert_image_dtype(max_val, dtypes.float32) + img1 = convert_image_dtype(img1, dtypes.float32) + img2 = convert_image_dtype(img2, dtypes.float32) + ssim_per_channel, _ = _ssim_per_channel(img1, img2, max_val) + # Compute average over color channels. + return math_ops.reduce_mean(ssim_per_channel, [-1]) + + +# Default values obtained by Wang et al. +_MSSSIM_WEIGHTS = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) + + +@tf_export('image.ssim_multiscale') +def ssim_multiscale(img1, img2, max_val, power_factors=_MSSSIM_WEIGHTS): + """Computes the MS-SSIM between img1 and img2. + + This function assumes that `img1` and `img2` are image batches, i.e. the last + three dimensions are [height, width, channels]. + + Note: The true SSIM is only defined on grayscale. This function does not + perform any colorspace transform. (If input is already YUV, then it will + compute YUV SSIM average.) + + Original paper: Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. "Multiscale + structural similarity for image quality assessment." Signals, Systems and + Computers, 2004. + + Arguments: + img1: First image batch. + img2: Second image batch. Must have the same rank as img1. + max_val: The dynamic range of the images (i.e., the difference between the + maximum the and minimum allowed values). + power_factors: Iterable of weights for each of the scales. The number of + scales used is the length of the list. Index 0 is the unscaled + resolution's weight and each increasing scale corresponds to the image + being downsampled by 2. Defaults to (0.0448, 0.2856, 0.3001, 0.2363, + 0.1333), which are the values obtained in the original paper. + + Returns: + A tensor containing an MS-SSIM value for each image in batch. The values + are in range [0, 1]. Returns a tensor with shape: + broadcast(img1.shape[:-3], img2.shape[:-3]). + """ + # Shape checking. + shape1 = img1.get_shape().with_rank_at_least(3) + shape2 = img2.get_shape().with_rank_at_least(3) + shape1[-3:].merge_with(shape2[-3:]) + + with ops.name_scope(None, 'MS-SSIM', [img1, img2]): + shape1, shape2, checks = _verify_compatible_image_shapes(img1, img2) + with ops.control_dependencies(checks): + img1 = array_ops.identity(img1) + + # Need to convert the images to float32. Scale max_val accordingly so that + # SSIM is computed correctly. + max_val = math_ops.cast(max_val, img1.dtype) + max_val = convert_image_dtype(max_val, dtypes.float32) + img1 = convert_image_dtype(img1, dtypes.float32) + img2 = convert_image_dtype(img2, dtypes.float32) + + imgs = [img1, img2] + shapes = [shape1, shape2] + + # img1 and img2 are assumed to be a (multi-dimensional) batch of + # 3-dimensional images (height, width, channels). `heads` contain the batch + # dimensions, and `tails` contain the image dimensions. + heads = [s[:-3] for s in shapes] + tails = [s[-3:] for s in shapes] + + divisor = [1, 2, 2, 1] + divisor_tensor = constant_op.constant(divisor[1:], dtype=dtypes.int32) + + def do_pad(images, remainder): + padding = array_ops.expand_dims(remainder, -1) + padding = array_ops.pad(padding, [[1, 0], [1, 0]]) + return [array_ops.pad(x, padding, mode='SYMMETRIC') for x in images] + + mcs = [] + for k in range(len(power_factors)): + with ops.name_scope(None, 'Scale%d' % k, imgs): + if k > 0: + # Avg pool takes rank 4 tensors. Flatten leading dimensions. + flat_imgs = [ + array_ops.reshape(x, array_ops.concat([[-1], t], 0)) + for x, t in zip(imgs, tails) + ] + + remainder = tails[0] % divisor_tensor + need_padding = math_ops.reduce_any(math_ops.not_equal(remainder, 0)) + # pylint: disable=cell-var-from-loop + padded = control_flow_ops.cond(need_padding, + lambda: do_pad(flat_imgs, remainder), + lambda: flat_imgs) + # pylint: enable=cell-var-from-loop + + downscaled = [nn_ops.avg_pool(x, ksize=divisor, strides=divisor, + padding='VALID') + for x in padded] + tails = [x[1:] for x in array_ops.shape_n(downscaled)] + imgs = [ + array_ops.reshape(x, array_ops.concat([h, t], 0)) + for x, h, t in zip(downscaled, heads, tails) + ] + + # Overwrite previous ssim value since we only need the last one. + ssim_per_channel, cs = _ssim_per_channel(*imgs, max_val=max_val) + mcs.append(nn_ops.relu(cs)) + + # Remove the cs score for the last scale. In the MS-SSIM calculation, + # we use the l(p) at the highest scale. l(p) * cs(p) is ssim(p). + mcs.pop() # Remove the cs score for the last scale. + mcs_and_ssim = array_ops.stack(mcs + [nn_ops.relu(ssim_per_channel)], + axis=-1) + # Take weighted geometric mean across the scale axis. + ms_ssim = math_ops.reduce_prod(math_ops.pow(mcs_and_ssim, power_factors), + [-1]) + + return math_ops.reduce_mean(ms_ssim, [-1]) # Avg over color channels. + + +@tf_export('image.image_gradients') +def image_gradients(image): + """Returns image gradients (dy, dx) for each color channel. + + Both output tensors have the same shape as the input: [batch_size, h, w, + d]. The gradient values are organized so that [I(x+1, y) - I(x, y)] is in + location (x, y). That means that dy will always have zeros in the last row, + and dx will always have zeros in the last column. + + Arguments: + image: Tensor with shape [batch_size, h, w, d]. + + Returns: + Pair of tensors (dy, dx) holding the vertical and horizontal image + gradients (1-step finite difference). + + Raises: + ValueError: If `image` is not a 4D tensor. + """ + if image.get_shape().ndims != 4: + raise ValueError('image_gradients expects a 4D tensor ' + '[batch_size, h, w, d], not %s.', image.get_shape()) + image_shape = array_ops.shape(image) + batch_size, height, width, depth = array_ops.unstack(image_shape) + dy = image[:, 1:, :, :] - image[:, :-1, :, :] + dx = image[:, :, 1:, :] - image[:, :, :-1, :] + + # Return tensors with same size as original image by concatenating + # zeros. Place the gradient [I(x+1,y) - I(x,y)] on the base pixel (x, y). + shape = array_ops.stack([batch_size, 1, width, depth]) + dy = array_ops.concat([dy, array_ops.zeros(shape, image.dtype)], 1) + dy = array_ops.reshape(dy, image_shape) + + shape = array_ops.stack([batch_size, height, 1, depth]) + dx = array_ops.concat([dx, array_ops.zeros(shape, image.dtype)], 2) + dx = array_ops.reshape(dx, image_shape) + + return dy, dx + + +@tf_export('image.sobel_edges') +def sobel_edges(image): + """Returns a tensor holding Sobel edge maps. + + Arguments: + image: Image tensor with shape [batch_size, h, w, d] and type float32 or + float64. The image(s) must be 2x2 or larger. + + Returns: + Tensor holding edge maps for each channel. Returns a tensor with shape + [batch_size, h, w, d, 2] where the last two dimensions hold [[dy[0], dx[0]], + [dy[1], dx[1]], ..., [dy[d-1], dx[d-1]]] calculated using the Sobel filter. + """ + # Define vertical and horizontal Sobel filters. + static_image_shape = image.get_shape() + image_shape = array_ops.shape(image) + kernels = [[[-1, -2, -1], [0, 0, 0], [1, 2, 1]], + [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]] + num_kernels = len(kernels) + kernels = np.transpose(np.asarray(kernels), (1, 2, 0)) + kernels = np.expand_dims(kernels, -2) + kernels_tf = constant_op.constant(kernels, dtype=image.dtype) + + kernels_tf = array_ops.tile(kernels_tf, [1, 1, image_shape[-1], 1], + name='sobel_filters') + + # Use depth-wise convolution to calculate edge maps per channel. + pad_sizes = [[0, 0], [1, 1], [1, 1], [0, 0]] + padded = array_ops.pad(image, pad_sizes, mode='REFLECT') + + # Output tensor has shape [batch_size, h, w, d * num_kernels]. + strides = [1, 1, 1, 1] + output = nn.depthwise_conv2d(padded, kernels_tf, strides, 'VALID') + + # Reshape to [batch_size, h, w, d, num_kernels]. + shape = array_ops.concat([image_shape, [num_kernels]], 0) + output = array_ops.reshape(output, shape=shape) + output.set_shape(static_image_shape.concatenate([num_kernels])) + return output diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index b8c4b27c162acdd86d88da641ff8afffaa5a9e6a..c437c12c2744792eaee197bf7d2a5f2b75d280bf 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import colorsys import functools +import itertools import math import os import time @@ -37,7 +38,9 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_image_ops +from tensorflow.python.ops import gradients from tensorflow.python.ops import image_ops +from tensorflow.python.ops import image_ops_impl from tensorflow.python.ops import io_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -3328,5 +3331,420 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase): image_ops.non_max_suppression(boxes, scores, 3, [[0.5]]) +class VerifyCompatibleImageShapesTest(test_util.TensorFlowTestCase): + """Tests utility function used by ssim() and psnr().""" + + def testWrongDims(self): + img = array_ops.placeholder(dtype=dtypes.float32) + img_np = np.array((2, 2)) + + with self.test_session(use_gpu=True) as sess: + _, _, checks = image_ops_impl._verify_compatible_image_shapes(img, img) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(checks, {img: img_np}) + + def testShapeMismatch(self): + img1 = array_ops.placeholder(dtype=dtypes.float32) + img2 = array_ops.placeholder(dtype=dtypes.float32) + + img1_np = np.array([1, 2, 2, 1]) + img2_np = np.array([1, 3, 3, 1]) + + with self.test_session(use_gpu=True) as sess: + _, _, checks = image_ops_impl._verify_compatible_image_shapes(img1, img2) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(checks, {img1: img1_np, img2: img2_np}) + + +class PSNRTest(test_util.TensorFlowTestCase): + """Tests for PSNR.""" + + def _LoadTestImage(self, sess, filename): + content = io_ops.read_file(os.path.join( + "tensorflow/core/lib/psnr/testdata", filename)) + im = image_ops.decode_jpeg(content, dct_method="INTEGER_ACCURATE") + im = image_ops.convert_image_dtype(im, dtypes.float32) + im, = sess.run([im]) + return np.expand_dims(im, axis=0) + + def _LoadTestImages(self): + with self.test_session(use_gpu=True) as sess: + q20 = self._LoadTestImage(sess, "cat_q20.jpg") + q72 = self._LoadTestImage(sess, "cat_q72.jpg") + q95 = self._LoadTestImage(sess, "cat_q95.jpg") + return q20, q72, q95 + + def _PSNR_NumPy(self, orig, target, max_value): + """Numpy implementation of PSNR.""" + mse = ((orig - target) ** 2).mean(axis=(-3, -2, -1)) + return 20 * np.log10(max_value) - 10 * np.log10(mse) + + def _RandomImage(self, shape, max_val): + """Returns an image or image batch with given shape.""" + return np.random.rand(*shape).astype(np.float32) * max_val + + def testPSNRSingleImage(self): + image1 = self._RandomImage((8, 8, 1), 1) + image2 = self._RandomImage((8, 8, 1), 1) + psnr = self._PSNR_NumPy(image1, image2, 1) + + with self.test_session(use_gpu=True): + tf_image1 = constant_op.constant(image1, shape=image1.shape, + dtype=dtypes.float32) + tf_image2 = constant_op.constant(image2, shape=image2.shape, + dtype=dtypes.float32) + tf_psnr = image_ops.psnr(tf_image1, tf_image2, 1.0, "psnr").eval() + self.assertAllClose(psnr, tf_psnr, atol=0.001) + + def testPSNRMultiImage(self): + image1 = self._RandomImage((10, 8, 8, 1), 1) + image2 = self._RandomImage((10, 8, 8, 1), 1) + psnr = self._PSNR_NumPy(image1, image2, 1) + + with self.test_session(use_gpu=True): + tf_image1 = constant_op.constant(image1, shape=image1.shape, + dtype=dtypes.float32) + tf_image2 = constant_op.constant(image2, shape=image2.shape, + dtype=dtypes.float32) + tf_psnr = image_ops.psnr(tf_image1, tf_image2, 1, "psnr").eval() + self.assertAllClose(psnr, tf_psnr, atol=0.001) + + def testGoldenPSNR(self): + q20, q72, q95 = self._LoadTestImages() + + # Verify NumPy implementation first. + # Golden values are generated using GNU Octave's psnr() function. + psnr1 = self._PSNR_NumPy(q20, q72, 1) + self.assertNear(30.321, psnr1, 0.001, msg="q20.dtype=" + str(q20.dtype)) + psnr2 = self._PSNR_NumPy(q20, q95, 1) + self.assertNear(29.994, psnr2, 0.001) + psnr3 = self._PSNR_NumPy(q72, q95, 1) + self.assertNear(35.302, psnr3, 0.001) + + # Test TensorFlow implementation. + with self.test_session(use_gpu=True): + tf_q20 = constant_op.constant(q20, shape=q20.shape, dtype=dtypes.float32) + tf_q72 = constant_op.constant(q72, shape=q72.shape, dtype=dtypes.float32) + tf_q95 = constant_op.constant(q95, shape=q95.shape, dtype=dtypes.float32) + tf_psnr1 = image_ops.psnr(tf_q20, tf_q72, 1, "psnr1").eval() + tf_psnr2 = image_ops.psnr(tf_q20, tf_q95, 1, "psnr2").eval() + tf_psnr3 = image_ops.psnr(tf_q72, tf_q95, 1, "psnr3").eval() + self.assertAllClose(psnr1, tf_psnr1, atol=0.001) + self.assertAllClose(psnr2, tf_psnr2, atol=0.001) + self.assertAllClose(psnr3, tf_psnr3, atol=0.001) + + def testInfinity(self): + q20, _, _ = self._LoadTestImages() + psnr = self._PSNR_NumPy(q20, q20, 1) + with self.test_session(use_gpu=True): + tf_q20 = constant_op.constant(q20, shape=q20.shape, dtype=dtypes.float32) + tf_psnr = image_ops.psnr(tf_q20, tf_q20, 1, "psnr").eval() + self.assertAllClose(psnr, tf_psnr, atol=0.001) + + def testInt(self): + img1 = self._RandomImage((10, 8, 8, 1), 255) + img2 = self._RandomImage((10, 8, 8, 1), 255) + img1 = constant_op.constant(img1, dtypes.uint8) + img2 = constant_op.constant(img2, dtypes.uint8) + psnr_uint8 = image_ops.psnr(img1, img2, 255) + img1 = image_ops.convert_image_dtype(img1, dtypes.float32) + img2 = image_ops.convert_image_dtype(img2, dtypes.float32) + psnr_float32 = image_ops.psnr(img1, img2, 1.0) + with self.test_session(use_gpu=True): + self.assertAllClose(psnr_uint8.eval(), psnr_float32.eval(), atol=0.001) + + +class SSIMTest(test_util.TensorFlowTestCase): + """Tests for SSIM.""" + + _filenames = ["checkerboard1.png", + "checkerboard2.png", + "checkerboard3.png",] + + _ssim = np.asarray([[1.000000, 0.230880, 0.231153], + [0.230880, 1.000000, 0.996828], + [0.231153, 0.996828, 1.000000]]) + + def _LoadTestImage(self, sess, filename): + content = io_ops.read_file(os.path.join( + "tensorflow/core/lib/ssim/testdata", filename)) + im = image_ops.decode_png(content) + im = image_ops.convert_image_dtype(im, dtypes.float32) + im, = sess.run([im]) + return np.expand_dims(im, axis=0) + + def _LoadTestImages(self): + with self.test_session(use_gpu=True) as sess: + return [self._LoadTestImage(sess, f) for f in self._filenames] + + def _RandomImage(self, shape, max_val): + """Returns an image or image batch with given shape.""" + return np.random.rand(*shape).astype(np.float32) * max_val + + def testAgainstMatlab(self): + """Tests against values produced by Matlab.""" + img = self._LoadTestImages() + expected = self._ssim[np.triu_indices(3)] + + ph = [array_ops.placeholder(dtype=dtypes.float32) for _ in range(2)] + ssim = image_ops.ssim(*ph, max_val=1.0) + with self.test_session(use_gpu=True): + scores = [ssim.eval(dict(zip(ph, t))) + for t in itertools.combinations_with_replacement(img, 2)] + self.assertAllClose(expected, np.squeeze(scores), atol=1e-4) + + def testBatch(self): + img = self._LoadTestImages() + expected = self._ssim[np.triu_indices(3, k=1)] + + img1, img2 = zip(*itertools.combinations(img, 2)) + img1 = np.concatenate(img1) + img2 = np.concatenate(img2) + + ssim = image_ops.ssim(constant_op.constant(img1), + constant_op.constant(img2), 1.0) + with self.test_session(use_gpu=True): + self.assertAllClose(expected, ssim.eval(), atol=1e-4) + + def testBroadcast(self): + img = self._LoadTestImages()[:2] + expected = self._ssim[:2, :2] + + img = constant_op.constant(np.concatenate(img)) + img1 = array_ops.expand_dims(img, axis=0) # batch dims: 1, 2. + img2 = array_ops.expand_dims(img, axis=1) # batch dims: 2, 1. + + ssim = image_ops.ssim(img1, img2, 1.0) + with self.test_session(use_gpu=True): + self.assertAllClose(expected, ssim.eval(), atol=1e-4) + + def testNegative(self): + """Tests against negative SSIM index.""" + step = np.expand_dims(np.arange(0, 256, 16, dtype=np.uint8), axis=0) + img1 = np.tile(step, (16, 1)) + img2 = np.fliplr(img1) + + img1 = img1.reshape((1, 16, 16, 1)) + img2 = img2.reshape((1, 16, 16, 1)) + + ssim = image_ops.ssim(constant_op.constant(img1), + constant_op.constant(img2), 255) + with self.test_session(use_gpu=True): + self.assertLess(ssim.eval(), 0) + + def testInt(self): + img1 = self._RandomImage((1, 16, 16, 3), 255) + img2 = self._RandomImage((1, 16, 16, 3), 255) + img1 = constant_op.constant(img1, dtypes.uint8) + img2 = constant_op.constant(img2, dtypes.uint8) + ssim_uint8 = image_ops.ssim(img1, img2, 255) + img1 = image_ops.convert_image_dtype(img1, dtypes.float32) + img2 = image_ops.convert_image_dtype(img2, dtypes.float32) + ssim_float32 = image_ops.ssim(img1, img2, 1.0) + with self.test_session(use_gpu=True): + self.assertAllClose(ssim_uint8.eval(), ssim_float32.eval(), atol=0.001) + + +class MultiscaleSSIMTest(test_util.TensorFlowTestCase): + """Tests for MS-SSIM.""" + + _filenames = ["checkerboard1.png", + "checkerboard2.png", + "checkerboard3.png",] + + _msssim = np.asarray([[1.000000, 0.091016, 0.091025], + [0.091016, 1.000000, 0.999567], + [0.091025, 0.999567, 1.000000]]) + + def _LoadTestImage(self, sess, filename): + content = io_ops.read_file(os.path.join( + "tensorflow/core/lib/ssim/testdata", filename)) + im = image_ops.decode_png(content) + im = image_ops.convert_image_dtype(im, dtypes.float32) + im, = sess.run([im]) + return np.expand_dims(im, axis=0) + + def _LoadTestImages(self): + with self.test_session(use_gpu=True) as sess: + return [self._LoadTestImage(sess, f) for f in self._filenames] + + def _RandomImage(self, shape, max_val): + """Returns an image or image batch with given shape.""" + return np.random.rand(*shape).astype(np.float32) * max_val + + def testAgainstMatlab(self): + """Tests against MS-SSIM computed with Matlab implementation. + + For color images, MS-SSIM scores are averaged over color channels. + """ + img = self._LoadTestImages() + expected = self._msssim[np.triu_indices(3)] + + ph = [array_ops.placeholder(dtype=dtypes.float32) for _ in range(2)] + msssim = image_ops.ssim_multiscale(*ph, max_val=1.0) + with self.test_session(use_gpu=True): + scores = [msssim.eval(dict(zip(ph, t))) + for t in itertools.combinations_with_replacement(img, 2)] + + self.assertAllClose(expected, np.squeeze(scores), atol=1e-4) + + def testUnweightedIsDifferentiable(self): + img = self._LoadTestImages() + ph = [array_ops.placeholder(dtype=dtypes.float32) for _ in range(2)] + scalar = constant_op.constant(1.0, dtype=dtypes.float32) + scaled_ph = [x * scalar for x in ph] + msssim = image_ops.ssim_multiscale(*scaled_ph, max_val=1.0, + power_factors=(1, 1, 1, 1, 1)) + grads = gradients.gradients(msssim, scalar) + with self.test_session(use_gpu=True) as sess: + np_grads = sess.run(grads, feed_dict={ph[0]: img[0], ph[1]: img[1]}) + self.assertTrue(np.isfinite(np_grads).all()) + + def testBatch(self): + """Tests MS-SSIM computed in batch.""" + img = self._LoadTestImages() + expected = self._msssim[np.triu_indices(3, k=1)] + + img1, img2 = zip(*itertools.combinations(img, 2)) + img1 = np.concatenate(img1) + img2 = np.concatenate(img2) + + msssim = image_ops.ssim_multiscale(constant_op.constant(img1), + constant_op.constant(img2), 1.0) + with self.test_session(use_gpu=True): + self.assertAllClose(expected, msssim.eval(), 1e-4) + + def testBroadcast(self): + """Tests MS-SSIM broadcasting.""" + img = self._LoadTestImages()[:2] + expected = self._msssim[:2, :2] + + img = constant_op.constant(np.concatenate(img)) + img1 = array_ops.expand_dims(img, axis=0) # batch dims: 1, 2. + img2 = array_ops.expand_dims(img, axis=1) # batch dims: 2, 1. + + score_tensor = image_ops.ssim_multiscale(img1, img2, 1.0) + with self.test_session(use_gpu=True): + self.assertAllClose(expected, score_tensor.eval(), 1e-4) + + def testRange(self): + """Tests against low MS-SSIM score. + + MS-SSIM is a geometric mean of SSIM and CS scores of various scales. + If any of the value is negative so that the geometric mean is not + well-defined, then treat the MS-SSIM score as zero. + """ + with self.test_session(use_gpu=True) as sess: + img1 = self._LoadTestImage(sess, "checkerboard1.png") + img2 = self._LoadTestImage(sess, "checkerboard3.png") + images = [img1, img2, np.zeros_like(img1), + np.full_like(img1, fill_value=255)] + + images = [ops.convert_to_tensor(x, dtype=dtypes.float32) for x in images] + msssim_ops = [image_ops.ssim_multiscale(x, y, 1.0) + for x, y in itertools.combinations(images, 2)] + msssim = sess.run(msssim_ops) + msssim = np.squeeze(msssim) + + self.assertTrue(np.all(msssim >= 0.0)) + self.assertTrue(np.all(msssim <= 1.0)) + + def testInt(self): + img1 = self._RandomImage((1, 180, 240, 3), 255) + img2 = self._RandomImage((1, 180, 240, 3), 255) + img1 = constant_op.constant(img1, dtypes.uint8) + img2 = constant_op.constant(img2, dtypes.uint8) + ssim_uint8 = image_ops.ssim_multiscale(img1, img2, 255) + img1 = image_ops.convert_image_dtype(img1, dtypes.float32) + img2 = image_ops.convert_image_dtype(img2, dtypes.float32) + ssim_float32 = image_ops.ssim_multiscale(img1, img2, 1.0) + with self.test_session(use_gpu=True): + self.assertAllClose(ssim_uint8.eval(), ssim_float32.eval(), atol=0.001) + + +class ImageGradientsTest(test_util.TensorFlowTestCase): + + def testImageGradients(self): + shape = [1, 2, 4, 1] + img = constant_op.constant([[1, 3, 4, 2], [8, 7, 5, 6]]) + img = array_ops.reshape(img, shape) + + expected_dy = np.reshape([[7, 4, 1, 4], [0, 0, 0, 0]], shape) + expected_dx = np.reshape([[2, 1, -2, 0], [-1, -2, 1, 0]], shape) + + dy, dx = image_ops.image_gradients(img) + with self.test_session(): + actual_dy = dy.eval() + actual_dx = dx.eval() + self.assertAllClose(expected_dy, actual_dy) + self.assertAllClose(expected_dx, actual_dx) + + def testImageGradientsMultiChannelBatch(self): + batch = [[[[1, 2], [2, 5], [3, 3]], + [[8, 4], [5, 1], [9, 8]]], + [[[5, 3], [7, 9], [1, 6]], + [[1, 2], [6, 3], [6, 3]]]] + + expected_dy = [[[[7, 2], [3, -4], [6, 5]], + [[0, 0], [0, 0], [0, 0]]], + [[[-4, -1], [-1, -6], [5, -3]], + [[0, 0], [0, 0], [0, 0]]]] + + expected_dx = [[[[1, 3], [1, -2], [0, 0]], + [[-3, -3], [4, 7], [0, 0]]], + [[[2, 6], [-6, -3], [0, 0]], + [[5, 1], [0, 0], [0, 0]]]] + + batch = constant_op.constant(batch) + assert batch.get_shape().as_list() == [2, 2, 3, 2] + dy, dx = image_ops.image_gradients(batch) + with self.test_session(use_gpu=True): + actual_dy = dy.eval() + actual_dx = dx.eval() + self.assertAllClose(expected_dy, actual_dy) + self.assertAllClose(expected_dx, actual_dx) + + def testImageGradientsBadShape(self): + # [2 x 4] image but missing batch and depth dimensions. + img = constant_op.constant([[1, 3, 4, 2], [8, 7, 5, 6]]) + with self.assertRaises(ValueError): + image_ops.image_gradients(img) + + +class SobelEdgesTest(test_util.TensorFlowTestCase): + + def testSobelEdges1x2x3x1(self): + img = constant_op.constant([[1, 3, 6], [4, 1, 5]], + dtype=dtypes.float32, shape=[1, 2, 3, 1]) + expected = np.reshape([[[0, 0], [0, 12], [0, 0]], + [[0, 0], [0, 12], [0, 0]]], [1, 2, 3, 1, 2]) + sobel = image_ops.sobel_edges(img) + with self.test_session(use_gpu=True): + actual_sobel = sobel.eval() + self.assertAllClose(expected, actual_sobel) + + def testSobelEdges5x3x4x2(self): + batch_size = 5 + plane = np.reshape([[1, 3, 6, 2], [4, 1, 5, 7], [2, 5, 1, 4]], + [1, 3, 4, 1]) + two_channel = np.concatenate([plane, plane], axis=3) + batch = np.concatenate([two_channel] * batch_size, axis=0) + img = constant_op.constant(batch, dtype=dtypes.float32, + shape=[batch_size, 3, 4, 2]) + + expected_plane = np.reshape([[[0, 0], [0, 12], [0, 10], [0, 0]], + [[6, 0], [0, 6], [-6, 10], [-6, 0]], + [[0, 0], [0, 0], [0, 10], [0, 0]]], + [1, 3, 4, 1, 2]) + expected_two_channel = np.concatenate( + [expected_plane, expected_plane], axis=3) + expected_batch = np.concatenate([expected_two_channel] * batch_size, axis=0) + + sobel = image_ops.sobel_edges(img) + with self.test_session(use_gpu=True): + actual_sobel = sobel.eval() + self.assertAllClose(expected_batch, actual_sobel) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py index c7502d0fda5c38079362d30877a917e3965e6ca0..40ab22951b1aa04a61e09aac155b6449ae358d7b 100644 --- a/tensorflow/python/ops/init_ops.py +++ b/tensorflow/python/ops/init_ops.py @@ -542,6 +542,62 @@ class Orthogonal(Initializer): return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name} +class ConvolutionDeltaOrthogonal(Initializer): + """Initializer that generates a delta orthogonal kernel for ConvNets. + + The shape of the tensor must have length 3, 4 or 5. The number of input + filters must not exceed the number of output filters. The center pixels of the + tensor form an orthogonal matrix. Other pixels are set to be zero. + + Args: + gain: multiplicative factor to apply to the orthogonal matrix. Default is 1. + The 2-norm of an input is multiplied by a factor of 'sqrt(gain)' after + applying this convolution. + dtype: The type of the output. + seed: A Python integer. Used to create random seeds. See + @{tf.set_random_seed} + for behavior. + """ + + def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32): + self.gain = gain + self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) + self.seed = seed + + def __call__(self, shape, dtype=None, partition_info=None): + if dtype is None: + dtype = self.dtype + # Check the shape + if len(shape) < 3 or len(shape) > 5: + raise ValueError("The tensor to initialize must be at least " + "three-dimensional and at most five-dimensional") + + if shape[-2] > shape[-1]: + raise ValueError("In_filters cannot be greater than out_filters.") + + # Generate a random matrix + a = random_ops.random_normal([shape[-1], shape[-1]], + dtype=dtype, seed=self.seed) + # Compute the qr factorization + q, _ = linalg_ops.qr(a, full_matrices=False) + q = q[:shape[-2], :] + q *= math_ops.sqrt(math_ops.cast(self.gain, dtype=dtype)) + if len(shape) == 3: + weight = array_ops.scatter_nd([[(shape[0]-1)//2]], + array_ops.expand_dims(q, 0), shape) + elif len(shape) == 4: + weight = array_ops.scatter_nd([[(shape[0]-1)//2, (shape[1]-1)//2]], + array_ops.expand_dims(q, 0), shape) + else: + weight = array_ops.scatter_nd([[(shape[0]-1)//2, (shape[1]-1)//2, + (shape[2]-1)//2]], + array_ops.expand_dims(q, 0), shape) + return weight + + def get_config(self): + return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name} + + @tf_export("keras.initializers.Identity", "initializers.identity") class Identity(Initializer): """Initializer that generates the identity matrix. @@ -586,7 +642,7 @@ uniform_unit_scaling_initializer = UniformUnitScaling variance_scaling_initializer = VarianceScaling orthogonal_initializer = Orthogonal identity_initializer = Identity - +convolutional_delta_orthogonal = ConvolutionDeltaOrthogonal # pylint: enable=invalid-name diff --git a/tensorflow/python/ops/initializers_ns.py b/tensorflow/python/ops/initializers_ns.py index c21079f2971a4bdd76b4be1a803055c12b243903..e7996efe93eb2f33306a52ded91c273009192789 100644 --- a/tensorflow/python/ops/initializers_ns.py +++ b/tensorflow/python/ops/initializers_ns.py @@ -39,5 +39,8 @@ global_variables = _variables.global_variables_initializer local_variables = _variables.local_variables_initializer # Seal API. +del absolute_import +del division +del print_function del init_ops del _variables diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index 5e70b3186f382a0c795b1795b2db27bb2058ee41..f6a25610c5a2ee8b76d06e286365cb957ab643cd 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -111,10 +111,10 @@ def _save(filename, tensor_names, tensors, tensor_slices=None, name="save"): An Operation that saves the tensors. """ if tensor_slices is None: - return gen_io_ops._save(filename, tensor_names, tensors, name=name) + return gen_io_ops.save(filename, tensor_names, tensors, name=name) else: - return gen_io_ops._save_slices(filename, tensor_names, tensor_slices, - tensors, name=name) + return gen_io_ops.save_slices(filename, tensor_names, tensor_slices, + tensors, name=name) def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type, @@ -136,7 +136,7 @@ def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type, A tensor of type "tensor_type". """ base_type = dtypes.as_dtype(tensor_type).base_dtype - return gen_io_ops._restore_slice( + return gen_io_ops.restore_slice( file_pattern, tensor_name, shape_and_slice, base_type, preferred_shard, name=name) @@ -173,7 +173,7 @@ class ReaderBase(object): Raises: RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError( "Readers are not supported when eager execution is enabled. " "Instead, please use tf.data to get data into your model.") @@ -208,12 +208,12 @@ class ReaderBase(object): else: queue_ref = queue.queue_ref if self._reader_ref.dtype == dtypes.resource: - return gen_io_ops._reader_read_v2(self._reader_ref, queue_ref, name=name) + return gen_io_ops.reader_read_v2(self._reader_ref, queue_ref, name=name) else: # For compatibility with pre-resource queues, create a ref(string) tensor # which can be looked up as the same queue by a resource manager. - old_queue_op = gen_data_flow_ops._fake_queue(queue_ref) - return gen_io_ops._reader_read(self._reader_ref, old_queue_op, name=name) + old_queue_op = gen_data_flow_ops.fake_queue(queue_ref) + return gen_io_ops.reader_read(self._reader_ref, old_queue_op, name=name) def read_up_to(self, queue, num_records, # pylint: disable=invalid-name name=None): @@ -240,18 +240,18 @@ class ReaderBase(object): else: queue_ref = queue.queue_ref if self._reader_ref.dtype == dtypes.resource: - return gen_io_ops._reader_read_up_to_v2(self._reader_ref, - queue_ref, - num_records, - name=name) + return gen_io_ops.reader_read_up_to_v2(self._reader_ref, + queue_ref, + num_records, + name=name) else: # For compatibility with pre-resource queues, create a ref(string) tensor # which can be looked up as the same queue by a resource manager. - old_queue_op = gen_data_flow_ops._fake_queue(queue_ref) - return gen_io_ops._reader_read_up_to(self._reader_ref, - old_queue_op, - num_records, - name=name) + old_queue_op = gen_data_flow_ops.fake_queue(queue_ref) + return gen_io_ops.reader_read_up_to(self._reader_ref, + old_queue_op, + num_records, + name=name) def num_records_produced(self, name=None): """Returns the number of records this reader has produced. @@ -267,11 +267,11 @@ class ReaderBase(object): """ if self._reader_ref.dtype == dtypes.resource: - return gen_io_ops._reader_num_records_produced_v2(self._reader_ref, - name=name) + return gen_io_ops.reader_num_records_produced_v2(self._reader_ref, + name=name) else: - return gen_io_ops._reader_num_records_produced(self._reader_ref, - name=name) + return gen_io_ops.reader_num_records_produced(self._reader_ref, + name=name) def num_work_units_completed(self, name=None): """Returns the number of work units this reader has finished processing. @@ -283,11 +283,11 @@ class ReaderBase(object): An int64 Tensor. """ if self._reader_ref.dtype == dtypes.resource: - return gen_io_ops._reader_num_work_units_completed_v2(self._reader_ref, - name=name) + return gen_io_ops.reader_num_work_units_completed_v2(self._reader_ref, + name=name) else: - return gen_io_ops._reader_num_work_units_completed(self._reader_ref, - name=name) + return gen_io_ops.reader_num_work_units_completed(self._reader_ref, + name=name) def serialize_state(self, name=None): """Produce a string tensor that encodes the state of a reader. @@ -302,9 +302,9 @@ class ReaderBase(object): A string Tensor. """ if self._reader_ref.dtype == dtypes.resource: - return gen_io_ops._reader_serialize_state_v2(self._reader_ref, name=name) + return gen_io_ops.reader_serialize_state_v2(self._reader_ref, name=name) else: - return gen_io_ops._reader_serialize_state(self._reader_ref, name=name) + return gen_io_ops.reader_serialize_state(self._reader_ref, name=name) def restore_state(self, state, name=None): """Restore a reader to a previously saved state. @@ -321,11 +321,10 @@ class ReaderBase(object): The created Operation. """ if self._reader_ref.dtype == dtypes.resource: - return gen_io_ops._reader_restore_state_v2( + return gen_io_ops.reader_restore_state_v2( self._reader_ref, state, name=name) else: - return gen_io_ops._reader_restore_state( - self._reader_ref, state, name=name) + return gen_io_ops.reader_restore_state(self._reader_ref, state, name=name) @property def supports_serialize(self): @@ -342,9 +341,9 @@ class ReaderBase(object): The created Operation. """ if self._reader_ref.dtype == dtypes.resource: - return gen_io_ops._reader_reset_v2(self._reader_ref, name=name) + return gen_io_ops.reader_reset_v2(self._reader_ref, name=name) else: - return gen_io_ops._reader_reset(self._reader_ref, name=name) + return gen_io_ops.reader_reset(self._reader_ref, name=name) ops.NotDifferentiable("ReaderRead") @@ -377,7 +376,7 @@ class WholeFileReader(ReaderBase): Args: name: A name for the operation (optional). """ - rr = gen_io_ops._whole_file_reader_v2(name=name) + rr = gen_io_ops.whole_file_reader_v2(name=name) super(WholeFileReader, self).__init__(rr, supports_serialize=True) @@ -406,8 +405,8 @@ class TextLineReader(ReaderBase): to skip from the beginning of every file. name: A name for the operation (optional). """ - rr = gen_io_ops._text_line_reader_v2(skip_header_lines=skip_header_lines, - name=name) + rr = gen_io_ops.text_line_reader_v2(skip_header_lines=skip_header_lines, + name=name) super(TextLineReader, self).__init__(rr) @@ -444,7 +443,7 @@ class FixedLengthRecordReader(ReaderBase): name: A name for the operation (optional). encoding: The type of encoding for the file. Defaults to none. """ - rr = gen_io_ops._fixed_length_record_reader_v2( + rr = gen_io_ops.fixed_length_record_reader_v2( record_bytes=record_bytes, header_bytes=header_bytes, footer_bytes=footer_bytes, @@ -480,7 +479,7 @@ class TFRecordReader(ReaderBase): compression_type = python_io.TFRecordOptions.get_compression_type_string( options) - rr = gen_io_ops._tf_record_reader_v2( + rr = gen_io_ops.tf_record_reader_v2( name=name, compression_type=compression_type) super(TFRecordReader, self).__init__(rr) @@ -506,7 +505,7 @@ class LMDBReader(ReaderBase): name: A name for the operation (optional). options: A LMDBRecordOptions object (optional). """ - rr = gen_io_ops._lmdb_reader(name=name) + rr = gen_io_ops.lmdb_reader(name=name) super(LMDBReader, self).__init__(rr) @@ -534,7 +533,7 @@ class IdentityReader(ReaderBase): Args: name: A name for the operation (optional). """ - rr = gen_io_ops._identity_reader_v2(name=name) + rr = gen_io_ops.identity_reader_v2(name=name) super(IdentityReader, self).__init__(rr, supports_serialize=True) diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py index d5bd916f80d8a03e5423c43d1ca039bc4dceff5e..8343c62816c6aeadc77dae701ae9917a86e68954 100644 --- a/tensorflow/python/ops/linalg/linalg_impl.py +++ b/tensorflow/python/ops/linalg/linalg_impl.py @@ -31,18 +31,19 @@ band_part = array_ops.matrix_band_part cholesky = linalg_ops.cholesky cholesky_solve = linalg_ops.cholesky_solve det = linalg_ops.matrix_determinant -# pylint: disable=protected-access -slogdet = gen_linalg_ops._log_matrix_determinant -# pylint: disable=protected-access +slogdet = gen_linalg_ops.log_matrix_determinant +tf_export('linalg.slogdet')(slogdet) diag = array_ops.matrix_diag diag_part = array_ops.matrix_diag_part eigh = linalg_ops.self_adjoint_eig eigvalsh = linalg_ops.self_adjoint_eigvals einsum = special_math_ops.einsum -expm = gen_linalg_ops._matrix_exponential +expm = gen_linalg_ops.matrix_exponential +tf_export('linalg.expm')(expm) eye = linalg_ops.eye inv = linalg_ops.matrix_inverse -logm = gen_linalg_ops._matrix_logarithm +logm = gen_linalg_ops.matrix_logarithm +tf_export('linalg.logm')(logm) lstsq = linalg_ops.matrix_solve_ls norm = linalg_ops.norm qr = linalg_ops.qr diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py index 957a7959181efe3bbc319e62582053329b763dc3..c7513d5b40c5a4bb11501c90e08a9dc3a38c2e09 100644 --- a/tensorflow/python/ops/linalg/linear_operator.py +++ b/tensorflow/python/ops/linalg/linear_operator.py @@ -204,16 +204,6 @@ class LinearOperator(object): self._is_positive_definite = is_positive_definite self._name = name or type(self).__name__ - # We will cache some tensors to avoid repeatedly adding shape - # manipulation ops to the graph. - # Naming convention: - # self._cached_X_tensor is the cached version of self._X_tensor. - self._cached_shape_tensor = None - self._cached_batch_shape_tensor = None - self._cached_domain_dimension_tensor = None - self._cached_range_dimension_tensor = None - self._cached_tensor_rank_tensor = None - @contextlib.contextmanager def _name_scope(self, name=None, values=None): """Helper function to standardize op scope.""" @@ -299,15 +289,11 @@ class LinearOperator(object): `int32` `Tensor` """ with self._name_scope(name): - # Be clean by avoiding adding shape Ops to the graph too many times. - if self._cached_shape_tensor is None: - # Prefer to use statically defined shape if available. - if self.shape.is_fully_defined(): - self._cached_shape_tensor = linear_operator_util.shape_tensor( - self.shape.as_list()) - else: - self._cached_shape_tensor = self._shape_tensor() - return self._cached_shape_tensor + # Prefer to use statically defined shape if available. + if self.shape.is_fully_defined(): + return linear_operator_util.shape_tensor(self.shape.as_list()) + else: + return self._shape_tensor() @property def batch_shape(self): @@ -338,14 +324,12 @@ class LinearOperator(object): """ # Derived classes get this "for free" once .shape() is implemented. with self._name_scope(name): - if self._cached_batch_shape_tensor is None: - # Prefer to use statically defined shape if available. - if self.batch_shape.is_fully_defined(): - self._cached_batch_shape_tensor = linear_operator_util.shape_tensor( - self.batch_shape.as_list(), name="batch_shape") - else: - self._cached_batch_shape_tensor = self.shape_tensor()[:-2] - return self._cached_batch_shape_tensor + # Prefer to use statically defined shape if available. + if self.batch_shape.is_fully_defined(): + return linear_operator_util.shape_tensor( + self.batch_shape.as_list(), name="batch_shape") + else: + return self.shape_tensor()[:-2] @property def tensor_rank(self, name="tensor_rank"): @@ -378,14 +362,11 @@ class LinearOperator(object): """ # Derived classes get this "for free" once .shape() is implemented. with self._name_scope(name): - if self._cached_tensor_rank_tensor is None: - # Prefer to use statically defined shape if available. - if self.tensor_rank is not None: - self._cached_tensor_rank_tensor = ops.convert_to_tensor( - self.tensor_rank) - else: - self._cached_tensor_rank_tensor = array_ops.size(self.shape_tensor()) - return self._cached_tensor_rank_tensor + # Prefer to use statically defined shape if available. + if self.tensor_rank is not None: + return ops.convert_to_tensor(self.tensor_rank) + else: + return array_ops.size(self.shape_tensor()) @property def domain_dimension(self): @@ -416,14 +397,11 @@ class LinearOperator(object): """ # Derived classes get this "for free" once .shape() is implemented. with self._name_scope(name): - if self._cached_domain_dimension_tensor is None: - # Prefer to use statically defined shape if available. - if self.domain_dimension.value is not None: - self._cached_domain_dimension_tensor = ops.convert_to_tensor( - self.domain_dimension.value) - else: - self._cached_domain_dimension_tensor = self.shape_tensor()[-1] - return self._cached_domain_dimension_tensor + # Prefer to use statically defined shape if available. + if self.domain_dimension.value is not None: + return ops.convert_to_tensor(self.domain_dimension.value) + else: + return self.shape_tensor()[-1] @property def range_dimension(self): @@ -454,14 +432,11 @@ class LinearOperator(object): """ # Derived classes get this "for free" once .shape() is implemented. with self._name_scope(name): - if self._cached_range_dimension_tensor is None: - # Prefer to use statically defined shape if available. - if self.range_dimension.value is not None: - self._cached_range_dimension_tensor = ops.convert_to_tensor( - self.range_dimension.value) - else: - self._cached_range_dimension_tensor = self.shape_tensor()[-2] - return self._cached_range_dimension_tensor + # Prefer to use statically defined shape if available. + if self.range_dimension.value is not None: + return ops.convert_to_tensor(self.range_dimension.value) + else: + return self.shape_tensor()[-2] def _assert_non_singular(self): """Private default implementation of _assert_non_singular.""" @@ -471,8 +446,7 @@ class LinearOperator(object): if self._can_use_cholesky(): return self.assert_positive_definite() else: - singular_values = linalg_ops.svd( - self._get_cached_dense_matrix(), compute_uv=False) + singular_values = linalg_ops.svd(self.to_dense(), compute_uv=False) # TODO(langmore) Add .eig and .cond as methods. cond = (math_ops.reduce_max(singular_values, axis=-1) / math_ops.reduce_min(singular_values, axis=-1)) @@ -524,7 +498,7 @@ class LinearOperator(object): # and sufficient. if self.is_self_adjoint: return check_ops.assert_positive( - array_ops.matrix_diag_part(self._get_cached_chol()), + array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())), message="Matrix was not positive definite.") # We have no generic check for positive definite. raise NotImplementedError("assert_positive_definite is not implemented.") @@ -547,7 +521,7 @@ class LinearOperator(object): return self._assert_positive_definite() def _assert_self_adjoint(self): - dense = self._get_cached_dense_matrix() + dense = self.to_dense() logging.warn( "Using (possibly slow) default implementation of assert_self_adjoint." " Requires conversion to a dense matrix.") @@ -692,7 +666,7 @@ class LinearOperator(object): "Using (possibly slow) default implementation of determinant." " Requires conversion to a dense matrix and O(N^3) operations.") if self._can_use_cholesky(): - diag = array_ops.matrix_diag_part(self._get_cached_chol()) + diag = array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())) return 2 * math_ops.reduce_sum(math_ops.log(diag), reduction_indices=[-1]) _, log_abs_det = linalg.slogdet(self._matrix) return log_abs_det @@ -726,9 +700,9 @@ class LinearOperator(object): " Requires conversion to a dense matrix and O(N^3) operations.") rhs = linalg.adjoint(rhs) if adjoint_arg else rhs if self._can_use_cholesky(): - return linalg_ops.cholesky_solve(self._get_cached_chol(), rhs) - return linalg_ops.matrix_solve( - self._get_cached_dense_matrix(), rhs, adjoint=adjoint) + return linalg_ops.cholesky_solve( + linalg_ops.cholesky(self.to_dense()), rhs) + return linalg_ops.matrix_solve(self.to_dense(), rhs, adjoint=adjoint) def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. @@ -866,7 +840,7 @@ class LinearOperator(object): def _diag_part(self): """Generic and often inefficient implementation. Override often.""" - return array_ops.matrix_diag_part(self._get_cached_dense_matrix()) + return array_ops.matrix_diag_part(self.to_dense()) def diag_part(self, name="diag_part"): """Efficiently get the [batch] diagonal part of this operator. @@ -915,7 +889,7 @@ class LinearOperator(object): def _add_to_tensor(self, x): # Override if a more efficient implementation is available. - return self._get_cached_dense_matrix() + x + return self.to_dense() + x def add_to_tensor(self, x, name="add_to_tensor"): """Add matrix represented by this operator to `x`. Equivalent to `A + x`. @@ -936,13 +910,3 @@ class LinearOperator(object): # TODO(langmore) Add complex types when tf.cholesky can use them. return (not self.dtype.is_complex and self.is_self_adjoint and self.is_positive_definite) - - def _get_cached_dense_matrix(self): - if not hasattr(self, "_cached_dense_matrix"): - self._cached_dense_matrix = self.to_dense() - return self._cached_dense_matrix - - def _get_cached_chol(self): - if not hasattr(self, "_cached_chol"): - self._cached_chol = linalg_ops.cholesky(self._get_cached_dense_matrix()) - return self._cached_chol diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py index 2c11f90e6d9de280e6020edfaa4d8ef237126705..ce1a112ad584a14298be6e471578858ef31573d5 100644 --- a/tensorflow/python/ops/linalg/linear_operator_test_util.py +++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py @@ -35,6 +35,18 @@ from tensorflow.python.ops.linalg import linalg_impl as linalg from tensorflow.python.platform import test +class OperatorBuildInfo(object): + """Object encoding expected shape for a test. + + Encodes the expected shape of a matrix for a test. Also + allows additional metadata for the test harness. + """ + + def __init__(self, shape, **kwargs): + self.shape = shape + self.__dict__.update(kwargs) + + @six.add_metaclass(abc.ABCMeta) # pylint: disable=no-init class LinearOperatorDerivedClassTest(test.TestCase): """Tests for derived classes. @@ -84,19 +96,20 @@ class LinearOperatorDerivedClassTest(test.TestCase): return [False, True] @abc.abstractproperty - def _shapes_to_test(self): - """Returns list of tuples, each is one shape that will be tested.""" - raise NotImplementedError("shapes_to_test has not been implemented.") + def _operator_build_infos(self): + """Returns list of OperatorBuildInfo, encapsulating the shape to test.""" + raise NotImplementedError("operator_build_infos has not been implemented.") @abc.abstractmethod - def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder): + def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): """Build a batch matrix and an Operator that should have similar behavior. Every operator acts like a (batch) matrix. This method returns both together, and is used by tests. Args: - shape: List-like of Python integers giving full shape of operator. + build_info: `OperatorBuildInfo`, encoding shape information about the + operator. dtype: Numpy dtype. Data type of returned array/operator. use_placeholder: Python bool. If True, initialize the operator with a placeholder of undefined shape and correct dtype. @@ -164,30 +177,30 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_to_dense(self): self._skip_if_tests_to_skip_contains("to_dense") for use_placeholder in self._use_placeholder_options: - for shape in self._shapes_to_test: + for build_info in self._operator_build_infos: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( - shape, dtype, use_placeholder=use_placeholder) + build_info, dtype, use_placeholder=use_placeholder) op_dense = operator.to_dense() if not use_placeholder: - self.assertAllEqual(shape, op_dense.get_shape()) + self.assertAllEqual(build_info.shape, op_dense.get_shape()) op_dense_v, mat_v = sess.run([op_dense, mat], feed_dict=feed_dict) self.assertAC(op_dense_v, mat_v) def test_det(self): self._skip_if_tests_to_skip_contains("det") for use_placeholder in self._use_placeholder_options: - for shape in self._shapes_to_test: + for build_info in self._operator_build_infos: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( - shape, dtype, use_placeholder=use_placeholder) + build_info, dtype, use_placeholder=use_placeholder) op_det = operator.determinant() if not use_placeholder: - self.assertAllEqual(shape[:-2], op_det.get_shape()) + self.assertAllEqual(build_info.shape[:-2], op_det.get_shape()) op_det_v, mat_det_v = sess.run( [op_det, linalg_ops.matrix_determinant(mat)], feed_dict=feed_dict) @@ -196,16 +209,17 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_log_abs_det(self): self._skip_if_tests_to_skip_contains("log_abs_det") for use_placeholder in self._use_placeholder_options: - for shape in self._shapes_to_test: + for build_info in self._operator_build_infos: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( - shape, dtype, use_placeholder=use_placeholder) + build_info, dtype, use_placeholder=use_placeholder) op_log_abs_det = operator.log_abs_determinant() _, mat_log_abs_det = linalg.slogdet(mat) if not use_placeholder: - self.assertAllEqual(shape[:-2], op_log_abs_det.get_shape()) + self.assertAllEqual( + build_info.shape[:-2], op_log_abs_det.get_shape()) op_log_abs_det_v, mat_log_abs_det_v = sess.run( [op_log_abs_det, mat_log_abs_det], feed_dict=feed_dict) self.assertAC(op_log_abs_det_v, mat_log_abs_det_v) @@ -213,14 +227,14 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_matmul(self): self._skip_if_tests_to_skip_contains("matmul") for use_placeholder in self._use_placeholder_options: - for shape in self._shapes_to_test: + for build_info in self._operator_build_infos: for dtype in self._dtypes_to_test: for adjoint in self._adjoint_options: for adjoint_arg in self._adjoint_arg_options: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( - shape, dtype, use_placeholder=use_placeholder) + build_info, dtype, use_placeholder=use_placeholder) x = self._make_x(operator, adjoint=adjoint) # If adjoint_arg, compute A X^H^H = A X. if adjoint_arg: @@ -241,14 +255,14 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_solve(self): self._skip_if_tests_to_skip_contains("solve") for use_placeholder in self._use_placeholder_options: - for shape in self._shapes_to_test: + for build_info in self._operator_build_infos: for dtype in self._dtypes_to_test: for adjoint in self._adjoint_options: for adjoint_arg in self._adjoint_arg_options: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( - shape, dtype, use_placeholder=use_placeholder) + build_info, dtype, use_placeholder=use_placeholder) rhs = self._make_rhs(operator, adjoint=adjoint) # If adjoint_arg, solve A X = (rhs^H)^H = rhs. if adjoint_arg: @@ -270,12 +284,12 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_trace(self): self._skip_if_tests_to_skip_contains("trace") for use_placeholder in self._use_placeholder_options: - for shape in self._shapes_to_test: + for build_info in self._operator_build_infos: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( - shape, dtype, use_placeholder=use_placeholder) + build_info, dtype, use_placeholder=use_placeholder) op_trace = operator.trace() mat_trace = math_ops.trace(mat) if not use_placeholder: @@ -287,16 +301,16 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_add_to_tensor(self): self._skip_if_tests_to_skip_contains("add_to_tensor") for use_placeholder in self._use_placeholder_options: - for shape in self._shapes_to_test: + for build_info in self._operator_build_infos: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( - shape, dtype, use_placeholder=use_placeholder) + build_info, dtype, use_placeholder=use_placeholder) op_plus_2mat = operator.add_to_tensor(2 * mat) if not use_placeholder: - self.assertAllEqual(shape, op_plus_2mat.get_shape()) + self.assertAllEqual(build_info.shape, op_plus_2mat.get_shape()) op_plus_2mat_v, mat_v = sess.run( [op_plus_2mat, mat], feed_dict=feed_dict) @@ -306,12 +320,12 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_diag_part(self): self._skip_if_tests_to_skip_contains("diag_part") for use_placeholder in self._use_placeholder_options: - for shape in self._shapes_to_test: + for build_info in self._operator_build_infos: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( - shape, dtype, use_placeholder=use_placeholder) + build_info, dtype, use_placeholder=use_placeholder) op_diag_part = operator.diag_part() mat_diag_part = array_ops.matrix_diag_part(mat) @@ -334,9 +348,15 @@ class SquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest): """ @property - def _shapes_to_test(self): + def _operator_build_infos(self): + build_info = OperatorBuildInfo # non-batch operators (n, n) and batch operators. - return [(0, 0), (1, 1), (1, 3, 3), (3, 4, 4), (2, 1, 4, 4)] + return [ + build_info((0, 0)), + build_info((1, 1)), + build_info((1, 3, 3)), + build_info((3, 4, 4)), + build_info((2, 1, 4, 4))] def _make_rhs(self, operator, adjoint): # This operator is square, so rhs and x will have same shape. @@ -387,9 +407,15 @@ class NonSquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest): return ["solve", "det", "log_abs_det"] @property - def _shapes_to_test(self): + def _operator_build_infos(self): + build_info = OperatorBuildInfo # non-batch operators (n, n) and batch operators. - return [(2, 1), (1, 2), (1, 3, 2), (3, 3, 4), (2, 1, 2, 4)] + return [ + build_info((2, 1)), + build_info((1, 2)), + build_info((1, 3, 2)), + build_info((3, 3, 4)), + build_info((2, 1, 2, 4))] def _make_rhs(self, operator, adjoint): # TODO(langmore) Add once we're testing solve_ls. diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py index 9803eed6aefe072cbe0841dff2de3f640a440dd5..170861b43fd980ab0e107fc0b2e3d6f02339ed34 100644 --- a/tensorflow/python/ops/linalg_ops.py +++ b/tensorflow/python/ops/linalg_ops.py @@ -248,7 +248,7 @@ def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None): and l2_regularizer != 0 due to poor accuracy. """ - # pylint: disable=protected-access,long-lambda + # pylint: disable=long-lambda def _use_composite_impl(fast, tensor_shape): """Determines whether to use the composite or specialized CPU kernel. @@ -323,9 +323,8 @@ def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None): if _use_composite_impl(fast, tensor_shape): return _composite_impl(matrix, rhs, l2_regularizer) else: - return gen_linalg_ops._matrix_solve_ls( + return gen_linalg_ops.matrix_solve_ls( matrix, rhs, l2_regularizer, fast=fast, name=name) - # pylint: enable=protected-access @tf_export('self_adjoint_eig', 'linalg.eigh') @@ -342,12 +341,11 @@ def self_adjoint_eig(tensor, name=None): name: string, optional name of the operation. Returns: - e: Eigenvalues. Shape is `[..., N]`. + e: Eigenvalues. Shape is `[..., N]`. Sorted in non-decreasing order. v: Eigenvectors. Shape is `[..., N, N]`. The columns of the inner most matrices contain eigenvectors of the corresponding matrices in `tensor` """ - # pylint: disable=protected-access - e, v = gen_linalg_ops._self_adjoint_eig_v2(tensor, compute_v=True, name=name) + e, v = gen_linalg_ops.self_adjoint_eig_v2(tensor, compute_v=True, name=name) return e, v @@ -369,8 +367,7 @@ def self_adjoint_eigvals(tensor, name=None): e: Eigenvalues. Shape is `[..., N]`. The vector `e[..., :]` contains the `N` eigenvalues of `tensor[..., :, :]`. """ - # pylint: disable=protected-access - e, _ = gen_linalg_ops._self_adjoint_eig_v2(tensor, compute_v=False, name=name) + e, _ = gen_linalg_ops.self_adjoint_eig_v2(tensor, compute_v=False, name=name) return e @@ -432,13 +429,11 @@ def svd(tensor, full_matrices=False, compute_uv=True, name=None): u, s, v_adj = np.linalg.svd(a, full_matrices=False) np_a_approx = np.dot(u, np.dot(np.diag(s), v_adj)) # tf_a_approx and np_a_approx should be numerically close. - ```` + ``` @end_compatibility """ - # pylint: disable=protected-access - s, u, v = gen_linalg_ops._svd( + s, u, v = gen_linalg_ops.svd( tensor, compute_uv=compute_uv, full_matrices=full_matrices, name=name) - # pylint: enable=protected-access if compute_uv: return math_ops.real(s), u, v else: diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py index 3757109c956dfedc64ac4cda4ad13a4cfa601418..222b8ebc9da6b076f012f8febbd50cc3c4c86c08 100644 --- a/tensorflow/python/ops/logging_ops.py +++ b/tensorflow/python/ops/logging_ops.py @@ -109,7 +109,7 @@ def histogram_summary(tag, values, collections=None, name=None): buffer. """ with ops.name_scope(name, "HistogramSummary", [tag, values]) as scope: - val = gen_logging_ops._histogram_summary( + val = gen_logging_ops.histogram_summary( tag=tag, values=values, name=scope) _Collect(val, collections, [ops.GraphKeys.SUMMARIES]) return val @@ -170,7 +170,7 @@ def image_summary(tag, tensor, max_images=3, collections=None, name=None): buffer. """ with ops.name_scope(name, "ImageSummary", [tag, tensor]) as scope: - val = gen_logging_ops._image_summary( + val = gen_logging_ops.image_summary( tag=tag, tensor=tensor, max_images=max_images, name=scope) _Collect(val, collections, [ops.GraphKeys.SUMMARIES]) return val @@ -226,11 +226,12 @@ def audio_summary(tag, with ops.name_scope(name, "AudioSummary", [tag, tensor]) as scope: sample_rate = ops.convert_to_tensor(sample_rate, dtype=dtypes.float32, name="sample_rate") - val = gen_logging_ops._audio_summary_v2(tag=tag, - tensor=tensor, - max_outputs=max_outputs, - sample_rate=sample_rate, - name=scope) + val = gen_logging_ops.audio_summary_v2( + tag=tag, + tensor=tensor, + max_outputs=max_outputs, + sample_rate=sample_rate, + name=scope) _Collect(val, collections, [ops.GraphKeys.SUMMARIES]) return val @@ -263,7 +264,7 @@ def merge_summary(inputs, collections=None, name=None): buffer resulting from the merging. """ with ops.name_scope(name, "MergeSummary", inputs): - val = gen_logging_ops._merge_summary(inputs=inputs, name=name) + val = gen_logging_ops.merge_summary(inputs=inputs, name=name) _Collect(val, collections, []) return val @@ -345,7 +346,7 @@ def scalar_summary(tags, values, collections=None, name=None): buffer. """ with ops.name_scope(name, "ScalarSummary", [tags, values]) as scope: - val = gen_logging_ops._scalar_summary(tags=tags, values=values, name=scope) + val = gen_logging_ops.scalar_summary(tags=tags, values=values, name=scope) _Collect(val, collections, [ops.GraphKeys.SUMMARIES]) return val diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index f539a7bb68da57e31746bc80fb25339a03a4fafe..6f043f60e677eac560004619464905cd616256b2 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -157,10 +157,10 @@ class InitializableLookupTableBase(LookupInterface): default_value: The value to use if a key is missing in the table. initializer: The table initializer to use. """ - if context.in_graph_mode(): - name = table_ref.op.name.split("/")[-1] - else: + if context.executing_eagerly(): name = context.context().scope_name + else: + name = table_ref.op.name.split("/")[-1] super(InitializableLookupTableBase, self).__init__(initializer.key_dtype, initializer.value_dtype, name) @@ -196,9 +196,7 @@ class InitializableLookupTableBase(LookupInterface): """ with ops.name_scope(name, "%s_Size" % self._name, [self._table_ref]) as scope: - # pylint: disable=protected-access - return gen_lookup_ops._lookup_table_size_v2(self._table_ref, name=scope) - # pylint: enable=protected-access + return gen_lookup_ops.lookup_table_size_v2(self._table_ref, name=scope) def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. @@ -227,10 +225,8 @@ class InitializableLookupTableBase(LookupInterface): with ops.name_scope(name, "%s_Lookup" % self._name, (self._table_ref, key_tensor, self._default_value)) as scope: - # pylint: disable=protected-access - values = gen_lookup_ops._lookup_table_find_v2( + values = gen_lookup_ops.lookup_table_find_v2( self._table_ref, key_tensor, self._default_value, name=scope) - # pylint: enable=protected-access values.set_shape(key_tensor.get_shape()) if isinstance(keys, sparse_tensor.SparseTensor): @@ -274,13 +270,11 @@ class HashTable(InitializableLookupTableBase): """ with ops.name_scope(name, "hash_table", (initializer, default_value)) as scope: - # pylint: disable=protected-access - table_ref = gen_lookup_ops._hash_table_v2( + table_ref = gen_lookup_ops.hash_table_v2( shared_name=shared_name, key_dtype=initializer.key_dtype, value_dtype=initializer.value_dtype, name=scope) - # pylint: enable=protected-access super(HashTable, self).__init__(table_ref, default_value, initializer) @@ -352,10 +346,8 @@ class KeyValueTensorInitializer(TableInitializerBase): with ops.name_scope( self._name, values=(table.table_ref, self._keys, self._values)) as scope: - # pylint: disable=protected-access - init_op = gen_lookup_ops._initialize_table_v2( + init_op = gen_lookup_ops.initialize_table_v2( table.table_ref, self._keys, self._values, name=scope) - # pylint: enable=protected-access ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) return init_op @@ -518,8 +510,7 @@ class TextFileInitializer(TableInitializerBase): (table.table_ref,)) as scope: filename = ops.convert_to_tensor( self._filename, dtypes.string, name="asset_filepath") - # pylint: disable=protected-access - init_op = gen_lookup_ops._initialize_table_from_text_file_v2( + init_op = gen_lookup_ops.initialize_table_from_text_file_v2( table.table_ref, filename, self._key_index, @@ -527,11 +518,10 @@ class TextFileInitializer(TableInitializerBase): -1 if self._vocab_size is None else self._vocab_size, self._delimiter, name=scope) - # pylint: enable=protected-access ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) # If the filename tensor is anything other than a string constant (e.g., if # it is a placeholder) then it does not make sense to track it as an asset. - if context.in_graph_mode() and constant_op.is_constant(filename): + if not context.executing_eagerly() and constant_op.is_constant(filename): ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename) return init_op diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index 7386976e93fbb82f38550f50429af878fadda813..34ca1adc3e13dc67560fb21d70c16cd42dc40552 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -89,14 +89,6 @@ def _safe_div(numerator, denominator, name="value"): Returns: The element-wise value of the numerator divided by the denominator. """ - if isinstance(denominator, float): - if math_ops.equal(denominator, 0.0): - return ops.convert_to_tensor(0.0, dtype=numerator.dtype) - return math_ops.div(numerator, denominator) - if context.in_eager_mode() and denominator._rank() == 0: # pylint: disable=protected-access - if math_ops.equal(denominator, 0.0): - return ops.convert_to_tensor(0.0, dtype=numerator.dtype) - return math_ops.div(numerator, denominator) return array_ops.where( math_ops.greater(denominator, 0), math_ops.div(numerator, array_ops.where( @@ -144,7 +136,7 @@ def _num_present(losses, weights, per_batch=False): `[batch_size]`. Otherwise, a single scalar tensor is returned. """ if ((isinstance(weights, float) and weights != 0.0) or - (context.in_eager_mode() and weights._rank() == 0 # pylint: disable=protected-access + (context.executing_eagerly() and weights._rank() == 0 # pylint: disable=protected-access and not math_ops.equal(weights, 0.0))): return _num_elements(losses) with ops.name_scope(None, "num_present", (losses, weights)) as scope: @@ -202,6 +194,11 @@ def compute_weighted_loss( """ Reduction.validate(reduction) with ops.name_scope(scope, "weighted_loss", (losses, weights)): + # Save the `reduction` argument for loss normalization when distributing + # to multiple towers. + # TODO(josh11b): Associate it with the returned op for more precision. + ops.get_default_graph()._last_loss_reduction = reduction # pylint: disable=protected-access + with ops.control_dependencies(( weights_broadcast_ops.assert_broadcastable(weights, losses),)): losses = ops.convert_to_tensor(losses) diff --git a/tensorflow/python/ops/manip_ops.py b/tensorflow/python/ops/manip_ops.py index 91e15b47b9400f29425af2f186c7c44ee6a5a622..6d335cdc212f368e7667a030791c7b634113a9c6 100644 --- a/tensorflow/python/ops/manip_ops.py +++ b/tensorflow/python/ops/manip_ops.py @@ -23,9 +23,11 @@ from __future__ import print_function from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export # pylint: disable=protected-access +@tf_export('manip.roll') def roll(input, shift, axis): # pylint: disable=redefined-builtin return _gen_manip_ops.roll(input, shift, axis) diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 9e7f37d80fdd71e84516ab450d145d79519ae47a..02e07dc7b1f5fe6a671da967f6d07cef123d3d1e 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -35,6 +35,18 @@ def _safe_shape_div(x, y): return x // math_ops.maximum(y, 1) +@ops.RegisterGradient("ArgMax") +def _ArgMaxGrad(op, grad): + del op, grad + return [None, None] + + +@ops.RegisterGradient("ArgMin") +def _ArgMinGrad(op, grad): + del op, grad + return [None, None] + + @ops.RegisterGradient("Sum") def _SumGrad(op, grad): """Gradient for Sum.""" @@ -46,10 +58,18 @@ def _SumGrad(op, grad): if axes is not None: rank = len(input_0_shape) if np.array_equal(axes, np.arange(rank)): # Reduce all dims. - grad = array_ops.reshape(grad, [1] * rank) + if context.executing_eagerly(): + ctx = context.context() + new_shape = ctx.ones_rank_cache().get(rank) + if new_shape is None: + new_shape = constant_op.constant([1] * rank, dtype=dtypes.int32) + ctx.ones_rank_cache().put(rank, new_shape) + else: + new_shape = [1] * rank + grad = array_ops.reshape(grad, new_shape) # If shape is not fully defined (but rank is), we use Shape. if None not in input_0_shape: - input_shape = input_0_shape + input_shape = constant_op.constant(input_0_shape, dtype=dtypes.int32) else: input_shape = array_ops.shape(op.inputs[0]) return [array_ops.tile(grad, input_shape), None] @@ -382,16 +402,14 @@ def _NegGrad(_, grad): def _InvGrad(op, grad): """Returns -grad * (1 / x^2).""" y = op.outputs[0] # y = 1 / x - # pylint: disable=protected-access - return gen_math_ops._reciprocal_grad(y, grad) + return gen_math_ops.reciprocal_grad(y, grad) @ops.RegisterGradient("Reciprocal") def _ReciprocalGrad(op, grad): """Returns -grad * (1 / x^2).""" y = op.outputs[0] # y = 1 / x - # pylint: disable=protected-access - return gen_math_ops._reciprocal_grad(y, grad) + return gen_math_ops.reciprocal_grad(y, grad) @ops.RegisterGradient("InvGrad") @@ -401,8 +419,7 @@ def _InvGradGrad(op, grad): with ops.control_dependencies([grad]): ca = math_ops.conj(op.inputs[0]) cg = math_ops.conj(grad) - # pylint: disable=protected-access - return cg * -2.0 * b * ca, gen_math_ops._reciprocal_grad(ca, grad) + return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad) @ops.RegisterGradient("ReciprocalGrad") @@ -412,8 +429,7 @@ def _ReciprocalGradGrad(op, grad): with ops.control_dependencies([grad]): ca = math_ops.conj(op.inputs[0]) cg = math_ops.conj(grad) - # pylint: disable=protected-access - return cg * -2.0 * b * ca, gen_math_ops._reciprocal_grad(ca, grad) + return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad) @ops.RegisterGradient("Square") @@ -422,15 +438,14 @@ def _SquareGrad(op, grad): # Added control dependencies to prevent 2*x from being computed too early. with ops.control_dependencies([grad]): x = math_ops.conj(x) - return math_ops.multiply(grad, math_ops.multiply(x, 2.0)) + y = constant_op.constant(2.0, dtype=x.dtype) + return math_ops.multiply(grad, math_ops.multiply(x, y)) @ops.RegisterGradient("Sqrt") def _SqrtGrad(op, grad): y = op.outputs[0] # y = x^(1/2) - # pylint: disable=protected-access - return gen_math_ops._sqrt_grad(y, grad) - # pylint: enable=protected-access + return gen_math_ops.sqrt_grad(y, grad) @ops.RegisterGradient("SqrtGrad") @@ -446,9 +461,7 @@ def _SqrtGradGrad(op, grad): def _RsqrtGrad(op, grad): """Returns -0.5 * grad * conj(y)^3.""" y = op.outputs[0] # y = x^(-1/2) - # pylint: disable=protected-access - return gen_math_ops._rsqrt_grad(y, grad) - # pylint: enable=protected-access + return gen_math_ops.rsqrt_grad(y, grad) @ops.RegisterGradient("RsqrtGrad") @@ -460,8 +473,7 @@ def _RsqrtGradGrad(op, grad): ca = math_ops.conj(a) cg = math_ops.conj(grad) grad_a = -1.5 * cg * b * math_ops.square(ca) - # pylint: disable=protected-access - grad_b = gen_math_ops._rsqrt_grad(ca, grad) + grad_b = gen_math_ops.rsqrt_grad(ca, grad) return grad_a, grad_b @@ -526,8 +538,7 @@ def _TanhGrad(op, grad): y = op.outputs[0] # y = tanh(x) with ops.control_dependencies([grad]): y = math_ops.conj(y) - # pylint: disable=protected-access - return gen_math_ops._tanh_grad(y, grad) + return gen_math_ops.tanh_grad(y, grad) @ops.RegisterGradient("Asinh") @@ -565,8 +576,7 @@ def _TanhGradGrad(op, grad): with ops.control_dependencies([grad]): a = math_ops.conj(op.inputs[0]) b = math_ops.conj(op.inputs[1]) - # pylint: disable=protected-access - return grad * -2.0 * b * a, gen_math_ops._tanh_grad(a, grad) + return grad * -2.0 * b * a, gen_math_ops.tanh_grad(a, grad) @ops.RegisterGradient("Erf") @@ -616,9 +626,7 @@ def _IgammaGrad(op, grad): x = op.inputs[1] sa = array_ops.shape(a) sx = array_ops.shape(x) - # pylint: disable=protected-access - unused_ra, rx = gen_array_ops._broadcast_gradient_args(sa, sx) - # pylint: enable=protected-access + unused_ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx) # Perform operations in log space before summing, because Gamma(a) # and Gamma'(a) can grow large. @@ -645,9 +653,7 @@ def _BetaincGrad(op, grad): # versa; so its sufficient to check against shape(a). sa = array_ops.shape(a) sx = array_ops.shape(x) - # pylint: disable=protected-access - _, rx = gen_array_ops._broadcast_gradient_args(sa, sx) - # pylint: enable=protected-access + _, rx = gen_array_ops.broadcast_gradient_args(sa, sx) # Perform operations in log space before summing, because terms # can grow large. @@ -673,9 +679,7 @@ def _ZetaGrad(op, grad): # Broadcast gradients sx = array_ops.shape(x) sq = array_ops.shape(q) - # pylint: disable=protected-access - unused_rx, rq = gen_array_ops._broadcast_gradient_args(sx, sq) - # pylint: enable=protected-access + unused_rx, rq = gen_array_ops.broadcast_gradient_args(sx, sq) # Evaluate gradient with ops.control_dependencies([grad]): x = math_ops.conj(x) @@ -695,9 +699,7 @@ def _PolygammaGrad(op, grad): # Broadcast gradients sn = array_ops.shape(n) sx = array_ops.shape(x) - # pylint: disable=protected-access - unused_rn, rx = gen_array_ops._broadcast_gradient_args(sn, sx) - # pylint: enable=protected-access + unused_rn, rx = gen_array_ops.broadcast_gradient_args(sn, sx) # Evaluate gradient with ops.control_dependencies([grad]): n = math_ops.conj(n) @@ -714,8 +716,7 @@ def _SigmoidGrad(op, grad): y = op.outputs[0] # y = sigmoid(x) with ops.control_dependencies([grad]): y = math_ops.conj(y) - # pylint: disable=protected-access - return gen_math_ops._sigmoid_grad(y, grad) + return gen_math_ops.sigmoid_grad(y, grad) @ops.RegisterGradient("SigmoidGrad") @@ -724,8 +725,7 @@ def _SigmoidGradGrad(op, grad): a = math_ops.conj(op.inputs[0]) b = math_ops.conj(op.inputs[1]) gb = grad * b - # pylint: disable=protected-access - return gb - 2.0 * gb * a, gen_math_ops._sigmoid_grad(a, grad) + return gb - 2.0 * gb * a, gen_math_ops.sigmoid_grad(a, grad) @ops.RegisterGradient("Sign") @@ -839,9 +839,7 @@ def _AddGrad(op, grad): return grad, grad sx = array_ops.shape(x) sy = array_ops.shape(y) - # pylint: disable=protected-access - rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) - # pylint: enable=protected-access + rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx), array_ops.reshape(math_ops.reduce_sum(grad, ry), sy)) @@ -856,9 +854,7 @@ def _SubGrad(op, grad): return grad, -grad sx = array_ops.shape(x) sy = array_ops.shape(y) - # pylint: disable=protected-access - rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) - # pylint: enable=protected-access + rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx), array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy)) @@ -868,22 +864,20 @@ def _MulGrad(op, grad): """The gradient of scalar multiplication.""" x = op.inputs[0] y = op.inputs[1] - # pylint: disable=protected-access if (isinstance(grad, ops.Tensor) and _ShapesFullySpecifiedAndEqual(x, y, grad) and grad.dtype in (dtypes.int32, dtypes.float32)): - return gen_math_ops._mul(grad, y), gen_math_ops._mul(grad, x) + return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x) assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype) sx = array_ops.shape(x) sy = array_ops.shape(y) - rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) + rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) x = math_ops.conj(x) y = math_ops.conj(y) return (array_ops.reshape( - math_ops.reduce_sum(gen_math_ops._mul(grad, y), rx), sx), + math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx), array_ops.reshape( - math_ops.reduce_sum(gen_math_ops._mul(x, grad), ry), sy)) - # pylint: enable=protected-access + math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy)) @ops.RegisterGradient("Div") @@ -893,9 +887,7 @@ def _DivGrad(op, grad): y = op.inputs[1] sx = array_ops.shape(x) sy = array_ops.shape(y) - # pylint: disable=protected-access - rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) - # pylint: enable=protected-access + rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) x = math_ops.conj(x) y = math_ops.conj(y) return (array_ops.reshape(math_ops.reduce_sum(math_ops.div(grad, y), rx), sx), @@ -918,9 +910,7 @@ def _FloorModGrad(op, grad): sx = array_ops.shape(x) sy = array_ops.shape(y) - # pylint: disable=protected-access - rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) - # pylint: enable=protected-access + rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) floor_xy = math_ops.floor_div(x, y) gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx) gy = array_ops.reshape( @@ -940,9 +930,7 @@ def _RealDivGrad(op, grad): y = op.inputs[1] sx = array_ops.shape(x) sy = array_ops.shape(y) - # pylint: disable=protected-access - rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) - # pylint: enable=protected-access + rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) x = math_ops.conj(x) y = math_ops.conj(y) return (array_ops.reshape( @@ -960,7 +948,7 @@ def _PowGrad(op, grad): z = op.outputs[0] sx = array_ops.shape(x) sy = array_ops.shape(y) - rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) + rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) x = math_ops.conj(x) y = math_ops.conj(y) z = math_ops.conj(z) @@ -988,7 +976,7 @@ def _MaximumMinimumGrad(op, grad, selector_op): gradshape = array_ops.shape(grad) zeros = array_ops.zeros(gradshape, gdtype) xmask = selector_op(x, y) - rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) + rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) xgrad = array_ops.where(xmask, grad, zeros) ygrad = array_ops.where(xmask, zeros, grad) gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx) @@ -1015,9 +1003,7 @@ def _SquaredDifferenceGrad(op, grad): y = op.inputs[1] sx = array_ops.shape(x) sy = array_ops.shape(y) - # pylint: disable=protected-access - rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) - # pylint: enable=protected-access + rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) with ops.control_dependencies([grad]): # The parens ensure that if grad is IndexedSlices, it'll get multiplied by # Tensor (not a number like 2.0) which causes it to convert to Tensor. @@ -1056,20 +1042,18 @@ def _MatMulGrad(op, grad): t_b = op.get_attr("transpose_b") a = math_ops.conj(op.inputs[0]) b = math_ops.conj(op.inputs[1]) - # pylint: disable=protected-access if not t_a and not t_b: - grad_a = gen_math_ops._mat_mul(grad, b, transpose_b=True) - grad_b = gen_math_ops._mat_mul(a, grad, transpose_a=True) + grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True) + grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True) elif not t_a and t_b: - grad_a = gen_math_ops._mat_mul(grad, b) - grad_b = gen_math_ops._mat_mul(grad, a, transpose_a=True) + grad_a = gen_math_ops.mat_mul(grad, b) + grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True) elif t_a and not t_b: - grad_a = gen_math_ops._mat_mul(b, grad, transpose_b=True) - grad_b = gen_math_ops._mat_mul(a, grad) + grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True) + grad_b = gen_math_ops.mat_mul(a, grad) elif t_a and t_b: - grad_a = gen_math_ops._mat_mul(b, grad, transpose_a=True, transpose_b=True) - grad_b = gen_math_ops._mat_mul(grad, a, transpose_a=True, transpose_b=True) - # pylint: enable=protected-access + grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True) + grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True) return grad_a, grad_b @@ -1083,7 +1067,7 @@ def _SparseMatMulGrad(op, grad): op.inputs[0]: op.get_attr("a_is_sparse"), op.inputs[1]: op.get_attr("b_is_sparse"), # Use heuristic to figure out if grad might be sparse - grad: context.in_graph_mode() and (grad.op.type == "ReluGrad") + grad: not context.executing_eagerly() and (grad.op.type == "ReluGrad") } def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False): @@ -1183,7 +1167,7 @@ def _ComplexGrad(op, grad): y = op.inputs[1] sx = array_ops.shape(x) sy = array_ops.shape(y) - rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) + rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) return (array_ops.reshape(math_ops.reduce_sum(math_ops.real(grad), rx), sx), array_ops.reshape(math_ops.reduce_sum(math_ops.imag(grad), ry), sy)) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 2ae8b610da04d9762284d51b9c9f28a8c07e24f7..276897ab99e5e8770b72cb1eb27d07fb8dbc08bb 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -89,8 +89,6 @@ See the @{$python/math_ops} guide. @@matrix_inverse @@cholesky @@cholesky_solve -@@matrix_exponential -@@matrix_logarithm @@matrix_solve @@matrix_triangular_solve @@matrix_solve_ls @@ -129,8 +127,11 @@ See the @{$python/math_ops} guide. @@segment_min @@segment_max @@segment_mean +@@to_complex128 +@@to_complex64 @@unsorted_segment_sum @@unsorted_segment_max +@@unsorted_segment_mean @@unsorted_segment_min @@unsorted_segment_prod @@unsorted_segment_sqrt_n @@ -161,14 +162,12 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gen_sparse_ops from tensorflow.python.ops import gen_spectral_ops -from tensorflow.python.ops import gen_state_ops -from tensorflow.python.ops import state_ops +from tensorflow.python.platform import tf_logging as logging # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_math_ops import * @@ -182,6 +181,13 @@ linspace = gen_math_ops.lin_space arg_max = deprecation.deprecated(None, "Use `argmax` instead")(arg_max) # pylint: disable=used-before-assignment arg_min = deprecation.deprecated(None, "Use `argmin` instead")(arg_min) # pylint: disable=used-before-assignment +tf_export("arg_max")(arg_max) +tf_export("arg_min")(arg_min) + + +# This is set by resource_variable_ops.py. It is included in this way since +# there is a circular dependency between math_ops and resource_variable_ops +_resource_variable_type = None def _set_doc(doc): @@ -266,7 +272,7 @@ def abs(x, name=None): # pylint: disable=redefined-builtin with ops.name_scope(name, "Abs", [x]) as name: if isinstance(x, sparse_tensor.SparseTensor): if x.values.dtype.is_complex: - x_abs = gen_math_ops._complex_abs( + x_abs = gen_math_ops.complex_abs( x.values, Tout=x.values.dtype.real_dtype, name=name) return sparse_tensor.SparseTensor( indices=x.indices, values=x_abs, dense_shape=x.dense_shape) @@ -276,7 +282,7 @@ def abs(x, name=None): # pylint: disable=redefined-builtin else: x = ops.convert_to_tensor(x, name="x") if x.dtype.is_complex: - return gen_math_ops._complex_abs(x, Tout=x.dtype.real_dtype, name=name) + return gen_math_ops.complex_abs(x, Tout=x.dtype.real_dtype, name=name) return gen_math_ops._abs(x, name=name) @@ -285,7 +291,7 @@ def abs(x, name=None): # pylint: disable=redefined-builtin # pylint: disable=redefined-builtin def _bucketize(input, boundaries, name=None): - return gen_math_ops._bucketize(input=input, boundaries=boundaries, name=name) + return gen_math_ops.bucketize(input=input, boundaries=boundaries, name=name) # pylint: enable=redefined-builtin @@ -328,10 +334,10 @@ def divide(x, y, name=None): @tf_export("multiply") def multiply(x, y, name=None): - return gen_math_ops._mul(x, y, name) + return gen_math_ops.mul(x, y, name) -multiply.__doc__ = gen_math_ops._mul.__doc__.replace("Mul", "`tf.multiply`") +multiply.__doc__ = gen_math_ops.mul.__doc__.replace("Multiply", "`tf.multiply`") # TODO(aselle): put deprecation in after another round of global code changes @@ -339,19 +345,19 @@ multiply.__doc__ = gen_math_ops._mul.__doc__.replace("Mul", "`tf.multiply`") "2016-12-30", "`tf.mul(x, y)` is deprecated, please use `tf.multiply(x, y)` or `x * y`") def _mul(x, y, name=None): - return gen_math_ops._mul(x, y, name) + return gen_math_ops.mul(x, y, name) _mul.__doc__ = ( - gen_math_ops._mul.__doc__ + ("" if _mul.__doc__ is None else _mul.__doc__)) + gen_math_ops.mul.__doc__ + ("" if _mul.__doc__ is None else _mul.__doc__)) @tf_export("subtract") def subtract(x, y, name=None): - return gen_math_ops._sub(x, y, name) + return gen_math_ops.sub(x, y, name) -subtract.__doc__ = gen_math_ops._sub.__doc__.replace("`Sub`", "`tf.subtract`") +subtract.__doc__ = gen_math_ops.sub.__doc__.replace("`Sub`", "`tf.subtract`") # TODO(aselle): put deprecation in after another round of global code changes @@ -359,11 +365,11 @@ subtract.__doc__ = gen_math_ops._sub.__doc__.replace("`Sub`", "`tf.subtract`") "2016-12-30", "`tf.sub(x, y)` is deprecated, please use `tf.subtract(x, y)` or `x - y`") def _sub(x, y, name=None): - return gen_math_ops._sub(x, y, name) + return gen_math_ops.sub(x, y, name) _sub.__doc__ = ( - gen_math_ops._sub.__doc__ + ("" if _sub.__doc__ is None else _sub.__doc__)) + gen_math_ops.sub.__doc__ + ("" if _sub.__doc__ is None else _sub.__doc__)) # pylint: disable=g-docstring-has-escape @@ -383,11 +389,11 @@ def negative(x, name=None): """ with ops.name_scope(name, "Neg", [x]) as name: if isinstance(x, sparse_tensor.SparseTensor): - x_neg = gen_math_ops._neg(x.values, name=name) + x_neg = gen_math_ops.neg(x.values, name=name) return sparse_tensor.SparseTensor( indices=x.indices, values=x_neg, dense_shape=x.dense_shape) else: - return gen_math_ops._neg(x, name=name) + return gen_math_ops.neg(x, name=name) # pylint: enable=g-docstring-has-escape @@ -770,16 +776,18 @@ def cast(x, dtype, name=None): with ops.name_scope(name, "Cast", [x]) as name: if isinstance(x, sparse_tensor.SparseTensor): values_cast = cast(x.values, base_type, name=name) - return sparse_tensor.SparseTensor(x.indices, values_cast, x.dense_shape) + x = sparse_tensor.SparseTensor(x.indices, values_cast, x.dense_shape) else: # TODO(josh11b): If x is not already a Tensor, we could return # ops.convert_to_tensor(x, dtype=dtype, ...) here, but that # allows some conversions that cast() can't do, e.g. casting numbers to # strings. x = ops.convert_to_tensor(x, name="x") - if x.dtype.base_dtype == base_type: - return x - return gen_math_ops.cast(x, base_type, name=name) + if x.dtype.base_dtype != base_type: + x = gen_math_ops.cast(x, base_type, name=name) + if x.dtype.is_complex and base_type.is_floating: + logging.warn("Casting complex to real discards imaginary part.") + return x @tf_export("saturate_cast") @@ -935,7 +943,7 @@ def to_complex128(x, name="ToComplex128"): return cast(x, dtypes.complex128, name=name) -ops.Tensor._override_operator("__neg__", gen_math_ops._neg) +ops.Tensor._override_operator("__neg__", gen_math_ops.neg) ops.Tensor._override_operator("__abs__", abs) # __invert__ corresponds to the ~ operator. Here we follow the numpy convention # ~ marks an elementwise bit-wise inverse. This is only implemented for boolean @@ -1064,7 +1072,7 @@ def _truediv_python3(x, y, name=None): if dtype is not None: x = cast(x, dtype) y = cast(y, dtype) - return gen_math_ops._real_div(x, y, name=name) + return gen_math_ops.real_div(x, y, name=name) def _div_python2(x, y, name=None): @@ -1087,9 +1095,9 @@ def _div_python2(x, y, name=None): raise TypeError("x and y must have the same dtype, got %r != %r" % (x_dtype, y_dtype)) if x_dtype.is_floating or x_dtype.is_complex: - return gen_math_ops._real_div(x, y, name=name) + return gen_math_ops.real_div(x, y, name=name) else: - return gen_math_ops._floor_div(x, y, name=name) + return gen_math_ops.floor_div(x, y, name=name) @tf_export("truediv") @@ -1147,7 +1155,7 @@ def div(x, y, name=None): # TODO(aselle): This should be removed -mod = gen_math_ops._floor_mod +mod = gen_math_ops.floor_mod # TODO(aselle): Deprecate this once all internal functionality uses @@ -1180,22 +1188,27 @@ def floordiv(x, y, name=None): TypeError: If the inputs are complex. """ with ops.name_scope(name, "floordiv", [x, y]) as name: - return gen_math_ops._floor_div(x, y, name=name) + return gen_math_ops.floor_div(x, y, name=name) -realdiv = gen_math_ops._real_div -truncatediv = gen_math_ops._truncate_div +realdiv = gen_math_ops.real_div +tf_export("realdiv")(realdiv) +truncatediv = gen_math_ops.truncate_div +tf_export("truncatediv")(truncatediv) # TODO(aselle): Rename this to floordiv when we can. -floor_div = gen_math_ops._floor_div -truncatemod = gen_math_ops._truncate_mod -floormod = gen_math_ops._floor_mod +floor_div = gen_math_ops.floor_div +tf_export("floor_div")(floor_div) +truncatemod = gen_math_ops.truncate_mod +tf_export("truncatemod")(truncatemod) +floormod = gen_math_ops.floor_mod +tf_export("floormod", "mod")(floormod) def _mul_dispatch(x, y, name=None): """Dispatches cwise mul for "Dense*Dense" and "Dense*Sparse".""" is_tensor_y = isinstance(y, ops.Tensor) if is_tensor_y: - return gen_math_ops._mul(x, y, name=name) + return gen_math_ops.mul(x, y, name=name) else: assert isinstance(y, sparse_tensor.SparseTensor) # Case: Dense * Sparse. new_vals = gen_sparse_ops.sparse_dense_cwise_mul(y.indices, y.values, @@ -1214,12 +1227,12 @@ _OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_mul, "mul", sparse_tensor.SparseTensor) _OverrideBinaryOperatorHelper(gen_math_ops.add, "add") -_OverrideBinaryOperatorHelper(gen_math_ops._sub, "sub") +_OverrideBinaryOperatorHelper(gen_math_ops.sub, "sub") _OverrideBinaryOperatorHelper(_mul_dispatch, "mul") _OverrideBinaryOperatorHelper(_div_python2, "div") _OverrideBinaryOperatorHelper(_truediv_python3, "truediv") _OverrideBinaryOperatorHelper(floordiv, "floordiv") -_OverrideBinaryOperatorHelper(gen_math_ops._floor_mod, "mod") +_OverrideBinaryOperatorHelper(gen_math_ops.floor_mod, "mod") _OverrideBinaryOperatorHelper(pow, "pow") @@ -1541,7 +1554,7 @@ def reduce_mean(input_tensor, if keepdims is None: keepdims = False return _may_reduce_to_scalar(keepdims, axis, reduction_indices, - gen_math_ops._mean( + gen_math_ops.mean( input_tensor, _ReductionDims(input_tensor, axis, reduction_indices), @@ -1591,7 +1604,7 @@ def reduce_prod(input_tensor, if keepdims is None: keepdims = False return _may_reduce_to_scalar(keepdims, axis, reduction_indices, - gen_math_ops._prod( + gen_math_ops.prod( input_tensor, _ReductionDims(input_tensor, axis, reduction_indices), @@ -2044,8 +2057,15 @@ def matmul(a, if transpose_b and adjoint_b: raise ValueError("Only one of transpose_b and adjoint_b can be True.") - a = ops.convert_to_tensor(a, name="a") - b = ops.convert_to_tensor(b, name="b") + if context.executing_eagerly(): + if not isinstance(a, (ops.EagerTensor, _resource_variable_type)): + a = ops.convert_to_tensor(a, name="a") + if not isinstance(b, (ops.EagerTensor, _resource_variable_type)): + b = ops.convert_to_tensor(b, name="b") + else: + a = ops.convert_to_tensor(a, name="a") + b = ops.convert_to_tensor(b, name="b") + # TODO(apassos) remove _shape_tuple here when it is not needed. a_shape = a._shape_tuple() # pylint: disable=protected-access b_shape = b._shape_tuple() # pylint: disable=protected-access @@ -2060,7 +2080,7 @@ def matmul(a, if transpose_b: b = conj(b) adjoint_b = True - return gen_math_ops._batch_mat_mul( + return gen_math_ops.batch_mat_mul( a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name) # Neither matmul nor sparse_matmul support adjoint, so we conjugate @@ -2078,8 +2098,9 @@ def matmul(a, sparse_matmul_types = [dtypes.bfloat16, dtypes.float32] use_sparse_matmul = ( a.dtype in sparse_matmul_types and b.dtype in sparse_matmul_types) - if a.dtype == dtypes.bfloat16 or b.dtype == dtypes.bfloat16: - # matmul currently doesn't handle bfloat16 inputs. + if (a.dtype == dtypes.bfloat16 or b.dtype == dtypes.bfloat16 and + a.dtype != b.dtype): + # matmul currently doesn't handle mixed-precision inputs. use_sparse_matmul = True if use_sparse_matmul: ret = sparse_matmul( @@ -2097,13 +2118,14 @@ def matmul(a, ret = cast(ret, dtypes.bfloat16) return ret else: - return gen_math_ops._mat_mul( + return gen_math_ops.mat_mul( a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name) _OverrideBinaryOperatorHelper(matmul, "matmul") -sparse_matmul = gen_math_ops._sparse_mat_mul +sparse_matmul = gen_math_ops.sparse_mat_mul +tf_export("sparse_matmul")(sparse_matmul) @ops.RegisterStatistics("MatMul", "flops") @@ -2208,7 +2230,7 @@ def add_n(inputs, name=None): if name: return array_ops.identity(inputs[0], name=name) return inputs[0] - return gen_math_ops._add_n(inputs, name=name) + return gen_math_ops.add_n(inputs, name=name) @tf_export("accumulate_n") @@ -2218,14 +2240,12 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): Optionally, pass `shape` and `tensor_dtype` for shape and type checking, otherwise, these are inferred. - NOTE: This operation is not differentiable and cannot be used if inputs depend - on trainable variables. Please use `tf.add_n` for such cases. + `tf.accumulate_n` performs the same operation as `tf.add_n`, but does not + wait for all of its inputs to be ready before beginning to sum. This can + save memory if inputs are ready at different times, since minimum temporary + storage is proportional to the output size rather than the inputs size. - Aside from differentiability, `tf.accumulate_n` performs the same operation as - `tf.add_n`, but does not wait for all of its inputs to be ready before - beginning to sum. This can save memory if inputs are ready at different times, - since minimum temporary storage is proportional to the output size rather than - the inputs size. + `accumulate_n` is differentiable (but wasn't previous to TensorFlow 1.7). For example: @@ -2235,8 +2255,9 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): tf.accumulate_n([a, b, a]) # [[7, 4], [6, 14]] # Explicitly pass shape and type - tf.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) # [[7, 4], - # [6, 14]] + tf.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) + # [[7, 4], + # [6, 14]] ``` Args: @@ -2252,20 +2273,17 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): ValueError: If `inputs` don't all have same shape and dtype or the shape cannot be inferred. """ - if context.in_eager_mode(): - # TODO(apassos) remove this once the lifetime of eager variables gets - # addressed. - raise ValueError("accumulate_n not supported in eager mode") + def _input_error(): + return ValueError( + "inputs must be a list of at least one Tensor with the " + "same dtype and shape") if not inputs or not isinstance(inputs, (list, tuple)): - raise ValueError("inputs must be a list of at least one Tensor with the " - "same dtype and shape") + raise _input_error() inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs) if not all(isinstance(x, ops.Tensor) for x in inputs): - raise ValueError("inputs must be a list of at least one Tensor with the " - "same dtype and shape") + raise _input_error() if not all(x.dtype == inputs[0].dtype for x in inputs): - raise ValueError("inputs must be a list of at least one Tensor with the " - "same dtype and shape") + raise _input_error() if shape is not None: shape = tensor_shape.as_shape(shape) else: @@ -2273,27 +2291,31 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): for input_tensor in inputs: if isinstance(input_tensor, ops.Tensor): shape = shape.merge_with(input_tensor.get_shape()) - if tensor_dtype is None: - tensor_dtype = inputs[0].dtype - if tensor_dtype != inputs[0].dtype: - raise TypeError("tensor_dtype is {}, but input is of type {}".format( - tensor_dtype, inputs[0].dtype)) - if len(inputs) == 1: + + # tensor_dtype is for safety only; operator's output type computed in C++ + if tensor_dtype is not None and tensor_dtype != inputs[0].dtype: + raise TypeError("tensor_dtype is {}, but input is of type {}" + .format(tensor_dtype, inputs[0].dtype)) + + if len(inputs) == 1 and name is None: return inputs[0] - with ops.name_scope(name, "AccumulateN", inputs) as name: - var = gen_state_ops._temporary_variable( - shape=tensor_shape.vector(0), dtype=tensor_dtype) - with ops.colocate_with(var): - zeros = array_ops.zeros_like(gen_control_flow_ops._merge(inputs)[0]) - zeros.set_shape(shape) - ref = state_ops.assign(var, zeros, validate_shape=False) - update_ops = [ - state_ops.assign_add(ref, input_tensor, use_locking=True) - for input_tensor in inputs - ] - with ops.control_dependencies(update_ops): - return gen_state_ops._destroy_temporary_variable( - ref, var_name=var.op.name, name=name) + elif len(inputs) == 1 and name is not None: + return array_ops.identity(inputs[0], name=name) + elif context.executing_eagerly(): + # TemporaryVariable not currently supported in eager mode; fall back + # onto AddN for now. + # TODO(frreiss) remove this once the lifetime of eager variables gets + # addressed + return add_n(inputs, name=name) + else: + return gen_math_ops.accumulate_nv2(inputs, name=name, shape=shape) # pylint: disable=protected-access + + +@ops.RegisterGradient("AccumulateNV2") +def _accumulate_n_grad(op, grad): + """Same as gradient for AddN. Copies the gradient to all inputs.""" + # Not broadcasting. + return [grad] * len(op.inputs) @tf_export("nn.sigmoid", "sigmoid") @@ -2316,7 +2338,7 @@ def sigmoid(x, name=None): """ with ops.name_scope(name, "Sigmoid", [x]) as name: x = ops.convert_to_tensor(x, name="x") - return gen_math_ops._sigmoid(x, name=name) + return gen_math_ops.sigmoid(x, name=name) @tf_export("log_sigmoid") @@ -2335,7 +2357,7 @@ def log_sigmoid(x, name=None): """ with ops.name_scope(name, "LogSigmoid", [x]) as name: x = ops.convert_to_tensor(x, name="x") - return gen_math_ops._neg(gen_nn_ops.softplus(-x), name=name) + return gen_math_ops.neg(gen_nn_ops.softplus(-x), name=name) @tf_export("nn.tanh", "tanh") @@ -2352,11 +2374,11 @@ def tanh(x, name=None): """ with ops.name_scope(name, "Tanh", [x]) as name: if isinstance(x, sparse_tensor.SparseTensor): - x_tanh = gen_math_ops._tanh(x.values, name=name) + x_tanh = gen_math_ops.tanh(x.values, name=name) return sparse_tensor.SparseTensor( indices=x.indices, values=x_tanh, dense_shape=x.dense_shape) else: - return gen_math_ops._tanh(x, name=name) + return gen_math_ops.tanh(x, name=name) @tf_export("bincount") @@ -2545,7 +2567,7 @@ def conj(x, name=None): with ops.name_scope(name, "Conj", [x]) as name: x = ops.convert_to_tensor(x, name="x") if x.dtype.is_complex or x.dtype == dtypes.variant: - return gen_math_ops._conj(x, name=name) + return gen_math_ops.conj(x, name=name) elif x.dtype.is_floating or x.dtype.is_integer: return x else: diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index d314124ccd9bc8b7676e6926830a8eb1e0315f5f..9f85188b3513563a7444f7a0e908f11af985498b 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -60,7 +60,7 @@ class ReduceTest(test_util.TensorFlowTestCase): @test_util.run_in_graph_and_eager_modes() def testReduceInvalidAxis(self): - if context.in_eager_mode(): + if context.executing_eagerly(): # The shape check is in run a graph construction time. In eager mode, # it misses the check, magically return result given wrong shape. return @@ -249,7 +249,7 @@ class ScalarMulTest(test_util.TensorFlowTestCase): @test_util.run_in_graph_and_eager_modes() def testAcceptsRefs(self): - if context.in_eager_mode(): + if context.executing_eagerly(): var = resource_variable_ops.ResourceVariable(10, name="var") else: var = variables.Variable(10) diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 043c0e30cd8476b1a91e136df60edfbedf85ab24..9ec49545796cfa7a603b31c23bfd0d495639898d 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -308,7 +308,7 @@ def mean(values, or tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.mean is not supported when eager execution ' 'is enabled.') @@ -394,7 +394,7 @@ def accuracy(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.accuracy is not supported when eager ' 'execution is enabled.') @@ -644,7 +644,7 @@ def auc(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.auc is not supported when eager execution ' 'is enabled.') @@ -672,7 +672,7 @@ def auc(labels, x = fp_rate y = rec else: # curve == 'PR'. - prec = math_ops.div(tp, tp + fp + epsilon) + prec = math_ops.div(tp + epsilon, tp + fp + epsilon) x = rec y = prec if summation_method == 'trapezoidal': @@ -758,7 +758,7 @@ def mean_absolute_error(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.mean_absolute_error is not supported ' 'when eager execution is enabled.') @@ -818,7 +818,7 @@ def mean_cosine_distance(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when ' 'eager execution is enabled.') @@ -891,7 +891,7 @@ def mean_per_class_accuracy(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported ' 'when eager execution is enabled.') @@ -923,8 +923,8 @@ def mean_per_class_accuracy(labels, weights = array_ops.reshape(weights, [-1]) weights = math_ops.to_float(weights) - is_correct *= weights - ones *= weights + is_correct = is_correct * weights + ones = ones * weights update_total_op = state_ops.scatter_add(total, labels, ones) update_count_op = state_ops.scatter_add(count, labels, is_correct) @@ -996,7 +996,7 @@ def mean_iou(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.mean_iou is not supported when ' 'eager execution is enabled.') @@ -1098,7 +1098,7 @@ def mean_relative_error(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.mean_relative_error is not supported when ' 'eager execution is enabled.') @@ -1165,7 +1165,7 @@ def mean_squared_error(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.mean_squared_error is not supported when ' 'eager execution is enabled.') @@ -1223,7 +1223,7 @@ def mean_tensor(values, or tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.mean_tensor is not supported when ' 'eager execution is enabled.') @@ -1304,7 +1304,7 @@ def percentage_below(values, or tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.percentage_below is not supported when ' 'eager execution is enabled.') @@ -1397,7 +1397,7 @@ def false_negatives(labels, or tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.false_negatives is not supported when ' 'eager execution is enabled.') @@ -1453,7 +1453,7 @@ def false_negatives_at_thresholds(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not ' 'supported when eager execution is enabled.') @@ -1507,7 +1507,7 @@ def false_positives(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.false_positives is not supported when ' 'eager execution is enabled.') @@ -1563,7 +1563,7 @@ def false_positives_at_thresholds(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.false_positives_at_thresholds is not ' 'supported when eager execution is enabled.') @@ -1617,7 +1617,7 @@ def true_negatives(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.true_negatives is not ' 'supported when eager execution is enabled.') @@ -1673,7 +1673,7 @@ def true_negatives_at_thresholds(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not ' 'supported when eager execution is enabled.') @@ -1727,7 +1727,7 @@ def true_positives(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.true_positives is not ' 'supported when eager execution is enabled.') @@ -1783,7 +1783,7 @@ def true_positives_at_thresholds(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.true_positives_at_thresholds is not ' 'supported when eager execution is enabled.') @@ -1851,7 +1851,7 @@ def precision(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.precision is not ' 'supported when eager execution is enabled.') @@ -1947,7 +1947,7 @@ def precision_at_thresholds(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.precision_at_thresholds is not ' 'supported when eager execution is enabled.') @@ -2023,7 +2023,7 @@ def recall(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.recall is not supported is not ' 'supported when eager execution is enabled.') @@ -2400,7 +2400,7 @@ def recall_at_k(labels, are not a list or tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.recall_at_k is not ' 'supported when eager execution is enabled.') @@ -2549,7 +2549,7 @@ def recall_at_thresholds(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.recall_at_thresholds is not ' 'supported when eager execution is enabled.') @@ -2626,7 +2626,7 @@ def root_mean_squared_error(labels, tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.root_mean_squared_error is not ' 'supported when eager execution is enabled.') @@ -2707,7 +2707,7 @@ def sensitivity_at_specificity(labels, or `updates_collections` are not a list or tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.sensitivity_at_specificity is not ' 'supported when eager execution is enabled.') @@ -3098,7 +3098,7 @@ def average_precision_at_k(labels, ValueError: if k is invalid. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not ' 'supported when eager execution is enabled.') @@ -3267,7 +3267,7 @@ def precision_at_top_k(labels, are not a list or tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.precision_at_top_k is not ' 'supported when eager execution is enabled.') @@ -3396,7 +3396,7 @@ def precision_at_k(labels, are not a list or tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.sparse_precision_at_k is not ' 'supported when eager execution is enabled.') @@ -3473,7 +3473,7 @@ def specificity_at_sensitivity(labels, or `updates_collections` are not a list or tuple. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError('tf.metrics.specificity_at_sensitivity is not ' 'supported when eager execution is enabled.') diff --git a/tensorflow/python/ops/nn_batchnorm_test.py b/tensorflow/python/ops/nn_batchnorm_test.py index eebfb17085a568f48769f6df7dddd3ae2f799efc..3ac2c8eb17ef31b46638ce50e0e9f9705adce189 100644 --- a/tensorflow/python/ops/nn_batchnorm_test.py +++ b/tensorflow/python/ops/nn_batchnorm_test.py @@ -57,7 +57,6 @@ class BatchNormalizationTest(test.TestCase): test_util.set_producer_version(ops.get_default_graph(), 8) return gen_nn_ops._batch_norm_with_global_normalization( x, m, v, beta, gamma, epsilon, scale_after_normalization) - # pylint: enable=protected-access def _tfBatchNormV1BW(self, x, m, v, beta, gamma, epsilon, scale_after_normalization): @@ -223,7 +222,7 @@ class BatchNormalizationTest(test.TestCase): for scale_after_normalization in [True, False]: # _batch_norm_with_global_normalization_grad is deprecated in v9 test_util.set_producer_version(ops.get_default_graph(), 8) - grad = gen_nn_ops._batch_norm_with_global_normalization_grad( + grad = gen_nn_ops.batch_norm_with_global_normalization_grad( x, m, v, gamma, backprop, epsilon, scale_after_normalization) dx, dm, dv, db, dg = grad self.assertEqual(grad.dx, dx) diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index dc24b821a5580e3581f153f3cbf63ad2868b8a18..4af5bd26dd80b984b1c898411c2a23827bed1b4b 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -150,7 +150,7 @@ def _Conv3DBackpropFilterGrad(op, grad): @ops.RegisterGradient("AvgPool3D") def _AvgPool3DGrad(op, grad): - return gen_nn_ops._avg_pool3d_grad( + return gen_nn_ops.avg_pool3d_grad( array_ops.shape(op.inputs[0]), grad, ksize=op.get_attr("ksize"), @@ -172,7 +172,7 @@ def _AvgPool3DGradGrad(op, grad): @ops.RegisterGradient("MaxPool3D") def _MaxPool3DGrad(op, grad): - return gen_nn_ops._max_pool3d_grad( + return gen_nn_ops.max_pool3d_grad( op.inputs[0], op.outputs[0], grad, @@ -188,7 +188,7 @@ def _MaxPool3DGradGrad(op, grad): shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), array_ops.zeros( shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), - gen_nn_ops._max_pool3d_grad_grad( + gen_nn_ops.max_pool3d_grad_grad( op.inputs[0], op.inputs[1], grad, @@ -204,7 +204,7 @@ def _MaxPool3DGradGradGrad(op, grad): shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), array_ops.zeros( shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), - gen_nn_ops._max_pool3d_grad( + gen_nn_ops.max_pool3d_grad( op.inputs[0], op.inputs[1], grad, @@ -352,13 +352,13 @@ def _BiasAddGradV1(unused_bias_op, received_grad): @ops.RegisterGradient("Relu") def _ReluGrad(op, grad): - return gen_nn_ops._relu_grad(grad, op.outputs[0]) + return gen_nn_ops.relu_grad(grad, op.outputs[0]) @ops.RegisterGradient("EluGrad") def _EluGradGrad(op, grad): elu_x = op.inputs[1] - return (gen_nn_ops._elu_grad(grad, op.outputs[0]), + return (gen_nn_ops.elu_grad(grad, op.outputs[0]), array_ops.where(elu_x < 0, grad * op.inputs[0], array_ops.zeros( shape=array_ops.shape(elu_x), dtype=elu_x.dtype))) @@ -368,63 +368,63 @@ def _EluGradGrad(op, grad): def _SeluGradGrad(op, grad): x = op.inputs[1] scale_alpha = 1.7580993408473768599402175208123 - return (gen_nn_ops._elu_grad(grad, op.outputs[0]), + return (gen_nn_ops.elu_grad(grad, op.outputs[0]), array_ops.where(x < 0., - gen_nn_ops._elu_grad(grad, - op.outputs[0] + scale_alpha), + gen_nn_ops.elu_grad(grad, + op.outputs[0] + scale_alpha), array_ops.zeros( shape=array_ops.shape(x), dtype=x.dtype))) @ops.RegisterGradient("Relu6") def _Relu6Grad(op, grad): - return gen_nn_ops._relu6_grad(grad, op.outputs[0]) # pylint: disable=protected-access + return gen_nn_ops.relu6_grad(grad, op.outputs[0]) @ops.RegisterGradient("Relu6Grad") def _Relu6GradGrad(op, grad): x = op.inputs[1] - return (gen_nn_ops._relu6_grad(grad, x), + return (gen_nn_ops.relu6_grad(grad, x), array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) @ops.RegisterGradient("Elu") def _EluGrad(op, grad): - return gen_nn_ops._elu_grad(grad, op.outputs[0]) + return gen_nn_ops.elu_grad(grad, op.outputs[0]) @ops.RegisterGradient("Selu") def _SeluGrad(op, grad): - return gen_nn_ops._selu_grad(grad, op.outputs[0]) + return gen_nn_ops.selu_grad(grad, op.outputs[0]) @ops.RegisterGradient("Softplus") def _SoftplusGrad(op, grad): - return gen_nn_ops._softplus_grad(grad, op.inputs[0]) + return gen_nn_ops.softplus_grad(grad, op.inputs[0]) @ops.RegisterGradient("SoftplusGrad") def _SoftplusGradGrad(op, grad): # Let: # y = tf.nn.softplus(x) - # dx = gen_nn_ops._softplus_grad(dy, x) = dy / (1 + exp(-x)) + # dx = gen_nn_ops.softplus_grad(dy, x) = dy / (1 + exp(-x)) # This op computes (ddy, d2x) from op.inputs == [dy, x] and grad == ddx. dy, x = op.inputs with ops.control_dependencies([grad]): - ddy = gen_nn_ops._softplus_grad(grad, x) # pylint: disable=protected-access + ddy = gen_nn_ops.softplus_grad(grad, x) d2x = grad * dy / (math_ops.exp(-x) + 2.0 + math_ops.exp(x)) return (ddy, d2x) @ops.RegisterGradient("Softsign") def _SoftsignGrad(op, grad): - return gen_nn_ops._softsign_grad(grad, op.inputs[0]) + return gen_nn_ops.softsign_grad(grad, op.inputs[0]) @ops.RegisterGradient("ReluGrad") def _ReluGradGrad(op, grad): x = op.inputs[1] - return (gen_nn_ops._relu_grad(grad, x), + return (gen_nn_ops.relu_grad(grad, x), array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) @@ -456,7 +456,7 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad): def IsZero(g): # Some introspection to check if the gradient is feeding zeros - if context.in_eager_mode(): + if context.executing_eagerly(): # TODO(apassos) add an efficient way to detect eager zeros here. return False if g.op.type in ("ZerosLike", "Zeros"): @@ -565,14 +565,14 @@ def _LRNGrad(op, grad): alpha = op.get_attr("alpha") beta = op.get_attr("beta") return [ - gen_nn_ops._lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius, - bias, alpha, beta) + gen_nn_ops.lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius, bias, + alpha, beta) ] @ops.RegisterGradient("AvgPool") def _AvgPoolGrad(op, grad): - return gen_nn_ops._avg_pool_grad( + return gen_nn_ops.avg_pool_grad( array_ops.shape(op.inputs[0]), grad, op.get_attr("ksize"), @@ -584,7 +584,7 @@ def _AvgPoolGrad(op, grad): @ops.RegisterGradient("AvgPoolGrad") def _AvgPoolGradGrad(op, grad): return (array_ops.stop_gradient(op.inputs[0]), - gen_nn_ops._avg_pool( + gen_nn_ops.avg_pool( grad, op.get_attr("ksize"), op.get_attr("strides"), @@ -594,7 +594,7 @@ def _AvgPoolGradGrad(op, grad): @ops.RegisterGradient("MaxPool") def _MaxPoolGrad(op, grad): - return gen_nn_ops._max_pool_grad( + return gen_nn_ops.max_pool_grad( op.inputs[0], op.outputs[0], grad, @@ -620,7 +620,7 @@ def _MaxPoolGradV2(op, grad): @ops.RegisterGradient("MaxPoolWithArgmax") def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad): - return gen_nn_ops._max_pool_grad_with_argmax( + return gen_nn_ops.max_pool_grad_with_argmax( op.inputs[0], grad, op.outputs[1], @@ -635,7 +635,7 @@ def _MaxPoolGradGrad(op, grad): shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), array_ops.zeros( shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), - gen_nn_ops._max_pool_grad_grad( + gen_nn_ops.max_pool_grad_grad( op.inputs[0], op.inputs[1], grad, @@ -669,7 +669,7 @@ def _MaxPoolGradGradGrad(op, grad): shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), array_ops.zeros( shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), - gen_nn_ops._max_pool_grad( + gen_nn_ops.max_pool_grad( op.inputs[0], op.inputs[1], grad, @@ -696,8 +696,7 @@ def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2): Returns: Input backprop for FractionalMaxPool op. """ - # pylint: disable=protected-access - return gen_nn_ops._fractional_max_pool_grad( + return gen_nn_ops.fractional_max_pool_grad( op.inputs[0], op.outputs[0], grad_0, op.outputs[1], op.outputs[2], op.get_attr("overlapping")) @@ -719,10 +718,9 @@ def _FractionalAvgPoolGrad(op, grad_0, unused_grad_1, unused_grad_2): Returns: Input backprop for FractionalAvgPool op. """ - # pylint: disable=protected-access - return gen_nn_ops._fractional_avg_pool_grad(op.inputs[0].get_shape(), grad_0, - op.outputs[1], op.outputs[2], - op.get_attr("overlapping")) + return gen_nn_ops.fractional_avg_pool_grad(op.inputs[0].get_shape(), grad_0, + op.outputs[1], op.outputs[2], + op.get_attr("overlapping")) @ops.RegisterGradient("BatchNormWithGlobalNormalization") @@ -746,7 +744,7 @@ def _BatchNormWithGlobalNormalizationGrad(op, grad): last dimension. dg: Backprop for gamma, which is (grad * ((x - m) * rsqrt(v + epsilon))) """ - dx, dm, dv, db, dg = gen_nn_ops._batch_norm_with_global_normalization_grad( + dx, dm, dv, db, dg = gen_nn_ops.batch_norm_with_global_normalization_grad( op.inputs[0], op.inputs[1], op.inputs[2], op.inputs[4], grad, op.get_attr("variance_epsilon"), op.get_attr("scale_after_normalization")) return dx, dm, dv, db, dg diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 5fa5708114fd5cda6afbca78fa0debf68f0252cc..47cc4da7f2abd1f5b00e193a76c8391be94ca27d 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -303,12 +303,12 @@ def _swish_grad(features, grad): # @Defun decorator with noinline=True so that sigmoid(features) is re-computed # during backprop, and we can free the sigmoid(features) expression immediately # after use during the forward pass. +@tf_export("nn.swish") @function.Defun( grad_func=_swish_grad, shape_func=_swish_shape, func_name="swish", noinline=True) -@tf_export("nn.swish") def swish(features): # pylint: disable=g-doc-args """Computes the Swish activation function: `x * sigmoid(x)`. @@ -888,12 +888,10 @@ def fused_batch_norm( # TODO(reedwm): In a few weeks, switch to using the V2 version exclusively. We # currently only use the V2 version for float16 inputs, which is not supported # by the V1 version. - # pylint: disable=protected-access if x.dtype == dtypes.float16 or x.dtype == dtypes.bfloat16: - fused_batch_norm_func = gen_nn_ops._fused_batch_norm_v2 + fused_batch_norm_func = gen_nn_ops.fused_batch_norm_v2 else: - fused_batch_norm_func = gen_nn_ops._fused_batch_norm - # pylint: enable=protected-access + fused_batch_norm_func = gen_nn_ops._fused_batch_norm # pylint: disable=protected-access y, batch_mean, batch_var, _, _ = fused_batch_norm_func( x, scale, diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 8fbe698914e5f2fa8f821feed82c33fc77e35e21..a74de39eab34a1a27df90f70adf0f4c68ec29465 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -149,14 +150,12 @@ class _NonAtrousConvolution(object): conv_dims)) if conv_dims == 1: # conv1d uses the 2-d data format names - if data_format is None or data_format == "NWC": - data_format_2d = "NHWC" - elif data_format == "NCW": - data_format_2d = "NCHW" - else: + if data_format is None: + data_format = "NWC" + elif data_format not in {"NCW", "NWC", "NCHW", "NHWC"}: raise ValueError("data_format must be \"NWC\" or \"NCW\".") self.strides = strides[0] - self.data_format = data_format_2d + self.data_format = data_format self.conv_op = self._conv1d elif conv_dims == 2: if data_format is None or data_format == "NHWC": @@ -698,7 +697,7 @@ def convolution( `padded_input` is obtained by zero padding the input using an effective spatial filter shape of `(spatial_filter_shape-1) * dilation_rate + 1` and output striding `strides` as described in the - @{tf.nn.convolution$comment here}. + @{$python/nn#Convolution$comment here}. In the case that `data_format` does start with `"NC"`, the `input` and output (but not the `filter`) are simply transposed as follows: @@ -1042,9 +1041,7 @@ def pool( @tf_export("nn.atrous_conv2d") def atrous_conv2d(value, filters, rate, padding, name=None): - """Atrous convolution (a.k.a. - - convolution with holes or dilated convolution). + """Atrous convolution (a.k.a. convolution with holes or dilated convolution). This function is a simpler wrapper around the more general @{tf.nn.convolution}, and exists only for backwards compatibility. You can @@ -1481,7 +1478,6 @@ def conv3d_transpose( name=name) -# pylint: disable=protected-access @tf_export("nn.bias_add") def bias_add(value, bias, data_format=None, name=None): """Adds `bias` to `value`. @@ -1504,12 +1500,12 @@ def bias_add(value, bias, data_format=None, name=None): A `Tensor` with the same type as `value`. """ with ops.name_scope(name, "BiasAdd", [value, bias]) as name: - value = ops.convert_to_tensor(value, name="input") - bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias") - return gen_nn_ops._bias_add(value, bias, data_format=data_format, name=name) + if not context.executing_eagerly(): + value = ops.convert_to_tensor(value, name="input") + bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias") + return gen_nn_ops.bias_add(value, bias, data_format=data_format, name=name) -# pylint: disable=protected-access def bias_add_v1(value, bias, name=None): """Adds `bias` to `value`. @@ -1534,7 +1530,7 @@ def bias_add_v1(value, bias, name=None): with ops.name_scope(name, "BiasAddV1", [value, bias]) as name: value = ops.convert_to_tensor(value, name="input") bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias") - return gen_nn_ops._bias_add_v1(value, bias, name=name) + return gen_nn_ops.bias_add_v1(value, bias, name=name) @tf_export("nn.crelu") @@ -1580,7 +1576,7 @@ def relu6(features, name=None): """ with ops.name_scope(name, "Relu6", [features]) as name: features = ops.convert_to_tensor(features, name="features") - return gen_nn_ops._relu6(features, name=name) + return gen_nn_ops.relu6(features, name=name) @tf_export("nn.leaky_relu") @@ -1616,7 +1612,7 @@ def _flatten_outer_dims(logits): output = array_ops.reshape(logits, array_ops.concat([[-1], last_dim_size], 0)) # Set output shape if known. - if context.in_graph_mode(): + if not context.executing_eagerly(): shape = logits.get_shape() if shape is not None and shape.dims is not None: shape = shape.as_list() @@ -1645,7 +1641,7 @@ def _softmax(logits, compute_op, dim=-1, name=None): Args: logits: A non-empty `Tensor`. Must be one of the following types: `half`, `float32`, `float64`. - compute_op: Either gen_nn_ops._softmax or gen_nn_ops._log_softmax + compute_op: Either gen_nn_ops.softmax or gen_nn_ops.log_softmax dim: The dimension softmax would be performed on. The default is -1 which indicates the last dimension. name: A name for the operation (optional). @@ -1739,7 +1735,7 @@ def softmax(logits, axis=None, name=None, dim=None): axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim) if axis is None: axis = -1 - return _softmax(logits, gen_nn_ops._softmax, axis, name) + return _softmax(logits, gen_nn_ops.softmax, axis, name) @tf_export("nn.log_softmax") @@ -1769,7 +1765,7 @@ def log_softmax(logits, axis=None, name=None, dim=None): axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim) if axis is None: axis = -1 - return _softmax(logits, gen_nn_ops._log_softmax, axis, name) + return _softmax(logits, gen_nn_ops.log_softmax, axis, name) def _ensure_xent_args(name, sentinel, labels, logits): @@ -1871,7 +1867,7 @@ def softmax_cross_entropy_with_logits_v2( # Do the actual op computation. # The second output tensor contains the gradients. We use it in # _CrossEntropyGrad() in nn_grad but not here. - cost, unused_backprop = gen_nn_ops._softmax_cross_entropy_with_logits( + cost, unused_backprop = gen_nn_ops.softmax_cross_entropy_with_logits( precise_logits, labels, name=name) # The output cost shape should be the input minus dim. @@ -1881,7 +1877,8 @@ def softmax_cross_entropy_with_logits_v2( # Make shape inference work since reshape and transpose may erase its static # shape. - if context.in_graph_mode() and shape is not None and shape.dims is not None: + if not context.executing_eagerly( + ) and shape is not None and shape.dims is not None: shape = shape.as_list() del shape[dim] cost.set_shape(shape) @@ -2027,6 +2024,9 @@ def sparse_softmax_cross_entropy_with_logits( # Store label shape for result later. labels_static_shape = labels.get_shape() labels_shape = array_ops.shape(labels) + static_shapes_fully_defined = ( + labels_static_shape.is_fully_defined() and + logits.get_shape()[:-1].is_fully_defined()) if logits.get_shape().ndims is not None and logits.get_shape().ndims == 0: raise ValueError( "Logits cannot be scalars - received shape %s." % logits.get_shape()) @@ -2036,29 +2036,44 @@ def sparse_softmax_cross_entropy_with_logits( raise ValueError("Rank mismatch: Rank of labels (received %s) should " "equal rank of logits minus 1 (received %s)." % (labels_static_shape.ndims, logits.get_shape().ndims)) + if (static_shapes_fully_defined and + labels_static_shape != logits.get_shape()[:-1]): + raise ValueError("Shape mismatch: The shape of labels (received %s) " + "should equal the shape of logits except for the last " + "dimension (received %s)." % (labels_static_shape, + logits.get_shape())) # Check if no reshapes are required. if logits.get_shape().ndims == 2: - cost, _ = gen_nn_ops._sparse_softmax_cross_entropy_with_logits( + cost, _ = gen_nn_ops.sparse_softmax_cross_entropy_with_logits( precise_logits, labels, name=name) if logits.dtype == dtypes.float16: return math_ops.cast(cost, dtypes.float16) else: return cost - # Reshape logits to 2 dim, labels to 1 dim. - num_classes = array_ops.shape(logits)[array_ops.rank(logits) - 1] - precise_logits = array_ops.reshape(precise_logits, [-1, num_classes]) - labels = array_ops.reshape(labels, [-1]) - # The second output tensor contains the gradients. We use it in - # _CrossEntropyGrad() in nn_grad but not here. - cost, _ = gen_nn_ops._sparse_softmax_cross_entropy_with_logits( - precise_logits, labels, name=name) - cost = array_ops.reshape(cost, labels_shape) - cost.set_shape(labels_static_shape) - if logits.dtype == dtypes.float16: - return math_ops.cast(cost, dtypes.float16) - else: - return cost + # Perform a check of the dynamic shapes if the static shapes are not fully + # defined. + shape_checks = [] + if not static_shapes_fully_defined: + shape_checks.append( + check_ops.assert_equal( + array_ops.shape(labels), + array_ops.shape(logits)[:-1])) + with ops.control_dependencies(shape_checks): + # Reshape logits to 2 dim, labels to 1 dim. + num_classes = array_ops.shape(logits)[array_ops.rank(logits) - 1] + precise_logits = array_ops.reshape(precise_logits, [-1, num_classes]) + labels = array_ops.reshape(labels, [-1]) + # The second output tensor contains the gradients. We use it in + # _CrossEntropyGrad() in nn_grad but not here. + cost, _ = gen_nn_ops.sparse_softmax_cross_entropy_with_logits( + precise_logits, labels, name=name) + cost = array_ops.reshape(cost, labels_shape) + cost.set_shape(labels_static_shape) + if logits.dtype == dtypes.float16: + return math_ops.cast(cost, dtypes.float16) + else: + return cost @tf_export("nn.avg_pool") @@ -2086,7 +2101,7 @@ def avg_pool(value, ksize, strides, padding, data_format="NHWC", name=None): """ with ops.name_scope(name, "AvgPool", [value]) as name: value = ops.convert_to_tensor(value, name="input") - return gen_nn_ops._avg_pool( + return gen_nn_ops.avg_pool( value, ksize=ksize, strides=strides, @@ -2116,12 +2131,13 @@ def max_pool(value, ksize, strides, padding, data_format="NHWC", name=None): """ with ops.name_scope(name, "MaxPool", [value]) as name: value = ops.convert_to_tensor(value, name="input") - return gen_nn_ops._max_pool(value, - ksize=ksize, - strides=strides, - padding=padding, - data_format=data_format, - name=name) + return gen_nn_ops.max_pool( + value, + ksize=ksize, + strides=strides, + padding=padding, + data_format=data_format, + name=name) @ops.RegisterStatistics("Conv2D", "flops") @@ -2299,7 +2315,7 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: di # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) binary_tensor = math_ops.floor(random_tensor) ret = math_ops.div(x, keep_prob) * binary_tensor - if context.in_graph_mode(): + if not context.executing_eagerly(): ret.set_shape(x.get_shape()) return ret @@ -2331,7 +2347,7 @@ def top_k(input, k=1, sorted=True, name=None): # pylint: disable=redefined-buil values: The `k` largest elements along each last dimensional slice. indices: The indices of `values` within the last dimension of `input`. """ - return gen_nn_ops._top_kv2(input, k=k, sorted=sorted, name=name) + return gen_nn_ops.top_kv2(input, k=k, sorted=sorted, name=name) def nth_element(input, n, reverse=False, name=None): # pylint: disable=redefined-builtin @@ -2650,4 +2666,4 @@ def in_top_k(predictions, targets, k, name=None): A `Tensor` of type `bool`. Computed Precision at `k` as a `bool Tensor`. """ with ops.name_scope(name, "in_top_k"): - return gen_nn_ops._in_top_kv2(predictions, targets, k, name=name) + return gen_nn_ops.in_top_kv2(predictions, targets, k, name=name) diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 21eea3db25af0d1bcfbc7496665f5535c3f660ea..af9dae2aa64f0994f403ac81dcba800699d3c960 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -1049,6 +1049,22 @@ class DataFormatVectorPermuteTest(test_lib.TestCase): y_val = sess.run(y) self.assertAllEqual(y_val, [7, 9, 3, 4]) + def testNHWCToHWNC(self): + x_val = [7, 4, 9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC") + with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: + y_val = sess.run(y) + self.assertAllEqual(y_val, [4, 9, 7, 3]) + + def testHWNCToNHWC(self): + x_val = [7, 4, 9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC") + with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: + y_val = sess.run(y) + self.assertAllEqual(y_val, [9, 7, 4, 3]) + def testNHWCToNCHW2D(self): x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] x = constant_op.constant(x_val) @@ -1057,6 +1073,22 @@ class DataFormatVectorPermuteTest(test_lib.TestCase): y_val = sess.run(y) self.assertAllEqual(y_val, [[7, 4], [5, 1], [9, 3], [4, 5]]) + def testNHWCToHWNC2D(self): + x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC") + with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: + y_val = sess.run(y) + self.assertAllEqual(y_val, [[9, 3], [4, 5], [7, 4], [5, 1]]) + + def testHWNCToNHWC2D(self): + x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC") + with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: + y_val = sess.run(y) + self.assertAllEqual(y_val, [[4, 5], [7, 4], [9, 3], [5, 1]]) + def testNCHWToNHWC2D(self): x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] x = constant_op.constant(x_val) diff --git a/tensorflow/python/ops/numerics.py b/tensorflow/python/ops/numerics.py index b4ce1cbf25346412e2781a520b7e2cdcf720bcd5..d348e47f57b703138aabfc3463e750b795113335 100644 --- a/tensorflow/python/ops/numerics.py +++ b/tensorflow/python/ops/numerics.py @@ -74,7 +74,7 @@ def add_check_numerics_ops(): the checked operations. @enc_compatibility """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError( "add_check_numerics_ops() is not compatible with eager execution. " "To check for Inf's and NaN's under eager execution, call " diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py index b0315ceee268be8ac1813dae5a262a7d9496e154..075b38d743d13329e646c0b268e938b5c5704e47 100644 --- a/tensorflow/python/ops/parsing_ops.py +++ b/tensorflow/python/ops/parsing_ops.py @@ -700,8 +700,7 @@ def _parse_example_raw(serialized, # Finally, convert dense_shapes to TensorShapeProto dense_shapes = [shape.as_proto() for shape in dense_shapes] - # pylint: disable=protected-access - outputs = gen_parsing_ops._parse_example( + outputs = gen_parsing_ops.parse_example( serialized=serialized, names=names, dense_defaults=dense_defaults_vec, @@ -710,7 +709,6 @@ def _parse_example_raw(serialized, dense_keys=dense_keys, dense_shapes=dense_shapes, name=name) - # pylint: enable=protected-access (sparse_indices, sparse_values, sparse_shapes, dense_values) = outputs @@ -1132,8 +1130,7 @@ def _parse_single_sequence_example_raw(serialized, feature_list_dense_shapes = [tensor_shape.as_shape(shape).as_proto() for shape in feature_list_dense_shapes] - # pylint: disable=protected-access - outputs = gen_parsing_ops._parse_single_sequence_example( + outputs = gen_parsing_ops.parse_single_sequence_example( serialized=serialized, debug_name=debug_name, context_dense_defaults=context_dense_defaults_vec, @@ -1149,7 +1146,6 @@ def _parse_single_sequence_example_raw(serialized, feature_list_dense_missing_assumed_empty=( feature_list_dense_missing_assumed_empty), name=name) - # pylint: enable=protected-access (context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, @@ -1182,7 +1178,6 @@ def _parse_single_sequence_example_raw(serialized, @tf_export("decode_csv") def decode_csv(records, record_defaults, field_delim=",", use_quote_delim=True, name=None, na_value=""): - # pylint: disable=protected-access """Convert CSV records to tensors. Each column maps to one tensor. RFC 4180 format is expected for the CSV records. @@ -1211,11 +1206,13 @@ def decode_csv(records, record_defaults, field_delim=",", Each tensor will have the same shape as records. """ # TODO(martinwicke), remove the wrapper when new Python API generator is done. - return gen_parsing_ops._decode_csv( - records=records, record_defaults=record_defaults, - field_delim=field_delim, use_quote_delim=use_quote_delim, - na_value=na_value, name=name) - # pylint: enable=protected-access + return gen_parsing_ops.decode_csv( + records=records, + record_defaults=record_defaults, + field_delim=field_delim, + use_quote_delim=use_quote_delim, + na_value=na_value, + name=name) # TODO(b/70890287): Combine the implementation of this op and @@ -1391,7 +1388,6 @@ def _parse_single_example_v2_raw(serialized, sparse_keys, sparse_types, # Finally, convert dense_shapes to TensorShapeProto dense_shapes = [shape.as_proto() for shape in dense_shapes] - # pylint: disable=protected-access outputs = gen_parsing_ops.parse_single_example( serialized=serialized, dense_defaults=dense_defaults_vec, @@ -1401,7 +1397,6 @@ def _parse_single_example_v2_raw(serialized, sparse_keys, sparse_types, dense_keys=dense_keys, dense_shapes=dense_shapes, name=name) - # pylint: enable=protected-access (sparse_indices, sparse_values, sparse_shapes, dense_values) = outputs diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py index 2c86358d21b1c280b8d7ade625fd4b7a44c5de26..6a2dd3f1cd55eea1d3b652a31cd2784c411c2ce0 100644 --- a/tensorflow/python/ops/random_ops.py +++ b/tensorflow/python/ops/random_ops.py @@ -43,7 +43,6 @@ def _ShapeTensor(shape): return ops.convert_to_tensor(shape, dtype=dtype, name="shape") -# pylint: disable=protected-access @tf_export("random_normal") def random_normal(shape, mean=0.0, @@ -74,7 +73,7 @@ def random_normal(shape, mean_tensor = ops.convert_to_tensor(mean, dtype=dtype, name="mean") stddev_tensor = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") seed1, seed2 = random_seed.get_seed(seed) - rnd = gen_random_ops._random_standard_normal( + rnd = gen_random_ops.random_standard_normal( shape_tensor, dtype, seed=seed1, seed2=seed2) mul = rnd * stddev_tensor value = math_ops.add(mul, mean_tensor, name=name) @@ -126,7 +125,7 @@ def parameterized_truncated_normal(shape, minvals_tensor = ops.convert_to_tensor(minvals, dtype=dtype, name="minvals") maxvals_tensor = ops.convert_to_tensor(maxvals, dtype=dtype, name="maxvals") seed1, seed2 = random_seed.get_seed(seed) - rnd = gen_random_ops._parameterized_truncated_normal( + rnd = gen_random_ops.parameterized_truncated_normal( shape_tensor, means_tensor, stddevs_tensor, @@ -171,7 +170,7 @@ def truncated_normal(shape, mean_tensor = ops.convert_to_tensor(mean, dtype=dtype, name="mean") stddev_tensor = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") seed1, seed2 = random_seed.get_seed(seed) - rnd = gen_random_ops._truncated_normal( + rnd = gen_random_ops.truncated_normal( shape_tensor, dtype, seed=seed1, seed2=seed2) mul = rnd * stddev_tensor value = math_ops.add(mul, mean_tensor, name=name) @@ -210,7 +209,7 @@ def random_uniform(shape, maxval: A 0-D Tensor or Python value of type `dtype`. The upper bound on the range of random values to generate. Defaults to 1 if `dtype` is floating point. - dtype: The type of the output: 'float16`, `float32`, `float64`, `int32`, + dtype: The type of the output: `float16`, `float32`, `float64`, `int32`, or `int64`. seed: A Python integer. Used to create a random seed for the distribution. See @{tf.set_random_seed} @@ -237,11 +236,10 @@ def random_uniform(shape, maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max") seed1, seed2 = random_seed.get_seed(seed) if dtype.is_integer: - return gen_random_ops._random_uniform_int( + return gen_random_ops.random_uniform_int( shape, minval, maxval, seed=seed1, seed2=seed2, name=name) else: - rnd = gen_random_ops._random_uniform( - shape, dtype, seed=seed1, seed2=seed2) + rnd = gen_random_ops.random_uniform(shape, dtype, seed=seed1, seed2=seed2) return math_ops.add(rnd * (maxval - minval), minval, name=name) @@ -275,7 +273,7 @@ def random_shuffle(value, seed=None, name=None): dimension. """ seed1, seed2 = random_seed.get_seed(seed) - return gen_random_ops._random_shuffle( + return gen_random_ops.random_shuffle( value, seed=seed1, seed2=seed2, name=name) @@ -420,7 +418,7 @@ def random_gamma(shape, seed1, seed2 = random_seed.get_seed(seed) return math_ops.maximum( np.finfo(dtype.as_numpy_dtype).tiny, - gen_random_ops._random_gamma( + gen_random_ops.random_gamma( shape, alpha_broadcast, seed=seed1, seed2=seed2) / beta) ops.NotDifferentiable("RandomGamma") diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 2d6d0672e03d9435175b0accd7c20dfddae16bcc..df873da98e7fac7accc99a229ffb53a60a74c9bb 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import variable_pb2 +from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import dtypes @@ -30,6 +31,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import gen_state_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables # go/tf-wildcard-import # pylint: disable=wildcard-import @@ -44,10 +46,6 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): container = ops.get_default_graph()._container # pylint: disable=protected-access if container is None: container = "" - if not graph_mode: - # When in eager mode use a uid for the shared_name, to prevent accidental - # sharing. - shared_name = str(ops.uid()) handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, @@ -133,10 +131,10 @@ class EagerResourceDeleter(object): # valid, and so on. Printing warnings in these cases is silly # (exceptions raised from __del__ are printed as warnings to stderr). pass # 'NoneType' object is not callable when the handle has been - # partially unloaded. + # partially unloaded. except AttributeError: pass # 'NoneType' object has no attribute 'eager_mode' when context has - # been unloaded. Will catch other module unloads as well. + # been unloaded. Will catch other module unloads as well. def shape_safe_assign_variable_handle(handle, shape, value, name=None): @@ -151,7 +149,7 @@ def shape_safe_assign_variable_handle(handle, shape, value, name=None): class ResourceVariable(variables.Variable): """Variable based on resource handles. - See the ${variables} documentation for more details. + See the @{$variables$Variables How To} for a high level overview. A `ResourceVariable` allows you to maintain state across subsequent calls to session.run. @@ -181,24 +179,20 @@ class ResourceVariable(variables.Variable): by edges in the graph. Consider the following example, in which two writes can cause tf.Variable and tf.ResourceVariable to behave differently: - ```python - a = tf.ResourceVariable(1.0) - a.initializer.run() - - assign = a.assign(2.0) - with tf.control_dependencies([assign]): - b = a.read_value() - with tf.control_dependencies([b]): - other_assign = a.assign(3.0) - with tf.control_dependencies([other_assign]): - # Will print 2.0 because the value was read before other_assign ran. If - # `a` was a tf.Variable instead, 2.0 or 3.0 could be printed. - tf.Print(b, [b]).eval() + ```python + a = tf.ResourceVariable(1.0) + a.initializer.run() + + assign = a.assign(2.0) + with tf.control_dependencies([assign]): + b = a.read_value() + with tf.control_dependencies([b]): + other_assign = a.assign(3.0) + with tf.control_dependencies([other_assign]): + # Will print 2.0 because the value was read before other_assign ran. If + # `a` was a tf.Variable instead, 2.0 or 3.0 could be printed. + tf.Print(b, [b]).eval() ``` - - To enforce these consistency properties tf.ResourceVariable might make more - copies than an equivalent tf.Variable under the hood, so tf.Variable is still - not deprecated. """ def __init__(self, @@ -265,9 +259,9 @@ class ResourceVariable(variables.Variable): if initial_value is not None: raise ValueError("variable_def and initial_value are mutually " "exclusive.") - if not context.in_graph_mode(): - raise ValueError("Creating ResourceVariable from variable_def" - " only supported in GRAPH mode.") + if context.executing_eagerly(): + raise ValueError("Creating ResourceVariable from variable_def is " + "not supported when eager execution is enabled.") self._init_from_proto(variable_def, import_scope=import_scope) else: self._init_from_args( @@ -361,11 +355,17 @@ class ResourceVariable(variables.Variable): # this graph. self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access with ops.init_scope(): - self._in_graph_mode = context.in_graph_mode() + self._in_graph_mode = not context.executing_eagerly() with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: # pylint: disable=protected-access handle_name = ops._name_from_scope_name(name) + if self._in_graph_mode: + shared_name = handle_name + else: + # When in eager mode use a uid for the shared_name, to prevent + # accidental sharing. + shared_name = "%s_%d" % (handle_name, ops.uid()) if init_from_fn: # Use attr_scope and device(None) to simulate the behavior of # colocate_with when the variable we want to colocate with doesn't @@ -381,12 +381,9 @@ class ResourceVariable(variables.Variable): self._handle = _eager_safe_variable_handle( shape=initial_value.get_shape(), dtype=initial_value.dtype.base_dtype, - shared_name=handle_name, + shared_name=shared_name, name=name, graph_mode=self._in_graph_mode) - self._handle_device = ( - self._handle.device if self._in_graph_mode else - context.get_default_context().device_name) self._shape = initial_value.get_shape() else: initial_value = initial_value() @@ -396,12 +393,9 @@ class ResourceVariable(variables.Variable): self._handle = _eager_safe_variable_handle( shape=initial_value.get_shape(), dtype=initial_value.dtype.base_dtype, - shared_name=handle_name, + shared_name=shared_name, name=name, graph_mode=False) - self._handle_device = ( - self._handle.device if self._in_graph_mode else - context.get_default_context().device_name) self._shape = initial_value.get_shape() # pylint: enable=protected-access @@ -422,13 +416,12 @@ class ResourceVariable(variables.Variable): self._handle = _eager_safe_variable_handle( shape=initial_value.get_shape(), dtype=initial_value.dtype.base_dtype, - shared_name=handle_name, + shared_name=shared_name, name=name, graph_mode=self._in_graph_mode) - self._handle_device = (self._handle.device if self._in_graph_mode else - context.get_default_context().device_name) self._shape = initial_value.get_shape() + self._unique_id = shared_name self._initial_value = initial_value if self._in_graph_mode else None self._handle_name = handle_name + ":0" self._dtype = initial_value.dtype.base_dtype @@ -449,7 +442,7 @@ class ResourceVariable(variables.Variable): with ops.name_scope("Read"), ops.colocate_with(self._handle): # Manually assign reads to the handle's device to avoid log # messages. - with ops.device(self._handle_device): + with ops.device(self._handle.device): value = self._read_variable_op() self._graph_element = value if caching_device is not None: @@ -476,7 +469,7 @@ class ResourceVariable(variables.Variable): self._cached_value = self._read_variable_op() else: self._cached_value = None - if context.in_graph_mode(): + if not context.executing_eagerly(): ops.add_to_collections(collections, self) elif ops.GraphKeys.GLOBAL_STEP in collections: ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self) @@ -489,12 +482,13 @@ class ResourceVariable(variables.Variable): # cycles being uncollectable, and means that no __del__ will be defined at # all in graph mode. self._handle_deleter = EagerResourceDeleter( - handle=self._handle, handle_device=self._handle_device) + handle=self._handle, handle_device=self._handle.device) + self._cached_shape_as_list = None def _init_from_proto(self, variable_def, import_scope=None): """Initializes from `VariableDef` proto.""" # Note that init_from_proto is currently not supported in Eager mode. - assert context.in_graph_mode() + assert not context.executing_eagerly() self._in_graph_mode = True assert isinstance(variable_def, variable_pb2.VariableDef) if not variable_def.is_resource: @@ -507,8 +501,8 @@ class ResourceVariable(variables.Variable): variable_def.variable_name, import_scope=import_scope)) self._shape = tensor_shape.TensorShape( self._handle.op.get_attr("shape")) - self._handle_device = self._handle.device self._handle_name = self._handle.name + self._unique_id = self._handle_name self._initializer_op = g.as_graph_element( ops.prepend_name_scope( variable_def.initializer_name, import_scope=import_scope)) @@ -534,8 +528,10 @@ class ResourceVariable(variables.Variable): self._save_slice_info = None self._caching_device = None self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype")) - self._graph_element = self.value() + self._graph_element = g.get_tensor_by_name( + self._handle.op.name + "/Read/ReadVariableOp:0") self._constraint = None + self._cached_shape_as_list = None def __nonzero__(self): return self.__bool__() @@ -551,7 +547,7 @@ class ResourceVariable(variables.Variable): @property def device(self): """The device this variable is on.""" - return self._handle_device + return self._handle.device @property def graph(self): @@ -568,11 +564,26 @@ class ResourceVariable(variables.Variable): """The shape of this variable.""" return self._shape + def _shape_as_list(self): + if self._cached_shape_as_list: + return self._cached_shape_as_list + if self.shape.ndims is None: + return None + self._cached_shape_as_list = [dim.value for dim in self.shape.dims] + return self._cached_shape_as_list + + def _shape_tuple(self): + shape = self._shape_as_list() + if shape is None: + return None + return tuple(shape) + @property def create(self): """The op responsible for initializing this variable.""" if not self._in_graph_mode: - raise RuntimeError("Calling create in EAGER mode not supported.") + raise RuntimeError("Calling create is not supported when eager execution" + " is enabled.") return self._initializer_op @property @@ -585,7 +596,7 @@ class ResourceVariable(variables.Variable): if self._cached_value is not None: return self._cached_value with ops.colocate_with(None, ignore_existing=True): - with ops.device(self._handle_device): + with ops.device(self._handle.device): return self._read_variable_op() def _as_graph_element(self): @@ -600,7 +611,7 @@ class ResourceVariable(variables.Variable): @property def initial_value(self): """Returns the Tensor used as the initial value for the variable.""" - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("initial_value not supported in EAGER mode.") return self._initial_value @@ -621,15 +632,15 @@ class ResourceVariable(variables.Variable): def eval(self, session=None): """Evaluates and returns the value of this variable.""" - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("Trying to eval in EAGER mode") return self._graph_element.eval(session=session) def numpy(self): - if context.in_graph_mode(): - raise NotImplementedError( - "numpy() is only available when eager execution is enabled.") - return self.read_value().numpy() + if context.executing_eagerly(): + return self.read_value().numpy() + raise NotImplementedError( + "numpy() is only available when eager execution is enabled.") def count_up_to(self, limit): """Increments this variable until it reaches `limit`. @@ -682,7 +693,7 @@ class ResourceVariable(variables.Variable): """ with ops.name_scope("Read"): # Ensure we read the variable in the same device as the handle. - with ops.device(self._handle_device): + with ops.device(self._handle.device): value = self._read_variable_op() # Return an identity so it can get placed on whatever device the context # specifies instead of the device where the variable is. @@ -710,7 +721,7 @@ class ResourceVariable(variables.Variable): A `VariableDef` protocol buffer, or `None` if the `Variable` is not in the specified name scope. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("to_proto not supported in EAGER mode.") if export_scope is None or self.handle.name.startswith(export_scope): var_def = variable_pb2.VariableDef() @@ -737,7 +748,7 @@ class ResourceVariable(variables.Variable): @staticmethod def from_proto(variable_def, import_scope=None): - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("from_proto not supported in EAGER mode.") return ResourceVariable( variable_def=variable_def, import_scope=import_scope) @@ -788,37 +799,84 @@ class ResourceVariable(variables.Variable): __array_priority__ = 100 - def assign_sub(self, delta, use_locking=None, name=None): + def assign_sub(self, delta, use_locking=None, name=None, read_value=True): + """Subtracts a value from this variable. + + Args: + delta: A `Tensor`. The value to subtract from this variable. + use_locking: If `True`, use locking during the operation. + name: The name to use for the operation. + read_value: A `bool`. Whether to read and return the new value of the + variable or not. + + Returns: + If `read_value` is `True`, this method will return the new value of the + variable after the assignment has completed. Otherwise, when in graph mode + it will return the `Operation` that does the assignment, and when in eager + mode it will return `None`. + """ # TODO(apassos): this here and below is not atomic. Consider making it # atomic if there's a way to do so without a performance cost for those who # don't need it. - return self._lazy_read(gen_resource_variable_ops.assign_sub_variable_op( - self.handle, - ops.convert_to_tensor(delta, dtype=self.dtype), - name=name)) + assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op( + self.handle, ops.convert_to_tensor(delta, dtype=self.dtype), name=name) + if read_value: + return self._lazy_read(assign_sub_op) + return assign_sub_op + + def assign_add(self, delta, use_locking=None, name=None, read_value=True): + """Adds a value to this variable. + + Args: + delta: A `Tensor`. The value to add to this variable. + use_locking: If `True`, use locking during the operation. + name: The name to use for the operation. + read_value: A `bool`. Whether to read and return the new value of the + variable or not. - def assign_add(self, delta, use_locking=None, name=None): - return self._lazy_read(gen_resource_variable_ops.assign_add_variable_op( - self.handle, - ops.convert_to_tensor(delta, dtype=self.dtype), - name=name)) + Returns: + If `read_value` is `True`, this method will return the new value of the + variable after the assignment has completed. Otherwise, when in graph mode + it will return the `Operation` that does the assignment, and when in eager + mode it will return `None`. + """ + assign_add_op = gen_resource_variable_ops.assign_add_variable_op( + self.handle, ops.convert_to_tensor(delta, dtype=self.dtype), name=name) + if read_value: + return self._lazy_read(assign_add_op) + return assign_add_op def _lazy_read(self, op): if hasattr(self, "_trainable") and self._trainable: tape.watch_variable(self) return _UnreadVariable( - self._handle, self.dtype, self._handle_device, self._shape, - self._in_graph_mode, - self._handle_deleter if not self._in_graph_mode else None, op) + self._handle, self.dtype, self._shape, self._in_graph_mode, + self._handle_deleter if not self._in_graph_mode else None, op, + self._unique_id) + + def assign(self, value, use_locking=None, name=None, read_value=True): + """Assigns a new value to this variable. + + Args: + value: A `Tensor`. The new value for this variable. + use_locking: If `True`, use locking during the assignment. + name: The name to use for the assignment. + read_value: A `bool`. Whether to read and return the new value of the + variable or not. - def assign(self, value, use_locking=None, name=None): + Returns: + If `read_value` is `True`, this method will return the new value of the + variable after the assignment has completed. Otherwise, when in graph mode + it will return the `Operation` that does the assignment, and when in eager + mode it will return `None`. + """ value_tensor = ops.convert_to_tensor(value, dtype=self.dtype) self._shape.assert_is_compatible_with(value_tensor.shape) - return self._lazy_read( - gen_resource_variable_ops.assign_variable_op( - self.handle, - value_tensor, - name=name)) + assign_op = gen_resource_variable_ops.assign_variable_op( + self.handle, value_tensor, name=name) + if read_value: + return self._lazy_read(assign_op) + return assign_op def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, end_mask, ellipsis_mask, new_axis_mask, @@ -894,6 +952,10 @@ class ResourceVariable(variables.Variable): "Tensor object.") +pywrap_tensorflow.TFE_Py_RegisterResourceVariableType(ResourceVariable) +math_ops._resource_variable_type = ResourceVariable # pylint: disable=protected-access + + def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access @@ -904,31 +966,31 @@ class _UnreadVariable(ResourceVariable): Pretends to be the tensor if anyone looks. """ - def __init__(self, handle, dtype, handle_device, # pylint: disable=super-init-not-called - shape, in_graph_mode, deleter, parent_op): + def __init__(self, handle, dtype, # pylint: disable=super-init-not-called + shape, in_graph_mode, deleter, parent_op, unique_id): # We do not call super init on purpose. self._trainable = False self._save_slice_info = None self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access self._in_graph_mode = in_graph_mode self._handle = handle - self._handle_device = handle_device self._shape = shape self._initial_value = None if isinstance(self._handle, ops.EagerTensor): self._handle_name = "" else: self._handle_name = self._handle.name + self._unique_id = unique_id self._dtype = dtype self._constraint = None self._cached_value = None self._is_initialized_op = None self._initializer_op = None self._parent_op = parent_op - if context.in_graph_mode(): - self._graph_element = self.read_value() - else: + if context.executing_eagerly(): self._graph_element = None + else: + self._graph_element = self.read_value() self._handle_deleter = deleter def value(self): @@ -944,6 +1006,7 @@ class _UnreadVariable(ResourceVariable): def set_shape(self, shape): self._shape = shape + self._cached_shape_as_list = None @property def op(self): diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index aa8d4327d2f0e93768728744d5cce3fed385393f..1dd464d51d9d1b17bf9e2741668117bf014d9453 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -45,29 +45,25 @@ from tensorflow.python.util.tf_export import tf_export # pylint: disable=protected-access _concat = rnn_cell_impl._concat -_like_rnncell = rnn_cell_impl._like_rnncell # pylint: enable=protected-access def _transpose_batch_time(x): - """Transpose the batch and time dimensions of a Tensor. + """Transposes the batch and time dimensions of a Tensor. - Retains as much of the static shape information as possible. + If the input tensor has rank < 2 it returns the original tensor. Retains as + much of the static shape information as possible. Args: - x: A tensor of rank 2 or higher. + x: A Tensor. Returns: x transposed along the first two dimensions. - - Raises: - ValueError: if `x` is rank 1 or lower. """ x_static_shape = x.get_shape() if x_static_shape.ndims is not None and x_static_shape.ndims < 2: - raise ValueError( - "Expected input tensor %s to have rank at least 2, but saw shape: %s" % - (x, x_static_shape)) + return x + x_rank = array_ops.rank(x) x_t = array_ops.transpose( x, array_ops.concat( @@ -403,11 +399,8 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, Raises: TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. """ - - if not _like_rnncell(cell_fw): - raise TypeError("cell_fw must be an instance of RNNCell") - if not _like_rnncell(cell_bw): - raise TypeError("cell_bw must be an instance of RNNCell") + rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw) + rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw) with vs.variable_scope(scope or "bidirectional_rnn"): # Forward direction @@ -568,14 +561,13 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, TypeError: If `cell` is not an instance of RNNCell. ValueError: If inputs is None or an empty list. """ - if not _like_rnncell(cell): - raise TypeError("cell must be an instance of RNNCell") + rnn_cell_impl.assert_like_rnncell("cell", cell) with vs.variable_scope(scope or "rnn") as varscope: # Create a new scope in which the caching device is either # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. - if context.in_graph_mode(): + if not context.executing_eagerly(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) @@ -616,7 +608,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, ["Expected shape for Tensor %s is " % x.name, packed_shape, " but saw shape: ", x_shape]) - if context.in_graph_mode() and sequence_length is not None: + if not context.executing_eagerly() and sequence_length is not None: # Perform some shape validation with ops.control_dependencies( [_assert_has_shape(sequence_length, [batch_size])]): @@ -742,7 +734,7 @@ def _dynamic_rnn_loop(cell, element_shape=element_shape, tensor_array_name=base_name + name) - in_graph_mode = context.in_graph_mode() + in_graph_mode = not context.executing_eagerly() if in_graph_mode: output_ta = tuple( _create_ta( @@ -872,7 +864,7 @@ def raw_rnn(cell, loop_fn, ```python time = tf.constant(0, dtype=tf.int32) - (finished, next_input, initial_state, _, loop_state) = loop_fn( + (finished, next_input, initial_state, emit_structure, loop_state) = loop_fn( time=time, cell_output=None, cell_state=None, loop_state=None) emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype) state = initial_state @@ -883,7 +875,7 @@ def raw_rnn(cell, loop_fn, loop_state=loop_state) # Emit zeros and copy forward state for minibatch entries that are finished. state = tf.where(finished, state, next_state) - emit = tf.where(finished, tf.zeros_like(emit), emit) + emit = tf.where(finished, tf.zeros_like(emit_structure), emit) emit_ta = emit_ta.write(time, emit) # If any new minibatch entries are marked as finished, mark these. finished = tf.logical_or(finished, next_finished) @@ -943,10 +935,15 @@ def raw_rnn(cell, loop_fn, and `emit_output`: the output to store for this iteration. Note that `emit_output` should be a `Tensor` or (possibly nested) - tuple of tensors with shapes and structure matching `cell.output_size` - and `cell_output` above. The parameter `cell_state` and output - `next_cell_state` may be either a single or (possibly nested) tuple - of tensors. The parameter `loop_state` and + tuple of tensors which is aggregated in the `emit_ta` inside the + `while_loop`. For the first call to `loop_fn`, the `emit_output` + corresponds to the `emit_structure` which is then used to determine the + size of the `zero_tensor` for the `emit_ta` (defaults to + `cell.output_size`). For the subsequent calls to the `loop_fn`, the + `emit_output` corresponds to the actual output tensor + that is to be aggregated in the `emit_ta`. The parameter `cell_state` + and output `next_cell_state` may be either a single or (possibly nested) + tuple of tensors. The parameter `loop_state` and output `next_loop_state` may be either a single or (possibly nested) tuple of `Tensor` and `TensorArray` objects. This last parameter may be ignored by `loop_fn` and the return value may be `None`. If it @@ -1015,9 +1012,8 @@ def raw_rnn(cell, loop_fn, TypeError: If `cell` is not an instance of RNNCell, or `loop_fn` is not a `callable`. """ + rnn_cell_impl.assert_like_rnncell("cell", cell) - if not _like_rnncell(cell): - raise TypeError("cell must be an instance of RNNCell") if not callable(loop_fn): raise TypeError("loop_fn must be a callable") @@ -1027,7 +1023,7 @@ def raw_rnn(cell, loop_fn, # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: - if context.in_graph_mode(): + if not context.executing_eagerly(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) @@ -1229,9 +1225,7 @@ def static_rnn(cell, ValueError: If `inputs` is `None` or an empty list, or if the input depth (column size) cannot be inferred from inputs via shape inference. """ - - if not _like_rnncell(cell): - raise TypeError("cell must be an instance of RNNCell") + rnn_cell_impl.assert_like_rnncell("cell", cell) if not nest.is_sequence(inputs): raise TypeError("inputs must be a sequence") if not inputs: @@ -1242,7 +1236,7 @@ def static_rnn(cell, # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: - if context.in_graph_mode(): + if not context.executing_eagerly(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) @@ -1469,11 +1463,8 @@ def static_bidirectional_rnn(cell_fw, TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. ValueError: If inputs is None or an empty list. """ - - if not _like_rnncell(cell_fw): - raise TypeError("cell_fw must be an instance of RNNCell") - if not _like_rnncell(cell_bw): - raise TypeError("cell_bw must be an instance of RNNCell") + rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw) + rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw) if not nest.is_sequence(inputs): raise TypeError("inputs must be a sequence") if not inputs: diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 923348ea44e18a87e09fe1c0424f0323eb967e3d..fe380c44dafdad6dc25d50102bacba610132674d 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -46,6 +46,7 @@ from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpointable from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -54,6 +55,8 @@ _BIAS_VARIABLE_NAME = "bias" _WEIGHTS_VARIABLE_NAME = "kernel" +# TODO(jblespiau): Remove this function when we are sure there are no longer +# any usage (even if protected, it is being used). Prefer assert_like_rnncell. def _like_rnncell(cell): """Checks that a given object is an RNNCell by using duck typing.""" conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"), @@ -61,6 +64,45 @@ def _like_rnncell(cell): return all(conditions) +# This can be used with self.assertRaisesRegexp for assert_like_rnncell. +ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell" + + +def assert_like_rnncell(cell_name, cell): + """Raises a TypeError if cell is not like an RNNCell. + + NOTE: Do not rely on the error message (in particular in tests) which can be + subject to change to increase readability. Use + ASSERT_LIKE_RNNCELL_ERROR_REGEXP. + + Args: + cell_name: A string to give a meaningful error referencing to the name + of the functionargument. + cell: The object which should behave like an RNNCell. + + Raises: + TypeError: A human-friendly exception. + """ + conditions = [ + hasattr(cell, "output_size"), + hasattr(cell, "state_size"), + hasattr(cell, "zero_state"), + callable(cell), + ] + errors = [ + "'output_size' property is missing", + "'state_size' property is missing", + "'zero_state' method is missing", + "is not callable" + ] + + if not all(conditions): + + errors = [error for error, cond in zip(errors, conditions) if not cond] + raise TypeError("The argument {!r} ({}) is not an RNNCell: {}.".format( + cell_name, cell, ", ".join(errors))) + + def _concat(prefix, suffix, static=False): """Concat that enables int, Tensor, or TensorShape values. @@ -127,7 +169,7 @@ def _zero_state_tensors(state_size, batch_size, dtype): """Combine s with batch_size to get a proper tensor shape.""" c = _concat(batch_size, s) size = array_ops.zeros(c, dtype=dtype) - if context.in_graph_mode(): + if not context.executing_eagerly(): c_static = _concat(batch_size, s, static=True) size.set_shape(c_static) return size @@ -191,12 +233,13 @@ class RNNCell(base_layer.Layer): def _rnn_get_variable(self, getter, *args, **kwargs): variable = getter(*args, **kwargs) - if context.in_graph_mode(): - trainable = (variable in tf_variables.trainable_variables() or - (isinstance(variable, tf_variables.PartitionedVariable) and - list(variable)[0] in tf_variables.trainable_variables())) - else: + if context.executing_eagerly(): trainable = variable._trainable # pylint: disable=protected-access + else: + trainable = ( + variable in tf_variables.trainable_variables() or + (isinstance(variable, tf_variables.PartitionedVariable) and + list(variable)[0] in tf_variables.trainable_variables())) if trainable and variable not in self._trainable_weights: self._trainable_weights.append(variable) elif not trainable and variable not in self._non_trainable_weights: @@ -240,7 +283,7 @@ class RNNCell(base_layer.Layer): # Try to use the last cached zero_state. This is done to avoid recreating # zeros, especially when eager execution is enabled. state_size = self.state_size - is_eager = context.in_eager_mode() + is_eager = context.executing_eagerly() if is_eager and hasattr(self, "_last_zero_state"): (last_state_size, last_batch_size, last_dtype, last_output) = getattr(self, "_last_zero_state") @@ -912,8 +955,8 @@ class DropoutWrapper(RNNCell): but not `callable`. ValueError: if any of the keep_probs are not between 0 and 1. """ - if not _like_rnncell(cell): - raise TypeError("The parameter cell is not a RNNCell.") + assert_like_rnncell("cell", cell) + if (dropout_state_filter_visitor is not None and not callable(dropout_state_filter_visitor)): raise TypeError("dropout_state_filter_visitor must be callable") @@ -1187,6 +1230,12 @@ class MultiRNNCell(RNNCell): "cells must be a list or tuple, but saw: %s." % cells) self._cells = cells + for cell_number, cell in enumerate(self._cells): + # Add Checkpointable dependencies on these cells so their variables get + # saved with this object when using object-based saving. + if isinstance(cell, checkpointable.CheckpointableBase): + # TODO(allenl): Track down non-Checkpointable callers. + self._track_checkpointable(cell, name="cell-%d" % (cell_number,)) self._state_is_tuple = state_is_tuple if not state_is_tuple: if any(nest.is_sequence(c.state_size) for c in self._cells): diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index 6fe2f61016775b410045fefcc8764907b8ea39f3..1b4111bca630ffa122ed590b0e3d54b796ab6b7a 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -25,6 +25,9 @@ from __future__ import print_function import threading +# Used by py_util.cc to get tracebacks. +import traceback # pylint: disable=unused-import + import numpy as np import six @@ -33,6 +36,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_script_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -51,6 +55,16 @@ class EagerFunc(object): self._func = func self._out_dtypes = Tout + def _convert(self, value, dtype): + if isinstance(value, resource_variable_ops.ResourceVariable): + raise RuntimeError( + "Attempting to return a variable from an eagerly executed py_func. " + "Only numeric data structures like Tensors or NumPy arrays should " + "be returned; to return the value of a variable, make sure to obtain " + "the Tensor backing it by calling `.read_value()` on the variable in " + "question: %s" % value) + return ops.convert_to_tensor(value, dtype=dtype) + def __call__(self, on_gpu, args): """Passes `args` to `self._func`, which is executed eagerly.""" with context.eager_mode(): @@ -58,14 +72,13 @@ class EagerFunc(object): maybe_copy_to_gpu = lambda x: x if not on_gpu else x.gpu() if isinstance(ret, (tuple, list)): return [ - maybe_copy_to_gpu(ops.convert_to_tensor(x, dtype=dtype)) + maybe_copy_to_gpu(self._convert(x, dtype=dtype)) for (x, dtype) in zip(ret, self._out_dtypes) ] elif ret is None: return ret else: - return maybe_copy_to_gpu( - ops.convert_to_tensor(ret, dtype=self._out_dtypes[0])) + return maybe_copy_to_gpu(self._convert(ret, dtype=self._out_dtypes[0])) class FuncRegistry(object): @@ -219,18 +232,16 @@ def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None): graph._cleanup_py_funcs_used_in_graph.append(cleanup) # pylint: enable=protected-access - # pylint: disable=protected-access if eager: - result = gen_script_ops._eager_py_func( + result = gen_script_ops.eager_py_func( input=inp, token=token, Tout=Tout, name=name) else: if stateful: - result = gen_script_ops._py_func( + result = gen_script_ops.py_func( input=inp, token=token, Tout=Tout, name=name) else: - result = gen_script_ops._py_func_stateless( + result = gen_script_ops.py_func_stateless( input=inp, token=token, Tout=Tout, name=name) - # pylint: enable=protected-access return result if is_list_or_tuple else result[0] @@ -319,7 +330,7 @@ def py_func(func, inp, Tout, stateful=True, name=None): Returns: A list of `Tensor` or a single `Tensor` which `func` computes. """ - if context.in_eager_mode(): + if context.executing_eagerly(): result = func(*[x.numpy() for x in inp]) result = nest.flatten(result) diff --git a/tensorflow/python/ops/session_ops.py b/tensorflow/python/ops/session_ops.py index cedd36c1deed541adcf601ff9447345e2279e8f9..ad38845153c94e9bb31e6e3ee05ebed0a4313efc 100644 --- a/tensorflow/python/ops/session_ops.py +++ b/tensorflow/python/ops/session_ops.py @@ -16,7 +16,6 @@ """Tensor Handle Operations. See the @{$python/session_ops} guide. @@get_session_handle -@@get_session_handle_v2 @@get_session_tensor @@delete_session_tensor """ @@ -182,7 +181,7 @@ def get_session_handle(data, name=None): # Colocate this operation with data. with ops.colocate_with(data): - return gen_data_flow_ops._get_session_handle(data, name=name) # pylint: disable=protected-access + return gen_data_flow_ops.get_session_handle(data, name=name) @tf_export("get_session_tensor") @@ -222,7 +221,7 @@ def get_session_tensor(handle, dtype, name=None): with ops.device(handle_device): holder = array_ops.placeholder(dtypes.string) _register_handle_feeder(holder.graph, holder, dtype) - tensor = gen_data_flow_ops._get_session_tensor(holder, dtype, name=name) + tensor = gen_data_flow_ops.get_session_tensor(holder, dtype, name=name) return (holder, tensor) @@ -246,7 +245,7 @@ def delete_session_tensor(handle, name=None): handle_device = TensorHandle._get_device_name(handle) with ops.device(handle_device): holder = array_ops.placeholder(dtypes.string) - deleter = gen_data_flow_ops._delete_session_tensor(holder, name=name) + deleter = gen_data_flow_ops.delete_session_tensor(holder, name=name) return (holder, deleter) @@ -268,7 +267,7 @@ def _get_handle_reader(graph, handle, dtype): with graph.as_default(), graph.device(handle_device): holder = array_ops.placeholder(dtypes.string) _register_handle_feeder(holder.graph, holder, dtype) - reader = gen_data_flow_ops._get_session_tensor(holder, dtype) + reader = gen_data_flow_ops.get_session_tensor(holder, dtype) result = (holder, reader) graph._handle_readers[graph_key] = result return result @@ -289,7 +288,7 @@ def _get_handle_mover(graph, feeder, handle): # Create mover if we haven't done it. holder, reader = _get_handle_reader(graph, handle, dtype) with graph.as_default(), graph.device(feeder.op.device): - mover = gen_data_flow_ops._get_session_handle(reader) # pylint: disable=protected-access + mover = gen_data_flow_ops.get_session_handle(reader) result = (holder, mover) graph._handle_movers[graph_key] = result return result @@ -303,7 +302,7 @@ def _get_handle_deleter(graph, deleter_key, handle): handle_device = TensorHandle._get_device_name(handle) with graph.as_default(), graph.device(handle_device): holder = array_ops.placeholder(dtypes.string) - deleter = gen_data_flow_ops._delete_session_tensor(holder) + deleter = gen_data_flow_ops.delete_session_tensor(holder) result = (holder, deleter) graph._handle_deleters[deleter_key] = result return result diff --git a/tensorflow/python/ops/sets_impl.py b/tensorflow/python/ops/sets_impl.py index b0eecd8a1e812857de8f47e1370e4fc5f1004bc0..21e08d03d213c173d12dfc6676fe7f009811e93f 100644 --- a/tensorflow/python/ops/sets_impl.py +++ b/tensorflow/python/ops/sets_impl.py @@ -247,7 +247,7 @@ def set_difference(a, b, aminusb=True, validate_indices=True): # # collections.OrderedDict([ # ((0, 0, 0), 2), - # ((0, 0, 1), 3), + # ((0, 1, 0), 3), # ]) ``` diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py index 5295e7d21c2b5810422ec36f5aced63c9039feca..97353d6c747cb7e4d3c1fa92ad61af24fb17de91 100644 --- a/tensorflow/python/ops/sparse_grad.py +++ b/tensorflow/python/ops/sparse_grad.py @@ -88,10 +88,8 @@ def _SparseAddGrad(op, *grads): # the non-zero elements of the sum, and we will peek into `sum_indices` in the # gradient op. - # pylint: disable=protected-access - a_val_grad, b_val_grad = gen_sparse_ops._sparse_add_grad(val_grad, a_indices, - b_indices, - sum_indices) + a_val_grad, b_val_grad = gen_sparse_ops.sparse_add_grad( + val_grad, a_indices, b_indices, sum_indices) a_val_grad.set_shape(op.inputs[1].get_shape()) b_val_grad.set_shape(op.inputs[4].get_shape()) # (a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh) @@ -151,7 +149,7 @@ def _SparseTensorDenseMatMulGrad(op, grad): "complex gradients.") # gradient w.r.t. dense - b_grad = gen_sparse_ops._sparse_tensor_dense_mat_mul( # pylint: disable=protected-access + b_grad = gen_sparse_ops.sparse_tensor_dense_mat_mul( a_indices, a_values, a_shape, grad, adjoint_a=not adj_a) if adj_b: b_grad = array_ops.transpose(b_grad) @@ -278,8 +276,7 @@ def _SparseFillEmptyRowsGrad(op, unused_grad_output_indices, output_grad_values, """Gradients for SparseFillEmptyRows.""" reverse_index_map = op.outputs[3] - # pylint: disable=protected-access - d_values, d_default_value = gen_sparse_ops._sparse_fill_empty_rows_grad( + d_values, d_default_value = gen_sparse_ops.sparse_fill_empty_rows_grad( reverse_index_map=reverse_index_map, grad_values=output_grad_values) # d_indices, d_values, d_dense_shape, d_default_value. diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 0fbbf5a805f1439d85ad53f02bdb665c04248606..c580052c32c8b61467b857af3d237be41718c1a1 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -234,7 +234,7 @@ def sparse_concat(axis, ] output_ind, output_val, output_shape = ( - gen_sparse_ops._sparse_concat(inds, vals, shapes, axis, name=name)) + gen_sparse_ops.sparse_concat(inds, vals, shapes, axis, name=name)) return sparse_tensor.SparseTensor(output_ind, output_val, output_shape) @@ -302,8 +302,8 @@ def sparse_add(a, b, thresh=0): thresh = ops.convert_to_tensor( thresh, dtype=a.values.dtype.real_dtype.base_dtype, name="thresh") output_ind, output_val, output_shape = ( - gen_sparse_ops._sparse_add(a.indices, a.values, a.dense_shape, - b.indices, b.values, b.dense_shape, thresh)) + gen_sparse_ops.sparse_add(a.indices, a.values, a.dense_shape, + b.indices, b.values, b.dense_shape, thresh)) # Attempt to get output_shape statically. a.get_shape().assert_is_compatible_with(b.get_shape()) @@ -317,8 +317,8 @@ def sparse_add(a, b, thresh=0): # swap to make `a` the SparseTensor. if isinstance(b, sparse_classes): a, b = b, a - return gen_sparse_ops._sparse_tensor_dense_add(a.indices, a.values, - a.dense_shape, b) + return gen_sparse_ops.sparse_tensor_dense_add(a.indices, a.values, + a.dense_shape, b) def _sparse_cross(inputs, name=None): @@ -402,7 +402,7 @@ def _sparse_cross_internal(inputs, num_buckets=0, hash_key=None, name=None): - """See gen_sparse_ops._sparse_cross.""" + """See gen_sparse_ops.sparse_cross.""" if not isinstance(inputs, list): raise TypeError("Inputs must be a list") if not all( @@ -432,7 +432,7 @@ def _sparse_cross_internal(inputs, dense_inputs[i] = math_ops.to_int64(dense_inputs[i]) internal_type = dtypes.int64 - indices_out, values_out, shape_out = gen_sparse_ops._sparse_cross( + indices_out, values_out, shape_out = gen_sparse_ops.sparse_cross( indices=indices, values=values, shapes=shapes, @@ -511,7 +511,7 @@ def sparse_reorder(sp_input, name=None): sp_input = _convert_to_sparse_tensor(sp_input) reordered_ind, reordered_val = ( - gen_sparse_ops._sparse_reorder( + gen_sparse_ops.sparse_reorder( sp_input.indices, sp_input.values, sp_input.dense_shape, name=name)) if sp_input.get_shape().is_fully_defined(): @@ -575,7 +575,7 @@ def sparse_reshape(sp_input, shape, name=None): shape = math_ops.cast(shape, dtype=dtypes.int64) with ops.name_scope(name, "SparseReshape", [sp_input]) as name: - reshaped_ind, reshaped_shape = gen_sparse_ops._sparse_reshape( + reshaped_ind, reshaped_shape = gen_sparse_ops.sparse_reshape( sp_input.indices, sp_input.dense_shape, shape, name=name) reshaped_shape_const = tensor_util.constant_value(shape) @@ -671,7 +671,7 @@ def sparse_split(keyword_required=KeywordRequired(), sp_input = _convert_to_sparse_tensor(sp_input) output_inds, output_vals, output_shapes = ( - gen_sparse_ops._sparse_split( + gen_sparse_ops.sparse_split( axis, sp_input.indices, sp_input.values, @@ -782,7 +782,7 @@ def sparse_to_dense(sparse_indices, Dense `Tensor` of shape `output_shape`. Has the same type as `sparse_values`. """ - return gen_sparse_ops._sparse_to_dense( + return gen_sparse_ops.sparse_to_dense( sparse_indices, output_shape, sparse_values, @@ -1412,7 +1412,7 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None): default_value = ops.convert_to_tensor( default_value, dtype=sp_input.values.dtype) (output_indices, output_values, empty_row_indicator, - unused_reverse_index_map) = gen_sparse_ops._sparse_fill_empty_rows( + unused_reverse_index_map) = gen_sparse_ops.sparse_fill_empty_rows( indices=sp_input.indices, values=sp_input.values, dense_shape=sp_input.dense_shape, @@ -1441,7 +1441,7 @@ def serialize_sparse(sp_input, name=None, out_type=dtypes.string): """ sp_input = _convert_to_sparse_tensor(sp_input) - return gen_sparse_ops._serialize_sparse( + return gen_sparse_ops.serialize_sparse( sp_input.indices, sp_input.values, sp_input.dense_shape, @@ -1476,7 +1476,7 @@ def serialize_many_sparse(sp_input, name=None, out_type=dtypes.string): """ sp_input = _convert_to_sparse_tensor(sp_input) - return gen_sparse_ops._serialize_many_sparse( + return gen_sparse_ops.serialize_many_sparse( sp_input.indices, sp_input.values, sp_input.dense_shape, @@ -1541,7 +1541,7 @@ def deserialize_sparse(serialized_sparse, dtype, rank=None, name=None): """ output_indices, output_values, output_shape = ( - gen_sparse_ops._deserialize_sparse(serialized_sparse, dtype, name=name)) + gen_sparse_ops.deserialize_sparse(serialized_sparse, dtype, name=name)) # Feed rank data back in, if available output_indices.set_shape([None, rank]) @@ -1610,7 +1610,7 @@ def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None): All of the serialized `SparseTensor`s must have had the same rank and type. """ output_indices, output_values, output_shape = ( - gen_sparse_ops._deserialize_many_sparse( + gen_sparse_ops.deserialize_many_sparse( serialized_sparse, dtype, name=name)) # Feed rank data back in, if available @@ -1828,7 +1828,7 @@ def sparse_tensor_dense_matmul(sp_a, with ops.name_scope(name, "SparseTensorDenseMatMul", [sp_a.indices, sp_a.values, b]) as name: b = ops.convert_to_tensor(b, name="b") - return gen_sparse_ops._sparse_tensor_dense_mat_mul( + return gen_sparse_ops.sparse_tensor_dense_mat_mul( a_indices=sp_a.indices, a_values=sp_a.values, a_shape=sp_a.dense_shape, @@ -2046,7 +2046,7 @@ def _add_sparse_to_tensors_map(sp_input, """ sp_input = _convert_to_sparse_tensor(sp_input) - return gen_sparse_ops._add_sparse_to_tensors_map( + return gen_sparse_ops.add_sparse_to_tensors_map( sp_input.indices, sp_input.values, sp_input.dense_shape, @@ -2086,7 +2086,7 @@ def _add_many_sparse_to_tensors_map(sp_input, """ sp_input = _convert_to_sparse_tensor(sp_input) - return gen_sparse_ops._add_many_sparse_to_tensors_map( + return gen_sparse_ops.add_many_sparse_to_tensors_map( sp_input.indices, sp_input.values, sp_input.dense_shape, @@ -2167,7 +2167,7 @@ def _take_many_sparse_from_tensors_map(sparse_map_op, with ops.colocate_with(sparse_map_op): shared_name = sparse_map_op.get_attr("shared_name") or sparse_map_op.name output_indices, output_values, output_shape = ( - gen_sparse_ops._take_many_sparse_from_tensors_map( + gen_sparse_ops.take_many_sparse_from_tensors_map( sparse_handles, dtype=sparse_map_op.get_attr("T"), container=sparse_map_op.get_attr("container"), diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py index 6d7eaababcd94d687ff20dddc35c68a98320a19b..5e2146b79f08e6671c429f388b05634b1727b4ed 100644 --- a/tensorflow/python/ops/special_math_ops.py +++ b/tensorflow/python/ops/special_math_ops.py @@ -163,7 +163,7 @@ def einsum(equation, *inputs, **kwargs): if '...' in equation: raise ValueError('Subscripts with ellipses are not yet supported.') - match = re.match('([a-z,]+)(->[a-z]*)?', equation) + match = re.match('^([a-zA-Z,]+)(->[a-zA-Z]*)?$', equation) if not match: raise ValueError('Indices have incorrect format: %s' % equation) @@ -402,7 +402,7 @@ def _exponential_space_einsum(equation, *inputs): if '...' in equation: raise ValueError('Subscripts with ellipses are not yet supported.') - match = re.match('([a-z,]+)(->[a-z]*)?', equation) + match = re.match('^([a-zA-Z,]+)(->[a-zA-Z]*)?$', equation) if not match: raise ValueError('Indices have incorrect format: %s' % equation) diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py index 2c212f45483eacfd3fd27eecb8d7b2c846b5fe96..d7c3a7e8dc7c2ad611cf47718dddcf38700ce304 100644 --- a/tensorflow/python/ops/special_math_ops_test.py +++ b/tensorflow/python/ops/special_math_ops_test.py @@ -192,6 +192,9 @@ class EinsumTest(test.TestCase): 'abc,cba', 'dba,ead,cad->bce', 'aef,fbc,dca->bde', + 'iJ,Jk->ik', + 'iJ,Ki->JK', + 'iJk,Jklm->Jk' ] long_cases = [ @@ -208,6 +211,8 @@ class EinsumTest(test.TestCase): 'ijk ijk', 'ij.jk->ik', 'ij...,jk...->ik...', + 'ij,k ->kji', + 'ij,k-> kji', # axis in output that does not exist 'ij,jk->im', diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index b62e556967753dac4418add2864ce4e641dc6b58..e90ff0746a8e86b4b462b71028fd677632c9075d 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -80,6 +80,8 @@ from tensorflow.python.ops.state_ops import scatter_add from tensorflow.python.ops.state_ops import scatter_div from tensorflow.python.ops.state_ops import scatter_mul from tensorflow.python.ops.state_ops import scatter_sub +from tensorflow.python.ops.state_ops import scatter_min +from tensorflow.python.ops.state_ops import scatter_max from tensorflow.python.ops.state_ops import scatter_update from tensorflow.python.ops.state_ops import scatter_nd_add from tensorflow.python.ops.state_ops import scatter_nd_sub @@ -186,7 +188,6 @@ _allowed_symbols_array_ops = [ "quantize_and_dequantize", # to-doc # TODO(drpng): legacy symbols to be removed. - "list_diff", # Use tf.listdiff instead. "batch_matrix_diag", "batch_matrix_band_part", "batch_matrix_diag_part", @@ -219,6 +220,8 @@ _allowed_symbols_gradients = [ # Documented in training.py: # Not importing training.py to avoid complex graph dependencies. "AggregationMethod", + "GradientTape", + "custom_gradient", "gradients", # tf.gradients = gradients.gradients "hessians", ] diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index 6c0a090d16bb328de40f02edf9865a0e0a62d385..01fc3182bc6f7b4f85d0df540bb26308d9fec72f 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -63,6 +63,8 @@ @@scatter_nd_update @@scatter_sub @@scatter_update +@@scatter_min +@@scatter_max @@sparse_mask @@tables_initializer @@trainable_variables @@ -99,8 +101,8 @@ def variable_op(shape, dtype, name="Variable", set_shape=True, container="", """Deprecated. Used variable_op_v2 instead.""" if not set_shape: shape = tensor_shape.unknown_shape() - ret = gen_state_ops._variable(shape=shape, dtype=dtype, name=name, - container=container, shared_name=shared_name) + ret = gen_state_ops.variable(shape=shape, dtype=dtype, name=name, + container=container, shared_name=shared_name) # TODO(mrry): Move this to where it is used, so we can get rid of this op # wrapper? if set_shape: @@ -127,11 +129,12 @@ def variable_op_v2(shape, dtype, name="Variable", container="", shared_name=""): Returns: A variable tensor. """ - return gen_state_ops._variable_v2(shape=shape, - dtype=dtype, - name=name, - container=container, - shared_name=shared_name) + return gen_state_ops.variable_v2( + shape=shape, + dtype=dtype, + name=name, + container=container, + shared_name=shared_name) def init_variable(v, init, name="init"): @@ -185,7 +188,7 @@ def is_variable_initialized(ref, name=None): if ref.dtype._is_ref_dtype: return gen_state_ops.is_variable_initialized(ref=ref, name=name) # Handle resource variables. - if context.in_eager_mode() or ref.op.type == "VarHandleOp": + if context.executing_eagerly() or ref.op.type == "VarHandleOp": return gen_resource_variable_ops.var_is_initialized_op(ref.handle, name=name) diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py index b8c39d91b41790c6441594b175e8eaa03620e1ec..5bd75b9215fdbccd5882ea39c2b35ccbbe29d5b0 100644 --- a/tensorflow/python/ops/string_ops.py +++ b/tensorflow/python/ops/string_ops.py @@ -17,6 +17,7 @@ See the @{$python/string_ops} guide. +@@regex_replace @@string_to_hash_bucket_fast @@string_to_hash_bucket_strong @@string_to_hash_bucket @@ -93,10 +94,8 @@ def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=inv delimiter = ops.convert_to_tensor(delimiter, dtype=dtypes.string) source = ops.convert_to_tensor(source, dtype=dtypes.string) - # pylint: disable=protected-access - indices, values, shape = gen_string_ops._string_split( + indices, values, shape = gen_string_ops.string_split( source, delimiter=delimiter, skip_empty=skip_empty) - # pylint: enable=protected-access indices.set_shape([None, 2]) values.set_shape([None]) shape.set_shape([2]) @@ -141,6 +140,7 @@ def reduce_join(inputs, axis=None, reduce_join.__doc__ = deprecation.rewrite_argument_docstring( gen_string_ops.reduce_join.__doc__, "reduction_indices", "axis") +ops.NotDifferentiable("RegexReplace") ops.NotDifferentiable("StringToHashBucket") ops.NotDifferentiable("StringToHashBucketFast") ops.NotDifferentiable("StringToHashBucketStrong") diff --git a/tensorflow/python/ops/summary_ops.py b/tensorflow/python/ops/summary_ops.py index 7f4f4ce5ab4ee2bd309932cb81f05775996371d6..037bc9845a3f734f65b73b0c4b4ca19fb653731d 100644 --- a/tensorflow/python/ops/summary_ops.py +++ b/tensorflow/python/ops/summary_ops.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== """Summary Operations.""" -# pylint: disable=protected-access from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -74,7 +73,7 @@ def tensor_summary(name, with summary_op_util.summary_scope( name, family, values=[tensor]) as (tag, scope): - val = gen_logging_ops._tensor_summary_v2( + val = gen_logging_ops.tensor_summary_v2( tensor=tensor, tag=tag, name=scope, diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py index 424582b348d87d8a5b043ec9b771d8f2768a5994..0294ecee548d1e7f507a5e4195e4ee320a0b9918 100644 --- a/tensorflow/python/ops/template.py +++ b/tensorflow/python/ops/template.py @@ -26,6 +26,7 @@ from tensorflow.python.eager import function from tensorflow.python.framework import ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpointable from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_decorator from tensorflow.python.util.deprecation import deprecated @@ -203,7 +204,7 @@ def make_template_internal(name_, if kwargs: func_ = tf_decorator.make_decorator(func_, functools.partial( func_, **kwargs)) - if context.in_eager_mode(): + if context.executing_eagerly(): if unique_name_ is not None: raise ValueError( "unique_name_ cannot be used when eager exeuction is enabled.") @@ -230,7 +231,7 @@ def _skip_common_stack_elements(stacktrace, base_case): return stacktrace[-1:] -class Template(object): +class Template(checkpointable.CheckpointableBase): """Wrap a function to aid in variable sharing. Templates are functions that create variables the first time they are called @@ -294,12 +295,115 @@ class Template(object): # which is not the same as whether the scope has been created. self._variables_created = False + @property + def _checkpoint_dependencies(self): + """Sanity checking for object-based saving. + + Does not override Checkpointable dependency tracking, but checks that + variables accessible through Checkpointable dependencies on other `Template` + objects include all of the variable_scope-filtered `Template.variables`. + + Returns: + A list of checkpointable.CheckpointableReference objects. + Raises: + ValueError: If this object is not compatible with object-based saving. + """ + dependencies = super(Template, self)._checkpoint_dependencies + dependency_variables = [] + for _, dependency in dependencies: + if isinstance(dependency, Template): + dependency_variables.extend(dependency.variables) + else: + dependency_variables.append(dependency) + dependency_variables = set(dependency_variables) + not_included_variables = [] + for expected_variable in sorted(self.variables, key=lambda v: v.name): + if expected_variable not in dependency_variables: + not_included_variables.append(expected_variable) + if not_included_variables: + # Trying to save a Template which improperly tracks its variables. + raise ValueError( + ("The Template '%s' references variables which are not included via " + "object-based dependency tracking. Most likely a custom " + "getter/creator was registered which does not call Template's " + "custom variable creator (which is responsible for tracking " + "dependencies).\n\nExpected these variables to be dependencies: %s") + % (self, not_included_variables)) + return dependencies + + def _checkpointable_custom_creator(self, next_creator, name, initial_value, + checkpointable_parent=None, **kwargs): + """A variable creation hook which adds Checkpointable dependencies. + + Set during the `Template`'s first wrapped function execution. Ensures that + (a) `Template` objects depend on `Template`s created inside them which + create variables, and (b) that any variables not in a more deeply nested + `Template` are added as dependencies directly. + + The `checkpointable_parent` argument is passed between `Template` custom + creators but ignored when the variable object itself is created. This + argument indicates (if not `None`) that a more deeply nested `Template` has + already added the variable as a dependency, and that parent `Template`s + should add a dependency on that `Template` rather than on the variable + directly. + + Args: + next_creator: See `variable_scope.variable_creator_scope`; the next + creator in the chain. + name: The (full, scope-influenced) name of the variable. The scope name + for the Template itself is stripped for the purposes of object-based + dependency tracking, but scopes within Templates are respected. + initial_value: See `variable_scope.variable_creator_scope`. Taken + explicitly so the argument can be re-named and used with + `Checkpointable._add_variable_with_custom_getter`. + checkpointable_parent: If not None, a more deeply nested Template object + to add a dependency on (rather than depending on the variable directly). + **kwargs: Passed through to the next creator. + Returns: + The output of `next_creator`: the fetched/created variable object. + """ + def _call_next_creator_renaming_initializer(initializer, **inner_kwargs): + inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which + # we don't want to propagate. + return next_creator( + initial_value=initializer, + name=name, + **inner_kwargs) + if name.startswith(self._variable_scope.name): + scope_stripped_name = name[len(self._variable_scope.name) + 1:] + if not checkpointable_parent: + return self._add_variable_with_custom_getter( + initializer=initial_value, + name=scope_stripped_name, + getter=_call_next_creator_renaming_initializer, + # Disable error checking for Checkpointable. Exceptions are instead + # raised if necessary when the object-based saver tries to + # save/restore the object. + overwrite=True, + checkpointable_parent=self, + **kwargs) + else: + self._track_checkpointable( + checkpointable_parent, + name=checkpointable_parent._variable_scope.name[ # pylint: disable=protected-access + len(self._variable_scope.name) + 1:], + overwrite=True) + return next_creator(name=name, initial_value=initial_value, + checkpointable_parent=self, **kwargs) + def _call_func(self, args, kwargs): try: vars_at_start = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) trainable_at_start = len( ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) - result = self._func(*args, **kwargs) + if self._variables_created: + result = self._func(*args, **kwargs) + else: + # The first time we run, restore variables if necessary (via + # Checkpointable). + with variable_scope.variable_creator_scope( + self._checkpointable_custom_creator): + result = self._func(*args, **kwargs) if self._variables_created: # Variables were previously created, implying this is not the first @@ -479,7 +583,7 @@ class _EagerTemplateVariableStore(object): if self._variable_scope_name is None: raise RuntimeError("A variable scope must be set before an " "_EagerTemplateVariableStore object exits.") - self._eager_variable_store._store.close_variable_subscopes( # pylint: disable=protected-access + variable_scope.get_variable_scope_store().close_variable_subscopes( self._variable_scope_name) def _variables_in_scope(self, variable_list): @@ -543,7 +647,7 @@ class EagerTemplate(Template): Raises: RuntimeError: if eager execution is not enabled. """ - if not context.in_eager_mode(): + if not context.executing_eagerly(): raise RuntimeError( "{} objects can only be used when eager execution is enabled, use " "tf.Template for graph construction". @@ -563,7 +667,14 @@ class EagerTemplate(Template): try: vars_at_start = self._template_store.variables() trainable_at_start = self._template_store.trainable_variables() - result = self._func(*args, **kwargs) + if self._variables_created: + result = self._func(*args, **kwargs) + else: + # The first time we run, restore variables if necessary (via + # Checkpointable). + with variable_scope.variable_creator_scope( + self._checkpointable_custom_creator): + result = self._func(*args, **kwargs) if self._variables_created: # Variables were previously created, implying this is not the first diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index 3c08870146e447d84d4a5f620cbead633d94751f..2f6badcb532c0ef9d82b211d0c7b11a67e8e3010 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -148,7 +148,7 @@ class _GraphTensorArray(object): # will retroactively set the device value of this op. def create(): """Create the TensorArray op.""" - return gen_data_flow_ops._tensor_array_v3( + return gen_data_flow_ops.tensor_array_v3( dtype=dtype, size=size, element_shape=element_shape, @@ -237,7 +237,7 @@ class _GraphTensorArray(object): flow = self.flow with ops.name_scope(name, "TensorArrayGrad", [self._handle]): with ops.colocate_with(self._handle): - g_handle, unused_flow = gen_data_flow_ops._tensor_array_grad_v3( + g_handle, unused_flow = gen_data_flow_ops.tensor_array_grad_v3( handle=self._handle, source=source, flow_in=flow, name=name) with ops.control_dependencies([g_handle]): flow = array_ops.identity(flow, name="gradient_flow") @@ -252,7 +252,7 @@ class _GraphTensorArray(object): def read(self, index, name=None): """See TensorArray.""" - value = gen_data_flow_ops._tensor_array_read_v3( + value = gen_data_flow_ops.tensor_array_read_v3( handle=self._handle, index=index, flow_in=self._flow, @@ -270,7 +270,7 @@ class _GraphTensorArray(object): if self._infer_shape: self._merge_element_shape(value.shape) with self._maybe_colocate_with(value): - flow_out = gen_data_flow_ops._tensor_array_write_v3( + flow_out = gen_data_flow_ops.tensor_array_write_v3( handle=self._handle, index=index, value=value, @@ -296,7 +296,7 @@ class _GraphTensorArray(object): element_shape = self._element_shape[0] else: element_shape = tensor_shape.TensorShape(None) - value = gen_data_flow_ops._tensor_array_gather_v3( + value = gen_data_flow_ops.tensor_array_gather_v3( handle=self._handle, indices=indices, flow_in=self._flow, @@ -314,7 +314,7 @@ class _GraphTensorArray(object): tensor_shape.TensorShape(self._element_shape[0].dims[1:])) else: element_shape_except0 = tensor_shape.TensorShape(None) - value, _ = gen_data_flow_ops._tensor_array_concat_v3( + value, _ = gen_data_flow_ops.tensor_array_concat_v3( handle=self._handle, flow_in=self._flow, dtype=self._dtype, @@ -338,10 +338,10 @@ class _GraphTensorArray(object): with ops.name_scope(name, "TensorArrayScatter", [self._handle, value, indices]): value = ops.convert_to_tensor(value, name="value") - if self._infer_shape and context.in_graph_mode(): + if self._infer_shape and not context.executing_eagerly(): self._merge_element_shape(value.shape[1:]) with self._maybe_colocate_with(value): - flow_out = gen_data_flow_ops._tensor_array_scatter_v3( + flow_out = gen_data_flow_ops.tensor_array_scatter_v3( handle=self._handle, indices=indices, value=value, @@ -363,14 +363,14 @@ class _GraphTensorArray(object): value = ops.convert_to_tensor(value, name="value") with self._maybe_colocate_with(value): lengths_64 = math_ops.to_int64(lengths) - if self._infer_shape and context.in_graph_mode(): + if self._infer_shape and not context.executing_eagerly(): clengths = tensor_util.constant_value(lengths_64) if value.shape.dims is not None: if clengths is not None and clengths.max() == clengths.min(): self._merge_element_shape( tensor_shape.TensorShape([clengths[0]]).concatenate( value.shape[1:])) - flow_out = gen_data_flow_ops._tensor_array_split_v3( + flow_out = gen_data_flow_ops.tensor_array_split_v3( handle=self._handle, value=value, lengths=lengths_64, @@ -386,13 +386,13 @@ class _GraphTensorArray(object): def size(self, name=None): """See TensorArray.""" - return gen_data_flow_ops._tensor_array_size_v3( + return gen_data_flow_ops.tensor_array_size_v3( handle=self._handle, flow_in=self.flow, name=name) @tf_should_use.should_use_result def close(self, name=None): """See TensorArray.""" - return gen_data_flow_ops._tensor_array_close_v3( + return gen_data_flow_ops.tensor_array_close_v3( handle=self._handle, name=name) # pylint: enable=protected-access @@ -774,10 +774,10 @@ class TensorArray(object): ValueError: if both handle and tensor_array_name are provided. TypeError: if handle is provided but is not a Tensor. """ - if context.in_graph_mode(): - implementation = _GraphTensorArray - else: + if context.executing_eagerly(): implementation = _EagerTensorArray + else: + implementation = _GraphTensorArray self._implementation = implementation( dtype, diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 81565a63774da49628d100ef071b02f6311f6af2..c35735ca656b21d43f758830e68e5777d654f271 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -24,6 +24,7 @@ import copy import enum # pylint: disable=g-bad-import-order import functools import sys +import threading import traceback import six @@ -211,23 +212,8 @@ class _VariableStore(object): """Create a variable store.""" self._vars = {} # A dictionary of the stored TensorFlow variables. self._partitioned_vars = {} # A dict of the stored PartitionedVariables. - self.variable_scopes_count = {} # Count re-used variable scopes. self._store_eager_variables = False - def open_variable_scope(self, scope_name): - if scope_name in self.variable_scopes_count: - self.variable_scopes_count[scope_name] += 1 - else: - self.variable_scopes_count[scope_name] = 1 - - def close_variable_subscopes(self, scope_name): - for k in self.variable_scopes_count: - if not scope_name or k.startswith(scope_name + "/"): - self.variable_scopes_count[k] = 0 - - def variable_scope_count(self, scope_name): - return self.variable_scopes_count.get(scope_name, 0) - def get_variable(self, name, shape=None, dtype=dtypes.float32, initializer=None, regularizer=None, reuse=None, trainable=True, collections=None, caching_device=None, @@ -321,7 +307,7 @@ class _VariableStore(object): raise ValueError( "Passed a custom_getter which is not callable: %s" % custom_getter) - if context.in_eager_mode(): + if context.executing_eagerly(): if not self._store_eager_variables and reuse: raise RuntimeError( "When eager execution is enabled variable reuse is only supported" @@ -518,7 +504,7 @@ class _VariableStore(object): when violating reuse during variable creation, or if an existing sharded variable exists for the given name but with different sharding. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise NotImplementedError("Partitioned variables are not yet supported " "when eager execution is enabled.") @@ -798,7 +784,7 @@ class _VariableStore(object): validate_shape=validate_shape, constraint=constraint, use_resource=use_resource) - if context.in_graph_mode() or self._store_eager_variables: + if not context.executing_eagerly() or self._store_eager_variables: # In eager mode we do not want to keep default references to Variable # objects as this will prevent their memory from being released. self._vars[name] = v @@ -811,12 +797,12 @@ class _VariableStore(object): with ops.name_scope(name + "/Regularizer/"): loss = regularizer(v) if loss is not None: - if context.in_graph_mode(): - v_name = v.name - loss_name = loss.name - else: + if context.executing_eagerly(): v_name = "v_%s" % type(v) loss_name = "loss_%s" % type(loss) + else: + v_name = v.name + loss_name = loss.name logging.vlog(1, "Applied regularizer to %s and added the result %s " "to REGULARIZATION_LOSSES.", v_name, loss_name) ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss) @@ -920,7 +906,7 @@ class VariableScope(object): self._dtype = dtype self._use_resource = use_resource self._constraint = constraint - if context.in_eager_mode(): + if context.executing_eagerly(): if self._caching_device is not None: raise NotImplementedError("Caching devices is not yet supported " "when eager execution is enabled.") @@ -988,7 +974,7 @@ class VariableScope(object): def set_use_resource(self, use_resource): """Sets whether to use ResourceVariables for this scope.""" - if context.in_eager_mode() and not use_resource: + if context.executing_eagerly() and not use_resource: raise ValueError("When eager execution is enabled, " "use_resource cannot be set to false.") self._use_resource = use_resource @@ -999,14 +985,14 @@ class VariableScope(object): def set_caching_device(self, caching_device): """Set caching_device for this scope.""" - if context.in_eager_mode(): + if context.executing_eagerly(): raise NotImplementedError("Caching devices are not yet supported " "when eager execution is enabled.") self._caching_device = caching_device def set_partitioner(self, partitioner): """Set partitioner for this scope.""" - if partitioner and context.in_eager_mode(): + if partitioner and context.executing_eagerly(): raise NotImplementedError("Partitioned variables are not yet supported " "when eager execution is enabled.") self._partitioner = partitioner @@ -1057,14 +1043,14 @@ class VariableScope(object): partitioner = self._partitioner if custom_getter is None: custom_getter = self._custom_getter - if context.in_graph_mode(): + if context.executing_eagerly(): + reuse = False + use_resource = True + else: if reuse is None: reuse = self._reuse if use_resource is None: use_resource = self._use_resource - else: - reuse = False - use_resource = True full_name = self.name + "/" + name if self.name else name # Variable names only depend on variable_scope (full_name here), @@ -1107,7 +1093,7 @@ class VariableScope(object): use_resource=None, constraint=None): """Gets an existing variable with this name or create a new one.""" - if context.in_eager_mode(): + if context.executing_eagerly(): raise NotImplementedError("Partitioned variables are not yet supported " "when eager execution is enabled.") if initializer is None: @@ -1160,18 +1146,49 @@ class VariableScope(object): _VARSTORE_KEY = ("__variable_store",) -_VARSCOPE_KEY = ("__varscope",) +_VARSCOPESTORE_KEY = ("__varscope",) + + +class _VariableScopeStore(threading.local): + """A thread local store for the current variable scope and scope counts.""" + + def __init__(self): + super(_VariableScopeStore, self).__init__() + self.current_scope = VariableScope(False) + self.variable_scopes_count = {} + + def open_variable_scope(self, scope_name): + if scope_name in self.variable_scopes_count: + self.variable_scopes_count[scope_name] += 1 + else: + self.variable_scopes_count[scope_name] = 1 + + def close_variable_subscopes(self, scope_name): + for k in self.variable_scopes_count: + if not scope_name or k.startswith(scope_name + "/"): + self.variable_scopes_count[k] = 0 + + def variable_scope_count(self, scope_name): + return self.variable_scopes_count.get(scope_name, 0) + + +def get_variable_scope_store(): + """Returns the variable scope store for current thread.""" + scope_store = ops.get_collection(_VARSCOPESTORE_KEY) + + if not scope_store: + scope_store = _VariableScopeStore() + ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store) + else: + scope_store = scope_store[0] + + return scope_store @tf_export("get_variable_scope") def get_variable_scope(): """Returns the current variable scope.""" - scope = ops.get_collection(_VARSCOPE_KEY) - if scope: # This collection has at most 1 element, the default scope at [0]. - return scope[0] - scope = VariableScope(False) - ops.add_to_collection(_VARSCOPE_KEY, scope) - return scope + return get_variable_scope_store().current_scope def _get_default_variable_store(): @@ -1274,6 +1291,9 @@ class EagerVariableStore(object): # pylint: enable=protected-access +# The argument list for get_variable must match arguments to get_local_variable. +# So, if you are updating the arguments, also update arguments to +# get_local_variable below. @tf_export("get_variable") def get_variable(name, shape=None, @@ -1385,15 +1405,32 @@ get_variable.__doc__ = get_variable_or_local_docstring % ( "GraphKeys.GLOBAL_VARIABLES") -@functools.wraps(get_variable) +# The argument list for get_local_variable must match arguments to get_variable. +# So, if you are updating the arguments, also update arguments to get_variable. @tf_export("get_local_variable") -def get_local_variable(*args, **kwargs): - kwargs["trainable"] = False - if "collections" in kwargs: - kwargs["collections"] += [ops.GraphKeys.LOCAL_VARIABLES] +def get_local_variable(name, + shape=None, + dtype=None, + initializer=None, + regularizer=None, + trainable=False, # pylint: disable=unused-argument + collections=None, + caching_device=None, + partitioner=None, + validate_shape=True, + use_resource=None, + custom_getter=None, + constraint=None): + if collections: + collections += [ops.GraphKeys.LOCAL_VARIABLES] else: - kwargs["collections"] = [ops.GraphKeys.LOCAL_VARIABLES] - return get_variable(*args, **kwargs) + collections = [ops.GraphKeys.LOCAL_VARIABLES] + return get_variable( + name, shape=shape, dtype=dtype, initializer=initializer, + regularizer=regularizer, trainable=False, collections=collections, + caching_device=caching_device, partitioner=partitioner, + validate_shape=validate_shape, use_resource=use_resource, + custom_getter=custom_getter, constraint=constraint) get_local_variable.__doc__ = get_variable_or_local_docstring % ( "Gets an existing *local* variable or creates a new one.", "Behavior is the same as in `get_variable`, except that variables are\n" @@ -1555,10 +1592,8 @@ class _pure_variable_scope(object): # pylint: disable=invalid-name self._dtype = dtype self._use_resource = use_resource self._constraint = constraint - get_variable_scope() # Ensure that a default exists, then get a pointer. - # Get the reference to the collection as we want to modify it in place. - self._default_varscope = ops.get_collection_ref(_VARSCOPE_KEY) self._var_store = _get_default_variable_store() + self._var_scope_store = get_variable_scope_store() if isinstance(self._name_or_scope, VariableScope): self._new_name = self._name_or_scope.name name_scope = self._name_or_scope._name_scope # pylint: disable=protected-access @@ -1606,10 +1641,11 @@ class _pure_variable_scope(object): # pylint: disable=invalid-name a reuse scope, or if reuse is not `None` or `True`. TypeError: when the types of some arguments are not appropriate. """ - self._old = self._default_varscope[0] + self._old = self._var_scope_store.current_scope if isinstance(self._name_or_scope, VariableScope): - self._var_store.open_variable_scope(self._new_name) - self._old_subscopes = copy.copy(self._var_store.variable_scopes_count) + self._var_scope_store.open_variable_scope(self._new_name) + self._old_subscopes = copy.copy( + self._var_scope_store.variable_scopes_count) variable_scope_object = self._cached_variable_scope_object else: # Handler for the case when we just prolong current variable scope. @@ -1652,17 +1688,17 @@ class _pure_variable_scope(object): # pylint: disable=invalid-name variable_scope_object.set_dtype(self._dtype) if self._use_resource is not None: variable_scope_object.set_use_resource(self._use_resource) - self._var_store.open_variable_scope(self._new_name) - self._default_varscope[0] = variable_scope_object + self._var_scope_store.open_variable_scope(self._new_name) + self._var_scope_store.current_scope = variable_scope_object return variable_scope_object def __exit__(self, type_arg, value_arg, traceback_arg): # If jumping out from a non-prolonged scope, restore counts. if isinstance(self._name_or_scope, VariableScope): - self._var_store.variable_scopes_count = self._old_subscopes + self._var_scope_store.variable_scopes_count = self._old_subscopes else: - self._var_store.close_variable_subscopes(self._new_name) - self._default_varscope[0] = self._old + self._var_scope_store.close_variable_subscopes(self._new_name) + self._var_scope_store.current_scope = self._old def _maybe_wrap_custom_getter(custom_getter, old_getter): @@ -1687,13 +1723,13 @@ def _maybe_wrap_custom_getter(custom_getter, old_getter): def _get_unique_variable_scope(prefix): """Get a name with the given prefix unique in the current variable scope.""" - var_store = _get_default_variable_store() + var_scope_store = get_variable_scope_store() current_scope = get_variable_scope() name = current_scope.name + "/" + prefix if current_scope.name else prefix - if var_store.variable_scope_count(name) == 0: + if var_scope_store.variable_scope_count(name) == 0: return prefix idx = 1 - while var_store.variable_scope_count(name + ("_%d" % idx)) > 0: + while var_scope_store.variable_scope_count(name + ("_%d" % idx)) > 0: idx += 1 return prefix + ("_%d" % idx) @@ -1709,9 +1745,10 @@ class variable_scope(object): graph, ensures that graph is the default graph, and pushes a name scope and a variable scope. - If `name_or_scope` is not None, it is used as is. If `scope` is None, then - `default_name` is used. In that case, if the same name has been previously - used in the same scope, it will be made unique by appending `_N` to it. + If `name_or_scope` is not None, it is used as is. If `name_or_scope` is None, + then `default_name` is used. In that case, if the same name has been + previously used in the same scope, it will be made unique by appending `_N` + to it. Variable scope allows you to create new variables and to share already created ones while providing checks to not create or share by accident. For details, @@ -1790,6 +1827,32 @@ class variable_scope(object): discouraged) to pass False to the reuse argument, yielding undocumented behaviour slightly different from None. Starting at 1.1.0 passing None and False as reuse has exactly the same effect. + + A note about using variable scopes in multi-threaded environment: Variable + scopes are thread local, so one thread will not see another thread's current + scope. Also, when using `default_name`, unique scopes names are also generated + only on a per thread basis. If the same name was used within a different + thread, that doesn't prevent a new thread from creating the same scope. + However, the underlying variable store is shared across threads (within the + same graph). As such, if another thread tries to create a new variable with + the same name as a variable created by a previous thread, it will fail unless + reuse is True. + + Further, each thread starts with an empty variable scope. So if you wish to + preserve name prefixes from a scope from the main thread, you should capture + the main thread's scope and re-enter it in each thread. For e.g. + + ``` + main_thread_scope = variable_scope.get_variable_scope() + + # Thread's target function: + def thread_target_fn(captured_scope): + with variable_scope.variable_scope(captured_scope): + # .... regular code for this thread + + + thread = threading.Thread(target=thread_target_fn, args=(main_thread_scope,)) + ``` """ def __init__(self, @@ -1871,7 +1934,7 @@ class variable_scope(object): raise ValueError("The reuse parameter must be True or False or None.") if self._values is None: self._values = [] - self._in_graph_mode = not context.in_eager_mode() + self._in_graph_mode = not context.executing_eagerly() if self._in_graph_mode: self._graph = ops._get_graph_from_inputs(self._values) # pylint: disable=protected-access self._cached_pure_variable_scope = None @@ -2111,13 +2174,13 @@ def default_variable_creator(next_creator=None, **kwargs): use_resource = kwargs.get("use_resource", None) if use_resource is None: use_resource = get_variable_scope().use_resource - if use_resource or (use_resource is None and context.in_eager_mode()): + if use_resource or (use_resource is None and context.executing_eagerly()): return resource_variable_ops.ResourceVariable( initial_value=initial_value, trainable=trainable, collections=collections, validate_shape=validate_shape, caching_device=caching_device, name=name, dtype=dtype, constraint=constraint) - elif not use_resource and context.in_eager_mode(): + elif not use_resource and context.executing_eagerly(): raise RuntimeError( "VariableScope should use resource variable when eager execution is" " enabled, but use_resource is False." @@ -2145,7 +2208,7 @@ def variable(initial_value=None, constraint=None, use_resource=None): previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs) - for getter in ops.get_default_graph()._get_variable_creator_stack(): # pylint: disable=protected-access + for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access previous_getter = _make_getter(getter, previous_getter) return previous_getter(initial_value=initial_value, trainable=trainable, diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index d382683858be5d91755ef1a15ebbc6ae2287f8a7..c646f795896f0abfce3eb9a57cadc27299714023 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -125,8 +125,8 @@ class Variable(checkpointable.CheckpointableBase): @compatibility(eager) `tf.Variable` is not compatible with eager execution. Use - `tfe.Variable` instead which is compatible with both eager execution - and graph construction. See [the TensorFlow Eager Execution + `tf.contrib.eager.Variable` instead which is compatible with both eager + execution and graph construction. See [the TensorFlow Eager Execution guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers) for details on how variables work in eager execution. @end_compatibility @@ -210,10 +210,11 @@ class Variable(checkpointable.CheckpointableBase): for details on how variables work in eager execution. @end_compatibility """ - if not context.in_graph_mode(): - raise RuntimeError("tf.Variable not supported in Eager mode. " - "Please use tfe.Variable instead") - self._in_graph_mode = context.in_graph_mode() + if context.executing_eagerly(): + raise RuntimeError( + "tf.Variable not supported when eager execution is enabled. " + "Please use tf.contrib.eager.Variable instead") + self._in_graph_mode = True if variable_def: # If variable_def is provided, recreates the variable from its fields. if initial_value: @@ -234,7 +235,7 @@ class Variable(checkpointable.CheckpointableBase): constraint=constraint) def __repr__(self): - if context.in_eager_mode(): + if context.executing_eagerly(): return "" % ( self.name, self.get_shape(), self.dtype.name, ops.numpy_text(self.read_value(), is_repr=True)) @@ -292,6 +293,7 @@ class Variable(checkpointable.CheckpointableBase): Raises: ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. + RuntimeError: If lifted into the eager context. """ _ = expected_shape if initial_value is None: @@ -307,6 +309,9 @@ class Variable(checkpointable.CheckpointableBase): if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") + # Store the graph key so optimizers know how to only retrieve variables from + # this graph. + self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access if isinstance(initial_value, checkpointable.CheckpointInitialValue): self._maybe_initialize_checkpointable() self._update_uid = initial_value.checkpoint_position.restore_uid @@ -315,6 +320,11 @@ class Variable(checkpointable.CheckpointableBase): if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] with ops.init_scope(): + # Ensure that we weren't lifted into the eager context. + if context.executing_eagerly(): + raise RuntimeError( + "tf.Variable not supported when eager execution is enabled. " + "Please use tf.contrib.eager.Variable instead") with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: @@ -737,15 +747,15 @@ class Variable(checkpointable.CheckpointableBase): Raises: ValueError: Session is not passed and no default session """ - if context.in_graph_mode(): + if context.executing_eagerly(): + self.assign(value) + else: session = session or ops.get_default_session() if session is None: raise ValueError( "Either session argument should be provided or default session " "should be established") session.run(self._initializer_op, {self._initializer_op.inputs[1]: value}) - else: - self.assign(value) # Conversion to tensor. @staticmethod @@ -1245,9 +1255,9 @@ class PartitionedVariable(object): information does not match `shape`, or `partitions` has invalid values. RuntimeError: If eager execution is enabled """ - if not context.in_graph_mode(): - raise RuntimeError("tf.PartitionedVariable not supported in " - "eager mode. Please use tfe.Variable instead") + if context.executing_eagerly(): + raise RuntimeError( + "tf.PartitionedVariable not supported with eager execution enabled.") if not isinstance(variable_list, (list, tuple)): raise TypeError( "variable_list is not a list or tuple: %s" % variable_list) @@ -1538,7 +1548,7 @@ def variables_initializer(var_list, name="init"): Returns: An Op that run the initializers of all the specified variables. """ - if var_list and context.in_graph_mode(): + if var_list and not context.executing_eagerly(): return control_flow_ops.group(*[v.initializer for v in var_list], name=name) return control_flow_ops.no_op(name=name) @@ -1560,7 +1570,7 @@ def global_variables_initializer(): Returns: An Op that initializes global variables in the graph. """ - if context.in_eager_mode(): + if context.executing_eagerly(): return control_flow_ops.no_op(name="global_variables_initializer") return variables_initializer(global_variables()) @@ -1582,7 +1592,7 @@ def local_variables_initializer(): Returns: An Op that initializes all local variables in the graph. """ - if context.in_eager_mode(): + if context.executing_eagerly(): return control_flow_ops.no_op(name="local_variables_initializer") return variables_initializer(local_variables()) diff --git a/tensorflow/python/platform/googletest.py b/tensorflow/python/platform/googletest.py index 96219faab719e28a5fa8a9a21c83f81a6f8478e6..8141cf92c568f257a5e9810318182d71f445dfa1 100644 --- a/tensorflow/python/platform/googletest.py +++ b/tensorflow/python/platform/googletest.py @@ -36,6 +36,7 @@ from tensorflow.python.platform import benchmark from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect +from tensorflow.python.util.tf_export import tf_export Benchmark = benchmark.TensorFlowBenchmark # pylint: disable=invalid-name @@ -138,6 +139,7 @@ def StatefulSessionAvailable(): return False +@tf_export('test.StubOutForTesting') class StubOutForTesting(object): """Support class for stubbing methods out for unit testing. diff --git a/tensorflow/python/platform/sysconfig.py b/tensorflow/python/platform/sysconfig.py index 5c50fa023dc3b216838390d9356a39e70e2362d2..fdd2b903fc79c40a26392714328f74756f3fff92 100644 --- a/tensorflow/python/platform/sysconfig.py +++ b/tensorflow/python/platform/sysconfig.py @@ -68,7 +68,6 @@ def get_compile_flags(): """ flags = [] flags.append('-I%s' % get_include()) - flags.append('-I%s/external/nsync/public' % get_include()) flags.append('-D_GLIBCXX_USE_CXX11_ABI=%d' % _CXX11_ABI_FLAG) return flags diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py index 9b7655722ac5a917f2753617f8e99bf2bd2f8d11..1660791febc9da93f3a3a977a17ca876e772a9a5 100644 --- a/tensorflow/python/platform/test.py +++ b/tensorflow/python/platform/test.py @@ -62,6 +62,8 @@ if sys.version_info.major == 2: else: from unittest import mock # pylint: disable=g-import-not-at-top +tf_export('test.mock')(mock) + # Import Benchmark class Benchmark = _googletest.Benchmark # pylint: disable=invalid-name diff --git a/tensorflow/python/profiler/model_analyzer.py b/tensorflow/python/profiler/model_analyzer.py index 0e20ca35bba606079ed5b0f225dd3029772b5af3..acf02096fffe8b38e68824878fa698ed69d3895c 100644 --- a/tensorflow/python/profiler/model_analyzer.py +++ b/tensorflow/python/profiler/model_analyzer.py @@ -172,7 +172,7 @@ class Profiler(object): op_log: optional. tensorflow::tfprof::OpLogProto proto. Used to define extra op types. """ - if not graph and context.in_graph_mode(): + if not graph and not context.executing_eagerly(): graph = ops.get_default_graph() self._coverage = 0.0 self._graph = graph @@ -336,7 +336,7 @@ def profile(graph=None, If cmd is 'op' or 'code', returns MultiGraphNodeProto proto. Side effect: stdout/file/timeline.json depending on options['output'] """ - if not graph and context.in_graph_mode(): + if not graph and not context.executing_eagerly(): graph = ops.get_default_graph() if options == _DEFAULT_PROFILE_OPTIONS: diff --git a/tensorflow/python/profiler/tfprof_logger.py b/tensorflow/python/profiler/tfprof_logger.py index 8d121064967f2f87cd0aefaa361bfd6f387a3e6e..e651de32ea3bce32a965bfbeefc76ff08a79ac38 100644 --- a/tensorflow/python/profiler/tfprof_logger.py +++ b/tensorflow/python/profiler/tfprof_logger.py @@ -156,7 +156,7 @@ def merge_default_with_oplog(graph, op_log=None, run_meta=None, Returns: tmp_op_log: Merged OpLogProto proto. """ - if not graph and context.in_graph_mode(): + if not graph and not context.executing_eagerly(): graph = ops.get_default_graph() tmp_op_log = tfprof_log_pb2.OpLogProto() @@ -210,7 +210,7 @@ def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True): add_trace: Whether to add python code trace information. Used to support "code" view. """ - if not graph and context.in_graph_mode(): + if not graph and not context.executing_eagerly(): graph = ops.get_default_graph() op_log = merge_default_with_oplog(graph, op_log, run_meta, add_trace) diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 7ab0db526881109765adf83749bd01e4d543e5b2..39fabb9c1bc646a09557293c1f645a8b97f5bbdd 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -26,11 +26,15 @@ limitations under the License. %rename("%s") TFE_ContextClearCaches; %rename("%s") TFE_ContextGetDevicePlacementPolicy; %rename("%s") TFE_ContextSetThreadLocalDevicePlacementPolicy; +%rename("%s") TFE_ContextSetAsyncForThread; +%rename("%s") TFE_ContextAsyncWait; +%rename("%s") TFE_ContextAsyncClearError; %rename("%s") TFE_OpNameGetAttrType; %rename("%s") TFE_Py_InitEagerTensor; %rename("%s") TFE_Py_RegisterExceptionClass; %rename("%s") TFE_Py_RegisterBackwardFunctionGetter; %rename("%s") TFE_Py_RegisterFallbackExceptionClass; +%rename("%s") TFE_Py_RegisterResourceVariableType; %rename("%s") TFE_Py_Execute; %rename("%s") TFE_Py_FastPathExecute; %rename("%s") TFE_Py_RecordGradient; @@ -50,6 +54,7 @@ limitations under the License. %rename("%s") TFE_NewContextOptions; %rename("%s") TFE_ContextOptionsSetConfig; %rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy; +%rename("%s") TFE_ContextOptionsSetAsync; %rename("%s") TFE_DeleteContextOptions; %rename("%s") TFE_Py_TensorShapeSlice; diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index 7347da75364818b95d3f2ad7dfa74a8c3614b161..3447d917e9bf2dace3de784106dadb1fcc3a9647 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -193,7 +193,8 @@ class SavedModelBuilder(object): def _validate_tensor_info(self, tensor_info): """Validates the `TensorInfo` proto. - Checks if the `name` and `dtype` fields exist and are non-empty. + Checks if the `encoding` (`name` or `coo_sparse`) and `dtype` fields exist + and are non-empty. Args: tensor_info: `TensorInfo` protocol buffer to validate. @@ -206,10 +207,12 @@ class SavedModelBuilder(object): raise AssertionError( "All TensorInfo protos used in the SignatureDefs must have the name " "and dtype fields set.") - if not tensor_info.name: + if tensor_info.WhichOneof("encoding") is None: + # TODO(soergel) validate each of the fields of coo_sparse raise AssertionError( - "All TensorInfo protos used in the SignatureDefs must have the name " - "field set: %s" % tensor_info) + "All TensorInfo protos used in the SignatureDefs must have one of " + "the 'encoding' fields (e.g., name or coo_sparse) set: %s" + % tensor_info) if tensor_info.dtype is types_pb2.DT_INVALID: raise AssertionError( "All TensorInfo protos used in the SignatureDefs must have the dtype " diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index d9d316882584470769c14cf0c5f265b58e37ab43..804255375e7c5215597a5dcca02f3b32f2c0a497 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -94,7 +94,7 @@ class SavedModelTest(test.TestCase): self.assertEqual(expected_asset_file_name, asset.filename) self.assertEqual(expected_asset_tensor_name, asset.tensor_info.name) - def _validate_inputs_tensor_info(self, builder, tensor_info): + def _validate_inputs_tensor_info_fail(self, builder, tensor_info): with self.test_session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) @@ -107,7 +107,18 @@ class SavedModelTest(test.TestCase): sess, ["foo"], signature_def_map={"foo_key": foo_signature}) - def _validate_outputs_tensor_info(self, builder, tensor_info): + def _validate_inputs_tensor_info_accept(self, builder, tensor_info): + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 42) + + foo_signature = signature_def_utils.build_signature_def({ + "foo_inputs": tensor_info + }, dict(), "foo") + builder.add_meta_graph_and_variables( + sess, ["foo"], + signature_def_map={"foo_key": foo_signature}) + + def _validate_outputs_tensor_info_fail(self, builder, tensor_info): with self.test_session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) @@ -119,6 +130,16 @@ class SavedModelTest(test.TestCase): sess, ["foo"], signature_def_map={"foo_key": foo_signature}) + def _validate_outputs_tensor_info_accept(self, builder, tensor_info): + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 42) + + foo_signature = signature_def_utils.build_signature_def( + dict(), {"foo_outputs": tensor_info}, "foo") + builder.add_meta_graph_and_variables( + sess, ["foo"], + signature_def_map={"foo_key": foo_signature}) + def testMaybeSavedModelDir(self): base_path = test.test_src_dir_path("/python/saved_model") self.assertFalse(loader.maybe_saved_model_directory(base_path)) @@ -538,23 +559,50 @@ class SavedModelTest(test.TestCase): self.assertEqual("bar", bar_signature["bar_key"].method_name) self.assertEqual("foo_new", bar_signature["foo_key"].method_name) - def testSignatureDefValidation(self): - export_dir = self._get_export_dir("test_signature_def_validation") + def testSignatureDefValidationFails(self): + export_dir = self._get_export_dir("test_signature_def_validation_fail") builder = saved_model_builder.SavedModelBuilder(export_dir) - tensor_without_name = meta_graph_pb2.TensorInfo() - tensor_without_name.dtype = types_pb2.DT_FLOAT - self._validate_inputs_tensor_info(builder, tensor_without_name) - self._validate_outputs_tensor_info(builder, tensor_without_name) + tensor_without_encoding = meta_graph_pb2.TensorInfo() + tensor_without_encoding.dtype = types_pb2.DT_FLOAT + self._validate_inputs_tensor_info_fail(builder, tensor_without_encoding) + self._validate_outputs_tensor_info_fail(builder, tensor_without_encoding) tensor_without_dtype = meta_graph_pb2.TensorInfo() tensor_without_dtype.name = "x" - self._validate_inputs_tensor_info(builder, tensor_without_dtype) - self._validate_outputs_tensor_info(builder, tensor_without_dtype) + self._validate_inputs_tensor_info_fail(builder, tensor_without_dtype) + self._validate_outputs_tensor_info_fail(builder, tensor_without_dtype) tensor_empty = meta_graph_pb2.TensorInfo() - self._validate_inputs_tensor_info(builder, tensor_empty) - self._validate_outputs_tensor_info(builder, tensor_empty) + self._validate_inputs_tensor_info_fail(builder, tensor_empty) + self._validate_outputs_tensor_info_fail(builder, tensor_empty) + + def testSignatureDefValidationSucceedsWithName(self): + tensor_with_name = meta_graph_pb2.TensorInfo() + tensor_with_name.name = "foo" + tensor_with_name.dtype = types_pb2.DT_FLOAT + + export_dir = self._get_export_dir("test_signature_def_validation_name_1") + builder = saved_model_builder.SavedModelBuilder(export_dir) + self._validate_inputs_tensor_info_accept(builder, tensor_with_name) + + export_dir = self._get_export_dir("test_signature_def_validation_name_2") + builder = saved_model_builder.SavedModelBuilder(export_dir) + self._validate_outputs_tensor_info_accept(builder, tensor_with_name) + + def testSignatureDefValidationSucceedsWithCoo(self): + tensor_with_coo = meta_graph_pb2.TensorInfo() + # TODO(soergel) test validation of each of the fields of coo_sparse + tensor_with_coo.coo_sparse.values_tensor_name = "foo" + tensor_with_coo.dtype = types_pb2.DT_FLOAT + + export_dir = self._get_export_dir("test_signature_def_validation_coo_1") + builder = saved_model_builder.SavedModelBuilder(export_dir) + self._validate_inputs_tensor_info_accept(builder, tensor_with_coo) + + export_dir = self._get_export_dir("test_signature_def_validation_coo_2") + builder = saved_model_builder.SavedModelBuilder(export_dir) + self._validate_outputs_tensor_info_accept(builder, tensor_with_coo) def testAssets(self): export_dir = self._get_export_dir("test_assets") diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py index b80ad79074e85bdeae70148b2822c319c29468bc..97f2ddfdfc49e415bdcff428d6bd3f5b61cc3f20 100644 --- a/tensorflow/python/summary/summary.py +++ b/tensorflow/python/summary/summary.py @@ -48,10 +48,12 @@ from tensorflow.core.util.event_pb2 import SessionLog from tensorflow.core.util.event_pb2 import TaggedRunMetadata # pylint: enable=unused-import + from tensorflow.python.eager import context as _context from tensorflow.python.framework import dtypes as _dtypes from tensorflow.python.framework import ops as _ops from tensorflow.python.ops import gen_logging_ops as _gen_logging_ops +from tensorflow.python.ops import gen_summary_ops as _gen_summary_ops # pylint: disable=unused-import from tensorflow.python.ops import summary_op_util as _summary_op_util # exports tensor-related summaries @@ -98,8 +100,7 @@ def scalar(name, tensor, collections=None, family=None): """ with _summary_op_util.summary_scope( name, family, values=[tensor]) as (tag, scope): - # pylint: disable=protected-access - val = _gen_logging_ops._scalar_summary(tags=tag, values=tensor, name=scope) + val = _gen_logging_ops.scalar_summary(tags=tag, values=tensor, name=scope) _summary_op_util.collect(val, collections, [_ops.GraphKeys.SUMMARIES]) return val @@ -152,8 +153,7 @@ def image(name, tensor, max_outputs=3, collections=None, family=None): """ with _summary_op_util.summary_scope( name, family, values=[tensor]) as (tag, scope): - # pylint: disable=protected-access - val = _gen_logging_ops._image_summary( + val = _gen_logging_ops.image_summary( tag=tag, tensor=tensor, max_images=max_outputs, name=scope) _summary_op_util.collect(val, collections, [_ops.GraphKeys.SUMMARIES]) return val @@ -192,8 +192,7 @@ def histogram(name, values, collections=None, family=None): with _summary_op_util.summary_scope( name, family, values=[values], default_name='HistogramSummary') as (tag, scope): - # pylint: disable=protected-access - val = _gen_logging_ops._histogram_summary( + val = _gen_logging_ops.histogram_summary( tag=tag, values=values, name=scope) _summary_op_util.collect(val, collections, [_ops.GraphKeys.SUMMARIES]) return val @@ -237,10 +236,9 @@ def audio(name, tensor, sample_rate, max_outputs=3, collections=None, """ with _summary_op_util.summary_scope( name, family=family, values=[tensor]) as (tag, scope): - # pylint: disable=protected-access sample_rate = _ops.convert_to_tensor( sample_rate, dtype=_dtypes.float32, name='sample_rate') - val = _gen_logging_ops._audio_summary_v2( + val = _gen_logging_ops.audio_summary_v2( tag=tag, tensor=tensor, max_outputs=max_outputs, sample_rate=sample_rate, name=scope) _summary_op_util.collect(val, collections, [_ops.GraphKeys.SUMMARIES]) @@ -280,14 +278,13 @@ def merge(inputs, collections=None, name=None): @end_compatbility """ # pylint: enable=line-too-long - if _context.in_eager_mode(): + if _context.executing_eagerly(): raise RuntimeError( 'Merging tf.summary.* ops is not compatible with eager execution. ' 'Use tf.contrib.summary instead.') name = _summary_op_util.clean_tag(name) with _ops.name_scope(name, 'Merge', inputs): - # pylint: disable=protected-access - val = _gen_logging_ops._merge_summary(inputs=inputs, name=name) + val = _gen_logging_ops.merge_summary(inputs=inputs, name=name) _summary_op_util.collect(val, collections, []) return val @@ -314,7 +311,7 @@ def merge_all(key=_ops.GraphKeys.SUMMARIES, scope=None): summaries under eager execution, use `tf.contrib.summary` instead. @end_compatbility """ - if _context.in_eager_mode(): + if _context.executing_eagerly(): raise RuntimeError( 'Merging tf.summary.* ops is not compatible with eager execution. ' 'Use tf.contrib.summary instead.') diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py index 1f3f2287043c021d636113b5a8807c9f4adf77aa..57f78c156b1334a5486b29f2ddec957e49156e73 100644 --- a/tensorflow/python/summary/writer/writer.py +++ b/tensorflow/python/summary/writer/writer.py @@ -343,7 +343,7 @@ class FileWriter(SummaryToEventTransformer): summaries under eager execution, use `tf.contrib.summary` instead. @end_compatbility """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError( "tf.summary.FileWriter is not compatible with eager execution. " "Use tf.contrib.summary instead.") diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD index 63f16c53a29fd65c32077dd29e3b1823c11d457b..1de1adcfbc35e2b760f362cb9784dd415b9a4dc4 100644 --- a/tensorflow/python/tools/BUILD +++ b/tensorflow/python/tools/BUILD @@ -14,6 +14,7 @@ py_library( name = "tools_pip", deps = [ ":freeze_graph", + ":import_pb_to_tensorboard", ":inspect_checkpoint", ":optimize_for_inference", ":print_selective_registration_header", @@ -248,7 +249,10 @@ py_test( "//tensorflow/cc/saved_model:saved_model_half_plus_two", ], srcs_version = "PY2AND3", - tags = ["manual"], + tags = [ + "manual", + "no-internal-py3", + ], deps = [ ":saved_model_cli", "//tensorflow/core:protos_all_py", diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py index a52f325ddbcd90ad011c1c056965912b96f27aaa..e9f1def48c462dcd8a5acf0e3d29d562cd1b3d58 100644 --- a/tensorflow/python/tools/freeze_graph.py +++ b/tensorflow/python/tools/freeze_graph.py @@ -56,8 +56,6 @@ from tensorflow.python.saved_model import tag_constants from tensorflow.python.tools import saved_model_utils from tensorflow.python.training import saver as saver_lib -FLAGS = None - def freeze_graph_with_def_protos(input_graph_def, input_saver_def, @@ -256,25 +254,24 @@ def freeze_graph(input_graph, checkpoint_version=checkpoint_version) -def main(unused_args): - if FLAGS.checkpoint_version == 1: +def main(unused_args, flags): + if flags.checkpoint_version == 1: checkpoint_version = saver_pb2.SaverDef.V1 - elif FLAGS.checkpoint_version == 2: + elif flags.checkpoint_version == 2: checkpoint_version = saver_pb2.SaverDef.V2 else: print("Invalid checkpoint version (must be '1' or '2'): %d" % - FLAGS.checkpoint_version) + flags.checkpoint_version) return -1 - freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary, - FLAGS.input_checkpoint, FLAGS.output_node_names, - FLAGS.restore_op_name, FLAGS.filename_tensor_name, - FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes, - FLAGS.variable_names_whitelist, FLAGS.variable_names_blacklist, - FLAGS.input_meta_graph, FLAGS.input_saved_model_dir, - FLAGS.saved_model_tags, checkpoint_version) - + freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary, + flags.input_checkpoint, flags.output_node_names, + flags.restore_op_name, flags.filename_tensor_name, + flags.output_graph, flags.clear_devices, flags.initializer_nodes, + flags.variable_names_whitelist, flags.variable_names_blacklist, + flags.input_meta_graph, flags.input_saved_model_dir, + flags.saved_model_tags, checkpoint_version) -if __name__ == "__main__": +def run_main(): parser = argparse.ArgumentParser() parser.register("type", "bool", lambda v: v.lower() == "true") parser.add_argument( @@ -376,5 +373,10 @@ if __name__ == "__main__": separated by \',\'. For tag-set contains multiple tags, all tags \ must be passed in.\ """) - FLAGS, unparsed = parser.parse_known_args() - app.run(main=main, argv=[sys.argv[0]] + unparsed) + flags, unparsed = parser.parse_known_args() + + my_main = lambda unused_args: main(unused_args, flags) + app.run(main=my_main, argv=[sys.argv[0]] + unparsed) + +if __name__ == '__main__': + run_main() diff --git a/tensorflow/python/tools/inspect_checkpoint.py b/tensorflow/python/tools/inspect_checkpoint.py index dd876cbe7fcd64a8de70eb28f67996df9de1dd7d..6504fbc10755c5c543016b8d56d6d53f3311b249 100644 --- a/tensorflow/python/tools/inspect_checkpoint.py +++ b/tensorflow/python/tools/inspect_checkpoint.py @@ -30,7 +30,7 @@ FLAGS = None def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors, - all_tensor_names): + all_tensor_names=False): """Prints tensors in a checkpoint file. If no `tensor_name` is provided, prints the tensor names and shapes @@ -139,7 +139,7 @@ if __name__ == "__main__": const=True, type="bool", default=False, - help="If True, print the values of all the tensors.") + help="If True, print the names and values of all the tensors.") parser.add_argument( "--all_tensor_names", nargs="?", diff --git a/tensorflow/python/tools/print_selective_registration_header_test.py b/tensorflow/python/tools/print_selective_registration_header_test.py index 36978b0860a423569149cd0572629f9f1f280637..4b3d98242caf683693430f08bd8cb74483f4bc74 100644 --- a/tensorflow/python/tools/print_selective_registration_header_test.py +++ b/tensorflow/python/tools/print_selective_registration_header_test.py @@ -24,6 +24,7 @@ import sys from google.protobuf import text_format from tensorflow.core.framework import graph_pb2 +from tensorflow.python.framework import test_util from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.tools import selective_registration_header_lib @@ -93,11 +94,16 @@ class PrintOpFilegroupTest(test.TestCase): ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( 'rawproto', self.WriteGraphFiles(graphs), default_ops) + matmul_prefix = '' + if test_util.IsMklEnabled(): + matmul_prefix = 'Mkl' + self.assertListEqual( [ ('BiasAdd', 'BiasOp'), # - ('MatMul', 'MatMulOp'), # - ('MatMul', 'MatMulOp'), # + ('MatMul', + matmul_prefix + 'MatMulOp'), # + ('MatMul', matmul_prefix + 'MatMulOp'), # ('NoOp', 'NoOp'), # ('Reshape', 'ReshapeOp'), # ('_Recv', 'RecvOp'), # @@ -112,8 +118,9 @@ class PrintOpFilegroupTest(test.TestCase): self.assertListEqual( [ ('BiasAdd', 'BiasOp'), # - ('MatMul', 'MatMulOp'), # - ('MatMul', 'MatMulOp'), # + ('MatMul', + matmul_prefix + 'MatMulOp'), # + ('MatMul', matmul_prefix + 'MatMulOp'), # ('NoOp', 'NoOp'), # ('Reshape', 'ReshapeOp'), # ('_Recv', 'RecvOp'), # diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 33f6debbcbecb652774c776be54323bbaa824822..b88be4ae04d5dc7a7641fb8dbd7e56e61035869f 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -38,11 +38,15 @@ from tensorflow.core.example import example_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session from tensorflow.python.debug.wrappers import local_cli_wrapper +from tensorflow.python.framework import meta_graph as meta_graph_lib from tensorflow.python.framework import ops as ops_lib from tensorflow.python.platform import app # pylint: disable=unused-import from tensorflow.python.saved_model import loader from tensorflow.python.tools import saved_model_utils +# Set of ops to blacklist. +_OP_BLACKLIST = set(['WriteFile', 'ReadFile']) + def _show_tag_sets(saved_model_dir): """Prints the tag-sets stored in SavedModel directory. @@ -115,7 +119,7 @@ def _get_outputs_tensor_info_from_meta_graph_def(meta_graph_def, signature_def_key).outputs -def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key): +def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key, indent=0): """Prints input and output TensorInfos. Prints the details of input and output TensorInfos for the SignatureDef mapped @@ -126,6 +130,7 @@ def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key): tag_set: Group of tag(s) of the MetaGraphDef, in string format, separated by ','. For tag-set contains multiple tags, all tags must be passed in. signature_def_key: A SignatureDef key string. + indent: How far (in increments of 2 spaces) to indent each line of output. """ meta_graph_def = saved_model_utils.get_meta_graph_def(saved_model_dir, tag_set) @@ -134,29 +139,39 @@ def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key): outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def( meta_graph_def, signature_def_key) - print('The given SavedModel SignatureDef contains the following input(s):') + indent_str = " " * indent + def in_print(s): + print(indent_str + s) + + in_print('The given SavedModel SignatureDef contains the following input(s):') for input_key, input_tensor in sorted(inputs_tensor_info.items()): - print('inputs[\'%s\'] tensor_info:' % input_key) - _print_tensor_info(input_tensor) + in_print(' inputs[\'%s\'] tensor_info:' % input_key) + _print_tensor_info(input_tensor, indent+1) - print('The given SavedModel SignatureDef contains the following output(s):') + in_print('The given SavedModel SignatureDef contains the following ' + 'output(s):') for output_key, output_tensor in sorted(outputs_tensor_info.items()): - print('outputs[\'%s\'] tensor_info:' % output_key) - _print_tensor_info(output_tensor) + in_print(' outputs[\'%s\'] tensor_info:' % output_key) + _print_tensor_info(output_tensor, indent+1) - print('Method name is: %s' % - meta_graph_def.signature_def[signature_def_key].method_name) + in_print('Method name is: %s' % + meta_graph_def.signature_def[signature_def_key].method_name) -def _print_tensor_info(tensor_info): +def _print_tensor_info(tensor_info, indent=0): """Prints details of the given tensor_info. Args: tensor_info: TensorInfo object to be printed. + indent: How far (in increments of 2 spaces) to indent each line output """ - print(' dtype: ' + - {value: key - for (key, value) in types_pb2.DataType.items()}[tensor_info.dtype]) + indent_str = " " * indent + def in_print(s): + print(indent_str + s) + + in_print(' dtype: ' + + {value: key + for (key, value) in types_pb2.DataType.items()}[tensor_info.dtype]) # Display shape as tuple. if tensor_info.tensor_shape.unknown_rank: shape = 'unknown_rank' @@ -164,8 +179,8 @@ def _print_tensor_info(tensor_info): dims = [str(dim.size) for dim in tensor_info.tensor_shape.dim] shape = ', '.join(dims) shape = '(' + shape + ')' - print(' shape: ' + shape) - print(' name: ' + tensor_info.name) + in_print(' shape: ' + shape) + in_print(' name: ' + tensor_info.name) def _show_all(saved_model_dir): @@ -186,7 +201,8 @@ def _show_all(saved_model_dir): signature_def_map = get_signature_def_map(saved_model_dir, tag_set) for signature_def_key in sorted(signature_def_map.keys()): print('\nsignature_def[\'' + signature_def_key + '\']:') - _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key) + _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key, + indent=1) def get_meta_graph_def(saved_model_dir, tag_set): @@ -230,6 +246,27 @@ def get_signature_def_map(saved_model_dir, tag_set): return meta_graph.signature_def +def scan_meta_graph_def(meta_graph_def): + """Scans meta_graph_def and reports if there are ops on blacklist. + + Print ops if they are on black list, or print success if no blacklisted ops + found. + + Args: + meta_graph_def: MetaGraphDef protocol buffer. + """ + all_ops_set = set( + meta_graph_lib.ops_used_by_graph_def(meta_graph_def.graph_def)) + blacklisted_ops = _OP_BLACKLIST & all_ops_set + if blacklisted_ops: + # TODO(yifeif): print more warnings + print('MetaGraph with tag set %s contains the following blacklisted ops:' % + meta_graph_def.meta_info_def.tags, blacklisted_ops) + else: + print('MetaGraph with tag set %s does not contain blacklisted ops.' % + meta_graph_def.meta_info_def.tags) + + def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key, input_tensor_key_feed_dict, outdir, overwrite_flag, tf_debug=False): @@ -597,6 +634,21 @@ def run(args): args.overwrite, tf_debug=args.tf_debug) +def scan(args): + """Function triggered by scan command. + + Args: + args: A namespace parsed from command line. + """ + if args.tag_set: + scan_meta_graph_def( + saved_model_utils.get_meta_graph_def(args.dir, args.tag_set)) + else: + saved_model = reader.read_saved_model(args.dir) + for meta_graph_def in saved_model.meta_graphs: + scan_meta_graph_def(meta_graph_def) + + def create_parser(): """Creates a parser that parse the command line arguments. @@ -614,19 +666,19 @@ def create_parser(): show_msg = ( 'Usage examples:\n' 'To show all tag-sets in a SavedModel:\n' - '$saved_model_cli show --dir /tmp/saved_model\n' + '$saved_model_cli show --dir /tmp/saved_model\n\n' 'To show all available SignatureDef keys in a ' 'MetaGraphDef specified by its tag-set:\n' - '$saved_model_cli show --dir /tmp/saved_model --tag_set serve\n' + '$saved_model_cli show --dir /tmp/saved_model --tag_set serve\n\n' 'For a MetaGraphDef with multiple tags in the tag-set, all tags must be ' 'passed in, separated by \';\':\n' '$saved_model_cli show --dir /tmp/saved_model --tag_set serve,gpu\n\n' 'To show all inputs and outputs TensorInfo for a specific' ' SignatureDef specified by the SignatureDef key in a' ' MetaGraph.\n' - '$saved_model_cli show --dir /tmp/saved_model --tag_set serve ' - '--signature_def serving_default\n\n' - 'To show all available information in the SavedModel\n:' + '$saved_model_cli show --dir /tmp/saved_model --tag_set serve' + ' --signature_def serving_default\n\n' + 'To show all available information in the SavedModel:\n' '$saved_model_cli show --dir /tmp/saved_model --all') parser_show = subparsers.add_parser( 'show', @@ -658,12 +710,14 @@ def create_parser(): run_msg = ('Usage example:\n' 'To run input tensors from files through a MetaGraphDef and save' ' the output tensors to files:\n' - '$saved_model_cli show --dir /tmp/saved_model --tag_set serve ' - '--signature_def serving_default ' - '--inputs input1_key=/tmp/124.npz[x],input2_key=/tmp/123.npy ' - '--input_exprs \'input3_key=np.ones(2)\' --input_examples ' - '\'input4_key=[{"id":[26],"weights":[0.5, 0.5]}]\' ' - '--outdir=/out\n\n' + '$saved_model_cli show --dir /tmp/saved_model --tag_set serve \\\n' + ' --signature_def serving_default \\\n' + ' --inputs input1_key=/tmp/124.npz[x],input2_key=/tmp/123.npy ' + '\\\n' + ' --input_exprs \'input3_key=np.ones(2)\' \\\n' + ' --input_examples ' + '\'input4_key=[{"id":[26],"weights":[0.5, 0.5]}]\' \\\n' + ' --outdir=/out\n\n' 'For more information about input file format, please see:\n' 'https://www.tensorflow.org/programmers_guide/saved_model_cli\n') parser_run = subparsers.add_parser( @@ -716,6 +770,26 @@ def create_parser(): 'SavedModel.') parser_run.set_defaults(func=run) + # scan command + scan_msg = ('Usage example:\n' + 'To scan for blacklisted ops in SavedModel:\n' + '$saved_model_cli scan --dir /tmp/saved_model\n' + 'To scan a specific MetaGraph, pass in --tag_set\n') + parser_scan = subparsers.add_parser( + 'scan', + description=scan_msg, + formatter_class=argparse.RawTextHelpFormatter) + parser_scan.add_argument( + '--dir', + type=str, + required=True, + help='directory containing the SavedModel to execute') + parser_scan.add_argument( + '--tag_set', + type=str, + help='tag-set of graph in SavedModel to scan, separated by \',\'') + parser_scan.set_defaults(func=scan) + return parser diff --git a/tensorflow/python/tools/saved_model_cli_test.py b/tensorflow/python/tools/saved_model_cli_test.py index d6cbc49ba1e08a6b808b228fb8d69fc14f36e3d2..eedc893a38d3d0857dd49c7ce03f3921da48fdbd 100644 --- a/tensorflow/python/tools/saved_model_cli_test.py +++ b/tensorflow/python/tools/saved_model_cli_test.py @@ -61,83 +61,84 @@ class SavedModelCLITestCase(test.TestCase): exp_out = """MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs: signature_def['classify_x2_to_y3']: -The given SavedModel SignatureDef contains the following input(s): -inputs['inputs'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: x2:0 -The given SavedModel SignatureDef contains the following output(s): -outputs['scores'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: y3:0 -Method name is: tensorflow/serving/classify + The given SavedModel SignatureDef contains the following input(s): + inputs['inputs'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: x2:0 + The given SavedModel SignatureDef contains the following output(s): + outputs['scores'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y3:0 + Method name is: tensorflow/serving/classify signature_def['classify_x_to_y']: -The given SavedModel SignatureDef contains the following input(s): -inputs['inputs'] tensor_info: - dtype: DT_STRING - shape: unknown_rank - name: tf_example:0 -The given SavedModel SignatureDef contains the following output(s): -outputs['scores'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: y:0 -Method name is: tensorflow/serving/classify + The given SavedModel SignatureDef contains the following input(s): + inputs['inputs'] tensor_info: + dtype: DT_STRING + shape: unknown_rank + name: tf_example:0 + The given SavedModel SignatureDef contains the following output(s): + outputs['scores'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y:0 + Method name is: tensorflow/serving/classify signature_def['regress_x2_to_y3']: -The given SavedModel SignatureDef contains the following input(s): -inputs['inputs'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: x2:0 -The given SavedModel SignatureDef contains the following output(s): -outputs['outputs'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: y3:0 -Method name is: tensorflow/serving/regress + The given SavedModel SignatureDef contains the following input(s): + inputs['inputs'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: x2:0 + The given SavedModel SignatureDef contains the following output(s): + outputs['outputs'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y3:0 + Method name is: tensorflow/serving/regress signature_def['regress_x_to_y']: -The given SavedModel SignatureDef contains the following input(s): -inputs['inputs'] tensor_info: - dtype: DT_STRING - shape: unknown_rank - name: tf_example:0 -The given SavedModel SignatureDef contains the following output(s): -outputs['outputs'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: y:0 -Method name is: tensorflow/serving/regress + The given SavedModel SignatureDef contains the following input(s): + inputs['inputs'] tensor_info: + dtype: DT_STRING + shape: unknown_rank + name: tf_example:0 + The given SavedModel SignatureDef contains the following output(s): + outputs['outputs'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y:0 + Method name is: tensorflow/serving/regress signature_def['regress_x_to_y2']: -The given SavedModel SignatureDef contains the following input(s): -inputs['inputs'] tensor_info: - dtype: DT_STRING - shape: unknown_rank - name: tf_example:0 -The given SavedModel SignatureDef contains the following output(s): -outputs['outputs'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: y2:0 -Method name is: tensorflow/serving/regress + The given SavedModel SignatureDef contains the following input(s): + inputs['inputs'] tensor_info: + dtype: DT_STRING + shape: unknown_rank + name: tf_example:0 + The given SavedModel SignatureDef contains the following output(s): + outputs['outputs'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y2:0 + Method name is: tensorflow/serving/regress signature_def['serving_default']: -The given SavedModel SignatureDef contains the following input(s): -inputs['x'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: x:0 -The given SavedModel SignatureDef contains the following output(s): -outputs['y'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: y:0 -Method name is: tensorflow/serving/predict""" + The given SavedModel SignatureDef contains the following input(s): + inputs['x'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: x:0 + The given SavedModel SignatureDef contains the following output(s): + outputs['y'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y:0 + Method name is: tensorflow/serving/predict""" # pylint: enable=line-too-long + self.maxDiff = None # Produce a useful error msg if the comparison fails self.assertMultiLineEqual(output, exp_out) self.assertEqual(err.getvalue().strip(), '') @@ -193,11 +194,11 @@ Method name is: tensorflow/serving/predict""" output = out.getvalue().strip() expected_output = ( 'The given SavedModel SignatureDef contains the following input(s):\n' - 'inputs[\'x\'] tensor_info:\n' - ' dtype: DT_FLOAT\n shape: (-1, 1)\n name: x:0\n' + ' inputs[\'x\'] tensor_info:\n' + ' dtype: DT_FLOAT\n shape: (-1, 1)\n name: x:0\n' 'The given SavedModel SignatureDef contains the following output(s):\n' - 'outputs[\'y\'] tensor_info:\n' - ' dtype: DT_FLOAT\n shape: (-1, 1)\n name: y:0\n' + ' outputs[\'y\'] tensor_info:\n' + ' dtype: DT_FLOAT\n shape: (-1, 1)\n name: y:0\n' 'Method name is: tensorflow/serving/predict') self.assertEqual(output, expected_output) self.assertEqual(err.getvalue().strip(), '') @@ -524,6 +525,28 @@ Method name is: tensorflow/serving/predict""" y_expected = np.array([[2.5], [3.0]]) self.assertAllClose(y_expected, y_actual) + def testScanCommand(self): + self.parser = saved_model_cli.create_parser() + base_path = test.test_src_dir_path(SAVED_MODEL_PATH) + args = self.parser.parse_args(['scan', '--dir', base_path]) + with captured_output() as (out, _): + saved_model_cli.scan(args) + output = out.getvalue().strip() + self.assertTrue('does not contain blacklisted ops' in output) + + def testScanCommandFoundBlacklistedOp(self): + self.parser = saved_model_cli.create_parser() + base_path = test.test_src_dir_path(SAVED_MODEL_PATH) + args = self.parser.parse_args( + ['scan', '--dir', base_path, '--tag_set', 'serve']) + op_blacklist = saved_model_cli._OP_BLACKLIST + saved_model_cli._OP_BLACKLIST = set(['VariableV2']) + with captured_output() as (out, _): + saved_model_cli.scan(args) + saved_model_cli._OP_BLACKLIST = op_blacklist + output = out.getvalue().strip() + self.assertTrue('\'VariableV2\'' in output) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py index c92f6fc3015960a2b821651231bb94713e0d53dd..006e360389b404a8edd97c9a8bf4b8876c828004 100644 --- a/tensorflow/python/training/adam.py +++ b/tensorflow/python/training/adam.py @@ -106,10 +106,10 @@ class AdamOptimizer(optimizer.Optimizer): self._updated_lr = None def _get_beta_accumulators(self): - if context.in_graph_mode(): - graph = ops.get_default_graph() - else: + if context.executing_eagerly(): graph = None + else: + graph = ops.get_default_graph() return (self._get_non_slot_variable("beta1_power", graph=graph), self._get_non_slot_variable("beta2_power", graph=graph)) diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py index a521f1299e035424d1c3897a469655db732b0dcd..9be8b6aafefa33977511cde24dd2e87dd6c3b81a 100644 --- a/tensorflow/python/training/adam_test.py +++ b/tensorflow/python/training/adam_test.py @@ -184,7 +184,7 @@ class AdamOptimizerTest(test.TestCase): # Shouldn't return non-slot variables from other graphs. self.assertEqual(0, len(opt.variables())) - if context.in_graph_mode(): + if not context.executing_eagerly(): self.evaluate(variables.global_variables_initializer()) # Fetch params to validate initial values self.assertAllClose([1.0, 2.0], self.evaluate(var0)) @@ -194,7 +194,7 @@ class AdamOptimizerTest(test.TestCase): # Run 3 steps of Adam for t in range(1, 4): - if context.in_graph_mode(): + if not context.executing_eagerly(): self.evaluate(update) elif t > 1: opt.apply_gradients(zip([grads0, grads1], [var0, var1])) @@ -319,6 +319,15 @@ class AdamOptimizerTest(test.TestCase): # fails. optimizer.apply_gradients([(grads0, var0)]) + def testSlotsUniqueEager(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.) + v2 = resource_variable_ops.ResourceVariable(1.) + opt = adam.AdamOptimizer(1.) + opt.minimize(lambda: v1 + v2) + # There should be two non-slot variables, and two unique slot variables + # for v1 and v2 respectively. + self.assertEqual(6, len(set(opt.variables()))) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/checkpoint_ops.py b/tensorflow/python/training/checkpoint_ops.py index 7f92d94d2be369709608d36c109863b0ebfb7bbe..a6e9662b7305a00f1fcf03245685e93b756942d3 100644 --- a/tensorflow/python/training/checkpoint_ops.py +++ b/tensorflow/python/training/checkpoint_ops.py @@ -149,7 +149,7 @@ def _load_and_remap_matrix(ckpt_path, num_rows_present = num_rows_to_load if remap_rows: row_remapping, num_rows_present = ( - gen_checkpoint_ops._generate_vocab_remapping( # pylint: disable=protected-access + gen_checkpoint_ops.generate_vocab_remapping( new_vocab_file=new_row_vocab_file, old_vocab_file=old_row_vocab_file, new_vocab_offset=new_row_vocab_offset, @@ -168,7 +168,7 @@ def _load_and_remap_matrix(ckpt_path, num_cols_present = new_col_vocab_size if remap_cols: col_remapping, num_cols_present = ( - gen_checkpoint_ops._generate_vocab_remapping( # pylint: disable=protected-access + gen_checkpoint_ops.generate_vocab_remapping( new_vocab_file=new_col_vocab_file, old_vocab_file=old_col_vocab_file, new_vocab_offset=0, # Offset is unused for cols (no partitioning). @@ -178,7 +178,7 @@ def _load_and_remap_matrix(ckpt_path, num_rows_to_load * new_col_vocab_size - num_rows_present * num_cols_present, 1 ]) - return_tensor = gen_checkpoint_ops._load_and_remap_matrix( # pylint: disable=protected-access + return_tensor = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=ckpt_path, old_tensor_name=old_tensor_name, row_remapping=row_remapping, diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py index 0af1cdecfa280f757b253abb2f3408bc5c9416f1..e7f88de1d2290a49f3b7bdf47417016d7e7c9cea 100644 --- a/tensorflow/python/training/checkpoint_utils.py +++ b/tensorflow/python/training/checkpoint_utils.py @@ -23,6 +23,7 @@ import six from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import ops from tensorflow.python.ops import io_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables @@ -289,10 +290,18 @@ def _set_checkpoint_initializer(variable, name: Name of the operation. """ base_type = variable.dtype.base_dtype - with ops.colocate_with(variable): + # Do not colocate with variable since RestoreV2 op only runs on CPU and + # colocation will force variable (and other ops that colocate with variable) + # to be on CPU as well. It is okay to place the variable's initializer op on + # CPU since it will only be run once at the start. + with ops.device(variable.device), ops.device("/cpu:0"): restore_op = io_ops.restore_v2( ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] - variable._initializer_op = state_ops.assign(variable, restore_op) # pylint:disable=protected-access + if isinstance(variable, resource_variable_ops.ResourceVariable): + init_op = variable.assign(restore_op, read_value=False) + else: + init_op = state_ops.assign(variable, restore_op) + variable._initializer_op = init_op # pylint:disable=protected-access restore_op.set_shape(variable.shape) variable._initial_value = restore_op # pylint:disable=protected-access diff --git a/tensorflow/python/training/checkpoint_utils_test.py b/tensorflow/python/training/checkpoint_utils_test.py index a461b24cbb1acafe60937f64d1cc0d35eb1bfc55..4e08a1c859fbaac75e7cd09ad498d9fea14c6338 100644 --- a/tensorflow/python/training/checkpoint_utils_test.py +++ b/tensorflow/python/training/checkpoint_utils_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -157,23 +158,23 @@ class CheckpointsTest(test.TestCase): "some_scope", initializer=init_ops.zeros_initializer()): my1 = variable_scope.get_variable("my1", [1, 10]) - # At this point, my1.initialized_value() will add ops that reference - # the zeros initializer of my1. - before = variables.Variable(my1.initialized_value(), name="before") + before = my1.initialized_value() checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1}) - # At this point, my1.initialized_value() will add ops that reference - # the newly set initializer of my1. - after = variables.Variable(my1.initialized_value(), name="after") + after = my1.initialized_value() + + self.assertAllEqual(session.run(before), [[0.0] * 10]) + self.assertAllEqual(session.run(after), v1) session.run(variables.global_variables_initializer()) + self.assertAllEqual(session.run(my1), v1) self.assertAllEqual(session.run(my1.initialized_value()), v1) - self.assertAllClose(session.run(before), [[0.0] * 10]) + self.assertAllClose(session.run(before), v1) self.assertAllClose(session.run(after), v1) with self.assertRaises(AssertionError): - self.assertAllClose(session.run(before), session.run(after)) + self.assertAllClose(v1, [[0.0] * 10]) def testInitWithScopeDoesNotCaptureSuffixes(self): checkpoint_dir = self.get_temp_dir() @@ -206,7 +207,9 @@ class CheckpointsTest(test.TestCase): checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"useful_scope/": "useful_scope/"}) - self.assertEqual(my4._initializer_op.op.inputs[1].device, "/job:ps") + # initializer runs on the same task but always on CPU. + self.assertEqual(my4._initializer_op.op.inputs[1].device, + "/job:ps/device:CPU:0") def testInitFromRootCheckpoint(self): checkpoint_dir = self.get_temp_dir() @@ -362,6 +365,31 @@ class CheckpointsTest(test.TestCase): checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"useful_scope": "some_scope/"}) + def testNoAdditionalReadOpsForResourceVariables(self): + checkpoint_dir = self.get_temp_dir() + with self.test_session() as session: + v1, _, _, _ = _create_checkpoints(session, checkpoint_dir) + + # New graph and session. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as session: + my1 = resource_variable_ops.ResourceVariable([[0.0] * 10], name="my1") + + with ops.name_scope("init_from_checkpoint"): + checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1}) + + # Basic sanity checks: + session.run(variables.global_variables_initializer()) + self.assertAllEqual(session.run(my1), v1) + + ops_in_init_from_checkpoint_scope = [ + op for op in g.get_operations() + if (op.name.startswith("init_from_checkpoint/") and + not op.name.startswith("init_from_checkpoint/checkpoint_initializer" + ) and op.type != "AssignVariableOp") + ] + self.assertEqual(ops_in_init_from_checkpoint_scope, []) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/checkpointable.py b/tensorflow/python/training/checkpointable.py index 11caa761aec5d631d87a91ec876e0b5032ffdc5b..d0650eb127640a5cfb28f9c238343791bfa1746c 100644 --- a/tensorflow/python/training/checkpointable.py +++ b/tensorflow/python/training/checkpointable.py @@ -22,6 +22,7 @@ import collections from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_io_ops as io_ops from tensorflow.python.util import nest @@ -31,8 +32,8 @@ from tensorflow.python.util import nest # creation (avoiding double assignment when executing eagerly). VARIABLE_VALUE_KEY = "VARIABLE_VALUE" -_CheckpointableReference = collections.namedtuple( - "_CheckpointableReference", +CheckpointableReference = collections.namedtuple( + "CheckpointableReference", [ # The local name for this dependency. "name", @@ -181,13 +182,16 @@ class _CheckpointPosition(object): dtype = self._checkpoint.dtype_map[checkpoint_key] base_type = dtype.base_dtype with ops.init_scope(): - value, = io_ops.restore_v2( - prefix=self._checkpoint.save_path, - tensor_names=[checkpoint_key], - shape_and_slices=[""], - dtypes=[base_type], - name="%s_checkpoint_read" % (serialized_tensor.name,)) - value_tensors[serialized_tensor.name] = value + with ops.device("/cpu:0"): + # Run the restore itself on the CPU. + value, = io_ops.restore_v2( + prefix=self._checkpoint.save_path, + tensor_names=[checkpoint_key], + shape_and_slices=[""], + dtypes=[base_type], + name="%s_checkpoint_read" % (serialized_tensor.name,)) + # Copy the value to the current device if necessary. + value_tensors[serialized_tensor.name] = array_ops.identity(value) return value_tensors def restore_ops(self): @@ -204,10 +208,10 @@ class _CheckpointPosition(object): # Name saveables based on the name this object had when it was checkpointed. named_saveables = {} restore_ops = [] - in_graph_mode = context.in_graph_mode() + building_graph = not context.executing_eagerly() for serialized_tensor in self.object_proto.attributes: - saveable_object = saveables.get(serialized_tensor.name, None) - if saveable_object is None: + saveable_factory = saveables.get(serialized_tensor.name, None) + if saveable_factory is None: # Purposefully does not throw an exception if attributes have been added # or deleted. Stores unused attributes so an exception can be raised if # the user decides to check that everything in the checkpoint was @@ -215,13 +219,17 @@ class _CheckpointPosition(object): self._checkpoint.unused_attributes.setdefault( self.checkpointable, []).append(serialized_tensor.name) continue - if in_graph_mode: + if building_graph: existing_ops = self._checkpoint.restore_ops_by_name.get( serialized_tensor.name, None) else: existing_ops = None if existing_ops is None: - named_saveables[serialized_tensor.checkpoint_key] = saveable_object + if callable(saveable_factory): + saveable = saveable_factory(name=serialized_tensor.checkpoint_key) + else: + saveable = saveable_factory + named_saveables[serialized_tensor.checkpoint_key] = saveable if named_saveables: validated_saveables = ( self._checkpoint.builder._ValidateAndSliceInputs(named_saveables)) # pylint: disable=protected-access @@ -241,7 +249,7 @@ class _CheckpointPosition(object): saveable_index:saveable_index + num_specs] saveable_index += num_specs restore_op = saveable.restore(saveable_tensors, restored_shapes=None) - if in_graph_mode: + if building_graph: assert saveable.name not in self._checkpoint.restore_ops_by_name self._checkpoint.restore_ops_by_name[saveable.name] = restore_op restore_ops.append(restore_op) @@ -301,17 +309,22 @@ class CheckpointableBase(object): Not __init__, since most objects will forget to call it. """ - if hasattr(self, "_checkpoint_dependencies"): + if hasattr(self, "_unconditional_checkpoint_dependencies"): # __init__ already called. This check means that we don't need # Checkpointable.__init__() in the constructor of every TensorFlow object. return - # A list of _CheckpointableReference objects. - self._checkpoint_dependencies = [] + # A list of CheckpointableReference objects. Some classes implementing + # `Checkpointable`, notably `Optimizer`s, may override the + # _checkpoint_dependencies property with conditional dependencies + # (e.g. based on the current graph when saving). + self._unconditional_checkpoint_dependencies = [] # Maps names -> Checkpointable objects - self._dependency_names = {} + self._unconditional_dependency_names = {} # Restorations for other Checkpointable objects on which this object may - # eventually depend. - self._deferred_dependencies = {} # local name -> _CheckpointPosition list + # eventually depend. Maps local name -> _CheckpointPosition list. Optimizers + # tack on conditional dependencies, and so need separate management of + # deferred dependencies too. + self._unconditional_deferred_dependencies = {} # The UID of the highest assignment to this object. Used to ensure that the # last requested assignment determines the final value of an object. if hasattr(self, "_update_uid"): @@ -320,9 +333,51 @@ class CheckpointableBase(object): "initialization code was run.") self._update_uid = -1 + @property + def _checkpoint_dependencies(self): + """All dependencies of this object. + + May be overridden to include conditional dependencies. + + Returns: + A list of `CheckpointableReference` objects indicating named + `Checkpointable` dependencies which should be saved along with this + object. + """ + return self._unconditional_checkpoint_dependencies + + @property + def _deferred_dependencies(self): + """A dictionary with deferred dependencies. + + Stores restorations for other Checkpointable objects on which this object + may eventually depend. May be overridden by sub-classes (e.g. Optimizers use + conditional dependencies based the current graph, and so need separate + management of deferred dependencies too). + + Returns: + A dictionary mapping from local name to a list of _CheckpointPosition + objects. + """ + return self._unconditional_deferred_dependencies + + def _lookup_dependency(self, name): + """Look up a dependency by name. + + May be overridden to include conditional dependencies. + + Args: + name: The local name of the dependency. + Returns: + A `Checkpointable` object, or `None` if no dependency by this name was + found. + """ + return self._unconditional_dependency_names.get(name, None) + def _add_variable_with_custom_getter( self, name, shape=None, dtype=dtypes.float32, - initializer=None, getter=None, **kwargs_for_getter): + initializer=None, getter=None, overwrite=False, + **kwargs_for_getter): """Restore-on-create for a variable be saved with this `Checkpointable`. If the user has requested that this object or another `Checkpointable` which @@ -334,12 +389,11 @@ class CheckpointableBase(object): name: A name for the variable. Must be unique within this object. shape: The shape of the variable. dtype: The data type of the variable. - initializer: The initializer to use. Ignored if there is a deferred restoration left over from a call to `_restore_from_checkpoint_position`. - getter: The getter to wrap which actually fetches the variable. + overwrite: If True, disables unique name and type checks. **kwargs_for_getter: Passed to the getter. Returns: @@ -349,13 +403,13 @@ class CheckpointableBase(object): ValueError: If the variable name is not unique. """ self._maybe_initialize_checkpointable() - if name in self._dependency_names: + if not overwrite and self._lookup_dependency(name) is not None: raise ValueError( ("A variable named '%s' already exists in this Checkpointable, but " "Checkpointable._add_variable called to create another with " "that name. Variable names must be unique within a Checkpointable " "object.") % (name,)) - if context.in_eager_mode(): + if context.executing_eagerly(): # If this is a variable with a single Tensor stored in the checkpoint, we # can set that value as an initializer rather than initializing and then # assigning (when executing eagerly). This call returns None if there is @@ -385,7 +439,13 @@ class CheckpointableBase(object): # assign again. It will add this variable to our dependencies, and if there # is a non-trivial restoration queued, it will handle that. This also # handles slot variables. - return self._track_checkpointable(new_variable, name=name) + if not overwrite or isinstance(new_variable, CheckpointableBase): + return self._track_checkpointable(new_variable, name=name, + overwrite=overwrite) + else: + # TODO(allenl): Some variable types are not yet supported. Remove this + # fallback once all get_variable() return types are Checkpointable. + return new_variable def _preload_simple_restoration(self, name, shape): """Return a dependency's value for restore-on-create. @@ -455,9 +515,10 @@ class CheckpointableBase(object): raise TypeError( ("Checkpointable._track_checkpointable() passed type %s, not a " "Checkpointable.") % (type(checkpointable),)) - new_reference = _CheckpointableReference(name=name, ref=checkpointable) - if (name in self._dependency_names - and self._dependency_names[name] is not checkpointable): + new_reference = CheckpointableReference(name=name, ref=checkpointable) + current_object = self._lookup_dependency(name) + if (current_object is not None + and current_object is not checkpointable): if not overwrite: raise ValueError( ("Called Checkpointable._track_checkpointable() with name='%s', " @@ -465,19 +526,47 @@ class CheckpointableBase(object): "dependency. Names must be unique (or overwrite=True).") % (name,)) # This is a weird thing to do, but we're not going to stop people from # using __setattr__. - for index, (old_name, _) in enumerate(self._checkpoint_dependencies): + for index, (old_name, _) in enumerate( + self._unconditional_checkpoint_dependencies): if name == old_name: - self._checkpoint_dependencies[index] = new_reference + self._unconditional_checkpoint_dependencies[index] = new_reference else: - self._checkpoint_dependencies.append(new_reference) + self._unconditional_checkpoint_dependencies.append(new_reference) - self._dependency_names[name] = checkpointable - deferred_dependency_list = self._deferred_dependencies.pop(name, None) - if deferred_dependency_list is not None: - for checkpoint_position in deferred_dependency_list: - checkpoint_position.restore(checkpointable=checkpointable) + self._unconditional_dependency_names[name] = checkpointable + self._handle_deferred_dependencies(name=name, checkpointable=checkpointable) return checkpointable + def _handle_deferred_dependencies(self, name, checkpointable): + """Pop and load any deferred checkpoint restores into `checkpointable`. + + This method does not add a new dependency on `checkpointable`, but it does + check if any outstanding/deferred dependencies have been queued waiting for + this dependency to be added (matched based on `name`). If so, + `checkpointable` and its dependencies are restored. The restorations are + considered fulfilled and so are deleted. + + `_track_checkpointable` is more appropriate for adding a + normal/unconditional dependency, and includes handling for deferred + restorations. This method allows objects such as `Optimizer` to use the same + restoration logic while managing conditional dependencies themselves, by + overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the + object's dependencies based on the context it is saved/restored in (a single + optimizer instance can have state associated with multiple graphs). + + Args: + name: The name of the dependency within this object (`self`), used to + match `checkpointable` with values saved in a checkpoint. + checkpointable: The Checkpointable object to restore (inheriting from + `CheckpointableBase`). + """ + deferred_dependencies_list = self._deferred_dependencies.pop(name, ()) + for checkpoint_position in sorted( + deferred_dependencies_list, + key=lambda restore: restore.checkpoint.restore_uid, + reverse=True): + checkpoint_position.restore(checkpointable) + def _restore_from_checkpoint_position(self, checkpoint_position): """Restore this object and its dependencies (may be deferred).""" # Attempt a breadth-first traversal, since presumably the user has more @@ -513,7 +602,7 @@ class CheckpointableBase(object): child_position = _CheckpointPosition( checkpoint=checkpoint, proto_id=child.node_id) - local_object = self._dependency_names.get(child.local_name, None) + local_object = self._lookup_dependency(child.local_name) if local_object is None: # We don't yet have a dependency registered with this name. Save it # in case we do. @@ -532,14 +621,30 @@ class CheckpointableBase(object): """Returns a dictionary of values to checkpoint with this object. Keys in the returned dictionary are local to this object and in a separate - namespace from dependencies. Values may either be `SaveableObject`s or - variables easily converted to `SaveableObject`s (as in `tf.train.Saver`'s + namespace from dependencies. Values may either be `SaveableObject` factories + or variables easily converted to `SaveableObject`s (as in `tf.train.Saver`'s `var_list` constructor argument). + `SaveableObjects` have a name set, which Checkpointable needs to generate + itself. So rather than returning `SaveableObjects` directly, this method + should return a dictionary of callables which take `name` arguments and + return `SaveableObjects` with that name. + + If this object may also be passed to the global-name-based `tf.train.Saver`, + the returned callables should have a default value for their name argument + (i.e. be callable with no arguments). + Returned values must be saved only by this object; if any value may be shared, it should instead be a dependency. For example, variable objects save their own values with the key `VARIABLE_VALUE_KEY`, but objects which reference variables simply add a dependency. + + Returns: + The dictionary mapping attribute names to `SaveableObject` factories + described above. For example: + {VARIABLE_VALUE_KEY: + lambda name="global_name_for_this_object": + SaveableObject(name=name, ...)} """ return {} diff --git a/tensorflow/python/training/device_setter.py b/tensorflow/python/training/device_setter.py index 689088bb41edfd94a1d483ed2b5f7447e9e060e7..d31c375b4ce48dcb9bc2918514707636a647c675 100644 --- a/tensorflow/python/training/device_setter.py +++ b/tensorflow/python/training/device_setter.py @@ -25,6 +25,15 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib from tensorflow.python.util.tf_export import tf_export +# This is a tuple of PS ops used by tf.estimator.Esitmator which should work in +# almost all of cases. +STANDARD_PS_OPS = ( + "Variable", "VariableV2", "AutoReloadVariable", "MutableHashTable", + "MutableHashTableV2", "MutableHashTableOfTensors", + "MutableHashTableOfTensorsV2", "MutableDenseHashTable", + "MutableDenseHashTableV2", "VarHandleOp" +) + class _RoundRobinStrategy(object): """Returns the next ps task index for placement in round-robin order. @@ -170,8 +179,7 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps", than overriding them. cluster: `ClusterDef` proto or `ClusterSpec`. ps_ops: List of strings representing `Operation` types that need to be - placed on `ps` devices. If `None`, defaults to - `["Variable", "VariableV2", "VarHandleOp"]`. + placed on `ps` devices. If `None`, defaults to `STANDARD_PS_OPS`. ps_strategy: A callable invoked for every ps `Operation` (i.e. matched by `ps_ops`), that takes the `Operation` and returns the ps task index to use. If `None`, defaults to a round-robin strategy across all `ps` @@ -201,7 +209,7 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps", if ps_ops is None: # TODO(sherrym): Variables in the LOCAL_VARIABLES collection should not be # placed in the parameter server. - ps_ops = ["Variable", "VariableV2", "VarHandleOp"] + ps_ops = list(STANDARD_PS_OPS) if not merge_devices: logging.warning( diff --git a/tensorflow/python/training/device_util.py b/tensorflow/python/training/device_util.py new file mode 100644 index 0000000000000000000000000000000000000000..f1137e80ab4394333ef0f3b7982d5b55f4704d0d --- /dev/null +++ b/tensorflow/python/training/device_util.py @@ -0,0 +1,68 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Device-related support functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.framework import device as tf_device +from tensorflow.python.framework import ops + + +def canonicalize(d): + d = tf_device.DeviceSpec.from_string(d) + assert d.device_type is None or d.device_type == d.device_type.upper(), ( + "Device type '%s' must be all-caps." % (d.device_type,)) + # Fill in missing device fields using defaults. + result = tf_device.DeviceSpec( + job="localhost", replica=0, task=0, device_type="CPU", device_index=0) + result.merge_from(d) + return result.to_string() + + +class _FakeNodeDef(object): + """A fake NodeDef for _FakeOperation.""" + + def __init__(self): + self.op = "" + self.name = "" + + +class _FakeOperation(object): + """A fake Operation object to pass to device functions.""" + + def __init__(self): + self.device = "" + self.type = "" + self.name = "" + self.node_def = _FakeNodeDef() + + def _set_device(self, device): + self.device = ops._device_string(device) # pylint: disable=protected-access + + +def current(): + """Return a string (not canonicalized) for the current device.""" + # TODO(josh11b): Work out how this function interacts with ops.colocate_with. + ctx = context.context() + if ctx.executing_eagerly(): + d = ctx.device_name + else: + op = _FakeOperation() + ops.get_default_graph()._apply_device_functions(op) # pylint: disable=protected-access + d = op.device + return d diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py new file mode 100644 index 0000000000000000000000000000000000000000..9261e132302043e97f2adb696fbde2dd01c897ce --- /dev/null +++ b/tensorflow/python/training/distribute.py @@ -0,0 +1,1118 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class DistributionStrategy, TowerContext, and supporting APIs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses_impl +from tensorflow.python.training import device_util +from tensorflow.python.util import nest + + +# ------------------------------------------------------------------------------ +# Internal API for setting the current thread mode as being either in a +# tower or cross-tower context for a particular distribution strategy. + + +class _ThreadMode(object): + + def __init__(self, dist, cross, tower): + self.distribution_strategy = dist + self.cross_tower_context = cross + self.tower_context = tower + + +class _CrossTowerThreadMode(_ThreadMode): + + def __init__(self, distribution_strategy): + _ThreadMode.__init__( + self, distribution_strategy, distribution_strategy, None) + + +class _InTowerThreadMode(_ThreadMode): + + def __init__(self, tower_ctx): + _ThreadMode.__init__( + self, tower_ctx.distribution_strategy, None, tower_ctx) + + +_per_thread_mode = threading.local() + + +def _push_per_thread_mode(context): + if not hasattr(_per_thread_mode, "stack"): + _per_thread_mode.stack = [] + _per_thread_mode.stack.append(context) + + +def _pop_per_thread_mode(): + _per_thread_mode.stack.pop(-1) + + +class _DefaultTowerThreadMode(_ThreadMode): + """Type of default value returned by `_get_per_thread_mode()`. + + Used when the thread-local stack is empty. + """ + + def __init__(self): + # _default_distribution_strategy and _default_tower_context are + # defined at the bottom of this file. + _ThreadMode.__init__( + self, _default_distribution_strategy, None, _default_tower_context) + + +def _get_per_thread_mode(): + try: + return _per_thread_mode.stack[-1] + except (AttributeError, IndexError): + # _default_tower_mode is defined at the bottom of this file. + return _default_tower_mode + + +# ------------------------------------------------------------------------------ +# Context tracking whether in a distribution.update() or .update_non_slot() +# call. + + +_update_device = threading.local() + + +def get_update_device(): + try: + return _update_device.current + except AttributeError: + return None + + +class UpdateContext(object): + """Context manager when you are in `update()` or `update_non_slot()`.""" + + def __init__(self, device): + self._device = device + self._old_device = None + + def __enter__(self): + self._old_device = get_update_device() + _update_device.current = self._device + + def __exit__(self, exception_type, exception_value, traceback): + del exception_type, exception_value, traceback + _update_device.current = self._old_device + + +# ------------------------------------------------------------------------------ +# Public API for accessing the current thread mode + + +def get_tower_context(): + """Returns the current TowerContext or None. + + Note that execution: + 1. starts in the default (single-tower) tower context; + 2. switches to cross-tower context when entering a + `with DistributionStrategy.scope():` block; + 3. switches to a (non-default) tower context inside + `call_for_each_tower(fn, ...)`; + 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then + inside `merge_fn` you are back in the cross-tower context. + + Note that you can also go directly from step 1 to 4 to switch to a + cross-tower context for the default `DistributionStrategy`. You may + also switch from the cross-tower context of 4 to a tower context by + calling `call_for_each_tower()`, jumping back to step 3. + + Most `DistributionStrategy` methods may only be executed in + a cross-tower context, in a tower context you should use the + `TowerContext` API instead. + + Returns: + The current `TowerContext` object when in a tower context scope, else None. + + Exactly one of `get_tower_context()` and `get_cross_tower_context()` + will return None in a particular block. + """ + return _get_per_thread_mode().tower_context + + +def get_cross_tower_context(): + """Returns the current DistributionStrategy if in a cross-tower context. + + Note that execution: + 1. starts in the default (single-tower) tower context; + 2. switches to cross-tower context when entering a + `with DistributionStrategy.scope():` block; + 3. switches to a (non-default) tower context inside + `call_for_each_tower(fn, ...)`; + 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then + inside `merge_fn` you are back in the cross-tower context. + + Note that you can also go directly from step 1 to 4 to switch to a + cross-tower context for the default `DistributionStrategy`. You may + also switch from the cross-tower context of 4 to a tower context by + calling `call_for_each_tower()`, jumping back to step 3. + + Most `DistributionStrategy` methods may only be executed in + a cross-tower context. + + Returns: + Returns the current `DistributionStrategy` object in a cross-tower + context, or None. + + Exactly one of `get_tower_context()` and `get_cross_tower_context()` + will return None in a particular block. + """ + return _get_per_thread_mode().cross_tower_context + + +def get_distribution_strategy(): + """Returns the current `DistributionStrategy` object. + + Returns: + A `DistributionStrategy` object. Inside a + `with distribution_strategy.scope()` block, it returns + `distribution_strategy`, otherwise it returns the default + (single-tower) `DistributionStrategy` object. + """ + return _get_per_thread_mode().distribution_strategy + + +def has_distribution_strategy(): + """Return if there is a current non-default `DistributionStrategy`. + + Returns: + True if inside a `with distribution_strategy.scope():`. + """ + return get_distribution_strategy() is not _default_distribution_strategy + + +# ------------------------------------------------------------------------------ +# Public utility functions. + + +def get_loss_reduction(): + """Reduce `method_string` corresponding to the last loss reduction.""" + loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access + if loss_reduction == losses_impl.Reduction.SUM: + return "sum" + return "mean" + + +# ------------------------------------------------------------------------------ +# Internal API for validating the current thread mode + + +def _require_cross_tower_context(distribution_strategy): + """Verify in cross-tower context for `distribution_strategy`.""" + context = _get_per_thread_mode() + if context.cross_tower_context is distribution_strategy: return + # We have an error to report, figure out the right message. + if context.distribution_strategy is not distribution_strategy: + if context.distribution_strategy is _default_distribution_strategy: + raise RuntimeError( + 'Need to be inside "with distribution_strategy.scope()" for %s' % + (distribution_strategy,)) + else: + raise RuntimeError( + "Mixing different DistributionStrategy objects: %s is not %s" % + (context.distribution_strategy, distribution_strategy)) + assert context.cross_tower_context is None + raise RuntimeError("Method requires being in cross-tower context, use " + "get_tower_context().merge_call()") + + +def require_tower_context(tower_ctx): + """Verify in `tower_ctx` tower context.""" + context = _get_per_thread_mode() + if context.tower_context is tower_ctx: return + # We have an error to report, figure out the right message. + if context.tower_context is None: + raise RuntimeError("Need to be inside `call_for_each_tower()`") + if context.distribution_strategy is tower_ctx.distribution_strategy: + # Two different TowerContexts with the same DistributionStrategy. + raise RuntimeError("Mismatching tower context.") + raise RuntimeError( + "Mismatching DistributionStrategy objects: %s is not %s." % + (context.distribution_strategy, tower_ctx.distribution_strategy)) + + +def _require_distribution_strategy_scope(distribution_strategy): + """Verify in a `distribution_strategy.scope()` in this thread.""" + context = _get_per_thread_mode() + if context.distribution_strategy is distribution_strategy: return + # We have an error to report, figure out the right message. + if context.distribution_strategy is _default_distribution_strategy: + raise RuntimeError( + 'Need to be inside "with distribution_strategy.scope()" for %s' % + (distribution_strategy,)) + else: + raise RuntimeError( + "Mixing different DistributionStrategy objects: %s is not %s" % + (context.distribution_strategy, distribution_strategy)) + + +# ------------------------------------------------------------------------------ +# Internal context managers used to implement the DistributionStrategy +# base class + + +class _CurrentDistributionContext(object): + """Context manager for setting the `DistributionStrategy` and var creator.""" + + def __init__(self, distribution_strategy, var_creator_scope): + self._context = _CrossTowerThreadMode(distribution_strategy) + self._var_creator_scope = var_creator_scope + + def __enter__(self): + _push_per_thread_mode(self._context) + self._var_creator_scope.__enter__() + return self._context.distribution_strategy + + def __exit__(self, exception_type, exception_value, traceback): + self._var_creator_scope.__exit__(exception_type, exception_value, traceback) + _pop_per_thread_mode() + + +class _SameScopeAgainContext(object): + """Trivial context manager when you are already in `scope()`.""" + + def __init__(self, distribution_strategy): + self._distribution_strategy = distribution_strategy + + def __enter__(self): + return self._distribution_strategy + + def __exit__(self, exception_type, exception_value, traceback): + del exception_type, exception_value, traceback + + +# ------------------------------------------------------------------------------ +# Base classes for all distribution strategies. + + +class DistributionStrategy(object): + """A list of devices with a state & compute distribution policy. + + The intent is that you can write an algorithm in a stylized way and + it will be usable with a variety of different `DistributionStrategy` + implementations. Each descendant will implement a different strategy + for distributing the algorithm across multiple devices/machines. + Furthermore, these changes can be hidden inside the specific layers + and other library classes that need special treatment to run in a + distributed setting, so that most users' model definition code can + run unchanged. The `DistributionStrategy` API works the same way + with eager and graph execution. + + First let's introduce a few high-level concepts: + + * _Data parallelism_ is where we run multiple copies of the model + on different slices of the input data. This is in contrast to + _model parallelism_ where we divide up a single copy of a model + across multiple devices. + Note: for now we only support data parallelism at this time, but + hope to add support for model parallelism in the future. + * A _tower_ is one copy of the model, running on one slice of the + input data. + * _Synchronous_, or more commonly _sync_, training is when the + updates from each tower are aggregated together before updating + the model variables. This is in contrast to _asynchronous_, or + _async_ training where each tower updates the model variables + independently. + * Furthermore you might run your computation on multiple devices + on one machine (or "host"), or on multiple machines/hosts. + If you are running on multiple machines, you might have a + single master host that drives computation across all of them, + or you might have multiple clients driving the computation + asynchronously. + + To distribute an algorithm, we might use some of these ingredients: + + * Parameter servers: These are hosts that hold a single copy of + parameters/variables. All towers that want to operate on a variable + retrieve it at the beginning of a step and send an update to be + applied at the end of the step. Can support either sync or async + training. + * Mirrored variables: These are variables that are copied to multiple + devices, where we keep the copies in sync by applying the same + updates to every copy. Normally would only be used with sync training. + * Reductions and Allreduce: A _reduction_ is some method of + aggregating multiple values into one value, like "sum" or + "mean". If doing sync training, we will perform a reduction on the + gradients to a parameter from each tower before applying the + update. Allreduce is an algorithm for performing a reduction on + values from multiple devices and making the result available on + all of those devices. + * TODO(josh11b): Future: partitioned variables + + We have then a few approaches we want to support: + * Code written (as if) with no knowledge of class `DistributionStrategy`. + This code should work as before, even if some of the layers, etc. + used by that code are written to be distribution-aware. This is done + by having a default `DistributionStrategy` that gives ordinary behavior, + and by default being in a single tower context. + * Ordinary model code that you want to run using a specific + `DistributionStrategy`. This can be as simple as: + + ``` + with my_distribution.scope(): + iterator = my_distribution.distribute_dataset(dataset) + # TODO(josh11b): iterator = dataset.make_one_shot_iterator() + tower_train_ops = my_distribution.call_for_each_tower( + tower_fn, iterator.get_next()) + train_op = tf.group(my_distribution.unwrap(tower_train_ops)) + ``` + + This takes an ordinary `dataset` and `tower_fn` and runs it + distributed using a particular `DistributionStrategy` in + `my_distribution`. Any variables created in `tower_fn` are created + using `my_distribution`'s policy, and library functions called by + `tower_fn` can use the `get_tower_context()` API to get enhanced + behavior in this case. + * If you want to write a distributed algorithm, you may use any of + the `DistributionStrategy` APIs inside a + `with my_distribution.scope():` block of code. + + Lower-level concepts: + + * Wrapped values: In order to represent values parallel across devices + (either towers or the devices associated with a particular value), we + wrap them in a "PerDevice" or "Mirrored" object that contains a map + from device to values. "PerDevice" is used when the value may be + different across devices, and "Mirrored" when the value are the same. + * Unwrapping and merging: Consider calling a function `fn` on + multiple devices, like `call_for_each_tower(fn, w)` with an + argument `w that is a wrapped value. This means `w` will have a + map taking tower device `d0` to `w0`, tower device `d1` to `w1`, + etc. `call_for_each_tower()` unwraps `w` before calling `fn`, so + it calls `fn(w0)` on `d0`, `fn(w1)` on `d1`, etc. It then merges + the return values from `fn()`, which can possibly result in + wrapped values. For example, let's say `fn()` returns a tuple with + three components: (x, a, v0) from tower 0, (x, b, v1) on tower 1, + etc. If the first component is the same object `x` from every + tower, then the first component of the merged result will also be + `x`. If the second component is different (`a`, `b`, ...) from + each tower, then the merged value will have a wrapped map from + tower device to the different values. If the third component is + the members of a mirrored variable (`v` maps `d0` to `v0, `d1` to + `v1`, etc.), then the merged result will be that mirrored variable + (`v`). + * Tower context vs. Cross-tower context: _tower context_ is when we + are in some function that is being called once for each tower. + Otherwise we are in cross-tower context, which is useful for + calling `DistributionStrategy` methods which operate across the + towers (like `reduce()`). By default you start in a tower context + (the default "single tower context") and then some methods can + switch you back and forth, as described below. + * Worker devices vs. parameter devices: Most tower computations will + happen on worker devices. Since we don't yet support model + parallelism, there will be one worker device per tower. When using + parameter servers (see above), the set of devices holding + variables may be different, otherwise the parameter devices might + match the worker devices. + * Non-slot devices are some subset of the parameter devices where we + put all the non-slot variables. We need to ensure that all + non-slot variables are allocated on the same device, or mirrored + across the same set of devices. If you have some variable you want + to colocate all the non-slot variables with, you can use + `colocate_vars_with()` to get the remaining non-slot variables on + the same device. Otherwise you can use `non_slot_devices()` to + pick a consistent set of devices to pass to both + `colocate_vars_with()` and `update_non_slot()`. + + When using a `DistributionStrategy`, we have a new type dimension + called _locality_ that says what values are compatible with which + APIs: + + * T: different value for each tower (e.g. a PerDevice-wrapped value). + * M: value is "mirrored" across towers, i.e. there are copies with the + same value on each tower (e.g. a Mirrored-wrapped value). + * V(`v`): value is "mirrored" across all the devices which have a + copy of variable `v` (also a Mirrored-wrapped value, but over + parameter devices instead of worker devices). + * N: value is "mirrored" across all the "non-slot" devices + + Rules for methods with respect to locality and single-tower vs. + cross-tower context: + + * `with d.scope()`: default single-tower context -> cross-tower context for + `d` + * `with d.colocate_vars_with(v)`: in tower/cross-tower context, variables + will be created with locality V(`v`). That is, if we write + `with d.colocate_vars_with(v1): v2 = tf.get_variable(...)`, then + `v2` will have locality V(`v1`), i.e. locality V(`v2`) will equal + V(`v1`). + * `with d.colocate_vars_with(d.non_slot_devices(...))`: in + tower/cross-tower context, variables will be created with locality N + * `v = tf.get_variable(...)`: in tower/cross-tower context, creates + a variable (which by definition will have locality V(`v`), though + will match another locality if inside a `colocate_vars_with` + scope). + * `d.distribute_dataset(dataset)`: in cross-tower context, produces an + iterator with locality T + * `d.broadcast(t)`: in cross-tower context, produces a value with locality M + * `d.broadcast(t, v)`: in cross-tower context, produces a value with + locality V(`v`) + * `d.call_for_each_tower(fn, ...)`: in cross-tower context, runs + `fn()` in a tower context (and so may call `get_tower_context()` and + use its API, including `merge_call()` to get back to cross-tower + context), once for each tower. May use values with locality T or + M, and any variable. + * `d.reduce(m, t)`: in cross-tower context, accepts t with locality T + and produces a value with locality M. + * `d.reduce(m, t, v)`: in cross-tower context, accepts t with + locality T and produces a value with locality V(`v`). + * `d.batch_reduce(m, [(t, v)]): see `d.reduce()` + * `d.update(v, fn, ...)`: in cross-tower context, runs `fn()` once + for each device `v` is copied to, all inputs should have locality + V(`v`), output will have locality V(`v`) as well. + * `d.update_non_slot(d.non_slot_devices(), fn)`: in cross-tower + context, like `d.update()` except with locality N. + * `d.fetch(t)`: Copy `t` with any locality to the client's CPU device. + + The standard pattern for updating variables is to: + + 1. Wrap your input dataset in `d.distribute_dataset()`. + 2. Define each tower `d.call_for_each_tower()` up to the point of + getting a list of gradient, variable pairs. + 3. Call `d.reduce("sum", t, v)` or `d.batch_reduce()` to sum the + gradients (with locality T) into values with locality V(`v`). + 4. Call `d.update(v)` for each variable to update its value. + + Steps 3 and 4 are done automatically by class `Optimizer` if you call + its `apply_gradients` method in a tower context. Otherwise you can + manually call its `distributed_apply` method in a cross-tower context. + + Another thing you might want to do in the middle of your tower function + is an all-reduce of some intermediate value, using `d.reduce()` or + `d.batch_reduce()` without supplying a variable as the destination. + + Layers should expect to be called in a tower context, and can use + the `get_tower_context()` function to get a `TowerContext` object. The + `TowerContext` object has a `merge_call()` method for entering + cross-tower context where you can use `reduce()` (or + `batch_reduce()`) and then optionally `update()` to update state. + + You may use this API whether or not a `DistributionStrategy` is + being used, since there is a default implementation of + `TowerContext` and `DistributionStrategy`. Or you can use the + `get_tower_context().is_single_tower` property to run different code + in the distributed vs. single tower cases. + """ + + # TODO(josh11b): Raise an exception if variable paritioning requested before + # we add support. + # TODO(josh11b): Also `parameter_device_index` property? + # TODO(josh11b): `map()` + # TODO(josh11b): ClusterSpec/ClusterResolver + # TODO(josh11b): Partitioned computations, state; sharding + # TODO(josh11b): Model parallelism: "towers" with multiple devices; shuffling + # TODO(josh11b): Tower-local variables + # TODO(josh11b): List of towers with their worker and parameter devices + # (where the parameter devices may overlap in the ps case). + + def scope(self): + """Returns a context manager selecting this DistributionStrategy as current. + + Inside a `with distribution_strategy.scope():` code block, this thread + will use a variable creator set by `distribution_strategy`, and will + enter its "cross-tower context". + + Returns: + A context manager. + """ + if has_distribution_strategy(): + _require_cross_tower_context(self) + return _SameScopeAgainContext(self) + + def creator_with_resource_vars(*args, **kwargs): + _require_distribution_strategy_scope(self) + kwargs["use_resource"] = True + return self._create_variable(*args, **kwargs) + + return _CurrentDistributionContext( + self, variable_scope.variable_creator_scope(creator_with_resource_vars)) + + def _create_variable(self, next_creator, *args, **kwargs): + # Note: should support "colocate_with" argument. + raise NotImplementedError("must be implemented in descendants") + + def colocate_vars_with(self, colocate_with_variable): + """Controls which devices variables will be created on. + + Note this may only be used inside `self.scope()`. + + Example usage: + + ``` + with distribution_strategy.scope(): + var1 = tf.get_variable(...) + with distribution_strategy.colocate_vars_with(v1): + # var2 and var3 will be created on the same device(s) as var1 + var2 = tf.get_variable(...) + var3 = tf.get_variable(...) + + def fn(v1, v2, v3): + # operates on v1 from var1, v2 from var2, and v3 from var3 + + # `fn` runs on every device `v1` is on, `v2` and `v3` will be there too. + distribution_strategy.update(v1, fn, v2, v3) + ``` + + Args: + colocate_with_variable: A created in `self.scope()`. Variables created + while in the returned context manager will be on the same set of + devices as `colocate_with_variable`. + + Returns: + A context manager. + """ + def create_colocated_variable(next_creator, *args, **kwargs): + _require_distribution_strategy_scope(self) + kwargs["use_resource"] = True + kwargs["colocate_with"] = colocate_with_variable + return next_creator(*args, **kwargs) + + _require_distribution_strategy_scope(self) + return variable_scope.variable_creator_scope(create_colocated_variable) + + # TODO(josh11b): Currently this returns an iterator, but should return + # something implementing (a subset of) the Dataset API. + def distribute_dataset(self, dataset): + """Return an iterator into `dataset` split across all towers. + + Suitable for providing input to for `call_for_each_tower()`, as in: + + ``` + with distribution_strategy.scope(): + iterator = distribution_strategy.distribute_dataset(dataset) + tower_results = distribution_strategy.call_for_each_tower( + tower_fn, iterator.get_next()) + ``` + + Args: + dataset: A `tf.data.Dataset`. + + Returns: + A Dataset iterator that will produce separate splits for each tower. + """ + raise NotImplementedError("must be implemented in descendants") + + def broadcast(self, tensor, destinations=None): + """Mirror a tensor on one device to all worker devices. + + Args: + tensor: A Tensor value to broadcast. + destinations: An optional mirrored variable, device string, or + list of device strings, specifying the destination devices + to copy `tensor` to. Defaults to `self.worker_devices`. + + Returns: + A value mirrored to `destinations` devices. + """ + # TODO(josh11b): More docstring + _require_cross_tower_context(self) + return self._broadcast(tensor, destinations) + + def _broadcast(self, tensor, destinations): + raise NotImplementedError("must be implemented in descendants") + + def call_for_each_tower(self, fn, *args, **kwargs): + """Run `fn` once per tower. + + `fn` may call `tf.get_tower_context()` to access methods such as + `tower_id()` and `merge_call()`. + + `merge_call()` is used to communicate betwen the towers and + re-enter the cross-tower context. All towers pause their execution + having encountered a `merge_call()` call. After that the + `merge_fn`-function is executed. Its results are then unwrapped and + given back to each tower call. After that execution resumes until + `fn` is complete or encounters another `merge_call()`. Example: + + ```python + # Called once in "cross-tower" context. + def merge_fn(distribution, three_plus_tower_id): + # sum the values across towers + return sum(distribution.unwrap(three_plus_tower_id)) + + # Called once per tower in `distribution`, in a "tower" context. + def fn(three): + tower_ctx = tf.get_tower_context() + v = three + tower_ctx.tower_id + # Computes the sum of the `v` values across all towers. + s = tower_ctx.merge_call(merge_fn, v) + return s + v + + with distribution.scope(): + # in "cross-tower" context + ... + merged_results = distribution.call_for_each_tower(fn, 3) + # merged_results has the values from every tower execution of `fn`. + print(distribution.unwrap(merged_results)) # Prints a list + ``` + + Args: + fn: function to run (will be run once per tower). + *args: positional arguments for `fn` + **kwargs: keyword arguments for `fn`. + `"run_concurrently"`: Boolean indicating whether executions of `fn` + can be run concurrently (under eager execution only), defaults to + `True`. + + Returns: + Merged return value of `fn` across all towers. + """ + _require_cross_tower_context(self) + return self._call_for_each_tower(fn, *args, **kwargs) + + def _call_for_each_tower(self, fn, *args, **kwargs): + raise NotImplementedError("must be implemented in descendants") + + def reduce(self, method_string, value, destinations=None): + """Combine (via e.g. sum or mean) values across towers. + + Args: + method_string: A string indicating how to combine values, either + "sum" or "mean". + value: A per-device value with one value per tower. + destinations: An optional mirrored variable, a device string, + list of device strings. The return value will be copied to all + destination devices (or all the devices where the mirrored + variable resides). If `None` or unspecified, the destinations + will match the devices `value` resides on. + + Returns: + A value mirrored to `destinations`. + """ + # TODO(josh11b): More docstring + # TODO(josh11b): Return an unwrapped value if colocate_with is a + # single device. + _require_cross_tower_context(self) + return self._reduce(method_string, value, destinations) + + def _reduce(self, method_string, value, destinations): + raise NotImplementedError("must be implemented in descendants") + + def batch_reduce(self, method_string, value_destination_pairs): + """Combine multiple `reduce` calls into one for faster execution. + + Args: + method_string: A string indicating how to combine values, either + "sum" or "mean". + value_destination_pairs: A sequence of (value, destinations) + pairs. See `reduce()` for a description. + + Returns: + A list of mirrored values, one per pair in `value_destination_pairs`. + """ + # TODO(josh11b): More docstring + _require_cross_tower_context(self) + assert method_string in ("sum", "mean") + return self._batch_reduce(method_string, value_destination_pairs) + + def _batch_reduce(self, method_string, value_destination_pairs): + return [self.reduce(method_string, t, destinations=v) + for t, v in value_destination_pairs] + + def update(self, var, fn, *args, **kwargs): + """Run `fn` to update `var` using inputs mirrored to the same devices. + + If `var` is mirrored across multiple devices, then this implements + logic like: + + ``` + results = {} + for device, v in var: + with tf.device(device): + # *args and **kwargs will be unwrapped if they are mirrored. + results[device] = fn(v, *args, **kwargs) + return merged(results) + ``` + + Otherwise this returns `fn(var, *args, **kwargs)` colocated with `var`.' + + Neither *args nor **kwargs may contain per-device values. + If they contain mirrored values, they will be unwrapped before + calling `fn`. + + Args: + var: Variable, possibly mirrored to multiple devices, to operate on. + fn: Function to call. Should take the variable as the first argument. + *args: Additional positional arguments to pass to `fn()`. + **kwargs: Keyword arguments to pass to `fn()`. + + Returns: + Merged return value of `fn` across all towers. + """ + _require_cross_tower_context(self) + return self._update(var, fn, *args, **kwargs) + + def _update(self, var, fn, *args, **kwargs): + raise NotImplementedError("must be implemented in descendants") + + def update_non_slot(self, colocate_with, fn, *args, **kwargs): + """Runs `fn(*args, **kwargs)` on `colocate_with` devices. + + Args: + colocate_with: The return value of `non_slot_devices()`. + fn: Function to execute. + *args: Positional arguments to pass to `fn()`. + **kwargs: Keyword arguments to pass to `fn()`. + + Returns: + Return value of `fn`, possibly merged across devices. + """ + _require_cross_tower_context(self) + return self._update_non_slot(colocate_with, fn, *args, **kwargs) + + def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + raise NotImplementedError("must be implemented in descendants") + + def fetch(self, val, destination="/device:CPU:0", fn=lambda x: x): + """Return a copy of `val` or `fn(val)` on `destination`. + + This is useful for getting a mirrored value onto a device. It + will attempt to avoid a copy by checking if the value is already + on the destination device. + + Args: + val: Value (which may be mirrored) to copy. + destination: A device string to copy the value to. + fn: An optional function to apply to the value on the source + device, before copying. + + Returns: + A `Tensor` on `destination`. + """ + _require_cross_tower_context(self) + return self._fetch(val, destination, fn) + + def _fetch(self, val, destination, fn): + raise NotImplementedError("must be implemented in descendants") + + def unwrap(self, value): + """Returns the list of all per-device values contained in `value`. + + Args: + value: A value returned by `call_for_each_tower()` or a variable + created in `scope()`. + + Returns: + A list of values contained in `value`. If `value` represents a single + value, this returns `[value].` + """ + _require_cross_tower_context(self) + return self._unwrap(value) + + def _unwrap(self, distributed_value): + raise NotImplementedError("must be implemented in descendants") + + def group(self, value, name=None): + """Shortcut for `tf.group(distribution.unwrap(value))`.""" + value = nest.flatten(self.unwrap(value)) + + if len(value) != 1 or name is not None: + return control_flow_ops.group(value, name=name) + # Special handling for the common case of one op. + v, = value + if isinstance(v, ops.Tensor): + v = v.op + return v + + @property + def is_single_tower(self): + """Returns whether there is a single tower or multiple. + + Returns: + A boolean. If `True`, `call_for_each_tower(fn)` will only call `fn` once. + If `False`, `call_for_each_tower(fn)` may call `fn` multiple times. + """ + raise NotImplementedError("must be implemented in descendants") + + @property + def num_towers(self): + """Returns number of towers, for purposes of averaging across towers.""" + raise NotImplementedError("must be implemented in descendants") + + @property + def worker_devices(self): + """Returns the list of devices used to run `call_for_each_tower()` calls.""" + # TODO(josh11b): More docstring + raise NotImplementedError("must be implemented in descendants") + + @property + def parameter_devices(self): + """Returns the list of devices used for variable and `update` placement.""" + # TODO(josh11b): More docstring + raise NotImplementedError("must be implemented in descendants") + + def non_slot_devices(self, var_list): + """Device(s) for non-slot variables. + + Create variables on these devices in a + `with colocate_vars_with(non_slot_devices(...)):` block. + Update those using `update_non_slot()`. + + Args: + var_list: The list of variables being optimized, needed with the + default `DistributionStrategy`. + """ + raise NotImplementedError("must be implemented in descendants") + + @property + def worker_device_index(self): + """An object mapping worker device to an id. + + This might be passed as an argument to `call_for_each_tower()`, as in: + + ``` + with distribution_strategy.scope(): + + def fn(device_id): + # device_id is an integer. `fn` is being executed on device: + # distribution_strategy.worker_devices[device_id]. + + distribution_strategy.call_for_each_tower( + fn, distribution_strategy.worker_device_index) + ``` + + Returns: + An index object, or the integer 0 if there is only a single tower. + """ + _require_cross_tower_context(self) + return self._worker_device_index() + + def _worker_device_index(self): + raise NotImplementedError("must be implemented in descendants") + + +# A note about the difference between the context managers +# `TowerContext` (defined here) and `_CurrentDistributionContext` +# (defined above) used by `DistributionStrategy.scope()`: +# +# * a TowerContext is only present during a `call_for_each_tower()` +# call (except during a `merge_run` call) and in such a scope it +# will be returned by calls to `get_tower_context()`. Implementers of new +# DistributionStrategy descendants will frequently also need to +# define a descendant of TowerContext, and are responsible for +# entering and exiting this context. +# +# * DistributionStrategy.scope() sets up a variable_creator scope that +# changes variable creation calls (e.g. to make mirrored +# variables). This is intended as an outer scope that users enter once +# around their model creation and graph definition. There is no +# anticipated need to define descendants of _CurrentDistributionContext. +# It sets the current DistributionStrategy for purposes of +# `get_distribution_strategy()` and `has_distribution_strategy()` +# and switches the thread mode to a "cross-tower context". +class TowerContext(object): + """DistributionStrategy API inside a `call_for_each_tower()` call.""" + + def __init__(self, distribution_strategy, tower_id): + self._distribution_strategy = distribution_strategy + self._thread_context = _InTowerThreadMode(self) + self._tower_id = tower_id + + def __enter__(self): + _push_per_thread_mode(self._thread_context) + + def __exit__(self, exception_type, exception_value, traceback): + _pop_per_thread_mode() + + def merge_call(self, merge_fn, *args, **kwargs): + """Merge args across towers and run `merge_fn` in a cross-tower context. + + This allows communication and coordination when there are multiple calls + to a model function triggered by a call to + `distribution.call_for_each_tower(model_fn, ...)`. + + See `MirroredDistribution.call_for_each_tower()` for an explanation. + + Otherwise, this is equivalent to: + + ``` + distribution = get_distribution_strategy() + with cross-tower-context(distribution): + return merge_fn(distribution, *args, **kwargs) + ``` + + Args: + merge_fn: function that joins arguments from threads that are given as + PerDevice. It accepts `DistributionStrategy` object as the first + argument. + *args: positional per-thread arguments for `merge_fn` + **kwargs: keyword per-thread arguments for `merge_fn`. + + Returns: + The return value of `merge_fn`, except for `PerDevice` values which are + unpacked. + """ + require_tower_context(self) + return self._merge_call(merge_fn, *args, **kwargs) + + def _merge_call(self, merge_fn, *args, **kwargs): + """Default implementation for single tower.""" + _push_per_thread_mode( # thread-local, so not needed with multiple threads + _CrossTowerThreadMode(self._distribution_strategy)) + try: + return merge_fn(self._distribution_strategy, *args, **kwargs) + finally: + _pop_per_thread_mode() + + @property + def is_single_tower(self): + """Returns whether there is a single tower or multiple.""" + require_tower_context(self) + return self._distribution_strategy.is_single_tower + + @property + def num_towers(self): + """Returns number of towers, for purposes of averaging across towers.""" + return self._distribution_strategy.num_towers + + @property + def tower_id(self): + """Which tower is being defined, a number from 0 to `num_towers - 1`.""" + require_tower_context(self) + return self._tower_id + + @property + def distribution_strategy(self): + """The current `DistributionStrategy` object.""" + return self._distribution_strategy + + @property + def device(self): + """The device this tower is to be executed on, as a string.""" + require_tower_context(self) + return device_util.current() + + # TODO(josh11b): Implement `start_all_reduce(method, t)` that returns + # a function returning the result of reducing `t` across all + # towers. Most likely can be implemented in terms of `merge_call()` + # and `batch_reduce()`. + +# ------------------------------------------------------------------------------ + + +class _DefaultDistributionStrategy(DistributionStrategy): + """Default `DistributionStrategy` if none is explicitly selected.""" + + def scope(self): + """Context manager setting a variable creator and `self` as current.""" + if has_distribution_strategy(): + raise RuntimeError("Must not nest DistributionStrategy scopes.") + + def creator(next_creator, *args, **kwargs): + _require_distribution_strategy_scope(self) + return next_creator(*args, **kwargs) + + return _CurrentDistributionContext( + self, variable_scope.variable_creator_scope(creator)) + + def colocate_vars_with(self, colocate_with_variable): + """Does not require `self.scope`.""" + def create_colocated_variable(next_creator, *args, **kwargs): + _require_distribution_strategy_scope(self) + with ops.colocate_with(colocate_with_variable): + return next_creator(*args, **kwargs) + + _require_distribution_strategy_scope(self) + return variable_scope.variable_creator_scope(create_colocated_variable) + + def distribute_dataset(self, dataset): + # TODO(josh11b): Support for this when executing eagerly is currently only + # in contrib. + return dataset.make_one_shot_iterator() + + def _broadcast(self, tensor, destinations): + if destinations is None: + return tensor + else: + raise NotImplementedError("TODO") + + def _call_for_each_tower(self, fn, *args, **kwargs): + # We don't run `fn` in multiple threads in _DefaultDistributionStrategy. + kwargs.pop("run_concurrently", None) + with TowerContext(self, tower_id=0): + return fn(*args, **kwargs) + + def _reduce(self, method_string, value, destinations): + # TODO(josh11b): Use destinations? + del method_string, destinations + return value + + def _update(self, var, fn, *args, **kwargs): + # TODO(josh11b): Figure out what we should be passing to UpdateContext() + # once that value is used for something. + with ops.colocate_with(var), UpdateContext(var): + return fn(var, *args, **kwargs) + + def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + # TODO(josh11b): Figure out what we should be passing to UpdateContext() + # once that value is used for something. + with ops.colocate_with(colocate_with), UpdateContext(colocate_with): + return fn(*args, **kwargs) + + def _fetch(self, var, destination, fn): + with ops.colocate_with(var): + var = fn(var) + with ops.device(destination): + return array_ops.identity(var) + + def _unwrap(self, distributed_value): + return [distributed_value] + + @property + def is_single_tower(self): + return True + + @property + def num_towers(self): + return 1 + + @property + def worker_devices(self): + raise RuntimeError( + "worker_devices() method unsupported by _DefaultDistributionStrategy.") + + @property + def parameter_devices(self): + raise RuntimeError("parameter_devices() method unsupported by " + "_DefaultDistributionStrategy.") + + def non_slot_devices(self, var_list): + return min(var_list, key=lambda x: x.name) + + def _worker_device_index(self): + raise RuntimeError("worker_device_index() method unsupported by " + "_DefaultDistributionStrategy.") + + +# ------------------------------------------------------------------------------ +# Singletons + +_default_distribution_strategy = _DefaultDistributionStrategy() +_default_tower_context = TowerContext( + _default_distribution_strategy, tower_id=0) +_default_tower_mode = _DefaultTowerThreadMode() diff --git a/tensorflow/python/training/distribute_test.py b/tensorflow/python/training/distribute_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0a4f19c31f6714e1211f9deed9703c02192cc2c0 --- /dev/null +++ b/tensorflow/python/training/distribute_test.py @@ -0,0 +1,104 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test DistributionStrategy, TowerContext, and supporting APIs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test +from tensorflow.python.training import distribute + + +class _TestTowerContext(distribute.TowerContext): + + def merge_call(self, fn, *args, **kwargs): + return kwargs["test_arg"] + + +class _TestStrategy(distribute.DistributionStrategy): + + def _call_for_each_tower(self, fn, *args, **kwargs): + with _TestTowerContext(self, tower_id=0): + return fn(*args, **kwargs) + + def _create_variable(self, next_creator, *args, **kwargs): + return kwargs["name"] + + +def _assert_in_default_state(t): + t.assertIs(distribute._default_tower_context, + distribute.get_tower_context()) + t.assertIs(None, distribute.get_cross_tower_context()) + t.assertIs(distribute._default_distribution_strategy, + distribute.get_distribution_strategy()) + t.assertFalse(distribute.has_distribution_strategy()) + + +class TestStrategyTest(test.TestCase): + + def testCallForEachTower(self): + _assert_in_default_state(self) + dist = _TestStrategy() + + def run_fn(): + tower_context = distribute.get_tower_context() + self.assertTrue(tower_context is not None) + self.assertIs(None, distribute.get_cross_tower_context()) + self.assertTrue(distribute.has_distribution_strategy()) + self.assertIs(dist, distribute.get_distribution_strategy()) + self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo")) + self.assertEqual("bar", variable_scope.variable(1.0, name="bar")) + + with self.assertRaises(RuntimeError): + dist.call_for_each_tower(run_fn) + with dist.scope(): + dist.call_for_each_tower(run_fn) + _assert_in_default_state(self) + + def testScope(self): + _assert_in_default_state(self) + dist = _TestStrategy() + with dist.scope(): + self.assertIs(None, distribute.get_tower_context()) + self.assertIs(dist, distribute.get_cross_tower_context()) + self.assertTrue(distribute.has_distribution_strategy()) + self.assertIs(dist, distribute.get_distribution_strategy()) + self.assertEqual("baz", variable_scope.variable(1.0, name="baz")) + _assert_in_default_state(self) + + +class DefaultDistributionStrategyTest(test.TestCase): + + def testMergeCall(self): + _assert_in_default_state(self) + + def merge_fn(dist, s): + self.assertIs(distribute._default_distribution_strategy, dist) + self.assertIs(None, distribute.get_tower_context()) + self.assertIs(dist, distribute.get_cross_tower_context()) + self.assertIs(dist, distribute.get_distribution_strategy()) + self.assertFalse(distribute.has_distribution_strategy()) + return "foo_" + s + + tower_ctx = distribute.get_tower_context() + self.assertIs(distribute._default_tower_context, tower_ctx) + self.assertEqual("foo_bar", tower_ctx.merge_call(merge_fn, "bar")) + _assert_in_default_state(self) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/training/ftrl.py b/tensorflow/python/training/ftrl.py index 9d02e694db15637126f37ee5575638908b351def..4fa081fab72df62107cf4957d4ff68240ced9ee0 100644 --- a/tensorflow/python/training/ftrl.py +++ b/tensorflow/python/training/ftrl.py @@ -53,7 +53,7 @@ class FtrlOptimizer(optimizer.Optimizer): learning_rate: A float value or a constant float `Tensor`. learning_rate_power: A float value, must be less or equal to zero. initial_accumulator_value: The starting value for accumulators. - Only positive values are allowed. + Only zero or positive values are allowed. l1_regularization_strength: A float value, must be greater than or equal to zero. l2_regularization_strength: A float value, must be greater than or @@ -84,9 +84,10 @@ class FtrlOptimizer(optimizer.Optimizer): """ super(FtrlOptimizer, self).__init__(use_locking, name) - if initial_accumulator_value <= 0.0: - raise ValueError("initial_accumulator_value %f needs to be positive" % - initial_accumulator_value) + if initial_accumulator_value < 0.0: + raise ValueError( + "initial_accumulator_value %f needs to be be positive or zero" % + initial_accumulator_value) if learning_rate_power > 0.0: raise ValueError("learning_rate_power %f needs to be negative or zero" % learning_rate_power) diff --git a/tensorflow/python/training/gradient_descent.py b/tensorflow/python/training/gradient_descent.py index 380e14e02497fbe3681d6bae03fe9c636c5d13aa..6caf29d83af546f821314179e17f7bf1a693ff1a 100644 --- a/tensorflow/python/training/gradient_descent.py +++ b/tensorflow/python/training/gradient_descent.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -43,6 +44,7 @@ class GradientDescentOptimizer(optimizer.Optimizer): """ super(GradientDescentOptimizer, self).__init__(use_locking, name) self._learning_rate = learning_rate + self._learning_rate_tensor = None def _apply_dense(self, grad, var): return training_ops.apply_gradient_descent( @@ -69,5 +71,6 @@ class GradientDescentOptimizer(optimizer.Optimizer): return var.scatter_sub(delta, use_locking=self._use_locking) def _prepare(self): - self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate, - name="learning_rate") + if not context.executing_eagerly() or self._learning_rate_tensor is None: + self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate, + name="learning_rate") diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index bd9985a7c5c181c0431e0c0a91186bc36b11c787..44f00a96deff64012705c4c81b185a9c4fac2295 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -159,7 +159,7 @@ def input_producer(input_tensor, enabled. Please use the `tf.data` API to ingest data under eager execution. @end_compatibility """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError( "Input pipelines based on Queues are not supported when eager execution" " is enabled. Please use tf.data to ingest data into your model" @@ -737,7 +737,7 @@ def _batch(tensors, batch_size, keep_input, num_threads=1, capacity=32, allow_smaller_final_batch=False, shared_name=None, name=None): """Helper function for `batch` and `maybe_batch`.""" - if context.in_eager_mode(): + if context.executing_eagerly(): raise ValueError( "Input pipelines based on Queues are not supported when eager execution" " is enabled. Please use tf.data to ingest data into your model" @@ -775,7 +775,7 @@ def _batch_join(tensors_list, batch_size, keep_input, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, name=None): """Helper function for `batch_join` and `maybe_batch_join`.""" - if context.in_eager_mode(): + if context.executing_eagerly(): raise ValueError( "Input pipelines based on Queues are not supported when eager execution" " is enabled. Please use tf.data to ingest data into your model" @@ -810,7 +810,7 @@ def _shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, shapes=None, allow_smaller_final_batch=False, shared_name=None, name=None): """Helper function for `shuffle_batch` and `maybe_shuffle_batch`.""" - if context.in_eager_mode(): + if context.executing_eagerly(): raise ValueError( "Input pipelines based on Queues are not supported when eager execution" " is enabled. Please use tf.data to ingest data into your model" @@ -855,7 +855,7 @@ def _shuffle_batch_join(tensors_list, batch_size, capacity, allow_smaller_final_batch=False, shared_name=None, name=None): """Helper function for `shuffle_batch_join` and `maybe_shuffle_batch_join`.""" - if context.in_eager_mode(): + if context.executing_eagerly(): raise ValueError( "Input pipelines based on Queues are not supported when eager execution" " is enabled. Please use tf.data to ingest data into your model" diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py index 1ce8c156a0b126f680bad62267f90e31a23febed..60306e4f1239a759ea1f68492a1211d5f0858997 100644 --- a/tensorflow/python/training/learning_rate_decay_test.py +++ b/tensorflow/python/training/learning_rate_decay_test.py @@ -43,8 +43,8 @@ class LRDecayTest(test_util.TensorFlowTestCase): def testStaircase(self): with self.test_session(): - step = gen_state_ops._variable(shape=[], dtype=dtypes.int32, - name="step", container="", shared_name="") + step = gen_state_ops.variable(shape=[], dtype=dtypes.int32, + name="step", container="", shared_name="") assign_100 = state_ops.assign(step, 100) assign_1 = state_ops.assign(step, 1) assign_2 = state_ops.assign(step, 2) @@ -113,7 +113,7 @@ class LRDecayTest(test_util.TensorFlowTestCase): learning_rate_decay.piecewise_constant(x, boundaries, values) # Test that ref types are valid. - if context.in_graph_mode(): + if not context.executing_eagerly(): x = variables.Variable(0.0) x_ref = x.op.outputs[0] # float32_ref tensor should be accepted boundaries, values = [1.0, 2.0], [1, 2, 3] @@ -264,8 +264,8 @@ class ExponentialDecayTest(test_util.TensorFlowTestCase): initial_lr = 0.1 k = 10 decay_rate = 0.96 - step = gen_state_ops._variable(shape=[], dtype=dtypes.int32, - name="step", container="", shared_name="") + step = gen_state_ops.variable( + shape=[], dtype=dtypes.int32, name="step", container="", shared_name="") assign_step = state_ops.assign(step, 0) increment_step = state_ops.assign_add(step, 1) decayed_lr = learning_rate_decay.natural_exp_decay(initial_lr, step, @@ -281,8 +281,8 @@ class ExponentialDecayTest(test_util.TensorFlowTestCase): initial_lr = 0.1 k = 10 decay_rate = 0.96 - step = gen_state_ops._variable(shape=[], dtype=dtypes.int32, - name="step", container="", shared_name="") + step = gen_state_ops.variable( + shape=[], dtype=dtypes.int32, name="step", container="", shared_name="") assign_step = state_ops.assign(step, 0) increment_step = state_ops.assign_add(step, 1) decayed_lr = learning_rate_decay.natural_exp_decay(initial_lr, @@ -304,8 +304,8 @@ class InverseDecayTest(test_util.TensorFlowTestCase): initial_lr = 0.1 k = 10 decay_rate = 0.96 - step = gen_state_ops._variable(shape=[], dtype=dtypes.int32, - name="step", container="", shared_name="") + step = gen_state_ops.variable( + shape=[], dtype=dtypes.int32, name="step", container="", shared_name="") assign_step = state_ops.assign(step, 0) increment_step = state_ops.assign_add(step, 1) decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr, @@ -323,8 +323,8 @@ class InverseDecayTest(test_util.TensorFlowTestCase): initial_lr = 0.1 k = 10 decay_rate = 0.96 - step = gen_state_ops._variable(shape=[], dtype=dtypes.int32, - name="step", container="", shared_name="") + step = gen_state_ops.variable( + shape=[], dtype=dtypes.int32, name="step", container="", shared_name="") assign_step = state_ops.assign(step, 0) increment_step = state_ops.assign_add(step, 1) decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr, diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py index cda421cef837fa6ab25898208a8dc94d70561048..297a8bbde5447cff9465be36c0bb71f2490c60fc 100644 --- a/tensorflow/python/training/momentum_test.py +++ b/tensorflow/python/training/momentum_test.py @@ -66,7 +66,7 @@ class MomentumOptimizerTest(test.TestCase): mom_update = mom_opt.apply_gradients( zip([grads0, grads1], [var0, var1])) - if context.in_graph_mode(): + if not context.executing_eagerly(): self.evaluate(variables.global_variables_initializer()) # Fetch params to validate initial values self.assertAllClose([1.0, 2.0], self.evaluate(var0)) @@ -78,13 +78,13 @@ class MomentumOptimizerTest(test.TestCase): self.assertEquals(slot0.get_shape(), var0.get_shape()) slot1 = mom_opt.get_slot(var1, "momentum") self.assertEquals(slot1.get_shape(), var1.get_shape()) - if context.in_graph_mode(): + if not context.executing_eagerly(): self.assertFalse(slot0 in variables.trainable_variables()) self.assertFalse(slot1 in variables.trainable_variables()) # Step 1: the momentum accumulators where 0. So we should see a normal # update: v -= grad * learning_rate - if context.in_graph_mode(): + if not context.executing_eagerly(): self.evaluate(mom_update) # Check that the momentum accumulators have been updated. self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), @@ -99,10 +99,10 @@ class MomentumOptimizerTest(test.TestCase): np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), self.evaluate(var1)) # Step 2: the momentum accumulators contain the previous update. - if context.in_graph_mode(): - self.evaluate(mom_update) - else: + if context.executing_eagerly(): mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + else: + self.evaluate(mom_update) # Check that the momentum accumulators have been updated. self.assertAllCloseAccordingToType( np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), @@ -142,7 +142,7 @@ class MomentumOptimizerTest(test.TestCase): [1.0, 2.0], dtype=dtypes.float32, name="var0") var1 = resource_variable_ops.ResourceVariable( [3.0, 4.0], dtype=dtypes.float32, name="var1") - if context.in_eager_mode(): + if context.executing_eagerly(): loss = lambda: math_ops.reduce_sum(var0 + var1) else: loss = math_ops.reduce_sum(var0 + var1) @@ -157,7 +157,7 @@ class MomentumOptimizerTest(test.TestCase): [1.0, 2.0], dtype=dtypes.float32, name="var2") var3 = resource_variable_ops.ResourceVariable( [3.0, 4.0], dtype=dtypes.float32, name="var3") - if context.in_eager_mode(): + if context.executing_eagerly(): loss = lambda: math_ops.reduce_sum(var2 + var3) else: loss = math_ops.reduce_sum(var2 + var3) diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index 6c5c9e01a76d539b550420134b09090b89beed46..2d4f09a60a518471b4f1c8104bf606953f0f296d 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -281,13 +281,14 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name scaffold=None, hooks=None, chief_only_hooks=None, - save_checkpoint_secs=600, + save_checkpoint_secs=USE_DEFAULT, save_summaries_steps=USE_DEFAULT, save_summaries_secs=USE_DEFAULT, config=None, stop_grace_period_secs=120, log_step_count_steps=100, - max_wait_secs=7200): + max_wait_secs=7200, + save_checkpoint_steps=USE_DEFAULT): """Creates a `MonitoredSession` for training. For a chief, this utility sets proper session initializer/restorer. It also @@ -310,8 +311,10 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if `is_chief==True`, ignore otherwise. save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved - using a default checkpoint saver. If `save_checkpoint_secs` is set to - `None`, then the default checkpoint saver isn't used. + using a default checkpoint saver. If both `save_checkpoint_steps` and + `save_checkpoint_secs` are set to `None`, then the default checkpoint + saver isn't used. If both are provided, then only `save_checkpoint_secs` + is used. Default 600. save_summaries_steps: The frequency, in number of global steps, that the summaries are written to disk using a default summary saver. If both `save_summaries_steps` and `save_summaries_secs` are set to `None`, then @@ -330,6 +333,11 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name become available. This should be kept relatively short to help detect incorrect code, but sometimes may need to be increased if the chief takes a while to start up. + save_checkpoint_steps: The frequency, in number of global steps, that a + checkpoint is saved using a default checkpoint saver. If both + `save_checkpoint_steps` and `save_checkpoint_secs` are set to `None`, then + the default checkpoint saver isn't used. If both are provided, then only + `save_checkpoint_secs` is used. Default not enabled. Returns: A `MonitoredSession` object. @@ -342,6 +350,15 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name elif save_summaries_steps == USE_DEFAULT: save_summaries_steps = None + if save_checkpoint_steps == USE_DEFAULT and \ + save_checkpoint_secs == USE_DEFAULT: + save_checkpoint_steps = None + save_checkpoint_secs = 600 + elif save_checkpoint_secs == USE_DEFAULT: + save_checkpoint_secs = None + elif save_checkpoint_steps == USE_DEFAULT: + save_checkpoint_steps = None + scaffold = scaffold or Scaffold() if not is_chief: session_creator = WorkerSessionCreator( @@ -374,9 +391,13 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name save_steps=save_summaries_steps, save_secs=save_summaries_secs, output_dir=checkpoint_dir)) - if save_checkpoint_secs and save_checkpoint_secs > 0: + if (save_checkpoint_secs and save_checkpoint_secs > 0) or ( + save_checkpoint_steps and save_checkpoint_steps > 0): all_hooks.append(basic_session_run_hooks.CheckpointSaverHook( - checkpoint_dir, save_secs=save_checkpoint_secs, scaffold=scaffold)) + checkpoint_dir, + save_steps=save_checkpoint_steps, + save_secs=save_checkpoint_secs, + scaffold=scaffold)) if hooks: all_hooks.extend(hooks) diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py index 159b2d5c1605bdd95303efb25690f55a54a3625d..3806056f01a73d21faf3de4539c0dd1ada5f96f8 100644 --- a/tensorflow/python/training/monitored_session_test.py +++ b/tensorflow/python/training/monitored_session_test.py @@ -282,6 +282,42 @@ class MonitoredTrainingSessionTest(test.TestCase): is_chief=True, checkpoint_dir=logdir) as session: self.assertEqual(2, session.run(gstep)) + def test_save_checkpoint_steps(self): + logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_steps') + with ops.Graph().as_default(): + gstep = variables_lib.get_or_create_global_step() + new_gstep = state_ops.assign_add(gstep, 1) + with monitored_session.MonitoredTrainingSession( + is_chief=True, + checkpoint_dir=logdir, + save_checkpoint_steps=100, + log_step_count_steps=10) as session: + for _ in range(100): + session.run(new_gstep) + # A restart will find the checkpoint and recover automatically. + with monitored_session.MonitoredTrainingSession( + is_chief=True, checkpoint_dir=logdir) as session: + self.assertEqual(100, session.run(gstep)) + + def test_save_checkpoint_secs(self): + logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_secs') + with ops.Graph().as_default(): + gstep = variables_lib.get_or_create_global_step() + new_gstep = state_ops.assign_add(gstep, 1) + with monitored_session.MonitoredTrainingSession( + is_chief=True, + checkpoint_dir=logdir, + save_checkpoint_secs=0.1, + log_step_count_steps=10) as session: + session.run(new_gstep) + time.sleep(0.2) + for _ in range(10): + session.run(new_gstep) + # A restart will find the checkpoint and recover automatically. + with monitored_session.MonitoredTrainingSession( + is_chief=True, checkpoint_dir=logdir) as session: + self.assertEqual(11, session.run(gstep)) + def test_summaries_steps(self): logdir = _test_dir(self.get_temp_dir(), 'test_summaries_steps') with ops.Graph().as_default(): diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index b9ecb27df19d051c28ec1c3fe3cd9fd86717a5ed..61fc828a840c490b0f787119134a0941f60f947a 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -52,16 +52,19 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None): they were created in and the scope of the variables they debias. They are also given a uniqifying-suffix. - Ex: + E.g.: + + ``` with tf.variable_scope('scope1'): with tf.variable_scope('scope2'): var = tf.get_variable('foo') - assign_moving_average(var, 0.0, 1.0) - assign_moving_average(var, 0.0, 0.9) + tf.assign_moving_average(var, 0.0, 1.0) + tf.assign_moving_average(var, 0.0, 0.9) - var.name: 'scope1/scope2/foo' - shadow var names: 'scope1/scope2/scope1/scope2/foo/biased' - 'scope1/scope2/scope1/scope2/foo/biased_1' + # var.name: 'scope1/scope2/foo' + # shadow var names: 'scope1/scope2/scope1/scope2/foo/biased' + # 'scope1/scope2/scope1/scope2/foo/biased_1' + ``` Args: variable: A Variable. diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py index 6efdeb286657e761a4c46634b9408121765a447b..6717811bbb0f05723a5ad0fbcbfba75249d0d43b 100644 --- a/tensorflow/python/training/moving_averages_test.py +++ b/tensorflow/python/training/moving_averages_test.py @@ -376,7 +376,7 @@ class ExponentialMovingAverageTest(test.TestCase): with ops.device("/job:dev_v0"): v0 = variables.Variable(10.0, name="v0") with ops.device("/job:dev_v1"): - v1 = gen_state_ops._variable( + v1 = gen_state_ops.variable( shape=[1], dtype=dtypes.float32, name="v1", diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 678d6322aa5ecea0a603b6a9858f7619638eae30..bf79714f9682e60b97788b8b470821cfe9290886 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -40,19 +40,6 @@ from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export -def _get_variable_for(v): - """Returns the ResourceVariable responsible for v, or v if not necessary.""" - if context.in_eager_mode(): - return v - if v.op.type == "VarHandleOp": - for var in variables.trainable_variables(): - if (isinstance(var, resource_variable_ops.ResourceVariable) - and var.handle.op is v.op): - return var - raise ValueError("Got %s but could not locate source variable." % (str(v))) - return v - - def _deduplicate_indexed_slices(values, indices): """Sums `values` associated with any non-unique `indices`. @@ -73,8 +60,8 @@ def _deduplicate_indexed_slices(values, indices): def _var_key(var): - if context.in_eager_mode(): - return var._shared_name # pylint: disable=protected-access + if context.executing_eagerly(): + return var._unique_id # pylint: disable=protected-access return (var.op.graph, var.op.name) @@ -98,6 +85,9 @@ class _RefVariableProcessor(_OptimizableVariable): def __init__(self, v): self._v = v + def __str__(self): + return "<_RefVariableProcessor(%s)>" % self._v + def target(self): return self._v._ref() # pylint: disable=protected-access @@ -196,11 +186,15 @@ class _TensorProcessor(_OptimizableVariable): def _get_processor(v): """The processor of v.""" - if context.in_eager_mode(): + if context.executing_eagerly(): if isinstance(v, ops.Tensor): return _TensorProcessor(v) else: return _DenseResourceVariableProcessor(v) + if isinstance( + v, resource_variable_ops.ResourceVariable) and not v._in_graph_mode: # pylint: disable=protected-access + # True if and only if `v` was initialized eagerly. + return _DenseResourceVariableProcessor(v) if v.op.type == "VarHandleOp": return _DenseResourceVariableProcessor(v) if isinstance(v, variables.Variable): @@ -213,7 +207,11 @@ def _get_processor(v): @tf_export("train.Optimizer") -class Optimizer(checkpointable.Checkpointable): +class Optimizer( + # Optimizers inherit from CheckpointableBase rather than Checkpointable + # since they do most of their dependency management themselves (slot + # variables are special-cased, and non-slot variables are keyed to graphs). + checkpointable.CheckpointableBase): """Base class for optimizers. This class defines the API to add Ops to train a model. You never use this @@ -453,7 +451,7 @@ class Optimizer(checkpointable.Checkpointable): var_list = tape.watched_variables() grads = tape.gradient(loss_value, var_list, grad_loss) return list(zip(grads, var_list)) - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError( "`loss` passed to Optimizer.compute_gradients should " "be a function when eager execution is enabled.") @@ -542,7 +540,7 @@ class Optimizer(checkpointable.Checkpointable): raise ValueError("No gradients provided for any variable: %s." % ([str(v) for _, _, v in converted_grads_and_vars],)) with ops.init_scope(): - self._create_slots([_get_variable_for(v) for v in var_list]) + self._create_slots(var_list) update_ops = [] with ops.name_scope(name, self._name) as name: self._prepare() @@ -552,7 +550,12 @@ class Optimizer(checkpointable.Checkpointable): # We colocate all ops created in _apply_dense or _apply_sparse # on the same device as the variable. # TODO(apassos): figure out how to get the variable name here. - scope_name = var.op.name if context.in_graph_mode() else "" + if context.executing_eagerly() or isinstance( + var, + resource_variable_ops.ResourceVariable) and not var._in_graph_mode: # pylint: disable=protected-access + scope_name = "" + else: + scope_name = var.op.name with ops.name_scope("update_" + scope_name), ops.colocate_with(var): update_ops.append(processor.update_op(self, grad)) if global_step is None: @@ -570,7 +573,7 @@ class Optimizer(checkpointable.Checkpointable): else: apply_updates = state_ops.assign_add(global_step, 1, name=name) - if context.in_graph_mode(): + if not context.executing_eagerly(): if isinstance(apply_updates, ops.Tensor): apply_updates = apply_updates.op train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) @@ -620,7 +623,7 @@ class Optimizer(checkpointable.Checkpointable): Returns: A list of variables. """ - executing_eagerly = context.in_eager_mode() + executing_eagerly = context.executing_eagerly() current_graph = ops.get_default_graph() def _from_current_graph(variable): @@ -642,20 +645,54 @@ class Optimizer(checkpointable.Checkpointable): def _create_non_slot_variable(self, initial_value, name, colocate_with): """Add an extra variable, not associated with a slot.""" - if context.in_graph_mode(): - graph = colocate_with.graph - else: - graph = None + eager = context.executing_eagerly() + graph = None if eager else colocate_with.graph key = (name, graph) v = self._non_slot_dict.get(key, None) if v is None: + self._maybe_initialize_checkpointable() with ops.colocate_with(colocate_with): + if eager: + restored_initial_value = self._preload_simple_restoration( + name=name, shape=None) + if restored_initial_value is not None: + initial_value = restored_initial_value v = variable_scope.variable(initial_value, name=name, trainable=False) + # Restore this variable by name if necessary, but don't add a + # Checkpointable dependency. Optimizers return the current graph's + # non-slot variables from _checkpoint_dependencies explicitly rather + # than unconditionally adding dependencies (since there may be multiple + # non-slot variables with the same name in different graphs, trying to + # save all of them would result in errors). + self._handle_deferred_dependencies(name=name, checkpointable=v) self._non_slot_dict[key] = v return v + @property + def _checkpoint_dependencies(self): + """From Checkpointable. Gather graph-specific non-slot variables to save.""" + current_graph_non_slot_variables = [] + current_graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access + for (name, _), variable_object in sorted(self._non_slot_dict.items(), + # Avoid comparing graphs + key=lambda item: item[0][0]): + if variable_object._graph_key == current_graph_key: # pylint: disable=protected-access + current_graph_non_slot_variables.append( + checkpointable.CheckpointableReference( + name=name, ref=variable_object)) + return (super(Optimizer, self)._checkpoint_dependencies + + current_graph_non_slot_variables) + + def _lookup_dependency(self, name): + """From Checkpointable. Find a non-slot variable in the current graph.""" + unconditional = super(Optimizer, self)._lookup_dependency(name) + if unconditional is not None: + return unconditional + graph = None if context.executing_eagerly() else ops.get_default_graph() + return self._get_non_slot_variable(name, graph=graph) + def _get_non_slot_variable(self, name, graph=None): return self._non_slot_dict.get((name, graph), None) @@ -987,9 +1024,8 @@ class Optimizer(checkpointable.Checkpointable): named_slots = self._slot_dict(slot_name) variable_key = _var_key(variable) slot_variable = named_slots.get(variable_key, None) - if (slot_variable is None - and context.in_eager_mode() - and slot_variable_position.is_simple_variable()): + if (slot_variable is None and context.executing_eagerly() and + slot_variable_position.is_simple_variable()): initializer = checkpointable.CheckpointInitialValue( checkpoint_position=slot_variable_position) slot_variable = self._get_or_make_slot( diff --git a/tensorflow/python/training/quantize_training.i b/tensorflow/python/training/quantize_training.i index 17ffcd6e0758c9c1bc8bab864b6b7a2a18bc9cbf..fb5e47efa0259d02df3ccf2e9b1430e027f8fcfb 100644 --- a/tensorflow/python/training/quantize_training.i +++ b/tensorflow/python/training/quantize_training.i @@ -56,6 +56,11 @@ PyObject* DoQuantizeTrainingOnGraphDefHelper( %insert("python") %{ def do_quantize_training_on_graphdef(input_graph, num_bits): + """A general quantization scheme is being developed in @{tf.contrib.quantize}. + + Consider using that instead, though since it is in the tf.contrib namespace, + it is not subject to backward compatibility guarantees. + """ from tensorflow.core.framework.graph_pb2 import GraphDef from tensorflow.python.framework import errors with errors.raise_exception_on_not_ok_status() as status: diff --git a/tensorflow/python/training/queue_runner_impl.py b/tensorflow/python/training/queue_runner_impl.py index 07afba79abf4d636c9ec2d53bcf2641594a35733..d38c5499c73e1217effbc907077236cb6c8e0ae8 100644 --- a/tensorflow/python/training/queue_runner_impl.py +++ b/tensorflow/python/training/queue_runner_impl.py @@ -89,7 +89,7 @@ class QueueRunner(object): restoring from `queue_runner_def`. RuntimeError: If eager execution is enabled. """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError( "QueueRunners are not supported when eager execution is enabled. " "Instead, please use tf.data to get data into your model.") @@ -441,7 +441,7 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True, use the `tf.data` API instead. @end_compatibility """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("Queues are not compatible with eager execution.") if sess is None: sess = ops.get_default_session() diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 3888e9bba42dc89055638ad0abe2b7e1a9f5b548..ba0d0384758f25cc2cc6264b9b73e47f15359721 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -50,6 +50,7 @@ from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpointable from tensorflow.python.training import training_util from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState from tensorflow.python.util import compat @@ -196,8 +197,8 @@ class BaseSaverBuilder(object): # Copy the restored tensor to the variable's device. with ops.device(self._var_device): restored_tensor = array_ops.identity(restored_tensor) - return resource_variable_ops.shape_safe_assign_variable_handle( - self.handle_op, self._var_shape, restored_tensor) + return resource_variable_ops.shape_safe_assign_variable_handle( + self.handle_op, self._var_shape, restored_tensor) def __init__(self, write_version=saver_pb2.SaverDef.V2): self._write_version = write_version @@ -310,8 +311,7 @@ class BaseSaverBuilder(object): Returns: A string tensor. """ - # pylint: disable=protected-access - return gen_io_ops._sharded_filename(filename_tensor, shard, num_shards) + return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards) def _AddSaveOps(self, filename_tensor, saveables): """Add ops to save variables that are on the same shard. @@ -420,8 +420,7 @@ class BaseSaverBuilder(object): sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) # Return the sharded name for the save path. with ops.control_dependencies([x.op for x in sharded_saves]): - # pylint: disable=protected-access - return gen_io_ops._sharded_filespec(filename_tensor, num_shards_tensor) + return gen_io_ops.sharded_filespec(filename_tensor, num_shards_tensor) def _AddRestoreOps(self, filename_tensor, @@ -577,10 +576,33 @@ class BaseSaverBuilder(object): names_to_saveables[name].append(var) else: names_to_saveables[name] = [var] + elif (isinstance(var, checkpointable.CheckpointableBase) + and not isinstance(var, variables.Variable)): + checkpointable_saveables = [ + (factory() if callable(factory) else factory) + for factory in var._gather_saveables_for_checkpoint().values()] + names_to_saveables.update( + BaseSaverBuilder.OpListToDict(checkpointable_saveables)) else: - if context.in_graph_mode(): + if context.executing_eagerly(): + if not isinstance(var, resource_variable_ops.ResourceVariable): + raise ValueError( + "Can only save/restore ResourceVariables when eager execution " + "is enabled, type: %s." % type(var)) + set_var = names_to_saveables.setdefault(var._shared_name, var) + if set_var is not var: + raise ValueError( + ("Two different ResourceVariable objects with the same " + "shared_name '%s' were passed to the Saver. This likely means " + "that they were created in different Graphs or isolation " + "contexts, and may not be checkpointed together.") % + (var._shared_name,)) + else: if convert_variable_to_tensor: - var = ops.internal_convert_to_tensor(var, as_ref=True) + if isinstance(var, resource_variable_ops.ResourceVariable): + var = var._graph_element # pylint: disable=protected-access + else: + var = ops.internal_convert_to_tensor(var, as_ref=True) if not BaseSaverBuilder._IsVariable(var): raise TypeError("Variable to save is not a Variable: %s" % var) if var.op.type == "ReadVariableOp": @@ -591,18 +613,6 @@ class BaseSaverBuilder(object): raise ValueError("At least two variables have the same name: %s" % name) names_to_saveables[name] = var - else: - if not isinstance(var, resource_variable_ops.ResourceVariable): - raise ValueError("Can only save/restore ResourceVariable eager " - "mode is enabled, type: %s." % type(var)) - set_var = names_to_saveables.setdefault(var._shared_name, var) - if set_var is not var: - raise ValueError( - ("Two different ResourceVariable objects with the same " - "shared_name '%s' were passed to the Saver. This likely means " - "that they were created in different Graphs or isolation " - "contexts, and may not be checkpointed together.") % ( - var._shared_name,)) # pylint: enable=protected-access return names_to_saveables @@ -664,13 +674,16 @@ class BaseSaverBuilder(object): # pylint: enable=protected-access else: # A variable or tensor. - if context.in_eager_mode(): + if context.executing_eagerly(): if not isinstance(op, resource_variable_ops.ResourceVariable): raise ValueError("Can only save/restore ResourceVariable eager " "mode is enabled, type: %s." % type(op)) saveable = BaseSaverBuilder.ResourceVariableSaveable(op, "", name) else: - variable = ops.internal_convert_to_tensor(op, as_ref=True) + if isinstance(op, resource_variable_ops.ResourceVariable): + variable = op._graph_element # pylint: disable=protected-access + else: + variable = ops.internal_convert_to_tensor(op, as_ref=True) if not BaseSaverBuilder._IsVariable(variable): raise TypeError("names_to_saveables must be a dict mapping string " "names to Tensors/Variables. Not a variable: %s" % @@ -768,8 +781,10 @@ class BaseSaverBuilder(object): build_save=True, build_restore=True): """build() with option to only perform save and restore.""" - if context.in_graph_mode() and (not build_save or not build_restore): - raise ValueError("Graph mode needs to build save and restore together.") + if not context.executing_eagerly() and (not build_save or + not build_restore): + raise ValueError("save and restore operations need to be built together " + " when eager execution is not enabled.") saveables = self._ValidateAndSliceInputs(names_to_saveables) if max_to_keep is None: @@ -806,22 +821,22 @@ class BaseSaverBuilder(object): # such usage model makes sense. # # assert restore_op.name.endswith("restore_all"), restore_op.name - if context.in_graph_mode(): + if context.executing_eagerly(): + # Store the tensor values to the tensor_names. + save_tensor_name = save_tensor.numpy() if build_save else "" return saver_pb2.SaverDef( - filename_tensor_name=filename_tensor.name, - save_tensor_name=save_tensor.name, - restore_op_name=restore_op.name, + filename_tensor_name=filename_tensor.numpy(), + save_tensor_name=save_tensor_name, + restore_op_name="", max_to_keep=max_to_keep, sharded=sharded, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, version=self._write_version) else: - # Store the tensor values to the tensor_names. - save_tensor_name = save_tensor.numpy() if build_save else "" return saver_pb2.SaverDef( - filename_tensor_name=filename_tensor.numpy(), - save_tensor_name=save_tensor_name, - restore_op_name="", + filename_tensor_name=filename_tensor.name, + save_tensor_name=save_tensor.name, + restore_op_name=restore_op.name, max_to_keep=max_to_keep, sharded=sharded, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, @@ -1120,8 +1135,9 @@ class Saver(object): the proliferation of checkpoint files on disk: * `max_to_keep` indicates the maximum number of recent checkpoint files to - keep. As new files are created, older files are deleted. If None or 0, - all checkpoint files are kept. Defaults to 5 (that is, the 5 most recent + keep. As new files are created, older files are deleted. If None or 0, + no checkpoints are deleted from the filesystem but only the last one is + kept in the `checkpoint` file. Defaults to 5 (that is, the 5 most recent checkpoint files are kept.) * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent @@ -1270,7 +1286,7 @@ class Saver(object): raise ValueError( "If `var_list` is provided then build cannot be deferred. " "Either set defer_build=False or var_list=None.") - if context.in_eager_mode() and var_list is None: + if context.executing_eagerly() and var_list is None: raise RuntimeError( "When eager execution is enabled, `var_list` must specify a list or " "dict of variables to save") @@ -1289,7 +1305,12 @@ class Saver(object): self._write_version = write_version self._pad_step_number = pad_step_number self._filename = filename - if not defer_build and context.in_graph_mode(): + self._last_checkpoints = [] + self._checkpoints_to_be_deleted = [] + if context.executing_eagerly(): + self._next_checkpoint_time = ( + time.time() + self._keep_checkpoint_every_n_hours * 3600) + elif not defer_build: self.build() if self.saver_def: self._check_saver_def() @@ -1297,7 +1318,7 @@ class Saver(object): self._save_relative_paths = save_relative_paths def build(self): - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("Use save/restore instead of build in eager mode.") self._build(self._filename, build_save=True, build_restore=True) @@ -1307,12 +1328,12 @@ class Saver(object): def _build(self, checkpoint_path, build_save, build_restore): """Builds saver_def.""" - if context.in_graph_mode(): + if not context.executing_eagerly(): if self._is_built: return self._is_built = True - if not self.saver_def or context.in_eager_mode(): + if not self.saver_def or context.executing_eagerly(): if self._builder is None: self._builder = BulkSaverBuilder(self._write_version) @@ -1349,17 +1370,17 @@ class Saver(object): self.saver_def.restore_op_name, self._name) self._check_saver_def() - # Updates next checkpoint time. - self._next_checkpoint_time = ( - time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600) - self._last_checkpoints = [] - self._checkpoints_to_be_deleted = [] + if not context.executing_eagerly(): + # Updates next checkpoint time. + # Set in __init__ when executing eagerly. + self._next_checkpoint_time = ( + time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600) def _check_saver_def(self): if not isinstance(self.saver_def, saver_pb2.SaverDef): raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" % self.saver_def) - if context.in_graph_mode(): + if not context.executing_eagerly(): if not self.saver_def.save_tensor_name: raise ValueError("saver_def must specify the save_tensor_name: %s" % str(self.saver_def)) @@ -1609,7 +1630,7 @@ class Saver(object): RuntimeError: If save and restore ops weren't built. """ # pylint: enable=line-too-long - if not self._is_built and context.in_graph_mode(): + if not self._is_built and not context.executing_eagerly(): raise RuntimeError( "`build()` should be called before save if defer_build==True") if latest_filename is None: @@ -1641,21 +1662,21 @@ class Saver(object): "'latest_filename' collides with 'save_path': '%s' and '%s'" % (latest_filename, save_path)) - if (context.in_graph_mode() and + if (not context.executing_eagerly() and not isinstance(sess, session.SessionInterface)): raise TypeError("'sess' must be a Session; %s" % sess) save_path_parent = os.path.dirname(save_path) if not self._is_empty: try: - if context.in_graph_mode(): - model_checkpoint_path = sess.run( - self.saver_def.save_tensor_name, - {self.saver_def.filename_tensor_name: checkpoint_file}) - else: + if context.executing_eagerly(): self._build_eager( checkpoint_file, build_save=True, build_restore=False) model_checkpoint_path = self.saver_def.save_tensor_name + else: + model_checkpoint_path = sess.run( + self.saver_def.save_tensor_name, + {self.saver_def.filename_tensor_name: checkpoint_file}) model_checkpoint_path = compat.as_str(model_checkpoint_path) if write_state: @@ -1677,7 +1698,7 @@ class Saver(object): if write_meta_graph: meta_graph_filename = self._MetaGraphFilename( checkpoint_file, meta_graph_suffix=meta_graph_suffix) - if context.in_graph_mode(): + if not context.executing_eagerly(): with sess.graph.as_default(): self.export_meta_graph( meta_graph_filename, strip_default_attrs=strip_default_attrs) @@ -1750,11 +1771,11 @@ class Saver(object): if save_path is None: raise ValueError("Can't load save_path when it is None.") logging.info("Restoring parameters from %s", save_path) - if context.in_graph_mode(): + if context.executing_eagerly(): + self._build_eager(save_path, build_save=False, build_restore=True) + else: sess.run(self.saver_def.restore_op_name, {self.saver_def.filename_tensor_name: save_path}) - else: - self._build_eager(save_path, build_save=False, build_restore=True) @staticmethod def _add_collection_def(meta_graph_def, key, export_scope=None): @@ -1894,7 +1915,7 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False, execution is enabled. @end_compatibility """ # pylint: disable=g-doc-exception - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("Exporting/importing meta graphs is not supported when " "eager execution is enabled. No graph exists when eager " "execution is enabled.") @@ -1949,7 +1970,7 @@ def export_meta_graph(filename=None, saver_def: `SaverDef` protocol buffer. collection_list: List of string keys to collect. as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. - graph: The `Graph` to import into. If `None`, use the default graph. + graph: The `Graph` to export. If `None`, use the default graph. export_scope: Optional `string`. Name scope under which to extract the subgraph. The scope name will be striped from the node definitions for easy import later into new name scopes. If `None`, the whole graph @@ -1977,7 +1998,7 @@ def export_meta_graph(filename=None, @end_compatibility """ # pylint: enable=line-too-long - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("Exporting/importing meta graphs is not supported when " "eager execution is enabled. No graph exists when eager " "execution is enabled.") diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index c5a6f49df599434ab3bc1a9fe3d85db6f824071e..7de778f298e0fb0d62d45abdd280b673f1068213 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -35,6 +35,7 @@ from google.protobuf import text_format from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import queue_runner_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.protobuf import saver_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.client import session @@ -53,6 +54,7 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import partitioned_variables @@ -66,6 +68,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary import summary from tensorflow.python.training import adam +from tensorflow.python.training import checkpointable from tensorflow.python.training import gradient_descent from tensorflow.python.training import queue_runner_impl from tensorflow.python.training import saver as saver_module @@ -89,7 +92,7 @@ class SaverTest(test.TestCase): v2_init = v2.insert("k1", 30.0) # Initialize all variables - if context.in_graph_mode(): + if not context.executing_eagerly(): self.evaluate([variables.global_variables_initializer(), v2_init]) # Check that the parameter nodes have been initialized. @@ -117,7 +120,7 @@ class SaverTest(test.TestCase): v2 = saver_test_utils.CheckpointedOp(name="v2") # Assert that the variables are not initialized. - if context.in_graph_mode(): + if not context.executing_eagerly(): self.assertEqual( len(variables.report_uninitialized_variables().eval()), 2) self.assertEqual(0, len(v2.keys().eval())) @@ -140,7 +143,7 @@ class SaverTest(test.TestCase): v2_init = v2_2.insert("k1000", 3000.0) # Check that the parameter nodes have been initialized. - if context.in_graph_mode(): + if not context.executing_eagerly(): init_all_op = [variables.global_variables_initializer(), v2_init] self.evaluate(init_all_op) # TODO(xpan): Why _mutable_hash_table_v2 doesn't create empty @@ -249,10 +252,10 @@ class SaverTest(test.TestCase): with self.test_session(graph=ops_lib.Graph()) as sess: v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0", name="v") - if context.in_graph_mode(): - self.evaluate(variables.global_variables_initializer()) - else: + if context.executing_eagerly(): sess = None + else: + self.evaluate(variables.global_variables_initializer()) save = saver_module.Saver([v]) save.save(sess, save_path) @@ -260,6 +263,24 @@ class SaverTest(test.TestCase): save2.restore(sess, save_path) self.assertEquals(self.evaluate(v), [1]) + def testNoAdditionalOpsAddedBySaverForResourceVariablesOutsideSaveScope(self): + with ops_lib.Graph().as_default() as g: + v = resource_variable_ops.ResourceVariable(1.0, name="v") + with ops_lib.name_scope("saver1"): + saver_module.Saver() + with ops_lib.name_scope("saver2"): + saver_module.Saver({"name": v}) + ops_in_saver1_scope_but_not_save_scope = [ + op for op in g.get_operations() + if (op.name.startswith("saver1/") and + not op.name.startswith("saver1/save/"))] + self.assertEqual(ops_in_saver1_scope_but_not_save_scope, []) + ops_in_saver2_scope_but_not_save_scope = [ + op for op in g.get_operations() + if (op.name.startswith("saver2/") and + not op.name.startswith("saver2/save/"))] + self.assertEqual(ops_in_saver2_scope_but_not_save_scope, []) + def testSaveCopyRestoreWithSaveRelativePaths(self): """Save, copy checkpoint dir and restore from copied dir. @@ -497,7 +518,7 @@ class SaverTest(test.TestCase): with self.test_session(graph=ops_lib.Graph()) as sess: var = resource_variable_ops.ResourceVariable(var_value, name=var_name) save = saver_module.Saver({var_name: var}) - if context.in_graph_mode(): + if not context.executing_eagerly(): self.evaluate(var.initializer) val = save.save(sess, save_path) self.assertEqual(save_path, val) @@ -657,11 +678,11 @@ class SaverTest(test.TestCase): { var._shared_name: var }, pad_step_number=pad_step_number) - if context.in_graph_mode(): + if context.executing_eagerly(): + sess = None + else: self.evaluate(var.initializer) sess = ops_lib.get_default_session() - else: - sess = None if use_tensor: global_step = constant_op.constant(global_step_int) val = save.save(sess, save_path, global_step=global_step) @@ -1039,6 +1060,77 @@ class MaxToKeepTest(test.TestCase): self.assertEqual(checkpoint_state.all_model_checkpoint_paths, all_model_checkpoint_paths) + def testMaxToKeepEager(self): + with context.eager_mode(): + save_dir = self._get_test_dir("max_to_keep_non_sharded") + + v = variable_scope.variable(10.0, name="v") + save = saver_module.Saver({"v": v}, max_to_keep=2) + self.evaluate(variables.global_variables_initializer()) + if not context.executing_eagerly(): + self.assertEqual([], save.last_checkpoints) + + s1 = save.save(None, os.path.join(save_dir, "s1")) + self.assertEqual([s1], save.last_checkpoints) + self.assertTrue(saver_module.checkpoint_exists(s1)) + self.assertCheckpointState( + model_checkpoint_path=s1, + all_model_checkpoint_paths=[s1], + save_dir=save_dir) + + s2 = save.save(None, os.path.join(save_dir, "s2")) + self.assertEqual([s1, s2], save.last_checkpoints) + self.assertTrue(saver_module.checkpoint_exists(s1)) + self.assertTrue(saver_module.checkpoint_exists(s2)) + self.assertCheckpointState( + model_checkpoint_path=s2, + all_model_checkpoint_paths=[s1, s2], + save_dir=save_dir) + + s3 = save.save(None, os.path.join(save_dir, "s3")) + self.assertEqual([s2, s3], save.last_checkpoints) + self.assertFalse(saver_module.checkpoint_exists(s1)) + self.assertTrue(saver_module.checkpoint_exists(s2)) + self.assertTrue(saver_module.checkpoint_exists(s3)) + self.assertCheckpointState( + model_checkpoint_path=s3, + all_model_checkpoint_paths=[s2, s3], + save_dir=save_dir) + + # Create a second helper, identical to the first. + save2 = saver_module.Saver({"v": v}, max_to_keep=2) + save2.set_last_checkpoints(save.last_checkpoints) + + # Exercise the first helper. + + # Adding s2 again (old s2 is removed first, then new s2 appended) + s2 = save.save(None, os.path.join(save_dir, "s2")) + self.assertEqual([s3, s2], save.last_checkpoints) + self.assertFalse(saver_module.checkpoint_exists(s1)) + self.assertTrue(saver_module.checkpoint_exists(s3)) + self.assertTrue(saver_module.checkpoint_exists(s2)) + self.assertCheckpointState( + model_checkpoint_path=s2, + all_model_checkpoint_paths=[s3, s2], + save_dir=save_dir) + + # Adding s1 (s3 should now be deleted as oldest in list) + s1 = save.save(None, os.path.join(save_dir, "s1")) + self.assertEqual([s2, s1], save.last_checkpoints) + self.assertFalse(saver_module.checkpoint_exists(s3)) + self.assertTrue(saver_module.checkpoint_exists(s2)) + self.assertCheckpointState( + model_checkpoint_path=s1, + all_model_checkpoint_paths=[s2, s1], + save_dir=save_dir) + + s2 = save2.save(None, os.path.join(save_dir, "s2")) + self.assertEqual([s3, s2], save2.last_checkpoints) + # Created by the first helper. + self.assertTrue(saver_module.checkpoint_exists(s1)) + # Deleted by the first helper. + self.assertFalse(saver_module.checkpoint_exists(s3)) + def testNonSharded(self): save_dir = self._get_test_dir("max_to_keep_non_sharded") @@ -1301,15 +1393,16 @@ class KeepCheckpointEveryNHoursTest(test.TestCase): gfile.MakeDirs(test_dir) return test_dir + @test_util.run_in_graph_and_eager_modes() @test.mock.patch.object(saver_module, "time") def testNonSharded(self, mock_time): save_dir = self._get_test_dir("keep_checkpoint_every_n_hours") with self.test_session() as sess: - v = variables.Variable([10.0], name="v") + v = variable_scope.variable([10.0], name="v") # Run the initializer NOW to avoid the 0.5s overhead of the first Run() # call, which throws the test timing off in fastbuild mode. - variables.global_variables_initializer().run() + self.evaluate(variables.global_variables_initializer()) # Create a saver that will keep the last 2 checkpoints plus one every 0.7 # seconds. start_time = time.time() @@ -1387,7 +1480,7 @@ class SaveRestoreWithVariableNameMap(test.TestCase): v0 = variable_op(-1.0, name="v0") v1 = variable_op(-1.0, name="v1") - if context.in_graph_mode(): + if not context.executing_eagerly(): with self.assertRaisesOpError("uninitialized"): self.evaluate(v0) with self.assertRaisesOpError("uninitialized"): @@ -1397,7 +1490,7 @@ class SaveRestoreWithVariableNameMap(test.TestCase): save.restore(sess, save_path) # Check that the parameter nodes have been restored. - if context.in_graph_mode(): + if not context.executing_eagerly(): self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) @@ -1407,7 +1500,7 @@ class SaveRestoreWithVariableNameMap(test.TestCase): v0 = variable_op(-1.0, name="restore_prefix/v0") v1 = variable_op(-1.0, name="restore_prefix/v1") - if context.in_graph_mode(): + if not context.executing_eagerly(): with self.assertRaisesOpError("uninitialized"): self.evaluate(v0) with self.assertRaisesOpError("uninitialized"): @@ -2039,6 +2132,113 @@ class MetaGraphTest(test.TestCase): self._testGraphExtensionRestore(test_dir) self._testRestoreFromTrainGraphWithControlContext(test_dir) + def _testGradientSerDes(self, graph_fn): + """Tests that gradients can be computed after exporting and importing. + + Builds a graph, exports it, and verifies that it can be imported and the + gradient can be built and run correctly. + + Args: + graph_fn: takes a single float Tensor argument as input, outputs a single + Tensor + """ + test_dir = self._get_test_dir("nested_control_flow") + filename = os.path.join(test_dir, "metafile") + saver_ckpt = os.path.join(test_dir, "saver.ckpt") + + # Create while loop using `outer_body_fn`. + with ops_lib.Graph().as_default(): + var = variables.Variable(0.0) + var_name = var.name + output = graph_fn(var) + output_name = output.name + init_op = variables.global_variables_initializer() + + # Generate a MetaGraphDef containing the while loop. + with session.Session() as sess: + sess.run(init_op) + sess.run(output) + saver = saver_module.Saver() + saver.save(sess, saver_ckpt) + saver.export_meta_graph(filename) + + # Build and run the gradients of the while loop. We use this below to + # verify that the gradients are correct with an imported MetaGraphDef. + grad = gradients_impl.gradients([output], [var]) + # Turn off constant folding to avoid breaking testNestedControlFlowSerDes. + # It appears that a missing control dependency in the gradient graph + # causes the fetch node to not be triggered. + no_constfold_config = config_pb2.ConfigProto() + no_constfold_config.graph_options.rewrite_options.constant_folding = ( + rewriter_config_pb2.RewriterConfig.OFF) + with session.Session(config=no_constfold_config) as sess: + sess.run(init_op) + expected_grad_value = sess.run(grad) + + # Restore the MetaGraphDef into a new Graph. + with ops_lib.Graph().as_default(): + with session.Session() as sess: + saver = saver_module.import_meta_graph(filename) + saver.restore(sess, saver_ckpt) + + # Make sure we can still build gradients and get the same result. + var = ops_lib.get_default_graph().get_tensor_by_name(var_name) + output = ops_lib.get_default_graph().get_tensor_by_name(output_name) + grad = gradients_impl.gradients([output], [var]) + + init_op = variables.global_variables_initializer() + + with session.Session(config=no_constfold_config) as sess: + sess.run(init_op) + actual_grad_value = sess.run(grad) + self.assertEqual(expected_grad_value, actual_grad_value) + + def _testWhileLoopAndGradientSerDes(self, outer_body_fn): + # Build a while loop with `outer_body_fn`, export it, and verify that it can + # be imported and the gradient can be built and run correctly. + # pylint: disable=g-long-lambda + return self._testGradientSerDes( + lambda x: control_flow_ops.while_loop( + lambda i, y: i < 5, outer_body_fn, [0, x])[1]) + # pylint: enable=g-long-lambda + + def testNestedWhileLoopsSerDes(self): + # Test two simple nested while loops. + def body(i, x): + _, r = control_flow_ops.while_loop(lambda j, y: j < 3, + lambda j, y: (j + 1, y + x), + [0, 0.0]) + return i + 1, x + r + self._testWhileLoopAndGradientSerDes(body) + + def testNestedControlFlowSerDes(self): + # Test while loop in a cond in a while loop. + # pylint: disable=g-long-lambda + def body(i, x): + cond_result = control_flow_ops.cond( + i > 0, + lambda: control_flow_ops.while_loop( + lambda j, y: j < 3, + lambda j, y: (j + 1, y + x), + [0, 0.0])[1], + lambda: x) + return i + 1, cond_result + # pylint: enable=g-long-lambda + self._testWhileLoopAndGradientSerDes(body) + + def testNestedCondsSerDes(self): + # Test conds in a cond. + # pylint: disable=g-long-lambda + self._testGradientSerDes(lambda x: control_flow_ops.cond( + x > 0, + lambda: control_flow_ops.cond(x > 3, + lambda: array_ops.identity(x), + lambda: math_ops.multiply(x, 2.0)), + lambda: control_flow_ops.cond(x < -3, + lambda: constant_op.constant(1.0), + lambda: math_ops.multiply(x, -1.0)))) + # pylint: enable=g-long-lambda + def testStrippedOpListDef(self): with self.test_session(): # Creates a graph. @@ -2660,5 +2860,94 @@ class ScopedGraphTest(test.TestCase): self.assertEqual(2.0, var_dict2["variable2:0"].eval()) +class _OwnsAVariableSimple(checkpointable.CheckpointableBase): + """A Checkpointable object which can be saved using a tf.train.Saver.""" + + def __init__(self): + self.non_dep_variable = variable_scope.get_variable( + name="non_dep_variable", initializer=6., use_resource=True) + + def _gather_saveables_for_checkpoint(self): + return {checkpointable.VARIABLE_VALUE_KEY: self.non_dep_variable} + + # The Saver sorts by name before parsing, so we need a name property. + @property + def name(self): + return self.non_dep_variable.name + + +class _MirroringSaveable( + saver_module.BaseSaverBuilder.ResourceVariableSaveable): + + def __init__(self, primary_variable, mirrored_variable, name): + self._primary_variable = primary_variable + self._mirrored_variable = mirrored_variable + super(_MirroringSaveable, self).__init__( + self._primary_variable, "", name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into both variables.""" + tensor, = restored_tensors + return control_flow_ops.group( + self._primary_variable.assign(tensor), + self._mirrored_variable.assign(tensor)) + + +class _OwnsMirroredVariables(checkpointable.CheckpointableBase): + """A Checkpointable object which returns a more complex SaveableObject.""" + + def __init__(self): + self.non_dep_variable = variable_scope.get_variable( + name="non_dep_variable", initializer=6., use_resource=True) + self.mirrored = variable_scope.get_variable( + name="mirrored", initializer=15., use_resource=True) + + def _gather_saveables_for_checkpoint(self): + def _saveable_factory(name=self.non_dep_variable.name): + return _MirroringSaveable( + primary_variable=self.non_dep_variable, + mirrored_variable=self.mirrored, + name=name) + return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + + # The Saver sorts by name before parsing, so we need a name property. + @property + def name(self): + return self.non_dep_variable.name + + +@test_util.with_c_api +class CheckpointableCompatibilityTests(test.TestCase): + + # TODO(allenl): Track down python3 reference cycles in these tests. + @test_util.run_in_graph_and_eager_modes() + def testNotSaveableButIsCheckpointable(self): + v = _OwnsAVariableSimple() + saver = saver_module.Saver(var_list=[v]) + test_dir = self.get_temp_dir() + prefix = os.path.join(test_dir, "ckpt") + self.evaluate(v.non_dep_variable.assign(42.)) + with self.test_session() as sess: + save_path = saver.save(sess, prefix) + self.evaluate(v.non_dep_variable.assign(43.)) + saver.restore(sess, save_path) + self.assertEqual(42., self.evaluate(v.non_dep_variable)) + + @test_util.run_in_graph_and_eager_modes() + def testMoreComplexSaveableReturned(self): + v = _OwnsMirroredVariables() + saver = saver_module.Saver(var_list=[v]) + test_dir = self.get_temp_dir() + prefix = os.path.join(test_dir, "ckpt") + self.evaluate(v.non_dep_variable.assign(42.)) + with self.test_session() as sess: + save_path = saver.save(sess, prefix) + self.evaluate(v.non_dep_variable.assign(43.)) + self.evaluate(v.mirrored.assign(44.)) + saver.restore(sess, save_path) + self.assertEqual(42., self.evaluate(v.non_dep_variable)) + self.assertEqual(42., self.evaluate(v.mirrored)) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/saver_test_utils.py b/tensorflow/python/training/saver_test_utils.py index 44b06b357ecbe4c8e330a2ccc49e83ddd4bf8c7d..2bbe5b6d845c304c4dc79fb3619c57211ca0489e 100644 --- a/tensorflow/python/training/saver_test_utils.py +++ b/tensorflow/python/training/saver_test_utils.py @@ -35,12 +35,12 @@ class CheckpointedOp(object): # pylint: disable=protected-access def __init__(self, name, table_ref=None): if table_ref is None: - self.table_ref = gen_lookup_ops._mutable_hash_table_v2( + self.table_ref = gen_lookup_ops.mutable_hash_table_v2( key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name) else: self.table_ref = table_ref self._name = name - if context.in_graph_mode(): + if not context.executing_eagerly(): self._saveable = CheckpointedOp.CustomSaveable(self, name) ops_lib.add_to_collection(ops_lib.GraphKeys.SAVEABLE_OBJECTS, self._saveable) @@ -51,16 +51,16 @@ class CheckpointedOp(object): @property def saveable(self): - if context.in_graph_mode(): - return self._saveable - else: + if context.executing_eagerly(): return CheckpointedOp.CustomSaveable(self, self.name) + else: + return self._saveable def insert(self, keys, values): - return gen_lookup_ops._lookup_table_insert_v2(self.table_ref, keys, values) + return gen_lookup_ops.lookup_table_insert_v2(self.table_ref, keys, values) def lookup(self, keys, default): - return gen_lookup_ops._lookup_table_find_v2(self.table_ref, keys, default) + return gen_lookup_ops.lookup_table_find_v2(self.table_ref, keys, default) def keys(self): return self._export()[0] @@ -69,8 +69,8 @@ class CheckpointedOp(object): return self._export()[1] def _export(self): - return gen_lookup_ops._lookup_table_export_v2(self.table_ref, dtypes.string, - dtypes.float32) + return gen_lookup_ops.lookup_table_export_v2(self.table_ref, dtypes.string, + dtypes.float32) class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject): """A custom saveable for CheckpointedOp.""" @@ -86,6 +86,6 @@ class CheckpointedOp(object): super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name) def restore(self, restore_tensors, shapes): - return gen_lookup_ops._lookup_table_import_v2( + return gen_lookup_ops.lookup_table_import_v2( self.op.table_ref, restore_tensors[0], restore_tensors[1]) # pylint: enable=protected-access diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py index 75ef3d5976aba9f0cbe849d9f6984646d71a29ef..9ac52dd0715d7ed15e2e57ed286be973614b01e5 100644 --- a/tensorflow/python/training/slot_creator.py +++ b/tensorflow/python/training/slot_creator.py @@ -106,7 +106,10 @@ def create_slot(primary, val, name, colocate_with_primary=True): # and the same name has been previously used, the scope name will add '_N' # as suffix for unique identifications. validate_shape = val.get_shape().is_fully_defined() - prefix = primary.op.name if context.in_graph_mode() else primary._shared_name # pylint: disable=protected-access + if context.executing_eagerly(): + prefix = primary._shared_name # pylint: disable=protected-access + else: + prefix = primary.op.name with variable_scope.variable_scope(None, prefix + "/" + name): if colocate_with_primary: with ops.colocate_with(primary): @@ -139,7 +142,10 @@ def create_slot_with_initializer(primary, initializer, shape, dtype, name, # and the same name has been previously used, the scope name will add '_N' # as suffix for unique identifications. validate_shape = shape.is_fully_defined() - prefix = primary.op.name if context.in_graph_mode() else primary._shared_name # pylint: disable=protected-access + if context.executing_eagerly(): + prefix = primary._shared_name # pylint: disable=protected-access + else: + prefix = primary.op.name with variable_scope.variable_scope(None, prefix + "/" + name): if colocate_with_primary: with ops.colocate_with(primary): diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py index d2ad34773e0615256c340826dcc312cc8a00dc23..7389e344c7d8eef8e26c4d24c0985ff66276deea 100644 --- a/tensorflow/python/training/supervisor.py +++ b/tensorflow/python/training/supervisor.py @@ -45,7 +45,7 @@ class Supervisor(object): """A training helper that checkpoints models and computes summaries. This class is deprecated. Please use - ${tf.train.MonitoredTrainingSession} instead. + @{tf.train.MonitoredTrainingSession} instead. The Supervisor is a small wrapper around a `Coordinator`, a `Saver`, and a `SessionManager` that takes care of common needs of TensorFlow @@ -305,7 +305,7 @@ class Supervisor(object): `Supervisor`s are not supported when eager execution is enabled. @end_compatibility """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("Supervisors are compatible with eager execution.") # Set default values of arguments. if graph is None: @@ -762,7 +762,7 @@ class Supervisor(object): execution is enabled, use the `tf.data` API. @end_compatibility """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("Queues are not compatible with eager execution.") if queue_runners is None: queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS) diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py index 78c8ce9208efc2f2fa8b5c671d3379e7ca8c70f5..b759b156d78cf8d869b49375058cc7ed42e82b34 100644 --- a/tensorflow/python/training/training.py +++ b/tensorflow/python/training/training.py @@ -28,8 +28,10 @@ See the @{$python/train} guide. @@ProximalGradientDescentOptimizer @@ProximalAdagradOptimizer @@RMSPropOptimizer +@@custom_gradient @@gradients @@AggregationMethod +@@GradientTape @@stop_gradient @@hessians @@clip_by_value @@ -94,6 +96,8 @@ See the @{$python/train} guide. @@load_variable @@list_variables @@init_from_checkpoint +@@warm_start +@@VocabInfo """ # Optimizers. @@ -187,6 +191,8 @@ from tensorflow.python.training.training_util import get_global_step from tensorflow.python.training.training_util import assert_global_step from tensorflow.python.training.training_util import create_global_step from tensorflow.python.training.training_util import get_or_create_global_step +from tensorflow.python.training.warm_starting_util import VocabInfo +from tensorflow.python.training.warm_starting_util import warm_start from tensorflow.python.pywrap_tensorflow import do_quantize_training_on_graphdef from tensorflow.python.pywrap_tensorflow import NewCheckpointReader from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py index 499f1feb2dbf8aee26314a43b0a000fb91a1c686..d05e1d2c830b2aa7008c9cba9f28eb6230d8bc82 100644 --- a/tensorflow/python/training/training_util.py +++ b/tensorflow/python/training/training_util.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_io @@ -31,7 +30,6 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export - # Picked a long key value to minimize the chance of collision with user defined # collection keys. GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache' @@ -64,7 +62,7 @@ def global_step(sess, global_step_tensor): Returns: The global step value. """ - if context.in_eager_mode(): + if context.executing_eagerly(): return int(global_step_tensor.numpy()) return int(sess.run(global_step_tensor)) @@ -123,7 +121,7 @@ def create_global_step(graph=None): raise ValueError('"global_step" already exists.') # Create in proper graph and base name_scope. with graph.as_default() as g, g.name_scope(None): - if context.in_eager_mode(): + if context.executing_eagerly(): with ops.device('cpu:0'): return variable_scope.get_variable( ops.GraphKeys.GLOBAL_STEP, @@ -170,8 +168,7 @@ def assert_global_step(global_step_tensor): """ if not (isinstance(global_step_tensor, variables.Variable) or isinstance(global_step_tensor, ops.Tensor) or - isinstance(global_step_tensor, - resource_variable_ops.ResourceVariable)): + resource_variable_ops.is_resource_variable(global_step_tensor)): raise TypeError( 'Existing "global_step" must be a Variable or Tensor: %s.' % global_step_tensor) diff --git a/tensorflow/python/estimator/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py similarity index 67% rename from tensorflow/python/estimator/warm_starting_util.py rename to tensorflow/python/training/warm_starting_util.py index adb013f5c653c4967a743047fef4e805946e0f59..4d4fb394c1272d2bf510bb594d70b9aa2edb3df2 100644 --- a/tensorflow/python/estimator/warm_starting_util.py +++ b/tensorflow/python/training/warm_starting_util.py @@ -33,7 +33,7 @@ from tensorflow.python.training import saver from tensorflow.python.util.tf_export import tf_export -@tf_export("estimator.VocabInfo") +@tf_export("train.VocabInfo", "estimator.VocabInfo") class VocabInfo( collections.namedtuple("VocabInfo", [ "new_vocab", @@ -43,7 +43,7 @@ class VocabInfo( "old_vocab_size", "backup_initializer", ])): - """Vocabulary information for WarmStartSettings. + """Vocabulary information for warm-starting. See @{tf.estimator.WarmStartSettings$WarmStartSettings} for examples of using VocabInfo to warm-start. @@ -83,164 +83,6 @@ class VocabInfo( ) -@tf_export("estimator.WarmStartSettings") -class WarmStartSettings( - collections.namedtuple("WarmStartSettings", [ - "ckpt_to_initialize_from", - "vars_to_warm_start", - "var_name_to_vocab_info", - "var_name_to_prev_var_name", - ])): - """Settings for warm-starting in Estimators. - - Example Use with canned `DNNEstimator`: - - ``` - emb_vocab_file = tf.feature_column.embedding_column( - tf.feature_column.categorical_column_with_vocabulary_file( - "sc_vocab_file", "new_vocab.txt", vocab_size=100), - dimension=8) - emb_vocab_list = tf.feature_column.embedding_column( - tf.feature_column.categorical_column_with_vocabulary_list( - "sc_vocab_list", vocabulary_list=["a", "b"]), - dimension=8) - estimator = tf.estimator.DNNClassifier( - hidden_units=[128, 64], feature_columns=[emb_vocab_file, emb_vocab_list], - warm_start_from=ws) - ``` - - where `ws` could be defined as: - - Warm-start all weights in the model (input layer and hidden weights). - Either the directory or a specific checkpoint can be provided (in the case - of the former, the latest checkpoint will be used): - - ``` - ws = WarmStartSettings(ckpt_to_initialize_from="/tmp") - ws = WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000") - ``` - - Warm-start only the embeddings (input layer): - - ``` - ws = WarmStartSettings(ckpt_to_initialize_from="/tmp", - vars_to_warm_start=".*input_layer.*") - ``` - - Warm-start all weights but the embedding parameters corresponding to - `sc_vocab_file` have a different vocab from the one used in the current - model: - - ``` - vocab_info = ws_util.VocabInfo( - new_vocab=sc_vocab_file.vocabulary_file, - new_vocab_size=sc_vocab_file.vocabulary_size, - num_oov_buckets=sc_vocab_file.num_oov_buckets, - old_vocab="old_vocab.txt" - ) - ws = WarmStartSettings( - ckpt_to_initialize_from="/tmp", - var_name_to_vocab_info={ - "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info - }) - ``` - - Warm-start only `sc_vocab_file` embeddings (and no other variables), which - have a different vocab from the one used in the current model: - - ``` - vocab_info = ws_util.VocabInfo( - new_vocab=sc_vocab_file.vocabulary_file, - new_vocab_size=sc_vocab_file.vocabulary_size, - num_oov_buckets=sc_vocab_file.num_oov_buckets, - old_vocab="old_vocab.txt" - ) - ws = WarmStartSettings( - ckpt_to_initialize_from="/tmp", - vars_to_warm_start=None, - var_name_to_vocab_info={ - "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info - }) - ``` - - Warm-start all weights but the parameters corresponding to `sc_vocab_file` - have a different vocab from the one used in current checkpoint, and only - 100 of those entries were used: - - ``` - vocab_info = ws_util.VocabInfo( - new_vocab=sc_vocab_file.vocabulary_file, - new_vocab_size=sc_vocab_file.vocabulary_size, - num_oov_buckets=sc_vocab_file.num_oov_buckets, - old_vocab="old_vocab.txt", - old_vocab_size=100 - ) - ws = WarmStartSettings( - ckpt_to_initialize_from="/tmp", - var_name_to_vocab_info={ - "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info - }) - ``` - - Warm-start all weights but the parameters corresponding to `sc_vocab_file` - have a different vocab from the one used in current checkpoint and the - parameters corresponding to `sc_vocab_list` have a different name from the - current checkpoint: - - ``` - vocab_info = ws_util.VocabInfo( - new_vocab=sc_vocab_file.vocabulary_file, - new_vocab_size=sc_vocab_file.vocabulary_size, - num_oov_buckets=sc_vocab_file.num_oov_buckets, - old_vocab="old_vocab.txt", - old_vocab_size=100 - ) - ws = WarmStartSettings( - ckpt_to_initialize_from="/tmp", - var_name_to_vocab_info={ - "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info - }, - var_name_to_prev_var_name={ - "input_layer/sc_vocab_list_embedding/embedding_weights": - "old_tensor_name" - }) - ``` - - Attributes: - ckpt_to_initialize_from: [Required] A string specifying the directory with - checkpoint file(s) or path to checkpoint from which to warm-start the - model parameters. - vars_to_warm_start: [Optional] A regular expression that captures which - variables to warm-start (see tf.get_collection). Defaults to `'.*'`, - which warm-starts all variables. If `None` is explicitly given, only - variables specified in `var_name_to_vocab_info` will be warm-started. - var_name_to_vocab_info: [Optional] Dict of variable names (strings) to - VocabInfo. The variable names should be "full" variables, not the names - of the partitions. If not explicitly provided, the variable is assumed to - have no vocabulary. - var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to - name of the previously-trained variable in `ckpt_to_initialize_from`. If - not explicitly provided, the name of the variable is assumed to be same - between previous checkpoint and current model. - """ - - def __new__(cls, - ckpt_to_initialize_from, - vars_to_warm_start=".*", - var_name_to_vocab_info=None, - var_name_to_prev_var_name=None): - if not ckpt_to_initialize_from: - raise ValueError( - "`ckpt_to_initialize_from` MUST be set in WarmStartSettings") - return super(WarmStartSettings, cls).__new__( - cls, - ckpt_to_initialize_from, - vars_to_warm_start, - var_name_to_vocab_info or {}, - var_name_to_prev_var_name or {}, - ) - - def _is_variable(x): return (isinstance(x, variables_lib.Variable) or isinstance(x, resource_variable_ops.ResourceVariable)) @@ -375,8 +217,7 @@ def _warm_start_var_with_vocab(var, full_shape=slice_info.full_shape, var_offset=slice_info.var_offset) - # TODO(eddz): Support WarmStartSettings where class vocabularies need - # remapping too. + # TODO(eddz): Support cases where class vocabularies need remapping too. init = checkpoint_ops._load_and_remap_matrix_initializer( ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt), old_tensor_name=prev_tensor_name, @@ -396,32 +237,53 @@ def _warm_start_var_with_vocab(var, # pylint: enable=protected-access -def _warm_start(warm_start_settings): +@tf_export("train.warm_start") +def warm_start(ckpt_to_initialize_from, + vars_to_warm_start=".*", + var_name_to_vocab_info=None, + var_name_to_prev_var_name=None): """Warm-starts a model using the given settings. If you are using a tf.estimator.Estimator, this will automatically be called during training. Args: - warm_start_settings: An object of `WarmStartSettings`. + ckpt_to_initialize_from: [Required] A string specifying the directory with + checkpoint file(s) or path to checkpoint from which to warm-start the + model parameters. + vars_to_warm_start: [Optional] A regular expression that captures which + variables to warm-start (see tf.get_collection). Defaults to `'.*'`, + which warm-starts all variables. If `None` is explicitly given, only + variables specified in `var_name_to_vocab_info` will be warm-started. + var_name_to_vocab_info: [Optional] Dict of variable names (strings) to + VocabInfo. The variable names should be "full" variables, not the names + of the partitions. If not explicitly provided, the variable is assumed to + have no vocabulary. + var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to + name of the previously-trained variable in `ckpt_to_initialize_from`. If + not explicitly provided, the name of the variable is assumed to be same + between previous checkpoint and current model. Raises: ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo configuration for variable names that are not used. This is to ensure a stronger check for variable configuration than relying on users to examine the logs. """ - logging.info("Warm-starting from: %s", - (warm_start_settings.ckpt_to_initialize_from,)) + if var_name_to_vocab_info is None: + var_name_to_vocab_info = {} + if var_name_to_prev_var_name is None: + var_name_to_prev_var_name = {} + logging.info("Warm-starting from: %s", (ckpt_to_initialize_from,)) # We have to deal with partitioned variables, since get_collection flattens # out the list. grouped_variables = {} - # Both warm_start_settings.vars_to_warm_start = '.*' and - # warm_start_settings.vars_to_warm_start = None will match everything here. + # Both vars_to_warm_start = '.*' and + # vars_to_warm_start = None will match everything here. for v in ops.get_collection( # TODO(eddz): Allow for different collections here (to support # warm-starting accumulators). ops.GraphKeys.TRAINABLE_VARIABLES, - scope=warm_start_settings.vars_to_warm_start): + scope=vars_to_warm_start): if not isinstance(v, list): var_name = _infer_var_name([v]) else: @@ -437,10 +299,10 @@ def _warm_start(warm_start_settings): vocab_info_used = set() for var_name, variable in six.iteritems(grouped_variables): - prev_var_name = warm_start_settings.var_name_to_prev_var_name.get(var_name) + prev_var_name = var_name_to_prev_var_name.get(var_name) if prev_var_name: prev_var_name_used.add(var_name) - vocab_info = warm_start_settings.var_name_to_vocab_info.get(var_name) + vocab_info = var_name_to_vocab_info.get(var_name) if vocab_info: vocab_info_used.add(var_name) logging.info( @@ -460,16 +322,16 @@ def _warm_start(warm_start_settings): variable, current_vocab_path=vocab_info.new_vocab, current_vocab_size=vocab_info.new_vocab_size, - prev_ckpt=warm_start_settings.ckpt_to_initialize_from, + prev_ckpt=ckpt_to_initialize_from, prev_vocab_path=vocab_info.old_vocab, previous_vocab_size=vocab_info.old_vocab_size, current_oov_buckets=vocab_info.num_oov_buckets, prev_tensor_name=prev_var_name, initializer=vocab_info.backup_initializer) else: - # For the special value of warm_start_settings.vars_to_warm_start = None, + # For the special value of vars_to_warm_start = None, # we only warm-start variables with explicitly specified vocabularies. - if warm_start_settings.vars_to_warm_start: + if vars_to_warm_start: logging.info("Warm-starting variable: {}; prev_var_name: {}".format( var_name, prev_var_name or "Unchanged")) # Because we use a default empty list in grouped_variables, single @@ -477,48 +339,22 @@ def _warm_start(warm_start_settings): # for init_from_checkpoint logic to work correctly. if len(variable) == 1: variable = variable[0] - _warm_start_var(variable, warm_start_settings.ckpt_to_initialize_from, - prev_var_name) + _warm_start_var(variable, ckpt_to_initialize_from, prev_var_name) prev_var_name_not_used = set( - warm_start_settings.var_name_to_prev_var_name.keys()) - prev_var_name_used - vocab_info_not_used = set( - warm_start_settings.var_name_to_vocab_info.keys()) - vocab_info_used + var_name_to_prev_var_name.keys()) - prev_var_name_used + vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used if prev_var_name_not_used: raise ValueError( "You provided the following variables in " - "warm_start_settings.var_name_to_prev_var_name that were not used: " + "var_name_to_prev_var_name that were not used: " "{0}. Perhaps you misspelled them? Here is the list of viable " "variable names: {1}".format(prev_var_name_not_used, grouped_variables.keys())) if vocab_info_not_used: raise ValueError( "You provided the following variables in " - "warm_start_settings.var_name_to_vocab_info that were not used: {0}. " + "var_name_to_vocab_info that were not used: {0}. " " Perhaps you misspelled them? Here is the list of viable variable " "names: {1}".format(vocab_info_not_used, grouped_variables.keys())) - - -def _get_default_warm_start_settings(warm_start_from): - """Returns default WarmStartSettings. - - Args: - warm_start_from: Either a string representing the filepath of a checkpoint - to initialize from, or an instance of WarmStartSettings. - - Returns: - Either None or an instance of WarmStartSettings. - - Raises: - ValueError: If warm_start_from is not None but is neither a string nor an - instance of WarmStartSettings. - """ - if warm_start_from is None: - return None - if isinstance(warm_start_from, six.string_types): - return WarmStartSettings(ckpt_to_initialize_from=warm_start_from) - elif isinstance(warm_start_from, WarmStartSettings): - return warm_start_from - else: - raise ValueError("warm_start_from must be a string or a WarmStartSettings") diff --git a/tensorflow/python/estimator/warm_starting_util_test.py b/tensorflow/python/training/warm_starting_util_test.py similarity index 94% rename from tensorflow/python/estimator/warm_starting_util_test.py rename to tensorflow/python/training/warm_starting_util_test.py index 3985d9ebd04e6963339fcf9999f6367fe4dadc1a..6e445d8bd14cc13010541c1ab0f737f96a4b1e03 100644 --- a/tensorflow/python/estimator/warm_starting_util_test.py +++ b/tensorflow/python/training/warm_starting_util_test.py @@ -22,7 +22,6 @@ import os import numpy as np import six -from tensorflow.python.estimator import warm_starting_util as ws_util from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -32,6 +31,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import warm_starting_util as ws_util ones = init_ops.ones_initializer norms = init_ops.truncated_normal_initializer @@ -330,9 +330,7 @@ class WarmStartingUtilTest(test.TestCase): with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: cols_to_vars = self._create_linear_model([sc_int], partitioner) - ws_util._warm_start( - ws_util.WarmStartSettings( - self.get_temp_dir(), vars_to_warm_start=".*sc_int.*")) + ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=".*sc_int.*") sess.run(variables.global_variables_initializer()) # Verify weights were correctly warm-started. self._assert_cols_to_vars(cols_to_vars, {sc_int: [prev_int_val]}, sess) @@ -361,9 +359,8 @@ class WarmStartingUtilTest(test.TestCase): with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: cols_to_vars = self._create_linear_model([sc_hash], partitioner) - ws_util._warm_start( - ws_util.WarmStartSettings( - self.get_temp_dir(), vars_to_warm_start=".*sc_hash.*")) + ws_util.warm_start( + self.get_temp_dir(), vars_to_warm_start=".*sc_hash.*") sess.run(variables.global_variables_initializer()) # Verify weights were correctly warm-started. self._assert_cols_to_vars(cols_to_vars, {sc_hash: [prev_hash_val]}, @@ -398,9 +395,8 @@ class WarmStartingUtilTest(test.TestCase): cols_to_vars = self._create_linear_model([sc_vocab], partitioner) # Since old vocab is not explicitly set in WarmStartSettings, the old # vocab is assumed to be same as new vocab. - ws_util._warm_start( - ws_util.WarmStartSettings( - self.get_temp_dir(), vars_to_warm_start=".*sc_vocab.*")) + ws_util.warm_start( + self.get_temp_dir(), vars_to_warm_start=".*sc_vocab.*") sess.run(variables.global_variables_initializer()) # Verify weights were correctly warm-started. self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]}, @@ -435,11 +431,10 @@ class WarmStartingUtilTest(test.TestCase): cols_to_vars = self._create_linear_model([sc_vocab], partitioner) # Since old vocab is not explicitly set in WarmStartSettings, the old # vocab is assumed to be same as new vocab. - ws_util._warm_start( - ws_util.WarmStartSettings( - # Explicitly provide the file prefix instead of just the dir. - os.path.join(self.get_temp_dir(), "model-0"), - vars_to_warm_start=".*sc_vocab.*")) + ws_util.warm_start( + # Explicitly provide the file prefix instead of just the dir. + os.path.join(self.get_temp_dir(), "model-0"), + vars_to_warm_start=".*sc_vocab.*") sess.run(variables.global_variables_initializer()) # Verify weights were correctly warm-started. self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]}, @@ -485,13 +480,12 @@ class WarmStartingUtilTest(test.TestCase): num_oov_buckets=sc_vocab.num_oov_buckets, old_vocab=old_vocab_path, old_vocab_size=old_vocab_size) - warm_start_settings = ws_util.WarmStartSettings( + ws_util.warm_start( ckpt_to_initialize_from=self.get_temp_dir(), vars_to_warm_start=".*sc_vocab.*", var_name_to_vocab_info={ "linear_model/sc_vocab/weights": vocab_info }) - ws_util._warm_start(warm_start_settings) sess.run(variables.global_variables_initializer()) # Verify weights were correctly warm-started. 'banana' isn't in the # first two entries of the old vocabulary, so it's newly initialized. @@ -523,9 +517,8 @@ class WarmStartingUtilTest(test.TestCase): with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: cols_to_vars = self._create_linear_model([real_bucket], partitioner) - ws_util._warm_start( - ws_util.WarmStartSettings( - self.get_temp_dir(), vars_to_warm_start=".*real_bucketized.*")) + ws_util.warm_start( + self.get_temp_dir(), vars_to_warm_start=".*real_bucketized.*") sess.run(variables.global_variables_initializer()) # Verify weights were correctly warm-started. self._assert_cols_to_vars(cols_to_vars, @@ -606,12 +599,11 @@ class WarmStartingUtilTest(test.TestCase): new_vocab_size=sc_vocab.vocabulary_size, num_oov_buckets=sc_vocab.num_oov_buckets, old_vocab=vocab_path) - ws_util._warm_start( - ws_util.WarmStartSettings( - self.get_temp_dir(), - var_name_to_vocab_info={ - "linear_model/sc_vocab/weights": vocab_info - })) + ws_util.warm_start( + self.get_temp_dir(), + var_name_to_vocab_info={ + "linear_model/sc_vocab/weights": vocab_info + }) sess.run(variables.global_variables_initializer()) # Verify weights were correctly warm-started. self._assert_cols_to_vars(cols_to_vars, { @@ -668,7 +660,7 @@ class WarmStartingUtilTest(test.TestCase): new_vocab_size=sc_vocab.vocabulary_size, num_oov_buckets=sc_vocab.num_oov_buckets, old_vocab=prev_vocab_path) - ws_settings = ws_util.WarmStartSettings( + ws_util.warm_start( self.get_temp_dir(), vars_to_warm_start=".*(sc_keys|sc_vocab).*", var_name_to_vocab_info={ @@ -678,7 +670,6 @@ class WarmStartingUtilTest(test.TestCase): ws_util._infer_var_name(cols_to_vars[sc_keys]): "some_other_name" }) - ws_util._warm_start(ws_settings) sess.run(variables.global_variables_initializer()) # Verify weights were correctly warm-started. Var corresponding to # sc_hash should not be warm-started. Var corresponding to sc_vocab @@ -732,7 +723,7 @@ class WarmStartingUtilTest(test.TestCase): new_vocab_size=sc_vocab.vocabulary_size, num_oov_buckets=sc_vocab.num_oov_buckets, old_vocab=prev_vocab_path) - ws_settings = ws_util.WarmStartSettings( + ws_util.warm_start( self.get_temp_dir(), vars_to_warm_start=".*(sc_keys|sc_vocab).*", var_name_to_vocab_info={ @@ -742,7 +733,6 @@ class WarmStartingUtilTest(test.TestCase): ws_util._infer_var_name(cols_to_vars[sc_keys]): "some_other_name" }) - ws_util._warm_start(ws_settings) sess.run(variables.global_variables_initializer()) # Verify weights were correctly warm-started. Var corresponding to # sc_hash should not be warm-started. Var corresponding to sc_vocab @@ -796,7 +786,7 @@ class WarmStartingUtilTest(test.TestCase): new_vocab_size=sc_vocab.vocabulary_size, num_oov_buckets=sc_vocab.num_oov_buckets, old_vocab=prev_vocab_path) - ws_settings = ws_util.WarmStartSettings( + ws_util.warm_start( self.get_temp_dir(), # The special value of None here will ensure that only the variable # specified in var_name_to_vocab_info (sc_vocab embedding) is @@ -812,7 +802,6 @@ class WarmStartingUtilTest(test.TestCase): ws_util._infer_var_name(cols_to_vars[sc_keys]): "some_other_name" }) - ws_util._warm_start(ws_settings) sess.run(variables.global_variables_initializer()) # Verify weights were correctly warm-started. Var corresponding to # sc_vocab should be correctly warm-started after vocab remapping, @@ -874,13 +863,12 @@ class WarmStartingUtilTest(test.TestCase): # use a truncated normal initializer. backup_initializer=init_ops.random_uniform_initializer( minval=0.42, maxval=0.42)) - ws_settings = ws_util.WarmStartSettings( + ws_util.warm_start( self.get_temp_dir(), var_name_to_vocab_info={ ws_util._infer_var_name(cols_to_vars[emb_vocab_column]): vocab_info }) - ws_util._warm_start(ws_settings) sess.run(variables.global_variables_initializer()) # Verify weights were correctly warm-started. Var corresponding to # emb_vocab_column should be correctly warm-started after vocab @@ -947,13 +935,12 @@ class WarmStartingUtilTest(test.TestCase): # use a truncated normal initializer. backup_initializer=init_ops.random_uniform_initializer( minval=0.42, maxval=0.42)) - ws_settings = ws_util.WarmStartSettings( + ws_util.warm_start( self.get_temp_dir(), vars_to_warm_start=".*sc_vocab.*", var_name_to_vocab_info={ "linear_model/sc_vocab_embedding/embedding_weights": vocab_info }) - ws_util._warm_start(ws_settings) sess.run(variables.global_variables_initializer()) # Verify weights were correctly warm-started. Var corresponding to # emb_vocab should be correctly warm-started after vocab remapping. @@ -973,7 +960,6 @@ class WarmStartingUtilTest(test.TestCase): }, sess) def testErrorConditions(self): - self.assertRaises(ValueError, ws_util.WarmStartSettings, None) x = variable_scope.get_variable( "x", shape=[4, 1], @@ -983,9 +969,6 @@ class WarmStartingUtilTest(test.TestCase): # List of PartitionedVariable is invalid type when warm-starting with vocab. self.assertRaises(TypeError, ws_util._warm_start_var_with_vocab, [x], "/tmp", 5, "/tmp", "/tmp") - # Keys of type other than FeatureColumn. - self.assertRaises(TypeError, ws_util._warm_start, {"StringType": x}, - ws_util.WarmStartSettings("/tmp")) # Unused variable names raises ValueError. with ops.Graph().as_default(): @@ -997,18 +980,16 @@ class WarmStartingUtilTest(test.TestCase): partitioner=lambda shape, dtype: [2, 1]) self._write_checkpoint(sess) - self.assertRaises(ValueError, ws_util._warm_start, - ws_util.WarmStartSettings( - self.get_temp_dir(), - var_name_to_vocab_info={ - "y": ws_util.VocabInfo("", 1, 0, "") - })) - self.assertRaises(ValueError, ws_util._warm_start, - ws_util.WarmStartSettings( - self.get_temp_dir(), - var_name_to_prev_var_name={ - "y": "y2" - })) + self.assertRaises( + ValueError, + ws_util.warm_start, + self.get_temp_dir(), + var_name_to_vocab_info={"y": ws_util.VocabInfo("", 1, 0, "")}) + self.assertRaises( + ValueError, + ws_util.warm_start, + self.get_temp_dir(), + var_name_to_prev_var_name={"y": "y2"}) if __name__ == "__main__": diff --git a/tensorflow/python/user_ops/user_ops.py b/tensorflow/python/user_ops/user_ops.py index 17dbab706c9243c5f119dc82cc4428f03b90a18d..20ea3b0f621dc74bd3778d565f8897e47a881d42 100644 --- a/tensorflow/python/user_ops/user_ops.py +++ b/tensorflow/python/user_ops/user_ops.py @@ -23,8 +23,10 @@ from tensorflow.python.ops import gen_user_ops as _gen_user_ops # go/tf-wildcard-import from tensorflow.python.ops.gen_user_ops import * # pylint: disable=wildcard-import +from tensorflow.python.util.tf_export import tf_export +@tf_export('user_ops.my_fact') def my_fact(): """Example of overriding the generated code for an Op.""" - return _gen_user_ops._fact() # pylint: disable=protected-access + return _gen_user_ops.fact() diff --git a/tensorflow/python/util/decorator_utils.py b/tensorflow/python/util/decorator_utils.py index df259c7f7c29f9a4b674d3e980b33d6dcf323769..7b4363c0e40802779cf47c75c5a5e5a901da37e2 100644 --- a/tensorflow/python/util/decorator_utils.py +++ b/tensorflow/python/util/decorator_utils.py @@ -82,7 +82,7 @@ def add_notice_to_docstring( lines = _normalize_docstring(doc).splitlines() lines[0] += ' ' + suffix_str - notice = [''] + notice + [instructions] + notice = [''] + notice + ([instructions] if instructions else []) if len(lines) > 1: # Make sure that we keep our distance from the main body diff --git a/tensorflow/python/util/port.i b/tensorflow/python/util/port.i index cea4d8468afe8816d71da6581635b8a7ab0c2388..2f730732bee373a6e6ead97fe3320645f37ac220 100644 --- a/tensorflow/python/util/port.i +++ b/tensorflow/python/util/port.i @@ -23,5 +23,6 @@ limitations under the License. %unignore tensorflow; %unignore tensorflow::IsGoogleCudaEnabled; %unignore tensorflow::CudaSupportsHalfMatMulAndConv; +%unignore tensorflow::IsMklEnabled; %include "tensorflow/core/util/port.h" %unignoreall diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py index c2fe6fc4494428693605a5a7463a9f590a2da39e..4ab8a72a83b466c38c50b1c76004e7a6fe942a04 100644 --- a/tensorflow/python/util/tf_inspect.py +++ b/tensorflow/python/util/tf_inspect.py @@ -46,8 +46,10 @@ def getargspec(object): # pylint: disable=redefined-builtin def getfullargspec(obj): # pylint: disable=redefined-builtin - """TFDecorator-aware replacement for inspect.getfullargspec and fallback to - inspect.getargspec in Python 2. + """TFDecorator-aware replacement for `inspect.getfullargspec`/`getargspec`. + + This wrapper uses `inspect.getfullargspec` if available and falls back to + `inspect.getargspec` in Python 2. Args: obj: A callable, possibly decorated. @@ -149,6 +151,11 @@ def getsource(object): # pylint: disable=redefined-builtin return _inspect.getsource(tf_decorator.unwrap(object)[1]) +def isbuiltin(object): # pylint: disable=redefined-builtin + """TFDecorator-aware replacement for inspect.isbuiltin.""" + return _inspect.isbuiltin(tf_decorator.unwrap(object)[1]) + + def isclass(object): # pylint: disable=redefined-builtin """TFDecorator-aware replacement for inspect.isclass.""" return _inspect.isclass(tf_decorator.unwrap(object)[1]) diff --git a/tensorflow/python/util/tf_inspect_test.py b/tensorflow/python/util/tf_inspect_test.py index 8903e1156b27b3a28543eb5ecfcc6eeb1a04f6ae..129408449ebb45ac3a322f163a13b705cbb31f0c 100644 --- a/tensorflow/python/util/tf_inspect_test.py +++ b/tensorflow/python/util/tf_inspect_test.py @@ -144,6 +144,19 @@ def test_decorated_function_with_defaults(a, b=2, c='Hello'): self.assertEqual( expected, tf_inspect.getsource(test_decorated_function_with_defaults)) + def testIsBuiltin(self): + self.assertEqual( + tf_inspect.isbuiltin(TestDecoratedClass), + inspect.isbuiltin(TestDecoratedClass)) + self.assertEqual( + tf_inspect.isbuiltin(test_decorated_function), + inspect.isbuiltin(test_decorated_function)) + self.assertEqual( + tf_inspect.isbuiltin(test_undecorated_function), + inspect.isbuiltin(test_undecorated_function)) + self.assertEqual(tf_inspect.isbuiltin(range), inspect.isbuiltin(range)) + self.assertEqual(tf_inspect.isbuiltin(max), inspect.isbuiltin(max)) + def testIsClass(self): self.assertTrue(tf_inspect.isclass(TestDecoratedClass)) self.assertFalse(tf_inspect.isclass(test_decorated_function)) diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py index 37733152e8ec6d7b026bf74e69e33bfe8f9f4e89..28e49afa023904abed076373685bb38f2537b7d4 100644 --- a/tensorflow/python/util/tf_should_use.py +++ b/tensorflow/python/util/tf_should_use.py @@ -47,7 +47,7 @@ def _add_should_use_warning(x, fatal_error=False): if x is None or x == []: # pylint: disable=g-explicit-bool-comparison return x - if context.in_eager_mode(): + if context.executing_eagerly(): # Typically not needed when executing eagerly (the main use case is for ops # which need to be incorporated into the graph), and even the no-op wrapper # creates reference cycles which require garbage collection. diff --git a/tensorflow/stream_executor/blas.cc b/tensorflow/stream_executor/blas.cc index da09d84921e2dd94942b3a62fe7366211c60aed1..31724cf6c9b97e45975b9e053459f7b8f5918dfa 100644 --- a/tensorflow/stream_executor/blas.cc +++ b/tensorflow/stream_executor/blas.cc @@ -79,6 +79,8 @@ string ComputationTypeString(ComputationType ty) { return "f32"; case ComputationType::kF64: return "f64"; + case ComputationType::kI32: + return "i32"; case ComputationType::kComplexF32: return "complex f32"; case ComputationType::kComplexF64: @@ -88,6 +90,10 @@ string ComputationTypeString(ComputationType ty) { } } +std::ostream& operator<<(std::ostream& os, ComputationType ty) { + return os << ComputationTypeString(ty); +} + } // namespace blas } // namespace gputools } // namespace perftools diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index 072f08554688276a05d9be85718de8750bd874c2..c5f778a5c74519c0f35cea5d59aac3d0d4564c56 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -104,6 +104,8 @@ enum class ComputationType { // Converts a ComputationType to a string. string ComputationTypeString(ComputationType ty); +std::ostream &operator<<(std::ostream &os, ComputationType ty); + // Opaque identifier for an "algorithm" used by a blas routine. This functions // as a hint to the blas library. typedef int64 AlgorithmType; diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 44a3a745ad86dc24f632e4a36691fba06171c9fb..c563f8f931b0a5689268329386d1252f2a45bdd1 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -13,17 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Include cuBLAS headers early, and then set EIGEN_HAS_CUDA_FP16 -// if we have new enough CUDA (which we will only know after including -// cuda.h). This ensures that Eigen's Half.h does not attempt to make its own -// __half typedef if CUDA has already defined one (and conversely, that we do -// not include after Half.h has made its typedef). -#include "cuda/include/cuda.h" #include "cuda/include/cublas_v2.h" - -#if CUDA_VERSION >= 7050 -#define EIGEN_HAS_CUDA_FP16 -#endif +#include "cuda/include/cuda.h" #if CUDA_VERSION >= 8000 #define SE_CUDA_DATA_HALF CUDA_R_16F @@ -33,6 +24,34 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_blas.h" +// Both Eigen Half.h and CUDA cuda_fp16.h provide similar typedef for __half. As +// such, there are two ways to get the typedef for __half: +// +// (1) Includes cuda_fp16.h and defines EIGEN_HAS_CUDA_FP16. +// (2) Neither includes cuda_fp16.h nor defines EIGEN_HAS_CUDA_FP16. +// +// Due to issue b/73793421, when the first approach is used and NVCC is used to +// compile this file, NVCC will complain duplicated definition for +// EIGEN_HAS_CUDA_FP16. On the other hand, when the second approach is used and +// clang is used to compile this file, clang will not understand __half +// due to missing the definition and macro EIGEN_HAS_CUDA_FP16. +// +// Because this file may be compiled with CLANG but will never be compiled with +// NVCC, we choose the first approach for CUDA < 9.0. For CUDA >= 9.0, we have +// to use the second approach because the data member in the __half defined +// by CUDA > 9.0 is `__x` while Eigen expects it to be `x`. +// +// TODO(b/73793421): Remove the following code block to switch to the second +// approach when the issue is fixed. +#if CUDA_VERSION < 9000 +#include "cuda/include/cuda_fp16.h" +#if CUDA_VERSION >= 7050 +#define EIGEN_HAS_CUDA_FP16 +#endif +#endif + +#include "third_party/eigen3/Eigen/Core" + #include #include @@ -2256,6 +2275,14 @@ bool CUDABlas::DoBlasGemmWithAlgorithm( DeviceMemory *c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { + if (computation_type == blas::ComputationType::kF32) { + return DoBlasGemmWithAlgorithmImpl( + stream, transa, transb, m, n, k, static_cast(alpha), a, lda, b, + ldb, static_cast(beta), c, ldc, computation_type, algorithm, + output_profile_result); + } + + CHECK_EQ(computation_type, blas::ComputationType::kF16); return DoBlasGemmWithAlgorithmImpl( stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, algorithm, output_profile_result); diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 61cf4ba7eac1f9482e3c1b179f35434a2a65d955..ab5e6590e0fcdb2f19a0a3a85e64e6b144a97363 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -274,7 +274,8 @@ CUDNN_DNN_ROUTINE_EACH_R6(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) // clang-format off #if CUDNN_VERSION >= 7000 #define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \ - __macro(cudnnSetConvolutionMathType) + __macro(cudnnSetConvolutionMathType) \ + __macro(cudnnSetRNNMatrixMathType) // clang-format on CUDNN_DNN_ROUTINE_EACH_R7(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) @@ -586,6 +587,19 @@ static bool TensorOpMathEnabled() { return is_enabled; } +// A helper function to decide whether to enable the TENSOR_OP_MATH math type +// for RNNs. +static bool RnnTensorOpMathEnabled() { + static bool is_enabled = [] { + bool is_disabled = false; + TF_CHECK_OK( + tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUDNN_RNN_TENSOR_OP_MATH", + /*default_val=*/false, &is_disabled)); + return !is_disabled; + }(); + return is_enabled; +} + // A helper function to decide whether to use CUDNN_BATCHNORM_SPATIAL_PERSISTENT // in batchnorm. This mode can be faster in some tasks because an optimized path // may be selected for CUDNN_DATA_FLOAT and CUDNN_DATA_HALF data types, compute @@ -1124,6 +1138,9 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon { SetFailure(cudnn_params_desc_->Status()); return; } + if (data_type == CUDNN_DATA_HALF) { + set_use_tensor_op_math(true); + } } ~CudnnRnnDescriptor() override { if (rnn_desc_) { @@ -1132,6 +1149,20 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon { CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN descriptor"); } } + void set_use_tensor_op_math(bool use_tensor_op_math) { +#if CUDNN_VERSION >= 7000 + cudnnMathType_t math_type = + (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH); + if (RnnTensorOpMathEnabled()) { + cudnnStatus_t status = + wrap::cudnnSetRNNMatrixMathType(parent_, rnn_desc_, math_type); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(FATAL) << "could not set cudnn RNN math type: " + << ToString(status); + } + } +#endif + } cudnnRNNDescriptor_t handle() const { if (!ok()) return nullptr; return rnn_desc_; @@ -2281,7 +2312,6 @@ struct ConvDoFP32ComputationFP16Input { // A group of helper functions to return the internal compute type for // convolutions in cudnn. -// TODO(yangzihao): Add support for float64. template cudnnDataType_t GetConvComputeType() { return CUDNN_DATA_FLOAT; @@ -2296,6 +2326,11 @@ cudnnDataType_t GetConvComputeType() { } } +template <> +cudnnDataType_t GetConvComputeType() { + return CUDNN_DATA_DOUBLE; +} + } // namespace template @@ -2324,9 +2359,15 @@ bool CudnnSupport::DoConvolveImpl( LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status); } // Alpha is the scaling factor for input. - float alpha = 1.0; + float falpha = 1.0; + double dalpha = 1.0; + void* alpha = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast(&dalpha) + : static_cast(&falpha); // Beta is the scaling factor for output. - float beta = 0.0; + float fbeta = 0.0; + double dbeta = 0.0; + void* beta = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast(&dbeta) + : static_cast(&fbeta); const bool is_profiling = output_profile_result != nullptr; cudnnConvolutionFwdAlgo_t algo; @@ -2464,11 +2505,11 @@ bool CudnnSupport::DoConvolveImpl( } status = wrap::cudnnConvolutionForward( parent_, ToHandle(dnn_handle_), - /*alpha=*/&alpha, /*srcDesc=*/input_nd.handle(), + /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(), /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(), /*algo=*/algo, /*workSpace=*/scratch.opaque(), - /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/&beta, + /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/beta, /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque()); if (is_profiling) { @@ -2943,10 +2984,14 @@ bool CudnnSupport::DoConvolve( const FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, const ConvolutionDescriptor& convolution_descriptor, - const BatchDescriptor& output_descriptor, - DeviceMemory* output_data) { - LOG(ERROR) << "double-based DNN not yet implemented"; - return false; + const BatchDescriptor& output_descriptor, DeviceMemory* output_data, + ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + dnn::ProfileResult* output_profile_result) { + return DoConvolveImpl( + stream, batch_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, output_descriptor, output_data, scratch_allocator, + algorithm_config, output_profile_result); } bool CudnnSupport::DoConvolve( @@ -3112,12 +3157,18 @@ bool CudnnSupport::DoTransformTensor(Stream* stream, dnn::DataType output_type, float scale, DeviceMemoryBase* output_data) { mutex_lock lock{dnn_handle_mutex_}; + cudnnStatus_t status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_), + AsCUDAStreamValue(stream)); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status); + } + float beta = 0.0f; ScopedTensorDescriptor input_tensor_desc( parent_, input_desc, ToCudnnDataType(input_type, input_desc.layout())); ScopedTensorDescriptor output_tensor_desc( parent_, output_desc, ToCudnnDataType(output_type, output_desc.layout())); - cudnnStatus_t status = wrap::cudnnTransformTensor( + status = wrap::cudnnTransformTensor( parent_, ToHandle(dnn_handle_), &scale, input_tensor_desc.handle(), input_data.opaque(), &beta, output_tensor_desc.handle(), output_data->opaque()); @@ -3151,10 +3202,17 @@ bool CudnnSupport::DoConvolveBackwardDataImpl( LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status); } + cudnnDataType_t cudnn_type = GetCudnnDataType(); // Alpha is the scaling factor for input. - float alpha = 1.0; + float falpha = 1.0; + double dalpha = 1.0; + void* alpha = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast(&dalpha) + : static_cast(&falpha); // Beta is the scaling factor for output. - float beta = 0.0; + float fbeta = 0.0; + double dbeta = 0.0; + void* beta = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast(&dbeta) + : static_cast(&fbeta); // TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass. BatchDescriptor output_descriptor; @@ -3163,7 +3221,6 @@ bool CudnnSupport::DoConvolveBackwardDataImpl( backward_output_data = MaybeTransformLayout( stream, &output_descriptor, backward_output_data, &transform_scratch); - cudnnDataType_t cudnn_type = GetCudnnDataType(); ScopedTensorDescriptor out_back_nd{parent_, output_descriptor, cudnn_type}; ScopedTensorDescriptor in_back_nd{parent_, input_descriptor, cudnn_type}; ScopedFilterDescriptor filter{parent_, filter_descriptor, input_descriptor, @@ -3310,7 +3367,7 @@ bool CudnnSupport::DoConvolveBackwardDataImpl( status = wrap::cudnnConvolutionBackwardData_v3( #endif parent_, ToHandle(dnn_handle_), - /*alpha=*/&alpha, + /*alpha=*/alpha, /*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(), /*diffDesc=*/out_back_nd.handle(), @@ -3319,7 +3376,7 @@ bool CudnnSupport::DoConvolveBackwardDataImpl( /*algo=*/algo, /*workSpace=*/scratch.opaque(), /*workSpaceSizeInBytes=*/scratch.size(), - /*beta=*/&beta, + /*beta=*/beta, /*gradDesc=*/in_back_nd.handle(), /*gradData=*/backward_input_data->opaque()); if (is_profiling) { @@ -3344,10 +3401,28 @@ bool CudnnSupport::DoConvolveBackwardDataImpl( return true; } +bool CudnnSupport::DoConvolveBackwardData( + Stream* stream, const FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const BatchDescriptor& output_descriptor, + DeviceMemory backward_output_data, + const ConvolutionDescriptor& convolution_descriptor, + const BatchDescriptor& input_descriptor, + DeviceMemory* backward_input_data, + ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + dnn::ProfileResult* output_profile_result) { + return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data, + output_descriptor, backward_output_data, + convolution_descriptor, input_descriptor, + backward_input_data, scratch_allocator, + algorithm_config, output_profile_result); +} + bool CudnnSupport::DoConvolveBackwardData( Stream* stream, const FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, - const BatchDescriptor& output_descriptor_in, + const BatchDescriptor& output_descriptor, DeviceMemory backward_output_data, const ConvolutionDescriptor& convolution_descriptor, const BatchDescriptor& input_descriptor, @@ -3356,7 +3431,7 @@ bool CudnnSupport::DoConvolveBackwardData( const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data, - output_descriptor_in, backward_output_data, + output_descriptor, backward_output_data, convolution_descriptor, input_descriptor, backward_input_data, scratch_allocator, algorithm_config, output_profile_result); @@ -3365,7 +3440,7 @@ bool CudnnSupport::DoConvolveBackwardData( bool CudnnSupport::DoConvolveBackwardData( Stream* stream, const FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, - const BatchDescriptor& output_descriptor_in, + const BatchDescriptor& output_descriptor, DeviceMemory backward_output_data, const ConvolutionDescriptor& convolution_descriptor, const BatchDescriptor& input_descriptor, @@ -3374,7 +3449,7 @@ bool CudnnSupport::DoConvolveBackwardData( const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data, - output_descriptor_in, backward_output_data, + output_descriptor, backward_output_data, convolution_descriptor, input_descriptor, backward_input_data, scratch_allocator, algorithm_config, output_profile_result); @@ -3398,10 +3473,17 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl( LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status); } + cudnnDataType_t cudnn_type = GetCudnnDataType(); // Alpha is the scaling factor for input. - float alpha = 1.0; + float falpha = 1.0; + double dalpha = 1.0; + void* alpha = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast(&dalpha) + : static_cast(&falpha); // Beta is the scaling factor for output. - float beta = 0.0; + float fbeta = 0.0; + double dbeta = 0.0; + void* beta = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast(&dbeta) + : static_cast(&fbeta); // TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass. BatchDescriptor output_descriptor; @@ -3410,7 +3492,6 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl( backward_output_data = MaybeTransformLayout( stream, &output_descriptor, backward_output_data, &transform_scratch); - cudnnDataType_t cudnn_type = GetCudnnDataType(); ScopedTensorDescriptor out_back_nd{parent_, output_descriptor, cudnn_type}; ScopedTensorDescriptor input_nd{parent_, input_descriptor, cudnn_type}; ScopedFilterDescriptor filter{parent_, filter_descriptor, input_descriptor, @@ -3557,7 +3638,7 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl( #else status = wrap::cudnnConvolutionBackwardFilter_v3( #endif - parent_, ToHandle(dnn_handle_), /*alpha=*/&alpha, + parent_, ToHandle(dnn_handle_), /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(), /*srcData=*/input_data.opaque(), /*diffDesc=*/out_back_nd.handle(), @@ -3566,7 +3647,7 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl( /*algo=*/algo, /*workSpace=*/scratch.opaque(), /*workSpaceSizeInBytes=*/scratch.size(), - /*beta=*/&beta, + /*beta=*/beta, /*gradDesc=*/filter.handle(), /*gradData=*/backward_filter_data->opaque()); @@ -3592,10 +3673,28 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl( return true; } +bool CudnnSupport::DoConvolveBackwardFilter( + Stream* stream, const dnn::BatchDescriptor& input_descriptor, + const DeviceMemory& input_data, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemory backward_output_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + const dnn::FilterDescriptor& filter_descriptor, + DeviceMemory* backward_filter_data, + ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + dnn::ProfileResult* output_profile_result) { + return DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data, + output_descriptor, backward_output_data, + convolution_descriptor, filter_descriptor, + backward_filter_data, scratch_allocator, + algorithm_config, output_profile_result); +} + bool CudnnSupport::DoConvolveBackwardFilter( Stream* stream, const dnn::BatchDescriptor& input_descriptor, const DeviceMemory& input_data, - const dnn::BatchDescriptor& output_descriptor_in, + const dnn::BatchDescriptor& output_descriptor, DeviceMemory backward_output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, const dnn::FilterDescriptor& filter_descriptor, @@ -3603,17 +3702,17 @@ bool CudnnSupport::DoConvolveBackwardFilter( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveBackwardFilterImpl( - stream, input_descriptor, input_data, output_descriptor_in, - backward_output_data, convolution_descriptor, filter_descriptor, - backward_filter_data, scratch_allocator, algorithm_config, - output_profile_result); + return DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data, + output_descriptor, backward_output_data, + convolution_descriptor, filter_descriptor, + backward_filter_data, scratch_allocator, + algorithm_config, output_profile_result); } bool CudnnSupport::DoConvolveBackwardFilter( Stream* stream, const dnn::BatchDescriptor& input_descriptor, const DeviceMemory& input_data, - const dnn::BatchDescriptor& output_descriptor_in, + const dnn::BatchDescriptor& output_descriptor, DeviceMemory backward_output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, const dnn::FilterDescriptor& filter_descriptor, @@ -3621,11 +3720,11 @@ bool CudnnSupport::DoConvolveBackwardFilter( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveBackwardFilterImpl( - stream, input_descriptor, input_data, output_descriptor_in, - backward_output_data, convolution_descriptor, filter_descriptor, - backward_filter_data, scratch_allocator, algorithm_config, - output_profile_result); + return DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data, + output_descriptor, backward_output_data, + convolution_descriptor, filter_descriptor, + backward_filter_data, scratch_allocator, + algorithm_config, output_profile_result); } template diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index 40aa974dd967df50075da6f2bb34439cd238a113..48d56f71e3195a897b6216ab9f5709326d1b86d3 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -259,7 +259,10 @@ class CudnnSupport : public dnn::DnnSupport { const DeviceMemory& filter_data, const dnn::ConvolutionDescriptor& convolution_descriptor, const dnn::BatchDescriptor& output_descriptor, - DeviceMemory* output_data) override; + DeviceMemory* output_data, + ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + dnn::ProfileResult* output_profile_result) override; bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor, const DeviceMemory& input_data, @@ -371,6 +374,18 @@ class CudnnSupport : public dnn::DnnSupport { return false; } + bool DoConvolveBackwardData( + Stream* stream, const dnn::FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemory backward_output_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + const dnn::BatchDescriptor& input_descriptor, + DeviceMemory* backward_input_data, + ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + dnn::ProfileResult* output_profile_result) override; + bool DoConvolveBackwardData( Stream* stream, const dnn::FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, @@ -395,6 +410,18 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) override; + bool DoConvolveBackwardFilter( + Stream* stream, const dnn::BatchDescriptor& input_descriptor, + const DeviceMemory& input_data, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemory backward_output_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + const dnn::FilterDescriptor& filter_descriptor, + DeviceMemory* backward_filter_data, + ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + dnn::ProfileResult* output_profile_result) override; + bool DoConvolveBackwardFilter( Stream* stream, const dnn::BatchDescriptor& input_descriptor, const DeviceMemory& input_data, diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc index a017ff64d4c69b6952b442464877dc26a800ad37..58e1e58c593a3d938d97baff2356bce2c215a7a1 100644 --- a/tensorflow/stream_executor/cuda/cuda_driver.cc +++ b/tensorflow/stream_executor/cuda/cuda_driver.cc @@ -1503,6 +1503,19 @@ static port::StatusOr GetSimpleAttribute(CUdevice device, return true; } +/* static */ port::StatusOr CUDADriver::GetDeviceAttribute( + CUdevice_attribute attribute, CUdevice device) { + int val; + CUresult res = cuDeviceGetAttribute(&val, attribute, device); + if (res != CUDA_SUCCESS) { + return port::Status{ + port::error::INTERNAL, + port::Printf("failed to get device attribute %d for device %d: %s", + attribute, device, ToString(res).c_str())}; + } + return val; +} + /* static */ bool CUDADriver::IsEccEnabled(CUdevice device, bool *result) { int value = -1; CUresult res = diff --git a/tensorflow/stream_executor/cuda/cuda_driver.h b/tensorflow/stream_executor/cuda/cuda_driver.h index 4002ba2021d1a2e2c36bd1786a3084ee8c08bb78..fa9172b3f008d3083309126bbfa4a1ab961030e1 100644 --- a/tensorflow/stream_executor/cuda/cuda_driver.h +++ b/tensorflow/stream_executor/cuda/cuda_driver.h @@ -400,12 +400,20 @@ class CUDADriver { // Returns a grab-bag of device properties in a caller-owned device_properties // structure for device_ordinal via cuDeviceGetProperties. - // This call is deprecated in the NVIDIA driver API. + // + // This call is deprecated in the NVIDIA driver API; its replacement is + // GetDeviceAttribute // // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html#group__CUDA__DEVICE__DEPRECATED_1g65a5b4e25186bd257df80b98c98cffe6 static bool GetDeviceProperties(CUdevprop *device_properties, int device_ordinal); + // Gets a specific integer-valued property about the given device. + // + // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g9c3e1414f0ad901d3278a4d6645fc266 + static port::StatusOr GetDeviceAttribute(CUdevice_attribute attribute, + CUdevice device); + // Returns whether ECC is enabled for the given CUdevice via // cuDeviceGetattribute with CU_DEVICE_ATTRIBUTE_ECC_ENABLED. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g9c3e1414f0ad901d3278a4d6645fc266 diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index 4bbd531e14f18fc24d87b4fa655fe72e9f56b129..5ecaf46b8cae3c1e1f312816e7e5aec8ff8ce306 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -1103,6 +1103,18 @@ DeviceDescription *CUDAExecutor::PopulateDeviceDescription() const { builder.set_device_memory_size(device_memory_size); } + port::StatusOr mem_clock_khz = CUDADriver::GetDeviceAttribute( + CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device_ordinal_); + port::StatusOr mem_bus_width_bits = CUDADriver::GetDeviceAttribute( + CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device_ordinal_); + if (mem_clock_khz.ok() && mem_bus_width_bits.ok()) { + // Times 2 because HBM is DDR memory; it gets two data bits per each data + // lane. + builder.set_memory_bandwidth(2 * int64_t{mem_clock_khz.ValueOrDie()} * + 1000 * + int64_t{mem_bus_width_bits.ValueOrDie()} / 8); + } + { BlockDim block_dim_limit; FillBlockDimLimit(&block_dim_limit); diff --git a/tensorflow/stream_executor/device_description.cc b/tensorflow/stream_executor/device_description.cc index a98143e34bbb42c3aee76c27e1648c49397a0e44..52f5319a3b16c771ce89843a963841b25df5467e 100644 --- a/tensorflow/stream_executor/device_description.cc +++ b/tensorflow/stream_executor/device_description.cc @@ -50,6 +50,7 @@ DeviceDescription::DeviceDescription() shared_memory_alloc_granularity_(1), device_address_bits_(kUninitializedUint64), device_memory_size_(kUninitializedUint64), + memory_bandwidth_(kUninitializedUint64), shared_memory_per_core_(kUninitializedUint64), shared_memory_per_block_(kUninitializedUint64), clock_rate_ghz_(-1.0), @@ -85,6 +86,8 @@ std::unique_ptr> DeviceDescription::ToMap() const { result["Device Address Bits"] = port::StrCat(device_address_bits()); result["Device Memory Size"] = port::HumanReadableNumBytes::ToString(device_memory_size()); + result["Memory Bandwidth"] = port::StrCat( + port::HumanReadableNumBytes::ToString(memory_bandwidth_), "/s"); result["Shared Memory Per Core"] = port::HumanReadableNumBytes::ToString(shared_memory_per_core_); diff --git a/tensorflow/stream_executor/device_description.h b/tensorflow/stream_executor/device_description.h index f2b35bcb4345a37f72541979564cbbb7944595c2..fcf0928096ed1f1bdf0499efb92af2bc9cb0eaa2 100644 --- a/tensorflow/stream_executor/device_description.h +++ b/tensorflow/stream_executor/device_description.h @@ -140,6 +140,11 @@ class DeviceDescription { // Returns the device memory size in bytes. uint64 device_memory_size() const { return device_memory_size_; } + // Returns the device's memory bandwidth in bytes/sec. (This is for + // reads/writes to/from the device's own memory, not for transfers between the + // host and device.) + uint64 memory_bandwidth() const { return memory_bandwidth_; } + // Returns the device's core clock rate in GHz. float clock_rate_ghz() const { return clock_rate_ghz_; } @@ -212,6 +217,7 @@ class DeviceDescription { uint64 device_address_bits_; uint64 device_memory_size_; + uint64 memory_bandwidth_; // Shared memory limits on a given device. uint64 shared_memory_per_core_; @@ -305,6 +311,9 @@ class DeviceDescriptionBuilder { void set_device_memory_size(uint64 value) { device_description_->device_memory_size_ = value; } + void set_memory_bandwidth(uint64 value) { + device_description_->memory_bandwidth_ = value; + } void set_shared_memory_per_core(int64 value) { device_description_->shared_memory_per_core_ = value; diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index aa88fe770f3596e5da5e12705c3b706365382134..b41536e638873412a31a0cdbbd3ba3a818dd9cf2 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -1172,7 +1172,9 @@ class DnnSupport { const DeviceMemory& filter_data, const dnn::ConvolutionDescriptor& convolution_descriptor, const dnn::BatchDescriptor& output_descriptor, - DeviceMemory* output_data) = 0; + DeviceMemory* output_data, ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + dnn::ProfileResult* output_profile_result) = 0; // Enqueues a half-precision convolution operation onto the stream. // See DoConvolve above for argument details. @@ -1273,6 +1275,18 @@ class DnnSupport { bool with_winograd_nonfused, int cc_major, int cc_minor, std::vector* out_algorithms); + virtual bool DoConvolveBackwardData( + Stream* stream, const FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const BatchDescriptor& output_descriptor, + DeviceMemory backward_output_data, + const ConvolutionDescriptor& convolution_descriptor, + const BatchDescriptor& input_descriptor, + DeviceMemory* backward_input_data, + ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + ProfileResult* output_profile_result) = 0; + virtual bool DoConvolveBackwardData( Stream* stream, const FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, @@ -1322,6 +1336,18 @@ class DnnSupport { bool with_winograd_nonfused, int cc_major, int cc_minor, std::vector* out_algorithms); + virtual bool DoConvolveBackwardFilter( + Stream* stream, const BatchDescriptor& input_descriptor, + const DeviceMemory& input_data, + const BatchDescriptor& output_descriptor, + DeviceMemory backward_output_data, + const ConvolutionDescriptor& convolution_descriptor, + const FilterDescriptor& filter_descriptor, + DeviceMemory* backward_filter_data, + ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + ProfileResult* output_profile_result) = 0; + virtual bool DoConvolveBackwardFilter( Stream* stream, const BatchDescriptor& input_descriptor, const DeviceMemory& input_data, diff --git a/tensorflow/stream_executor/multi_platform_manager.cc b/tensorflow/stream_executor/multi_platform_manager.cc index f23224ae772b9c5915426feaef1155fc9711f075..f9f3737a06dad3f146ef9fc8e2ec50160b3a01b5 100644 --- a/tensorflow/stream_executor/multi_platform_manager.cc +++ b/tensorflow/stream_executor/multi_platform_manager.cc @@ -23,11 +23,37 @@ limitations under the License. namespace perftools { namespace gputools { +/* static */ mutex MultiPlatformManager::platforms_mutex_{LINKER_INITIALIZED}; + +/* static */ port::StatusOr MultiPlatformManager::LookupByNameLocked( + const string& target) { + PlatformMap* platform_map = GetPlatformMap(); + auto it = platform_map->find(port::Lowercase(target)); + if (it == platform_map->end()) { + return port::Status( + port::error::NOT_FOUND, + "could not find registered platform with name: \"" + target + "\""); + } + return it->second; +} + +/* static */ port::StatusOr MultiPlatformManager::LookupByIdLocked( + const Platform::Id& id) { + PlatformIdMap* platform_map = GetPlatformByIdMap(); + auto it = platform_map->find(id); + if (it == platform_map->end()) { + return port::Status( + port::error::NOT_FOUND, + port::Printf("could not find registered platform with id: 0x%p", id)); + } + return it->second; +} + /* static */ port::Status MultiPlatformManager::RegisterPlatform( std::unique_ptr platform) { CHECK(platform != nullptr); string key = port::Lowercase(platform->Name()); - mutex_lock lock(GetPlatformsMutex()); + mutex_lock lock(platforms_mutex_); if (GetPlatformMap()->find(key) != GetPlatformMap()->end()) { return port::Status(port::error::INTERNAL, "platform is already registered with name: \"" + @@ -45,33 +71,63 @@ namespace gputools { /* static */ port::StatusOr MultiPlatformManager::PlatformWithName( const string& target) { - tf_shared_lock lock(GetPlatformsMutex()); - auto it = GetPlatformMap()->find(port::Lowercase(target)); + mutex_lock lock(platforms_mutex_); - if (it == GetPlatformMap()->end()) { - return port::Status( - port::error::NOT_FOUND, - "could not find registered platform with name: \"" + target + "\""); + SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target)); + if (!platform->Initialized()) { + SE_RETURN_IF_ERROR(platform->Initialize({})); } - return it->second; + return platform; } /* static */ port::StatusOr MultiPlatformManager::PlatformWithId( const Platform::Id& id) { - tf_shared_lock lock(GetPlatformsMutex()); - auto it = GetPlatformByIdMap()->find(id); - if (it == GetPlatformByIdMap()->end()) { + mutex_lock lock(platforms_mutex_); + + SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id)); + if (!platform->Initialized()) { + SE_RETURN_IF_ERROR(platform->Initialize({})); + } + + return platform; +} + +/* static */ port::StatusOr +MultiPlatformManager::InitializePlatformWithName( + const string& target, const std::map& options) { + mutex_lock lock(platforms_mutex_); + + SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target)); + if (platform->Initialized()) { + return port::Status(port::error::FAILED_PRECONDITION, + "platform \"" + target + "\" is already initialized"); + } + + SE_RETURN_IF_ERROR(platform->Initialize(options)); + + return platform; +} + +/* static */ port::StatusOr +MultiPlatformManager::InitializePlatformWithId( + const Platform::Id& id, const std::map& options) { + mutex_lock lock(platforms_mutex_); + + SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id)); + if (platform->Initialized()) { return port::Status( - port::error::NOT_FOUND, - port::Printf("could not find registered platform with id: 0x%p", id)); + port::error::FAILED_PRECONDITION, + port::Printf("platform with id 0x%p is already initialized", id)); } - return it->second; + SE_RETURN_IF_ERROR(platform->Initialize(options)); + + return platform; } /* static */ void MultiPlatformManager::ClearPlatformRegistry() { - mutex_lock lock(GetPlatformsMutex()); + mutex_lock lock(platforms_mutex_); GetPlatformMap()->clear(); GetPlatformByIdMap()->clear(); } diff --git a/tensorflow/stream_executor/multi_platform_manager.h b/tensorflow/stream_executor/multi_platform_manager.h index ea6155b4826439b98262530e70e6463eb1fda237..438653ee20bdb1fd83cd9e75c4bcd35af277cc28 100644 --- a/tensorflow/stream_executor/multi_platform_manager.h +++ b/tensorflow/stream_executor/multi_platform_manager.h @@ -67,13 +67,13 @@ limitations under the License. #include #include #include -#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" +#include "tensorflow/stream_executor/platform/thread_annotations.h" namespace perftools { namespace gputools { @@ -85,26 +85,43 @@ class MultiPlatformManager { // already registered. The associated listener, if not null, will be used to // trace events for ALL executors for that platform. // Takes ownership of listener. - static port::Status RegisterPlatform(std::unique_ptr platform); + static port::Status RegisterPlatform(std::unique_ptr platform) + LOCKS_EXCLUDED(platforms_mutex_); - // Retrieves the platform registered with the given platform name; e.g. - // "CUDA", "OpenCL", ... + // Retrieves the platform registered with the given platform name (e.g. + // "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the + // Platform's Id() method). + // + // If the platform has not already been initialized, it will be initialized + // with a default set of parameters. // // If the requested platform is not registered, an error status is returned. // Ownership of the platform is NOT transferred to the caller -- // the MultiPlatformManager owns the platforms in a singleton-like fashion. - static port::StatusOr PlatformWithName(const string& target); - - // Retrieves the platform registered with the given platform ID, which - // is an opaque (but comparable) value. + static port::StatusOr PlatformWithName(const string& target) + LOCKS_EXCLUDED(platforms_mutex_); + static port::StatusOr PlatformWithId(const Platform::Id& id) + LOCKS_EXCLUDED(platforms_mutex_); + + // Retrieves the platform registered with the given platform name (e.g. + // "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the + // Platform's Id() method). + // + // The platform will be initialized with the given options. If the platform + // was already initialized, an error will be returned. // // If the requested platform is not registered, an error status is returned. // Ownership of the platform is NOT transferred to the caller -- // the MultiPlatformManager owns the platforms in a singleton-like fashion. - static port::StatusOr PlatformWithId(const Platform::Id& id); + static port::StatusOr InitializePlatformWithName( + const string& target, const std::map& options) + LOCKS_EXCLUDED(platforms_mutex_); + static port::StatusOr InitializePlatformWithId( + const Platform::Id& id, const std::map& options) + LOCKS_EXCLUDED(platforms_mutex_); // Clears the set of registered platforms, primarily used for testing. - static void ClearPlatformRegistry(); + static void ClearPlatformRegistry() LOCKS_EXCLUDED(platforms_mutex_); // Although the MultiPlatformManager "owns" its platforms, it holds them as // undecorated pointers to prevent races during program exit (between this @@ -122,17 +139,16 @@ class MultiPlatformManager { // Provides access to the available set of platforms under a lock. static port::Status WithPlatforms( - std::function callback) { - mutex_lock lock(GetPlatformsMutex()); + std::function callback) + LOCKS_EXCLUDED(platforms_mutex_) { + mutex_lock lock(platforms_mutex_); return callback(GetPlatformMap()); } private: - // mutex that guards the platform map. - static mutex& GetPlatformsMutex() { - static mutex* platforms_mutex = new mutex; - return *platforms_mutex; - } + using PlatformIdMap = std::map; + + static mutex platforms_mutex_; // TODO(b/22689637): Clean up these two maps; make sure they coexist nicely. // TODO(b/22689637): Move this (whatever the final/"official" map is) to @@ -147,12 +163,21 @@ class MultiPlatformManager { // Holds a Platform::Id-to-object mapping. // Unlike platforms_ above, this map does not own its contents. - static std::map* GetPlatformByIdMap() { - using PlatformIdMap = std::map; + static PlatformIdMap* GetPlatformByIdMap() { static PlatformIdMap* instance = new PlatformIdMap; return instance; } + // Looks up the platform object with the given name. Assumes the Platforms + // mutex is held. + static port::StatusOr LookupByNameLocked(const string& target) + EXCLUSIVE_LOCKS_REQUIRED(platforms_mutex_); + + // Looks up the platform object with the given id. Assumes the Platforms + // mutex is held. + static port::StatusOr LookupByIdLocked(const Platform::Id& id) + EXCLUSIVE_LOCKS_REQUIRED(platforms_mutex_); + SE_DISALLOW_COPY_AND_ASSIGN(MultiPlatformManager); }; diff --git a/tensorflow/stream_executor/platform.cc b/tensorflow/stream_executor/platform.cc index 93f08d06dae862f24b5b533395af63139f344f77..4cdc22bd16a3ea66037696f6a9d70bcb86ef5ebb 100644 --- a/tensorflow/stream_executor/platform.cc +++ b/tensorflow/stream_executor/platform.cc @@ -85,6 +85,17 @@ StreamExecutorConfig::StreamExecutorConfig(int ordinal_in) Platform::~Platform() {} +bool Platform::Initialized() const { return true; } + +port::Status Platform::Initialize( + const std::map &platform_options) { + if (!platform_options.empty()) { + return port::Status(port::error::UNIMPLEMENTED, + "this platform does not support custom initialization"); + } + return port::Status::OK(); +} + port::Status Platform::ForceExecutorShutdown() { return port::Status(port::error::UNIMPLEMENTED, "executor shutdown is not supported on this platform"); diff --git a/tensorflow/stream_executor/platform.h b/tensorflow/stream_executor/platform.h index f0a0e60e02f951018b39ef831cd2f7dd3256f87d..54f8aa86c269ff0d32648e1d4629179cafd5be76 100644 --- a/tensorflow/stream_executor/platform.h +++ b/tensorflow/stream_executor/platform.h @@ -111,6 +111,9 @@ class Platform { // Returns a key uniquely identifying this platform. virtual Id id() const = 0; + // Name of this platform. + virtual const string& Name() const = 0; + // Returns the number of devices accessible on this platform. // // Note that, though these devices are visible, if there is only one userspace @@ -118,8 +121,17 @@ class Platform { // device, a call to ExecutorForDevice may return an error status. virtual int VisibleDeviceCount() const = 0; - // Name of this platform. - virtual const string& Name() const = 0; + // Returns true iff the platform has been initialized. + virtual bool Initialized() const; + + // Initializes the platform with a custom set of options. The platform must be + // initialized before obtaining StreamExecutor objects. The interpretation of + // the platform_options argument is implementation specific. This method may + // return an error if unrecognized options are provided. If using + // MultiPlatformManager, this method will be called automatically by + // InitializePlatformWithId/InitializePlatformWithName. + virtual port::Status Initialize( + const std::map& platform_options); // Returns a device with the given ordinal on this platform with a default // plugin configuration or, if none can be found with the given ordinal or @@ -156,6 +168,8 @@ class Platform { // This is only useful on platforms which bind a device to a single process // that has obtained the device context. May return UNIMPLEMENTED on platforms // that have no reason to destroy device contexts. + // + // The platform must be reinitialized after this is called. virtual port::Status ForceExecutorShutdown(); // Registers a TraceListener to listen to all StreamExecutors for this diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index ba5001e273632c893b05eea64542f1b156e28c47..1e3afde2687657e417e9e2cb3f5e2aaf0600da7a 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/port.h" +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/host_buffer.h" #include "tensorflow/stream_executor/lib/stacktrace.h" @@ -117,7 +118,9 @@ string ToVlogString(const DeviceMemoryBase *memory) { return ToVlogString(*memory); } -string ToVlogString(const Eigen::half &h) { return port::StrCat(h); } +string ToVlogString(const Eigen::half &h) { + return port::StrCat(static_cast(h)); +} string ToVlogString(int i) { return port::StrCat(i); } @@ -681,6 +684,37 @@ Stream &Stream::ThenFusedConvolveWithAlgorithm( return *this; } +Stream &Stream::ThenConvolveWithAlgorithm( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::BatchDescriptor &output_descriptor, DeviceMemory *output, + ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), + PARAM(filter_descriptor), PARAM(filter_data), + PARAM(convolution_descriptor), PARAM(output_descriptor), + PARAM(output), PARAM(algorithm_config)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + auto status = dnn->DoConvolve( + this, input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, output_descriptor, output, scratch_allocator, + algorithm_config, output_profile_result); + if (!status && !output_profile_result) { + SetError(); + } + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + Stream &Stream::ThenConvolveWithAlgorithm( const dnn::BatchDescriptor &input_descriptor, const DeviceMemory &input_data, @@ -890,6 +924,39 @@ Stream &Stream::ThenConvolveBackwardDataWithScratch( return *this; } +Stream &Stream::ThenConvolveBackwardDataWithAlgorithm( + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory backward_output_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::BatchDescriptor &input_descriptor, + DeviceMemory *backward_input_data, + ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data), + PARAM(output_descriptor), PARAM(backward_output_data), + PARAM(convolution_descriptor), PARAM(input_descriptor), + PARAM(backward_input_data)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + auto status = dnn->DoConvolveBackwardData( + this, filter_descriptor, filter_data, output_descriptor, + backward_output_data, convolution_descriptor, input_descriptor, + backward_input_data, scratch_allocator, algorithm_config, + output_profile_result); + if (!status && !output_profile_result) { + SetError(); + } + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + Stream &Stream::ThenConvolveBackwardDataWithAlgorithm( const dnn::FilterDescriptor &filter_descriptor, const DeviceMemory &filter_data, @@ -1026,6 +1093,39 @@ Stream &Stream::ThenConvolveBackwardFilterWithScratch( return *this; } +Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory backward_output_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::FilterDescriptor &filter_descriptor, + DeviceMemory *backward_filter_data, + ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), + PARAM(output_descriptor), PARAM(backward_output_data), + PARAM(convolution_descriptor), PARAM(filter_descriptor), + PARAM(backward_filter_data)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + auto status = dnn->DoConvolveBackwardFilter( + this, input_descriptor, input_data, output_descriptor, + backward_output_data, convolution_descriptor, filter_descriptor, + backward_filter_data, scratch_allocator, algorithm_config, + output_profile_result); + if (!status && !output_profile_result) { + SetError(); + } + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm( const dnn::BatchDescriptor &input_descriptor, const DeviceMemory &input_data, @@ -4923,12 +5023,6 @@ Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc, return *this; } -Stream &Stream::ThenDoHostCallbackForTest(std::function callback) { - VLOG_CALL(PARAM(callback)); - - return ThenDoHostCallback(callback); -} - Stream &Stream::ThenDoHostCallback(std::function callback) { VLOG_CALL(PARAM(callback)); diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index a2fb2ea2375d0f245ae3bf3ccb04803d01663def..d7d11315699b85cae4d479b79bc8fc2717b2d8fb 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -358,6 +358,17 @@ class Stream { const dnn::BatchDescriptor &output_descriptor, DeviceMemory *output, ScratchAllocator *scratch_allocator); + Stream &ThenConvolveWithAlgorithm( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory *output, ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result); + Stream &ThenConvolveWithAlgorithm( const dnn::BatchDescriptor &input_descriptor, const DeviceMemory &input_data, @@ -476,6 +487,18 @@ class Stream { DeviceMemory *backward_input_data, ScratchAllocator *scratch_allocator); + Stream &ThenConvolveBackwardDataWithAlgorithm( + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory &filter_data, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory backward_output_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::BatchDescriptor &input_descriptor, + DeviceMemory *backward_input_data, + ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result); + Stream &ThenConvolveBackwardDataWithAlgorithm( const dnn::FilterDescriptor &filter_descriptor, const DeviceMemory &filter_data, @@ -529,6 +552,18 @@ class Stream { DeviceMemory *backward_filter_data, ScratchAllocator *scratch_allocator); + Stream &ThenConvolveBackwardFilterWithAlgorithm( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory &input_data, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory backward_output_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::FilterDescriptor &filter_descriptor, + DeviceMemory *backward_filter_data, + ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result); + Stream &ThenConvolveBackwardFilterWithAlgorithm( const dnn::BatchDescriptor &input_descriptor, const DeviceMemory &input_data, @@ -1933,16 +1968,15 @@ class Stream { // Entrains onto the stream a callback to the host (from the device). // Host callbacks block/occupy the stream just as device functions // (execute one at a time, block later stream operations). + // // Behavior is undefined when synchronizing using OpenCL user events. // Behavior is undefined if host callbacks call device routines or insert // them into any stream. + // // On certain platforms, ThenDoHostCallback is expected to have significant // negative effects on performance. Stream &ThenDoHostCallback(std::function callback); - // Identical to ThenDoHostCallback; only exposed for testing purposes. - Stream &ThenDoHostCallbackForTest(std::function callback); - // Returns the StreamExecutor (parent object) associated with this stream. StreamExecutor *parent() const { CHECK(parent_ != nullptr); diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 818d67f7b5be1e8f2db66b24976a529b361a4990..fcc57d506e38205d8da605653ed67fb645102c35 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -22,6 +22,7 @@ load( load( "//third_party/mkl:build_defs.bzl", "if_mkl", + "if_mkl_lnx_x64" ) def register_extension_info(**kwargs): @@ -34,7 +35,7 @@ def src_to_test_name(src): return src.replace("/", "_").split(".")[0] def full_path(relative_paths): - return [PACKAGE_NAME + "/" + relative for relative in relative_paths] + return [native.package_name() + "/" + relative for relative in relative_paths] # List of proto files for android builds def tf_android_core_proto_sources(core_proto_sources_relative): @@ -202,7 +203,8 @@ def tf_copts(android_optimization_level_override="-O2", is_external=False): "-ftemplate-depth=900"]) + if_cuda(["-DGOOGLE_CUDA=1"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"]) - + if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML", "-fopenmp",]) + + if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML"]) + + if_mkl_lnx_x64(["-fopenmp"]) + if_android_arm(["-mfpu=neon"]) + if_linux_x86_64(["-msse3"]) + if_ios_x86_64(["-msse4.1"]) @@ -265,7 +267,7 @@ def _rpath_linkopts(name): # deployed. Other shared object dependencies (e.g. shared between contrib/ # ops) are picked up as long as they are in either the same or a parent # directory in the tensorflow/ tree. - levels_to_root = PACKAGE_NAME.count("/") + name.count("/") + levels_to_root = native.package_name().count("/") + name.count("/") return select({ clean_dep("//tensorflow:darwin"): [ "-Wl,%s" % (_make_search_paths("@loader_path", levels_to_root),), @@ -905,6 +907,14 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs): if not cuda_deps: cuda_deps = [] + if 'linkstatic' not in kwargs or kwargs['linkstatic'] != 1: + enable_text_relocation_linkopt = select({ + clean_dep("//tensorflow:darwin"): [], + "//conditions:default": ['-Wl,-z,notext'],}) + if 'linkopts' in kwargs: + kwargs['linkopts'] += enable_text_relocation_linkopt + else: + kwargs['linkopts'] = enable_text_relocation_linkopt native.cc_library( deps=deps + if_cuda(cuda_deps + [ clean_dep("//tensorflow/core:cuda"), @@ -1158,22 +1168,6 @@ def transitive_hdrs(name, deps=[], **kwargs): # the libraries in deps. def cc_header_only_library(name, deps=[], includes=[], **kwargs): _transitive_hdrs(name=name + "_gather", deps=deps) - - # We could generalize the following, but rather than complicate things - # here, we'll do the minimal use case for now, and hope bazel comes up - # with a better solution before too long. We'd expect it to compute - # the right include path by itself, but it doesn't, possibly because - # _transitive_hdrs lost some information about the include path. - if "@nsync//:nsync_headers" in deps: - # Buiding tensorflow from @org_tensorflow finds this two up. - nsynch = "../../external/nsync/public" - # Building tensorflow from elsewhere finds it four up. - # Note that native.repository_name() is not yet available in TF's Kokoro. - if REPOSITORY_NAME != "@": - nsynch = "../../" + nsynch - includes = includes[:] - includes.append(nsynch) - native.cc_library(name=name, hdrs=[":" + name + "_gather"], includes=includes, @@ -1182,7 +1176,6 @@ def cc_header_only_library(name, deps=[], includes=[], **kwargs): def tf_custom_op_library_additional_deps(): return [ "@protobuf_archive//:protobuf_headers", - "@nsync//:nsync_headers", clean_dep("//third_party/eigen3"), clean_dep("//tensorflow/core:framework_headers_lib"), ] diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD index e731127a63d792825e15a4b95379517117edebb0..d9b0260c9f254f0b609ecc9094789085bb6586d4 100644 --- a/tensorflow/tools/api/generator/BUILD +++ b/tensorflow/tools/api/generator/BUILD @@ -1,5 +1,6 @@ # Description: # Scripts used to generate TensorFlow Python API. + licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) @@ -21,7 +22,7 @@ py_binary( srcs = ["create_python_api.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow:tensorflow_py", + "//tensorflow/python", ], ) @@ -80,6 +81,7 @@ genrule( "api/keras/datasets/boston_housing/__init__.py", "api/keras/datasets/cifar10/__init__.py", "api/keras/datasets/cifar100/__init__.py", + "api/keras/datasets/fashion_mnist/__init__.py", "api/keras/datasets/imdb/__init__.py", "api/keras/datasets/mnist/__init__.py", "api/keras/datasets/reuters/__init__.py", @@ -102,6 +104,7 @@ genrule( "api/linalg/__init__.py", "api/logging/__init__.py", "api/losses/__init__.py", + "api/manip/__init__.py", "api/metrics/__init__.py", "api/nn/__init__.py", "api/nn/rnn_cell/__init__.py", @@ -124,6 +127,7 @@ genrule( "api/test/__init__.py", "api/train/__init__.py", "api/train/queue_runner/__init__.py", + "api/user_ops/__init__.py", ], cmd = "$(location create_python_api) $(OUTS)", tools = ["create_python_api"], @@ -133,7 +137,9 @@ py_library( name = "python_api", srcs = [":python_api_gen"], srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/contrib:contrib_py", # keep + "//tensorflow/python", # keep ], ) diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py index 1557314939bd85c0467426216f90aa3891ca0ac0..183c4731b8176ece16a70bac421291fd76d748cb 100644 --- a/tensorflow/tools/api/generator/create_python_api.py +++ b/tensorflow/tools/api/generator/create_python_api.py @@ -23,15 +23,13 @@ import collections import os import sys -# This import is needed so that we can traverse over TensorFlow modules. -import tensorflow as tf # pylint: disable=unused-import from tensorflow.python.util import tf_decorator _API_CONSTANTS_ATTR = '_tf_api_constants' _API_NAMES_ATTR = '_tf_api_names' _API_DIR = '/api/' -_CONTRIB_IMPORT = 'from tensorflow import contrib' +_OUTPUT_MODULE = 'tensorflow.tools.api.generator.api' _GENERATED_FILE_HEADER = """\"\"\"Imports for Python API. This file is MACHINE GENERATED! Do not edit. @@ -40,6 +38,11 @@ Generated by: tensorflow/tools/api/generator/create_python_api.py script. """ +class SymbolExposedTwiceError(Exception): + """Raised when different symbols are exported with the same name.""" + pass + + def format_import(source_module_name, source_name, dest_name): """Formats import statement. @@ -64,6 +67,44 @@ def format_import(source_module_name, source_name, dest_name): return 'import %s as %s' % (source_name, dest_name) +class _ModuleImportsBuilder(object): + """Builds a map from module name to imports included in that module.""" + + def __init__(self): + self.module_imports = collections.defaultdict(list) + self._seen_api_names = set() + + def add_import( + self, dest_module_name, source_module_name, source_name, dest_name): + """Adds this import to module_imports. + + Args: + dest_module_name: (string) Module name to add import to. + source_module_name: (string) Module to import from. + source_name: (string) Name of the symbol to import. + dest_name: (string) Import the symbol using this name. + + Raises: + SymbolExposedTwiceError: Raised when an import with the same + dest_name has already been added to dest_module_name. + """ + import_str = format_import(source_module_name, source_name, dest_name) + if import_str in self.module_imports[dest_module_name]: + return + + # Check if we are trying to expose two different symbols with same name. + full_api_name = dest_name + if dest_module_name: + full_api_name = dest_module_name + '.' + full_api_name + if full_api_name in self._seen_api_names: + raise SymbolExposedTwiceError( + 'Trying to export multiple symbols with same name: %s.' % + full_api_name) + self._seen_api_names.add(full_api_name) + + self.module_imports[dest_module_name].append(import_str) + + def get_api_imports(): """Get a map from destination module to formatted imports. @@ -74,7 +115,9 @@ def get_api_imports(): (for e.g. 'from foo import bar') and constant assignments (for e.g. 'FOO = 123'). """ - module_imports = collections.defaultdict(list) + module_imports_builder = _ModuleImportsBuilder() + visited_symbols = set() + # Traverse over everything imported above. Specifically, # we want to traverse over TensorFlow Python modules. for module in sys.modules.values(): @@ -87,48 +130,56 @@ def get_api_imports(): for module_contents_name in dir(module): attr = getattr(module, module_contents_name) + if id(attr) in visited_symbols: + continue # If attr is _tf_api_constants attribute, then add the constants. if module_contents_name == _API_CONSTANTS_ATTR: for exports, value in attr: for export in exports: - names = ['tf'] + export.split('.') + names = export.split('.') dest_module = '.'.join(names[:-1]) - import_str = format_import(module.__name__, value, names[-1]) - module_imports[dest_module].append(import_str) + module_imports_builder.add_import( + dest_module, module.__name__, value, names[-1]) continue _, attr = tf_decorator.unwrap(attr) # If attr is a symbol with _tf_api_names attribute, then # add import for it. if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__: - # The same op might be accessible from multiple modules. - # We only want to consider location where function was defined. - if attr.__module__ != module.__name__: + # If the same symbol is available using multiple names, only create + # imports for it once. + if id(attr) in visited_symbols: continue + visited_symbols.add(id(attr)) for export in attr._tf_api_names: # pylint: disable=protected-access - names = ['tf'] + export.split('.') + names = export.split('.') dest_module = '.'.join(names[:-1]) - import_str = format_import( - module.__name__, module_contents_name, names[-1]) - module_imports[dest_module].append(import_str) + module_imports_builder.add_import( + dest_module, module.__name__, module_contents_name, names[-1]) # Import all required modules in their parent modules. - # For e.g. if we import 'tf.foo.bar.Value'. Then, we also - # import 'bar' in 'tf.foo'. - dest_modules = set(module_imports.keys()) - for dest_module in dest_modules: - dest_module_split = dest_module.split('.') - for dest_submodule_index in range(1, len(dest_module_split)): - dest_submodule = '.'.join(dest_module_split[:dest_submodule_index]) - submodule_import = format_import( - '', dest_module_split[dest_submodule_index], - dest_module_split[dest_submodule_index]) - if submodule_import not in module_imports[dest_submodule]: - module_imports[dest_submodule].append(submodule_import) - - return module_imports + # For e.g. if we import 'foo.bar.Value'. Then, we also + # import 'bar' in 'foo'. + imported_modules = set(module_imports_builder.module_imports.keys()) + for module in imported_modules: + if not module: + continue + module_split = module.split('.') + parent_module = '' # we import submodules in their parent_module + + for submodule_index in range(len(module_split)): + import_from = _OUTPUT_MODULE + if submodule_index > 0: + parent_module += ('.' + module_split[submodule_index-1] if parent_module + else module_split[submodule_index-1]) + import_from += '.' + parent_module + module_imports_builder.add_import( + parent_module, import_from, module_split[submodule_index], + module_split[submodule_index]) + + return module_imports_builder.module_imports def create_api_files(output_files): @@ -151,8 +202,8 @@ def create_api_files(output_files): # First get module directory under _API_DIR. module_dir = os.path.dirname( output_file[output_file.rfind(_API_DIR)+len(_API_DIR):]) - # Convert / to . and prefix with tf. - module_name = '.'.join(['tf', module_dir.replace('/', '.')]).strip('.') + # Convert / to . + module_name = module_dir.replace('/', '.').strip('.') module_name_to_file_path[module_name] = output_file # Create file for each expected output in genrule. @@ -162,16 +213,14 @@ def create_api_files(output_files): open(file_path, 'a').close() module_imports = get_api_imports() - module_imports['tf'].append(_CONTRIB_IMPORT) # Include all of contrib. # Add imports to output files. missing_output_files = [] for module, exports in module_imports.items(): # Make sure genrule output file list is in sync with API exports. if module not in module_name_to_file_path: - module_without_tf = module[len('tf.'):] module_file_path = '"api/%s/__init__.py"' % ( - module_without_tf.replace('.', '/')) + module.replace('.', '/')) missing_output_files.append(module_file_path) continue with open(module_name_to_file_path[module], 'w') as fp: diff --git a/tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7405202b892bba67a36d86cd43fb7a67ab3be947 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.GradientTape" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'persistent\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "gradient" + argspec: "args=[\'self\', \'target\', \'sources\', \'output_gradients\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "watch" + argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "watched_variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt index 42de5c0c80023ad5bd7f33a564780060998307c1..0900adaf762df1415c8db63c3879ca2fabc28d9f 100644 --- a/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt @@ -64,7 +64,7 @@ tf_class { } member_method { name: "list_files" - argspec: "args=[\'file_pattern\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "make_initializable_iterator" diff --git a/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt index e2fc8d6cb1d318cc50828f22e8e575cc28c7aaad..7b16ac90c925beb25e065d26e73ee2a54b06d9dc 100644 --- a/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "list_files" - argspec: "args=[\'file_pattern\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "make_initializable_iterator" diff --git a/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt index 9770389e5ef1e29a80ae1da2725d9862f6521ff9..9cf5f2ae2057ab4a16131527cf2ef2fa6ada28e5 100644 --- a/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt @@ -17,7 +17,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'filenames\', \'compression_type\', \'buffer_size\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'filenames\', \'compression_type\', \'buffer_size\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " } member_method { name: "apply" @@ -65,7 +65,7 @@ tf_class { } member_method { name: "list_files" - argspec: "args=[\'file_pattern\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "make_initializable_iterator" diff --git a/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt index 7263230c1c7182bb812cb2e433aedd415bcd16c7..8c3d6691439e619c906996a3ddaea4317c4a9597 100644 --- a/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt @@ -65,7 +65,7 @@ tf_class { } member_method { name: "list_files" - argspec: "args=[\'file_pattern\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "make_initializable_iterator" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt index 091b1be0c83480757445542acb97e139bd74ef03..759ff752b0ea6b710a2d20fd9ad665b3e6e6ea82 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt @@ -6,6 +6,10 @@ tf_class { name: "cluster_spec" mtype: "" } + member { + name: "distribute" + mtype: "" + } member { name: "evaluation_master" mtype: "" @@ -80,7 +84,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'\', \'\', \'None\', \'5\', \'10000\', \'100\'], " + argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'\', \'\', \'None\', \'5\', \'10000\', \'100\', \'None\'], " } member_method { name: "replace" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-vocab-info.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-vocab-info.pbtxt index a16e3aedae96e7289e73c49ac7890550dd5ddb08..5301b94eb361251a1cb4d02a5d8168f7c8191045 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-vocab-info.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-vocab-info.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.estimator.VocabInfo" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "backup_initializer" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-warm-start-settings.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-warm-start-settings.pbtxt index afdd6bb058353594415cd1abe726070f84ae46b6..43f5343359aff3b856a2b3708e4cda7cec29e146 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-warm-start-settings.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-warm-start-settings.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.estimator.WarmStartSettings" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "ckpt_to_initialize_from" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-tensor-serving-input-receiver.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.export.-tensor-serving-input-receiver.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4fe92643bf9867765499d7bf475b9cdd1686aec5 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.export.-tensor-serving-input-receiver.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.estimator.export.TensorServingInputReceiver" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "features" + mtype: "" + } + member { + name: "receiver_tensors" + mtype: "" + } + member { + name: "receiver_tensors_alternatives" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.export.pbtxt index 4d0dddb3bc0305a28fab0c95c31e4869f5db0aa8..bd72f6cd79f7dffb9f0a7f8ae43751c4ecba939d 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.export.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.export.pbtxt @@ -20,6 +20,10 @@ tf_module { name: "ServingInputReceiver" mtype: "" } + member { + name: "TensorServingInputReceiver" + mtype: "" + } member_method { name: "build_parsing_serving_input_receiver_fn" argspec: "args=[\'feature_spec\', \'default_batch_size\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt index bda1c2bf85977e69b0969bc8b6056710d88ca910..3fc64dae888012169af3ea7695154b73f24d90c8 100644 --- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt @@ -100,6 +100,10 @@ tf_module { name: "hsv_to_rgb" argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "image_gradients" + argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "is_jpeg" argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -116,6 +120,10 @@ tf_module { name: "per_image_standardization" argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "psnr" + argspec: "args=[\'a\', \'b\', \'max_val\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "random_brightness" argspec: "args=[\'image\', \'max_delta\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -188,6 +196,18 @@ tf_module { name: "sample_distorted_bounding_box" argspec: "args=[\'image_size\', \'bounding_boxes\', \'seed\', \'seed2\', \'min_object_covered\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.1\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "sobel_edges" + argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "ssim" + argspec: "args=[\'img1\', \'img2\', \'max_val\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "ssim_multiscale" + argspec: "args=[\'img1\', \'img2\', \'max_val\', \'power_factors\'], varargs=None, keywords=None, defaults=[\'(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)\'], " + } member_method { name: "total_variation" argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.initializers.pbtxt b/tensorflow/tools/api/golden/tensorflow.initializers.pbtxt index 21a0f84d22fc2d06e551c9a709f3963e812333b8..eaf0036cacfadce335a84bcf61f47f9d360be7e2 100644 --- a/tensorflow/tools/api/golden/tensorflow.initializers.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.initializers.pbtxt @@ -1,17 +1,9 @@ path: "tensorflow.initializers" tf_module { - member { - name: "absolute_import" - mtype: "" - } member { name: "constant" mtype: "" } - member { - name: "division" - mtype: "" - } member { name: "identity" mtype: "" @@ -24,10 +16,6 @@ tf_module { name: "orthogonal" mtype: "" } - member { - name: "print_function" - mtype: "" - } member { name: "random_normal" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt index 04724e3a1af9702e2e0eb8fd8b460d4cc7997a97..7be2f4f61f6b9637f372591e49efc0c93c7a8c0a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt @@ -1,9 +1,10 @@ path: "tensorflow.keras.Model" tf_class { is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt index c94bd2faa4f3828a7b80388ad6504f7d10b630a1..bf361cf8054571c0b056e1373acb838aaea87173 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt @@ -1,10 +1,11 @@ path: "tensorflow.keras.Sequential" tf_class { - is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" @@ -74,10 +75,6 @@ tf_class { name: "output_shape" mtype: "" } - member { - name: "regularizers" - mtype: "" - } member { name: "scope_name" mtype: "" @@ -90,10 +87,6 @@ tf_class { name: "stateful" mtype: "" } - member { - name: "trainable" - mtype: "" - } member { name: "trainable_variables" mtype: "" @@ -152,11 +145,11 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" + argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" @@ -172,11 +165,11 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'32\', \'1\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " } member_method { name: "evaluate_generator" - argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\'], " } member_method { name: "fit" @@ -184,7 +177,7 @@ tf_class { } member_method { name: "fit_generator" - argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " + argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " } member_method { name: "from_config" @@ -244,7 +237,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], " + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " } member_method { name: "predict_classes" @@ -252,7 +245,7 @@ tf_class { } member_method { name: "predict_generator" - argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " } member_method { name: "predict_on_batch" @@ -296,6 +289,6 @@ tf_class { } member_method { name: "train_on_batch" - argspec: "args=[\'self\', \'x\', \'y\', \'class_weight\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } } diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.fashion_mnist.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.fashion_mnist.pbtxt index 791cfda23345fea7df1cfb107ae5dec06354bd48..a0e14356fa5e91bc81bd89f6eb8c07087956c392 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.datasets.fashion_mnist.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.fashion_mnist.pbtxt @@ -1,3 +1,7 @@ path: "tensorflow.keras.datasets.fashion_mnist" tf_module { + member_method { + name: "load_data" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt index f4ab075959906cdf350ec5d49dc86f928b7eb7ae..db8f626b98b70fd99f38e696aa16c72e74e86e25 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.Activation" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt index eb558cddafc3972127786353072767f0d53bf174..809b3a5430449176a0d7423ec7f4499ceb620890 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.ActivityRegularization" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt index 770a107b664d7ab0a8aedf292a34d4258a201859..68d41bb6cc258ca87d4664ac0fb9d5649f89ebaf 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.Add" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt index 0ce42b706ec20a8ea1cc83ec95cb64d9be2e5710..970b777e514194db4ac49fe58bea737b35436217 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.AlphaDropout" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt index d6c98fa225ce924bc8e20f8531516eaed4d32ffb..529c64ab293d596012aefd42e0695bd1eb7e44d1 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt index 754fd310c6d8ddb994db0590342b29f8cb7abd71..7e7c330d74fe3b71ecd0eb87e34719e47ae70784 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt index 9b62880c7931d151fb98cc1dc3149dcbd4dd103d..ada8466d7473072b1878861ab36ec40b07fa1914 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt index b371ad148cee16dd243869d929e0c1c002794682..2a5c1cd530a7a532f6cdd3c184f4ee7eb88d23d3 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.Average" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt index 3e2aba55fd63326bb0e232fdce06f32884db7a0a..9a2cb29815d59f3761ea25e9ea36ff6489c85b88 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt index fb37308cce0124538648c3837e1e802794d7f1ae..f5e991ea42e5ee2723b64574d4598dc8463f1c8c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt index 813470ffc7c87727eb0b958e54806f530399806a..31732214a62524017e39776cdfb9ab629746e8ae 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt index e251ac18e511b58a49816126d9941b98e4f91088..422eddf10db6763e10405dba5537ca161d1b8994 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.BatchNormalization" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt index 699208a0b9b665b69f02edaa2b2d2aeed6a83b63..9053a37916314198842bc21b0608a9b69a64c264 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.Bidirectional" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt index ff08def0a08e5201bc01d61be3f2d66d712c384b..3d536d2182fc4480a2ee5fba177543ca21fbd5ac 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.Concatenate" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt index 6db22ca0320519fd9c101456c9c9c0e26a9a11e0..6a7da1aef8db64ad11bb5a5ba357f33eeb99170b 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt index 577f206e3510a9995d5d383ac440b4f68ea39fe5..801a0339720919f8b3f6beee0f045d58b2c0a371 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt index 72924c32b43e5edb39938cc0cd909cffefa61be1..13352e264a5305190717bb973a3f2bce4d7f4fff 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt @@ -4,8 +4,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt index 16be08d9b2bae8fe1faecf34c4d87ac9b9baf142..f400e4a15c362037e85ac375cee98bb5f6358669 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt index 11e05f884d781166616a9c9a61dacbc8fdae6ae3..b3a9f573b8ba652d2544b21f36f65fe81a6ebb50 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt @@ -4,8 +4,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt index 72b72d6b3b1e410dda0b0a529449f0135203fc1b..a9be09c0abd19aeb4df30116ef2befc3948bfbf4 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt index ee93247f63ed700dc6058041bd0ea4ff5c879078..be1ef5eb928d16cc6bf78c289aa20d815c728b23 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt index e5023287e5f38553f3553a37b5a908790072b5c7..30034f7eaf6d9073695353e5c8d9ead0cc8de7cc 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt @@ -4,8 +4,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt index ba38cb7121c9d312e7ba9d7147bdc67673d1ad2e..189b38054c004facfeeff8ad2ae87848b89040f2 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt index 58724a1e1661609ef3c000c7ca1dfe9b3235acff..a76d85c629c1fe620dafd62a0f0e05e9009109e2 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt @@ -4,8 +4,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt index 98d52c430c659d0fc3e9299f7bede9190dad2fcf..782195d4ad5883d8c0ea6a657cc10258f2080a55 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt index 33b6ebe1af731f66f88a9493502f69049ab34b42..2cb7a39ea595e1ff699b96554cb135377d20a488 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.Cropping1D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt index 4b241ebb0f68c270a9448b02138d44f82211f418..80803306992bba3b601824a93cb3086ef3947369 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.Cropping2D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt index 1856a9ee21347ed6ca3dd592517eb644e205a5b7..678f40bbc23db15ff7c1138169478fb4412a449d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.Cropping3D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt index a8c37af31f649d28ca2ab7614178f2dee58c13fc..fac826109b6a32305ece86c4990f08afe2236ce8 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.Dense" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt index 07d3f023e54105c606b198c05750ffa78ee5d0c8..285d544af2d69d564afdec748598b39b6b95670f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.Dot" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt index e2e21b5f123f63fa38cb0e344be9a12fc091f20b..b77976974cccb96fc2373c093d2bdf279560c46f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.Dropout" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt index 92b9760d53e35d3e5066a730bb5cbda45492cc64..b07714d3f2d158496e0482f8611e55ea0fb0fd51 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.ELU" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt index 83c528b40117222ac2b3e85ad338459948d0aa8c..e67d4ddfc47077d62319ab097e5333a373cbfc80 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.Embedding" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt index 73609752886c8c57a78f6bc02cc46d2c7ff6e996..b2a668e5a88d312656f48ddd0e9f7aa9f6306991 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.Flatten" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt index b329f1c46bb07ab7684dec6aaf45a20b98c27ed9..1fd3febad26df16576dedca1df7560bf230c08ec 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.GRUCell" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt index c741d4d6e6cf8da9712e68f86abe64e2828823da..f5f41d879dcb840551c00a7272bbcfbe51dbee89 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.GRU" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activation" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt index 57596badf1881950270fa6d3c074afb65daaa8eb..f4f1a5d51c5d5689918af4facf907f79d9ca71ec 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.GaussianDropout" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt index 3829353cc3c195a750ad862707c5c8563e203fba..e502df5e177d422403d0643c18a9588afb9d9713 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.GaussianNoise" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt index e53e78a977b32eaf2e31867044aedd39ab2dd34f..9c8d5bfcd8966384230e7d5cdcc1cac53a0eab9a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.GlobalAveragePooling1D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt index 48fcd1044e06b2fe61aadb6c3675ce82197ff003..8dd65f1f248daaf120780f19050c45d297b7902e 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.GlobalAveragePooling2D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt index 66c06ed47289eb2d83d97778a7b13dab821722d2..5e30571cc730ee23767a044036b590460deec00b 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.GlobalAveragePooling3D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt index 4f2420f74ab3069952e4a44bf61e5e12b3e80ea3..ba90fa454696d1cb4e77d80a2dc77ff65def4714 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.GlobalAvgPool1D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt index 7912a6d933b851521358e0246d04688da410b909..8823857758307c208527b144c0cc73b566f2f115 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.GlobalAvgPool2D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt index d5b2d2c274ad97071497045271c0a595f8e0e062..500ced852ba6b19502769ba9052f2e364af7e283 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.GlobalAvgPool3D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt index d88ff17eb6df7bbba7d3af4344fc8ddc367ae44c..cf2717ed46b56e639fb774c1e922648e1653ec0d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.GlobalMaxPool1D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt index c8cc5a0ddfdd54cbb47de922591a9842abf63396..a86ff1a46997f19b11e6ef03be432b45687a2df2 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.GlobalMaxPool2D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt index 7956c5a340d963cfd5976e8af56da222848a164a..e01cc7c1b09ad6a40380613d54b771c6a1c89c1c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.GlobalMaxPool3D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt index 0a7e16413dfbd80d448eb1bad5771915475d96b2..259c1fb37c787f5318570b7aca6935d2f0ed997f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.GlobalMaxPooling1D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt index 6c8a58a996f5313ea48e395e7e443a7c21f198ee..0c41bf97f763f1e40e8fac714709ccac1483a00b 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.GlobalMaxPooling2D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt index 7678ce8aab63fcfa76c0ac61346a723c1dfe1ee7..bec8817aa393ba2d8a6410408938402366cbb01d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.GlobalMaxPooling3D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt index 1e9370b02f9c5d198547e27e60bc3acf9e866d89..17be86222901c0f5a9a18c0e5f1c5bcac6c06a17 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.InputLayer" tf_class { - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt index 3b171b137af699c9608494a17c5651b439fe4545..6d2a8c56196d9b3c80f570c7f1d3ac803253fff6 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.LSTMCell" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt index 29d9cf78ab5ed3bdd1a488359b59cf7171e7e051..490b5b618c65e28f1ae2e01e8d35e7f3973cc180 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.LSTM" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activation" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt index ca0144929942f7024a4e8bac5552bf0547ceb56d..21a65b838af35e2f540eacab823513e7bf54b434 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.Lambda" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt index c52ad727545c0bf4f199714d71180eac3f1bf62a..127b04738e70c11b2dc1071cf174cf5de23c5133 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt @@ -1,7 +1,8 @@ path: "tensorflow.keras.layers.Layer" tf_class { - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt index 8134fb738683b79764662d9ea7f721fe04751162..87e49f2ed5b5d73aee5e9aa2511485b1f3f4bcd9 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.LeakyReLU" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt index c5d452300947d7f74e7458e2a04bfdfabb1c1da2..1aa3aad3246b83931a47e69a4aa76fdf2b5aee22 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.LocallyConnected1D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt index bcbed9241b525a953c8b499197facaefebe8cc44..5e9dc7d4774c651a186a4e320d0cfd088e87b6b3 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.LocallyConnected2D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt index 244e79b4ffe60ddd6aa56d2780d80dfd66c494a9..0d101e5b68cdb2cdf24ed472c724cfc885e3d95d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.Masking" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt index 56cbf5df785ef0e2614ea7e9e6cfe1335e148eec..c85cd49ac8ce2c1fc0759671865b7174cd1c1480 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt index 33c2d30e86f9cdc3fb9f4f498bfc2c94497fe2dd..4f59e330c92f96101c65a9a24f66196e84587ccb 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt index 94f91059b7a1e291c38fe0045accc6c03f226603..c0ea0eb0505d20e70d641f2a646a060d7dbfabda 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt index 247230a6d68b8ea93a30a2f5846d8baaa78cb13e..ca37ae51314516ae67c7725eb2ccd3d25154e2ac 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt index 8d61b67e7ce9564d31b0bd904a58540d19c89172..3ede2378347f5eddb0e8fae775a0200ea484d3f8 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt index ad2e30802006e934730e5c75247e958329f7121c..d87e25a7ba8e7cce615431723b53a0106c2b5279 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt index ff0db15f190675d533c50c277eb1cb60e0b95e55..e4df7b48ae6b41400375920a48ef8577bb69376e 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.Maximum" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt index 1d3f33f04516345ee32f16befe0d7200d2cdad00..6bf7c77743c31b6d74df35d827e9d5bc9a25d303 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.Multiply" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt index c86bc49b22a8cc3e004a77f4a21594aacb2c665a..c14be132b7e406c99841576be8d8fa9ab99aa816 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.PReLU" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt index 2043e1a1263f0f0745b7c6446cc670fd6b0f0000..72ffbceae01da900778dba1ec14e646aa17b39e5 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.Permute" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt index ad539a7c4c5362500baef0a9c89d054762bbb47d..d3e780c8b22ed580f61ffc3d9b2bad7278391402 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.RNN" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt index 4b0e98520a0dd86c085fa7345af445e1ae253d3b..a27980a9d17397e558a4b732e3dc332a0c1e8432 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.RepeatVector" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt index 34bc71af8a26ff6e4d7c81a3877751df5209906f..67f991276c6908ff54fd516e84533542a5f60528 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.Reshape" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt index dd67b76523cc50409516e29f963f59d039455bfd..fccea5e8af5ab81e712669ff1b2567d8bde8607e 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt @@ -4,8 +4,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt index 5d898fb2bd86b39cb8fab755382bb96cce231fa6..d20663bdb0bc2eea323d35b1e3d4d27122f50472 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt @@ -4,8 +4,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt index bf62c095e7cc3fbeac95919a0f9fdc545efd3d25..889fa0a1b58bbd3babd293b7b1b45915a9ee3ca4 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt @@ -4,8 +4,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt index c758d87993b3acba88a13c7bc9eaeee929a22652..c850f3fedc814b20f0f95cc3cf4fd5c973446b5b 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt @@ -4,8 +4,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt index 6e3cde3e3eaba4f9985411d66a220f7cdd4ee7ad..526d88ccba60eb25c68432e5baa03fd3a878f718 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.SimpleRNNCell" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt index 6fafc77b947d0df11755e3136ed2e7a14c148081..7fddae34472411f49d42b4d65d12034d056ec818 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.SimpleRNN" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activation" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt index ee4b2fa39ed34a544ee800e9370e4f34c4a17041..5b9b62fc970238e49e6d4849285606d0a7908b23 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.Softmax" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt index e4727072e375b9fc4dc99a1536eaaf3df5415369..769da30999993fad05ae0f7c04e256e6cf01a774 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt index c5ff7043115ccdd3bc4a1147790b20feda410f65..fca2e42a1519fcf3a9f0ec996c50b148b2df05fd 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt index 476a7f362cf88e234e964f6f6645ee4ed0cbaff8..36e8de09a967c5940bf8078234f5980a78ec8009 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt @@ -3,8 +3,9 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt index 3dde1e576918409b106649789443f910775e2f6c..a96f16fae99af9c30959d228202055e9aebfaf58 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.StackedRNNCells" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt index ef31c5443efa0c0e5a7a2e0a422d2a9c9c49baaf..e1cbd0e150ed890ae57c1725249d1340fc2cb663 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.ThresholdedReLU" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt index 1e176d8d4b4eb010049f267be3d0683228a7782b..f0d35728fb1c42d563ff0598dd84da51a766a764 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt @@ -2,8 +2,9 @@ path: "tensorflow.keras.layers.TimeDistributed" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt index a81b83be49e0073f242efc6890e419b4fe172ab2..74efaea6ddb22ec2fe9d41558978c183b0e06671 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.UpSampling1D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt index 5403279d45ec7b93bae7907b891c659a043e96d0..dc5bd5fd5319f9bbd601a3c4083ae566b47e1aaa 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.UpSampling2D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt index 96c337caf28d43fabd0b90df016f4e8ab0c408db..e01ccfb74aead591f1018cdcbb1c888767ecdb20 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.UpSampling3D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt index ea3bb2f8f567c648cd8b3dfa6f179a108690b0f0..7e6f90f7623677244865ac285c134dc79f7b9b69 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.Wrapper" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt index b81a4b1c50b22f13eacb521cfc8bc288bd40c81f..4d0d402dad442ccf52267f5ce40b05400afbfbc7 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.ZeroPadding1D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt index 1a26f2f3c9bbaa2aa567e76e1aafe14805ecff38..b353a529bcf8e543d334fee57fca26ebc83036a4 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.ZeroPadding2D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt index 310277fe67433fd870ae3d907984f402576925b2..9fe1256e616dbca4f35101df160dc55bc68bfa8a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt @@ -1,8 +1,9 @@ path: "tensorflow.keras.layers.ZeroPadding3D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt index 88eb237cec16210a6e290046bb479735e2158971..8ccf15f9ab0fcfa59907ff05a962a84d3d86ccb4 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt @@ -1,9 +1,10 @@ path: "tensorflow.keras.models.Model" tf_class { is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt index 34f10f01ad3fda4aa0ec076eff1a0221a2932959..be12b0bd2ec509ff394eaa3f43db0b54badd7fba 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt @@ -1,10 +1,11 @@ path: "tensorflow.keras.models.Sequential" tf_class { - is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" @@ -74,10 +75,6 @@ tf_class { name: "output_shape" mtype: "" } - member { - name: "regularizers" - mtype: "" - } member { name: "scope_name" mtype: "" @@ -90,10 +87,6 @@ tf_class { name: "stateful" mtype: "" } - member { - name: "trainable" - mtype: "" - } member { name: "trainable_variables" mtype: "" @@ -152,11 +145,11 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" + argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" @@ -172,11 +165,11 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'32\', \'1\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " } member_method { name: "evaluate_generator" - argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\'], " } member_method { name: "fit" @@ -184,7 +177,7 @@ tf_class { } member_method { name: "fit_generator" - argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " + argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " } member_method { name: "from_config" @@ -244,7 +237,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], " + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " } member_method { name: "predict_classes" @@ -252,7 +245,7 @@ tf_class { } member_method { name: "predict_generator" - argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " } member_method { name: "predict_on_batch" @@ -296,6 +289,6 @@ tf_class { } member_method { name: "train_on_batch" - argspec: "args=[\'self\', \'x\', \'y\', \'class_weight\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } } diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt index 04174bff5f04fead68af68afeec80316867009a4..ec0f3d892d9d03a738d34a40afe701e788908a8e 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt @@ -6,7 +6,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'directory\', \'image_data_generator\', \'target_size\', \'color_mode\', \'classes\', \'class_mode\', \'batch_size\', \'shuffle\', \'seed\', \'data_format\', \'save_to_dir\', \'save_prefix\', \'save_format\', \'follow_links\', \'interpolation\'], varargs=None, keywords=None, defaults=[\'(256, 256)\', \'rgb\', \'None\', \'categorical\', \'32\', \'True\', \'None\', \'None\', \'None\', \'\', \'png\', \'False\', \'nearest\'], " + argspec: "args=[\'self\', \'directory\', \'image_data_generator\', \'target_size\', \'color_mode\', \'classes\', \'class_mode\', \'batch_size\', \'shuffle\', \'seed\', \'data_format\', \'save_to_dir\', \'save_prefix\', \'save_format\', \'follow_links\', \'subset\', \'interpolation\'], varargs=None, keywords=None, defaults=[\'(256, 256)\', \'rgb\', \'None\', \'categorical\', \'32\', \'True\', \'None\', \'None\', \'None\', \'\', \'png\', \'False\', \'None\', \'nearest\'], " } member_method { name: "next" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt index 41f27d1f740457f4b7c4f74cb089a448a0fed845..f5bc04e44c198e5bc60f8361dd32e4ae00250468 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'featurewise_center\', \'samplewise_center\', \'featurewise_std_normalization\', \'samplewise_std_normalization\', \'zca_whitening\', \'zca_epsilon\', \'rotation_range\', \'width_shift_range\', \'height_shift_range\', \'shear_range\', \'zoom_range\', \'channel_shift_range\', \'fill_mode\', \'cval\', \'horizontal_flip\', \'vertical_flip\', \'rescale\', \'preprocessing_function\', \'data_format\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'1e-06\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'nearest\', \'0.0\', \'False\', \'False\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'featurewise_center\', \'samplewise_center\', \'featurewise_std_normalization\', \'samplewise_std_normalization\', \'zca_whitening\', \'zca_epsilon\', \'rotation_range\', \'width_shift_range\', \'height_shift_range\', \'brightness_range\', \'shear_range\', \'zoom_range\', \'channel_shift_range\', \'fill_mode\', \'cval\', \'horizontal_flip\', \'vertical_flip\', \'rescale\', \'preprocessing_function\', \'data_format\', \'validation_split\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'1e-06\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'0.0\', \'0.0\', \'0.0\', \'nearest\', \'0.0\', \'False\', \'False\', \'None\', \'None\', \'None\', \'0.0\'], " } member_method { name: "fit" @@ -12,11 +12,11 @@ tf_class { } member_method { name: "flow" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'shuffle\', \'seed\', \'save_to_dir\', \'save_prefix\', \'save_format\'], varargs=None, keywords=None, defaults=[\'None\', \'32\', \'True\', \'None\', \'None\', \'\', \'png\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'shuffle\', \'seed\', \'save_to_dir\', \'save_prefix\', \'save_format\', \'subset\'], varargs=None, keywords=None, defaults=[\'None\', \'32\', \'True\', \'None\', \'None\', \'\', \'png\', \'None\'], " } member_method { name: "flow_from_directory" - argspec: "args=[\'self\', \'directory\', \'target_size\', \'color_mode\', \'classes\', \'class_mode\', \'batch_size\', \'shuffle\', \'seed\', \'save_to_dir\', \'save_prefix\', \'save_format\', \'follow_links\', \'interpolation\'], varargs=None, keywords=None, defaults=[\'(256, 256)\', \'rgb\', \'None\', \'categorical\', \'32\', \'True\', \'None\', \'None\', \'\', \'png\', \'False\', \'nearest\'], " + argspec: "args=[\'self\', \'directory\', \'target_size\', \'color_mode\', \'classes\', \'class_mode\', \'batch_size\', \'shuffle\', \'seed\', \'save_to_dir\', \'save_prefix\', \'save_format\', \'follow_links\', \'subset\', \'interpolation\'], varargs=None, keywords=None, defaults=[\'(256, 256)\', \'rgb\', \'None\', \'categorical\', \'32\', \'True\', \'None\', \'None\', \'\', \'png\', \'False\', \'None\', \'nearest\'], " } member_method { name: "random_transform" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt index 4ef6e6e99e3b71d4a6e497cc577ef8b42cebab79..42196ddeee7aab144537eef250c07060923fa6a9 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt @@ -6,7 +6,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'x\', \'y\', \'image_data_generator\', \'batch_size\', \'shuffle\', \'seed\', \'data_format\', \'save_to_dir\', \'save_prefix\', \'save_format\'], varargs=None, keywords=None, defaults=[\'32\', \'False\', \'None\', \'None\', \'None\', \'\', \'png\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'image_data_generator\', \'batch_size\', \'shuffle\', \'seed\', \'data_format\', \'save_to_dir\', \'save_prefix\', \'save_format\', \'subset\'], varargs=None, keywords=None, defaults=[\'32\', \'False\', \'None\', \'None\', \'None\', \'\', \'png\', \'None\'], " } member_method { name: "next" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt index d28fef696515e09990d63581de6127fd52c0a4ee..6b850dd6b784412d623f44200b4acc169bf25968 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt @@ -36,6 +36,10 @@ tf_module { name: "load_img" argspec: "args=[\'path\', \'grayscale\', \'target_size\', \'interpolation\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'nearest\'], " } + member_method { + name: "random_brightness" + argspec: "args=[\'x\', \'brightness_range\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "random_channel_shift" argspec: "args=[\'x\', \'intensity\', \'channel_axis\'], varargs=None, keywords=None, defaults=[\'0\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d9c3215b555c19bc5cf4b32b0d227a9e1b63ce1e --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.keras.preprocessing.sequence.TimeseriesGenerator" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'data\', \'targets\', \'length\', \'sampling_rate\', \'stride\', \'start_index\', \'end_index\', \'shuffle\', \'reverse\', \'batch_size\'], varargs=None, keywords=None, defaults=[\'1\', \'1\', \'0\', \'None\', \'False\', \'False\', \'128\'], " + } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.pbtxt index 1b01935cc53b450c3e7009f945f86c8e1c10bf8e..cf59f8a27269c1161919f7ca2a44c5717a836dd7 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.keras.preprocessing.sequence" tf_module { + member { + name: "TimeseriesGenerator" + mtype: "" + } member_method { name: "make_sampling_table" argspec: "args=[\'size\', \'sampling_factor\'], varargs=None, keywords=None, defaults=[\'1e-05\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.pbtxt index d106429df0273929472aa58909f554bcffde9bca..50b54fc7e179bdfb8641d8de12934caa3fc44300 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.pbtxt @@ -4,6 +4,10 @@ tf_module { name: "Tokenizer" mtype: "" } + member_method { + name: "hashing_trick" + argspec: "args=[\'text\', \'n\', \'hash_function\', \'filters\', \'lower\', \'split\'], varargs=None, keywords=None, defaults=[\'None\', \'!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n\', \'True\', \' \'], " + } member_method { name: "one_hot" argspec: "args=[\'text\', \'n\', \'filters\', \'lower\', \'split\'], varargs=None, keywords=None, defaults=[\'!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n\', \'True\', \' \'], " diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt index de81206bc8b25046cd48c79ff8f154041c0e0cb0..1c4f550d7f05b8be33326cb39d7a5f3bf663f5e6 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt @@ -3,6 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt index 72d5496464210efd9e423996dfb274dd9564f761..d2db0952693f2989e6a9e8748a254eb4db483206 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt @@ -3,6 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt index 595e77ff9f8b64b6606fb075f3edf2281b4c3c1f..34d9a9df281c09a2e2030daf74a2ceb8066085bb 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt @@ -3,6 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt index 0c4aa2ff2612269727026141574726ad6df5cdbd..21ad0efecf88c42a3a679910ddfe095585a7933a 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.layers.BatchNormalization" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt index 5f576d0189309442dc4cea3d3617ab3144420165..ed38747c7671a267bb640ecb96a4c5fcc46c5edf 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt @@ -3,6 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt index 675a7c76e569d3163ecd2c547841b4c36078b21d..ff453c6059477c20528fc768d93c65d208cdfc4a 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt @@ -4,6 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt index eaabbf6aab172aea5c51f8071076890bb6b5bcf7..5583bd22dce18b0a0593b73bde509818b63b3f29 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt @@ -3,6 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt index 838e070d79d2d7cfbd631f1a5e9960412cfdae5a..63f0c32a7c8f7e530c76c64fa619102bc12f9ad9 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt @@ -4,6 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt index 4bd8cfc1a48cd839e2ffa54d0d0ca863060406d8..b77726252ccca30a7c6555fb569eb65b69e34998 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt @@ -3,6 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt index 57eccb03ffeb90652b019b5ce8a519797e4a3a3d..92db9f6dcd2f77c4253eb77df4a26fb632b2a766 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.layers.Dense" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt index a1ec00eeeaa98a6199e29b187b0760ddc92db09d..80fa846a24c9162d8521bdb4f098b9cd8e34aedb 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.layers.Dropout" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt index a06943d51a52f1951056136445b0d5786d801b5b..f63213b3dde40aa54b165c1c269c26fd2cd9e3b4 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.layers.Flatten" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt index 24fda0c87ed0aeabd0fd4a16bb2efab444f8cd8a..4e45b2d513bb72bb47433d72c310d6a34fbc0c01 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.layers.Layer" tf_class { is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt index 4c3d00e0e1ddfe95c56f9ebc7c5d609c79dd44d4..19ec33fce775caa634e71e2295ac945a6f70ade9 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt @@ -3,6 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt index f7e2017b0c9438130f1cfb2431eb73ca4d3103c5..76180c333a21c592a3b53bb445df9b12d3596552 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt @@ -3,6 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt index 84780926a38ff811a5ab35fadfac690a6dbbbbe2..ded75c8ff09efc6746ddd2284f53d2c021cc473c 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt @@ -3,6 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt index 05799ecfc9fdb9ff44620a67dcdbdc4426fddced..3dbfa5453f8e0ebb02429df9c4cbdf98de6b8ced 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt @@ -4,6 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt index c2aeb35c4648bcce22ca73c838a85803a6b9cedf..ab171df1d1650e19836018f3316e6919f6d36def 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt @@ -4,6 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt index 44536787f09fc98bba8a4eb0bc562427cfe48b8b..9c71a24d0500e2091e0ae94cc4dd7ed6b788a54f 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt @@ -4,6 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt index 768565d3cacbd1313ee5a64c9b15f9ab70683772..9e19f96b7452616956fb7fd3ca62d8f4b25a2122 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt @@ -4,6 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt index 0d253e5dd233d6d2b6ad0070a463c283a8769dab..7540aa62861895a7c41840476d4edb79785a77a9 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt @@ -3,6 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt index 97edf245f6fbed393a6fb8dbf1e83649e9ac4b4e..fc1ff386690f9c7acb11d4cc0770e394f78350ad 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt @@ -3,6 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt index 6ecc134d4df866ab5d59e238a8157064421579bd..751122cfff3bf9c55dd9fa264fdf2e1960940724 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt @@ -4,6 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt index 4b3ca1578ba52f30e3405ff198fb716496a462c6..4b6313f395fd8fd4ec2af78365117620263e7a55 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt @@ -4,6 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt index 9a6c73a079884b8ab92be1c9e89b2a9f34aad851..00e8c71140596ecea237ce05a09feff1fbb49001 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt @@ -3,6 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt index 27488f8e73f20456fae911511ecd2e41a60da351..3852f90dd6c4a254e20e789bdeb7796d61cef6bc 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.nn.rnn_cell.RNNCell" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt index 3310836ed26387718115c2454300b9edfe930451..8f3f0f7506ef49014b31cd4bc04f1cb1e0d696fc 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt @@ -3,6 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "activity_regularizer" diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index f8d08f1d39a8bfa7d78be106e59d88de75a57823..937044aece83e49549bf6aca938bf673203f392b 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -84,6 +84,10 @@ tf_module { name: "GRAPH_DEF_VERSION_MIN_PRODUCER" mtype: "" } + member { + name: "GradientTape" + mtype: "" + } member { name: "Graph" mtype: "" @@ -596,6 +600,10 @@ tf_module { name: "add_to_collection" argspec: "args=[\'name\', \'value\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "add_to_collections" + argspec: "args=[\'names\', \'value\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "all_variables" argspec: "args=[], varargs=None, keywords=None, defaults=None" @@ -892,6 +900,10 @@ tf_module { name: "cumsum" argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], " } + member_method { + name: "custom_gradient" + argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "decode_base64" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -964,6 +976,10 @@ tf_module { name: "einsum" argspec: "args=[\'equation\'], varargs=inputs, keywords=kwargs, defaults=None" } + member_method { + name: "enable_eager_execution" + argspec: "args=[\'config\', \'device_policy\', \'execution_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } member_method { name: "encode_base64" argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " @@ -980,6 +996,10 @@ tf_module { name: "erfc" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "executing_eagerly" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } member_method { name: "exp" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -1094,7 +1114,7 @@ tf_module { } member_method { name: "get_local_variable" - argspec: "args=[], varargs=args, keywords=kwargs, defaults=None" + argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " } member_method { name: "get_seed" @@ -1600,6 +1620,10 @@ tf_module { name: "reduce_sum" argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "regex_replace" + argspec: "args=[\'input\', \'pattern\', \'rewrite\', \'replace_global\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } member_method { name: "register_tensor_conversion_function" argspec: "args=[\'base_type\', \'conversion_func\', \'priority\'], varargs=None, keywords=None, defaults=[\'100\'], " @@ -1664,6 +1688,14 @@ tf_module { name: "scatter_div" argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } + member_method { + name: "scatter_max" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "scatter_min" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } member_method { name: "scatter_mul" argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " @@ -1996,6 +2028,14 @@ tf_module { name: "to_bfloat16" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'ToBFloat16\'], " } + member_method { + name: "to_complex128" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'ToComplex128\'], " + } + member_method { + name: "to_complex64" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'ToComplex64\'], " + } member_method { name: "to_double" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'ToDouble\'], " @@ -2060,6 +2100,10 @@ tf_module { name: "unsorted_segment_max" argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "unsorted_segment_mean" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "unsorted_segment_min" argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adadelta-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-adadelta-optimizer.pbtxt index c02e54adfbd9f33e661453767b517a5f0de90d57..16bfbf20d5227d6308248bebcb62f32a2df8ef41 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-adadelta-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-adadelta-optimizer.pbtxt @@ -2,7 +2,6 @@ path: "tensorflow.train.AdadeltaOptimizer" tf_class { is_instance: "" is_instance: "" - is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adagrad-d-a-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-adagrad-d-a-optimizer.pbtxt index 2b619908fc6aea3f4b8e6a57d0dcf85a9854d466..61cde9181c2367153b7b289b41bd932482bb92fd 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-adagrad-d-a-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-adagrad-d-a-optimizer.pbtxt @@ -2,7 +2,6 @@ path: "tensorflow.train.AdagradDAOptimizer" tf_class { is_instance: "" is_instance: "" - is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adagrad-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-adagrad-optimizer.pbtxt index 2005cf4677c06cf1f8b4207a444690fdd0c2306e..0a998c1afe4fff6e215360bc1cf8fc135754223c 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-adagrad-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-adagrad-optimizer.pbtxt @@ -2,7 +2,6 @@ path: "tensorflow.train.AdagradOptimizer" tf_class { is_instance: "" is_instance: "" - is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adam-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-adam-optimizer.pbtxt index 0a2bae1d9021b20707e03ae5786e71f388266c14..cc5954152577796ee7a5a6e1cedc873647d64f7c 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-adam-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-adam-optimizer.pbtxt @@ -2,7 +2,6 @@ path: "tensorflow.train.AdamOptimizer" tf_class { is_instance: "" is_instance: "" - is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt index 847f9ad75998f1bdda8858650091c70fd0b4015b..1add3a902122341a706c38b19ea6ff5882c26445 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt @@ -2,7 +2,6 @@ path: "tensorflow.train.FtrlOptimizer" tf_class { is_instance: "" is_instance: "" - is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.train.-gradient-descent-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-gradient-descent-optimizer.pbtxt index 13a58e0608ed269415ba78d84a03f1bae128e80c..ef5bbd6ace29abb5c73516176fcc7594a58d493a 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-gradient-descent-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-gradient-descent-optimizer.pbtxt @@ -2,7 +2,6 @@ path: "tensorflow.train.GradientDescentOptimizer" tf_class { is_instance: "" is_instance: "" - is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.train.-momentum-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-momentum-optimizer.pbtxt index bfbc2357a346c7bfef0242a735ab14c5f4005b22..3d6e87f5eb44de9d6ce1bdd25a54b8df9020cc03 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-momentum-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-momentum-optimizer.pbtxt @@ -2,7 +2,6 @@ path: "tensorflow.train.MomentumOptimizer" tf_class { is_instance: "" is_instance: "" - is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.train.-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-optimizer.pbtxt index 437efa0a2bd04c308db6186e714a5d8785541fa5..e73861ff7cb2d90d8efac72cdd7de3b27395f29e 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-optimizer.pbtxt @@ -1,7 +1,6 @@ path: "tensorflow.train.Optimizer" tf_class { is_instance: "" - is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.train.-proximal-adagrad-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-proximal-adagrad-optimizer.pbtxt index 72f224605f67e72dd78699b5f1a703cc3edd566b..301b35b199c87890a0aef4139eb06253592ce0c4 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-proximal-adagrad-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-proximal-adagrad-optimizer.pbtxt @@ -2,7 +2,6 @@ path: "tensorflow.train.ProximalAdagradOptimizer" tf_class { is_instance: "" is_instance: "" - is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt index 316275b1fb1abd384e193994e35115a1c463f07d..8815befa936a85522011111a4a6270d22cbc25ae 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt @@ -2,7 +2,6 @@ path: "tensorflow.train.ProximalGradientDescentOptimizer" tf_class { is_instance: "" is_instance: "" - is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.train.-r-m-s-prop-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-r-m-s-prop-optimizer.pbtxt index af50a1986100d830f0809a3f4a0f01faa8821b3b..e9819683ba5ec1bcacb3cdbcb2d787e866a77b6f 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-r-m-s-prop-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-r-m-s-prop-optimizer.pbtxt @@ -2,7 +2,6 @@ path: "tensorflow.train.RMSPropOptimizer" tf_class { is_instance: "" is_instance: "" - is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.train.-sync-replicas-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-sync-replicas-optimizer.pbtxt index 6edc516c9392fa14f23ffc2a6481ec21216f06cf..3db96aff876b88b80b647570cf68b1ebc0b2da3b 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-sync-replicas-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-sync-replicas-optimizer.pbtxt @@ -2,7 +2,6 @@ path: "tensorflow.train.SyncReplicasOptimizer" tf_class { is_instance: "" is_instance: "" - is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.train.-vocab-info.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-vocab-info.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4ce7cb111163e103a1cebe30d5c6f3eeb4234693 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-vocab-info.pbtxt @@ -0,0 +1,39 @@ +path: "tensorflow.train.VocabInfo" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "backup_initializer" + mtype: "" + } + member { + name: "new_vocab" + mtype: "" + } + member { + name: "new_vocab_size" + mtype: "" + } + member { + name: "num_oov_buckets" + mtype: "" + } + member { + name: "old_vocab" + mtype: "" + } + member { + name: "old_vocab_size" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.pbtxt index e49c719a334455d1f8f39fa67332be8bb81f2bc2..bec72e1e609c3e32ca8366396b9b1cb577feab9d 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.pbtxt @@ -224,6 +224,10 @@ tf_module { name: "SyncReplicasOptimizer" mtype: "" } + member { + name: "VocabInfo" + mtype: "" + } member { name: "WorkerSessionCreator" mtype: "" @@ -234,7 +238,7 @@ tf_module { } member_method { name: "MonitoredTrainingSession" - argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'600\', \'\', \'\', \'None\', \'120\', \'100\', \'7200\'], " + argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\', \'save_checkpoint_steps\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'\', \'\', \'\', \'None\', \'120\', \'100\', \'7200\', \'\'], " } member_method { name: "NewCheckpointReader" @@ -402,7 +406,7 @@ tf_module { } member_method { name: "sdca_optimizer" - argspec: "args=[\'sparse_example_indices\', \'sparse_feature_indices\', \'sparse_feature_values\', \'dense_features\', \'example_weights\', \'example_labels\', \'sparse_indices\', \'sparse_weights\', \'dense_weights\', \'example_state_data\', \'loss_type\', \'l1\', \'l2\', \'num_loss_partitions\', \'num_inner_iterations\', \'adaptative\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'sparse_example_indices\', \'sparse_feature_indices\', \'sparse_feature_values\', \'dense_features\', \'example_weights\', \'example_labels\', \'sparse_indices\', \'sparse_weights\', \'dense_weights\', \'example_state_data\', \'loss_type\', \'l1\', \'l2\', \'num_loss_partitions\', \'num_inner_iterations\', \'adaptative\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " } member_method { name: "sdca_shrink_l1" @@ -436,6 +440,10 @@ tf_module { name: "update_checkpoint_state" argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } + member_method { + name: "warm_start" + argspec: "args=[\'ckpt_to_initialize_from\', \'vars_to_warm_start\', \'var_name_to_vocab_info\', \'var_name_to_prev_var_name\'], varargs=None, keywords=None, defaults=[\'.*\', \'None\', \'None\'], " + } member_method { name: "write_graph" argspec: "args=[\'graph_or_graph_def\', \'logdir\', \'name\', \'as_text\'], varargs=None, keywords=None, defaults=[\'True\'], " diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD index 608a34ab7b32bdc26cebbe43b383155406fb51b2..15bf1abb5f8f541c435be77b1a3c2f13382f2438 100644 --- a/tensorflow/tools/api/tests/BUILD +++ b/tensorflow/tools/api/tests/BUILD @@ -23,6 +23,7 @@ py_test( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow:experimental_tensorflow_py", "//tensorflow:tensorflow_py", "//tensorflow/python:client_testlib", "//tensorflow/python:lib", diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py index c1e09cc531ed8e8995e3e73b87e96b72fba6c038..603b2a4327b94873b9908d5e0e114dcc4f7542dc 100644 --- a/tensorflow/tools/api/tests/api_compatibility_test.py +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -34,6 +34,7 @@ import sys import unittest import tensorflow as tf +from tensorflow import experimental_api as api from google.protobuf import text_format @@ -46,6 +47,9 @@ from tensorflow.tools.api.lib import python_object_to_proto_visitor from tensorflow.tools.common import public_api from tensorflow.tools.common import traverse +if hasattr(tf, 'experimental_api'): + del tf.experimental_api + # FLAGS defined at the bottom: FLAGS = None # DEFINE_boolean, update_goldens, default False: @@ -54,7 +58,7 @@ _UPDATE_GOLDENS_HELP = """ have to be authorized by TensorFlow leads. """ -# DEFINE_boolean, verbose_diffs, default False: +# DEFINE_boolean, verbose_diffs, default True: _VERBOSE_DIFFS_HELP = """ If set to true, print line by line diffs on all libraries. If set to false, only print which libraries have differences. @@ -109,7 +113,8 @@ class ApiCompatibilityTest(test.TestCase): expected_dict, actual_dict, verbose=False, - update_goldens=False): + update_goldens=False, + additional_missing_object_message=''): """Diff given dicts of protobufs and report differences a readable way. Args: @@ -120,6 +125,8 @@ class ApiCompatibilityTest(test.TestCase): verbose: Whether to log the full diffs, or simply report which files were different. update_goldens: Whether to update goldens when there are diffs found. + additional_missing_object_message: Message to print when a symbol is + missing. """ diffs = [] verbose_diffs = [] @@ -138,7 +145,8 @@ class ApiCompatibilityTest(test.TestCase): verbose_diff_message = '' # First check if the key is not found in one or the other. if key in only_in_expected: - diff_message = 'Object %s expected but not found (removed).' % key + diff_message = 'Object %s expected but not found (removed). %s' % ( + key, additional_missing_object_message) verbose_diff_message = diff_message elif key in only_in_actual: diff_message = 'New object %s found (added).' % key @@ -165,7 +173,7 @@ class ApiCompatibilityTest(test.TestCase): logging.error('%d differences found between API and golden.', diff_count) messages = verbose_diffs if verbose else diffs for i in range(diff_count): - logging.error('Issue %d\t: %s', i + 1, messages[i]) + print('Issue %d\t: %s' % (i + 1, messages[i]), file=sys.stderr) if update_goldens: # Write files if requested. @@ -229,13 +237,56 @@ class ApiCompatibilityTest(test.TestCase): verbose=FLAGS.verbose_diffs, update_goldens=FLAGS.update_goldens) + @unittest.skipUnless( + sys.version_info.major == 2, + 'API compabitility test goldens are generated using python2.') + def testNewAPIBackwardsCompatibility(self): + # Extract all API stuff. + visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() + + public_api_visitor = public_api.PublicAPIVisitor(visitor) + public_api_visitor.do_not_descend_map['tf'].append('contrib') + public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental'] + # TODO(annarev): Make slide_dataset available in API. + public_api_visitor.private_map['tf'] = ['slide_dataset'] + traverse.traverse(api, public_api_visitor) + + proto_dict = visitor.GetProtos() + + # Read all golden files. + expression = os.path.join( + resource_loader.get_root_dir_with_all_resources(), + _KeyToFilePath('*')) + golden_file_list = file_io.get_matching_files(expression) + + def _ReadFileToProto(filename): + """Read a filename, create a protobuf from its contents.""" + ret_val = api_objects_pb2.TFAPIObject() + text_format.Merge(file_io.read_file_to_string(filename), ret_val) + return ret_val + + golden_proto_dict = { + _FileNameToKey(filename): _ReadFileToProto(filename) + for filename in golden_file_list + } + + # Diff them. Do not fail if called with update. + # If the test is run to update goldens, only report diffs but do not fail. + self._AssertProtoDictEquals( + golden_proto_dict, + proto_dict, + verbose=FLAGS.verbose_diffs, + update_goldens=False, + additional_missing_object_message= + 'Check if tf_export decorator/call is missing for this symbol.') + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( '--update_goldens', type=bool, default=False, help=_UPDATE_GOLDENS_HELP) parser.add_argument( - '--verbose_diffs', type=bool, default=False, help=_VERBOSE_DIFFS_HELP) + '--verbose_diffs', type=bool, default=True, help=_VERBOSE_DIFFS_HELP) FLAGS, unparsed = parser.parse_known_args() # Now update argv, so that unittest library does not get confused. diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc index ecab6f8769ae2d0126f63580030ed6ff756015d0..15523028c726fefa13641a1369cf4274bcfb9973 100644 --- a/tensorflow/tools/benchmark/benchmark_model.cc +++ b/tensorflow/tools/benchmark/benchmark_model.cc @@ -48,33 +48,14 @@ limitations under the License. namespace tensorflow { namespace benchmark_model { -Status InitializeSession(int num_threads, const string& graph, - std::unique_ptr* session, - std::unique_ptr* graph_def) { - LOG(INFO) << "Loading TensorFlow."; +namespace { - tensorflow::SessionOptions options; - tensorflow::ConfigProto& config = options.config; - if (num_threads > 0) { - config.set_intra_op_parallelism_threads(num_threads); +Status InitializeVariables(Session* session, + const std::vector& init_ops) { + LOG(INFO) << "Initializing graph variables"; + for (const string& init_op : init_ops) { + TF_RETURN_IF_ERROR(session->Run({}, {}, {init_op}, nullptr)); } - LOG(INFO) << "Got config, " << config.device_count_size() << " devices"; - - session->reset(tensorflow::NewSession(options)); - graph_def->reset(new GraphDef()); - tensorflow::GraphDef tensorflow_graph; - Status s = ReadBinaryProto(Env::Default(), graph, graph_def->get()); - if (!s.ok()) { - LOG(ERROR) << "Could not create TensorFlow Graph: " << s; - return s; - } - - s = (*session)->Create(*(graph_def->get())); - if (!s.ok()) { - LOG(ERROR) << "Could not create TensorFlow Session: " << s; - return s; - } - return Status::OK(); } @@ -247,8 +228,56 @@ void RecordBenchmarkEntry(const string& output_prefix, TF_QCHECK_OK(node_reporter.Close()); } +void SleepSeconds(double sleep_seconds) { + if (sleep_seconds <= 0.0) { + return; + } +#ifdef PLATFORM_WINDOWS + Sleep(sleep_seconds * 1000); +#else + // Convert the inference_delay string into a timespec. + timespec req; + req.tv_sec = static_cast(sleep_seconds); + req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000; + nanosleep(&req, nullptr); +#endif +} + +} // namespace + +Status InitializeSession(int num_threads, const string& graph, + std::unique_ptr* session, + std::unique_ptr* graph_def) { + LOG(INFO) << "Loading TensorFlow."; + + tensorflow::SessionOptions options; + tensorflow::ConfigProto& config = options.config; + if (num_threads > 0) { + config.set_intra_op_parallelism_threads(num_threads); + } + LOG(INFO) << "Got config, " << config.device_count_size() << " devices"; + + session->reset(tensorflow::NewSession(options)); + graph_def->reset(new GraphDef()); + tensorflow::GraphDef tensorflow_graph; + Status s = ReadBinaryProto(Env::Default(), graph, graph_def->get()); + if (!s.ok()) { + LOG(ERROR) << "Could not create TensorFlow Graph: " << s; + return s; + } + + s = (*session)->Create(*(graph_def->get())); + if (!s.ok()) { + LOG(ERROR) << "Could not create TensorFlow Session: " << s; + return s; + } + + return Status::OK(); +} + Status RunBenchmark(const std::vector& inputs, - const std::vector& outputs, Session* session, + const std::vector& outputs, + const std::vector& targets, Session* session, StatSummarizer* stats, int64* inference_time_us) { std::vector > input_tensors; CreateTensorsFromInputInfo(inputs, &input_tensors); @@ -264,8 +293,8 @@ Status RunBenchmark(const std::vector& inputs, RunMetadata run_metadata; const int64 start_time = Env::Default()->NowMicros(); - s = session->Run(run_options, input_tensors, outputs, {}, &output_tensors, - &run_metadata); + s = session->Run(run_options, input_tensors, outputs, targets, + &output_tensors, &run_metadata); const int64 end_time = Env::Default()->NowMicros(); *inference_time_us = end_time - start_time; @@ -283,24 +312,10 @@ Status RunBenchmark(const std::vector& inputs, return s; } -void SleepSeconds(double sleep_seconds) { - if (sleep_seconds <= 0.0) { - return; - } -#ifdef PLATFORM_WINDOWS - Sleep(sleep_seconds * 1000); -#else - // Convert the inference_delay string into a timespec. - timespec req; - req.tv_sec = static_cast(sleep_seconds); - req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000; - nanosleep(&req, nullptr); -#endif -} - Status TimeMultipleRuns(double sleep_seconds, int num_runs, double max_time_s, const std::vector& inputs, - const std::vector& outputs, Session* session, + const std::vector& outputs, + const std::vector& targets, Session* session, StatSummarizer* stats, int64* total_time_us, int64* actual_num_runs) { *total_time_us = 0; @@ -315,7 +330,8 @@ Status TimeMultipleRuns(double sleep_seconds, int num_runs, double max_time_s, const bool until_max_time = num_runs <= 0; for (int i = 0; until_max_time || i < num_runs; ++i) { int64 time; - Status run_status = RunBenchmark(inputs, outputs, session, stats, &time); + Status run_status = + RunBenchmark(inputs, outputs, targets, session, stats, &time); stat.UpdateStat(time); (*total_time_us) += time; ++(*actual_num_runs); @@ -345,11 +361,13 @@ Status TimeMultipleRuns(double sleep_seconds, int num_runs, double max_time_s, int Main(int argc, char** argv) { string graph = "/data/local/tmp/tensorflow_inception_graph.pb"; + string init_ops_string = ""; string input_layer_string = "input:0"; string input_layer_shape_string = "1,224,224,3"; string input_layer_type_string = "float"; string input_layer_values_string = ""; string output_layer_string = "output:0"; + string target_layer_string = ""; int max_num_runs = 1000; string max_time = "10.0"; string inference_delay = "-1.0"; @@ -371,12 +389,14 @@ int Main(int argc, char** argv) { std::vector flag_list = { Flag("graph", &graph, "graph file name"), + Flag("init_ops", &init_ops_string, "init ops"), Flag("input_layer", &input_layer_string, "input layer names"), Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"), Flag("input_layer_type", &input_layer_type_string, "input layer type"), Flag("input_layer_values", &input_layer_values_string, "values to initialize the inputs with"), Flag("output_layer", &output_layer_string, "output layer name"), + Flag("target_layer", &target_layer_string, "target layer name"), Flag("max_num_runs", &max_num_runs, "number of runs max"), Flag("max_time", &max_time, "length to run max"), Flag("inference_delay", &inference_delay, @@ -410,6 +430,7 @@ int Main(int argc, char** argv) { return -1; } + std::vector init_ops = str_util::Split(init_ops_string, ','); std::vector input_layers = str_util::Split(input_layer_string, ','); std::vector input_layer_shapes = str_util::Split(input_layer_shape_string, ':'); @@ -418,6 +439,7 @@ int Main(int argc, char** argv) { std::vector input_layer_values = str_util::Split(input_layer_values_string, ':'); std::vector output_layers = str_util::Split(output_layer_string, ','); + std::vector target_layers = str_util::Split(target_layer_string, ','); if ((input_layers.size() != input_layer_shapes.size()) || (input_layers.size() != input_layer_types.size())) { LOG(ERROR) << "There must be the same number of items in --input_layer," @@ -441,10 +463,12 @@ int Main(int argc, char** argv) { } LOG(INFO) << "Graph: [" << graph << "]"; + LOG(INFO) << "Init ops:" << init_ops_string; LOG(INFO) << "Input layers: [" << input_layer_string << "]"; LOG(INFO) << "Input shapes: [" << input_layer_shape_string << "]"; LOG(INFO) << "Input types: [" << input_layer_type_string << "]"; LOG(INFO) << "Output layers: [" << output_layer_string << "]"; + LOG(INFO) << "Target layers: [" << target_layer_string << "]"; LOG(INFO) << "Num runs: [" << max_num_runs << "]"; LOG(INFO) << "Inter-inference delay (seconds): [" << inference_delay << "]"; LOG(INFO) << "Inter-benchmark delay (seconds): [" << inter_benchmark_delay @@ -470,6 +494,16 @@ int Main(int argc, char** argv) { return -1; } + if (!init_ops.empty()) { + Status initialize_variables_status = + InitializeVariables(session.get(), init_ops); + if (!initialize_variables_status.ok()) { + LOG(ERROR) << "Graph variables initialization failed with " + << initialize_variables_status; + return -1; + } + } + StatSummarizerOptions stats_options; stats_options.show_run_order = show_run_order; stats_options.run_order_limit = run_order_limit; @@ -520,9 +554,10 @@ int Main(int argc, char** argv) { int64 warmup_time_us = 0; int64 num_warmup_runs = 0; if (warmup_runs > 0) { - Status warmup_time_status = TimeMultipleRuns( - inter_inference_sleep_seconds, warmup_runs, -1.0, inputs, output_layers, - session.get(), nullptr, &warmup_time_us, &num_warmup_runs); + Status warmup_time_status = + TimeMultipleRuns(inter_inference_sleep_seconds, warmup_runs, -1.0, + inputs, output_layers, target_layers, session.get(), + nullptr, &warmup_time_us, &num_warmup_runs); if (!warmup_time_status.ok()) { LOG(ERROR) << "Timing failed with " << warmup_time_status; return -1; @@ -536,8 +571,8 @@ int Main(int argc, char** argv) { int64 no_stat_num_runs = 0; Status no_stat_time_status = TimeMultipleRuns( inter_inference_sleep_seconds, max_num_runs, max_benchmark_time_seconds, - inputs, output_layers, session.get(), nullptr, &no_stat_time_us, - &no_stat_num_runs); + inputs, output_layers, target_layers, session.get(), nullptr, + &no_stat_time_us, &no_stat_num_runs); const double no_stat_wall_time = no_stat_time_us / 1000000.0; if (!no_stat_time_status.ok()) { LOG(ERROR) << "Timing failed with " << no_stat_time_status; @@ -551,8 +586,8 @@ int Main(int argc, char** argv) { int64 stat_num_runs = 0; Status stat_time_status = TimeMultipleRuns( inter_inference_sleep_seconds, max_num_runs, max_benchmark_time_seconds, - inputs, output_layers, session.get(), stats.get(), &stat_time_us, - &stat_num_runs); + inputs, output_layers, target_layers, session.get(), stats.get(), + &stat_time_us, &stat_num_runs); if (!stat_time_status.ok()) { LOG(ERROR) << "Timing failed with " << stat_time_status; return -1; diff --git a/tensorflow/tools/benchmark/benchmark_model.h b/tensorflow/tools/benchmark/benchmark_model.h index dff62c5b5d518da8f9034295626e46db783f343d..dc5f0080374e70edad52965cc0a95f99751baa48 100644 --- a/tensorflow/tools/benchmark/benchmark_model.h +++ b/tensorflow/tools/benchmark/benchmark_model.h @@ -37,13 +37,15 @@ Status InitializeSession(int num_threads, const string& graph, // Does a single run of the model that's been loaded into the given session. Status RunBenchmark(const std::vector& inputs, - const std::vector& outputs, Session* session, + const std::vector& outputs, + const std::vector& targets, Session* session, StatSummarizer* stats, int64* inference_time_us); // Runs the model multiple time, keeping track of timing information. Status TimeMultipleRuns(double sleep_seconds, int num_runs, double max_time_s, const std::vector& inputs, - const std::vector& outputs, Session* session, + const std::vector& outputs, + const std::vector& targets, Session* session, StatSummarizer* stats, int64* total_time_us, int64* actual_num_runs); diff --git a/tensorflow/tools/benchmark/benchmark_model_test.cc b/tensorflow/tools/benchmark/benchmark_model_test.cc index bb4eb5352039b01a6692621906eff005187cfa36..16ab2ff66e763a0ca5130f075f988bade9c8abd1 100644 --- a/tensorflow/tools/benchmark/benchmark_model_test.cc +++ b/tensorflow/tools/benchmark/benchmark_model_test.cc @@ -64,8 +64,8 @@ TEST(BenchmarkModelTest, InitializeAndRun) { int64 time; int64 num_runs = 0; TF_ASSERT_OK(benchmark_model::TimeMultipleRuns( - 0.0, 10, 0.0, {input}, {output_name}, session.get(), stats.get(), &time, - &num_runs)); + 0.0, 10, 0.0, {input}, {output_name}, {}, session.get(), stats.get(), + &time, &num_runs)); ASSERT_EQ(num_runs, 10); } diff --git a/tensorflow/tools/ci_build/Dockerfile.cmake b/tensorflow/tools/ci_build/Dockerfile.cmake index ec90c83aacd068e8f9c16e5be8eb6e1cef098ea6..d5dea4f3e41841aed5aeac02fcca850dbfdfaeb3 100644 --- a/tensorflow/tools/ci_build/Dockerfile.cmake +++ b/tensorflow/tools/ci_build/Dockerfile.cmake @@ -23,11 +23,12 @@ RUN /install/install_deb_packages.sh RUN apt-get update RUN apt-get install -y --no-install-recommends python-pip +RUN pip install --upgrade wheel RUN pip install --upgrade astor RUN pip install --upgrade gast RUN pip install --upgrade numpy RUN pip install --upgrade termcolor # Install golang -RUN add-apt-repository -y ppa:ubuntu-lxc/lxd-stable -RUN apt-get install -y golang +RUN apt-get install -t xenial-backports -y golang-1.9 +ENV PATH=${PATH}:/usr/lib/go-1.9/bin diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cpu b/tensorflow/tools/ci_build/Dockerfile.rbe.cpu new file mode 100644 index 0000000000000000000000000000000000000000..6f0798b1afc34bc08df6f3f8f467a329fcf0fe9b --- /dev/null +++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cpu @@ -0,0 +1,14 @@ +FROM launcher.gcr.io/google/rbe-debian8:r322167 +LABEL maintainer="Yu Yi " + +# Copy install scripts +COPY install/*.sh /install/ + +# Setup envvars +ENV CC /usr/local/bin/clang +ENV CXX /usr/local/bin/clang++ +ENV AR /usr/bin/ar + +# Run pip install script for RBE Debian8 container. +RUN /install/install_pip_packages_remote.sh +RUN /install/install_pip_packages.sh diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.gpu b/tensorflow/tools/ci_build/Dockerfile.rbe.gpu new file mode 100644 index 0000000000000000000000000000000000000000..24ff4765a619701cd614414d2b06f7fa4ce7d8c0 --- /dev/null +++ b/tensorflow/tools/ci_build/Dockerfile.rbe.gpu @@ -0,0 +1,26 @@ +FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04 + +LABEL maintainer="Nick Lopez " + +# In the Ubuntu 16.04 images, cudnn is placed in system paths. Move them to +# /usr/local/cuda +RUN cp -P /usr/include/cudnn.h /usr/local/cuda/include +RUN cp -P /usr/lib/x86_64-linux-gnu/libcudnn* /usr/local/cuda/lib64 + +# Copy and run the install scripts. +COPY install/*.sh /install/ +ARG DEBIAN_FRONTEND=noninteractive +RUN /install/install_bootstrap_deb_packages.sh +RUN add-apt-repository -y ppa:openjdk-r/ppa && \ + add-apt-repository -y ppa:george-edison55/cmake-3.x +RUN /install/install_deb_packages.sh +RUN /install/install_pip_packages.sh +RUN /install/install_golang.sh + +# Install clang from pre-built package +RUN cd /tmp && \ + wget https://storage.googleapis.com/clang-builds-stable/clang-ubuntu16_04/clang_r323528.tar.gz && \ + echo "26752d9f5785df07193fac8316ba5d5ba3bec36d970c29a1577360848818ac74 clang_r323528.tar.gz" | sha256sum -c && \ + tar -C /usr/local -xf clang_r323528.tar.gz && \ + rm clang_r323528.tar.gz + diff --git a/tensorflow/tools/ci_build/builds/android.sh b/tensorflow/tools/ci_build/builds/android.sh index 564c5aa1480f5fd824dbc5c8bc85cec90664c512..d81793efe08f151c1b448a9da3cc971ca3137829 100755 --- a/tensorflow/tools/ci_build/builds/android.sh +++ b/tensorflow/tools/ci_build/builds/android.sh @@ -29,7 +29,8 @@ echo "========== TensorFlow Demo Build Test ==========" # Enable sandboxing so that zip archives don't get incorrectly packaged # in assets/ dir (see https://github.com/bazelbuild/bazel/issues/2334) # TODO(gunan): remove extra flags once sandboxing is enabled for all builds. -bazel --bazelrc=/dev/null build -c opt --fat_apk_cpu=x86_64 \ +bazel --bazelrc=/dev/null build \ + --compilation_mode=opt --cxxopt=-std=c++11 --fat_apk_cpu=x86_64 \ --spawn_strategy=sandboxed --genrule_strategy=sandboxed \ //tensorflow/examples/android:tensorflow_demo diff --git a/tensorflow/tools/ci_build/builds/android_full.sh b/tensorflow/tools/ci_build/builds/android_full.sh index 9d449241e8413ddbd81c580cc4def808c0086cb9..41dc66dd5436a81eeeca197f6ef57cb2a1407ca0 100755 --- a/tensorflow/tools/ci_build/builds/android_full.sh +++ b/tensorflow/tools/ci_build/builds/android_full.sh @@ -40,7 +40,8 @@ rm -rf ${AAR_LIB_TMP} for CPU in ${CPUS//,/ } do echo "========== Building native libs for Android ${CPU} ==========" - bazel build -c opt --config=monolithic --cpu=${CPU} \ + bazel build --config=monolithic --cpu=${CPU} \ + --compilation_mode=opt --cxxopt=-std=c++11 \ --crosstool_top=//external:android/crosstool \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ //tensorflow/core:android_tensorflow_lib \ @@ -62,7 +63,8 @@ done # in assets/ dir (see https://github.com/bazelbuild/bazel/issues/2334) # TODO(gunan): remove extra flags once sandboxing is enabled for all builds. echo "========== Building TensorFlow Android Jar and Demo ==========" -bazel --bazelrc=/dev/null build -c opt --config=monolithic --fat_apk_cpu=${CPUS} \ +bazel --bazelrc=/dev/null build --config=monolithic --fat_apk_cpu=${CPUS} \ + --compilation_mode=opt --cxxopt=-std=c++11 \ --spawn_strategy=sandboxed --genrule_strategy=sandboxed \ //tensorflow/contrib/android:android_tensorflow_inference_java \ //tensorflow/contrib/android:android_tensorflow_inference_java.aar \ diff --git a/tensorflow/tools/ci_build/builds/test_tutorials.sh b/tensorflow/tools/ci_build/builds/test_tutorials.sh index 67e5af556405a5c659000a07a79a6bd9a1d1e542..db335f14ca4f88ade7a540ffab7ed9de67f1248e 100755 --- a/tensorflow/tools/ci_build/builds/test_tutorials.sh +++ b/tensorflow/tools/ci_build/builds/test_tutorials.sh @@ -277,17 +277,6 @@ test_ptb_word_lm() { fi } - -# ----------------------------------------------------------- -# translate_test -test_translate_test() { - LOG_FILE=$1 - - run_in_directory "${TEST_DIR}" "${LOG_FILE}" \ - "${TF_MODELS_DIR}/tutorials/rnn/translate/translate.py" --self_test=True -} - - # Run the tutorial tests test_runner "tutorial test-on-install" \ "${TUT_TESTS}" "${TF_BUILD_TUT_TEST_BLACKLIST}" "${LOGS_DIR}" diff --git a/tensorflow/tools/ci_build/remote/remote_docker_build.sh b/tensorflow/tools/ci_build/ci_rbe_docker_build.sh similarity index 58% rename from tensorflow/tools/ci_build/remote/remote_docker_build.sh rename to tensorflow/tools/ci_build/ci_rbe_docker_build.sh index e00a66aabaf1068c772aabce2391616518be44d4..cd811de6bdf9275b799a608381c76713a6c7a65b 100755 --- a/tensorflow/tools/ci_build/remote/remote_docker_build.sh +++ b/tensorflow/tools/ci_build/ci_rbe_docker_build.sh @@ -16,25 +16,19 @@ # Build TensorFlow Docker images for remote build # # Usage: -# remote_docker_build.sh -c # docker image for cpu build -# remote_docker_build.sh -g # docker image for gpu build - +# ci_rbe_docker_build.sh -c # docker image for cpu build +# ci_rbe_docker_build.sh -g # docker image for gpu build function main { - publish=true cpu_build=false gpu_build=false - publish=true + publish=false script_dir=$(dirname "$(readlink -f "$0")") cd $script_dir - trap cleanup_on_finish EXIT - set_script_flags $@ - build_base_image - build_tf_image if [ "$publish" = true ] ; then @@ -50,17 +44,14 @@ function set_script_flags { c) cpu_build=true ;; - f) - base_image_build_script=$OPTARG - ;; g) gpu_build=true ;; h) print_usage ;; - n) - publish=false + p) + publish=true ;; *) print_usage "ERROR: unknown option" @@ -76,7 +67,6 @@ function print_usage { echo "Usage: $(basename $0) -c | -g [options]" echo " -c build image for CPU build (base image debian8-clang)" echo " -g build image for GPU build (base image nvidia-clang)" - echo " -f the script which build the {debian8,nvidia}-clang base image" echo "[option] is one of" echo " -n not publish the locally-built image to GCR;" echo " the build process will publish image to GCR by default" @@ -87,54 +77,22 @@ function print_usage { exit 1 } - -# Build nvidia-cuba-clang base image for GPU image. -# For CPU the `clang-debian8` from Cloud Launcher will be used directly: -# https://console.cloud.google.com/launcher/details/google/clang-debian8?filter=category:developer-tools&q=clang -function build_base_image { - if [ "$gpu_build" = true ] ; then - base_image="nvidia-cuda" - # Run a 2-stage build for clang base image, see - # https://github.com/llvm-mirror/llvm/blob/master/docs/Docker.rst - $base_image_build_script \ - --source $base_image \ - --branch branches/google/stable \ - --docker-repository ${base_image}-clang --docker-tag "latest" \ - -p clang -i stage2-install-clang -i stage2-install-clang-headers \ - -- \ - -DLLVM_TARGETS_TO_BUILD=Native -DCMAKE_BUILD_TYPE=Release \ - -DBOOTSTRAP_CMAKE_BUILD_TYPE=Release \ - -DCLANG_ENABLE_BOOTSTRAP=ON \ - -DCLANG_BOOTSTRAP_TARGETS="install-clang;install-clang-headers" - fi -} - - function build_tf_image { if [ "$cpu_build" = true ] ; then - dockerfile="Dockerfile.cpu" - tf_image="tensorflow-remote" + dockerfile="Dockerfile.rbe.cpu" + tf_image="tensorflow-rbe-cpu" else - dockerfile="Dockerfile.gpu" - tf_image="tensorflow-remote-gpu" + dockerfile="Dockerfile.rbe.gpu" + tf_image="tensorflow-rbe-gpu" fi docker build -f $dockerfile -t $tf_image . } - function publish_tf_image { gcr_tf_image="gcr.io/tensorflow/${tf_image}" docker tag $tf_image $gcr_tf_image gcloud docker -- push $gcr_tf_image } - -function cleanup_on_finish { - cd $script_dir - rm -rf $llvm_docker_src - docker rmi -f ${base_image}-clang ${base_image}-clang-build -} - - main $@ diff --git a/tensorflow/tools/ci_build/copy_binary.py b/tensorflow/tools/ci_build/copy_binary.py index 90fd6a6e71f19649406234bc93025c15e4a5063c..420d390d2b9dc1ec25461b3502c63467a7eda16b 100755 --- a/tensorflow/tools/ci_build/copy_binary.py +++ b/tensorflow/tools/ci_build/copy_binary.py @@ -29,13 +29,9 @@ import argparse import os import re import shutil -import subprocess +import tempfile import zipfile -UNZIP_CMD = "/usr/bin/unzip" -ZIP_CMD = "/usr/bin/zip" -SED_CMD = "/bin/sed" - TF_NIGHTLY_REGEX = r"(.+)tf_nightly(|_gpu)-(\d\.\d\.\d.dev[\d]{0,8})-(.+)\.whl" BINARY_STRING_TEMPLATE = "%s-%s-%s.whl" @@ -43,7 +39,7 @@ BINARY_STRING_TEMPLATE = "%s-%s-%s.whl" def check_existence(filename): """Check the existence of file or dir.""" if not os.path.exists(filename): - raise RuntimeError("%s not found.") + raise RuntimeError("%s not found." % filename) def copy_binary(directory, origin_tag, new_tag, version, gpu=False): @@ -64,27 +60,36 @@ def copy_binary(directory, origin_tag, new_tag, version, gpu=False): package = "tf_nightly" origin_binary = BINARY_STRING_TEMPLATE % (package, version, origin_tag) new_binary = BINARY_STRING_TEMPLATE % (package, version, new_tag) - zip_ref = zipfile.ZipFile(directory + origin_binary, "r") - zip_ref.extractall() - zip_ref.close() - old_py_ver = re.search(r"(cp\d\d-cp\d\d)", origin_tag).group(1) - new_py_ver = re.search(r"(cp\d\d-cp\d\d)", new_tag).group(1) - subprocess.check_call( - "%s -i s/%s/%s/g %s-%s.dist-info/WHEEL" % (SED_CMD, old_py_ver, - new_py_ver, package, version), - shell=True) - zout = zipfile.ZipFile(directory + new_binary, "w", zipfile.ZIP_DEFLATED) - zip_these_files = [ - "%s-%s.dist-info" % (package, version), - "%s-%s.data" % (package, version) - ] - for dirname in zip_these_files: - for root, _, files in os.walk(dirname): - for filename in files: - zout.write(os.path.join(root, filename)) - zout.close() - for dirname in zip_these_files: - shutil.rmtree(dirname) + zip_ref = zipfile.ZipFile(os.path.join(directory, origin_binary), "r") + + try: + tmpdir = tempfile.mkdtemp() + os.chdir(tmpdir) + + zip_ref.extractall() + zip_ref.close() + old_py_ver = re.search(r"(cp\d\d-cp\d\d)", origin_tag).group(1) + new_py_ver = re.search(r"(cp\d\d-cp\d\d)", new_tag).group(1) + + wheel_file = os.path.join( + tmpdir, "%s-%s.dist-info" % (package, version), "WHEEL") + with open(wheel_file, "r") as f: + content = f.read() + with open(wheel_file, "w") as f: + f.write(content.replace(old_py_ver, new_py_ver)) + + zout = zipfile.ZipFile(directory + new_binary, "w", zipfile.ZIP_DEFLATED) + zip_these_files = [ + "%s-%s.dist-info" % (package, version), + "%s-%s.data" % (package, version), + ] + for dirname in zip_these_files: + for root, _, files in os.walk(dirname): + for filename in files: + zout.write(os.path.join(root, filename)) + zout.close() + finally: + shutil.rmtree(tmpdir) def main(): @@ -110,6 +115,7 @@ def main(): args = parser.parse_args() # Argument checking + args.filename = os.path.abspath(args.filename) check_existence(args.filename) regex_groups = re.search(TF_NIGHTLY_REGEX, args.filename) directory = regex_groups.group(1) diff --git a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh index cfeaebdbf57c01fef7cd81dae76217429336d0ff..d0816c92b7308a1079579e605ee9af491a0533fb 100755 --- a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh +++ b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh @@ -54,3 +54,6 @@ for i in `seq 0 $((TF_GPU_COUNT-1))`; do fi done +echo "Cannot find a free GPU to run the test $* on, exiting with failure..." +exit 1 + diff --git a/tensorflow/tools/ci_build/install/install_bazel.sh b/tensorflow/tools/ci_build/install/install_bazel.sh index 1df6a84d7c6f86abfb965063625ac43a3f1a57fb..3e27a94cf2bf3110ac181d6ef5a57366be17255f 100755 --- a/tensorflow/tools/ci_build/install/install_bazel.sh +++ b/tensorflow/tools/ci_build/install/install_bazel.sh @@ -15,7 +15,7 @@ # ============================================================================== # Select bazel version. -BAZEL_VERSION="0.10.0" +BAZEL_VERSION="0.11.0" set +e local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}') diff --git a/tensorflow/tools/integration_tests/gcs_smoke_test/teardown.sh b/tensorflow/tools/ci_build/install/install_pip_packages_remote.sh similarity index 62% rename from tensorflow/tools/integration_tests/gcs_smoke_test/teardown.sh rename to tensorflow/tools/ci_build/install/install_pip_packages_remote.sh index 852486d1677ec597fe56111ffb0e470c333c1cd7..39a6d557d185d8564a79315fc738a054325aa0bc 100755 --- a/tensorflow/tools/integration_tests/gcs_smoke_test/teardown.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages_remote.sh @@ -1,5 +1,5 @@ -#!/bin/bash -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +#!/usr/bin/env bash +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,14 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -GSUTIL_BIN="/var/gcloud/google-cloud-sdk/bin/gsutil" -echo "Got teardown argument $1" +set -e -if "${GSUTIL_BIN}" rm "$1" -then - echo "Cleaned up new tfrecord file in GCS: '$1'" -else - echo "FAIL: Unable to clean up new tfrecord file in GCS: '$1'" - exit 1 +if [ ! -f /usr/bin/x86_64-linux-gnu-gcc ]; then + ln -s /usr/local/bin/clang /usr/bin/x86_64-linux-gnu-gcc fi + +pip2 install -U pip +pip3 install -U pip +pip2 install -U setuptools +pip3 install -U setuptools + +# The rest of the pip packages will be installed in +# `install_pip_packages.sh` diff --git a/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh b/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh index 509ee38ec4fd584037f8e43726c01391430c1817..5c5a36139f50e85e70ce4bff5ca8054f7570b0f5 100755 --- a/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh +++ b/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh @@ -31,7 +31,7 @@ export CC_OPT_FLAGS='-mavx' export PYTHON_BIN_PATH=$(which python2) yes "" | $PYTHON_BIN_PATH configure.py which bazel -bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac \ +bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac,-no_mac \ --test_timeout 300,450,1200,3600 \ --test_size_filters=small,medium --config=opt \ --jobs=${N_JOBS} --build_tests_only --test_output=errors -k -- \ diff --git a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh index 05547136704394ed9262f566a2bfb4160b73c7fd..338066131b5d4511ae9f0646a1269b182cf8e1fa 100755 --- a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh +++ b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh @@ -31,7 +31,7 @@ export CC_OPT_FLAGS='-mavx' export PYTHON_BIN_PATH=$(which python2) yes "" | $PYTHON_BIN_PATH configure.py which bazel -bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac \ +bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac,-no_mac \ --test_timeout 300,450,1200,3600 --config=opt \ --test_size_filters=small,medium \ --jobs=${N_JOBS} --build_tests_only --test_output=errors -k -- \ diff --git a/tensorflow/tools/ci_build/osx/cpu/run_py3_cc_core.sh b/tensorflow/tools/ci_build/osx/cpu/run_py3_cc_core.sh index 8f839ca110e5bbeba6fb7f0baaeab2fe6f126319..920a261ae3c8d68ec0b0d311fd361e3843eebd86 100755 --- a/tensorflow/tools/ci_build/osx/cpu/run_py3_cc_core.sh +++ b/tensorflow/tools/ci_build/osx/cpu/run_py3_cc_core.sh @@ -30,7 +30,7 @@ export TF_NEED_CUDA=0 export PYTHON_BIN_PATH=$(which python3) yes "" | $PYTHON_BIN_PATH configure.py which bazel -bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac \ +bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac,-no_mac \ --test_timeout 300,450,1200,3600 \ --test_size_filters=small,medium \ --jobs=${N_JOBS} --build_tests_only --test_output=errors -k -- \ diff --git a/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh index e1b56b9a25f663737ffe0991882f6e5e753265ed..7d471b47034f04ea4c2d31d9cdd7cea48fb32745 100755 --- a/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh +++ b/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh @@ -31,5 +31,5 @@ export TF_NEED_OPENCL_SYCL=0 export TF_NEED_MKL=0 export COMPUTECPP_PATH="/usr/local" -export PATH="/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin" +export PATH="$PATH:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin" build_libtensorflow_tarball "-cpu-darwin-$(uname -m)" diff --git a/tensorflow/tools/ci_build/remote/Dockerfile.cpu b/tensorflow/tools/ci_build/remote/Dockerfile.cpu deleted file mode 100644 index 7b01d8320d26f38c92ad8f404da3188809a6d400..0000000000000000000000000000000000000000 --- a/tensorflow/tools/ci_build/remote/Dockerfile.cpu +++ /dev/null @@ -1,27 +0,0 @@ -FROM launcher.gcr.io/google/clang-debian8:latest - -RUN apt-get update && apt-get --no-install-recommends install -y \ - binutils \ - binutils-gold \ - curl \ - libstdc++-4.9-dev \ - python \ - python-dev \ - python-numpy \ - python-pip \ - unzip \ - zip && \ - rm -rf /var/lib/apt/lists/* - -RUN curl -fSsL -O https://bootstrap.pypa.io/get-pip.py && \ - python get-pip.py && \ - rm get-pip.py - -# Set up grpc -RUN pip install --upgrade enum34 futures mock numpy six backports.weakref portpicker && \ - pip install --pre 'protobuf>=3.0.0a3' && \ - pip install 'grpcio>=1.1.3' - -# TODO: Set up golang which is compatible with clang - -WORKDIR /botexec diff --git a/tensorflow/tools/ci_build/remote/Dockerfile.gpu b/tensorflow/tools/ci_build/remote/Dockerfile.gpu deleted file mode 100644 index 47ffd44163dd3e4b99f06689e1aa6f19f84cc2ca..0000000000000000000000000000000000000000 --- a/tensorflow/tools/ci_build/remote/Dockerfile.gpu +++ /dev/null @@ -1,27 +0,0 @@ -FROM nvidia-cuda-clang:latest - -RUN apt-get update && apt-get --no-install-recommends install -y \ - binutils \ - binutils-gold \ - curl \ - libstdc++-4.9-dev \ - python \ - python-dev \ - python-numpy \ - python-pip \ - unzip \ - zip && \ - rm -rf /var/lib/apt/lists/* - -RUN curl -fSsL -O https://bootstrap.pypa.io/get-pip.py && \ - python get-pip.py && \ - rm get-pip.py - -# Set up grpc -RUN pip install --upgrade \ - enum34 futures astor gast mock numpy six \ - backports.weakref termcolor && \ - pip install --pre 'protobuf>=3.0.0a3' && \ - pip install 'grpcio>=1.1.3' - -WORKDIR /botexec diff --git a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh index 7b2d7e1a568b0235a5bdd55bb23e542772902576..d654b433e7ddcfc79dea010c43d8eb0bc33fdcb2 100644 --- a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh +++ b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh @@ -120,7 +120,9 @@ function run_configure_for_gpu_build { export TF_CUDA_VERSION=9.0 export CUDA_TOOLKIT_PATH="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0" export TF_CUDNN_VERSION=7.0 - export CUDNN_INSTALL_PATH="C:/tools/cuda" + if [ -z "$CUDNN_INSTALL_PATH" ]; then + export CUDNN_INSTALL_PATH="C:/tools/cuda" + fi export TF_CUDA_COMPUTE_CAPABILITIES="3.7" if [ -z "$TF_ENABLE_XLA" ]; then export TF_ENABLE_XLA=0 diff --git a/tensorflow/tools/ci_build/windows/bazel/common_env.sh b/tensorflow/tools/ci_build/windows/bazel/common_env.sh index 1c35d74af72ad0a72b0016356888c8cf77e20e56..7d4cc7ac3005f7ff9a79d18228e86d6b74e1e8b0 100644 --- a/tensorflow/tools/ci_build/windows/bazel/common_env.sh +++ b/tensorflow/tools/ci_build/windows/bazel/common_env.sh @@ -34,6 +34,9 @@ export BAZEL_SH=${BAZEL_SH:-"C:/tools/msys64/usr/bin/bash"} export PYTHON_BASE_PATH="${PYTHON_DIRECTORY:-Program Files/Anaconda3}" +# Set the path to find bazel. +export PATH="/c/tools/bazel/:$PATH" + # Set Python path for ./configure export PYTHON_BIN_PATH="C:/${PYTHON_BASE_PATH}/python.exe" export PYTHON_LIB_PATH="C:/${PYTHON_BASE_PATH}/lib/site-packages" diff --git a/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat b/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat index b87e4a9bec41264827d415a11dfa6f23aeda725d..4656afe0256d03540fed6912677c8e93f9cf9eb6 100644 --- a/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat +++ b/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat @@ -37,7 +37,7 @@ SET CMAKE_DIR=%REPO_ROOT%\tensorflow\contrib\cmake SET MSBUILD_EXE="C:\Program Files (x86)\MSBuild\14.0\Bin\msbuild.exe" :: Run cmake to create Visual Studio Project files. -%CMAKE_EXE% %CMAKE_DIR% -A x64 -DSWIG_EXECUTABLE=%SWIG_EXE% -DPYTHON_EXECUTABLE=%PY_EXE% -DCMAKE_BUILD_TYPE=Release -DPYTHON_LIBRARIES=%PY_LIB% -Dtensorflow_BUILD_PYTHON_TESTS=%BUILD_PYTHON_TESTS% -Dtensorflow_BUILD_CC_TESTS=%BUILD_CC_TESTS% -Dtensorflow_ENABLE_GPU=ON -DCUDNN_HOME=%CUDNN_HOME% -Dtensorflow_TF_NIGHTLY=%TF_NIGHTLY% -Dtensorflow_DISABLE_EIGEN_FORCEINLINE=%DISABLE_FORCEINLINE% -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX +%CMAKE_EXE% %CMAKE_DIR% -A x64 -DSWIG_EXECUTABLE=%SWIG_EXE% -DPYTHON_EXECUTABLE=%PY_EXE% -DCMAKE_BUILD_TYPE=Release -DPYTHON_LIBRARIES=%PY_LIB% -Dtensorflow_BUILD_PYTHON_TESTS=%BUILD_PYTHON_TESTS% -Dtensorflow_BUILD_CC_TESTS=%BUILD_CC_TESTS% -Dtensorflow_ENABLE_GPU=ON -DCUDNN_HOME=%CUDNN_HOME% -Dtensorflow_TF_NIGHTLY=%TF_NIGHTLY% -Dtensorflow_DISABLE_EIGEN_FORCEINLINE=%DISABLE_FORCEINLINE% -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX -G"Visual Studio 14" :: Run msbuild in the resulting VS project files to build a pip package. %MSBUILD_EXE% /p:Configuration=Release /maxcpucount:32 tf_python_build_pip_package.vcxproj diff --git a/tensorflow/tools/ci_build/windows/gpu/cmake/run_py.bat b/tensorflow/tools/ci_build/windows/gpu/cmake/run_py.bat index b537192a945b2a2d8c2df940b947c6c0f7d6fc06..97829892b10059f9d9663e103534891d1481abec 100644 --- a/tensorflow/tools/ci_build/windows/gpu/cmake/run_py.bat +++ b/tensorflow/tools/ci_build/windows/gpu/cmake/run_py.bat @@ -28,6 +28,9 @@ IF DEFINED TF_NIGHTLY (ECHO TF_NIGHTLY is set to %TF_NIGHTLY%) ELSE (SET TF_NIGH :: Set pip binary location. Do not override if it is set already. IF DEFINED PIP_EXE (ECHO PIP_EXE is set to %PIP_EXE%) ELSE (SET PIP_EXE="C:\Program Files\Anaconda3\Scripts\pip.exe") +:: Set ctest binary location. +IF DEFINED CTEST_EXE (ECHO CTEST_EXE is set to %CTEST_EXE%) ELSE (SET CTEST_EXE="C:\Program Files\cmake\bin\ctest.exe") + :: Run the CMAKE build to build the pip package. CALL %REPO_ROOT%\tensorflow\tools\ci_build\windows\gpu\cmake\run_build.bat if %errorlevel% neq 0 exit /b %errorlevel% @@ -47,4 +50,4 @@ if %errorlevel% neq 0 exit /b %errorlevel% :: Run all python tests if the installation succeeded. echo Running tests... -ctest -C Release --output-on-failure --jobs 1 +%CTEST_EXE% -C Release --output-on-failure --jobs 1 diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py index 6e90b286c99f894ddd25268afc69043759571c36..1f8833582af4c922115e637117e775e619439786 100644 --- a/tensorflow/tools/compatibility/tf_upgrade.py +++ b/tensorflow/tools/compatibility/tf_upgrade.py @@ -662,9 +662,9 @@ class TFAPIChangeSpec(APIChangeSpec): def _reverse_handler(file_edit_recorder, node): # TODO(aselle): Could check for a literal list of bools and try to convert # them to indices. - comment = ("ERROR: tf.reverse has had its argument semantics changed\n" - "significantly the converter cannot detect this reliably, so you" - "need to inspect this usage manually.\n") + comment = ("ERROR: tf.reverse has had its argument semantics changed " + "significantly the converter cannot detect this reliably, so " + "you need to inspect this usage manually.\n") file_edit_recorder.add( comment, node.lineno, diff --git a/tensorflow/tools/dist_test/README.md b/tensorflow/tools/dist_test/README.md index c1b1f79bbd4b657768b9bbcab93efa3354774915..228d5ee35d1839c60b51a85bd606c1ba86e46886 100644 --- a/tensorflow/tools/dist_test/README.md +++ b/tensorflow/tools/dist_test/README.md @@ -17,6 +17,14 @@ cesnsu model: ./local_test.sh --model_name CENSUS_WIDENDEEP +You can test specify version of TensorFlow: + +```shell +./local_test.sh ${whl_file_url} +``` + +For example, you can find these TensorFlow python package URLs from [here](https://www.tensorflow.org/install/install_linux#the_url_of_the_tensorflow_python_package) for Ubuntu. + **2) Launch a remote k8s cluster on Google Kubernetes Engine (GKE) and run the test suite on it** diff --git a/tensorflow/tools/dist_test/local_test.sh b/tensorflow/tools/dist_test/local_test.sh index 435f9d0dc9c55a3dcfc45e7e46f279b4679a9086..caae7fd5305af9846628eaf00348dd08df4e827f 100755 --- a/tensorflow/tools/dist_test/local_test.sh +++ b/tensorflow/tools/dist_test/local_test.sh @@ -16,12 +16,11 @@ # # Tests distributed TensorFlow on a locally running TF GRPC cluster. # -# This script peforms the following steps: -# 1) Build the docker-in-docker (dind) image capable of running docker and -# Kubernetes (k8s) cluster inside. +# This script performs the following steps: +# 1) Build the docker image capable of running distributed TensorFlow in docker. # 2) Run a container from the aforementioned image and start docker service # in it -# 3) Call a script to launch a k8s TensorFlow GRPC cluster inside the container +# 3) Call a script to launch a distributed TensorFlow GRPC cluster inside the container # and run the distributed test suite. # # Usage: local_test.sh @@ -64,15 +63,9 @@ die() { # Configurations DOCKER_IMG_NAME="tensorflow/tf-dist-test-local-cluster" -LOCAL_K8S_CACHE=${HOME}/kubernetes -# Helper function -get_container_id_by_image_name() { - # Get the id of a container by image name - # Usage: get_docker_container_id_by_image_name - - docker ps | grep $1 | awk '{print $1}' -} +# Use TensorFlow v1.5.0 for Python 2.7 and CPU only as we set num_gpus to 0 in the below +DEFAULT_WHL_FILE_LOCATION="https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0-cp27-none-linux_x86_64.whl" # Parse input arguments LEAVE_CONTAINER_RUNNING=0 @@ -84,7 +77,8 @@ SYNC_REPLICAS_FLAG="" WHL_FILE_LOCATION=${1} if [[ -z "${WHL_FILE_LOCATION}" ]]; then - die "whl file location is not specified" + WHL_FILE_LOCATION=${DEFAULT_WHL_FILE_LOCATION} + echo "use default whl file location" fi while true; do @@ -121,7 +115,7 @@ DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # Get utility functions source ${DIR}/scripts/utils.sh -# Build docker-in-docker image for local k8s cluster. +# Build docker image for local distributed TensorFlow cluster. NO_CACHE_FLAG="" if [[ ! -z "${TF_DIST_DOCKER_NO_CACHE}" ]] && [[ "${TF_DIST_DOCKER_NO_CACHE}" != "0" ]]; then diff --git a/tensorflow/tools/dist_test/python/mnist_replica.py b/tensorflow/tools/dist_test/python/mnist_replica.py index a2d12442c44553a287637029843021b7541fa3fa..d6e7f317dd0b52203e354676425dbbbcd53e1973 100644 --- a/tensorflow/tools/dist_test/python/mnist_replica.py +++ b/tensorflow/tools/dist_test/python/mnist_replica.py @@ -56,7 +56,7 @@ flags.DEFINE_integer("task_index", None, flags.DEFINE_integer("num_gpus", 1, "Total number of gpus for each machine." "If you don't use GPU, please set it to '0'") flags.DEFINE_integer("replicas_to_aggregate", None, - "Number of replicas to aggregate before parameter update" + "Number of replicas to aggregate before parameter update " "is applied (For sync_replicas mode only; default: " "num_workers)") flags.DEFINE_integer("hidden_units", 100, diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel index d16761c3675942838fd2be0ea6e0b7463a3bf249..11f476d12c086f70335d9a69d7f3b86b525b5623 100644 --- a/tensorflow/tools/docker/Dockerfile.devel +++ b/tensorflow/tools/docker/Dockerfile.devel @@ -57,7 +57,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \ >>/etc/bazel.bazelrc # Install the most recent bazel release. -ENV BAZEL_VERSION 0.8.0 +ENV BAZEL_VERSION 0.11.0 WORKDIR / RUN mkdir /bazel && \ cd /bazel && \ @@ -70,7 +70,7 @@ RUN mkdir /bazel && \ # Download and build TensorFlow. WORKDIR /tensorflow -RUN git clone --branch=r1.6 --depth=1 https://github.com/tensorflow/tensorflow.git . +RUN git clone --branch=r1.7 --depth=1 https://github.com/tensorflow/tensorflow.git . # TODO(craigcitro): Don't install the pip package, since it makes it # more difficult to experiment with local changes. Instead, just add diff --git a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl b/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl index 3690e7dfe57a4682276a90b10cb84c9a329b3f5e..037d13116efc5ddf76c31eb87d7f81d31c3591f5 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl +++ b/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl @@ -3,7 +3,7 @@ FROM tensorflow/tensorflow:latest-devel LABEL maintainer="Clayne Robison" # These arguments are parameterized. Use --build-args to override. -ARG TF_BRANCH=r1.6 +ARG TF_BRANCH=r1.7 ARG WHL_DIR=/whl RUN apt-get update && apt-get install -y --no-install-recommends \ diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu index 4ef37881bc91aaa58bab031c69b4a96c2a9d8ec1..1fcb6428b21b4ca495bef2b3249b6463e9ef0a10 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-gpu +++ b/tensorflow/tools/docker/Dockerfile.devel-gpu @@ -66,7 +66,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \ >>/etc/bazel.bazelrc # Install the most recent bazel release. -ENV BAZEL_VERSION 0.8.0 +ENV BAZEL_VERSION 0.11.0 WORKDIR / RUN mkdir /bazel && \ cd /bazel && \ @@ -79,7 +79,7 @@ RUN mkdir /bazel && \ # Download and build TensorFlow. WORKDIR /tensorflow -RUN git clone --branch=r1.6 --depth=1 https://github.com/tensorflow/tensorflow.git . +RUN git clone --branch=r1.7 --depth=1 https://github.com/tensorflow/tensorflow.git . # Configure the build for our CUDA configuration. ENV CI_BUILD_PYTHON python diff --git a/tensorflow/tools/docker/Dockerfile.gpu b/tensorflow/tools/docker/Dockerfile.gpu index b6682cd68163ec870ed815b45ac4fdd9233f88c6..625321e1235202f78a2d5e1a5b2d9d05e1e3f9ba 100644 --- a/tensorflow/tools/docker/Dockerfile.gpu +++ b/tensorflow/tools/docker/Dockerfile.gpu @@ -1,11 +1,18 @@ -FROM nvidia/cuda:9.0-cudnn7-runtime-ubuntu16.04 +FROM nvidia/cuda:9.0-base-ubuntu16.04 LABEL maintainer="Craig Citro " # Pick up some TF dependencies RUN apt-get update && apt-get install -y --no-install-recommends \ build-essential \ + cuda-command-line-tools-9-0 \ + cuda-cublas-9-0 \ + cuda-cufft-9-0 \ + cuda-curand-9-0 \ + cuda-cusolver-9-0 \ + cuda-cusparse-9-0 \ curl \ + libcudnn7=7.0.5.15-1+cuda9.0 \ libfreetype6-dev \ libpng12-dev \ libzmq3-dev \ diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py index 34dd419f15676babfa9a36c2c0960b01248b6f69..d22a465376f4f58164514e62d302524a43b0dd01 100644 --- a/tensorflow/tools/docs/generate_lib.py +++ b/tensorflow/tools/docs/generate_lib.py @@ -211,6 +211,7 @@ def _get_default_do_not_descend_map(): 'tf': ['cli', 'lib', 'wrappers'], 'tf.contrib': [ 'compiler', + 'distribute', 'grid_rnn', # Block contrib.keras to de-clutter the docs 'keras', diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py index e758229535e7b10994a39cbafb37e116fd2a465c..d2a63ecc4960117eb64fcc4f94bf882d4a3f91dd 100644 --- a/tensorflow/tools/docs/parser.py +++ b/tensorflow/tools/docs/parser.py @@ -34,7 +34,11 @@ from tensorflow.python.util import tf_inspect # A regular expression capturing a python identifier. -IDENTIFIER_RE = '[a-zA-Z_][a-zA-Z0-9_]*' +IDENTIFIER_RE = r'[a-zA-Z_]\w*' + + +class TFDocsError(Exception): + pass class _Errors(object): @@ -118,6 +122,8 @@ SYMBOL_REFERENCE_RE = re.compile( """, flags=re.VERBOSE) +AUTO_REFERENCE_RE = re.compile(r'`([a-zA-Z0-9_.]+?)`') + class ReferenceResolver(object): """Class for replacing @{...} references with Markdown links. @@ -240,10 +246,25 @@ class ReferenceResolver(object): Returns: `string`, with "@{symbol}" references replaced by Markdown links. """ - def one_ref(match): - return self._one_ref(match, relative_path_to_root) - return re.sub(SYMBOL_REFERENCE_RE, one_ref, string) + def strict_one_ref(match): + try: + return self._one_ref(match, relative_path_to_root) + except TFDocsError as e: + self.add_error(e.message) + return 'BAD_LINK' + + string = re.sub(SYMBOL_REFERENCE_RE, strict_one_ref, string) + + def sloppy_one_ref(match): + try: + return self._one_ref(match, relative_path_to_root) + except TFDocsError: + return match.group(0) + + string = re.sub(AUTO_REFERENCE_RE, sloppy_one_ref, string) + + return string def python_link(self, link_text, ref_full_name, relative_path_to_root, code_ref=True): @@ -307,14 +328,14 @@ class ReferenceResolver(object): Raises: RuntimeError: If `ref_full_name` is not documented. + TFDocsError: If the @{} syntax cannot be decoded. """ master_name = self._duplicate_of.get(ref_full_name, ref_full_name) # Check whether this link exists if master_name not in self._all_names: - message = 'Cannot make link to "%s": Not in index.' % master_name - self.add_error(message) - return 'BROKEN_LINK' + raise TFDocsError( + 'Cannot make link to "%s": Not in index.' % master_name) # If this is a member of a class, link to the class page with an anchor. ref_path = None @@ -369,8 +390,8 @@ class ReferenceResolver(object): code_ref=not manual_link_text) # Error! - self.add_error('Did not understand "%s"' % match.group(0)) - return 'BROKEN_LINK' + raise TFDocsError('Did not understand "%s"' % match.group(0), + 'BROKEN_LINK') def _doc_link(self, string, link_text, manual_link_text, relative_path_to_root): @@ -395,11 +416,10 @@ class ReferenceResolver(object): return self._doc_missing(string, hash_tag, link_text, manual_link_text, relative_path_to_root) - def _doc_missing(self, string, unused_hash_tag, link_text, + def _doc_missing(self, string, unused_hash_tag, unused_link_text, unused_manual_link_text, unused_relative_path_to_root): """Generate an error for unrecognized @{$...} references.""" - self.add_error('Unknown Document "%s"' % string) - return link_text + raise TFDocsError('Unknown Document "%s"' % string) def _cc_link(self, string, link_text, unused_manual_link_text, relative_path_to_root): @@ -416,8 +436,8 @@ class ReferenceResolver(object): elif string == 'tensorflow::ops::Const': ret = 'namespace/tensorflow/ops.md#const' else: - self.add_error('C++ reference not understood: "%s"' % string) - return 'TODO_C++:%s' % string + raise TFDocsError('C++ reference not understood: "%s"' % string) + # relative_path_to_root gets you to api_docs/python, we go from there # to api_docs/cc, and then add ret. cc_relative_path = os.path.normpath(os.path.join( diff --git a/tensorflow/tools/git/gen_git_source.py b/tensorflow/tools/git/gen_git_source.py index 3630dbd740e981971bdc9ff45b756b45095d437d..cbcdbf5b807a585865e2e3f19291e55388d55cb1 100755 --- a/tensorflow/tools/git/gen_git_source.py +++ b/tensorflow/tools/git/gen_git_source.py @@ -114,6 +114,13 @@ def configure(src_base_path, gen_path, debug=False): for target, src in link_map.items(): if src is None: open(os.path.join(gen_path, target), "w").write("") + elif not os.path.exists(src): + # Git repo is configured in a way we don't support such as having + # packed refs. Even though in a git repo, tf.__git_version__ will not + # be accurate. + # TODO(mikecase): Support grabbing git info when using packed refs. + open(os.path.join(gen_path, target), "w").write("") + spec["git"] = False else: try: # In python 3.5, symlink function exists even on Windows. But requires diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index ad3668fa02e102607c9a03ac312451a147affdda..6e21aa28461819fb9f65642716536e37ada8f9bf 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -91,7 +91,6 @@ cc_library( srcs = [ "add_default_attributes.cc", "backports.cc", - "fake_quantize_training.cc", "flatten_atrous.cc", "fold_batch_norms.cc", "fold_constants_lib.cc", @@ -105,7 +104,6 @@ cc_library( "remove_attribute.cc", "remove_control_dependencies.cc", "remove_device.cc", - "remove_ema.cc", "remove_nodes.cc", "rename_attribute.cc", "rename_op.cc", @@ -134,8 +132,8 @@ cc_library( "//tensorflow/core:tensorflow", "//tensorflow/contrib/rnn:gru_ops_op_lib", "//tensorflow/contrib/rnn:lstm_ops_op_lib", + "//tensorflow/core/kernels:quantization_utils", ] + if_not_windows([ - "//tensorflow/core/kernels:quantized_ops", "//tensorflow/core/kernels:remote_fused_graph_rewriter_transform", "//tensorflow/core/kernels/hexagon:hexagon_rewriter_transform", ]), @@ -148,7 +146,6 @@ tf_cc_test( srcs = [ "add_default_attributes_test.cc", "backports_test.cc", - "fake_quantize_training_test.cc", "flatten_atrous_test.cc", "fold_batch_norms_test.cc", "fold_constants_test.cc", @@ -161,7 +158,6 @@ tf_cc_test( "quantize_weights_test.cc", "remove_attribute_test.cc", "remove_device_test.cc", - "remove_ema_test.cc", "remove_nodes_test.cc", "rename_attribute_test.cc", "rename_op_test.cc", @@ -182,6 +178,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/kernels:quantization_utils", "//tensorflow/core/kernels:quantized_ops", "//tensorflow/core/util/tensor_bundle", ], diff --git a/tensorflow/tools/graph_transforms/fake_quantize_training.cc b/tensorflow/tools/graph_transforms/fake_quantize_training.cc deleted file mode 100644 index 61aecc6e16d817d245421f18fa39c70aa45b2bef..0000000000000000000000000000000000000000 --- a/tensorflow/tools/graph_transforms/fake_quantize_training.cc +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#define EIGEN_USE_THREADS - -#include "tensorflow/core/graph/quantize_training.h" -#include "tensorflow/tools/graph_transforms/transform_utils.h" - -namespace tensorflow { -namespace graph_transforms { - -// EXPERIMENTAL: This can change without warning. -// Rewrites the GraphDef for quantized training. -// Rewrites the forward pass to include the precision loss with quantization so -// the model can learn to deal with such loss and achieve better accuracy when -// it is quantized later for inference. -// Quantization range information is collected in FakeQuantizeWithMinMaxVars -// ops. -// -// TODO(suharshs): Provide instructions on converting the resulting graph for -// inference. -// TODO(suharshs): Implement this using the GTT rather than calling the old -// prototype function. -Status FakeQuantizeTraining(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def) { - // TODO(suharshs): Make num_bits a parameter. - const int32 num_bits = 8; - // TODO(suharshs): Make quantization op a parameter? - const string quant_op_type = "FakeQuantWithMinMaxVars"; - - return DoQuantizeTrainingOnGraphDef(input_graph_def, num_bits, quant_op_type, - output_graph_def); -} - -REGISTER_GRAPH_TRANSFORM("fake_quantize_training", FakeQuantizeTraining); - -} // namespace graph_transforms -} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc b/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc deleted file mode 100644 index 5e4ab209e97808c3f42ecf73fb763ef9d7ab1cfe..0000000000000000000000000000000000000000 --- a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/cc/ops/const_op.h" -#include "tensorflow/cc/ops/math_ops.h" -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/tools/graph_transforms/transform_utils.h" - -namespace tensorflow { -namespace graph_transforms { - -// Declare here, so we don't need a public header. -Status FakeQuantizeTraining(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); - -class FakeQuantizeTrainingTest : public ::testing::Test {}; - -// For now, since the fake_quantize_training transform just calls the -// quantize_training rewrite from tensorflow/core/graph/quantize_training.h, -// we just test that the graph has been changed by the transform. -// TODO(suharshs): Once we implement the fake_quantize_training transform -// using the GTT, write proper tests of the transform here. -TEST_F(FakeQuantizeTrainingTest, TransformOccurred) { - auto root = tensorflow::Scope::DisabledShapeInferenceScope(); - using namespace ::tensorflow::ops; // NOLINT(build/namespaces) - - Tensor a_data(DT_FLOAT, TensorShape()); - test::FillIota(&a_data, 1.0f); - Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); - - Tensor b_data(DT_FLOAT, TensorShape()); - test::FillIota(&b_data, 1.0f); - Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); - - Output matmul = MatMul(root.WithOpName("matmul"), a_const, b_const); - GraphDef graph_def; - TF_ASSERT_OK(root.ToGraphDef(&graph_def)); - - GraphDef result; - TransformFuncContext context; - TF_ASSERT_OK(FakeQuantizeTraining(graph_def, context, &result)); - - // Test that the transformation resulted in a graph with more nodes. - EXPECT_GT(result.node_size(), graph_def.node_size()); -} - -} // namespace graph_transforms -} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc index d89afe85c72883323cec3c14342fd60adebd024d..d86f65325be1c3f5151ab8d0a0c3c64afa3dcf0f 100644 --- a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc @@ -182,6 +182,36 @@ Status FuseBatchNormWithConv(const NodeMatch& match, return Status::OK(); } +Status FuseBatchNormWithBatchToSpace(const NodeMatch& match, + std::vector* new_nodes) { + // Calculate the scale and offset values to apply. + std::vector scale_values; + std::vector offset_values; + TF_RETURN_IF_ERROR( + GetScaleAndOffsetValues(match, &scale_values, &offset_values)); + + // Fuse conv weights, and set the final output node name as batch_norm_node. + const NodeDef& batch_norm_node = match.node; + const NodeMatch& batch_to_space_node_match = match.inputs[0]; + const NodeMatch& conv_node_match = batch_to_space_node_match.inputs[0]; + const NodeDef& batch_to_space_node = batch_to_space_node_match.node; + const NodeDef& conv_node = conv_node_match.node; + + string biasadd_name = conv_node.name() + "/biasadd"; + TF_RETURN_IF_ERROR( + FuseScaleOffsetToConvWeights(scale_values, offset_values, conv_node_match, + biasadd_name , new_nodes)); + + NodeDef new_batch_to_space_node = batch_to_space_node; + // reuse batch_norm node name + new_batch_to_space_node.set_name(batch_norm_node.name()); + new_batch_to_space_node.set_input(0, biasadd_name); + new_nodes->push_back(batch_to_space_node_match.inputs[1].node); + new_nodes->push_back(batch_to_space_node_match.inputs[2].node); + new_nodes->push_back(new_batch_to_space_node); + return Status::OK(); +} + Status FuseBatchNormWithConvConcat(const NodeMatch& match, std::vector* new_nodes) { // Calculate the scale and offset values to apply. @@ -284,6 +314,43 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def, current_graph_def = replaced_graph_def; } while (did_graph_change); + do { + did_graph_change = false; + GraphDef replaced_graph_def; + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + current_graph_def, // clang-format off + {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node + { + {"BatchToSpaceND", // batch_to_space_node + { + {"Conv2D", // conv_node + { + {"*"}, // input_node + {"Const"}, // weights_node + } + }, + {"Const"}, // block_shape + {"Const"}, // crops + } + }, + {"Const"}, // mean_node + {"Const"}, // variance_node + {"Const"}, // beta_node + {"Const"}, // gamma_node + } + }, // clang-format on + [&did_graph_change](const NodeMatch& match, + const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + TF_RETURN_IF_ERROR(FuseBatchNormWithBatchToSpace(match, new_nodes)); + did_graph_change = true; + return Status::OK(); + }, + {}, &replaced_graph_def)); + current_graph_def = replaced_graph_def; + } while (did_graph_change); + do { did_graph_change = false; GraphDef replaced_graph_def; diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc index b30ba9ac8b92db68eb3374c51a7f31b69cd1e3cf..7651a03fe51012678d6d6fc495fd82e497aa512b 100644 --- a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/image_ops.h" #include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -298,6 +299,96 @@ class FoldOldBatchNormsTest : public ::testing::Test { } }; +void TestFoldFusedBatchNormsWithBatchToSpace() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_data(DT_FLOAT, TensorShape({2, 1, 3, 2})); + test::FillValues( + &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f, + -5.0f, -3.0f, -6.0f}); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_data)); + + Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2})); + test::FillValues(&weights_data, + {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f}); + Output weights_op = + Const(root.WithOpName("weights_op"), Input::Initializer(weights_data)); + + Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op, + {1, 1, 1, 1}, "VALID"); + + Tensor block_shape_data(DT_INT32, TensorShape({2})); + test::FillValues(&block_shape_data, {1, 2}); + Output block_shape_op = + Const(root.WithOpName("block_shape_op"), Input::Initializer(block_shape_data)); + + Tensor crops_data(DT_INT32, TensorShape({2, 2})); + test::FillValues(&crops_data, {0, 0, 0, 1}); + Output crops_op = + Const(root.WithOpName("crops_op"), Input::Initializer(crops_data)); + + Output batch_to_space_op = BatchToSpaceND(root.WithOpName("batch_to_space_op"), + conv_op, block_shape_op, crops_data); + + Tensor mean_data(DT_FLOAT, TensorShape({2})); + test::FillValues(&mean_data, {10.0f, 20.0f}); + Output mean_op = + Const(root.WithOpName("mean_op"), Input::Initializer(mean_data)); + + Tensor variance_data(DT_FLOAT, TensorShape({2})); + test::FillValues(&variance_data, {0.25f, 0.5f}); + Output variance_op = Const(root.WithOpName("variance_op"), + Input::Initializer(variance_data)); + + Tensor beta_data(DT_FLOAT, TensorShape({2})); + test::FillValues(&beta_data, {0.1f, 0.6f}); + Output beta_op = + Const(root.WithOpName("beta_op"), Input::Initializer(beta_data)); + + Tensor gamma_data(DT_FLOAT, TensorShape({2})); + test::FillValues(&gamma_data, {1.0f, 2.0f}); + Output gamma_op = + Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data)); + + GraphDef original_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&original_graph_def)); + + NodeDef batch_norm_node; + batch_norm_node.set_op("FusedBatchNorm"); + batch_norm_node.set_name("output"); + AddNodeInput("batch_to_space_op", &batch_norm_node); + AddNodeInput("gamma_op", &batch_norm_node); + AddNodeInput("beta_op", &batch_norm_node); + AddNodeInput("mean_op", &batch_norm_node); + AddNodeInput("variance_op", &batch_norm_node); + SetNodeAttr("T", DT_FLOAT, &batch_norm_node); + SetNodeAttr("epsilon", 0.00001f, &batch_norm_node); + SetNodeAttr("is_training", false, &batch_norm_node); + *(original_graph_def.mutable_node()->Add()) = batch_norm_node; + + std::unique_ptr original_session(NewSession(SessionOptions())); + TF_ASSERT_OK(original_session->Create(original_graph_def)); + std::vector original_outputs; + TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs)); + + GraphDef fused_graph_def; + TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}}, + &fused_graph_def)); + + std::unique_ptr fused_session(NewSession(SessionOptions())); + TF_ASSERT_OK(fused_session->Create(fused_graph_def)); + std::vector fused_outputs; + TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs)); + + test::ExpectTensorNear(original_outputs[0], fused_outputs[0], 1e-5); + + for (const NodeDef& node : fused_graph_def.node()) { + EXPECT_NE("FusedBatchNormWithBatchToSpace", node.op()); + } +} + TEST_F(FoldOldBatchNormsTest, TestFoldOldBatchNorms) { TestFoldOldBatchNorms(); } @@ -307,7 +398,7 @@ TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNorms) { } TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNormsWithConcat) { - // Test axis is not 3, so all weigths and offsets are fused to each of inputs + // Test axis is not 3, so all weights and offsets are fused to each of inputs // of conv2d. TestFoldFusedBatchNormsWithConcat(/*split=*/true); // Test axis = 3, BatchNorm weights and offsets will be split before fused @@ -315,5 +406,9 @@ TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNormsWithConcat) { TestFoldFusedBatchNormsWithConcat(/*split=*/false); } +TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNormsWithBatchToSpace) { + TestFoldFusedBatchNormsWithBatchToSpace(); +} + } // namespace graph_transforms } // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/remove_ema.cc b/tensorflow/tools/graph_transforms/remove_ema.cc deleted file mode 100644 index 22e26267025c3fed4f44ffbc09d55d8d355cc448..0000000000000000000000000000000000000000 --- a/tensorflow/tools/graph_transforms/remove_ema.cc +++ /dev/null @@ -1,146 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#define EIGEN_USE_THREADS - -#include "tensorflow/tools/graph_transforms/transform_utils.h" - -namespace tensorflow { -namespace graph_transforms { - -// EXPERIMENTAL: This can change without warning. -// Given a graph that has gone through the FakeQuantizeTraining transform and -// has been frozen afterwards, RemoveEMA simplifies the FakeQuantize estimated -// moving average subgraphs to make it compatible with the QuantizeNodes -// transform. -Status RemoveEMA(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def) { - TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( - input_graph_def, // clang-format off - {"FakeQuantWithMinMaxVars", - { - {"*"}, - {"Assign", - { - {"Const"}, - {"Merge", - { - {"Switch", - { - {"Min", - { - {"*"}, - {"Range", - { - {"*"}, - {"*"}, - {"*"}, - } - } - } - }, - {"IsVariableInitialized"} - } - }, - {"Sub", - { - {"Const"}, - {"Mul", - { - {"Sub"}, - {"Sub", - { - {"Const"}, - {"Const"} - } - } - } - } - } - } - } - } - } - }, - {"Assign", - { - {"Const"}, - {"Merge", - { - {"Switch", - { - {"Max"}, - {"IsVariableInitialized"} - } - }, - {"Sub", - { - {"Const"}, - {"Mul", - { - {"Sub"}, - {"Sub", - { - {"Const"}, - {"Const"} - } - } - } - } - } - } - } - } - } - }, - } - }, // clang-format on - [](const NodeMatch& match, const std::set& input_nodes, - const std::set& output_nodes, - std::vector* new_nodes) { - const NodeDef& fake_quant_node = match.node; - const NodeDef& input_node = match.inputs[0].node; - const NodeDef& min_var_node = match.inputs[1].inputs[0].node; - const NodeDef& max_var_node = match.inputs[2].inputs[0].node; - - // Make a new FakeQuantizeWithMinMaxVars operation that uses constants - // for its min/max arguments rather than an entire EMA subgraph. - NodeDef new_fake_quant_node; - new_fake_quant_node.set_op(fake_quant_node.op()); - new_fake_quant_node.set_name(fake_quant_node.name()); - AddNodeInput(input_node.name(), &new_fake_quant_node); - AddNodeInput(min_var_node.name(), &new_fake_quant_node); - AddNodeInput(max_var_node.name(), &new_fake_quant_node); - CopyNodeAttr(fake_quant_node, "narrow_range", "narrow_range", - &new_fake_quant_node); - CopyNodeAttr(fake_quant_node, "num_bits", "num_bits", - &new_fake_quant_node); - - new_nodes->push_back(new_fake_quant_node); - new_nodes->push_back(input_node); - new_nodes->push_back(min_var_node); - new_nodes->push_back(max_var_node); - - return Status::OK(); - }, - {}, output_graph_def)); - return Status::OK(); -} - -REGISTER_GRAPH_TRANSFORM("remove_ema", RemoveEMA); - -} // namespace graph_transforms -} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/remove_ema_test.cc b/tensorflow/tools/graph_transforms/remove_ema_test.cc deleted file mode 100644 index 27db90e2729487f89324622f7a63aca1c5a58fe7..0000000000000000000000000000000000000000 --- a/tensorflow/tools/graph_transforms/remove_ema_test.cc +++ /dev/null @@ -1,121 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/cc/ops/const_op.h" -#include "tensorflow/cc/ops/math_ops.h" -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/public/session.h" -#include "tensorflow/tools/graph_transforms/transform_utils.h" - -namespace tensorflow { -namespace graph_transforms { - -// Declare transformations here, so we don't need a public header. -Status FakeQuantizeTraining(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); - -Status RemoveEMA(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); - -Status QuantizeNodes(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); - -class RemoveEMATest : public ::testing::Test {}; - -TEST_F(RemoveEMATest, FakeQuant_RemoveEMA_QuantizeTraining) { - // Build a small graph. - auto root = tensorflow::Scope::NewRootScope(); - using namespace ::tensorflow::ops; // NOLINT(build/namespaces) - - Tensor a_data(DT_FLOAT, TensorShape({1, 1})); - test::FillIota(&a_data, 1.0f); - Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); - - Tensor b_data(DT_FLOAT, TensorShape({1, 1})); - test::FillIota(&b_data, 1.0f); - Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); - - Output matmul = MatMul(root.WithOpName("matmul"), a_const, b_const); - GraphDef graph_def; - TF_ASSERT_OK(root.ToGraphDef(&graph_def)); - - // (1) FakeQuantize the graph. - GraphDef fake_quantized_graph_def; - TransformFuncContext context; - TF_ASSERT_OK( - FakeQuantizeTraining(graph_def, context, &fake_quantized_graph_def)); - - // Test that the transformation resulted in a graph with more nodes. - EXPECT_GT(fake_quantized_graph_def.node_size(), graph_def.node_size()); - - // (2) Run the graph to initialize the newly added variables. - std::unique_ptr session(NewSession(SessionOptions())); - TF_ASSERT_OK(session->Create(fake_quantized_graph_def)); - std::vector outputs; - TF_ASSERT_OK(session->Run({}, {"matmul"}, {}, &outputs)); - - // (3) Freeze the graph. Create a "frozen graph" that matches what we would - // expect if we actually froze the above graph. - // TODO(suharshs): Use a c++ freeze graph alternative, when one is available. - GraphDef frozen_graph_def; - for (const NodeDef& node : fake_quantized_graph_def.node()) { - if (node.op() == "Variable" || node.op() == "VariableV2") { - NodeDef const_node; - const_node.set_op("Const"); - const_node.set_name(node.name()); - SetNodeAttr("dtype", DT_FLOAT, &const_node); - Tensor tensor(DT_FLOAT, {}); - tensor.flat()(0) = 1.0f; - SetNodeTensorAttr("value", tensor, &const_node); - *(frozen_graph_def.mutable_node()->Add()) = const_node; - } else { - *(frozen_graph_def.mutable_node()->Add()) = node; - } - } - - // Test that freezing the graph resulted in a graph with the same number of - // nodes. - EXPECT_EQ(frozen_graph_def.node_size(), fake_quantized_graph_def.node_size()); - - // (4) RemoveEMA on the graph to make it compatible with QuantizeNodes. - GraphDef removed_ema_graph_def; - TF_ASSERT_OK(RemoveEMA(frozen_graph_def, context, &removed_ema_graph_def)); - - // Test that the transformation resulted in a graph with less nodes. - EXPECT_LT(removed_ema_graph_def.node_size(), frozen_graph_def.node_size()); - - // (5) QuantizeNodes and inspect the final graph. - // TODO(suharshs): Add a more thorough inspection of the structure of - // the output graph. - GraphDef quantized_graph_def; - TF_ASSERT_OK( - QuantizeNodes(removed_ema_graph_def, context, &quantized_graph_def)); - - // Test that the transformation resulted in a graph with more nodes. - EXPECT_GT(quantized_graph_def.node_size(), removed_ema_graph_def.node_size()); - - // Make sure that the FakeQuantizeWithMinMaxVars op has been removed. - for (const NodeDef& node : quantized_graph_def.node()) { - EXPECT_NE(node.op(), "FakeQuantWithMinMaxVars"); - } -} - -} // namespace graph_transforms -} // namespace tensorflow diff --git a/tensorflow/tools/integration_tests/gcs_smoke_test/BUILD.bazel b/tensorflow/tools/integration_tests/gcs_smoke_test/BUILD.bazel deleted file mode 100755 index 439d86c5d2c10d15f68247c0df42ce488c10d6be..0000000000000000000000000000000000000000 --- a/tensorflow/tools/integration_tests/gcs_smoke_test/BUILD.bazel +++ /dev/null @@ -1,56 +0,0 @@ -package(default_visibility = ["//visibility:public"]) - -load("@rbe_integration_test//skylark:integration_tests.bzl", "sut_component", "integration_test") -load("@rbe_integration_test//skylark:toolchains.bzl", "toolchain_container_images") - -sut_component( - name = "gcs", - docker_image = toolchain_container_images()["tensorflow"], - setups = [{ - "program": "setup.sh", - "args": [ - "gs://tensorflow-test-bucket/tf-gcs-test", - ], - "output_properties": ["gcs_path"], - "timeout_seconds": 100, - }], - teardowns = [{ - "program": "teardown.sh", - "args": ["{gcs_path}"], - "timeout_seconds": 100, - }], -) - -py_binary( - name = "gcs_smoke", - srcs = ["gcs_smoke.py"], -) - -sh_binary( - name = "test_wrapper", - srcs = ["test_wrapper.sh"], - data = [ - "gcs_smoke", - ], -) - -integration_test( - name = "gcs_smoke_test", - sut_deps = { - ":gcs": "gcs", - }, - tags = [ - "manual", - "notap", - ], - test = { - "program": ":test_wrapper", - "args": [ - "--gcs_bucket_url={gcs#gcs_path}", - "--num_examples=20", - ], - "timeout_seconds": 250, - }, - test_docker_image = toolchain_container_images()["tensorflow"], - test_type = "MultiMachine", -) diff --git a/tensorflow/tools/integration_tests/gcs_smoke_test/gcs_smoke.py b/tensorflow/tools/integration_tests/gcs_smoke_test/gcs_smoke.py deleted file mode 100755 index 8438c2156cb09b4d8c9442d9a5f4de67e59272f2..0000000000000000000000000000000000000000 --- a/tensorflow/tools/integration_tests/gcs_smoke_test/gcs_smoke.py +++ /dev/null @@ -1,253 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Smoke test for reading records from GCS to TensorFlow.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import sys -import time - -import numpy as np -import tensorflow as tf -from tensorflow.core.example import example_pb2 -from tensorflow.python.lib.io import file_io - -flags = tf.app.flags -flags.DEFINE_string("gcs_bucket_url", "", - "The URL to the GCS bucket in which the temporary " - "tfrecord file is to be written and read, e.g., " - "gs://my-gcs-bucket/test-directory") -flags.DEFINE_integer("num_examples", 10, "Number of examples to generate") - -FLAGS = flags.FLAGS - - -def create_examples(num_examples, input_mean): - """Create ExampleProto's containing data.""" - ids = np.arange(num_examples).reshape([num_examples, 1]) - inputs = np.random.randn(num_examples, 1) + input_mean - target = inputs - input_mean - examples = [] - for row in range(num_examples): - ex = example_pb2.Example() - ex.features.feature["id"].bytes_list.value.append(str(ids[row, 0])) - ex.features.feature["target"].float_list.value.append(target[row, 0]) - ex.features.feature["inputs"].float_list.value.append(inputs[row, 0]) - examples.append(ex) - return examples - - -def create_dir_test(): - """Verifies file_io directory handling methods.""" - - # Test directory creation. - starttime_ms = int(round(time.time() * 1000)) - dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime_ms) - print("Creating dir %s" % dir_name) - file_io.create_dir(dir_name) - elapsed_ms = int(round(time.time() * 1000)) - starttime_ms - print("Created directory in: %d milliseconds" % elapsed_ms) - - # Check that the directory exists. - dir_exists = file_io.is_directory(dir_name) - assert dir_exists - print("%s directory exists: %s" % (dir_name, dir_exists)) - - # Test recursive directory creation. - starttime_ms = int(round(time.time() * 1000)) - recursive_dir_name = "%s/%s/%s" % (dir_name, - "nested_dir1", - "nested_dir2") - print("Creating recursive dir %s" % recursive_dir_name) - file_io.recursive_create_dir(recursive_dir_name) - elapsed_ms = int(round(time.time() * 1000)) - starttime_ms - print("Created directory recursively in: %d milliseconds" % elapsed_ms) - - # Check that the directory exists. - recursive_dir_exists = file_io.is_directory(recursive_dir_name) - assert recursive_dir_exists - print("%s directory exists: %s" % (recursive_dir_name, recursive_dir_exists)) - - # Create some contents in the just created directory and list the contents. - num_files = 10 - files_to_create = ["file_%d.txt" % n for n in range(num_files)] - for file_num in files_to_create: - file_name = "%s/%s" % (dir_name, file_num) - print("Creating file %s." % file_name) - file_io.write_string_to_file(file_name, "test file.") - - print("Listing directory %s." % dir_name) - starttime_ms = int(round(time.time() * 1000)) - directory_contents = file_io.list_directory(dir_name) - print(directory_contents) - elapsed_ms = int(round(time.time() * 1000)) - starttime_ms - print("Listed directory %s in %s milliseconds" % (dir_name, elapsed_ms)) - assert set(directory_contents) == set(files_to_create + ["nested_dir1/"]) - - # Test directory renaming. - dir_to_rename = "%s/old_dir" % dir_name - new_dir_name = "%s/new_dir" % dir_name - file_io.create_dir(dir_to_rename) - assert file_io.is_directory(dir_to_rename) - assert not file_io.is_directory(new_dir_name) - - starttime_ms = int(round(time.time() * 1000)) - print("Will try renaming directory %s to %s" % (dir_to_rename, new_dir_name)) - file_io.rename(dir_to_rename, new_dir_name) - elapsed_ms = int(round(time.time() * 1000)) - starttime_ms - print("Renamed directory %s to %s in %s milliseconds" % ( - dir_to_rename, new_dir_name, elapsed_ms)) - assert not file_io.is_directory(dir_to_rename) - assert file_io.is_directory(new_dir_name) - - # Test Delete directory recursively. - print("Deleting directory recursively %s." % dir_name) - starttime_ms = int(round(time.time() * 1000)) - file_io.delete_recursively(dir_name) - elapsed_ms = int(round(time.time() * 1000)) - starttime_ms - dir_exists = file_io.is_directory(dir_name) - assert not dir_exists - print("Deleted directory recursively %s in %s milliseconds" % ( - dir_name, elapsed_ms)) - - -def create_object_test(): - """Verifies file_io's object manipulation methods .""" - starttime_ms = int(round(time.time() * 1000)) - dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime_ms) - print("Creating dir %s." % dir_name) - file_io.create_dir(dir_name) - - num_files = 5 - # Create files of 2 different patterns in this directory. - files_pattern_1 = ["%s/test_file_%d.txt" % (dir_name, n) - for n in range(num_files)] - files_pattern_2 = ["%s/testfile%d.txt" % (dir_name, n) - for n in range(num_files)] - - starttime_ms = int(round(time.time() * 1000)) - files_to_create = files_pattern_1 + files_pattern_2 - for file_name in files_to_create: - print("Creating file %s." % file_name) - file_io.write_string_to_file(file_name, "test file creation.") - elapsed_ms = int(round(time.time() * 1000)) - starttime_ms - print("Created %d files in %s milliseconds" % - (len(files_to_create), elapsed_ms)) - - # Listing files of pattern1. - list_files_pattern = "%s/test_file*.txt" % dir_name - print("Getting files matching pattern %s." % list_files_pattern) - starttime_ms = int(round(time.time() * 1000)) - files_list = file_io.get_matching_files(list_files_pattern) - elapsed_ms = int(round(time.time() * 1000)) - starttime_ms - print("Listed files in %s milliseconds" % elapsed_ms) - print(files_list) - assert set(files_list) == set(files_pattern_1) - - # Listing files of pattern2. - list_files_pattern = "%s/testfile*.txt" % dir_name - print("Getting files matching pattern %s." % list_files_pattern) - starttime_ms = int(round(time.time() * 1000)) - files_list = file_io.get_matching_files(list_files_pattern) - elapsed_ms = int(round(time.time() * 1000)) - starttime_ms - print("Listed files in %s milliseconds" % elapsed_ms) - print(files_list) - assert set(files_list) == set(files_pattern_2) - - # Test renaming file. - file_to_rename = "%s/oldname.txt" % dir_name - file_new_name = "%s/newname.txt" % dir_name - file_io.write_string_to_file(file_to_rename, "test file.") - assert file_io.file_exists(file_to_rename) - assert not file_io.file_exists(file_new_name) - - print("Will try renaming file %s to %s" % (file_to_rename, file_new_name)) - starttime_ms = int(round(time.time() * 1000)) - file_io.rename(file_to_rename, file_new_name) - elapsed_ms = int(round(time.time() * 1000)) - starttime_ms - print("File %s renamed to %s in %s milliseconds" % ( - file_to_rename, file_new_name, elapsed_ms)) - assert not file_io.file_exists(file_to_rename) - assert file_io.file_exists(file_new_name) - - # Delete directory. - print("Deleting directory %s." % dir_name) - file_io.delete_recursively(dir_name) - - -def main(argv): - del argv # Unused. - # Sanity check on the GCS bucket URL. - if not FLAGS.gcs_bucket_url or not FLAGS.gcs_bucket_url.startswith("gs://"): - print("ERROR: Invalid GCS bucket URL: \"%s\"" % FLAGS.gcs_bucket_url) - sys.exit(1) - - # Verify that writing to the records file in GCS works. - print("\n=== Testing writing and reading of GCS record file... ===") - example_data = create_examples(FLAGS.num_examples, 5) - with tf.python_io.TFRecordWriter(FLAGS.gcs_bucket_url) as hf: - for e in example_data: - hf.write(e.SerializeToString()) - - print("Data written to: %s" % FLAGS.gcs_bucket_url) - - # Verify that reading from the tfrecord file works and that - # tf_record_iterator works. - record_iter = tf.python_io.tf_record_iterator(FLAGS.gcs_bucket_url) - read_count = 0 - for _ in record_iter: - read_count += 1 - print("Read %d records using tf_record_iterator" % read_count) - - if read_count != FLAGS.num_examples: - print("FAIL: The number of records read from tf_record_iterator (%d) " - "differs from the expected number (%d)" % (read_count, - FLAGS.num_examples)) - sys.exit(1) - - # Verify that running the read op in a session works. - print("\n=== Testing TFRecordReader.read op in a session... ===") - with tf.Graph().as_default() as _: - filename_queue = tf.train.string_input_producer([FLAGS.gcs_bucket_url], - num_epochs=1) - reader = tf.TFRecordReader() - _, serialized_example = reader.read(filename_queue) - - with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - sess.run(tf.local_variables_initializer()) - tf.train.start_queue_runners() - index = 0 - for _ in range(FLAGS.num_examples): - print("Read record: %d" % index) - sess.run(serialized_example) - index += 1 - - # Reading one more record should trigger an exception. - try: - sess.run(serialized_example) - print("FAIL: Failed to catch the expected OutOfRangeError while " - "reading one more record than is available") - sys.exit(1) - except tf.errors.OutOfRangeError: - print("Successfully caught the expected OutOfRangeError while " - "reading one more record than is available") - - create_dir_test() - create_object_test() - -if __name__ == "__main__": - tf.app.run(main) diff --git a/tensorflow/tools/integration_tests/gcs_smoke_test/test_wrapper.sh b/tensorflow/tools/integration_tests/gcs_smoke_test/test_wrapper.sh deleted file mode 100755 index ef29dee3462c21d6318a6fb7e7e658961f0d88dd..0000000000000000000000000000000000000000 --- a/tensorflow/tools/integration_tests/gcs_smoke_test/test_wrapper.sh +++ /dev/null @@ -1,21 +0,0 @@ -# This is a python2 only test. -#!/bin/bash -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# Test Tensorflow package installation. -/usr/local/bin/pip install --user tf-nightly - -# Test Tensorflow interaction with GCS. -python tensorflow/tools/integration_test/gcs_smoke_test/gcs_smoke.py "$@" diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 3fbdb5cacd1fd0039deaae5ac330b6c2ca006a68..0ede8c63704ac4a474eb0d19e17cf5f365abca77 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -138,7 +138,6 @@ genrule( "@zlib_archive//:zlib.h", ] + if_mkl([ "//third_party/mkl:LICENSE", - "@mkl//:LICENSE", ]), outs = ["include/tensorflow/c/LICENSE"], cmd = "$(location :concat_licenses.sh) $(SRCS) >$@", @@ -176,7 +175,6 @@ genrule( "@zlib_archive//:zlib.h", ] + if_mkl([ "//third_party/mkl:LICENSE", - "@mkl//:LICENSE", ]), outs = ["include/tensorflow/jni/LICENSE"], cmd = "$(location :concat_licenses.sh) $(SRCS) >$@", diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index fb6eaa4faa28b4f6b17e1774907c0c9ff58d6ada..95cdf0bf3cdc76d5d10205dc4f97680cdfd8f8fe 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -108,6 +108,7 @@ filegroup( "@highwayhash//:LICENSE", "@jemalloc//:COPYING", "@jpeg//:LICENSE.md", + "@kafka//:LICENSE", "@libxsmm_archive//:LICENSE", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", @@ -125,7 +126,6 @@ filegroup( "@org_python_pypi_backports_weakref//:LICENSE", ] + if_mkl([ "//third_party/mkl:LICENSE", - "@mkl//:LICENSE", ]) + if_not_windows([ "@nccl_archive//:LICENSE.txt", ]) + tf_additional_license_deps(), @@ -156,17 +156,19 @@ sh_binary( "//tensorflow/contrib/graph_editor:graph_editor_pip", "//tensorflow/contrib/keras:keras", "//tensorflow/contrib/labeled_tensor:labeled_tensor_pip", + "//tensorflow/contrib/lite/python:interpreter_test_data", + "//tensorflow/contrib/lite/python:tf_lite_py_pip", "//tensorflow/contrib/lite/toco:toco", "//tensorflow/contrib/lite/toco/python:toco_wrapper", "//tensorflow/contrib/lite/toco/python:toco_from_protos", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/predictor:predictor_pip", - "//tensorflow/contrib/py2tf:py2tf", - "//tensorflow/contrib/py2tf/converters:converters", - "//tensorflow/contrib/py2tf/converters:test_lib", - "//tensorflow/contrib/py2tf/impl:impl", - "//tensorflow/contrib/py2tf/pyct:pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis:static_analysis", + "//tensorflow/contrib/autograph:autograph", + "//tensorflow/contrib/autograph/converters:converters", + "//tensorflow/contrib/autograph/converters:test_lib", + "//tensorflow/contrib/autograph/impl:impl", + "//tensorflow/contrib/autograph/pyct:pyct", + "//tensorflow/contrib/autograph/pyct/static_analysis:static_analysis", "//tensorflow/contrib/receptive_field:receptive_field_pip", "//tensorflow/contrib/session_bundle:session_bundle_pip", "//tensorflow/contrib/signal:signal_py", @@ -187,6 +189,7 @@ sh_binary( "//tensorflow/python:util_example_parser_configuration", "//tensorflow/python/debug:debug_pip", "//tensorflow/python/eager:eager_pip", + "//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files", "//tensorflow/python/saved_model:saved_model", "//tensorflow/python/tools:tools_pip", "//tensorflow/python:test_ops", diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index 73d759eb130633094b402c821cc32eb76c076a44..e2518f6cbf0beb0943e5b7289796459d14992bfc 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -58,6 +58,10 @@ BLACKLIST = [ # contrib "//tensorflow/contrib/session_bundle:session_bundle_half_plus_two", "//tensorflow/contrib/keras:testing_utils", + "//tensorflow/contrib/lite/python:interpreter", + "//tensorflow/contrib/lite/python:interpreter_test", + "//tensorflow/contrib/lite/python:interpreter.py", + "//tensorflow/contrib/lite/python:interpreter_test.py", "//tensorflow/contrib/ffmpeg:test_data", "//tensorflow/contrib/factorization/examples:mnist", "//tensorflow/contrib/factorization/examples:mnist.py", @@ -71,6 +75,7 @@ BLACKLIST = [ "//tensorflow/contrib/timeseries/examples:data/period_trend.csv", # pylint:disable=line-too-long "//tensorflow/contrib/timeseries/python/timeseries:test_utils", "//tensorflow/contrib/timeseries/python/timeseries/state_space_models:test_utils", # pylint:disable=line-too-long + "//tensorflow/contrib/image:sparse_image_warp_test_data", ] diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 4b6f123daa7b528173234a2bffd30ead2aa9fc0e..365e8d6b08d654138debd7acad5cf4aac5d07d55 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -29,7 +29,7 @@ from setuptools.dist import Distribution # This version string is semver compatible, but incompatible with pip. # For pip, we will remove all '-' characters from this string, and use the # result for pip. -_VERSION = '1.6.0-rc1' +_VERSION = '1.7.0-rc1' REQUIRED_PACKAGES = [ 'absl-py >= 0.1.6', @@ -39,7 +39,7 @@ REQUIRED_PACKAGES = [ 'numpy >= 1.13.3', 'six >= 1.10.0', 'protobuf >= 3.4.0', - 'tensorboard >= 1.6.0, < 1.7.0', + 'tensorboard >= 1.7.0, < 1.8.0', 'termcolor >= 1.1.0', ] @@ -62,7 +62,7 @@ else: if 'tf_nightly' in project_name: for i, pkg in enumerate(REQUIRED_PACKAGES): if 'tensorboard' in pkg: - REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.7.0a0, < 1.8.0a0' + REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.8.0a0, < 1.9.0a0' break # weakref.finalize and enum were introduced in Python 3.4 @@ -72,7 +72,7 @@ if sys.version_info < (3, 4): # pylint: disable=line-too-long CONSOLE_SCRIPTS = [ - 'freeze_graph = tensorflow.python.tools.freeze_graph:main', + 'freeze_graph = tensorflow.python.tools.freeze_graph:run_main', 'toco_from_protos = tensorflow.contrib.lite.toco.python.toco_from_protos:main', 'toco = tensorflow.contrib.lite.toco.python.toco_wrapper:main', 'saved_model_cli = tensorflow.python.tools.saved_model_cli:main', @@ -200,8 +200,7 @@ headers = (list(find_files('*.h', 'tensorflow/core')) + list(find_files('*.h', 'tensorflow/stream_executor')) + list(find_files('*.h', 'google/protobuf_archive/src')) + list(find_files('*', 'third_party/eigen3')) + - list(find_files('*', 'external/eigen_archive')) + - list(find_files('*.h', 'external/nsync/public'))) + list(find_files('*', 'external/eigen_archive'))) setup( name=project_name, diff --git a/tensorflow/tools/test/performance.bzl b/tensorflow/tools/test/performance.bzl index cee53dd5b61e50126948e3652865a32f45eab092..3486871080c78dc7a1cc201ea2a4d45ebc342758 100644 --- a/tensorflow/tools/test/performance.bzl +++ b/tensorflow/tools/test/performance.bzl @@ -31,7 +31,7 @@ def tf_cc_logged_benchmark( size = "large", srcs = ["//tensorflow/tools/test:run_and_gather_logs"], args = [ - "--name=//%s:%s" % (PACKAGE_NAME, name), + "--name=//%s:%s" % (native.package_name(), name), "--test_name=" + target, "--test_args=--benchmarks=%s" % benchmarks, "--benchmark_type=%s" % benchmark_type, diff --git a/tensorflow/tools/test/upload_test_benchmarks.py b/tensorflow/tools/test/upload_test_benchmarks.py index 77cc9f75f7725438918f681833d58e9ecb4a2f70..9c45359ee1b037ffb01820f874b88b6cabc6d14b 100644 --- a/tensorflow/tools/test/upload_test_benchmarks.py +++ b/tensorflow/tools/test/upload_test_benchmarks.py @@ -87,7 +87,9 @@ import json import os import shutil +from six import text_type from google.cloud import datastore +from six import text_type def is_real_file(dirpath, fname): @@ -150,7 +152,7 @@ def upload_benchmark_data(client, data): """ test_result = json.loads(data) - test_name = unicode(test_result["name"]) + test_name = text_type(test_result["name"]) start_time = datetime.datetime.utcfromtimestamp( float(test_result["startTime"])) batch = [] @@ -162,7 +164,7 @@ def upload_benchmark_data(client, data): t_val.update({ "test": test_name, "start": start_time, - "info": unicode(data) + "info": text_type(data) }) batch.append(t_val) @@ -170,7 +172,7 @@ def upload_benchmark_data(client, data): # the attribute to be fetched and displayed. The full entry information is # also stored as a non-indexed JSON blob. for ent in test_result["entries"].get("entry", []): - ent_name = unicode(ent["name"]) + ent_name = text_type(ent["name"]) e_key = client.key("Entry") e_val = datastore.Entity(e_key, exclude_from_indexes=["info"]) e_val.update({ @@ -178,7 +180,7 @@ def upload_benchmark_data(client, data): "start": start_time, "entry": ent_name, "timing": ent["wallTime"], - "info": unicode(json.dumps(ent)) + "info": text_type(json.dumps(ent)) }) batch.append(e_val) diff --git a/tensorflow/version_check.bzl b/tensorflow/version_check.bzl new file mode 100644 index 0000000000000000000000000000000000000000..79e721dab422c1449214acbe5fc1643edc3a9db0 --- /dev/null +++ b/tensorflow/version_check.bzl @@ -0,0 +1,48 @@ +""" Helpers to check minimum version of bazel.""" + +def _extract_version_number(bazel_version): + """Extracts the semantic version number from a version string + + Args: + bazel_version: the version string that begins with the semantic version + e.g. "1.2.3rc1 abc1234" where "abc1234" is a commit hash. + + Returns: + The semantic version string, like "1.2.3". + """ + for i in range(len(bazel_version)): + c = bazel_version[i] + if not (c.isdigit() or c == "."): + return bazel_version[:i] + return bazel_version + +# Parse the bazel version string from `native.bazel_version`. +# e.g. +# "0.10.0rc1 abc123d" => (0, 10, 0) +# "0.3.0" => (0, 3, 0) +def _parse_bazel_version(bazel_version): + """Parses a version string into a 3-tuple of ints + + int tuples can be compared directly using binary operators (<, >). + + Args: + bazel_version: the Bazel version string + + Returns: + An int 3-tuple of a (major, minor, patch) version. + """ + + version = _extract_version_number(bazel_version) + return tuple([int(n) for n in version.split(".")]) + +def check_bazel_version_at_least(minimum_bazel_version): + if "bazel_version" not in dir(native): + fail("\nCurrent Bazel version is lower than 0.2.1, expected at least %s\n" % minimum_bazel_version) + elif not native.bazel_version: + print("\nCurrent Bazel is not a release version, cannot check for compatibility.") + print("Make sure that you are running at least Bazel %s.\n" % minimum_bazel_version) + return + + if _parse_bazel_version(native.bazel_version) < _parse_bazel_version(minimum_bazel_version): + fail("\nCurrent Bazel version is {}, expected at least {}\n".format( + native.bazel_version, minimum_bazel_version)) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index d5c61baa8b77cbb25a899ba6b34ad9d3b32cc760..6ac98de43a1c648515277c0ff41ace5fdba5647b 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -10,65 +10,23 @@ load("//third_party/sycl:sycl_configure.bzl", "sycl_configure") load("//third_party/toolchains/clang6:repo.bzl", "clang6_configure") load("//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl", "arm_compiler_configure") load("//third_party:repo.bzl", "tf_http_archive") +load("//third_party/clang_toolchain:cc_configure_clang.bzl", "cc_download_clang_toolchain") load("@io_bazel_rules_closure//closure/private:java_import_external.bzl", "java_import_external") load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") -def _extract_version_number(bazel_version): - """Extracts the semantic version number from a version string - - Args: - bazel_version: the version string that begins with the semantic version - e.g. "1.2.3rc1 abc1234" where "abc1234" is a commit hash. - - Returns: - The semantic version string, like "1.2.3". - """ - for i in range(len(bazel_version)): - c = bazel_version[i] - if not (c.isdigit() or c == "."): - return bazel_version[:i] - return bazel_version - -# Parse the bazel version string from `native.bazel_version`. -# e.g. -# "0.10.0rc1 abc123d" => (0, 10, 0) -# "0.3.0" => (0, 3, 0) -def _parse_bazel_version(bazel_version): - """Parses a version string into a 3-tuple of ints - - int tuples can be compared directly using binary operators (<, >). - - Args: - bazel_version: the Bazel version string - - Returns: - An int 3-tuple of a (major, minor, patch) version. - """ - - version = _extract_version_number(bazel_version) - return tuple([int(n) for n in version.split(".")]) - -def check_bazel_version_at_least(minimum_bazel_version): - if "bazel_version" not in dir(native): - fail("\nCurrent Bazel version is lower than 0.2.1, expected at least %s\n" % minimum_bazel_version) - elif not native.bazel_version: - print("\nCurrent Bazel is not a release version, cannot check for compatibility.") - print("Make sure that you are running at least Bazel %s.\n" % minimum_bazel_version) - return - - if _parse_bazel_version(native.bazel_version) < _parse_bazel_version(minimum_bazel_version): - fail("\nCurrent Bazel version is {}, expected at least {}\n".format( - native.bazel_version, minimum_bazel_version)) + +# Sanitize a dependency so that it works correctly from code that includes +# TensorFlow as a submodule. +def clean_dep(dep): + return str(Label(dep)) # If TensorFlow is linked as a submodule. # path_prefix is no longer used. # tf_repo_name is thought to be under consideration. def tf_workspace(path_prefix="", tf_repo_name=""): - # We must check the bazel version before trying to parse any other BUILD - # files, in case the parsing of those build files depends on the bazel - # version we require here. - check_bazel_version_at_least("0.5.4") + # Note that we check the minimum bazel version in WORKSPACE. clang6_configure(name="local_config_clang6") + cc_download_clang_toolchain(name="local_config_download_clang") cuda_configure(name="local_config_cuda") tensorrt_configure(name="local_config_tensorrt") git_configure(name="local_config_git") @@ -79,17 +37,37 @@ def tf_workspace(path_prefix="", tf_repo_name=""): arm_compiler_configure( name="local_config_arm_compiler", remote_config_repo="../arm_compiler", - build_file = str(Label("//third_party/toolchains/cpus/arm:BUILD"))) + build_file = clean_dep("//third_party/toolchains/cpus/arm:BUILD")) mkl_repository( - name = "mkl", + name = "mkl_linux", urls = [ - "https://mirror.bazel.build/github.com/01org/mkl-dnn/releases/download/v0.11/mklml_lnx_2018.0.1.20171007.tgz", - "https://github.com/01org/mkl-dnn/releases/download/v0.11/mklml_lnx_2018.0.1.20171007.tgz", + "https://mirror.bazel.build/intel/mkl-dnn/releases/download/v0.12/mklml_lnx_2018.0.1.20171227.tgz", + "https://github.com/intel/mkl-dnn/releases/download/v0.12/mklml_lnx_2018.0.1.20171227.tgz", ], - sha256 = "6b07cb7e5451db67c2e31e785ae458b18f7f363c60a61685488f69e9ae7199d4", - strip_prefix = "mklml_lnx_2018.0.1.20171007", - build_file = str(Label("//third_party/mkl:mkl.BUILD")), + sha256 = "feacc3d82565c1231470359b42c696236fae873704e0b013436afba5fd4fd30f", + strip_prefix = "mklml_lnx_2018.0.1.20171227", + build_file = clean_dep("//third_party/mkl:mkl.BUILD") + ) + mkl_repository( + name = "mkl_windows", + urls = [ + "https://mirror.bazel.build/intel/mkl-dnn/releases/download/v0.12/mklml_win_2018.0.1.20171227.zip", + "https://github.com/intel/mkl-dnn/releases/download/v0.12/mklml_win_2018.0.1.20171227.zip" + ], + sha256 = "24bae8d7b22b431a654acadea43f2243c46ae6b1e5a73a4a936825f31d284ee4", + strip_prefix = "mklml_win_2018.0.1.20171227", + build_file = clean_dep("//third_party/mkl:mkl.BUILD") + ) + mkl_repository( + name = "mkl_darwin", + urls = [ + "https://mirror.bazel.build/intel/mkl-dnn/releases/download/v0.12/mklml_mac_2018.0.1.20171227.tgz", + "https://github.com/intel/mkl-dnn/releases/download/v0.12/mklml_mac_2018.0.1.20171227.tgz" + ], + sha256 = "0e954ec6fd3dc5e37f64c4043f6b5613dd687558da3df1028b3b7c29ff5cf77f", + strip_prefix = "mklml_mac_2018.0.1.20171227", + build_file = clean_dep("//third_party/mkl:mkl.BUILD") ) if path_prefix: @@ -99,12 +77,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "mkl_dnn", urls = [ - "https://mirror.bazel.build/github.com/01org/mkl-dnn/archive/e0bfcaa7fcb2b1e1558f5f0676933c1db807a729.tar.gz", - "https://github.com/01org/mkl-dnn/archive/e0bfcaa7fcb2b1e1558f5f0676933c1db807a729.tar.gz", + "https://mirror.bazel.build/github.com/intel/mkl-dnn/archive/v0.12.tar.gz", + "https://github.com/intel/mkl-dnn/archive/v0.12.tar.gz", ], - sha256 = "02e244f63dd95402691a361392504c143eede9a89043426f174836638a9cbf09", - strip_prefix = "mkl-dnn-e0bfcaa7fcb2b1e1558f5f0676933c1db807a729", - build_file = str(Label("//third_party/mkl_dnn:mkldnn.BUILD")), + sha256 = "86fa2a8c12a56e3b725945acedeaa82492746be02545aba6d710f097e013e19e", + strip_prefix = "mkl-dnn-0.12", + build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"), ) tf_http_archive( @@ -115,7 +93,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "5996380e3e8b981f55d1c8d58e709c00dbb4806ba367be75d0925a68cc2f6478", strip_prefix = "abseil-cpp-720c017e30339fd1786ce4aac68bc8559736e53f", - build_file = str(Label("//third_party:com_google_absl.BUILD")), + build_file = clean_dep("//third_party:com_google_absl.BUILD"), ) tf_http_archive( @@ -126,8 +104,8 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "0cadb31a35b514bf2dfd6b5d38205da94ef326ec6908fc3fd7c269948467214f", strip_prefix = "eigen-eigen-2355b229ea4c", - build_file = str(Label("//third_party:eigen.BUILD")), - patch_file = str(Label("//third_party:eigen_fix_cuda_compilation.patch")) + build_file = clean_dep("//third_party:eigen.BUILD"), + patch_file = clean_dep("//third_party:eigen_fix_cuda_compilation.patch") ) tf_http_archive( @@ -140,7 +118,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): # remove the whitelist entry in third_party/repo.bzl. # "https://github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz", ], - build_file = str(Label("//:arm_compiler.BUILD")), + build_file = clean_dep("//:arm_compiler.BUILD"), ) tf_http_archive( @@ -151,7 +129,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "2ade869c3f42f23b5263c7d594aa3c7e5e61ac6a3afcaf5d6e42899d2a7986ce", strip_prefix = "libxsmm-1.8.1", - build_file = str(Label("//third_party:libxsmm.BUILD")), + build_file = clean_dep("//third_party:libxsmm.BUILD"), ) tf_http_archive( @@ -164,7 +142,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "932075525642b04ac6f1b50589f1df5cd72ec2f448b721fd32234cf183f0e755", strip_prefix = "or-tools-253f7955c6a1fd805408fba2e42ac6d45b312d15/src", - build_file = str(Label("//third_party:ortools.BUILD")), + build_file = clean_dep("//third_party:ortools.BUILD"), ) tf_http_archive( @@ -196,7 +174,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "6560547c63e4af82b0f202cb710ceabb3f21347a4b996db565a411da5b17aba0", strip_prefix = "farmhash-816a4ae622e964763ca0862d9dbd19324a1eaf45", - build_file = str(Label("//third_party:farmhash.BUILD")), + build_file = clean_dep("//third_party:farmhash.BUILD"), ) tf_http_archive( @@ -207,7 +185,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "0f30a15b1566d93f146c8d149878a06e91d9bb7ec2cfd76906df62a82be4aac9", strip_prefix = "highwayhash-dfcb97ca4fe9277bf9dc1802dd979b071896453b", - build_file = str(Label("//third_party:highwayhash.BUILD")), + build_file = clean_dep("//third_party:highwayhash.BUILD"), ) tf_http_archive( @@ -218,7 +196,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "00b0891c678c065446ca59bcee64719d0096d54d6886e6e472aeee2e170ae324", strip_prefix = "nasm-2.12.02", - build_file = str(Label("//third_party:nasm.BUILD")), + build_file = clean_dep("//third_party:nasm.BUILD"), ) tf_http_archive( @@ -226,11 +204,10 @@ def tf_workspace(path_prefix="", tf_repo_name=""): urls = [ "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.1.tar.gz", "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.1.tar.gz", - "http://www.nasm.us/pub/nasm/releasebuilds/2.12.02/nasm-2.12.02.tar.bz2", ], sha256 = "c15a9607892113946379ccea3ca8b85018301b200754f209453ab21674268e77", strip_prefix = "libjpeg-turbo-1.5.1", - build_file = str(Label("//third_party/jpeg:jpeg.BUILD")), + build_file = clean_dep("//third_party/jpeg:jpeg.BUILD"), ) tf_http_archive( @@ -241,7 +218,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "716c59c7dfc808a4c368f8ada526932be72b2fcea11dd85dc9d88b1df1dfe9c2", strip_prefix = "libpng-1.2.53", - build_file = str(Label("//third_party:png.BUILD")), + build_file = clean_dep("//third_party:png.BUILD"), ) tf_http_archive( @@ -252,7 +229,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "208780b3616f9de0aeb50822b7a8f5482f6515193859e91ed61637be6ad74fd4", strip_prefix = "sqlite-amalgamation-3200000", - build_file = str(Label("//third_party:sqlite.BUILD")), + build_file = clean_dep("//third_party:sqlite.BUILD"), ) tf_http_archive( @@ -263,7 +240,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1", strip_prefix = "giflib-5.1.4", - build_file = str(Label("//third_party:gif.BUILD")), + build_file = clean_dep("//third_party:gif.BUILD"), ) tf_http_archive( @@ -274,7 +251,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a", strip_prefix = "six-1.10.0", - build_file = str(Label("//third_party:six.BUILD")), + build_file = clean_dep("//third_party:six.BUILD"), ) tf_http_archive( @@ -285,7 +262,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "ff6d2e2962d834acb125cc4dcc80c54a8c17c253f4cc9d9c43b5102a560bb75d", strip_prefix = "astor-0.6.2", - build_file = str(Label("//third_party:astor.BUILD")), + build_file = clean_dep("//third_party:astor.BUILD"), ) tf_http_archive( @@ -296,7 +273,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "7068908321ecd2774f145193c4b34a11305bd104b4551b09273dfd1d6a374930", strip_prefix = "gast-0.2.0", - build_file = str(Label("//third_party:gast.BUILD")), + build_file = clean_dep("//third_party:gast.BUILD"), ) tf_http_archive( @@ -307,7 +284,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "1d6d69ce66211143803fbc56652b41d73b4a400a2891d7bf7a1cdf4c02de613b", strip_prefix = "termcolor-1.1.0", - build_file = str(Label("//third_party:termcolor.BUILD")), + build_file = clean_dep("//third_party:termcolor.BUILD"), ) tf_http_archive( @@ -328,7 +305,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "8813bf712a66b3d8b85dc289e1104ed220f1878cf981e2fe756dfaabe9a82892", strip_prefix = "backports.weakref-1.0rc1/src", - build_file = str(Label("//third_party:backports_weakref.BUILD")), + build_file = clean_dep("//third_party:backports_weakref.BUILD"), ) tf_http_archive( @@ -339,7 +316,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "2dadd04a2802de27e0fe5a19b76538f6da9d39ff244036afa00c1bba754de5ee", strip_prefix = "codegen-1.0", - build_file = str(Label("//third_party:codegen.BUILD")), + build_file = clean_dep("//third_party:codegen.BUILD"), ) filegroup_external( @@ -389,11 +366,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "nsync", urls = [ - "https://mirror.bazel.build/github.com/google/nsync/archive/8502189abfa44c249c01c2cad64e6ed660a9a668.tar.gz", - "https://github.com/google/nsync/archive/8502189abfa44c249c01c2cad64e6ed660a9a668.tar.gz", + "https://mirror.bazel.build/github.com/google/nsync/archive/0559ce013feac8db639ee1bf776aca0325d28777.tar.gz", + "https://github.com/google/nsync/archive/0559ce013feac8db639ee1bf776aca0325d28777.tar.gz", ], - sha256 = "51f81ff4202bbb820cdbedc061bd2eb6765f2b5c06489e7a8694bedac329e8f8", - strip_prefix = "nsync-8502189abfa44c249c01c2cad64e6ed660a9a668", + sha256 = "6284454c5cd8b1dae2eeb8cf5eb63004de930b5427ed5f6b1aa793513df6b361", + strip_prefix = "nsync-0559ce013feac8db639ee1bf776aca0325d28777", ) tf_http_archive( @@ -424,7 +401,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): "http://ftp.exim.org/pub/pcre/pcre-8.39.tar.gz", ], strip_prefix = "pcre-8.39", - build_file = str(Label("//third_party:pcre.BUILD")), + build_file = clean_dep("//third_party:pcre.BUILD"), ) tf_http_archive( @@ -436,7 +413,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): "http://pilotfiber.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz", ], strip_prefix = "swig-3.0.8", - build_file = str(Label("//third_party:swig.BUILD")), + build_file = clean_dep("//third_party:swig.BUILD"), ) tf_http_archive( @@ -447,17 +424,17 @@ def tf_workspace(path_prefix="", tf_repo_name=""): "https://curl.haxx.se/download/curl-7.49.1.tar.gz", ], strip_prefix = "curl-7.49.1", - build_file = str(Label("//third_party:curl.BUILD")), + build_file = clean_dep("//third_party:curl.BUILD"), ) tf_http_archive( name = "grpc", urls = [ - "https://mirror.bazel.build/github.com/grpc/grpc/archive/730b778632e79cc3c96ad237f282d687ee325ce7.tar.gz", - "https://github.com/grpc/grpc/archive/730b778632e79cc3c96ad237f282d687ee325ce7.tar.gz", + "https://mirror.bazel.build/github.com/grpc/grpc/archive/575bda39755b98d1f7099406bb57a6e3b2074874.tar.gz", + "https://github.com/grpc/grpc/archive/575bda39755b98d1f7099406bb57a6e3b2074874.tar.gz", ], - sha256 = "8c91a8d12e1e868cf51f7340b75507a8aa017a7e1b56f46ed6816aeb803dc9bd", - strip_prefix = "grpc-730b778632e79cc3c96ad237f282d687ee325ce7", + sha256 = "f08a5c8e265191b39cc74915b1bc1fd380d86cd0176c92b7cce30b6ac50514ad", + strip_prefix = "grpc-575bda39755b98d1f7099406bb57a6e3b2074874", ) tf_http_archive( @@ -468,7 +445,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): "https://github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz", ], strip_prefix = "linenoise-c894b9e59f02203dbe4e2be657572cf88c4230c3", - build_file = str(Label("//third_party:linenoise.BUILD")), + build_file = clean_dep("//third_party:linenoise.BUILD"), ) # TODO(phawkins): currently, this rule uses an unofficial LLVM mirror. @@ -476,12 +453,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/fc8ba497cd1a1af4ecae19a5b64bdbd71e065e14.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/fc8ba497cd1a1af4ecae19a5b64bdbd71e065e14.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/1c3cdea2f181d8e14ee184466c5fb237f1b4cda8.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/1c3cdea2f181d8e14ee184466c5fb237f1b4cda8.tar.gz", ], - sha256 = "f5721d9cc18a9109c9e9f847f48e69b710b961cee83e6691227e310cb3b5da58", - strip_prefix = "llvm-fc8ba497cd1a1af4ecae19a5b64bdbd71e065e14", - build_file = str(Label("//third_party/llvm:llvm.BUILD")), + sha256 = "1efbb9b05af88368be984d2f6526061d4a857181ef10f8841889a3a46869bb01", + strip_prefix = "llvm-1c3cdea2f181d8e14ee184466c5fb237f1b4cda8", + build_file = clean_dep("//third_party/llvm:llvm.BUILD"), ) tf_http_archive( @@ -492,7 +469,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326", strip_prefix = "lmdb-LMDB_0.9.19/libraries/liblmdb", - build_file = str(Label("//third_party:lmdb.BUILD")), + build_file = clean_dep("//third_party:lmdb.BUILD"), ) tf_http_archive( @@ -503,7 +480,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "07d34db40593d257324ec5fb9debc4dc33f29f8fb44e33a2eeb35503e61d0fe2", strip_prefix = "jsoncpp-11086dd6a7eba04289944367ca82cea71299ed70", - build_file = str(Label("//third_party:jsoncpp.BUILD")), + build_file = clean_dep("//third_party:jsoncpp.BUILD"), ) tf_http_archive( @@ -524,7 +501,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "36658cb768a54c1d4dec43c3116c27ed893e88b02ecfcb44f2166f9c0b7f2a0d", strip_prefix = "zlib-1.2.8", - build_file = str(Label("//third_party:zlib.BUILD")), + build_file = clean_dep("//third_party:zlib.BUILD"), ) tf_http_archive( @@ -534,7 +511,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): "http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz", ], sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296", - build_file = str(Label("//third_party/fft2d:fft2d.BUILD")), + build_file = clean_dep("//third_party/fft2d:fft2d.BUILD"), ) tf_http_archive( @@ -545,7 +522,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "2f7504c73d85bac842e893340333be8cb8561710642fc9562fccdd9d2c3fcc94", strip_prefix = "snappy-1.1.4", - build_file = str(Label("//third_party:snappy.BUILD")), + build_file = clean_dep("//third_party:snappy.BUILD"), ) tf_http_archive( @@ -556,7 +533,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176", strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7", - build_file = str(Label("//third_party:nccl.BUILD")), + build_file = clean_dep("//third_party:nccl.BUILD"), ) tf_http_archive( @@ -567,8 +544,8 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "dd035d57c8f19b0b612dd6eefe6e5eebad76f506e302cccb7c2066f25a83585e", strip_prefix = "librdkafka-0.11.1", - build_file = str(Label("//third_party:kafka/BUILD")), - patch_file = str(Label("//third_party/kafka:config.patch")), + build_file = clean_dep("//third_party:kafka/BUILD"), + patch_file = clean_dep("//third_party/kafka:config.patch"), ) tf_http_archive( @@ -579,7 +556,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "b888d8ce5fc10254c3dd6c9020c7764dd53cf39cf011249d0b4deda895de1b7c", strip_prefix = "aws-sdk-cpp-1.3.15", - build_file = str(Label("//third_party:aws.BUILD")), + build_file = clean_dep("//third_party:aws.BUILD"), ) java_import_external( @@ -615,7 +592,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8", strip_prefix = "jemalloc-4.4.0", - build_file = str(Label("//third_party:jemalloc.BUILD")), + build_file = clean_dep("//third_party:jemalloc.BUILD"), ) java_import_external( @@ -624,7 +601,6 @@ def tf_workspace(path_prefix="", tf_repo_name=""): jar_urls = [ "http://mirror.bazel.build/repo1.maven.org/maven2/com/google/testing/compile/compile-testing/0.11/compile-testing-0.11.jar", "http://repo1.maven.org/maven2/com/google/testing/compile/compile-testing/0.11/compile-testing-0.11.jar", - "http://maven.ibiblio.org/maven2/com/google/testing/compile/compile-testing/0.11/compile-testing-0.11.jar", ], licenses = ["notice"], # New BSD License testonly_ = True, @@ -644,11 +620,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ) java_import_external( - name = "javax_validation", - jar_sha256 = "e459f313ebc6db2483f8ceaad39af07086361b474fa92e40f442e8de5d9895dc", + name = "org_checkerframework_qual", + jar_sha256 = "a17501717ef7c8dda4dba73ded50c0d7cde440fd721acfeacbf19786ceac1ed6", jar_urls = [ - "http://mirror.bazel.build/repo1.maven.org/maven2/javax/validation/validation-api/1.0.0.GA/validation-api-1.0.0.GA.jar", - "http://repo1.maven.org/maven2/javax/validation/validation-api/1.0.0.GA/validation-api-1.0.0.GA.jar", + "http://mirror.bazel.build/repo1.maven.org/maven2/org/checkerframework/checker-qual/2.4.0/checker-qual-2.4.0.jar", + "http://repo1.maven.org/maven2/org/checkerframework/checker-qual/2.4.0/checker-qual-2.4.0.jar", ], licenses = ["notice"], # Apache 2.0 ) @@ -661,21 +637,18 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "e0928ca4aa10ea1e0551e2d7ce4d1d7ea2d84b2abbdef082b0da84268791d0c4", strip_prefix = "pprof-c0fb62ec88c411cc91194465e54db2632845b650", - build_file = str(Label("//third_party:pprof.BUILD")), + build_file = clean_dep("//third_party:pprof.BUILD"), ) tf_http_archive( name = "cub_archive", urls = [ - "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip", - "https://github.com/NVlabs/cub/archive/1.7.4.zip", + "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.8.0.zip", + "https://github.com/NVlabs/cub/archive/1.8.0.zip", ], - sha256 = "20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31", - strip_prefix = "cub-1.7.4", - build_file = str(Label("//third_party:cub.BUILD")), - # TODO: remove the patch when upstream fix is accepted and released. - # PR with a fix: https://github.com/NVlabs/cub/pull/125 - patch_file = str(Label("//third_party/cub:fix_compilation_in_clang.patch")), + sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3", + strip_prefix = "cub-1.8.0", + build_file = clean_dep("//third_party:cub.BUILD"), ) tf_http_archive( @@ -686,7 +659,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): "https://github.com/cython/cython/archive/3732784c45cfb040a5b0936951d196f83a12ea17.tar.gz", ], strip_prefix = "cython-3732784c45cfb040a5b0936951d196f83a12ea17", - build_file = str(Label("//third_party:cython.BUILD")), + build_file = clean_dep("//third_party:cython.BUILD"), delete = ["BUILD.bazel"], ) @@ -700,16 +673,6 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "699b55a6916c687f4b7dc092dbbf5f64672cde0dc965f79717735ec4e5416556", ) - tf_http_archive( - name = "rbe_integration_test", - urls = [ - "http://mirror.bazel.build/github.com/google/rbe-integration-test/archive/78a6194c7dda200b9522cf07707e3bc695804d1e.tar.gz", - "https://github.com/google/rbe-integration-test/archive/78a6194c7dda200b9522cf07707e3bc695804d1e.tar.gz", - ], - sha256 = "66d93b3919a165d486c31f5290d312abe9fda2685242f812c110653c124e1db4", - strip_prefix = "rbe-integration-test-78a6194c7dda200b9522cf07707e3bc695804d1e", - ) - tf_http_archive( name = "arm_neon_2_x86_sse", sha256 = "c8d90aa4357f8079d427e87a6f4c493da1fa4140aee926c05902d7ec1533d9a5", @@ -718,7 +681,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): "https://mirror.bazel.build/github.com/intel/ARM_NEON_2_x86_SSE/archive/0f77d9d182265259b135dad949230ecbf1a2633d.tar.gz", "https://github.com/intel/ARM_NEON_2_x86_SSE/archive/0f77d9d182265259b135dad949230ecbf1a2633d.tar.gz", ], - build_file = str(Label("//third_party:arm_neon_2_x86_sse.BUILD")), + build_file = clean_dep("//third_party:arm_neon_2_x86_sse.BUILD"), ) tf_http_archive( @@ -729,7 +692,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): "https://mirror.bazel.build/github.com/google/flatbuffers/archive/971a68110e4fc1bace10fcb6deeb189e7e1a34ce.tar.gz", "https://github.com/google/flatbuffers/archive/971a68110e4fc1bace10fcb6deeb189e7e1a34ce.tar.gz", ], - build_file = str(Label("//third_party/flatbuffers:flatbuffers.BUILD")), + build_file = clean_dep("//third_party/flatbuffers:flatbuffers.BUILD"), ) tf_http_archive( @@ -739,7 +702,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip", "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip", ], - build_file = str(Label("//third_party:tflite_mobilenet.BUILD")), + build_file = clean_dep("//third_party:tflite_mobilenet.BUILD"), ) tf_http_archive( @@ -749,7 +712,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip", "https://storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip" ], - build_file = str(Label("//third_party:tflite_smartreply.BUILD")), + build_file = clean_dep("//third_party:tflite_smartreply.BUILD"), ) ############################################################################## @@ -813,7 +776,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): # Needed by Protobuf native.bind( name = "python_headers", - actual = str(Label("//util/python:python_headers")), + actual = clean_dep("//util/python:python_headers"), ) # Needed by Protobuf diff --git a/third_party/cub/BUILD b/third_party/clang_toolchain/BUILD similarity index 100% rename from third_party/cub/BUILD rename to third_party/clang_toolchain/BUILD diff --git a/third_party/clang_toolchain/cc_configure_clang.bzl b/third_party/clang_toolchain/cc_configure_clang.bzl new file mode 100644 index 0000000000000000000000000000000000000000..1181110ea9674e56264509fe5bb043a587888200 --- /dev/null +++ b/third_party/clang_toolchain/cc_configure_clang.bzl @@ -0,0 +1,27 @@ +""" Downloads clang and configures the crosstool using bazel's autoconf.""" + +load("@bazel_tools//tools/cpp:cc_configure.bzl", "cc_autoconf_impl") +load(":download_clang.bzl", "download_clang") + +_TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG" +_TF_NEED_CUDA = "TF_NEED_CUDA" + +def _cc_clang_autoconf(repo_ctx): + if repo_ctx.os.environ.get(_TF_DOWNLOAD_CLANG) != "1": + return + if repo_ctx.os.environ.get(_TF_NEED_CUDA) == "1": + # Clang is handled separately for CUDA configs. + # See cuda_configure.bzl for more details. + return + + download_clang(repo_ctx, out_folder='extra_tools') + overriden_tools = {'gcc': 'extra_tools/bin/clang'} + cc_autoconf_impl(repo_ctx, overriden_tools) + +cc_download_clang_toolchain = repository_rule( + environ = [ + _TF_DOWNLOAD_CLANG, + _TF_NEED_CUDA, + ], + implementation = _cc_clang_autoconf, +) diff --git a/third_party/gpus/download_clang.bzl b/third_party/clang_toolchain/download_clang.bzl similarity index 100% rename from third_party/gpus/download_clang.bzl rename to third_party/clang_toolchain/download_clang.bzl diff --git a/third_party/cub/fix_compilation_in_clang.patch b/third_party/cub/fix_compilation_in_clang.patch deleted file mode 100644 index 384e674f2012b2b3ea59c5c1bd205873baa8cf18..0000000000000000000000000000000000000000 --- a/third_party/cub/fix_compilation_in_clang.patch +++ /dev/null @@ -1,23 +0,0 @@ -From 565b77f7c82048871a4d5e3e506dc663d53cd469 Mon Sep 17 00:00:00 2001 -From: Ilya Biryukov -Date: Fri, 26 Jan 2018 18:46:06 +0100 -Subject: [PATCH] Added missing 'template' keyword. - -To unbreak compilation with clang. ---- - cub/device/dispatch/dispatch_radix_sort.cuh | 2 +- - 1 file changed, 1 insertion(+), 1 deletion(-) - -diff --git a/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/device/dispatch/dispatch_radix_sort.cuh -index 7fbc621f..f622e212 100644 ---- a/cub/device/dispatch/dispatch_radix_sort.cuh -+++ b/cub/device/dispatch/dispatch_radix_sort.cuh -@@ -104,7 +104,7 @@ __global__ void DeviceRadixSortUpsweepKernel( - CTA_SYNC(); - - // Write out digit counts (striped) -- upsweep.ExtractCounts(d_spine, gridDim.x, blockIdx.x); -+ upsweep.template ExtractCounts(d_spine, gridDim.x, blockIdx.x); - } - - diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py index 8a1c7db2ea14365be53a796a79fce77900e668e1..f8fb6ecb0ccc7f81040370a80c31d03daa659051 100644 --- a/third_party/examples/eager/spinn/spinn.py +++ b/third_party/examples/eager/spinn/spinn.py @@ -51,6 +51,9 @@ import tensorflow.contrib.eager as tfe from tensorflow.contrib.eager.python.examples.spinn import data +layers = tf.keras.layers + + def _bundle(lstm_iter): """Concatenate a list of Tensors along 1st axis and split result into two. @@ -78,17 +81,16 @@ def _unbundle(state): return tf.split(tf.concat(state, 1), state[0].shape[0], axis=0) -class Reducer(tfe.Network): +# pylint: disable=not-callable +class Reducer(tf.keras.Model): """A module that applies reduce operation on left and right vectors.""" def __init__(self, size, tracker_size=None): super(Reducer, self).__init__() - self.left = self.track_layer(tf.layers.Dense(5 * size, activation=None)) - self.right = self.track_layer( - tf.layers.Dense(5 * size, activation=None, use_bias=False)) + self.left = layers.Dense(5 * size, activation=None) + self.right = layers.Dense(5 * size, activation=None, use_bias=False) if tracker_size is not None: - self.track = self.track_layer( - tf.layers.Dense(5 * size, activation=None, use_bias=False)) + self.track = layers.Dense(5 * size, activation=None, use_bias=False) else: self.track = None @@ -123,7 +125,7 @@ class Reducer(tfe.Network): return h, c -class Tracker(tfe.Network): +class Tracker(tf.keras.Model): """A module that tracks the history of the sentence with an LSTM.""" def __init__(self, tracker_size, predict): @@ -134,10 +136,10 @@ class Tracker(tfe.Network): predict: (`bool`) Whether prediction mode is enabled. """ super(Tracker, self).__init__() - self._rnn = self.track_layer(tf.nn.rnn_cell.LSTMCell(tracker_size)) + self._rnn = tf.nn.rnn_cell.LSTMCell(tracker_size) self._state_size = tracker_size if predict: - self._transition = self.track_layer(tf.layers.Dense(4)) + self._transition = layers.Dense(4) else: self._transition = None @@ -182,7 +184,7 @@ class Tracker(tfe.Network): return unbundled, None -class SPINN(tfe.Network): +class SPINN(tf.keras.Model): """Stack-augmented Parser-Interpreter Neural Network. See https://arxiv.org/abs/1603.06021 for more details. @@ -204,9 +206,9 @@ class SPINN(tfe.Network): """ super(SPINN, self).__init__() self.config = config - self.reducer = self.track_layer(Reducer(config.d_hidden, config.d_tracker)) + self.reducer = Reducer(config.d_hidden, config.d_tracker) if config.d_tracker is not None: - self.tracker = self.track_layer(Tracker(config.d_tracker, config.predict)) + self.tracker = Tracker(config.d_tracker, config.predict) else: self.tracker = None @@ -248,7 +250,7 @@ class SPINN(tfe.Network): trans = transitions[i] if self.tracker: # Invoke tracker to obtain the current tracker states for the sentences. - tracker_states, trans_hypothesis = self.tracker(buffers, stacks) + tracker_states, trans_hypothesis = self.tracker(buffers, stacks=stacks) if trans_hypothesis: trans = tf.argmax(trans_hypothesis, axis=-1) else: @@ -264,7 +266,8 @@ class SPINN(tfe.Network): trackings.append(tracking) if rights: - reducer_output = self.reducer(lefts, rights, trackings) + reducer_output = self.reducer( + lefts, right_in=rights, tracking=trackings) reduced = iter(reducer_output) for transition, stack in zip(trans, stacks): @@ -273,7 +276,27 @@ class SPINN(tfe.Network): return _bundle([stack.pop() for stack in stacks])[0] -class SNLIClassifier(tfe.Network): +class Perceptron(tf.keras.Model): + """One layer of the SNLIClassifier multi-layer perceptron.""" + + def __init__(self, dimension, dropout_rate, previous_layer): + """Configure the Perceptron.""" + super(Perceptron, self).__init__() + self.dense = tf.keras.layers.Dense(dimension, activation=tf.nn.elu) + self.batchnorm = layers.BatchNormalization() + self.dropout = layers.Dropout(rate=dropout_rate) + self.previous_layer = previous_layer + + def call(self, x, training): + """Run previous Perceptron layers, then this one.""" + x = self.previous_layer(x, training=training) + x = self.dense(x) + x = self.batchnorm(x, training=training) + x = self.dropout(x, training=training) + return x + + +class SNLIClassifier(tf.keras.Model): """SNLI Classifier Model. A model aimed at solving the SNLI (Standford Natural Language Inference) @@ -304,29 +327,24 @@ class SNLIClassifier(tfe.Network): self.config = config self.embed = tf.constant(embed) - self.projection = self.track_layer(tf.layers.Dense(config.d_proj)) - self.embed_bn = self.track_layer(tf.layers.BatchNormalization()) - self.embed_dropout = self.track_layer( - tf.layers.Dropout(rate=config.embed_dropout)) - self.encoder = self.track_layer(SPINN(config)) - - self.feature_bn = self.track_layer(tf.layers.BatchNormalization()) - self.feature_dropout = self.track_layer( - tf.layers.Dropout(rate=config.mlp_dropout)) - - self.mlp_dense = [] - self.mlp_bn = [] - self.mlp_dropout = [] - for _ in xrange(config.n_mlp_layers): - self.mlp_dense.append(self.track_layer(tf.layers.Dense(config.d_mlp))) - self.mlp_bn.append( - self.track_layer(tf.layers.BatchNormalization())) - self.mlp_dropout.append( - self.track_layer(tf.layers.Dropout(rate=config.mlp_dropout))) - self.mlp_output = self.track_layer(tf.layers.Dense( + self.projection = layers.Dense(config.d_proj) + self.embed_bn = layers.BatchNormalization() + self.embed_dropout = layers.Dropout(rate=config.embed_dropout) + self.encoder = SPINN(config) + + self.feature_bn = layers.BatchNormalization() + self.feature_dropout = layers.Dropout(rate=config.mlp_dropout) + + current_mlp = lambda result, training: result + for _ in range(config.n_mlp_layers): + current_mlp = Perceptron(dimension=config.d_mlp, + dropout_rate=config.mlp_dropout, + previous_layer=current_mlp) + self.mlp = current_mlp + self.mlp_output = layers.Dense( config.d_out, kernel_initializer=tf.random_uniform_initializer(minval=-5e-3, - maxval=5e-3))) + maxval=5e-3)) def call(self, premise, @@ -370,10 +388,10 @@ class SNLIClassifier(tfe.Network): # Run the batch-normalized and dropout-processed word vectors through the # SPINN encoder. - premise = self.encoder(premise_embed, premise_transition, - training=training) - hypothesis = self.encoder(hypothesis_embed, hypothesis_transition, - training=training) + premise = self.encoder( + premise_embed, transitions=premise_transition, training=training) + hypothesis = self.encoder( + hypothesis_embed, transitions=hypothesis_transition, training=training) # Combine encoder outputs for premises and hypotheses into logits. # Then apply batch normalization and dropuout on the logits. @@ -383,15 +401,12 @@ class SNLIClassifier(tfe.Network): self.feature_bn(logits, training=training), training=training) # Apply the multi-layer perceptron on the logits. - for dense, bn, dropout in zip( - self.mlp_dense, self.mlp_bn, self.mlp_dropout): - logits = tf.nn.elu(dense(logits)) - logits = dropout(bn(logits, training=training), training=training) + logits = self.mlp(logits, training=training) logits = self.mlp_output(logits) return logits -class SNLIClassifierTrainer(object): +class SNLIClassifierTrainer(tfe.Checkpointable): """A class that coordinates the training of an SNLIClassifier.""" def __init__(self, snli_classifier, lr): @@ -450,10 +465,11 @@ class SNLIClassifierTrainer(object): """ with tfe.GradientTape() as tape: tape.watch(self._model.variables) + # TODO(allenl): Allow passing Layer inputs as position arguments. logits = self._model(premise, - premise_transition, - hypothesis, - hypothesis_transition, + premise_transition=premise_transition, + hypothesis=hypothesis, + hypothesis_transition=hypothesis_transition, training=True) loss = self.loss(labels, logits) gradients = tape.gradient(loss, self._model.variables) @@ -517,7 +533,9 @@ def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu): snli_data, batch_size): if use_gpu: label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu() - logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False) + logits = trainer.model( + prem, premise_transition=prem_trans, hypothesis=hypo, + hypothesis_transition=hypo_trans, training=False) loss_val = trainer.loss(label, logits) batch_size = tf.shape(label)[0] mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size) @@ -609,29 +627,30 @@ def train_or_infer_spinn(embed, with tf.device(device), \ summary_writer.as_default(), \ tf.contrib.summary.always_record_summaries(): - with tfe.restore_variables_on_create( - tf.train.latest_checkpoint(config.logdir)): - model = SNLIClassifier(config, embed) - global_step = tf.train.get_or_create_global_step() - trainer = SNLIClassifierTrainer(model, config.lr) + model = SNLIClassifier(config, embed) + global_step = tf.train.get_or_create_global_step() + trainer = SNLIClassifierTrainer(model, config.lr) + checkpoint = tfe.Checkpoint(trainer=trainer, global_step=global_step) + checkpoint.restore(tf.train.latest_checkpoint(config.logdir)) if inference_sentence_pair: # Inference mode. - with tfe.restore_variables_on_create( - tf.train.latest_checkpoint(config.logdir)): - prem, prem_trans = inference_sentence_pair[0] - hypo, hypo_trans = inference_sentence_pair[1] - hypo_trans = inference_sentence_pair[1][1] - inference_logits = model( # pylint: disable=not-callable - tf.constant(prem), tf.constant(prem_trans), - tf.constant(hypo), tf.constant(hypo_trans), training=False) - inference_logits = inference_logits[0][1:] - max_index = tf.argmax(inference_logits) - print("\nInference logits:") - for i, (label, logit) in enumerate( - zip(data.POSSIBLE_LABELS, inference_logits)): - winner_tag = " (winner)" if max_index == i else "" - print(" {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag)) + prem, prem_trans = inference_sentence_pair[0] + hypo, hypo_trans = inference_sentence_pair[1] + hypo_trans = inference_sentence_pair[1][1] + inference_logits = model( + tf.constant(prem), + premise_transition=tf.constant(prem_trans), + hypothesis=tf.constant(hypo), + hypothesis_transition=tf.constant(hypo_trans), + training=False) + inference_logits = inference_logits[0][1:] + max_index = tf.argmax(inference_logits) + print("\nInference logits:") + for i, (label, logit) in enumerate( + zip(data.POSSIBLE_LABELS, inference_logits)): + winner_tag = " (winner)" if max_index == i else "" + print(" {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag)) return inference_logits train_len = train_data.num_batches(config.batch_size) @@ -650,20 +669,15 @@ def train_or_infer_spinn(embed, # remain on CPU. Same in _evaluate_on_dataset(). iterations += 1 - with tfe.restore_variables_on_create( - tf.train.latest_checkpoint(config.logdir)): - batch_train_loss, batch_train_logits = trainer.train_batch( - label, prem, prem_trans, hypo, hypo_trans) + batch_train_loss, batch_train_logits = trainer.train_batch( + label, prem, prem_trans, hypo, hypo_trans) batch_size = tf.shape(label)[0] mean_loss(batch_train_loss.numpy(), weights=batch_size.gpu() if use_gpu else batch_size) accuracy(tf.argmax(batch_train_logits, axis=1), label) if iterations % config.save_every == 0: - all_variables = trainer.variables + [global_step] - saver = tfe.Saver(all_variables) - saver.save(os.path.join(config.logdir, "ckpt"), - global_step=global_step) + checkpoint.save(os.path.join(config.logdir, "ckpt")) if iterations % config.dev_every == 0: dev_loss, dev_frac_correct = _evaluate_on_dataset( diff --git a/third_party/gpus/crosstool/CROSSTOOL_clang.tpl b/third_party/gpus/crosstool/CROSSTOOL_clang.tpl index e4363d604577de09241d635b6990c9dd6429efe0..2f09473ee2ddf9a38ca0c7aa11094690607b532f 100644 --- a/third_party/gpus/crosstool/CROSSTOOL_clang.tpl +++ b/third_party/gpus/crosstool/CROSSTOOL_clang.tpl @@ -49,6 +49,7 @@ toolchain { flag_set { action: "c++-link-executable" action: "c++-link-dynamic-library" + action: "c++-link-nodeps-dynamic-library" flag_group { flag: "-lstdc++" } @@ -75,6 +76,7 @@ toolchain { name: "alwayslink" flag_set { action: "c++-link-dynamic-library" + action: "c++-link-nodeps-dynamic-library" action: "c++-link-executable" flag_group { flag: "-Wl,-no-as-needed" @@ -116,6 +118,7 @@ toolchain { } flag_set { action: "c++-link-dynamic-library" + action: "c++-link-nodeps-dynamic-library" flag_group { flag: "-Wl,-z,relro,-z,now" } @@ -161,6 +164,7 @@ toolchain { flag_set { action: "c++-link-executable" action: "c++-link-dynamic-library" + action: "c++-link-nodeps-dynamic-library" flag_group { # Stamp the binary with a unique identifier. flag: "-Wl,--build-id=md5" @@ -176,6 +180,7 @@ toolchain { action: "c++-compile" action: "c++-link-executable" action: "c++-link-dynamic-library" + action: "c++-link-nodeps-dynamic-library" flag_group { flag:"-no-canonical-prefixes" } @@ -199,6 +204,7 @@ toolchain { flag_set { action: "c++-link-executable" action: "c++-link-dynamic-library" + action: "c++-link-nodeps-dynamic-library" flag_group { flag: "-B/usr/bin/" } @@ -246,6 +252,7 @@ toolchain { } flag_set { action: "c++-link-dynamic-library" + action: "c++-link-nodeps-dynamic-library" action: "c++-link-executable" flag_group { flag: "-Wl,--gc-sections" diff --git a/third_party/gpus/cuda/remote.BUILD.tpl b/third_party/gpus/cuda/remote.BUILD.tpl index d88d512b90c352e6a301ed6efe8266d8dd6bf744..f774def5e6cec25e4920ecce0076340a31c70386 100644 --- a/third_party/gpus/cuda/remote.BUILD.tpl +++ b/third_party/gpus/cuda/remote.BUILD.tpl @@ -41,65 +41,65 @@ config_setting( alias( name = "cuda_headers", - actual = "%{remote_cuda_repo}cuda:cuda_headers", + actual = "%{remote_cuda_repo}/cuda:cuda_headers", ) alias( name = "cudart_static", - actual = "%{remote_cuda_repo}cuda:cudart_static", + actual = "%{remote_cuda_repo}/cuda:cudart_static", ) alias( name = "cuda_driver", - actual = "%{remote_cuda_repo}cuda:cuda_driver", + actual = "%{remote_cuda_repo}/cuda:cuda_driver", ) alias( name = "cudart", - actual = "%{remote_cuda_repo}cuda:cudart", + actual = "%{remote_cuda_repo}/cuda:cudart", ) alias( name = "cublas", - actual = "%{remote_cuda_repo}cuda:cublas", + actual = "%{remote_cuda_repo}/cuda:cublas", ) alias( name = "cusolver", - actual = "%{remote_cuda_repo}cuda:cusolver", + actual = "%{remote_cuda_repo}/cuda:cusolver", ) alias( name = "cudnn", - actual = "%{remote_cuda_repo}cuda:cudnn", + actual = "%{remote_cuda_repo}/cuda:cudnn", ) alias( name = "cufft", - actual = "%{remote_cuda_repo}cuda:cufft", + actual = "%{remote_cuda_repo}/cuda:cufft", ) alias( name = "curand", - actual = "%{remote_cuda_repo}cuda:curand", + actual = "%{remote_cuda_repo}/cuda:curand", ) alias( name = "cuda", - actual = "%{remote_cuda_repo}cuda:cuda", + actual = "%{remote_cuda_repo}/cuda:cuda", ) alias( name = "cupti_headers", - actual = "%{remote_cuda_repo}cuda:cupti_headers", + actual = "%{remote_cuda_repo}/cuda:cupti_headers", ) alias( name = "cupti_dsos", - actual = "%{remote_cuda_repo}cuda:cupti_dsos", + actual = "%{remote_cuda_repo}/cuda:cupti_dsos", ) alias( name = "libdevice_root", - actual = "%{remote_cuda_repo}cuda:libdevice_root", + actual = "%{remote_cuda_repo}/cuda:libdevice_root", ) diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index b7c47a19ddcfc69dbee54bf6ca4080489b292c01..ede7e318976527eb4fe6489083dc45896733f7bf 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -38,7 +38,65 @@ _DEFAULT_CUDA_TOOLKIT_PATH = "/usr/local/cuda" _DEFAULT_CUDNN_INSTALL_PATH = "/usr/local/cuda" _DEFAULT_CUDA_COMPUTE_CAPABILITIES = ["3.5", "5.2"] -load(":download_clang.bzl", "download_clang") +# Lookup paths for CUDA / cuDNN libraries, relative to the install directories. +# +# Paths will be tried out in the order listed below. The first successful path +# will be used. For example, when looking for the cudart libraries, the first +# attempt will be lib64/cudart inside the CUDA toolkit. +CUDA_LIB_PATHS = [ + "lib64/", + "lib64/stubs/", + "lib/x86_64-linux-gnu/", + "lib/x64/", + "lib/", + "", +] + +# Lookup paths for cupti.h, relative to the CUDA toolkit directory. +# +# On most systems, the cupti library is not installed in the same directory as +# the other CUDA libraries but rather in a special extras/CUPTI directory. +CUPTI_HEADER_PATHS = [ + "extras/CUPTI/include/", + "include/cuda/CUPTI/", +] + +# Lookup paths for the cupti library, relative to the +# +# On most systems, the cupti library is not installed in the same directory as +# the other CUDA libraries but rather in a special extras/CUPTI directory. +CUPTI_LIB_PATHS = [ + "extras/CUPTI/lib64/", + "lib/x86_64-linux-gnu", + "lib64/", + "extras/CUPTI/libx64/", + "extras/CUPTI/lib/", + "lib/", +] + +# Lookup paths for CUDA headers (cuda.h) relative to the CUDA toolkit directory. +CUDA_INCLUDE_PATHS = [ + "include/", + "include/cuda/" +] + +# Lookup paths for cudnn.h relative to the CUDNN install directory. +CUDNN_INCLUDE_PATHS = [ + "", + "include/", + "include/cuda/", +] + +# Lookup paths for NVVM libdevice relative to the CUDA directory toolkit. +# +# libdevice implements mathematical functions for GPU kernels, and is provided +# in NVVM bitcode (a subset of LLVM bitcode). +NVVM_LIBDEVICE_PATHS = [ + "nvvm/libdevice/", + "share/cuda/", +] + +load("//third_party/clang_toolchain:download_clang.bzl", "download_clang") # TODO(dzc): Once these functions have been factored out of Bazel's # cc_configure.bzl, load them from @bazel_tools instead. @@ -522,31 +580,31 @@ def _find_cuda_lib(lib, repository_ctx, cpu_value, basedir, version="", path: The full path to the library. """ file_name = _lib_name(lib, cpu_value, version, static) - if cpu_value == "Linux": - path = repository_ctx.path("%s/lib64/%s" % (basedir, file_name)) - if path.exists: - return struct(file_name=file_name, path=str(path.realpath)) - path = repository_ctx.path("%s/lib64/stubs/%s" % (basedir, file_name)) - if path.exists: - return struct(file_name=file_name, path=str(path.realpath)) - path = repository_ctx.path( - "%s/lib/x86_64-linux-gnu/%s" % (basedir, file_name)) + for relative_path in CUDA_LIB_PATHS: + path = repository_ctx.path("%s/%s%s" % (basedir, relative_path, file_name)) if path.exists: return struct(file_name=file_name, path=str(path.realpath)) + auto_configure_fail("Cannot find cuda library %s" % file_name) - elif cpu_value == "Windows": - path = repository_ctx.path("%s/lib/x64/%s" % (basedir, file_name)) - if path.exists: - return struct(file_name=file_name, path=str(path.realpath)) - path = repository_ctx.path("%s/lib/%s" % (basedir, file_name)) - if path.exists: - return struct(file_name=file_name, path=str(path.realpath)) - path = repository_ctx.path("%s/%s" % (basedir, file_name)) - if path.exists: - return struct(file_name=file_name, path=str(path.realpath)) +def _find_cupti_header_dir(repository_ctx, cuda_config): + """Returns the path to the directory containing cupti.h - auto_configure_fail("Cannot find cuda library %s" % file_name) + On most systems, the cupti library is not installed in the same directory as + the other CUDA libraries but rather in a special extras/CUPTI directory. + + Args: + repository_ctx: The repository context. + cuda_config: The CUDA config as returned by _get_cuda_config + + Returns: + The path of the directory containing the cupti header. + """ + cuda_toolkit_path = cuda_config.cuda_toolkit_path + for relative_path in CUPTI_HEADER_PATHS: + if repository_ctx.path("%s/%scupti.h" % (cuda_toolkit_path, relative_path)).exists: + return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1] + auto_configure_fail("Cannot find cupti.h under %s" % cuda_toolkit_path) def _find_cupti_lib(repository_ctx, cuda_config): @@ -566,35 +624,13 @@ def _find_cupti_lib(repository_ctx, cuda_config): """ file_name = _lib_name("cupti", cuda_config.cpu_value, cuda_config.cuda_version) - if cuda_config.cpu_value == "Linux": - path = repository_ctx.path( - "%s/extras/CUPTI/lib64/%s" % (cuda_config.cuda_toolkit_path, file_name)) - if path.exists: - return struct(file_name=file_name, path=str(path.realpath)) - - path = repository_ctx.path( - "%s/lib/x86_64-linux-gnu/%s" % (cuda_config.cuda_toolkit_path, - file_name)) - if path.exists: - return struct(file_name=file_name, path=str(path.realpath)) - - elif cuda_config.cpu_value == "Windows": + cuda_toolkit_path = cuda_config.cuda_toolkit_path + for relative_path in CUPTI_LIB_PATHS: path = repository_ctx.path( - "%s/extras/CUPTI/libx64/%s" % - (cuda_config.cuda_toolkit_path, file_name)) + "%s/%s%s" % (cuda_toolkit_path, relative_path, file_name)) if path.exists: return struct(file_name=file_name, path=str(path.realpath)) - path = repository_ctx.path( - "%s/extras/CUPTI/lib/%s" % (cuda_config.cuda_toolkit_path, file_name)) - if path.exists: - return struct(file_name=file_name, path=str(path.realpath)) - - path = repository_ctx.path( - "%s/lib/%s" % (cuda_config.cuda_toolkit_path, file_name)) - if path.exists: - return struct(file_name=file_name, path=str(path.realpath)) - auto_configure_fail("Cannot find cupti library %s" % file_name) def _find_libs(repository_ctx, cuda_config): @@ -635,6 +671,23 @@ def _find_libs(repository_ctx, cuda_config): } +def _find_cuda_include_path(repository_ctx, cuda_config): + """Returns the path to the directory containing cuda.h + + Args: + repository_ctx: The repository context. + cuda_config: The CUDA config as returned by _get_cuda_config + + Returns: + The path of the directory containing the CUDA headers. + """ + cuda_toolkit_path = cuda_config.cuda_toolkit_path + for relative_path in CUDA_INCLUDE_PATHS: + if repository_ctx.path("%s/%scuda.h" % (cuda_toolkit_path, relative_path)).exists: + return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1] + auto_configure_fail("Cannot find cuda.h under %s" % cuda_toolkit_path) + + def _find_cudnn_header_dir(repository_ctx, cudnn_install_basedir): """Returns the path to the directory containing cudnn.h @@ -646,15 +699,31 @@ def _find_cudnn_header_dir(repository_ctx, cudnn_install_basedir): Returns: The path of the directory containing the cudnn header. """ - if repository_ctx.path(cudnn_install_basedir + "/cudnn.h").exists: - return cudnn_install_basedir - if repository_ctx.path(cudnn_install_basedir + "/include/cudnn.h").exists: - return cudnn_install_basedir + "/include" + for relative_path in CUDA_INCLUDE_PATHS: + if repository_ctx.path("%s/%scudnn.h" % (cudnn_install_basedir, relative_path)).exists: + return ("%s/%s" % (cudnn_install_basedir, relative_path))[:-1] if repository_ctx.path("/usr/include/cudnn.h").exists: return "/usr/include" auto_configure_fail("Cannot find cudnn.h under %s" % cudnn_install_basedir) +def _find_nvvm_libdevice_dir(repository_ctx, cuda_config): + """Returns the path to the directory containing libdevice in bitcode format. + + Args: + repository_ctx: The repository context. + cuda_config: The CUDA config as returned by _get_cuda_config + + Returns: + The path of the directory containing the CUDA headers. + """ + cuda_toolkit_path = cuda_config.cuda_toolkit_path + for relative_path in NVVM_LIBDEVICE_PATHS: + if repository_ctx.path("%s/%slibdevice.10.bc" % (cuda_toolkit_path, relative_path)).exists: + return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1] + auto_configure_fail("Cannot find libdevice.10.bc under %s" % cuda_toolkit_path) + + def _cudart_static_linkopt(cpu_value): """Returns additional platform-specific linkopts for cudart.""" return "" if cpu_value == "Darwin" else "\"-lrt\"," @@ -925,21 +994,22 @@ def _create_local_cuda_repository(repository_ctx): """Creates the repository containing files set up to build with CUDA.""" cuda_config = _get_cuda_config(repository_ctx) + cuda_include_path = _find_cuda_include_path(repository_ctx, cuda_config) cudnn_header_dir = _find_cudnn_header_dir(repository_ctx, cuda_config.cudnn_install_basedir) + cupti_header_dir = _find_cupti_header_dir(repository_ctx, cuda_config) + nvvm_libdevice_dir = _find_nvvm_libdevice_dir(repository_ctx, cuda_config) # Set up symbolic links for the cuda toolkit by creating genrules to do # symlinking. We create one genrule for each directory we want to track under # cuda_toolkit_path cuda_toolkit_path = cuda_config.cuda_toolkit_path - cuda_include_path = cuda_toolkit_path + "/include" genrules = [symlink_genrule_for_dir(repository_ctx, cuda_include_path, "cuda/include", "cuda-include")] genrules.append(symlink_genrule_for_dir(repository_ctx, - cuda_toolkit_path + "/nvvm", "cuda/nvvm", "cuda-nvvm")) + nvvm_libdevice_dir, "cuda/nvvm/libdevice", "cuda-nvvm")) genrules.append(symlink_genrule_for_dir(repository_ctx, - cuda_toolkit_path + "/extras/CUPTI/include", - "cuda/extras/CUPTI/include", "cuda-extras")) + cupti_header_dir, "cuda/extras/CUPTI/include", "cuda-extras")) cuda_libs = _find_libs(repository_ctx, cuda_config) cuda_lib_src = [] @@ -1086,6 +1156,7 @@ cuda_configure = repository_rule( _TF_CUDNN_VERSION, _TF_CUDA_COMPUTE_CAPABILITIES, _TF_CUDA_CONFIG_REPO, + "NVVMIR_LIBRARY_DIR", ], ) diff --git a/third_party/jpeg/jpeg.BUILD b/third_party/jpeg/jpeg.BUILD index 87a23925c4316c3ee107af77272300e34b1bb257..4418ac32fc4b08713ff1d1f0d78042803153c886 100644 --- a/third_party/jpeg/jpeg.BUILD +++ b/third_party/jpeg/jpeg.BUILD @@ -526,12 +526,12 @@ config_setting( config_setting( name = "armeabi-v7a", - values = {"android_cpu": "armeabi-v7a"}, + values = {"cpu": "armeabi-v7a"}, ) config_setting( name = "arm64-v8a", - values = {"android_cpu": "arm64-v8a"}, + values = {"cpu": "arm64-v8a"}, ) config_setting( diff --git a/third_party/kafka/BUILD b/third_party/kafka/BUILD index a61a9e1f6c2b29ad3b992e810c0cab463dfd7feb..a839ca717e695f35fac684b510f0a022010e0710 100644 --- a/third_party/kafka/BUILD +++ b/third_party/kafka/BUILD @@ -130,12 +130,16 @@ cc_library( ], hdrs = [ "config.h", + "src-cpp/rdkafkacpp.h", + "src-cpp/rdkafkacpp_int.h", + "src/lz4.c", + "src/snappy_compat.h", ], - defines = [ + copts = [ + "-Iexternal/kafka/src", + "-Iexternal/kafka/src-cpp", ], - includes = [ - "src", - "src-cpp", + defines = [ ], linkopts = [ "-lpthread", @@ -143,5 +147,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ "@boringssl//:ssl", + "@zlib_archive//:zlib", ], ) diff --git a/third_party/mkl/BUILD b/third_party/mkl/BUILD index b27d341404c4ee1ca1e87ff3b9f427ec52eba739..3262562bccca4f2a8b3da860cb38928f144994a9 100644 --- a/third_party/mkl/BUILD +++ b/third_party/mkl/BUILD @@ -1,7 +1,5 @@ licenses(["notice"]) # 3-Clause BSD -exports_files(["LICENSE"]) - config_setting( name = "using_mkl", values = { @@ -10,17 +8,52 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "using_mkl_lnx_x64", + values = { + "cpu": "k8", + "define": "using_mkl=true", + }, + visibility = ["//visibility:public"], +) + load( "//third_party/mkl:build_defs.bzl", "if_mkl", ) +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], + srcs = ["MKL_LICENSE"] + select({ + "@org_tensorflow//tensorflow:linux_x86_64": [ + "@mkl_linux//:LICENSE", + ], + "@org_tensorflow//tensorflow:darwin": [ + "@mkl_darwin//:LICENSE", + ], + "@org_tensorflow//tensorflow:windows": [ + "@mkl_windows//:LICENSE", + ] + }) +) + cc_library( name = "intel_binary_blob", - srcs = if_mkl([ - "@mkl//:libmklml_intel.so", - "@mkl//:libiomp5.so", - ]), + visibility = ["//visibility:public"], - deps = ["@mkl//:mkl_headers"], + deps = select({ + "@org_tensorflow//tensorflow:linux_x86_64": [ + "@mkl_linux//:mkl_headers", + "@mkl_linux//:mkl_libs_linux", + ], + "@org_tensorflow//tensorflow:darwin": [ + "@mkl_darwin//:mkl_headers", + "@mkl_darwin//:mkl_libs_darwin", + ], + "@org_tensorflow//tensorflow:windows": [ + "@mkl_windows//:mkl_headers", + "@mkl_windows//:mkl_libs_windows", + ] + }) ) diff --git a/third_party/mkl/LICENSE b/third_party/mkl/MKL_LICENSE similarity index 100% rename from third_party/mkl/LICENSE rename to third_party/mkl/MKL_LICENSE diff --git a/third_party/mkl/build_defs.bzl b/third_party/mkl/build_defs.bzl index 8b73ddabdd7ff5de7374ffbbb76e7bf954c27765..53e02769dad5dd74348dec2dcec88010e543f01c 100644 --- a/third_party/mkl/build_defs.bzl +++ b/third_party/mkl/build_defs.bzl @@ -24,6 +24,18 @@ def if_mkl(if_true, if_false = []): "//conditions:default": if_false }) +def if_mkl_lnx_x64(if_true, if_false = []): + """Shorthand for select()'ing on whether we're building with MKL. + + Returns a select statement which evaluates to if_true if we're building + with MKL enabled. Otherwise, the select statement evaluates to if_false. + + """ + return select({ + str(Label("//third_party/mkl:using_mkl_lnx_x64")): if_true, + "//conditions:default": if_false + }) + def _enable_local_mkl(repository_ctx): return _TF_MKL_ROOT in repository_ctx.os.environ diff --git a/third_party/mkl/mkl.BUILD b/third_party/mkl/mkl.BUILD index 8db97232e156b46091b379b0771239f55d6ea5ad..892221ec00295a694ab40868cd886e820768f78f 100644 --- a/third_party/mkl/mkl.BUILD +++ b/third_party/mkl/mkl.BUILD @@ -17,14 +17,29 @@ cc_library( visibility = ["//visibility:public"], ) -filegroup( - name = "libmklml_intel.so", - srcs = ["lib/libmklml_intel.so"], +cc_library( + name = "mkl_libs_linux", + srcs = [ + "lib/libiomp5.so", + "lib/libmklml_intel.so" + ], visibility = ["//visibility:public"], ) -filegroup( - name = "libiomp5.so", - srcs = ["lib/libiomp5.so"], +cc_library( + name = "mkl_libs_darwin", + srcs = [ + "lib/libiomp5.dylib", + "lib/libmklml.dylib" + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "mkl_libs_windows", + srcs = [ + "lib/libiomp5md.lib", + "lib/mklml.lib" + ], visibility = ["//visibility:public"], ) diff --git a/third_party/mkl_dnn/mkldnn.BUILD b/third_party/mkl_dnn/mkldnn.BUILD index 58bb7a6a5d0494301aa5b0bd29f858e7d06e69d3..68f24aabaee6ed33fe5b92a3996f7d175b924ea0 100644 --- a/third_party/mkl_dnn/mkldnn.BUILD +++ b/third_party/mkl_dnn/mkldnn.BUILD @@ -1,5 +1,13 @@ exports_files(["LICENSE"]) +config_setting( + name = "clang_linux_x86_64", + values = { + "cpu": "k8", + "define": "using_clang=true", + }, +) + cc_library( name = "mkl_dnn", srcs = glob([ @@ -9,8 +17,11 @@ cc_library( hdrs = glob(["include/*"]), copts = ["-fexceptions"] + select({ "@org_tensorflow//tensorflow:linux_x86_64": [ - "-fopenmp", + "-fopenmp", # only works with gcc ], + # TODO(ibiryukov): enable openmp with clang by including libomp as a + # dependency. + ":clang_linux_x86_64": [], "//conditions:default": [], }), includes = [ diff --git a/third_party/py/BUILD.tpl b/third_party/py/BUILD.tpl index de06ad5f27e7c08aade4a8f51ab60ba52d012b7b..1dd8ab433a37a127b98ae7069bffcbfd4f6d8bd1 100644 --- a/third_party/py/BUILD.tpl +++ b/third_party/py/BUILD.tpl @@ -2,20 +2,26 @@ licenses(["restricted"]) package(default_visibility = ["//visibility:public"]) +# To build Python C/C++ extension on Windows, we need to link to python import library pythonXY.lib +# See https://docs.python.org/3/extending/windows.html +cc_import( + name = "python_lib", + interface_library = select({ + ":windows": ":python_import_lib", + # A placeholder for Unix platforms which makes --no_build happy. + "//conditions:default": "not-existing.lib", + }), + system_provided = 1, +) + cc_library( name = "python_headers", hdrs = [":python_include"], - data = select({ - ":windows": [":python_import_lib"], + deps = select({ + ":windows": [":python_lib"], "//conditions:default": [], }), includes = ["python_include"], - linkopts = select({ - # TODO(pcloudy): Ideally, this should just go into deps after resolving - # https://github.com/bazelbuild/bazel/issues/3237, - ":windows": ["$(locations :python_import_lib)"], - "//conditions:default": [], - }), ) cc_library( diff --git a/third_party/sycl/sycl/BUILD.tpl b/third_party/sycl/sycl/BUILD.tpl index 21b1a2bbf7d320327d8f6e35124e6ef47019130b..b7e9aa8edb4dd1ecc36595ea0a11f442d05cefee 100755 --- a/third_party/sycl/sycl/BUILD.tpl +++ b/third_party/sycl/sycl/BUILD.tpl @@ -21,7 +21,7 @@ config_setting( name = "using_sycl_trisycl", define_values = { "using_sycl": "true", - "using_trisycl": "false", + "using_trisycl": "true", }, ) diff --git a/third_party/tensorrt/tensorrt_configure.bzl b/third_party/tensorrt/tensorrt_configure.bzl index 8e76e5d02aeddab66dacaa495a6c493f18a95a69..9b946505a615372aa7de317c8ee390a2cd4b60e9 100644 --- a/third_party/tensorrt/tensorrt_configure.bzl +++ b/third_party/tensorrt/tensorrt_configure.bzl @@ -57,6 +57,10 @@ def _find_trt_header_dir(repository_ctx, trt_install_path): path = "/usr/include/x86_64-linux-gnu" if _headers_exist(repository_ctx, path): return path + if trt_install_path == "/usr/lib/aarch64-linux-gnu": + path = "/usr/include/aarch64-linux-gnu" + if _headers_exist(repository_ctx, path): + return path path = str(repository_ctx.path("%s/../include" % trt_install_path).realpath) if _headers_exist(repository_ctx, path): return path diff --git a/third_party/toolchains/gpus/crosstool/BUILD b/third_party/toolchains/gpus/crosstool/BUILD index a8c6b0f0291363f3a7576a70e78b3428fb984957..1f9065007ca884a46bfa391d1ee8a8f0333da235 100644 --- a/third_party/toolchains/gpus/crosstool/BUILD +++ b/third_party/toolchains/gpus/crosstool/BUILD @@ -50,3 +50,8 @@ filegroup( name = "empty", srcs = [], ) + +filegroup( + name = "crosstool_wrapper_driver_is_not_gcc", + srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"], +) diff --git a/third_party/toolchains/gpus/crosstool/CROSSTOOL b/third_party/toolchains/gpus/crosstool/CROSSTOOL index a47e0c7cd74edcea777d76854c2d7e97d69897fa..d6ee7e38c414dd59b76c7b2b4c95c55831bb30a8 100644 --- a/third_party/toolchains/gpus/crosstool/CROSSTOOL +++ b/third_party/toolchains/gpus/crosstool/CROSSTOOL @@ -53,6 +53,7 @@ toolchain { flag_set { action: "c++-link-executable" action: "c++-link-dynamic-library" + action: "c++-link-nodeps-dynamic-library" flag_group { flag: "-lstdc++" } @@ -79,6 +80,7 @@ toolchain { name: "alwayslink" flag_set { action: "c++-link-dynamic-library" + action: "c++-link-nodeps-dynamic-library" action: "c++-link-executable" flag_group { flag: "-Wl,-no-as-needed" @@ -120,6 +122,7 @@ toolchain { } flag_set { action: "c++-link-dynamic-library" + action: "c++-link-nodeps-dynamic-library" flag_group { flag: "-Wl,-z,relro,-z,now" } @@ -141,8 +144,8 @@ toolchain { flag_group { # All warnings are enabled. Maybe enable -Werror as well? flag: "-Wall" - # TODO(ngiraldo): Some parts of the codebase set -Werror and hit this - # warning, so switch it off for now. + # Some parts of the codebase set -Werror and hit this warning, so + # switch it off for now. flag: "-Wno-invalid-partial-specialization" } } @@ -165,6 +168,7 @@ toolchain { flag_set { action: "c++-link-executable" action: "c++-link-dynamic-library" + action: "c++-link-nodeps-dynamic-library" flag_group { # Stamp the binary with a unique identifier. flag: "-Wl,--build-id=md5" @@ -180,6 +184,7 @@ toolchain { action: "c++-compile" action: "c++-link-executable" action: "c++-link-dynamic-library" + action: "c++-link-nodeps-dynamic-library" flag_group { flag:"-no-canonical-prefixes" } @@ -203,6 +208,7 @@ toolchain { flag_set { action: "c++-link-executable" action: "c++-link-dynamic-library" + action: "c++-link-nodeps-dynamic-library" flag_group { flag: "-B/usr/bin/" } @@ -250,6 +256,7 @@ toolchain { } flag_set { action: "c++-link-dynamic-library" + action: "c++-link-nodeps-dynamic-library" action: "c++-link-executable" flag_group { flag: "-Wl,--gc-sections" @@ -296,7 +303,7 @@ toolchain { cxx_builtin_include_directory: "/usr/include/x86_64-linux-gnu/c++/5.4.0" cxx_builtin_include_directory: "/usr/include/c++/5.4.0/backward" cxx_builtin_include_directory: "/usr/local/include" - cxx_builtin_include_directory: "/usr/local/lib/clang/6.0.0/include" + cxx_builtin_include_directory: "/usr/local/lib/clang/7.0.0/include" cxx_builtin_include_directory: "/usr/include/x86_64-linux-gnu" cxx_builtin_include_directory: "/usr/include" } diff --git a/third_party/toolchains/gpus/cuda/BUILD b/third_party/toolchains/gpus/cuda/BUILD index 39136de99c901d6d6a9dafefe3163972511ec122..4cb83809383afa52d5a1d98777f8e5bb2d266286 100644 --- a/third_party/toolchains/gpus/cuda/BUILD +++ b/third_party/toolchains/gpus/cuda/BUILD @@ -51,6 +51,7 @@ cc_library( includes = [ ".", "cuda/include", + "cuda/include/crt", ], visibility = ["//visibility:public"], ) @@ -84,8 +85,8 @@ cc_library( cc_library( name = "cudart", - srcs = ["cuda/lib/libcudart.so.8.0"], - data = ["cuda/lib/libcudart.so.8.0"], + srcs = ["cuda/lib/libcudart.so.9.0"], + data = ["cuda/lib/libcudart.so.9.0"], includes = [ ".", "cuda/include", @@ -96,8 +97,8 @@ cc_library( cc_library( name = "cublas", - srcs = ["cuda/lib/libcublas.so.8.0"], - data = ["cuda/lib/libcublas.so.8.0"], + srcs = ["cuda/lib/libcublas.so.9.0"], + data = ["cuda/lib/libcublas.so.9.0"], includes = [ ".", "cuda/include", @@ -108,8 +109,8 @@ cc_library( cc_library( name = "cusolver", - srcs = ["cuda/lib/libcusolver.so.8.0"], - data = ["cuda/lib/libcusolver.so.8.0"], + srcs = ["cuda/lib/libcusolver.so.9.0"], + data = ["cuda/lib/libcusolver.so.9.0"], includes = [ ".", "cuda/include", @@ -121,8 +122,8 @@ cc_library( cc_library( name = "cudnn", - srcs = ["cuda/lib/libcudnn.so.6"], - data = ["cuda/lib/libcudnn.so.6"], + srcs = ["cuda/lib/libcudnn.so.7"], + data = ["cuda/lib/libcudnn.so.7"], includes = [ ".", "cuda/include", @@ -133,8 +134,8 @@ cc_library( cc_library( name = "cufft", - srcs = ["cuda/lib/libcufft.so.8.0"], - data = ["cuda/lib/libcufft.so.8.0"], + srcs = ["cuda/lib/libcufft.so.9.0"], + data = ["cuda/lib/libcufft.so.9.0"], includes = [ ".", "cuda/include", @@ -145,8 +146,8 @@ cc_library( cc_library( name = "curand", - srcs = ["cuda/lib/libcurand.so.8.0"], - data = ["cuda/lib/libcurand.so.8.0"], + srcs = ["cuda/lib/libcurand.so.9.0"], + data = ["cuda/lib/libcurand.so.9.0"], includes = [ ".", "cuda/include", @@ -183,7 +184,7 @@ cc_library( cc_library( name = "cupti_dsos", - data = ["cuda/lib/libcupti.so.8.0"], + data = ["cuda/lib/libcupti.so.9.0"], includes = [ ".", "cuda/include", @@ -200,1063 +201,990 @@ cc_library( genrule( name = "cuda-include", outs = [ - "cuda/include/math_functions.hpp", - "cuda/include/cufft.h", - "cuda/include/nvgraph.h", - "cuda/include/curand_normal.h", - "cuda/include/curand_uniform.h", - "cuda/include/nppi_data_exchange_and_initialization.h", - "cuda/include/cuda_gl_interop.h", - "cuda/include/nppi_compression_functions.h", - "cuda/include/npp.h", + "cuda/include/CL/cl.h", + "cuda/include/CL/cl.hpp", + "cuda/include/CL/cl_egl.h", + "cuda/include/CL/cl_ext.h", + "cuda/include/CL/cl_gl.h", + "cuda/include/CL/cl_gl_ext.h", + "cuda/include/CL/cl_platform.h", + "cuda/include/CL/opencl.h", + "cuda/include/builtin_types.h", + "cuda/include/channel_descriptor.h", + "cuda/include/common_functions.h", + "cuda/include/cooperative_groups.h", + "cuda/include/cooperative_groups_helpers.h", + "cuda/include/crt/common_functions.h", + "cuda/include/crt/device_double_functions.h", + "cuda/include/crt/device_double_functions.hpp", + "cuda/include/crt/device_functions.h", + "cuda/include/crt/device_functions.hpp", + "cuda/include/crt/func_macro.h", + "cuda/include/crt/host_config.h", + "cuda/include/crt/host_defines.h", + "cuda/include/crt/host_runtime.h", + "cuda/include/crt/math_functions.h", + "cuda/include/crt/math_functions.hpp", + "cuda/include/crt/mma.h", + "cuda/include/crt/mma.hpp", + "cuda/include/crt/nvfunctional", + "cuda/include/crt/sm_70_rt.h", + "cuda/include/crt/sm_70_rt.hpp", + "cuda/include/crt/storage_class.h", + "cuda/include/cuComplex.h", + "cuda/include/cublas.h", + "cuda/include/cublasXt.h", + "cuda/include/cublas_api.h", + "cuda/include/cublas_v2.h", "cuda/include/cuda.h", - "cuda/include/nppi_statistics_functions.h", - "cuda/include/vector_functions.hpp", - "cuda/include/sm_32_intrinsics.hpp", - "cuda/include/sm_32_intrinsics.h", - "cuda/include/curand_discrete.h", + "cuda/include/cudaEGL.h", + "cuda/include/cudaGL.h", + "cuda/include/cudaProfiler.h", + "cuda/include/cudaVDPAU.h", + "cuda/include/cuda_device_runtime_api.h", + "cuda/include/cuda_fp16.h", + "cuda/include/cuda_fp16.hpp", + "cuda/include/cuda_gl_interop.h", + "cuda/include/cuda_occupancy.h", + "cuda/include/cuda_profiler_api.h", "cuda/include/cuda_runtime.h", + "cuda/include/cuda_runtime_api.h", + "cuda/include/cuda_surface_types.h", + "cuda/include/cuda_texture_types.h", + "cuda/include/cuda_vdpau_interop.h", + "cuda/include/cudalibxt.h", + "cuda/include/cudnn.h", + "cuda/include/cufft.h", "cuda/include/cufftXt.h", - "cuda/include/sm_61_intrinsics.h", - "cuda/include/texture_fetch_functions.h", + "cuda/include/cufftw.h", + "cuda/include/curand.h", + "cuda/include/curand_discrete.h", + "cuda/include/curand_discrete2.h", + "cuda/include/curand_globals.h", + "cuda/include/curand_kernel.h", + "cuda/include/curand_lognormal.h", "cuda/include/curand_mrg32k3a.h", - "cuda/include/host_defines.h", - "cuda/include/common_functions.h", - "cuda/include/nppi_support_functions.h", - "cuda/include/nppi_linear_transforms.h", - "cuda/include/device_double_functions.hpp", - "cuda/include/math_constants.h", - "cuda/include/nvToolsExtSync.h", - "cuda/include/npps_initialization.h", + "cuda/include/curand_mtgp32.h", + "cuda/include/curand_mtgp32_host.h", + "cuda/include/curand_mtgp32_kernel.h", + "cuda/include/curand_mtgp32dc_p_11213.h", + "cuda/include/curand_normal.h", + "cuda/include/curand_normal_static.h", + "cuda/include/curand_philox4x32_x.h", + "cuda/include/curand_poisson.h", + "cuda/include/curand_precalc.h", + "cuda/include/curand_uniform.h", + "cuda/include/cusolverDn.h", + "cuda/include/cusolverRf.h", + "cuda/include/cusolverSp.h", "cuda/include/cusolverSp_LOWLEVEL_PREVIEW.h", - "cuda/include/texture_indirect_functions.hpp", - "cuda/include/cudaProfiler.h", - "cuda/include/npps_filtering_functions.h", + "cuda/include/cusolver_common.h", + "cuda/include/cusparse.h", "cuda/include/cusparse_v2.h", - "cuda/include/nppi.h", - "cuda/include/surface_indirect_functions.h", - "cuda/include/sm_30_intrinsics.h", + "cuda/include/device_atomic_functions.h", + "cuda/include/device_atomic_functions.hpp", "cuda/include/device_double_functions.h", - "cuda/include/sm_35_intrinsics.h", - "cuda/include/cusolverSp.h", - "cuda/include/library_types.h", - "cuda/include/surface_indirect_functions.hpp", - "cuda/include/cudalibxt.h", - "cuda/include/channel_descriptor.h", + "cuda/include/device_double_functions.hpp", + "cuda/include/device_functions.h", + "cuda/include/device_functions.hpp", "cuda/include/device_functions_decls.h", - "cuda/include/curand_kernel.h", - "cuda/include/curand_mtgp32_host.h", - "cuda/include/nvToolsExtCuda.h", - "cuda/include/nvToolsExt.h", - "cuda/include/cuComplex.h", - "cuda/include/sm_32_atomic_functions.h", - "cuda/include/texture_indirect_functions.h", - "cuda/include/sm_32_atomic_functions.hpp", - "cuda/include/sm_20_intrinsics.hpp", "cuda/include/device_launch_parameters.h", - "cuda/include/curand_mtgp32.h", - "cuda/include/texture_fetch_functions.hpp", - "cuda/include/cuda_occupancy.h", - "cuda/include/CL/opencl.h", - "cuda/include/CL/cl_platform.h", - "cuda/include/CL/cl_egl.h", - "cuda/include/CL/cl_gl.h", - "cuda/include/CL/cl.h", - "cuda/include/CL/cl_gl_ext.h", - "cuda/include/CL/cl_ext.h", - "cuda/include/CL/cl.hpp", + "cuda/include/device_types.h", + "cuda/include/driver_functions.h", + "cuda/include/driver_types.h", + "cuda/include/dynlink_cuda.h", + "cuda/include/dynlink_cuda_cuda.h", + "cuda/include/dynlink_cuviddec.h", + "cuda/include/dynlink_nvcuvid.h", + "cuda/include/fatBinaryCtl.h", + "cuda/include/fatbinary.h", "cuda/include/host_config.h", - "cuda/include/cuda_surface_types.h", + "cuda/include/host_defines.h", + "cuda/include/library_types.h", + "cuda/include/math_constants.h", "cuda/include/math_functions.h", + "cuda/include/math_functions.hpp", + "cuda/include/math_functions_dbl_ptx3.h", + "cuda/include/math_functions_dbl_ptx3.hpp", + "cuda/include/mma.h", + "cuda/include/npp.h", + "cuda/include/nppcore.h", + "cuda/include/nppdefs.h", + "cuda/include/nppi.h", + "cuda/include/nppi_arithmetic_and_logical_operations.h", + "cuda/include/nppi_color_conversion.h", + "cuda/include/nppi_compression_functions.h", + "cuda/include/nppi_computer_vision.h", + "cuda/include/nppi_data_exchange_and_initialization.h", + "cuda/include/nppi_filtering_functions.h", + "cuda/include/nppi_geometry_transforms.h", + "cuda/include/nppi_linear_transforms.h", + "cuda/include/nppi_morphological_operations.h", + "cuda/include/nppi_statistics_functions.h", + "cuda/include/nppi_support_functions.h", + "cuda/include/nppi_threshold_and_compare_operations.h", + "cuda/include/npps.h", + "cuda/include/npps_arithmetic_and_logical_operations.h", + "cuda/include/npps_conversion_functions.h", + "cuda/include/npps_filtering_functions.h", + "cuda/include/npps_initialization.h", + "cuda/include/npps_statistics_functions.h", + "cuda/include/npps_support_functions.h", + "cuda/include/nppversion.h", + "cuda/include/nvToolsExt.h", + "cuda/include/nvToolsExtCuda.h", + "cuda/include/nvToolsExtCudaRt.h", "cuda/include/nvToolsExtMeta.h", + "cuda/include/nvToolsExtSync.h", + "cuda/include/nvblas.h", + "cuda/include/nvfunctional", + "cuda/include/nvgraph.h", + "cuda/include/nvml.h", + "cuda/include/nvrtc.h", + "cuda/include/sm_20_atomic_functions.h", "cuda/include/sm_20_atomic_functions.hpp", - "cuda/include/device_functions.h", - "cuda/include/device_types.h", - "cuda/include/npps_conversion_functions.h", - "cuda/include/curand_precalc.h", - "cuda/include/cusolverRf.h", + "cuda/include/sm_20_intrinsics.h", + "cuda/include/sm_20_intrinsics.hpp", + "cuda/include/sm_30_intrinsics.h", + "cuda/include/sm_30_intrinsics.hpp", + "cuda/include/sm_32_atomic_functions.h", + "cuda/include/sm_32_atomic_functions.hpp", + "cuda/include/sm_32_intrinsics.h", + "cuda/include/sm_32_intrinsics.hpp", + "cuda/include/sm_35_atomic_functions.h", + "cuda/include/sm_35_intrinsics.h", + "cuda/include/sm_60_atomic_functions.h", "cuda/include/sm_60_atomic_functions.hpp", - "cuda/include/cuviddec.h", - "cuda/include/curand_discrete2.h", - "cuda/include/device_functions.hpp", - "cuda/include/thrust/transform_scan.h", - "cuda/include/thrust/system_error.h", - "cuda/include/thrust/device_malloc.h", - "cuda/include/thrust/partition.h", - "cuda/include/thrust/unique.h", - "cuda/include/thrust/device_delete.h", - "cuda/include/thrust/execution_policy.h", + "cuda/include/sm_61_intrinsics.h", + "cuda/include/sm_61_intrinsics.hpp", + "cuda/include/sobol_direction_vectors.h", + "cuda/include/surface_functions.h", + "cuda/include/surface_functions.hpp", + "cuda/include/surface_indirect_functions.h", + "cuda/include/surface_indirect_functions.hpp", + "cuda/include/surface_types.h", + "cuda/include/texture_fetch_functions.h", + "cuda/include/texture_fetch_functions.hpp", + "cuda/include/texture_indirect_functions.h", + "cuda/include/texture_indirect_functions.hpp", + "cuda/include/texture_types.h", "cuda/include/thrust/adjacent_difference.h", - "cuda/include/thrust/sequence.h", - "cuda/include/thrust/merge.h", - "cuda/include/thrust/device_new.h", - "cuda/include/thrust/transform_reduce.h", - "cuda/include/thrust/device_vector.h", - "cuda/include/thrust/gather.h", - "cuda/include/thrust/sort.h", - "cuda/include/thrust/scan.h", - "cuda/include/thrust/detail/temporary_array.h", - "cuda/include/thrust/detail/util/align.h", - "cuda/include/thrust/detail/util/blocking.h", - "cuda/include/thrust/detail/transform.inl", - "cuda/include/thrust/detail/device_vector.inl", + "cuda/include/thrust/advance.h", + "cuda/include/thrust/binary_search.h", + "cuda/include/thrust/complex.h", + "cuda/include/thrust/copy.h", + "cuda/include/thrust/count.h", + "cuda/include/thrust/detail/adjacent_difference.inl", + "cuda/include/thrust/detail/advance.inl", + "cuda/include/thrust/detail/allocator/allocator_traits.h", + "cuda/include/thrust/detail/allocator/allocator_traits.inl", + "cuda/include/thrust/detail/allocator/copy_construct_range.h", + "cuda/include/thrust/detail/allocator/copy_construct_range.inl", + "cuda/include/thrust/detail/allocator/default_construct_range.h", + "cuda/include/thrust/detail/allocator/default_construct_range.inl", + "cuda/include/thrust/detail/allocator/destroy_range.h", + "cuda/include/thrust/detail/allocator/destroy_range.inl", + "cuda/include/thrust/detail/allocator/fill_construct_range.h", + "cuda/include/thrust/detail/allocator/fill_construct_range.inl", + "cuda/include/thrust/detail/allocator/malloc_allocator.h", + "cuda/include/thrust/detail/allocator/malloc_allocator.inl", + "cuda/include/thrust/detail/allocator/no_throw_allocator.h", + "cuda/include/thrust/detail/allocator/tagged_allocator.h", + "cuda/include/thrust/detail/allocator/tagged_allocator.inl", + "cuda/include/thrust/detail/allocator/temporary_allocator.h", + "cuda/include/thrust/detail/allocator/temporary_allocator.inl", "cuda/include/thrust/detail/binary_search.inl", - "cuda/include/thrust/detail/overlapped_copy.h", - "cuda/include/thrust/detail/vector_base.inl", - "cuda/include/thrust/detail/device_reference.inl", - "cuda/include/thrust/detail/functional/actor.h", - "cuda/include/thrust/detail/functional/value.h", - "cuda/include/thrust/detail/functional/operators.h", - "cuda/include/thrust/detail/functional/operators/logical_operators.h", - "cuda/include/thrust/detail/functional/operators/relational_operators.h", - "cuda/include/thrust/detail/functional/operators/assignment_operator.h", - "cuda/include/thrust/detail/functional/operators/bitwise_operators.h", - "cuda/include/thrust/detail/functional/operators/operator_adaptors.h", - "cuda/include/thrust/detail/functional/operators/arithmetic_operators.h", - "cuda/include/thrust/detail/functional/operators/compound_assignment_operators.h", - "cuda/include/thrust/detail/functional/argument.h", - "cuda/include/thrust/detail/functional/placeholder.h", - "cuda/include/thrust/detail/functional/actor.inl", - "cuda/include/thrust/detail/functional/composite.h", - "cuda/include/thrust/detail/static_map.h", - "cuda/include/thrust/detail/type_traits/has_nested_type.h", - "cuda/include/thrust/detail/type_traits/is_call_possible.h", - "cuda/include/thrust/detail/type_traits/function_traits.h", - "cuda/include/thrust/detail/type_traits/pointer_traits.h", - "cuda/include/thrust/detail/type_traits/has_member_function.h", - "cuda/include/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h", - "cuda/include/thrust/detail/type_traits/minimum_type.h", - "cuda/include/thrust/detail/type_traits/has_trivial_assign.h", - "cuda/include/thrust/detail/type_traits/is_metafunction_defined.h", - "cuda/include/thrust/detail/type_traits/iterator/is_discard_iterator.h", - "cuda/include/thrust/detail/type_traits/iterator/is_output_iterator.h", - "cuda/include/thrust/detail/type_traits/result_of_adaptable_function.h", - "cuda/include/thrust/detail/reference.h", - "cuda/include/thrust/detail/inner_product.inl", - "cuda/include/thrust/detail/use_default.h", - "cuda/include/thrust/detail/sequence.inl", - "cuda/include/thrust/detail/sort.inl", - "cuda/include/thrust/detail/equal.inl", - "cuda/include/thrust/detail/execution_policy.h", - "cuda/include/thrust/detail/integer_traits.h", - "cuda/include/thrust/detail/type_traits.h", - "cuda/include/thrust/detail/reverse.inl", - "cuda/include/thrust/detail/tabulate.inl", - "cuda/include/thrust/detail/unique.inl", - "cuda/include/thrust/detail/scatter.inl", - "cuda/include/thrust/detail/set_operations.inl", - "cuda/include/thrust/detail/device_malloc.inl", - "cuda/include/thrust/detail/copy_if.inl", - "cuda/include/thrust/detail/fill.inl", - "cuda/include/thrust/detail/temporary_array.inl", - "cuda/include/thrust/detail/transform_scan.inl", - "cuda/include/thrust/detail/minmax.h", - "cuda/include/thrust/detail/swap.inl", - "cuda/include/thrust/detail/pointer.inl", - "cuda/include/thrust/detail/transform_reduce.inl", - "cuda/include/thrust/detail/config.h", - "cuda/include/thrust/detail/distance.inl", - "cuda/include/thrust/detail/pair.inl", - "cuda/include/thrust/detail/allocator/temporary_allocator.h", - "cuda/include/thrust/detail/allocator/tagged_allocator.h", - "cuda/include/thrust/detail/allocator/destroy_range.inl", - "cuda/include/thrust/detail/allocator/destroy_range.h", - "cuda/include/thrust/detail/allocator/no_throw_allocator.h", - "cuda/include/thrust/detail/allocator/default_construct_range.inl", - "cuda/include/thrust/detail/allocator/fill_construct_range.inl", - "cuda/include/thrust/detail/allocator/tagged_allocator.inl", - "cuda/include/thrust/detail/allocator/malloc_allocator.h", - "cuda/include/thrust/detail/allocator/allocator_traits.h", - "cuda/include/thrust/detail/allocator/copy_construct_range.h", - "cuda/include/thrust/detail/allocator/allocator_traits.inl", - "cuda/include/thrust/detail/allocator/default_construct_range.h", - "cuda/include/thrust/detail/allocator/copy_construct_range.inl", - "cuda/include/thrust/detail/allocator/malloc_allocator.inl", - "cuda/include/thrust/detail/allocator/temporary_allocator.inl", - "cuda/include/thrust/detail/allocator/fill_construct_range.h", - "cuda/include/thrust/detail/temporary_buffer.h", - "cuda/include/thrust/detail/reduce.inl", - "cuda/include/thrust/detail/device_new.inl", - "cuda/include/thrust/detail/pointer.h", - "cuda/include/thrust/detail/for_each.inl", - "cuda/include/thrust/detail/generate.inl", - "cuda/include/thrust/detail/dispatch/is_trivial_copy.h", - "cuda/include/thrust/detail/adjacent_difference.inl", - "cuda/include/thrust/detail/tuple_meta_transform.h", - "cuda/include/thrust/detail/functional.inl", - "cuda/include/thrust/detail/remove.inl", - "cuda/include/thrust/detail/tuple_transform.h", - "cuda/include/thrust/detail/merge.inl", - "cuda/include/thrust/detail/extrema.inl", - "cuda/include/thrust/detail/trivial_sequence.h", - "cuda/include/thrust/detail/vector_base.h", - "cuda/include/thrust/detail/count.inl", - "cuda/include/thrust/detail/uninitialized_copy.inl", - "cuda/include/thrust/detail/function.h", - "cuda/include/thrust/detail/swap_ranges.inl", - "cuda/include/thrust/detail/device_delete.inl", - "cuda/include/thrust/detail/static_assert.h", - "cuda/include/thrust/detail/logical.inl", - "cuda/include/thrust/detail/seq.h", - "cuda/include/thrust/detail/mpl/math.h", - "cuda/include/thrust/detail/mismatch.inl", - "cuda/include/thrust/detail/internal_functional.h", - "cuda/include/thrust/detail/get_iterator_value.h", - "cuda/include/thrust/detail/copy.inl", - "cuda/include/thrust/detail/copy.h", + "cuda/include/thrust/detail/complex/arithmetic.h", + "cuda/include/thrust/detail/complex/c99math.h", + "cuda/include/thrust/detail/complex/catrig.h", "cuda/include/thrust/detail/complex/catrigf.h", - "cuda/include/thrust/detail/complex/cpowf.h", - "cuda/include/thrust/detail/complex/csqrtf.h", + "cuda/include/thrust/detail/complex/ccosh.h", "cuda/include/thrust/detail/complex/ccoshf.h", - "cuda/include/thrust/detail/complex/csinhf.h", + "cuda/include/thrust/detail/complex/cexp.h", + "cuda/include/thrust/detail/complex/cexpf.h", + "cuda/include/thrust/detail/complex/clog.h", "cuda/include/thrust/detail/complex/clogf.h", - "cuda/include/thrust/detail/complex/ccosh.h", - "cuda/include/thrust/detail/complex/arithmetic.h", - "cuda/include/thrust/detail/complex/csqrt.h", - "cuda/include/thrust/detail/complex/cpow.h", "cuda/include/thrust/detail/complex/complex.inl", - "cuda/include/thrust/detail/complex/math_private.h", - "cuda/include/thrust/detail/complex/c99math.h", + "cuda/include/thrust/detail/complex/cpow.h", + "cuda/include/thrust/detail/complex/cpowf.h", "cuda/include/thrust/detail/complex/cproj.h", - "cuda/include/thrust/detail/complex/catrig.h", - "cuda/include/thrust/detail/complex/ctanhf.h", - "cuda/include/thrust/detail/complex/cexpf.h", "cuda/include/thrust/detail/complex/csinh.h", - "cuda/include/thrust/detail/complex/stream.h", + "cuda/include/thrust/detail/complex/csinhf.h", + "cuda/include/thrust/detail/complex/csqrt.h", + "cuda/include/thrust/detail/complex/csqrtf.h", "cuda/include/thrust/detail/complex/ctanh.h", - "cuda/include/thrust/detail/complex/cexp.h", - "cuda/include/thrust/detail/complex/clog.h", - "cuda/include/thrust/detail/range/head_flags.h", - "cuda/include/thrust/detail/range/tail_flags.h", - "cuda/include/thrust/detail/execute_with_allocator.h", - "cuda/include/thrust/detail/integer_math.h", - "cuda/include/thrust/detail/swap.h", - "cuda/include/thrust/detail/uninitialized_fill.inl", - "cuda/include/thrust/detail/scan.inl", - "cuda/include/thrust/detail/gather.inl", - "cuda/include/thrust/detail/reference_forward_declaration.h", - "cuda/include/thrust/detail/numeric_traits.h", - "cuda/include/thrust/detail/reference.inl", - "cuda/include/thrust/detail/cstdint.h", - "cuda/include/thrust/detail/device_free.inl", - "cuda/include/thrust/detail/copy_if.h", - "cuda/include/thrust/detail/partition.inl", - "cuda/include/thrust/detail/find.inl", - "cuda/include/thrust/detail/config/forceinline.h", - "cuda/include/thrust/detail/config/debug.h", - "cuda/include/thrust/detail/config/config.h", - "cuda/include/thrust/detail/config/host_device.h", - "cuda/include/thrust/detail/config/host_system.h", + "cuda/include/thrust/detail/complex/ctanhf.h", + "cuda/include/thrust/detail/complex/math_private.h", + "cuda/include/thrust/detail/complex/stream.h", + "cuda/include/thrust/detail/config.h", "cuda/include/thrust/detail/config/compiler.h", - "cuda/include/thrust/detail/config/device_system.h", "cuda/include/thrust/detail/config/compiler_fence.h", + "cuda/include/thrust/detail/config/config.h", + "cuda/include/thrust/detail/config/debug.h", + "cuda/include/thrust/detail/config/device_system.h", "cuda/include/thrust/detail/config/exec_check_disable.h", - "cuda/include/thrust/detail/config/simple_defines.h", + "cuda/include/thrust/detail/config/forceinline.h", "cuda/include/thrust/detail/config/global_workarounds.h", - "cuda/include/thrust/detail/replace.inl", + "cuda/include/thrust/detail/config/host_device.h", + "cuda/include/thrust/detail/config/host_system.h", + "cuda/include/thrust/detail/config/simple_defines.h", + "cuda/include/thrust/detail/contiguous_storage.h", + "cuda/include/thrust/detail/contiguous_storage.inl", + "cuda/include/thrust/detail/copy.h", + "cuda/include/thrust/detail/copy.inl", + "cuda/include/thrust/detail/copy_if.h", + "cuda/include/thrust/detail/copy_if.inl", + "cuda/include/thrust/detail/count.inl", + "cuda/include/thrust/detail/cstdint.h", + "cuda/include/thrust/detail/device_delete.inl", + "cuda/include/thrust/detail/device_free.inl", + "cuda/include/thrust/detail/device_malloc.inl", + "cuda/include/thrust/detail/device_new.inl", "cuda/include/thrust/detail/device_ptr.inl", - "cuda/include/thrust/detail/tuple.inl", - "cuda/include/thrust/detail/malloc_and_free.h", + "cuda/include/thrust/detail/device_reference.inl", + "cuda/include/thrust/detail/device_vector.inl", + "cuda/include/thrust/detail/dispatch/is_trivial_copy.h", + "cuda/include/thrust/detail/distance.inl", + "cuda/include/thrust/detail/equal.inl", + "cuda/include/thrust/detail/execute_with_allocator.h", + "cuda/include/thrust/detail/execution_policy.h", + "cuda/include/thrust/detail/extrema.inl", + "cuda/include/thrust/detail/fill.inl", + "cuda/include/thrust/detail/find.inl", + "cuda/include/thrust/detail/for_each.inl", + "cuda/include/thrust/detail/function.h", + "cuda/include/thrust/detail/functional.inl", + "cuda/include/thrust/detail/functional/actor.h", + "cuda/include/thrust/detail/functional/actor.inl", + "cuda/include/thrust/detail/functional/argument.h", + "cuda/include/thrust/detail/functional/composite.h", + "cuda/include/thrust/detail/functional/operators.h", + "cuda/include/thrust/detail/functional/operators/arithmetic_operators.h", + "cuda/include/thrust/detail/functional/operators/assignment_operator.h", + "cuda/include/thrust/detail/functional/operators/bitwise_operators.h", + "cuda/include/thrust/detail/functional/operators/compound_assignment_operators.h", + "cuda/include/thrust/detail/functional/operators/logical_operators.h", + "cuda/include/thrust/detail/functional/operators/operator_adaptors.h", + "cuda/include/thrust/detail/functional/operators/relational_operators.h", + "cuda/include/thrust/detail/functional/placeholder.h", + "cuda/include/thrust/detail/functional/value.h", + "cuda/include/thrust/detail/gather.inl", + "cuda/include/thrust/detail/generate.inl", + "cuda/include/thrust/detail/get_iterator_value.h", "cuda/include/thrust/detail/host_vector.inl", + "cuda/include/thrust/detail/inner_product.inl", + "cuda/include/thrust/detail/integer_math.h", + "cuda/include/thrust/detail/integer_traits.h", + "cuda/include/thrust/detail/internal_functional.h", + "cuda/include/thrust/detail/logical.inl", + "cuda/include/thrust/detail/malloc_and_free.h", + "cuda/include/thrust/detail/merge.inl", + "cuda/include/thrust/detail/minmax.h", + "cuda/include/thrust/detail/mismatch.inl", + "cuda/include/thrust/detail/mpl/math.h", + "cuda/include/thrust/detail/numeric_traits.h", + "cuda/include/thrust/detail/overlapped_copy.h", + "cuda/include/thrust/detail/pair.inl", + "cuda/include/thrust/detail/partition.inl", + "cuda/include/thrust/detail/pointer.h", + "cuda/include/thrust/detail/pointer.inl", + "cuda/include/thrust/detail/range/head_flags.h", + "cuda/include/thrust/detail/range/tail_flags.h", "cuda/include/thrust/detail/raw_pointer_cast.h", - "cuda/include/thrust/detail/advance.inl", - "cuda/include/thrust/detail/contiguous_storage.h", "cuda/include/thrust/detail/raw_reference_cast.h", - "cuda/include/thrust/detail/contiguous_storage.inl", - "cuda/include/thrust/reverse.h", - "cuda/include/thrust/device_malloc_allocator.h", - "cuda/include/thrust/scatter.h", - "cuda/include/thrust/pair.h", - "cuda/include/thrust/advance.h", - "cuda/include/thrust/find.h", - "cuda/include/thrust/device_ptr.h", - "cuda/include/thrust/generate.h", - "cuda/include/thrust/uninitialized_fill.h", - "cuda/include/thrust/system/system_error.h", - "cuda/include/thrust/system/detail/bad_alloc.h", - "cuda/include/thrust/system/detail/adl/transform_scan.h", - "cuda/include/thrust/system/detail/adl/unique_by_key.h", - "cuda/include/thrust/system/detail/adl/partition.h", - "cuda/include/thrust/system/detail/adl/unique.h", - "cuda/include/thrust/system/detail/adl/adjacent_difference.h", - "cuda/include/thrust/system/detail/adl/sequence.h", - "cuda/include/thrust/system/detail/adl/merge.h", - "cuda/include/thrust/system/detail/adl/transform_reduce.h", - "cuda/include/thrust/system/detail/adl/gather.h", - "cuda/include/thrust/system/detail/adl/sort.h", - "cuda/include/thrust/system/detail/adl/scan.h", - "cuda/include/thrust/system/detail/adl/temporary_buffer.h", - "cuda/include/thrust/system/detail/adl/scan_by_key.h", - "cuda/include/thrust/system/detail/adl/reverse.h", - "cuda/include/thrust/system/detail/adl/assign_value.h", - "cuda/include/thrust/system/detail/adl/scatter.h", - "cuda/include/thrust/system/detail/adl/find.h", - "cuda/include/thrust/system/detail/adl/generate.h", - "cuda/include/thrust/system/detail/adl/uninitialized_fill.h", - "cuda/include/thrust/system/detail/adl/remove.h", - "cuda/include/thrust/system/detail/adl/tabulate.h", - "cuda/include/thrust/system/detail/adl/for_each.h", - "cuda/include/thrust/system/detail/adl/reduce_by_key.h", - "cuda/include/thrust/system/detail/adl/reduce.h", - "cuda/include/thrust/system/detail/adl/equal.h", - "cuda/include/thrust/system/detail/adl/copy.h", - "cuda/include/thrust/system/detail/adl/swap_ranges.h", - "cuda/include/thrust/system/detail/adl/uninitialized_copy.h", - "cuda/include/thrust/system/detail/adl/binary_search.h", - "cuda/include/thrust/system/detail/adl/set_operations.h", - "cuda/include/thrust/system/detail/adl/mismatch.h", - "cuda/include/thrust/system/detail/adl/extrema.h", - "cuda/include/thrust/system/detail/adl/count.h", - "cuda/include/thrust/system/detail/adl/replace.h", + "cuda/include/thrust/detail/reduce.inl", + "cuda/include/thrust/detail/reference.h", + "cuda/include/thrust/detail/reference.inl", + "cuda/include/thrust/detail/reference_forward_declaration.h", + "cuda/include/thrust/detail/remove.inl", + "cuda/include/thrust/detail/replace.inl", + "cuda/include/thrust/detail/reverse.inl", + "cuda/include/thrust/detail/scan.inl", + "cuda/include/thrust/detail/scatter.inl", + "cuda/include/thrust/detail/seq.h", + "cuda/include/thrust/detail/sequence.inl", + "cuda/include/thrust/detail/set_operations.inl", + "cuda/include/thrust/detail/sort.inl", + "cuda/include/thrust/detail/static_assert.h", + "cuda/include/thrust/detail/static_map.h", + "cuda/include/thrust/detail/swap.h", + "cuda/include/thrust/detail/swap.inl", + "cuda/include/thrust/detail/swap_ranges.inl", + "cuda/include/thrust/detail/tabulate.inl", + "cuda/include/thrust/detail/temporary_array.h", + "cuda/include/thrust/detail/temporary_array.inl", + "cuda/include/thrust/detail/temporary_buffer.h", + "cuda/include/thrust/detail/transform.inl", + "cuda/include/thrust/detail/transform_reduce.inl", + "cuda/include/thrust/detail/transform_scan.inl", + "cuda/include/thrust/detail/trivial_sequence.h", + "cuda/include/thrust/detail/tuple.inl", + "cuda/include/thrust/detail/tuple_meta_transform.h", + "cuda/include/thrust/detail/tuple_transform.h", + "cuda/include/thrust/detail/type_traits.h", + "cuda/include/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h", + "cuda/include/thrust/detail/type_traits/function_traits.h", + "cuda/include/thrust/detail/type_traits/has_member_function.h", + "cuda/include/thrust/detail/type_traits/has_nested_type.h", + "cuda/include/thrust/detail/type_traits/has_trivial_assign.h", + "cuda/include/thrust/detail/type_traits/is_call_possible.h", + "cuda/include/thrust/detail/type_traits/is_metafunction_defined.h", + "cuda/include/thrust/detail/type_traits/iterator/is_discard_iterator.h", + "cuda/include/thrust/detail/type_traits/iterator/is_output_iterator.h", + "cuda/include/thrust/detail/type_traits/minimum_type.h", + "cuda/include/thrust/detail/type_traits/pointer_traits.h", + "cuda/include/thrust/detail/type_traits/result_of_adaptable_function.h", + "cuda/include/thrust/detail/uninitialized_copy.inl", + "cuda/include/thrust/detail/uninitialized_fill.inl", + "cuda/include/thrust/detail/unique.inl", + "cuda/include/thrust/detail/use_default.h", + "cuda/include/thrust/detail/util/align.h", + "cuda/include/thrust/detail/util/blocking.h", + "cuda/include/thrust/detail/vector_base.h", + "cuda/include/thrust/detail/vector_base.inl", + "cuda/include/thrust/device_allocator.h", + "cuda/include/thrust/device_delete.h", + "cuda/include/thrust/device_free.h", + "cuda/include/thrust/device_malloc.h", + "cuda/include/thrust/device_malloc_allocator.h", + "cuda/include/thrust/device_new.h", + "cuda/include/thrust/device_new_allocator.h", + "cuda/include/thrust/device_ptr.h", + "cuda/include/thrust/device_reference.h", + "cuda/include/thrust/device_vector.h", + "cuda/include/thrust/distance.h", + "cuda/include/thrust/equal.h", + "cuda/include/thrust/execution_policy.h", + "cuda/include/thrust/extrema.h", + "cuda/include/thrust/fill.h", + "cuda/include/thrust/find.h", + "cuda/include/thrust/for_each.h", + "cuda/include/thrust/functional.h", + "cuda/include/thrust/gather.h", + "cuda/include/thrust/generate.h", + "cuda/include/thrust/host_vector.h", + "cuda/include/thrust/inner_product.h", + "cuda/include/thrust/iterator/constant_iterator.h", + "cuda/include/thrust/iterator/counting_iterator.h", + "cuda/include/thrust/iterator/detail/any_assign.h", + "cuda/include/thrust/iterator/detail/any_system_tag.h", + "cuda/include/thrust/iterator/detail/constant_iterator_base.h", + "cuda/include/thrust/iterator/detail/counting_iterator.inl", + "cuda/include/thrust/iterator/detail/device_system_tag.h", + "cuda/include/thrust/iterator/detail/discard_iterator_base.h", + "cuda/include/thrust/iterator/detail/distance_from_result.h", + "cuda/include/thrust/iterator/detail/host_system_tag.h", + "cuda/include/thrust/iterator/detail/is_iterator_category.h", + "cuda/include/thrust/iterator/detail/is_trivial_iterator.h", + "cuda/include/thrust/iterator/detail/iterator_adaptor_base.h", + "cuda/include/thrust/iterator/detail/iterator_category_to_system.h", + "cuda/include/thrust/iterator/detail/iterator_category_to_traversal.h", + "cuda/include/thrust/iterator/detail/iterator_category_with_system_and_traversal.h", + "cuda/include/thrust/iterator/detail/iterator_facade_category.h", + "cuda/include/thrust/iterator/detail/iterator_traits.inl", + "cuda/include/thrust/iterator/detail/iterator_traversal_tags.h", + "cuda/include/thrust/iterator/detail/join_iterator.h", + "cuda/include/thrust/iterator/detail/minimum_category.h", + "cuda/include/thrust/iterator/detail/minimum_system.h", + "cuda/include/thrust/iterator/detail/normal_iterator.h", + "cuda/include/thrust/iterator/detail/permutation_iterator_base.h", + "cuda/include/thrust/iterator/detail/retag.h", + "cuda/include/thrust/iterator/detail/reverse_iterator.inl", + "cuda/include/thrust/iterator/detail/reverse_iterator_base.h", + "cuda/include/thrust/iterator/detail/tagged_iterator.h", + "cuda/include/thrust/iterator/detail/transform_iterator.inl", + "cuda/include/thrust/iterator/detail/transform_output_iterator.inl", + "cuda/include/thrust/iterator/detail/tuple_of_iterator_references.h", + "cuda/include/thrust/iterator/detail/universal_categories.h", + "cuda/include/thrust/iterator/detail/zip_iterator.inl", + "cuda/include/thrust/iterator/detail/zip_iterator_base.h", + "cuda/include/thrust/iterator/discard_iterator.h", + "cuda/include/thrust/iterator/iterator_adaptor.h", + "cuda/include/thrust/iterator/iterator_categories.h", + "cuda/include/thrust/iterator/iterator_facade.h", + "cuda/include/thrust/iterator/iterator_traits.h", + "cuda/include/thrust/iterator/permutation_iterator.h", + "cuda/include/thrust/iterator/retag.h", + "cuda/include/thrust/iterator/reverse_iterator.h", + "cuda/include/thrust/iterator/transform_iterator.h", + "cuda/include/thrust/iterator/transform_output_iterator.h", + "cuda/include/thrust/iterator/zip_iterator.h", + "cuda/include/thrust/logical.h", + "cuda/include/thrust/memory.h", + "cuda/include/thrust/merge.h", + "cuda/include/thrust/mismatch.h", + "cuda/include/thrust/pair.h", + "cuda/include/thrust/partition.h", + "cuda/include/thrust/random.h", + "cuda/include/thrust/random/detail/discard_block_engine.inl", + "cuda/include/thrust/random/detail/linear_congruential_engine.inl", + "cuda/include/thrust/random/detail/linear_congruential_engine_discard.h", + "cuda/include/thrust/random/detail/linear_feedback_shift_engine.inl", + "cuda/include/thrust/random/detail/linear_feedback_shift_engine_wordmask.h", + "cuda/include/thrust/random/detail/mod.h", + "cuda/include/thrust/random/detail/normal_distribution.inl", + "cuda/include/thrust/random/detail/normal_distribution_base.h", + "cuda/include/thrust/random/detail/random_core_access.h", + "cuda/include/thrust/random/detail/subtract_with_carry_engine.inl", + "cuda/include/thrust/random/detail/uniform_int_distribution.inl", + "cuda/include/thrust/random/detail/uniform_real_distribution.inl", + "cuda/include/thrust/random/detail/xor_combine_engine.inl", + "cuda/include/thrust/random/detail/xor_combine_engine_max.h", + "cuda/include/thrust/random/discard_block_engine.h", + "cuda/include/thrust/random/linear_congruential_engine.h", + "cuda/include/thrust/random/linear_feedback_shift_engine.h", + "cuda/include/thrust/random/normal_distribution.h", + "cuda/include/thrust/random/subtract_with_carry_engine.h", + "cuda/include/thrust/random/uniform_int_distribution.h", + "cuda/include/thrust/random/uniform_real_distribution.h", + "cuda/include/thrust/random/xor_combine_engine.h", + "cuda/include/thrust/reduce.h", + "cuda/include/thrust/remove.h", + "cuda/include/thrust/replace.h", + "cuda/include/thrust/reverse.h", + "cuda/include/thrust/scan.h", + "cuda/include/thrust/scatter.h", + "cuda/include/thrust/sequence.h", + "cuda/include/thrust/set_operations.h", + "cuda/include/thrust/sort.h", + "cuda/include/thrust/swap.h", + "cuda/include/thrust/system/cpp/detail/adjacent_difference.h", + "cuda/include/thrust/system/cpp/detail/assign_value.h", + "cuda/include/thrust/system/cpp/detail/binary_search.h", + "cuda/include/thrust/system/cpp/detail/copy.h", + "cuda/include/thrust/system/cpp/detail/copy_if.h", + "cuda/include/thrust/system/cpp/detail/count.h", + "cuda/include/thrust/system/cpp/detail/equal.h", + "cuda/include/thrust/system/cpp/detail/execution_policy.h", + "cuda/include/thrust/system/cpp/detail/extrema.h", + "cuda/include/thrust/system/cpp/detail/fill.h", + "cuda/include/thrust/system/cpp/detail/find.h", + "cuda/include/thrust/system/cpp/detail/for_each.h", + "cuda/include/thrust/system/cpp/detail/gather.h", + "cuda/include/thrust/system/cpp/detail/generate.h", + "cuda/include/thrust/system/cpp/detail/get_value.h", + "cuda/include/thrust/system/cpp/detail/inner_product.h", + "cuda/include/thrust/system/cpp/detail/iter_swap.h", + "cuda/include/thrust/system/cpp/detail/logical.h", + "cuda/include/thrust/system/cpp/detail/malloc_and_free.h", + "cuda/include/thrust/system/cpp/detail/memory.inl", + "cuda/include/thrust/system/cpp/detail/merge.h", + "cuda/include/thrust/system/cpp/detail/mismatch.h", + "cuda/include/thrust/system/cpp/detail/par.h", + "cuda/include/thrust/system/cpp/detail/partition.h", + "cuda/include/thrust/system/cpp/detail/reduce.h", + "cuda/include/thrust/system/cpp/detail/reduce_by_key.h", + "cuda/include/thrust/system/cpp/detail/remove.h", + "cuda/include/thrust/system/cpp/detail/replace.h", + "cuda/include/thrust/system/cpp/detail/reverse.h", + "cuda/include/thrust/system/cpp/detail/scan.h", + "cuda/include/thrust/system/cpp/detail/scan_by_key.h", + "cuda/include/thrust/system/cpp/detail/scatter.h", + "cuda/include/thrust/system/cpp/detail/sequence.h", + "cuda/include/thrust/system/cpp/detail/set_operations.h", + "cuda/include/thrust/system/cpp/detail/sort.h", + "cuda/include/thrust/system/cpp/detail/swap_ranges.h", + "cuda/include/thrust/system/cpp/detail/tabulate.h", + "cuda/include/thrust/system/cpp/detail/temporary_buffer.h", + "cuda/include/thrust/system/cpp/detail/transform.h", + "cuda/include/thrust/system/cpp/detail/transform_reduce.h", + "cuda/include/thrust/system/cpp/detail/transform_scan.h", + "cuda/include/thrust/system/cpp/detail/uninitialized_copy.h", + "cuda/include/thrust/system/cpp/detail/uninitialized_fill.h", + "cuda/include/thrust/system/cpp/detail/unique.h", + "cuda/include/thrust/system/cpp/detail/unique_by_key.h", + "cuda/include/thrust/system/cpp/detail/vector.inl", + "cuda/include/thrust/system/cpp/execution_policy.h", + "cuda/include/thrust/system/cpp/memory.h", + "cuda/include/thrust/system/cpp/vector.h", + "cuda/include/thrust/system/cuda/config.h", + "cuda/include/thrust/system/cuda/detail/adjacent_difference.h", + "cuda/include/thrust/system/cuda/detail/assign_value.h", + "cuda/include/thrust/system/cuda/detail/binary_search.h", + "cuda/include/thrust/system/cuda/detail/copy.h", + "cuda/include/thrust/system/cuda/detail/copy_if.h", + "cuda/include/thrust/system/cuda/detail/core/agent_launcher.h", + "cuda/include/thrust/system/cuda/detail/core/alignment.h", + "cuda/include/thrust/system/cuda/detail/core/triple_chevron_launch.h", + "cuda/include/thrust/system/cuda/detail/core/util.h", + "cuda/include/thrust/system/cuda/detail/count.h", + "cuda/include/thrust/system/cuda/detail/cross_system.h", + "cuda/include/thrust/system/cuda/detail/cub/agent/agent_histogram.cuh", + "cuda/include/thrust/system/cuda/detail/cub/agent/agent_radix_sort_downsweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/agent/agent_radix_sort_upsweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/agent/agent_reduce.cuh", + "cuda/include/thrust/system/cuda/detail/cub/agent/agent_reduce_by_key.cuh", + "cuda/include/thrust/system/cuda/detail/cub/agent/agent_rle.cuh", + "cuda/include/thrust/system/cuda/detail/cub/agent/agent_scan.cuh", + "cuda/include/thrust/system/cuda/detail/cub/agent/agent_segment_fixup.cuh", + "cuda/include/thrust/system/cuda/detail/cub/agent/agent_select_if.cuh", + "cuda/include/thrust/system/cuda/detail/cub/agent/agent_spmv_csrt.cuh", + "cuda/include/thrust/system/cuda/detail/cub/agent/agent_spmv_orig.cuh", + "cuda/include/thrust/system/cuda/detail/cub/agent/agent_spmv_row_based.cuh", + "cuda/include/thrust/system/cuda/detail/cub/agent/single_pass_scan_operators.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_adjacent_difference.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_discontinuity.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_exchange.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_histogram.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_load.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_radix_rank.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_radix_sort.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_raking_layout.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_reduce.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_scan.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_shuffle.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_store.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_atomic.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_sort.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking_commutative_only.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_warp_reductions.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_raking.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans2.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans3.cuh", + "cuda/include/thrust/system/cuda/detail/cub/cub.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_histogram.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_partition.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_radix_sort.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_reduce.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_run_length_encode.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_scan.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_segmented_radix_sort.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_segmented_reduce.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_select.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_spmv.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_histogram.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_radix_sort.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_reduce.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_reduce_by_key.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_rle.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_scan.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_select_if.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_csrt.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_orig.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_row_based.cuh", + "cuda/include/thrust/system/cuda/detail/cub/grid/grid_barrier.cuh", + "cuda/include/thrust/system/cuda/detail/cub/grid/grid_even_share.cuh", + "cuda/include/thrust/system/cuda/detail/cub/grid/grid_mapping.cuh", + "cuda/include/thrust/system/cuda/detail/cub/grid/grid_queue.cuh", + "cuda/include/thrust/system/cuda/detail/cub/host/mutex.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/arg_index_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_output_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/constant_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/discard_output_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/tex_obj_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/tex_ref_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/transform_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/thread/thread_load.cuh", + "cuda/include/thrust/system/cuda/detail/cub/thread/thread_operators.cuh", + "cuda/include/thrust/system/cuda/detail/cub/thread/thread_reduce.cuh", + "cuda/include/thrust/system/cuda/detail/cub/thread/thread_scan.cuh", + "cuda/include/thrust/system/cuda/detail/cub/thread/thread_search.cuh", + "cuda/include/thrust/system/cuda/detail/cub/thread/thread_store.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_allocator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_arch.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_debug.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_device.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_macro.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_namespace.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_ptx.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_type.cuh", + "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_shfl.cuh", + "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_smem.cuh", + "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_shfl.cuh", + "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_smem.cuh", + "cuda/include/thrust/system/cuda/detail/cub/warp/warp_reduce.cuh", + "cuda/include/thrust/system/cuda/detail/cub/warp/warp_scan.cuh", + "cuda/include/thrust/system/cuda/detail/equal.h", + "cuda/include/thrust/system/cuda/detail/error.inl", + "cuda/include/thrust/system/cuda/detail/execution_policy.h", + "cuda/include/thrust/system/cuda/detail/extrema.h", + "cuda/include/thrust/system/cuda/detail/fill.h", + "cuda/include/thrust/system/cuda/detail/find.h", + "cuda/include/thrust/system/cuda/detail/for_each.h", + "cuda/include/thrust/system/cuda/detail/gather.h", + "cuda/include/thrust/system/cuda/detail/generate.h", + "cuda/include/thrust/system/cuda/detail/get_value.h", + "cuda/include/thrust/system/cuda/detail/guarded_cuda_runtime_api.h", + "cuda/include/thrust/system/cuda/detail/guarded_driver_types.h", + "cuda/include/thrust/system/cuda/detail/inner_product.h", + "cuda/include/thrust/system/cuda/detail/internal/copy_cross_system.h", + "cuda/include/thrust/system/cuda/detail/internal/copy_device_to_device.h", + "cuda/include/thrust/system/cuda/detail/iter_swap.h", + "cuda/include/thrust/system/cuda/detail/logical.h", + "cuda/include/thrust/system/cuda/detail/malloc_and_free.h", + "cuda/include/thrust/system/cuda/detail/memory.inl", + "cuda/include/thrust/system/cuda/detail/memory_buffer.h", + "cuda/include/thrust/system/cuda/detail/merge.h", + "cuda/include/thrust/system/cuda/detail/mismatch.h", + "cuda/include/thrust/system/cuda/detail/par.h", + "cuda/include/thrust/system/cuda/detail/par_to_seq.h", + "cuda/include/thrust/system/cuda/detail/parallel_for.h", + "cuda/include/thrust/system/cuda/detail/partition.h", + "cuda/include/thrust/system/cuda/detail/reduce.h", + "cuda/include/thrust/system/cuda/detail/reduce_by_key.h", + "cuda/include/thrust/system/cuda/detail/remove.h", + "cuda/include/thrust/system/cuda/detail/replace.h", + "cuda/include/thrust/system/cuda/detail/reverse.h", + "cuda/include/thrust/system/cuda/detail/scan.h", + "cuda/include/thrust/system/cuda/detail/scan_by_key.h", + "cuda/include/thrust/system/cuda/detail/scatter.h", + "cuda/include/thrust/system/cuda/detail/sequence.h", + "cuda/include/thrust/system/cuda/detail/set_operations.h", + "cuda/include/thrust/system/cuda/detail/sort.h", + "cuda/include/thrust/system/cuda/detail/swap_ranges.h", + "cuda/include/thrust/system/cuda/detail/tabulate.h", + "cuda/include/thrust/system/cuda/detail/temporary_buffer.h", + "cuda/include/thrust/system/cuda/detail/terminate.h", + "cuda/include/thrust/system/cuda/detail/transform.h", + "cuda/include/thrust/system/cuda/detail/transform_reduce.h", + "cuda/include/thrust/system/cuda/detail/transform_scan.h", + "cuda/include/thrust/system/cuda/detail/uninitialized_copy.h", + "cuda/include/thrust/system/cuda/detail/uninitialized_fill.h", + "cuda/include/thrust/system/cuda/detail/unique.h", + "cuda/include/thrust/system/cuda/detail/unique_by_key.h", + "cuda/include/thrust/system/cuda/detail/util.h", + "cuda/include/thrust/system/cuda/detail/vector.inl", + "cuda/include/thrust/system/cuda/error.h", + "cuda/include/thrust/system/cuda/execution_policy.h", + "cuda/include/thrust/system/cuda/experimental/pinned_allocator.h", + "cuda/include/thrust/system/cuda/memory.h", + "cuda/include/thrust/system/cuda/vector.h", + "cuda/include/thrust/system/detail/adl/adjacent_difference.h", + "cuda/include/thrust/system/detail/adl/assign_value.h", + "cuda/include/thrust/system/detail/adl/binary_search.h", + "cuda/include/thrust/system/detail/adl/copy.h", + "cuda/include/thrust/system/detail/adl/copy_if.h", + "cuda/include/thrust/system/detail/adl/count.h", + "cuda/include/thrust/system/detail/adl/equal.h", + "cuda/include/thrust/system/detail/adl/extrema.h", + "cuda/include/thrust/system/detail/adl/fill.h", + "cuda/include/thrust/system/detail/adl/find.h", + "cuda/include/thrust/system/detail/adl/for_each.h", + "cuda/include/thrust/system/detail/adl/gather.h", + "cuda/include/thrust/system/detail/adl/generate.h", "cuda/include/thrust/system/detail/adl/get_value.h", "cuda/include/thrust/system/detail/adl/inner_product.h", - "cuda/include/thrust/system/detail/adl/copy_if.h", - "cuda/include/thrust/system/detail/adl/logical.h", "cuda/include/thrust/system/detail/adl/iter_swap.h", + "cuda/include/thrust/system/detail/adl/logical.h", "cuda/include/thrust/system/detail/adl/malloc_and_free.h", - "cuda/include/thrust/system/detail/adl/fill.h", + "cuda/include/thrust/system/detail/adl/merge.h", + "cuda/include/thrust/system/detail/adl/mismatch.h", + "cuda/include/thrust/system/detail/adl/partition.h", + "cuda/include/thrust/system/detail/adl/reduce.h", + "cuda/include/thrust/system/detail/adl/reduce_by_key.h", + "cuda/include/thrust/system/detail/adl/remove.h", + "cuda/include/thrust/system/detail/adl/replace.h", + "cuda/include/thrust/system/detail/adl/reverse.h", + "cuda/include/thrust/system/detail/adl/scan.h", + "cuda/include/thrust/system/detail/adl/scan_by_key.h", + "cuda/include/thrust/system/detail/adl/scatter.h", + "cuda/include/thrust/system/detail/adl/sequence.h", + "cuda/include/thrust/system/detail/adl/set_operations.h", + "cuda/include/thrust/system/detail/adl/sort.h", + "cuda/include/thrust/system/detail/adl/swap_ranges.h", + "cuda/include/thrust/system/detail/adl/tabulate.h", + "cuda/include/thrust/system/detail/adl/temporary_buffer.h", "cuda/include/thrust/system/detail/adl/transform.h", + "cuda/include/thrust/system/detail/adl/transform_reduce.h", + "cuda/include/thrust/system/detail/adl/transform_scan.h", + "cuda/include/thrust/system/detail/adl/uninitialized_copy.h", + "cuda/include/thrust/system/detail/adl/uninitialized_fill.h", + "cuda/include/thrust/system/detail/adl/unique.h", + "cuda/include/thrust/system/detail/adl/unique_by_key.h", + "cuda/include/thrust/system/detail/bad_alloc.h", "cuda/include/thrust/system/detail/errno.h", "cuda/include/thrust/system/detail/error_category.inl", - "cuda/include/thrust/system/detail/sequential/transform_scan.h", - "cuda/include/thrust/system/detail/sequential/unique_by_key.h", - "cuda/include/thrust/system/detail/sequential/stable_primitive_sort.h", - "cuda/include/thrust/system/detail/sequential/stable_primitive_sort.inl", - "cuda/include/thrust/system/detail/sequential/stable_merge_sort.h", - "cuda/include/thrust/system/detail/sequential/sort.inl", - "cuda/include/thrust/system/detail/sequential/partition.h", - "cuda/include/thrust/system/detail/sequential/unique.h", - "cuda/include/thrust/system/detail/sequential/execution_policy.h", - "cuda/include/thrust/system/detail/sequential/adjacent_difference.h", - "cuda/include/thrust/system/detail/sequential/sequence.h", - "cuda/include/thrust/system/detail/sequential/merge.h", - "cuda/include/thrust/system/detail/sequential/transform_reduce.h", - "cuda/include/thrust/system/detail/sequential/gather.h", - "cuda/include/thrust/system/detail/sequential/sort.h", - "cuda/include/thrust/system/detail/sequential/copy_backward.h", - "cuda/include/thrust/system/detail/sequential/stable_radix_sort.inl", - "cuda/include/thrust/system/detail/sequential/scan.h", - "cuda/include/thrust/system/detail/sequential/temporary_buffer.h", - "cuda/include/thrust/system/detail/sequential/scan_by_key.h", - "cuda/include/thrust/system/detail/sequential/reverse.h", - "cuda/include/thrust/system/detail/sequential/assign_value.h", - "cuda/include/thrust/system/detail/sequential/scatter.h", - "cuda/include/thrust/system/detail/sequential/find.h", - "cuda/include/thrust/system/detail/sequential/stable_merge_sort.inl", - "cuda/include/thrust/system/detail/sequential/merge.inl", - "cuda/include/thrust/system/detail/sequential/generate.h", - "cuda/include/thrust/system/detail/sequential/uninitialized_fill.h", - "cuda/include/thrust/system/detail/sequential/general_copy.h", - "cuda/include/thrust/system/detail/sequential/insertion_sort.h", - "cuda/include/thrust/system/detail/sequential/remove.h", - "cuda/include/thrust/system/detail/sequential/tabulate.h", - "cuda/include/thrust/system/detail/sequential/for_each.h", - "cuda/include/thrust/system/detail/sequential/reduce_by_key.h", - "cuda/include/thrust/system/detail/sequential/reduce.h", - "cuda/include/thrust/system/detail/sequential/equal.h", - "cuda/include/thrust/system/detail/sequential/stable_radix_sort.h", - "cuda/include/thrust/system/detail/sequential/copy.inl", - "cuda/include/thrust/system/detail/sequential/copy.h", - "cuda/include/thrust/system/detail/sequential/swap_ranges.h", - "cuda/include/thrust/system/detail/sequential/uninitialized_copy.h", - "cuda/include/thrust/system/detail/sequential/binary_search.h", - "cuda/include/thrust/system/detail/sequential/set_operations.h", - "cuda/include/thrust/system/detail/sequential/mismatch.h", - "cuda/include/thrust/system/detail/sequential/extrema.h", - "cuda/include/thrust/system/detail/sequential/count.h", - "cuda/include/thrust/system/detail/sequential/trivial_copy.h", - "cuda/include/thrust/system/detail/sequential/replace.h", - "cuda/include/thrust/system/detail/sequential/get_value.h", - "cuda/include/thrust/system/detail/sequential/inner_product.h", - "cuda/include/thrust/system/detail/sequential/copy_if.h", - "cuda/include/thrust/system/detail/sequential/logical.h", - "cuda/include/thrust/system/detail/sequential/iter_swap.h", - "cuda/include/thrust/system/detail/sequential/malloc_and_free.h", - "cuda/include/thrust/system/detail/sequential/fill.h", - "cuda/include/thrust/system/detail/sequential/transform.h", - "cuda/include/thrust/system/detail/error_condition.inl", - "cuda/include/thrust/system/detail/internal/decompose.h", "cuda/include/thrust/system/detail/error_code.inl", - "cuda/include/thrust/system/detail/generic/transform_scan.h", - "cuda/include/thrust/system/detail/generic/memory.inl", - "cuda/include/thrust/system/detail/generic/transform.inl", - "cuda/include/thrust/system/detail/generic/binary_search.inl", - "cuda/include/thrust/system/detail/generic/scan_by_key.inl", - "cuda/include/thrust/system/detail/generic/unique_by_key.h", - "cuda/include/thrust/system/detail/generic/inner_product.inl", - "cuda/include/thrust/system/detail/generic/select_system.h", - "cuda/include/thrust/system/detail/generic/sequence.inl", - "cuda/include/thrust/system/detail/generic/sort.inl", - "cuda/include/thrust/system/detail/generic/equal.inl", - "cuda/include/thrust/system/detail/generic/partition.h", - "cuda/include/thrust/system/detail/generic/unique.h", + "cuda/include/thrust/system/detail/error_condition.inl", "cuda/include/thrust/system/detail/generic/adjacent_difference.h", - "cuda/include/thrust/system/detail/generic/tag.h", - "cuda/include/thrust/system/detail/generic/unique_by_key.inl", - "cuda/include/thrust/system/detail/generic/sequence.h", - "cuda/include/thrust/system/detail/generic/type_traits.h", - "cuda/include/thrust/system/detail/generic/merge.h", - "cuda/include/thrust/system/detail/generic/reverse.inl", - "cuda/include/thrust/system/detail/generic/tabulate.inl", - "cuda/include/thrust/system/detail/generic/unique.inl", - "cuda/include/thrust/system/detail/generic/scatter.inl", - "cuda/include/thrust/system/detail/generic/set_operations.inl", - "cuda/include/thrust/system/detail/generic/copy_if.inl", - "cuda/include/thrust/system/detail/generic/transform_reduce.h", - "cuda/include/thrust/system/detail/generic/transform_scan.inl", - "cuda/include/thrust/system/detail/generic/gather.h", - "cuda/include/thrust/system/detail/generic/reduce_by_key.inl", - "cuda/include/thrust/system/detail/generic/transform_reduce.inl", - "cuda/include/thrust/system/detail/generic/sort.h", - "cuda/include/thrust/system/detail/generic/distance.inl", - "cuda/include/thrust/system/detail/generic/scan.h", - "cuda/include/thrust/system/detail/generic/temporary_buffer.h", - "cuda/include/thrust/system/detail/generic/reduce.inl", - "cuda/include/thrust/system/detail/generic/scan_by_key.h", - "cuda/include/thrust/system/detail/generic/reverse.h", - "cuda/include/thrust/system/detail/generic/temporary_buffer.inl", - "cuda/include/thrust/system/detail/generic/scatter.h", - "cuda/include/thrust/system/detail/generic/generate.inl", "cuda/include/thrust/system/detail/generic/adjacent_difference.inl", - "cuda/include/thrust/system/detail/generic/remove.inl", "cuda/include/thrust/system/detail/generic/advance.h", - "cuda/include/thrust/system/detail/generic/find.h", - "cuda/include/thrust/system/detail/generic/merge.inl", - "cuda/include/thrust/system/detail/generic/scalar/binary_search.inl", - "cuda/include/thrust/system/detail/generic/scalar/binary_search.h", - "cuda/include/thrust/system/detail/generic/extrema.inl", - "cuda/include/thrust/system/detail/generic/generate.h", - "cuda/include/thrust/system/detail/generic/uninitialized_fill.h", + "cuda/include/thrust/system/detail/generic/advance.inl", + "cuda/include/thrust/system/detail/generic/binary_search.h", + "cuda/include/thrust/system/detail/generic/binary_search.inl", + "cuda/include/thrust/system/detail/generic/copy.h", + "cuda/include/thrust/system/detail/generic/copy.inl", + "cuda/include/thrust/system/detail/generic/copy_if.h", + "cuda/include/thrust/system/detail/generic/copy_if.inl", + "cuda/include/thrust/system/detail/generic/count.h", "cuda/include/thrust/system/detail/generic/count.inl", - "cuda/include/thrust/system/detail/generic/remove.h", - "cuda/include/thrust/system/detail/generic/uninitialized_copy.inl", - "cuda/include/thrust/system/detail/generic/tabulate.h", - "cuda/include/thrust/system/detail/generic/for_each.h", "cuda/include/thrust/system/detail/generic/distance.h", - "cuda/include/thrust/system/detail/generic/swap_ranges.inl", - "cuda/include/thrust/system/detail/generic/reduce_by_key.h", - "cuda/include/thrust/system/detail/generic/reduce.h", + "cuda/include/thrust/system/detail/generic/distance.inl", "cuda/include/thrust/system/detail/generic/equal.h", - "cuda/include/thrust/system/detail/generic/mismatch.inl", - "cuda/include/thrust/system/detail/generic/copy.inl", - "cuda/include/thrust/system/detail/generic/copy.h", - "cuda/include/thrust/system/detail/generic/swap_ranges.h", - "cuda/include/thrust/system/detail/generic/uninitialized_copy.h", - "cuda/include/thrust/system/detail/generic/binary_search.h", - "cuda/include/thrust/system/detail/generic/set_operations.h", - "cuda/include/thrust/system/detail/generic/uninitialized_fill.inl", - "cuda/include/thrust/system/detail/generic/mismatch.h", - "cuda/include/thrust/system/detail/generic/scan.inl", - "cuda/include/thrust/system/detail/generic/gather.inl", + "cuda/include/thrust/system/detail/generic/equal.inl", "cuda/include/thrust/system/detail/generic/extrema.h", - "cuda/include/thrust/system/detail/generic/count.h", - "cuda/include/thrust/system/detail/generic/replace.h", + "cuda/include/thrust/system/detail/generic/extrema.inl", + "cuda/include/thrust/system/detail/generic/fill.h", + "cuda/include/thrust/system/detail/generic/find.h", + "cuda/include/thrust/system/detail/generic/find.inl", + "cuda/include/thrust/system/detail/generic/for_each.h", + "cuda/include/thrust/system/detail/generic/gather.h", + "cuda/include/thrust/system/detail/generic/gather.inl", + "cuda/include/thrust/system/detail/generic/generate.h", + "cuda/include/thrust/system/detail/generic/generate.inl", "cuda/include/thrust/system/detail/generic/inner_product.h", - "cuda/include/thrust/system/detail/generic/copy_if.h", + "cuda/include/thrust/system/detail/generic/inner_product.inl", "cuda/include/thrust/system/detail/generic/logical.h", - "cuda/include/thrust/system/detail/generic/partition.inl", "cuda/include/thrust/system/detail/generic/memory.h", - "cuda/include/thrust/system/detail/generic/find.inl", + "cuda/include/thrust/system/detail/generic/memory.inl", + "cuda/include/thrust/system/detail/generic/merge.h", + "cuda/include/thrust/system/detail/generic/merge.inl", + "cuda/include/thrust/system/detail/generic/mismatch.h", + "cuda/include/thrust/system/detail/generic/mismatch.inl", + "cuda/include/thrust/system/detail/generic/partition.h", + "cuda/include/thrust/system/detail/generic/partition.inl", + "cuda/include/thrust/system/detail/generic/reduce.h", + "cuda/include/thrust/system/detail/generic/reduce.inl", + "cuda/include/thrust/system/detail/generic/reduce_by_key.h", + "cuda/include/thrust/system/detail/generic/reduce_by_key.inl", + "cuda/include/thrust/system/detail/generic/remove.h", + "cuda/include/thrust/system/detail/generic/remove.inl", + "cuda/include/thrust/system/detail/generic/replace.h", "cuda/include/thrust/system/detail/generic/replace.inl", - "cuda/include/thrust/system/detail/generic/advance.inl", - "cuda/include/thrust/system/detail/generic/fill.h", + "cuda/include/thrust/system/detail/generic/reverse.h", + "cuda/include/thrust/system/detail/generic/reverse.inl", + "cuda/include/thrust/system/detail/generic/scalar/binary_search.h", + "cuda/include/thrust/system/detail/generic/scalar/binary_search.inl", + "cuda/include/thrust/system/detail/generic/scan.h", + "cuda/include/thrust/system/detail/generic/scan.inl", + "cuda/include/thrust/system/detail/generic/scan_by_key.h", + "cuda/include/thrust/system/detail/generic/scan_by_key.inl", + "cuda/include/thrust/system/detail/generic/scatter.h", + "cuda/include/thrust/system/detail/generic/scatter.inl", + "cuda/include/thrust/system/detail/generic/select_system.h", + "cuda/include/thrust/system/detail/generic/sequence.h", + "cuda/include/thrust/system/detail/generic/sequence.inl", + "cuda/include/thrust/system/detail/generic/set_operations.h", + "cuda/include/thrust/system/detail/generic/set_operations.inl", + "cuda/include/thrust/system/detail/generic/sort.h", + "cuda/include/thrust/system/detail/generic/sort.inl", + "cuda/include/thrust/system/detail/generic/swap_ranges.h", + "cuda/include/thrust/system/detail/generic/swap_ranges.inl", + "cuda/include/thrust/system/detail/generic/tabulate.h", + "cuda/include/thrust/system/detail/generic/tabulate.inl", + "cuda/include/thrust/system/detail/generic/tag.h", + "cuda/include/thrust/system/detail/generic/temporary_buffer.h", + "cuda/include/thrust/system/detail/generic/temporary_buffer.inl", "cuda/include/thrust/system/detail/generic/transform.h", + "cuda/include/thrust/system/detail/generic/transform.inl", + "cuda/include/thrust/system/detail/generic/transform_reduce.h", + "cuda/include/thrust/system/detail/generic/transform_reduce.inl", + "cuda/include/thrust/system/detail/generic/transform_scan.h", + "cuda/include/thrust/system/detail/generic/transform_scan.inl", + "cuda/include/thrust/system/detail/generic/type_traits.h", + "cuda/include/thrust/system/detail/generic/uninitialized_copy.h", + "cuda/include/thrust/system/detail/generic/uninitialized_copy.inl", + "cuda/include/thrust/system/detail/generic/uninitialized_fill.h", + "cuda/include/thrust/system/detail/generic/uninitialized_fill.inl", + "cuda/include/thrust/system/detail/generic/unique.h", + "cuda/include/thrust/system/detail/generic/unique.inl", + "cuda/include/thrust/system/detail/generic/unique_by_key.h", + "cuda/include/thrust/system/detail/generic/unique_by_key.inl", + "cuda/include/thrust/system/detail/internal/decompose.h", + "cuda/include/thrust/system/detail/sequential/adjacent_difference.h", + "cuda/include/thrust/system/detail/sequential/assign_value.h", + "cuda/include/thrust/system/detail/sequential/binary_search.h", + "cuda/include/thrust/system/detail/sequential/copy.h", + "cuda/include/thrust/system/detail/sequential/copy.inl", + "cuda/include/thrust/system/detail/sequential/copy_backward.h", + "cuda/include/thrust/system/detail/sequential/copy_if.h", + "cuda/include/thrust/system/detail/sequential/count.h", + "cuda/include/thrust/system/detail/sequential/equal.h", + "cuda/include/thrust/system/detail/sequential/execution_policy.h", + "cuda/include/thrust/system/detail/sequential/extrema.h", + "cuda/include/thrust/system/detail/sequential/fill.h", + "cuda/include/thrust/system/detail/sequential/find.h", + "cuda/include/thrust/system/detail/sequential/for_each.h", + "cuda/include/thrust/system/detail/sequential/gather.h", + "cuda/include/thrust/system/detail/sequential/general_copy.h", + "cuda/include/thrust/system/detail/sequential/generate.h", + "cuda/include/thrust/system/detail/sequential/get_value.h", + "cuda/include/thrust/system/detail/sequential/inner_product.h", + "cuda/include/thrust/system/detail/sequential/insertion_sort.h", + "cuda/include/thrust/system/detail/sequential/iter_swap.h", + "cuda/include/thrust/system/detail/sequential/logical.h", + "cuda/include/thrust/system/detail/sequential/malloc_and_free.h", + "cuda/include/thrust/system/detail/sequential/merge.h", + "cuda/include/thrust/system/detail/sequential/merge.inl", + "cuda/include/thrust/system/detail/sequential/mismatch.h", + "cuda/include/thrust/system/detail/sequential/partition.h", + "cuda/include/thrust/system/detail/sequential/reduce.h", + "cuda/include/thrust/system/detail/sequential/reduce_by_key.h", + "cuda/include/thrust/system/detail/sequential/remove.h", + "cuda/include/thrust/system/detail/sequential/replace.h", + "cuda/include/thrust/system/detail/sequential/reverse.h", + "cuda/include/thrust/system/detail/sequential/scan.h", + "cuda/include/thrust/system/detail/sequential/scan_by_key.h", + "cuda/include/thrust/system/detail/sequential/scatter.h", + "cuda/include/thrust/system/detail/sequential/sequence.h", + "cuda/include/thrust/system/detail/sequential/set_operations.h", + "cuda/include/thrust/system/detail/sequential/sort.h", + "cuda/include/thrust/system/detail/sequential/sort.inl", + "cuda/include/thrust/system/detail/sequential/stable_merge_sort.h", + "cuda/include/thrust/system/detail/sequential/stable_merge_sort.inl", + "cuda/include/thrust/system/detail/sequential/stable_primitive_sort.h", + "cuda/include/thrust/system/detail/sequential/stable_primitive_sort.inl", + "cuda/include/thrust/system/detail/sequential/stable_radix_sort.h", + "cuda/include/thrust/system/detail/sequential/stable_radix_sort.inl", + "cuda/include/thrust/system/detail/sequential/swap_ranges.h", + "cuda/include/thrust/system/detail/sequential/tabulate.h", + "cuda/include/thrust/system/detail/sequential/temporary_buffer.h", + "cuda/include/thrust/system/detail/sequential/transform.h", + "cuda/include/thrust/system/detail/sequential/transform_reduce.h", + "cuda/include/thrust/system/detail/sequential/transform_scan.h", + "cuda/include/thrust/system/detail/sequential/trivial_copy.h", + "cuda/include/thrust/system/detail/sequential/uninitialized_copy.h", + "cuda/include/thrust/system/detail/sequential/uninitialized_fill.h", + "cuda/include/thrust/system/detail/sequential/unique.h", + "cuda/include/thrust/system/detail/sequential/unique_by_key.h", "cuda/include/thrust/system/detail/system_error.inl", - "cuda/include/thrust/system/omp/execution_policy.h", - "cuda/include/thrust/system/omp/vector.h", - "cuda/include/thrust/system/omp/detail/transform_scan.h", - "cuda/include/thrust/system/omp/detail/memory.inl", - "cuda/include/thrust/system/omp/detail/reduce_intervals.inl", - "cuda/include/thrust/system/omp/detail/unique_by_key.h", - "cuda/include/thrust/system/omp/detail/sort.inl", - "cuda/include/thrust/system/omp/detail/partition.h", - "cuda/include/thrust/system/omp/detail/unique.h", - "cuda/include/thrust/system/omp/detail/execution_policy.h", + "cuda/include/thrust/system/error_code.h", "cuda/include/thrust/system/omp/detail/adjacent_difference.h", - "cuda/include/thrust/system/omp/detail/unique_by_key.inl", - "cuda/include/thrust/system/omp/detail/sequence.h", - "cuda/include/thrust/system/omp/detail/merge.h", - "cuda/include/thrust/system/omp/detail/unique.inl", + "cuda/include/thrust/system/omp/detail/assign_value.h", + "cuda/include/thrust/system/omp/detail/binary_search.h", + "cuda/include/thrust/system/omp/detail/copy.h", + "cuda/include/thrust/system/omp/detail/copy.inl", + "cuda/include/thrust/system/omp/detail/copy_if.h", "cuda/include/thrust/system/omp/detail/copy_if.inl", - "cuda/include/thrust/system/omp/detail/transform_reduce.h", - "cuda/include/thrust/system/omp/detail/gather.h", - "cuda/include/thrust/system/omp/detail/reduce_by_key.inl", - "cuda/include/thrust/system/omp/detail/sort.h", - "cuda/include/thrust/system/omp/detail/scan.h", - "cuda/include/thrust/system/omp/detail/temporary_buffer.h", + "cuda/include/thrust/system/omp/detail/count.h", "cuda/include/thrust/system/omp/detail/default_decomposition.h", - "cuda/include/thrust/system/omp/detail/reduce.inl", - "cuda/include/thrust/system/omp/detail/scan_by_key.h", - "cuda/include/thrust/system/omp/detail/reverse.h", - "cuda/include/thrust/system/omp/detail/assign_value.h", - "cuda/include/thrust/system/omp/detail/scatter.h", - "cuda/include/thrust/system/omp/detail/for_each.inl", "cuda/include/thrust/system/omp/detail/default_decomposition.inl", - "cuda/include/thrust/system/omp/detail/remove.inl", - "cuda/include/thrust/system/omp/detail/vector.inl", - "cuda/include/thrust/system/omp/detail/find.h", - "cuda/include/thrust/system/omp/detail/generate.h", - "cuda/include/thrust/system/omp/detail/uninitialized_fill.h", - "cuda/include/thrust/system/omp/detail/remove.h", - "cuda/include/thrust/system/omp/detail/tabulate.h", - "cuda/include/thrust/system/omp/detail/for_each.h", - "cuda/include/thrust/system/omp/detail/reduce_by_key.h", - "cuda/include/thrust/system/omp/detail/reduce.h", "cuda/include/thrust/system/omp/detail/equal.h", - "cuda/include/thrust/system/omp/detail/copy.inl", - "cuda/include/thrust/system/omp/detail/copy.h", - "cuda/include/thrust/system/omp/detail/swap_ranges.h", - "cuda/include/thrust/system/omp/detail/uninitialized_copy.h", - "cuda/include/thrust/system/omp/detail/binary_search.h", - "cuda/include/thrust/system/omp/detail/set_operations.h", - "cuda/include/thrust/system/omp/detail/mismatch.h", + "cuda/include/thrust/system/omp/detail/execution_policy.h", "cuda/include/thrust/system/omp/detail/extrema.h", - "cuda/include/thrust/system/omp/detail/count.h", - "cuda/include/thrust/system/omp/detail/replace.h", + "cuda/include/thrust/system/omp/detail/fill.h", + "cuda/include/thrust/system/omp/detail/find.h", + "cuda/include/thrust/system/omp/detail/for_each.h", + "cuda/include/thrust/system/omp/detail/for_each.inl", + "cuda/include/thrust/system/omp/detail/gather.h", + "cuda/include/thrust/system/omp/detail/generate.h", "cuda/include/thrust/system/omp/detail/get_value.h", "cuda/include/thrust/system/omp/detail/inner_product.h", - "cuda/include/thrust/system/omp/detail/copy_if.h", - "cuda/include/thrust/system/omp/detail/logical.h", - "cuda/include/thrust/system/omp/detail/partition.inl", "cuda/include/thrust/system/omp/detail/iter_swap.h", + "cuda/include/thrust/system/omp/detail/logical.h", + "cuda/include/thrust/system/omp/detail/malloc_and_free.h", + "cuda/include/thrust/system/omp/detail/memory.inl", + "cuda/include/thrust/system/omp/detail/merge.h", + "cuda/include/thrust/system/omp/detail/mismatch.h", "cuda/include/thrust/system/omp/detail/par.h", + "cuda/include/thrust/system/omp/detail/partition.h", + "cuda/include/thrust/system/omp/detail/partition.inl", + "cuda/include/thrust/system/omp/detail/reduce.h", + "cuda/include/thrust/system/omp/detail/reduce.inl", + "cuda/include/thrust/system/omp/detail/reduce_by_key.h", + "cuda/include/thrust/system/omp/detail/reduce_by_key.inl", "cuda/include/thrust/system/omp/detail/reduce_intervals.h", - "cuda/include/thrust/system/omp/detail/malloc_and_free.h", - "cuda/include/thrust/system/omp/detail/fill.h", + "cuda/include/thrust/system/omp/detail/reduce_intervals.inl", + "cuda/include/thrust/system/omp/detail/remove.h", + "cuda/include/thrust/system/omp/detail/remove.inl", + "cuda/include/thrust/system/omp/detail/replace.h", + "cuda/include/thrust/system/omp/detail/reverse.h", + "cuda/include/thrust/system/omp/detail/scan.h", + "cuda/include/thrust/system/omp/detail/scan_by_key.h", + "cuda/include/thrust/system/omp/detail/scatter.h", + "cuda/include/thrust/system/omp/detail/sequence.h", + "cuda/include/thrust/system/omp/detail/set_operations.h", + "cuda/include/thrust/system/omp/detail/sort.h", + "cuda/include/thrust/system/omp/detail/sort.inl", + "cuda/include/thrust/system/omp/detail/swap_ranges.h", + "cuda/include/thrust/system/omp/detail/tabulate.h", + "cuda/include/thrust/system/omp/detail/temporary_buffer.h", "cuda/include/thrust/system/omp/detail/transform.h", - "cuda/include/thrust/system/omp/memory.h", - "cuda/include/thrust/system/tbb/execution_policy.h", - "cuda/include/thrust/system/tbb/vector.h", - "cuda/include/thrust/system/tbb/detail/transform_scan.h", - "cuda/include/thrust/system/tbb/detail/memory.inl", - "cuda/include/thrust/system/tbb/detail/unique_by_key.h", - "cuda/include/thrust/system/tbb/detail/sort.inl", - "cuda/include/thrust/system/tbb/detail/partition.h", - "cuda/include/thrust/system/tbb/detail/unique.h", - "cuda/include/thrust/system/tbb/detail/execution_policy.h", + "cuda/include/thrust/system/omp/detail/transform_reduce.h", + "cuda/include/thrust/system/omp/detail/transform_scan.h", + "cuda/include/thrust/system/omp/detail/uninitialized_copy.h", + "cuda/include/thrust/system/omp/detail/uninitialized_fill.h", + "cuda/include/thrust/system/omp/detail/unique.h", + "cuda/include/thrust/system/omp/detail/unique.inl", + "cuda/include/thrust/system/omp/detail/unique_by_key.h", + "cuda/include/thrust/system/omp/detail/unique_by_key.inl", + "cuda/include/thrust/system/omp/detail/vector.inl", + "cuda/include/thrust/system/omp/execution_policy.h", + "cuda/include/thrust/system/omp/memory.h", + "cuda/include/thrust/system/omp/vector.h", + "cuda/include/thrust/system/system_error.h", "cuda/include/thrust/system/tbb/detail/adjacent_difference.h", - "cuda/include/thrust/system/tbb/detail/unique_by_key.inl", - "cuda/include/thrust/system/tbb/detail/sequence.h", - "cuda/include/thrust/system/tbb/detail/merge.h", - "cuda/include/thrust/system/tbb/detail/unique.inl", - "cuda/include/thrust/system/tbb/detail/copy_if.inl", - "cuda/include/thrust/system/tbb/detail/transform_reduce.h", - "cuda/include/thrust/system/tbb/detail/gather.h", - "cuda/include/thrust/system/tbb/detail/reduce_by_key.inl", - "cuda/include/thrust/system/tbb/detail/sort.h", - "cuda/include/thrust/system/tbb/detail/scan.h", - "cuda/include/thrust/system/tbb/detail/temporary_buffer.h", - "cuda/include/thrust/system/tbb/detail/reduce.inl", - "cuda/include/thrust/system/tbb/detail/scan_by_key.h", - "cuda/include/thrust/system/tbb/detail/reverse.h", "cuda/include/thrust/system/tbb/detail/assign_value.h", - "cuda/include/thrust/system/tbb/detail/scatter.h", - "cuda/include/thrust/system/tbb/detail/for_each.inl", - "cuda/include/thrust/system/tbb/detail/remove.inl", - "cuda/include/thrust/system/tbb/detail/vector.inl", - "cuda/include/thrust/system/tbb/detail/find.h", - "cuda/include/thrust/system/tbb/detail/merge.inl", - "cuda/include/thrust/system/tbb/detail/generate.h", - "cuda/include/thrust/system/tbb/detail/uninitialized_fill.h", - "cuda/include/thrust/system/tbb/detail/remove.h", - "cuda/include/thrust/system/tbb/detail/tabulate.h", - "cuda/include/thrust/system/tbb/detail/for_each.h", - "cuda/include/thrust/system/tbb/detail/reduce_by_key.h", - "cuda/include/thrust/system/tbb/detail/reduce.h", - "cuda/include/thrust/system/tbb/detail/equal.h", - "cuda/include/thrust/system/tbb/detail/copy.inl", - "cuda/include/thrust/system/tbb/detail/copy.h", - "cuda/include/thrust/system/tbb/detail/swap_ranges.h", - "cuda/include/thrust/system/tbb/detail/uninitialized_copy.h", "cuda/include/thrust/system/tbb/detail/binary_search.h", - "cuda/include/thrust/system/tbb/detail/set_operations.h", - "cuda/include/thrust/system/tbb/detail/mismatch.h", - "cuda/include/thrust/system/tbb/detail/scan.inl", - "cuda/include/thrust/system/tbb/detail/extrema.h", + "cuda/include/thrust/system/tbb/detail/copy.h", + "cuda/include/thrust/system/tbb/detail/copy.inl", + "cuda/include/thrust/system/tbb/detail/copy_if.h", + "cuda/include/thrust/system/tbb/detail/copy_if.inl", "cuda/include/thrust/system/tbb/detail/count.h", - "cuda/include/thrust/system/tbb/detail/replace.h", + "cuda/include/thrust/system/tbb/detail/equal.h", + "cuda/include/thrust/system/tbb/detail/execution_policy.h", + "cuda/include/thrust/system/tbb/detail/extrema.h", + "cuda/include/thrust/system/tbb/detail/fill.h", + "cuda/include/thrust/system/tbb/detail/find.h", + "cuda/include/thrust/system/tbb/detail/for_each.h", + "cuda/include/thrust/system/tbb/detail/for_each.inl", + "cuda/include/thrust/system/tbb/detail/gather.h", + "cuda/include/thrust/system/tbb/detail/generate.h", "cuda/include/thrust/system/tbb/detail/get_value.h", "cuda/include/thrust/system/tbb/detail/inner_product.h", - "cuda/include/thrust/system/tbb/detail/copy_if.h", - "cuda/include/thrust/system/tbb/detail/logical.h", - "cuda/include/thrust/system/tbb/detail/partition.inl", "cuda/include/thrust/system/tbb/detail/iter_swap.h", + "cuda/include/thrust/system/tbb/detail/logical.h", + "cuda/include/thrust/system/tbb/detail/malloc_and_free.h", + "cuda/include/thrust/system/tbb/detail/memory.inl", + "cuda/include/thrust/system/tbb/detail/merge.h", + "cuda/include/thrust/system/tbb/detail/merge.inl", + "cuda/include/thrust/system/tbb/detail/mismatch.h", "cuda/include/thrust/system/tbb/detail/par.h", + "cuda/include/thrust/system/tbb/detail/partition.h", + "cuda/include/thrust/system/tbb/detail/partition.inl", + "cuda/include/thrust/system/tbb/detail/reduce.h", + "cuda/include/thrust/system/tbb/detail/reduce.inl", + "cuda/include/thrust/system/tbb/detail/reduce_by_key.h", + "cuda/include/thrust/system/tbb/detail/reduce_by_key.inl", "cuda/include/thrust/system/tbb/detail/reduce_intervals.h", - "cuda/include/thrust/system/tbb/detail/malloc_and_free.h", - "cuda/include/thrust/system/tbb/detail/fill.h", + "cuda/include/thrust/system/tbb/detail/remove.h", + "cuda/include/thrust/system/tbb/detail/remove.inl", + "cuda/include/thrust/system/tbb/detail/replace.h", + "cuda/include/thrust/system/tbb/detail/reverse.h", + "cuda/include/thrust/system/tbb/detail/scan.h", + "cuda/include/thrust/system/tbb/detail/scan.inl", + "cuda/include/thrust/system/tbb/detail/scan_by_key.h", + "cuda/include/thrust/system/tbb/detail/scatter.h", + "cuda/include/thrust/system/tbb/detail/sequence.h", + "cuda/include/thrust/system/tbb/detail/set_operations.h", + "cuda/include/thrust/system/tbb/detail/sort.h", + "cuda/include/thrust/system/tbb/detail/sort.inl", + "cuda/include/thrust/system/tbb/detail/swap_ranges.h", + "cuda/include/thrust/system/tbb/detail/tabulate.h", + "cuda/include/thrust/system/tbb/detail/temporary_buffer.h", "cuda/include/thrust/system/tbb/detail/transform.h", - "cuda/include/thrust/system/tbb/memory.h", - "cuda/include/thrust/system/error_code.h", - "cuda/include/thrust/system/cpp/execution_policy.h", - "cuda/include/thrust/system/cpp/vector.h", - "cuda/include/thrust/system/cpp/detail/transform_scan.h", - "cuda/include/thrust/system/cpp/detail/memory.inl", - "cuda/include/thrust/system/cpp/detail/unique_by_key.h", - "cuda/include/thrust/system/cpp/detail/partition.h", - "cuda/include/thrust/system/cpp/detail/unique.h", - "cuda/include/thrust/system/cpp/detail/execution_policy.h", - "cuda/include/thrust/system/cpp/detail/adjacent_difference.h", - "cuda/include/thrust/system/cpp/detail/sequence.h", - "cuda/include/thrust/system/cpp/detail/merge.h", - "cuda/include/thrust/system/cpp/detail/transform_reduce.h", - "cuda/include/thrust/system/cpp/detail/gather.h", - "cuda/include/thrust/system/cpp/detail/sort.h", - "cuda/include/thrust/system/cpp/detail/scan.h", - "cuda/include/thrust/system/cpp/detail/temporary_buffer.h", - "cuda/include/thrust/system/cpp/detail/scan_by_key.h", - "cuda/include/thrust/system/cpp/detail/reverse.h", - "cuda/include/thrust/system/cpp/detail/assign_value.h", - "cuda/include/thrust/system/cpp/detail/scatter.h", - "cuda/include/thrust/system/cpp/detail/vector.inl", - "cuda/include/thrust/system/cpp/detail/find.h", - "cuda/include/thrust/system/cpp/detail/generate.h", - "cuda/include/thrust/system/cpp/detail/uninitialized_fill.h", - "cuda/include/thrust/system/cpp/detail/remove.h", - "cuda/include/thrust/system/cpp/detail/tabulate.h", - "cuda/include/thrust/system/cpp/detail/for_each.h", - "cuda/include/thrust/system/cpp/detail/reduce_by_key.h", - "cuda/include/thrust/system/cpp/detail/reduce.h", - "cuda/include/thrust/system/cpp/detail/equal.h", - "cuda/include/thrust/system/cpp/detail/copy.h", - "cuda/include/thrust/system/cpp/detail/swap_ranges.h", - "cuda/include/thrust/system/cpp/detail/uninitialized_copy.h", - "cuda/include/thrust/system/cpp/detail/binary_search.h", - "cuda/include/thrust/system/cpp/detail/set_operations.h", - "cuda/include/thrust/system/cpp/detail/mismatch.h", - "cuda/include/thrust/system/cpp/detail/extrema.h", - "cuda/include/thrust/system/cpp/detail/count.h", - "cuda/include/thrust/system/cpp/detail/replace.h", - "cuda/include/thrust/system/cpp/detail/get_value.h", - "cuda/include/thrust/system/cpp/detail/inner_product.h", - "cuda/include/thrust/system/cpp/detail/copy_if.h", - "cuda/include/thrust/system/cpp/detail/logical.h", - "cuda/include/thrust/system/cpp/detail/iter_swap.h", - "cuda/include/thrust/system/cpp/detail/par.h", - "cuda/include/thrust/system/cpp/detail/malloc_and_free.h", - "cuda/include/thrust/system/cpp/detail/fill.h", - "cuda/include/thrust/system/cpp/detail/transform.h", - "cuda/include/thrust/system/cpp/memory.h", - "cuda/include/thrust/system/cuda/execution_policy.h", - "cuda/include/thrust/system/cuda/vector.h", - "cuda/include/thrust/system/cuda/error.h", - "cuda/include/thrust/system/cuda/detail/copy_device_to_device.h", - "cuda/include/thrust/system/cuda/detail/transform_scan.h", - "cuda/include/thrust/system/cuda/detail/memory.inl", - "cuda/include/thrust/system/cuda/detail/cub/util_allocator.cuh", - "cuda/include/thrust/system/cuda/detail/cub/grid/grid_mapping.cuh", - "cuda/include/thrust/system/cuda/detail/cub/grid/grid_barrier.cuh", - "cuda/include/thrust/system/cuda/detail/cub/grid/grid_even_share.cuh", - "cuda/include/thrust/system/cuda/detail/cub/grid/grid_queue.cuh", - "cuda/include/thrust/system/cuda/detail/cub/util_device.cuh", - "cuda/include/thrust/system/cuda/detail/cub/device/device_run_length_encode.cuh", - "cuda/include/thrust/system/cuda/detail/cub/device/device_partition.cuh", - "cuda/include/thrust/system/cuda/detail/cub/device/device_radix_sort.cuh", - "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_rle_dispatch.cuh", - "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_histogram_dispatch.cuh", - "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_reduce_by_key_dispatch.cuh", - "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_scan_dispatch.cuh", - "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_select_dispatch.cuh", - "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_reduce_dispatch.cuh", - "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_radix_sort_dispatch.cuh", - "cuda/include/thrust/system/cuda/detail/cub/device/device_scan.cuh", - "cuda/include/thrust/system/cuda/detail/cub/device/device_select.cuh", - "cuda/include/thrust/system/cuda/detail/cub/device/device_reduce.cuh", - "cuda/include/thrust/system/cuda/detail/cub/device/device_histogram.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_reduce.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_histo.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_scan.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_radix_sort_downsweep.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_radix_sort_upsweep.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_satomic.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_sort.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_gatomic.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_select.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_range/block_scan_prefix_operators.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_reduce_by_key.cuh", - "cuda/include/thrust/system/cuda/detail/cub/util_macro.cuh", - "cuda/include/thrust/system/cuda/detail/cub/util_namespace.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_radix_sort_upsweep.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_histogram_sweep.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_rle_sweep.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_select_sweep.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_scan_sweep.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_reduce_sweep.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_satomic_sweep.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_sort_sweep.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_gatomic_sweep.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_radix_sort_downsweep.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_reduce_by_key_sweep.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_scan_prefix_operators.cuh", - "cuda/include/thrust/system/cuda/detail/cub/util_type.cuh", - "cuda/include/thrust/system/cuda/detail/cub/host/spinlock.cuh", - "cuda/include/thrust/system/cuda/detail/cub/warp/warp_reduce.cuh", - "cuda/include/thrust/system/cuda/detail/cub/warp/warp_scan.cuh", - "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_shfl.cuh", - "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_smem.cuh", - "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_shfl.cuh", - "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_smem.cuh", - "cuda/include/thrust/system/cuda/detail/cub/util_ptx.cuh", - "cuda/include/thrust/system/cuda/detail/cub/util_debug.cuh", - "cuda/include/thrust/system/cuda/detail/cub/cub.cuh", - "cuda/include/thrust/system/cuda/detail/cub/iterator/transform_input_iterator.cuh", - "cuda/include/thrust/system/cuda/detail/cub/iterator/tex_obj_input_iterator.cuh", - "cuda/include/thrust/system/cuda/detail/cub/iterator/tex_ref_input_iterator.cuh", - "cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_output_iterator.cuh", - "cuda/include/thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh", - "cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_input_iterator.cuh", - "cuda/include/thrust/system/cuda/detail/cub/iterator/arg_index_input_iterator.cuh", - "cuda/include/thrust/system/cuda/detail/cub/iterator/constant_input_iterator.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block/block_scan.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block/block_load.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block/block_discontinuity.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block/block_radix_rank.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block/block_shift.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block/block_store.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block/block_reduce.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block/block_exchange.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block/block_radix_sort.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block/block_histogram.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block/block_raking_layout.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_warp_reductions.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking_commutative_only.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_atomic.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_raking.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_sort.cuh", - "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking.cuh", - "cuda/include/thrust/system/cuda/detail/cub/thread/thread_load.cuh", - "cuda/include/thrust/system/cuda/detail/cub/thread/thread_store.cuh", - "cuda/include/thrust/system/cuda/detail/cub/thread/thread_scan.cuh", - "cuda/include/thrust/system/cuda/detail/cub/thread/thread_operators.cuh", - "cuda/include/thrust/system/cuda/detail/cub/thread/thread_reduce.cuh", - "cuda/include/thrust/system/cuda/detail/cub/util_arch.cuh", - "cuda/include/thrust/system/cuda/detail/reduce_intervals.inl", - "cuda/include/thrust/system/cuda/detail/copy_cross_system.inl", - "cuda/include/thrust/system/cuda/detail/unique_by_key.h", - "cuda/include/thrust/system/cuda/detail/bulk.h", - "cuda/include/thrust/system/cuda/detail/sort.inl", - "cuda/include/thrust/system/cuda/detail/partition.h", - "cuda/include/thrust/system/cuda/detail/unique.h", - "cuda/include/thrust/system/cuda/detail/execution_policy.h", - "cuda/include/thrust/system/cuda/detail/cuda_launch_config.h", - "cuda/include/thrust/system/cuda/detail/cub.h", - "cuda/include/thrust/system/cuda/detail/adjacent_difference.h", - "cuda/include/thrust/system/cuda/detail/sequence.h", - "cuda/include/thrust/system/cuda/detail/merge.h", - "cuda/include/thrust/system/cuda/detail/set_symmetric_difference.inl", - "cuda/include/thrust/system/cuda/detail/copy_if.inl", - "cuda/include/thrust/system/cuda/detail/transform_reduce.h", - "cuda/include/thrust/system/cuda/detail/error.inl", - "cuda/include/thrust/system/cuda/detail/gather.h", - "cuda/include/thrust/system/cuda/detail/reduce_by_key.inl", - "cuda/include/thrust/system/cuda/detail/sort.h", - "cuda/include/thrust/system/cuda/detail/synchronize.h", - "cuda/include/thrust/system/cuda/detail/scan.h", - "cuda/include/thrust/system/cuda/detail/temporary_indirect_permutation.h", - "cuda/include/thrust/system/cuda/detail/extern_shared_ptr.h", - "cuda/include/thrust/system/cuda/detail/detail/set_operation.inl", - "cuda/include/thrust/system/cuda/detail/detail/balanced_path.h", - "cuda/include/thrust/system/cuda/detail/detail/virtualized_smem_closure.h", - "cuda/include/thrust/system/cuda/detail/detail/stable_primitive_sort.h", - "cuda/include/thrust/system/cuda/detail/detail/set_operation.h", - "cuda/include/thrust/system/cuda/detail/detail/stable_primitive_sort.inl", - "cuda/include/thrust/system/cuda/detail/detail/stable_merge_sort.h", - "cuda/include/thrust/system/cuda/detail/detail/launch_closure.inl", - "cuda/include/thrust/system/cuda/detail/detail/merge.h", - "cuda/include/thrust/system/cuda/detail/detail/alignment.h", - "cuda/include/thrust/system/cuda/detail/detail/stable_radix_sort.inl", - "cuda/include/thrust/system/cuda/detail/detail/stable_sort_each.h", - "cuda/include/thrust/system/cuda/detail/detail/launch_calculator.inl", - "cuda/include/thrust/system/cuda/detail/detail/stable_merge_sort.inl", - "cuda/include/thrust/system/cuda/detail/detail/launch_closure.h", - "cuda/include/thrust/system/cuda/detail/detail/stable_radix_sort.h", - "cuda/include/thrust/system/cuda/detail/detail/uninitialized.h", - "cuda/include/thrust/system/cuda/detail/detail/cached_temporary_allocator.h", - "cuda/include/thrust/system/cuda/detail/detail/launch_calculator.h", - "cuda/include/thrust/system/cuda/detail/detail/stable_sort_each.inl", - "cuda/include/thrust/system/cuda/detail/temporary_buffer.h", - "cuda/include/thrust/system/cuda/detail/default_decomposition.h", - "cuda/include/thrust/system/cuda/detail/reduce.inl", - "cuda/include/thrust/system/cuda/detail/scan_by_key.h", - "cuda/include/thrust/system/cuda/detail/reverse.h", - "cuda/include/thrust/system/cuda/detail/assign_value.h", - "cuda/include/thrust/system/cuda/detail/scatter.h", - "cuda/include/thrust/system/cuda/detail/reduce_intervals.hpp", - "cuda/include/thrust/system/cuda/detail/for_each.inl", - "cuda/include/thrust/system/cuda/detail/default_decomposition.inl", - "cuda/include/thrust/system/cuda/detail/guarded_cuda_runtime_api.h", - "cuda/include/thrust/system/cuda/detail/adjacent_difference.inl", - "cuda/include/thrust/system/cuda/detail/vector.inl", - "cuda/include/thrust/system/cuda/detail/throw_on_error.h", - "cuda/include/thrust/system/cuda/detail/find.h", - "cuda/include/thrust/system/cuda/detail/terminate.h", - "cuda/include/thrust/system/cuda/detail/merge.inl", - "cuda/include/thrust/system/cuda/detail/trivial_copy.inl", - "cuda/include/thrust/system/cuda/detail/generate.h", - "cuda/include/thrust/system/cuda/detail/execute_on_stream.h", - "cuda/include/thrust/system/cuda/detail/uninitialized_fill.h", - "cuda/include/thrust/system/cuda/detail/remove.h", - "cuda/include/thrust/system/cuda/detail/tabulate.h", - "cuda/include/thrust/system/cuda/detail/for_each.h", - "cuda/include/thrust/system/cuda/detail/reduce_by_key.h", - "cuda/include/thrust/system/cuda/detail/decomposition.h", - "cuda/include/thrust/system/cuda/detail/reduce.h", - "cuda/include/thrust/system/cuda/detail/equal.h", - "cuda/include/thrust/system/cuda/detail/runtime_introspection.h", - "cuda/include/thrust/system/cuda/detail/copy.inl", - "cuda/include/thrust/system/cuda/detail/copy.h", - "cuda/include/thrust/system/cuda/detail/swap_ranges.h", - "cuda/include/thrust/system/cuda/detail/uninitialized_copy.h", - "cuda/include/thrust/system/cuda/detail/binary_search.h", - "cuda/include/thrust/system/cuda/detail/runtime_introspection.inl", - "cuda/include/thrust/system/cuda/detail/set_operations.h", - "cuda/include/thrust/system/cuda/detail/mismatch.h", - "cuda/include/thrust/system/cuda/detail/scan.inl", - "cuda/include/thrust/system/cuda/detail/synchronize.inl", - "cuda/include/thrust/system/cuda/detail/extrema.h", - "cuda/include/thrust/system/cuda/detail/set_union.inl", - "cuda/include/thrust/system/cuda/detail/set_intersection.inl", - "cuda/include/thrust/system/cuda/detail/count.h", - "cuda/include/thrust/system/cuda/detail/trivial_copy.h", - "cuda/include/thrust/system/cuda/detail/copy_device_to_device.inl", - "cuda/include/thrust/system/cuda/detail/replace.h", - "cuda/include/thrust/system/cuda/detail/bulk/malloc.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/algorithm.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/config.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/closure.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/tail_flags.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/terminate.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/alignment.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/guarded_cuda_runtime_api.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/choose_sizes.inl", - "cuda/include/thrust/system/cuda/detail/bulk/detail/tuple_meta_transform.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_task.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/head_flags.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/synchronize.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/throw_on_error.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/parameter_ptr.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/cuda_launcher.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/triple_chevron_launcher.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/runtime_introspection.inl", - "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/cuda_launch_config.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/runtime_introspection.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/async.inl", - "cuda/include/thrust/system/cuda/detail/bulk/detail/tuple_transform.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/pointer_traits.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/apply_from_tuple.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/detail/is_contiguous_iterator.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/iterator.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/choose_sizes.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/algorithm/copy.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/algorithm/merge.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/algorithm/accumulate.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/algorithm/scan.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/algorithm/detail/stable_merge_sort.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/algorithm/gather.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/algorithm/sort.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/algorithm/reduce.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/algorithm/scatter.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/algorithm/adjacent_difference.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/algorithm/reduce_by_key.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/algorithm/for_each.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/bulk.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/execution_policy.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/iterator/strided_iterator.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/uninitialized.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/async.hpp", - "cuda/include/thrust/system/cuda/detail/bulk/future.hpp", - "cuda/include/thrust/system/cuda/detail/guarded_driver_types.h", - "cuda/include/thrust/system/cuda/detail/get_value.h", - "cuda/include/thrust/system/cuda/detail/inner_product.h", - "cuda/include/thrust/system/cuda/detail/copy_if.h", - "cuda/include/thrust/system/cuda/detail/logical.h", - "cuda/include/thrust/system/cuda/detail/iter_swap.h", - "cuda/include/thrust/system/cuda/detail/block/merge.h", - "cuda/include/thrust/system/cuda/detail/block/inclusive_scan.h", - "cuda/include/thrust/system/cuda/detail/block/merge.inl", - "cuda/include/thrust/system/cuda/detail/block/merging_sort.h", - "cuda/include/thrust/system/cuda/detail/block/exclusive_scan.h", - "cuda/include/thrust/system/cuda/detail/block/reduce.h", - "cuda/include/thrust/system/cuda/detail/block/copy.h", - "cuda/include/thrust/system/cuda/detail/block/odd_even_sort.h", - "cuda/include/thrust/system/cuda/detail/par.h", - "cuda/include/thrust/system/cuda/detail/copy_cross_system.h", - "cuda/include/thrust/system/cuda/detail/reduce_intervals.h", - "cuda/include/thrust/system/cuda/detail/malloc_and_free.h", - "cuda/include/thrust/system/cuda/detail/fill.h", - "cuda/include/thrust/system/cuda/detail/set_difference.inl", - "cuda/include/thrust/system/cuda/detail/transform.h", - "cuda/include/thrust/system/cuda/experimental/pinned_allocator.h", - "cuda/include/thrust/system/cuda/memory.h", - "cuda/include/thrust/remove.h", + "cuda/include/thrust/system/tbb/detail/transform_reduce.h", + "cuda/include/thrust/system/tbb/detail/transform_scan.h", + "cuda/include/thrust/system/tbb/detail/uninitialized_copy.h", + "cuda/include/thrust/system/tbb/detail/uninitialized_fill.h", + "cuda/include/thrust/system/tbb/detail/unique.h", + "cuda/include/thrust/system/tbb/detail/unique.inl", + "cuda/include/thrust/system/tbb/detail/unique_by_key.h", + "cuda/include/thrust/system/tbb/detail/unique_by_key.inl", + "cuda/include/thrust/system/tbb/detail/vector.inl", + "cuda/include/thrust/system/tbb/execution_policy.h", + "cuda/include/thrust/system/tbb/memory.h", + "cuda/include/thrust/system/tbb/vector.h", + "cuda/include/thrust/system_error.h", "cuda/include/thrust/tabulate.h", - "cuda/include/thrust/for_each.h", - "cuda/include/thrust/distance.h", - "cuda/include/thrust/reduce.h", - "cuda/include/thrust/equal.h", - "cuda/include/thrust/complex.h", - "cuda/include/thrust/device_allocator.h", - "cuda/include/thrust/copy.h", + "cuda/include/thrust/transform.h", + "cuda/include/thrust/transform_reduce.h", + "cuda/include/thrust/transform_scan.h", + "cuda/include/thrust/tuple.h", "cuda/include/thrust/uninitialized_copy.h", - "cuda/include/thrust/device_reference.h", - "cuda/include/thrust/binary_search.h", - "cuda/include/thrust/set_operations.h", - "cuda/include/thrust/swap.h", - "cuda/include/thrust/mismatch.h", - "cuda/include/thrust/extrema.h", - "cuda/include/thrust/count.h", - "cuda/include/thrust/device_free.h", - "cuda/include/thrust/random/discard_block_engine.h", - "cuda/include/thrust/random/normal_distribution.h", - "cuda/include/thrust/random/detail/linear_feedback_shift_engine_wordmask.h", - "cuda/include/thrust/random/detail/subtract_with_carry_engine.inl", - "cuda/include/thrust/random/detail/xor_combine_engine_max.h", - "cuda/include/thrust/random/detail/linear_congruential_engine_discard.h", - "cuda/include/thrust/random/detail/uniform_int_distribution.inl", - "cuda/include/thrust/random/detail/discard_block_engine.inl", - "cuda/include/thrust/random/detail/uniform_real_distribution.inl", - "cuda/include/thrust/random/detail/random_core_access.h", - "cuda/include/thrust/random/detail/mod.h", - "cuda/include/thrust/random/detail/linear_feedback_shift_engine.inl", - "cuda/include/thrust/random/detail/linear_congruential_engine.inl", - "cuda/include/thrust/random/detail/xor_combine_engine.inl", - "cuda/include/thrust/random/detail/normal_distribution.inl", - "cuda/include/thrust/random/detail/normal_distribution_base.h", - "cuda/include/thrust/random/uniform_int_distribution.h", - "cuda/include/thrust/random/linear_feedback_shift_engine.h", - "cuda/include/thrust/random/xor_combine_engine.h", - "cuda/include/thrust/random/subtract_with_carry_engine.h", - "cuda/include/thrust/random/linear_congruential_engine.h", - "cuda/include/thrust/random/uniform_real_distribution.h", - "cuda/include/thrust/functional.h", - "cuda/include/thrust/replace.h", - "cuda/include/thrust/device_new_allocator.h", - "cuda/include/thrust/host_vector.h", + "cuda/include/thrust/uninitialized_fill.h", + "cuda/include/thrust/unique.h", "cuda/include/thrust/version.h", - "cuda/include/thrust/inner_product.h", - "cuda/include/thrust/iterator/iterator_traits.h", - "cuda/include/thrust/iterator/discard_iterator.h", - "cuda/include/thrust/iterator/retag.h", - "cuda/include/thrust/iterator/permutation_iterator.h", - "cuda/include/thrust/iterator/transform_iterator.h", - "cuda/include/thrust/iterator/detail/reverse_iterator.inl", - "cuda/include/thrust/iterator/detail/zip_iterator.inl", - "cuda/include/thrust/iterator/detail/counting_iterator.inl", - "cuda/include/thrust/iterator/detail/distance_from_result.h", - "cuda/include/thrust/iterator/detail/host_system_tag.h", - "cuda/include/thrust/iterator/detail/iterator_traversal_tags.h", - "cuda/include/thrust/iterator/detail/retag.h", - "cuda/include/thrust/iterator/detail/tagged_iterator.h", - "cuda/include/thrust/iterator/detail/iterator_traits.inl", - "cuda/include/thrust/iterator/detail/minimum_category.h", - "cuda/include/thrust/iterator/detail/discard_iterator_base.h", - "cuda/include/thrust/iterator/detail/iterator_category_to_traversal.h", - "cuda/include/thrust/iterator/detail/zip_iterator_base.h", - "cuda/include/thrust/iterator/detail/normal_iterator.h", - "cuda/include/thrust/iterator/detail/join_iterator.h", - "cuda/include/thrust/iterator/detail/device_system_tag.h", - "cuda/include/thrust/iterator/detail/universal_categories.h", - "cuda/include/thrust/iterator/detail/reverse_iterator_base.h", - "cuda/include/thrust/iterator/detail/minimum_system.h", - "cuda/include/thrust/iterator/detail/tuple_of_iterator_references.h", - "cuda/include/thrust/iterator/detail/is_iterator_category.h", - "cuda/include/thrust/iterator/detail/permutation_iterator_base.h", - "cuda/include/thrust/iterator/detail/any_assign.h", - "cuda/include/thrust/iterator/detail/any_system_tag.h", - "cuda/include/thrust/iterator/detail/is_trivial_iterator.h", - "cuda/include/thrust/iterator/detail/iterator_category_to_system.h", - "cuda/include/thrust/iterator/detail/iterator_adaptor_base.h", - "cuda/include/thrust/iterator/detail/constant_iterator_base.h", - "cuda/include/thrust/iterator/detail/transform_iterator.inl", - "cuda/include/thrust/iterator/detail/iterator_facade_category.h", - "cuda/include/thrust/iterator/detail/iterator_category_with_system_and_traversal.h", - "cuda/include/thrust/iterator/constant_iterator.h", - "cuda/include/thrust/iterator/counting_iterator.h", - "cuda/include/thrust/iterator/iterator_adaptor.h", - "cuda/include/thrust/iterator/iterator_facade.h", - "cuda/include/thrust/iterator/iterator_categories.h", - "cuda/include/thrust/iterator/reverse_iterator.h", - "cuda/include/thrust/iterator/zip_iterator.h", - "cuda/include/thrust/logical.h", - "cuda/include/thrust/tuple.h", - "cuda/include/thrust/memory.h", - "cuda/include/thrust/random.h", - "cuda/include/thrust/fill.h", - "cuda/include/thrust/transform.h", - "cuda/include/texture_types.h", - "cuda/include/nppversion.h", - "cuda/include/cuda_texture_types.h", - "cuda/include/fatbinary.h", - "cuda/include/cublasXt.h", - "cuda/include/cuda_fp16.h", "cuda/include/vector_functions.h", - "cuda/include/cusparse.h", - "cuda/include/nppi_filtering_functions.h", - "cuda/include/nppi_morphological_operations.h", - "cuda/include/sobol_direction_vectors.h", - "cuda/include/nvblas.h", - "cuda/include/curand_mtgp32dc_p_11213.h", - "cuda/include/nvcuvid.h", - "cuda/include/cuda_runtime_api.h", - "cuda/include/curand_mtgp32_kernel.h", - "cuda/include/cublas_v2.h", - "cuda/include/builtin_types.h", - "cuda/include/nppi_geometry_transforms.h", - "cuda/include/npps_support_functions.h", - "cuda/include/cufftw.h", - "cuda/include/cuda_device_runtime_api.h", - "cuda/include/sm_30_intrinsics.hpp", + "cuda/include/vector_functions.hpp", "cuda/include/vector_types.h", - "cuda/include/sm_35_atomic_functions.h", - "cuda/include/sm_20_intrinsics.h", - "cuda/include/driver_types.h", - "cuda/include/nvToolsExtCudaRt.h", - "cuda/include/curand_globals.h", - "cuda/include/device_atomic_functions.h", - "cuda/include/surface_types.h", - "cuda/include/nvrtc.h", - "cuda/include/nppdefs.h", - "cuda/include/sm_60_atomic_functions.h", - "cuda/include/driver_functions.h", - "cuda/include/cusolver_common.h", - "cuda/include/cublas.h", - "cuda/include/curand_lognormal.h", - "cuda/include/device_atomic_functions.hpp", - "cuda/include/crt/device_runtime.h", - "cuda/include/crt/storage_class.h", - "cuda/include/crt/func_macro.h", - "cuda/include/crt/host_runtime.h", - "cuda/include/nppi_arithmetic_and_logical_operations.h", - "cuda/include/npps_arithmetic_and_logical_operations.h", - "cuda/include/nppi_computer_vision.h", - "cuda/include/surface_functions.hpp", - "cuda/include/surface_functions.h", - "cuda/include/curand_normal_static.h", - "cuda/include/curand.h", - "cuda/include/math_functions_dbl_ptx3.h", - "cuda/include/curand_philox4x32_x.h", - "cuda/include/nppi_threshold_and_compare_operations.h", - "cuda/include/nvml.h", - "cuda/include/npps.h", - "cuda/include/cuda_vdpau_interop.h", - "cuda/include/sm_61_intrinsics.hpp", - "cuda/include/cublas_api.h", - "cuda/include/nppi_color_conversion.h", - "cuda/include/math_functions_dbl_ptx3.hpp", - "cuda/include/nppcore.h", - "cuda/include/cudaGL.h", - "cuda/include/fatBinaryCtl.h", - "cuda/include/npps_statistics_functions.h", - "cuda/include/cudaVDPAU.h", - "cuda/include/curand_poisson.h", - "cuda/include/cusolverDn.h", - "cuda/include/cuda_profiler_api.h", - "cuda/include/sm_20_atomic_functions.h", - "cuda/include/nvfunctional", ], cmd = """ -if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-8.0/include/math_functions.hpp" "$(@D)/cuda/include/math_functions.hpp" && cp "/usr/local/cuda-8.0/include/cufft.h" "$(@D)/cuda/include/cufft.h" && cp "/usr/local/cuda-8.0/include/nvgraph.h" "$(@D)/cuda/include/nvgraph.h" && cp "/usr/local/cuda-8.0/include/curand_normal.h" "$(@D)/cuda/include/curand_normal.h" && cp "/usr/local/cuda-8.0/include/curand_uniform.h" "$(@D)/cuda/include/curand_uniform.h" && cp "/usr/local/cuda-8.0/include/nppi_data_exchange_and_initialization.h" "$(@D)/cuda/include/nppi_data_exchange_and_initialization.h" && cp "/usr/local/cuda-8.0/include/cuda_gl_interop.h" "$(@D)/cuda/include/cuda_gl_interop.h" && cp "/usr/local/cuda-8.0/include/nppi_compression_functions.h" "$(@D)/cuda/include/nppi_compression_functions.h" && cp "/usr/local/cuda-8.0/include/npp.h" "$(@D)/cuda/include/npp.h" && cp "/usr/local/cuda-8.0/include/cuda.h" "$(@D)/cuda/include/cuda.h" && cp "/usr/local/cuda-8.0/include/nppi_statistics_functions.h" "$(@D)/cuda/include/nppi_statistics_functions.h" && cp "/usr/local/cuda-8.0/include/vector_functions.hpp" "$(@D)/cuda/include/vector_functions.hpp" && cp "/usr/local/cuda-8.0/include/sm_32_intrinsics.hpp" "$(@D)/cuda/include/sm_32_intrinsics.hpp" && cp "/usr/local/cuda-8.0/include/sm_32_intrinsics.h" "$(@D)/cuda/include/sm_32_intrinsics.h" && cp "/usr/local/cuda-8.0/include/curand_discrete.h" "$(@D)/cuda/include/curand_discrete.h" && cp "/usr/local/cuda-8.0/include/cuda_runtime.h" "$(@D)/cuda/include/cuda_runtime.h" && cp "/usr/local/cuda-8.0/include/cufftXt.h" "$(@D)/cuda/include/cufftXt.h" && cp "/usr/local/cuda-8.0/include/sm_61_intrinsics.h" "$(@D)/cuda/include/sm_61_intrinsics.h" && cp "/usr/local/cuda-8.0/include/texture_fetch_functions.h" "$(@D)/cuda/include/texture_fetch_functions.h" && cp "/usr/local/cuda-8.0/include/curand_mrg32k3a.h" "$(@D)/cuda/include/curand_mrg32k3a.h" && cp "/usr/local/cuda-8.0/include/host_defines.h" "$(@D)/cuda/include/host_defines.h" && cp "/usr/local/cuda-8.0/include/common_functions.h" "$(@D)/cuda/include/common_functions.h" && cp "/usr/local/cuda-8.0/include/nppi_support_functions.h" "$(@D)/cuda/include/nppi_support_functions.h" && cp "/usr/local/cuda-8.0/include/nppi_linear_transforms.h" "$(@D)/cuda/include/nppi_linear_transforms.h" && cp "/usr/local/cuda-8.0/include/device_double_functions.hpp" "$(@D)/cuda/include/device_double_functions.hpp" && cp "/usr/local/cuda-8.0/include/math_constants.h" "$(@D)/cuda/include/math_constants.h" && cp "/usr/local/cuda-8.0/include/nvToolsExtSync.h" "$(@D)/cuda/include/nvToolsExtSync.h" && cp "/usr/local/cuda-8.0/include/npps_initialization.h" "$(@D)/cuda/include/npps_initialization.h" && cp "/usr/local/cuda-8.0/include/cusolverSp_LOWLEVEL_PREVIEW.h" "$(@D)/cuda/include/cusolverSp_LOWLEVEL_PREVIEW.h" && cp "/usr/local/cuda-8.0/include/texture_indirect_functions.hpp" "$(@D)/cuda/include/texture_indirect_functions.hpp" && cp "/usr/local/cuda-8.0/include/cudaProfiler.h" "$(@D)/cuda/include/cudaProfiler.h" && cp "/usr/local/cuda-8.0/include/npps_filtering_functions.h" "$(@D)/cuda/include/npps_filtering_functions.h" && cp "/usr/local/cuda-8.0/include/cusparse_v2.h" "$(@D)/cuda/include/cusparse_v2.h" && cp "/usr/local/cuda-8.0/include/nppi.h" "$(@D)/cuda/include/nppi.h" && cp "/usr/local/cuda-8.0/include/surface_indirect_functions.h" "$(@D)/cuda/include/surface_indirect_functions.h" && cp "/usr/local/cuda-8.0/include/sm_30_intrinsics.h" "$(@D)/cuda/include/sm_30_intrinsics.h" && cp "/usr/local/cuda-8.0/include/device_double_functions.h" "$(@D)/cuda/include/device_double_functions.h" && cp "/usr/local/cuda-8.0/include/sm_35_intrinsics.h" "$(@D)/cuda/include/sm_35_intrinsics.h" && cp "/usr/local/cuda-8.0/include/cusolverSp.h" "$(@D)/cuda/include/cusolverSp.h" && cp "/usr/local/cuda-8.0/include/library_types.h" "$(@D)/cuda/include/library_types.h" && cp "/usr/local/cuda-8.0/include/surface_indirect_functions.hpp" "$(@D)/cuda/include/surface_indirect_functions.hpp" && cp "/usr/local/cuda-8.0/include/cudalibxt.h" "$(@D)/cuda/include/cudalibxt.h" && cp "/usr/local/cuda-8.0/include/channel_descriptor.h" "$(@D)/cuda/include/channel_descriptor.h" && cp "/usr/local/cuda-8.0/include/device_functions_decls.h" "$(@D)/cuda/include/device_functions_decls.h" && cp "/usr/local/cuda-8.0/include/curand_kernel.h" "$(@D)/cuda/include/curand_kernel.h" && cp "/usr/local/cuda-8.0/include/curand_mtgp32_host.h" "$(@D)/cuda/include/curand_mtgp32_host.h" && cp "/usr/local/cuda-8.0/include/nvToolsExtCuda.h" "$(@D)/cuda/include/nvToolsExtCuda.h" && cp "/usr/local/cuda-8.0/include/nvToolsExt.h" "$(@D)/cuda/include/nvToolsExt.h" && cp "/usr/local/cuda-8.0/include/cuComplex.h" "$(@D)/cuda/include/cuComplex.h" && cp "/usr/local/cuda-8.0/include/sm_32_atomic_functions.h" "$(@D)/cuda/include/sm_32_atomic_functions.h" && cp "/usr/local/cuda-8.0/include/texture_indirect_functions.h" "$(@D)/cuda/include/texture_indirect_functions.h" && cp "/usr/local/cuda-8.0/include/sm_32_atomic_functions.hpp" "$(@D)/cuda/include/sm_32_atomic_functions.hpp" && cp "/usr/local/cuda-8.0/include/sm_20_intrinsics.hpp" "$(@D)/cuda/include/sm_20_intrinsics.hpp" && cp "/usr/local/cuda-8.0/include/device_launch_parameters.h" "$(@D)/cuda/include/device_launch_parameters.h" && cp "/usr/local/cuda-8.0/include/curand_mtgp32.h" "$(@D)/cuda/include/curand_mtgp32.h" && cp "/usr/local/cuda-8.0/include/texture_fetch_functions.hpp" "$(@D)/cuda/include/texture_fetch_functions.hpp" && cp "/usr/local/cuda-8.0/include/cuda_occupancy.h" "$(@D)/cuda/include/cuda_occupancy.h" && cp "/usr/local/cuda-8.0/include/CL/opencl.h" "$(@D)/cuda/include/CL/opencl.h" && cp "/usr/local/cuda-8.0/include/CL/cl_platform.h" "$(@D)/cuda/include/CL/cl_platform.h" && cp "/usr/local/cuda-8.0/include/CL/cl_egl.h" "$(@D)/cuda/include/CL/cl_egl.h" && cp "/usr/local/cuda-8.0/include/CL/cl_gl.h" "$(@D)/cuda/include/CL/cl_gl.h" && cp "/usr/local/cuda-8.0/include/CL/cl.h" "$(@D)/cuda/include/CL/cl.h" && cp "/usr/local/cuda-8.0/include/CL/cl_gl_ext.h" "$(@D)/cuda/include/CL/cl_gl_ext.h" && cp "/usr/local/cuda-8.0/include/CL/cl_ext.h" "$(@D)/cuda/include/CL/cl_ext.h" && cp "/usr/local/cuda-8.0/include/CL/cl.hpp" "$(@D)/cuda/include/CL/cl.hpp" && cp "/usr/local/cuda-8.0/include/host_config.h" "$(@D)/cuda/include/host_config.h" && cp "/usr/local/cuda-8.0/include/cuda_surface_types.h" "$(@D)/cuda/include/cuda_surface_types.h" && cp "/usr/local/cuda-8.0/include/math_functions.h" "$(@D)/cuda/include/math_functions.h" && cp "/usr/local/cuda-8.0/include/nvToolsExtMeta.h" "$(@D)/cuda/include/nvToolsExtMeta.h" && cp "/usr/local/cuda-8.0/include/sm_20_atomic_functions.hpp" "$(@D)/cuda/include/sm_20_atomic_functions.hpp" && cp "/usr/local/cuda-8.0/include/device_functions.h" "$(@D)/cuda/include/device_functions.h" && cp "/usr/local/cuda-8.0/include/device_types.h" "$(@D)/cuda/include/device_types.h" && cp "/usr/local/cuda-8.0/include/npps_conversion_functions.h" "$(@D)/cuda/include/npps_conversion_functions.h" && cp "/usr/local/cuda-8.0/include/curand_precalc.h" "$(@D)/cuda/include/curand_precalc.h" && cp "/usr/local/cuda-8.0/include/cusolverRf.h" "$(@D)/cuda/include/cusolverRf.h" && cp "/usr/local/cuda-8.0/include/sm_60_atomic_functions.hpp" "$(@D)/cuda/include/sm_60_atomic_functions.hpp" && cp "/usr/local/cuda-8.0/include/cuviddec.h" "$(@D)/cuda/include/cuviddec.h" && cp "/usr/local/cuda-8.0/include/curand_discrete2.h" "$(@D)/cuda/include/curand_discrete2.h" && cp "/usr/local/cuda-8.0/include/device_functions.hpp" "$(@D)/cuda/include/device_functions.hpp" && cp "/usr/local/cuda-8.0/include/thrust/transform_scan.h" "$(@D)/cuda/include/thrust/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system_error.h" "$(@D)/cuda/include/thrust/system_error.h" && cp "/usr/local/cuda-8.0/include/thrust/device_malloc.h" "$(@D)/cuda/include/thrust/device_malloc.h" && cp "/usr/local/cuda-8.0/include/thrust/partition.h" "$(@D)/cuda/include/thrust/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/unique.h" "$(@D)/cuda/include/thrust/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/device_delete.h" "$(@D)/cuda/include/thrust/device_delete.h" && cp "/usr/local/cuda-8.0/include/thrust/execution_policy.h" "$(@D)/cuda/include/thrust/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/adjacent_difference.h" "$(@D)/cuda/include/thrust/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/sequence.h" "$(@D)/cuda/include/thrust/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/merge.h" "$(@D)/cuda/include/thrust/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/device_new.h" "$(@D)/cuda/include/thrust/device_new.h" && cp "/usr/local/cuda-8.0/include/thrust/transform_reduce.h" "$(@D)/cuda/include/thrust/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/device_vector.h" "$(@D)/cuda/include/thrust/device_vector.h" && cp "/usr/local/cuda-8.0/include/thrust/gather.h" "$(@D)/cuda/include/thrust/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/sort.h" "$(@D)/cuda/include/thrust/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/scan.h" "$(@D)/cuda/include/thrust/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/temporary_array.h" "$(@D)/cuda/include/thrust/detail/temporary_array.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/util/align.h" "$(@D)/cuda/include/thrust/detail/util/align.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/util/blocking.h" "$(@D)/cuda/include/thrust/detail/util/blocking.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/transform.inl" "$(@D)/cuda/include/thrust/detail/transform.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_vector.inl" "$(@D)/cuda/include/thrust/detail/device_vector.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/binary_search.inl" "$(@D)/cuda/include/thrust/detail/binary_search.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/overlapped_copy.h" "$(@D)/cuda/include/thrust/detail/overlapped_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/vector_base.inl" "$(@D)/cuda/include/thrust/detail/vector_base.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_reference.inl" "$(@D)/cuda/include/thrust/detail/device_reference.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/actor.h" "$(@D)/cuda/include/thrust/detail/functional/actor.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/value.h" "$(@D)/cuda/include/thrust/detail/functional/value.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/logical_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/logical_operators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/relational_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/relational_operators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/assignment_operator.h" "$(@D)/cuda/include/thrust/detail/functional/operators/assignment_operator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/bitwise_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/bitwise_operators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/operator_adaptors.h" "$(@D)/cuda/include/thrust/detail/functional/operators/operator_adaptors.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/arithmetic_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/arithmetic_operators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/compound_assignment_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/compound_assignment_operators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/argument.h" "$(@D)/cuda/include/thrust/detail/functional/argument.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/placeholder.h" "$(@D)/cuda/include/thrust/detail/functional/placeholder.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/actor.inl" "$(@D)/cuda/include/thrust/detail/functional/actor.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/composite.h" "$(@D)/cuda/include/thrust/detail/functional/composite.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/static_map.h" "$(@D)/cuda/include/thrust/detail/static_map.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/has_nested_type.h" "$(@D)/cuda/include/thrust/detail/type_traits/has_nested_type.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/is_call_possible.h" "$(@D)/cuda/include/thrust/detail/type_traits/is_call_possible.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/function_traits.h" "$(@D)/cuda/include/thrust/detail/type_traits/function_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/pointer_traits.h" "$(@D)/cuda/include/thrust/detail/type_traits/pointer_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/has_member_function.h" "$(@D)/cuda/include/thrust/detail/type_traits/has_member_function.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h" "$(@D)/cuda/include/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/minimum_type.h" "$(@D)/cuda/include/thrust/detail/type_traits/minimum_type.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/has_trivial_assign.h" "$(@D)/cuda/include/thrust/detail/type_traits/has_trivial_assign.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/is_metafunction_defined.h" "$(@D)/cuda/include/thrust/detail/type_traits/is_metafunction_defined.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/iterator/is_discard_iterator.h" "$(@D)/cuda/include/thrust/detail/type_traits/iterator/is_discard_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/iterator/is_output_iterator.h" "$(@D)/cuda/include/thrust/detail/type_traits/iterator/is_output_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/result_of_adaptable_function.h" "$(@D)/cuda/include/thrust/detail/type_traits/result_of_adaptable_function.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/reference.h" "$(@D)/cuda/include/thrust/detail/reference.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/inner_product.inl" "$(@D)/cuda/include/thrust/detail/inner_product.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/use_default.h" "$(@D)/cuda/include/thrust/detail/use_default.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/sequence.inl" "$(@D)/cuda/include/thrust/detail/sequence.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/sort.inl" "$(@D)/cuda/include/thrust/detail/sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/equal.inl" "$(@D)/cuda/include/thrust/detail/equal.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/execution_policy.h" "$(@D)/cuda/include/thrust/detail/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/integer_traits.h" "$(@D)/cuda/include/thrust/detail/integer_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits.h" "$(@D)/cuda/include/thrust/detail/type_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/reverse.inl" "$(@D)/cuda/include/thrust/detail/reverse.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/tabulate.inl" "$(@D)/cuda/include/thrust/detail/tabulate.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/unique.inl" "$(@D)/cuda/include/thrust/detail/unique.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/scatter.inl" "$(@D)/cuda/include/thrust/detail/scatter.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/set_operations.inl" "$(@D)/cuda/include/thrust/detail/set_operations.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_malloc.inl" "$(@D)/cuda/include/thrust/detail/device_malloc.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/copy_if.inl" "$(@D)/cuda/include/thrust/detail/copy_if.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/fill.inl" "$(@D)/cuda/include/thrust/detail/fill.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/temporary_array.inl" "$(@D)/cuda/include/thrust/detail/temporary_array.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/transform_scan.inl" "$(@D)/cuda/include/thrust/detail/transform_scan.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/minmax.h" "$(@D)/cuda/include/thrust/detail/minmax.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/swap.inl" "$(@D)/cuda/include/thrust/detail/swap.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/pointer.inl" "$(@D)/cuda/include/thrust/detail/pointer.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/transform_reduce.inl" "$(@D)/cuda/include/thrust/detail/transform_reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/config.h" "$(@D)/cuda/include/thrust/detail/config.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/distance.inl" "$(@D)/cuda/include/thrust/detail/distance.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/pair.inl" "$(@D)/cuda/include/thrust/detail/pair.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/temporary_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/temporary_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/tagged_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/tagged_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/destroy_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/destroy_range.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/destroy_range.h" "$(@D)/cuda/include/thrust/detail/allocator/destroy_range.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/no_throw_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/no_throw_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/default_construct_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/default_construct_range.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/fill_construct_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/fill_construct_range.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/tagged_allocator.inl" "$(@D)/cuda/include/thrust/detail/allocator/tagged_allocator.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/malloc_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/malloc_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/allocator_traits.h" "$(@D)/cuda/include/thrust/detail/allocator/allocator_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/copy_construct_range.h" "$(@D)/cuda/include/thrust/detail/allocator/copy_construct_range.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/allocator_traits.inl" "$(@D)/cuda/include/thrust/detail/allocator/allocator_traits.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/default_construct_range.h" "$(@D)/cuda/include/thrust/detail/allocator/default_construct_range.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/copy_construct_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/copy_construct_range.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/malloc_allocator.inl" "$(@D)/cuda/include/thrust/detail/allocator/malloc_allocator.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/temporary_allocator.inl" "$(@D)/cuda/include/thrust/detail/allocator/temporary_allocator.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/fill_construct_range.h" "$(@D)/cuda/include/thrust/detail/allocator/fill_construct_range.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/detail/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/reduce.inl" "$(@D)/cuda/include/thrust/detail/reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_new.inl" "$(@D)/cuda/include/thrust/detail/device_new.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/pointer.h" "$(@D)/cuda/include/thrust/detail/pointer.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/for_each.inl" "$(@D)/cuda/include/thrust/detail/for_each.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/generate.inl" "$(@D)/cuda/include/thrust/detail/generate.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/dispatch/is_trivial_copy.h" "$(@D)/cuda/include/thrust/detail/dispatch/is_trivial_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/adjacent_difference.inl" "$(@D)/cuda/include/thrust/detail/adjacent_difference.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/tuple_meta_transform.h" "$(@D)/cuda/include/thrust/detail/tuple_meta_transform.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional.inl" "$(@D)/cuda/include/thrust/detail/functional.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/remove.inl" "$(@D)/cuda/include/thrust/detail/remove.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/tuple_transform.h" "$(@D)/cuda/include/thrust/detail/tuple_transform.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/merge.inl" "$(@D)/cuda/include/thrust/detail/merge.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/extrema.inl" "$(@D)/cuda/include/thrust/detail/extrema.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/trivial_sequence.h" "$(@D)/cuda/include/thrust/detail/trivial_sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/vector_base.h" "$(@D)/cuda/include/thrust/detail/vector_base.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/count.inl" "$(@D)/cuda/include/thrust/detail/count.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/uninitialized_copy.inl" "$(@D)/cuda/include/thrust/detail/uninitialized_copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/function.h" "$(@D)/cuda/include/thrust/detail/function.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/swap_ranges.inl" "$(@D)/cuda/include/thrust/detail/swap_ranges.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_delete.inl" "$(@D)/cuda/include/thrust/detail/device_delete.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/static_assert.h" "$(@D)/cuda/include/thrust/detail/static_assert.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/logical.inl" "$(@D)/cuda/include/thrust/detail/logical.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/seq.h" "$(@D)/cuda/include/thrust/detail/seq.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/mpl/math.h" "$(@D)/cuda/include/thrust/detail/mpl/math.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/mismatch.inl" "$(@D)/cuda/include/thrust/detail/mismatch.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/internal_functional.h" "$(@D)/cuda/include/thrust/detail/internal_functional.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/get_iterator_value.h" "$(@D)/cuda/include/thrust/detail/get_iterator_value.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/copy.inl" "$(@D)/cuda/include/thrust/detail/copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/copy.h" "$(@D)/cuda/include/thrust/detail/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/catrigf.h" "$(@D)/cuda/include/thrust/detail/complex/catrigf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/cpowf.h" "$(@D)/cuda/include/thrust/detail/complex/cpowf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/csqrtf.h" "$(@D)/cuda/include/thrust/detail/complex/csqrtf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/ccoshf.h" "$(@D)/cuda/include/thrust/detail/complex/ccoshf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/csinhf.h" "$(@D)/cuda/include/thrust/detail/complex/csinhf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/clogf.h" "$(@D)/cuda/include/thrust/detail/complex/clogf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/ccosh.h" "$(@D)/cuda/include/thrust/detail/complex/ccosh.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/arithmetic.h" "$(@D)/cuda/include/thrust/detail/complex/arithmetic.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/csqrt.h" "$(@D)/cuda/include/thrust/detail/complex/csqrt.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/cpow.h" "$(@D)/cuda/include/thrust/detail/complex/cpow.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/complex.inl" "$(@D)/cuda/include/thrust/detail/complex/complex.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/math_private.h" "$(@D)/cuda/include/thrust/detail/complex/math_private.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/c99math.h" "$(@D)/cuda/include/thrust/detail/complex/c99math.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/cproj.h" "$(@D)/cuda/include/thrust/detail/complex/cproj.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/catrig.h" "$(@D)/cuda/include/thrust/detail/complex/catrig.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/ctanhf.h" "$(@D)/cuda/include/thrust/detail/complex/ctanhf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/cexpf.h" "$(@D)/cuda/include/thrust/detail/complex/cexpf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/csinh.h" "$(@D)/cuda/include/thrust/detail/complex/csinh.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/stream.h" "$(@D)/cuda/include/thrust/detail/complex/stream.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/ctanh.h" "$(@D)/cuda/include/thrust/detail/complex/ctanh.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/cexp.h" "$(@D)/cuda/include/thrust/detail/complex/cexp.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/clog.h" "$(@D)/cuda/include/thrust/detail/complex/clog.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/range/head_flags.h" "$(@D)/cuda/include/thrust/detail/range/head_flags.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/range/tail_flags.h" "$(@D)/cuda/include/thrust/detail/range/tail_flags.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/execute_with_allocator.h" "$(@D)/cuda/include/thrust/detail/execute_with_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/integer_math.h" "$(@D)/cuda/include/thrust/detail/integer_math.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/swap.h" "$(@D)/cuda/include/thrust/detail/swap.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/uninitialized_fill.inl" "$(@D)/cuda/include/thrust/detail/uninitialized_fill.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/scan.inl" "$(@D)/cuda/include/thrust/detail/scan.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/gather.inl" "$(@D)/cuda/include/thrust/detail/gather.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/reference_forward_declaration.h" "$(@D)/cuda/include/thrust/detail/reference_forward_declaration.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/numeric_traits.h" "$(@D)/cuda/include/thrust/detail/numeric_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/reference.inl" "$(@D)/cuda/include/thrust/detail/reference.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/cstdint.h" "$(@D)/cuda/include/thrust/detail/cstdint.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_free.inl" "$(@D)/cuda/include/thrust/detail/device_free.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/copy_if.h" "$(@D)/cuda/include/thrust/detail/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/partition.inl" "$(@D)/cuda/include/thrust/detail/partition.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/find.inl" "$(@D)/cuda/include/thrust/detail/find.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/forceinline.h" "$(@D)/cuda/include/thrust/detail/config/forceinline.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/debug.h" "$(@D)/cuda/include/thrust/detail/config/debug.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/config.h" "$(@D)/cuda/include/thrust/detail/config/config.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/host_device.h" "$(@D)/cuda/include/thrust/detail/config/host_device.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/host_system.h" "$(@D)/cuda/include/thrust/detail/config/host_system.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/compiler.h" "$(@D)/cuda/include/thrust/detail/config/compiler.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/device_system.h" "$(@D)/cuda/include/thrust/detail/config/device_system.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/compiler_fence.h" "$(@D)/cuda/include/thrust/detail/config/compiler_fence.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/exec_check_disable.h" "$(@D)/cuda/include/thrust/detail/config/exec_check_disable.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/simple_defines.h" "$(@D)/cuda/include/thrust/detail/config/simple_defines.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/global_workarounds.h" "$(@D)/cuda/include/thrust/detail/config/global_workarounds.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/replace.inl" "$(@D)/cuda/include/thrust/detail/replace.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_ptr.inl" "$(@D)/cuda/include/thrust/detail/device_ptr.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/tuple.inl" "$(@D)/cuda/include/thrust/detail/tuple.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/detail/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/host_vector.inl" "$(@D)/cuda/include/thrust/detail/host_vector.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/raw_pointer_cast.h" "$(@D)/cuda/include/thrust/detail/raw_pointer_cast.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/advance.inl" "$(@D)/cuda/include/thrust/detail/advance.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/contiguous_storage.h" "$(@D)/cuda/include/thrust/detail/contiguous_storage.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/raw_reference_cast.h" "$(@D)/cuda/include/thrust/detail/raw_reference_cast.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/contiguous_storage.inl" "$(@D)/cuda/include/thrust/detail/contiguous_storage.inl" && cp "/usr/local/cuda-8.0/include/thrust/reverse.h" "$(@D)/cuda/include/thrust/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/device_malloc_allocator.h" "$(@D)/cuda/include/thrust/device_malloc_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/scatter.h" "$(@D)/cuda/include/thrust/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/pair.h" "$(@D)/cuda/include/thrust/pair.h" && cp "/usr/local/cuda-8.0/include/thrust/advance.h" "$(@D)/cuda/include/thrust/advance.h" && cp "/usr/local/cuda-8.0/include/thrust/find.h" "$(@D)/cuda/include/thrust/find.h" && cp "/usr/local/cuda-8.0/include/thrust/device_ptr.h" "$(@D)/cuda/include/thrust/device_ptr.h" && cp "/usr/local/cuda-8.0/include/thrust/generate.h" "$(@D)/cuda/include/thrust/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/uninitialized_fill.h" "$(@D)/cuda/include/thrust/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/system_error.h" "$(@D)/cuda/include/thrust/system/system_error.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/bad_alloc.h" "$(@D)/cuda/include/thrust/system/detail/bad_alloc.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/transform_scan.h" "$(@D)/cuda/include/thrust/system/detail/adl/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/unique_by_key.h" "$(@D)/cuda/include/thrust/system/detail/adl/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/partition.h" "$(@D)/cuda/include/thrust/system/detail/adl/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/unique.h" "$(@D)/cuda/include/thrust/system/detail/adl/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/detail/adl/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/sequence.h" "$(@D)/cuda/include/thrust/system/detail/adl/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/merge.h" "$(@D)/cuda/include/thrust/system/detail/adl/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/transform_reduce.h" "$(@D)/cuda/include/thrust/system/detail/adl/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/gather.h" "$(@D)/cuda/include/thrust/system/detail/adl/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/sort.h" "$(@D)/cuda/include/thrust/system/detail/adl/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/scan.h" "$(@D)/cuda/include/thrust/system/detail/adl/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/detail/adl/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/scan_by_key.h" "$(@D)/cuda/include/thrust/system/detail/adl/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/reverse.h" "$(@D)/cuda/include/thrust/system/detail/adl/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/assign_value.h" "$(@D)/cuda/include/thrust/system/detail/adl/assign_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/scatter.h" "$(@D)/cuda/include/thrust/system/detail/adl/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/find.h" "$(@D)/cuda/include/thrust/system/detail/adl/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/generate.h" "$(@D)/cuda/include/thrust/system/detail/adl/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/detail/adl/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/remove.h" "$(@D)/cuda/include/thrust/system/detail/adl/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/tabulate.h" "$(@D)/cuda/include/thrust/system/detail/adl/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/for_each.h" "$(@D)/cuda/include/thrust/system/detail/adl/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/detail/adl/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/reduce.h" "$(@D)/cuda/include/thrust/system/detail/adl/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/equal.h" "$(@D)/cuda/include/thrust/system/detail/adl/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/copy.h" "$(@D)/cuda/include/thrust/system/detail/adl/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/swap_ranges.h" "$(@D)/cuda/include/thrust/system/detail/adl/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/detail/adl/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/adl/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/set_operations.h" "$(@D)/cuda/include/thrust/system/detail/adl/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/mismatch.h" "$(@D)/cuda/include/thrust/system/detail/adl/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/extrema.h" "$(@D)/cuda/include/thrust/system/detail/adl/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/count.h" "$(@D)/cuda/include/thrust/system/detail/adl/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/replace.h" "$(@D)/cuda/include/thrust/system/detail/adl/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/get_value.h" "$(@D)/cuda/include/thrust/system/detail/adl/get_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/inner_product.h" "$(@D)/cuda/include/thrust/system/detail/adl/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/copy_if.h" "$(@D)/cuda/include/thrust/system/detail/adl/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/logical.h" "$(@D)/cuda/include/thrust/system/detail/adl/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/iter_swap.h" "$(@D)/cuda/include/thrust/system/detail/adl/iter_swap.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/detail/adl/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/fill.h" "$(@D)/cuda/include/thrust/system/detail/adl/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/transform.h" "$(@D)/cuda/include/thrust/system/detail/adl/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/errno.h" "$(@D)/cuda/include/thrust/system/detail/errno.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/error_category.inl" "$(@D)/cuda/include/thrust/system/detail/error_category.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/transform_scan.h" "$(@D)/cuda/include/thrust/system/detail/sequential/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/unique_by_key.h" "$(@D)/cuda/include/thrust/system/detail/sequential/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/stable_primitive_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_primitive_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/stable_primitive_sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_primitive_sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/stable_merge_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_merge_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/partition.h" "$(@D)/cuda/include/thrust/system/detail/sequential/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/unique.h" "$(@D)/cuda/include/thrust/system/detail/sequential/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/execution_policy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/detail/sequential/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/sequence.h" "$(@D)/cuda/include/thrust/system/detail/sequential/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/merge.h" "$(@D)/cuda/include/thrust/system/detail/sequential/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/transform_reduce.h" "$(@D)/cuda/include/thrust/system/detail/sequential/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/gather.h" "$(@D)/cuda/include/thrust/system/detail/sequential/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/copy_backward.h" "$(@D)/cuda/include/thrust/system/detail/sequential/copy_backward.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/stable_radix_sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_radix_sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/scan.h" "$(@D)/cuda/include/thrust/system/detail/sequential/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/detail/sequential/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/scan_by_key.h" "$(@D)/cuda/include/thrust/system/detail/sequential/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/reverse.h" "$(@D)/cuda/include/thrust/system/detail/sequential/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/assign_value.h" "$(@D)/cuda/include/thrust/system/detail/sequential/assign_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/scatter.h" "$(@D)/cuda/include/thrust/system/detail/sequential/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/find.h" "$(@D)/cuda/include/thrust/system/detail/sequential/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/stable_merge_sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_merge_sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/merge.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/merge.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/generate.h" "$(@D)/cuda/include/thrust/system/detail/sequential/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/detail/sequential/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/general_copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/general_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/insertion_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/insertion_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/remove.h" "$(@D)/cuda/include/thrust/system/detail/sequential/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/tabulate.h" "$(@D)/cuda/include/thrust/system/detail/sequential/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/for_each.h" "$(@D)/cuda/include/thrust/system/detail/sequential/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/detail/sequential/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/reduce.h" "$(@D)/cuda/include/thrust/system/detail/sequential/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/equal.h" "$(@D)/cuda/include/thrust/system/detail/sequential/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/stable_radix_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_radix_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/copy.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/swap_ranges.h" "$(@D)/cuda/include/thrust/system/detail/sequential/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/sequential/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/set_operations.h" "$(@D)/cuda/include/thrust/system/detail/sequential/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/mismatch.h" "$(@D)/cuda/include/thrust/system/detail/sequential/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/extrema.h" "$(@D)/cuda/include/thrust/system/detail/sequential/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/count.h" "$(@D)/cuda/include/thrust/system/detail/sequential/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/trivial_copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/trivial_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/replace.h" "$(@D)/cuda/include/thrust/system/detail/sequential/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/get_value.h" "$(@D)/cuda/include/thrust/system/detail/sequential/get_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/inner_product.h" "$(@D)/cuda/include/thrust/system/detail/sequential/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/copy_if.h" "$(@D)/cuda/include/thrust/system/detail/sequential/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/logical.h" "$(@D)/cuda/include/thrust/system/detail/sequential/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/iter_swap.h" "$(@D)/cuda/include/thrust/system/detail/sequential/iter_swap.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/detail/sequential/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/fill.h" "$(@D)/cuda/include/thrust/system/detail/sequential/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/transform.h" "$(@D)/cuda/include/thrust/system/detail/sequential/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/error_condition.inl" "$(@D)/cuda/include/thrust/system/detail/error_condition.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/internal/decompose.h" "$(@D)/cuda/include/thrust/system/detail/internal/decompose.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/error_code.inl" "$(@D)/cuda/include/thrust/system/detail/error_code.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/transform_scan.h" "$(@D)/cuda/include/thrust/system/detail/generic/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/memory.inl" "$(@D)/cuda/include/thrust/system/detail/generic/memory.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/transform.inl" "$(@D)/cuda/include/thrust/system/detail/generic/transform.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/binary_search.inl" "$(@D)/cuda/include/thrust/system/detail/generic/binary_search.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scan_by_key.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scan_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/unique_by_key.h" "$(@D)/cuda/include/thrust/system/detail/generic/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/inner_product.inl" "$(@D)/cuda/include/thrust/system/detail/generic/inner_product.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/select_system.h" "$(@D)/cuda/include/thrust/system/detail/generic/select_system.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/sequence.inl" "$(@D)/cuda/include/thrust/system/detail/generic/sequence.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/sort.inl" "$(@D)/cuda/include/thrust/system/detail/generic/sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/equal.inl" "$(@D)/cuda/include/thrust/system/detail/generic/equal.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/partition.h" "$(@D)/cuda/include/thrust/system/detail/generic/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/unique.h" "$(@D)/cuda/include/thrust/system/detail/generic/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/detail/generic/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/tag.h" "$(@D)/cuda/include/thrust/system/detail/generic/tag.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/unique_by_key.inl" "$(@D)/cuda/include/thrust/system/detail/generic/unique_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/sequence.h" "$(@D)/cuda/include/thrust/system/detail/generic/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/type_traits.h" "$(@D)/cuda/include/thrust/system/detail/generic/type_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/merge.h" "$(@D)/cuda/include/thrust/system/detail/generic/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/reverse.inl" "$(@D)/cuda/include/thrust/system/detail/generic/reverse.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/tabulate.inl" "$(@D)/cuda/include/thrust/system/detail/generic/tabulate.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/unique.inl" "$(@D)/cuda/include/thrust/system/detail/generic/unique.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scatter.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scatter.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/set_operations.inl" "$(@D)/cuda/include/thrust/system/detail/generic/set_operations.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/copy_if.inl" "$(@D)/cuda/include/thrust/system/detail/generic/copy_if.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/transform_reduce.h" "$(@D)/cuda/include/thrust/system/detail/generic/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/transform_scan.inl" "$(@D)/cuda/include/thrust/system/detail/generic/transform_scan.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/gather.h" "$(@D)/cuda/include/thrust/system/detail/generic/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/detail/generic/reduce_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/transform_reduce.inl" "$(@D)/cuda/include/thrust/system/detail/generic/transform_reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/sort.h" "$(@D)/cuda/include/thrust/system/detail/generic/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/distance.inl" "$(@D)/cuda/include/thrust/system/detail/generic/distance.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scan.h" "$(@D)/cuda/include/thrust/system/detail/generic/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/detail/generic/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/reduce.inl" "$(@D)/cuda/include/thrust/system/detail/generic/reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scan_by_key.h" "$(@D)/cuda/include/thrust/system/detail/generic/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/reverse.h" "$(@D)/cuda/include/thrust/system/detail/generic/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/temporary_buffer.inl" "$(@D)/cuda/include/thrust/system/detail/generic/temporary_buffer.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scatter.h" "$(@D)/cuda/include/thrust/system/detail/generic/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/generate.inl" "$(@D)/cuda/include/thrust/system/detail/generic/generate.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/adjacent_difference.inl" "$(@D)/cuda/include/thrust/system/detail/generic/adjacent_difference.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/remove.inl" "$(@D)/cuda/include/thrust/system/detail/generic/remove.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/advance.h" "$(@D)/cuda/include/thrust/system/detail/generic/advance.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/find.h" "$(@D)/cuda/include/thrust/system/detail/generic/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/merge.inl" "$(@D)/cuda/include/thrust/system/detail/generic/merge.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scalar/binary_search.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scalar/binary_search.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scalar/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/generic/scalar/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/extrema.inl" "$(@D)/cuda/include/thrust/system/detail/generic/extrema.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/generate.h" "$(@D)/cuda/include/thrust/system/detail/generic/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/count.inl" "$(@D)/cuda/include/thrust/system/detail/generic/count.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/remove.h" "$(@D)/cuda/include/thrust/system/detail/generic/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/uninitialized_copy.inl" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/tabulate.h" "$(@D)/cuda/include/thrust/system/detail/generic/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/for_each.h" "$(@D)/cuda/include/thrust/system/detail/generic/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/distance.h" "$(@D)/cuda/include/thrust/system/detail/generic/distance.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/swap_ranges.inl" "$(@D)/cuda/include/thrust/system/detail/generic/swap_ranges.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/detail/generic/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/reduce.h" "$(@D)/cuda/include/thrust/system/detail/generic/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/equal.h" "$(@D)/cuda/include/thrust/system/detail/generic/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/mismatch.inl" "$(@D)/cuda/include/thrust/system/detail/generic/mismatch.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/copy.inl" "$(@D)/cuda/include/thrust/system/detail/generic/copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/copy.h" "$(@D)/cuda/include/thrust/system/detail/generic/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/swap_ranges.h" "$(@D)/cuda/include/thrust/system/detail/generic/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/generic/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/set_operations.h" "$(@D)/cuda/include/thrust/system/detail/generic/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/uninitialized_fill.inl" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_fill.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/mismatch.h" "$(@D)/cuda/include/thrust/system/detail/generic/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scan.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scan.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/gather.inl" "$(@D)/cuda/include/thrust/system/detail/generic/gather.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/extrema.h" "$(@D)/cuda/include/thrust/system/detail/generic/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/count.h" "$(@D)/cuda/include/thrust/system/detail/generic/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/replace.h" "$(@D)/cuda/include/thrust/system/detail/generic/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/inner_product.h" "$(@D)/cuda/include/thrust/system/detail/generic/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/copy_if.h" "$(@D)/cuda/include/thrust/system/detail/generic/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/logical.h" "$(@D)/cuda/include/thrust/system/detail/generic/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/partition.inl" "$(@D)/cuda/include/thrust/system/detail/generic/partition.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/memory.h" "$(@D)/cuda/include/thrust/system/detail/generic/memory.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/find.inl" "$(@D)/cuda/include/thrust/system/detail/generic/find.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/replace.inl" "$(@D)/cuda/include/thrust/system/detail/generic/replace.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/advance.inl" "$(@D)/cuda/include/thrust/system/detail/generic/advance.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/fill.h" "$(@D)/cuda/include/thrust/system/detail/generic/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/transform.h" "$(@D)/cuda/include/thrust/system/detail/generic/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/system_error.inl" "$(@D)/cuda/include/thrust/system/detail/system_error.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/execution_policy.h" "$(@D)/cuda/include/thrust/system/omp/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/vector.h" "$(@D)/cuda/include/thrust/system/omp/vector.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/omp/detail/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/memory.inl" "$(@D)/cuda/include/thrust/system/omp/detail/memory.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reduce_intervals.inl" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_intervals.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/omp/detail/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/sort.inl" "$(@D)/cuda/include/thrust/system/omp/detail/sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/partition.h" "$(@D)/cuda/include/thrust/system/omp/detail/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/unique.h" "$(@D)/cuda/include/thrust/system/omp/detail/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/omp/detail/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/omp/detail/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/unique_by_key.inl" "$(@D)/cuda/include/thrust/system/omp/detail/unique_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/sequence.h" "$(@D)/cuda/include/thrust/system/omp/detail/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/merge.h" "$(@D)/cuda/include/thrust/system/omp/detail/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/unique.inl" "$(@D)/cuda/include/thrust/system/omp/detail/unique.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/copy_if.inl" "$(@D)/cuda/include/thrust/system/omp/detail/copy_if.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/omp/detail/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/gather.h" "$(@D)/cuda/include/thrust/system/omp/detail/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/sort.h" "$(@D)/cuda/include/thrust/system/omp/detail/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/scan.h" "$(@D)/cuda/include/thrust/system/omp/detail/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/omp/detail/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/default_decomposition.h" "$(@D)/cuda/include/thrust/system/omp/detail/default_decomposition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reduce.inl" "$(@D)/cuda/include/thrust/system/omp/detail/reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/omp/detail/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reverse.h" "$(@D)/cuda/include/thrust/system/omp/detail/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/omp/detail/assign_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/scatter.h" "$(@D)/cuda/include/thrust/system/omp/detail/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/for_each.inl" "$(@D)/cuda/include/thrust/system/omp/detail/for_each.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/default_decomposition.inl" "$(@D)/cuda/include/thrust/system/omp/detail/default_decomposition.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/remove.inl" "$(@D)/cuda/include/thrust/system/omp/detail/remove.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/vector.inl" "$(@D)/cuda/include/thrust/system/omp/detail/vector.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/find.h" "$(@D)/cuda/include/thrust/system/omp/detail/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/generate.h" "$(@D)/cuda/include/thrust/system/omp/detail/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/omp/detail/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/remove.h" "$(@D)/cuda/include/thrust/system/omp/detail/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/omp/detail/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/for_each.h" "$(@D)/cuda/include/thrust/system/omp/detail/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reduce.h" "$(@D)/cuda/include/thrust/system/omp/detail/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/equal.h" "$(@D)/cuda/include/thrust/system/omp/detail/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/copy.inl" "$(@D)/cuda/include/thrust/system/omp/detail/copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/copy.h" "$(@D)/cuda/include/thrust/system/omp/detail/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/omp/detail/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/omp/detail/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/omp/detail/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/omp/detail/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/omp/detail/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/extrema.h" "$(@D)/cuda/include/thrust/system/omp/detail/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/count.h" "$(@D)/cuda/include/thrust/system/omp/detail/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/replace.h" "$(@D)/cuda/include/thrust/system/omp/detail/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/get_value.h" "$(@D)/cuda/include/thrust/system/omp/detail/get_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/omp/detail/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/omp/detail/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/logical.h" "$(@D)/cuda/include/thrust/system/omp/detail/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/partition.inl" "$(@D)/cuda/include/thrust/system/omp/detail/partition.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/omp/detail/iter_swap.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/par.h" "$(@D)/cuda/include/thrust/system/omp/detail/par.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reduce_intervals.h" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_intervals.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/omp/detail/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/fill.h" "$(@D)/cuda/include/thrust/system/omp/detail/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/transform.h" "$(@D)/cuda/include/thrust/system/omp/detail/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/memory.h" "$(@D)/cuda/include/thrust/system/omp/memory.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/execution_policy.h" "$(@D)/cuda/include/thrust/system/tbb/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/vector.h" "$(@D)/cuda/include/thrust/system/tbb/vector.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/tbb/detail/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/memory.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/memory.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/tbb/detail/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/sort.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/partition.h" "$(@D)/cuda/include/thrust/system/tbb/detail/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/unique.h" "$(@D)/cuda/include/thrust/system/tbb/detail/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/tbb/detail/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/tbb/detail/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/unique_by_key.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/unique_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/sequence.h" "$(@D)/cuda/include/thrust/system/tbb/detail/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/merge.h" "$(@D)/cuda/include/thrust/system/tbb/detail/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/unique.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/unique.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/copy_if.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/copy_if.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/tbb/detail/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/gather.h" "$(@D)/cuda/include/thrust/system/tbb/detail/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/sort.h" "$(@D)/cuda/include/thrust/system/tbb/detail/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/scan.h" "$(@D)/cuda/include/thrust/system/tbb/detail/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/tbb/detail/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/reduce.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/tbb/detail/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/reverse.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/tbb/detail/assign_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/scatter.h" "$(@D)/cuda/include/thrust/system/tbb/detail/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/for_each.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/for_each.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/remove.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/remove.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/vector.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/vector.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/find.h" "$(@D)/cuda/include/thrust/system/tbb/detail/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/merge.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/merge.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/generate.h" "$(@D)/cuda/include/thrust/system/tbb/detail/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/tbb/detail/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/remove.h" "$(@D)/cuda/include/thrust/system/tbb/detail/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/tbb/detail/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/for_each.h" "$(@D)/cuda/include/thrust/system/tbb/detail/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/reduce.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/equal.h" "$(@D)/cuda/include/thrust/system/tbb/detail/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/copy.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/copy.h" "$(@D)/cuda/include/thrust/system/tbb/detail/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/tbb/detail/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/tbb/detail/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/tbb/detail/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/tbb/detail/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/tbb/detail/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/scan.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/scan.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/extrema.h" "$(@D)/cuda/include/thrust/system/tbb/detail/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/count.h" "$(@D)/cuda/include/thrust/system/tbb/detail/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/replace.h" "$(@D)/cuda/include/thrust/system/tbb/detail/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/get_value.h" "$(@D)/cuda/include/thrust/system/tbb/detail/get_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/tbb/detail/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/tbb/detail/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/logical.h" "$(@D)/cuda/include/thrust/system/tbb/detail/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/partition.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/partition.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/tbb/detail/iter_swap.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/par.h" "$(@D)/cuda/include/thrust/system/tbb/detail/par.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/reduce_intervals.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce_intervals.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/tbb/detail/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/fill.h" "$(@D)/cuda/include/thrust/system/tbb/detail/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/transform.h" "$(@D)/cuda/include/thrust/system/tbb/detail/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/memory.h" "$(@D)/cuda/include/thrust/system/tbb/memory.h" && cp "/usr/local/cuda-8.0/include/thrust/system/error_code.h" "$(@D)/cuda/include/thrust/system/error_code.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/execution_policy.h" "$(@D)/cuda/include/thrust/system/cpp/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/vector.h" "$(@D)/cuda/include/thrust/system/cpp/vector.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/cpp/detail/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/memory.inl" "$(@D)/cuda/include/thrust/system/cpp/detail/memory.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/cpp/detail/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/partition.h" "$(@D)/cuda/include/thrust/system/cpp/detail/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/unique.h" "$(@D)/cuda/include/thrust/system/cpp/detail/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/cpp/detail/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/cpp/detail/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/sequence.h" "$(@D)/cuda/include/thrust/system/cpp/detail/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/merge.h" "$(@D)/cuda/include/thrust/system/cpp/detail/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/cpp/detail/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/gather.h" "$(@D)/cuda/include/thrust/system/cpp/detail/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/sort.h" "$(@D)/cuda/include/thrust/system/cpp/detail/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/scan.h" "$(@D)/cuda/include/thrust/system/cpp/detail/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/cpp/detail/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/cpp/detail/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/reverse.h" "$(@D)/cuda/include/thrust/system/cpp/detail/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/cpp/detail/assign_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/scatter.h" "$(@D)/cuda/include/thrust/system/cpp/detail/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/vector.inl" "$(@D)/cuda/include/thrust/system/cpp/detail/vector.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/find.h" "$(@D)/cuda/include/thrust/system/cpp/detail/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/generate.h" "$(@D)/cuda/include/thrust/system/cpp/detail/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/cpp/detail/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/remove.h" "$(@D)/cuda/include/thrust/system/cpp/detail/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/cpp/detail/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/for_each.h" "$(@D)/cuda/include/thrust/system/cpp/detail/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/cpp/detail/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/reduce.h" "$(@D)/cuda/include/thrust/system/cpp/detail/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/equal.h" "$(@D)/cuda/include/thrust/system/cpp/detail/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/copy.h" "$(@D)/cuda/include/thrust/system/cpp/detail/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/cpp/detail/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/cpp/detail/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/cpp/detail/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/cpp/detail/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/cpp/detail/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/extrema.h" "$(@D)/cuda/include/thrust/system/cpp/detail/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/count.h" "$(@D)/cuda/include/thrust/system/cpp/detail/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/replace.h" "$(@D)/cuda/include/thrust/system/cpp/detail/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/get_value.h" "$(@D)/cuda/include/thrust/system/cpp/detail/get_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/cpp/detail/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/cpp/detail/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/logical.h" "$(@D)/cuda/include/thrust/system/cpp/detail/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/cpp/detail/iter_swap.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/par.h" "$(@D)/cuda/include/thrust/system/cpp/detail/par.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/cpp/detail/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/fill.h" "$(@D)/cuda/include/thrust/system/cpp/detail/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/transform.h" "$(@D)/cuda/include/thrust/system/cpp/detail/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/memory.h" "$(@D)/cuda/include/thrust/system/cpp/memory.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/execution_policy.h" "$(@D)/cuda/include/thrust/system/cuda/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/vector.h" "$(@D)/cuda/include/thrust/system/cuda/vector.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/error.h" "$(@D)/cuda/include/thrust/system/cuda/error.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy_device_to_device.h" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_device_to_device.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/cuda/detail/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/memory.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/memory.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_allocator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_allocator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/grid/grid_mapping.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_mapping.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/grid/grid_barrier.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_barrier.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/grid/grid_even_share.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_even_share.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/grid/grid_queue.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_queue.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_device.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_device.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_run_length_encode.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_run_length_encode.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_partition.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_partition.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_radix_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_radix_sort.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_rle_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_rle_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_histogram_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_histogram_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_reduce_by_key_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_reduce_by_key_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_scan_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_scan_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_select_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_select_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_reduce_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_reduce_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_radix_sort_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_radix_sort_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_scan.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_select.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_select.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_reduce.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_histogram.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_histogram.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_reduce.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_histo.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_histo.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_scan.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_radix_sort_downsweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_radix_sort_downsweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_radix_sort_upsweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_radix_sort_upsweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_satomic.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_satomic.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_sort.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_gatomic.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_gatomic.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_select.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_select.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_scan_prefix_operators.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_scan_prefix_operators.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_reduce_by_key.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_reduce_by_key.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_macro.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_macro.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_namespace.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_namespace.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_radix_sort_upsweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_radix_sort_upsweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_histogram_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_histogram_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_rle_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_rle_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_select_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_select_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_scan_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_scan_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_reduce_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_reduce_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_satomic_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_satomic_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_sort_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_sort_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_gatomic_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_gatomic_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_radix_sort_downsweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_radix_sort_downsweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_reduce_by_key_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_reduce_by_key_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_scan_prefix_operators.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_scan_prefix_operators.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_type.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_type.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/host/spinlock.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/host/spinlock.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/warp/warp_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/warp_reduce.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/warp/warp_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/warp_scan.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_shfl.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_shfl.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_smem.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_smem.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_shfl.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_shfl.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_smem.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_smem.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_ptx.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_ptx.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_debug.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_debug.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/cub.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/cub.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/transform_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/transform_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/tex_obj_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/tex_obj_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/tex_ref_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/tex_ref_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/cache_modified_output_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_output_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/cache_modified_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/arg_index_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/arg_index_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/constant_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/constant_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_scan.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_load.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_load.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_discontinuity.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_discontinuity.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_radix_rank.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_radix_rank.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_shift.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_shift.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_store.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_store.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_reduce.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_exchange.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_exchange.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_radix_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_radix_sort.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_histogram.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_histogram.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_raking_layout.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_raking_layout.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_warp_reductions.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_warp_reductions.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking_commutative_only.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking_commutative_only.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_atomic.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_atomic.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_raking.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_raking.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_sort.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/thread/thread_load.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_load.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/thread/thread_store.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_store.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/thread/thread_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_scan.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/thread/thread_operators.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_operators.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/thread/thread_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_reduce.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_arch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_arch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce_intervals.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce_intervals.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy_cross_system.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_cross_system.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/cuda/detail/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk.h" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/sort.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/partition.h" "$(@D)/cuda/include/thrust/system/cuda/detail/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/unique.h" "$(@D)/cuda/include/thrust/system/cuda/detail/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cuda_launch_config.h" "$(@D)/cuda/include/thrust/system/cuda/detail/cuda_launch_config.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub.h" "$(@D)/cuda/include/thrust/system/cuda/detail/cub.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/cuda/detail/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/sequence.h" "$(@D)/cuda/include/thrust/system/cuda/detail/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/merge.h" "$(@D)/cuda/include/thrust/system/cuda/detail/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/set_symmetric_difference.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/set_symmetric_difference.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy_if.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_if.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/cuda/detail/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/error.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/error.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/gather.h" "$(@D)/cuda/include/thrust/system/cuda/detail/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/synchronize.h" "$(@D)/cuda/include/thrust/system/cuda/detail/synchronize.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/scan.h" "$(@D)/cuda/include/thrust/system/cuda/detail/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/temporary_indirect_permutation.h" "$(@D)/cuda/include/thrust/system/cuda/detail/temporary_indirect_permutation.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/extern_shared_ptr.h" "$(@D)/cuda/include/thrust/system/cuda/detail/extern_shared_ptr.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/set_operation.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/set_operation.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/balanced_path.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/balanced_path.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/virtualized_smem_closure.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/virtualized_smem_closure.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_primitive_sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_primitive_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/set_operation.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/set_operation.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_primitive_sort.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_primitive_sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_merge_sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_merge_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/launch_closure.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/launch_closure.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/merge.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/alignment.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/alignment.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_radix_sort.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_radix_sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_sort_each.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_sort_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/launch_calculator.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/launch_calculator.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_merge_sort.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_merge_sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/launch_closure.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/launch_closure.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_radix_sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_radix_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/uninitialized.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/uninitialized.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/cached_temporary_allocator.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/cached_temporary_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/launch_calculator.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/launch_calculator.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_sort_each.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_sort_each.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/cuda/detail/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/default_decomposition.h" "$(@D)/cuda/include/thrust/system/cuda/detail/default_decomposition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/cuda/detail/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reverse.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/cuda/detail/assign_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/scatter.h" "$(@D)/cuda/include/thrust/system/cuda/detail/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce_intervals.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce_intervals.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/for_each.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/for_each.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/default_decomposition.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/default_decomposition.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/guarded_cuda_runtime_api.h" "$(@D)/cuda/include/thrust/system/cuda/detail/guarded_cuda_runtime_api.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/adjacent_difference.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/adjacent_difference.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/vector.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/vector.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/throw_on_error.h" "$(@D)/cuda/include/thrust/system/cuda/detail/throw_on_error.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/find.h" "$(@D)/cuda/include/thrust/system/cuda/detail/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/terminate.h" "$(@D)/cuda/include/thrust/system/cuda/detail/terminate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/merge.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/merge.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/trivial_copy.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/trivial_copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/generate.h" "$(@D)/cuda/include/thrust/system/cuda/detail/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/execute_on_stream.h" "$(@D)/cuda/include/thrust/system/cuda/detail/execute_on_stream.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/cuda/detail/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/remove.h" "$(@D)/cuda/include/thrust/system/cuda/detail/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/cuda/detail/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/for_each.h" "$(@D)/cuda/include/thrust/system/cuda/detail/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/decomposition.h" "$(@D)/cuda/include/thrust/system/cuda/detail/decomposition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/equal.h" "$(@D)/cuda/include/thrust/system/cuda/detail/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/runtime_introspection.h" "$(@D)/cuda/include/thrust/system/cuda/detail/runtime_introspection.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/cuda/detail/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/cuda/detail/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/runtime_introspection.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/runtime_introspection.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/cuda/detail/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/cuda/detail/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/scan.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/scan.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/synchronize.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/synchronize.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/extrema.h" "$(@D)/cuda/include/thrust/system/cuda/detail/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/set_union.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/set_union.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/set_intersection.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/set_intersection.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/count.h" "$(@D)/cuda/include/thrust/system/cuda/detail/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/trivial_copy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/trivial_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy_device_to_device.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_device_to_device.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/replace.h" "$(@D)/cuda/include/thrust/system/cuda/detail/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/malloc.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/malloc.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/config.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/config.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/closure.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/closure.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/tail_flags.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/tail_flags.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/terminate.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/terminate.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/alignment.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/alignment.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/guarded_cuda_runtime_api.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/guarded_cuda_runtime_api.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/choose_sizes.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/choose_sizes.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/tuple_meta_transform.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/tuple_meta_transform.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_task.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_task.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/head_flags.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/head_flags.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/synchronize.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/synchronize.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/throw_on_error.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/throw_on_error.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/parameter_ptr.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/parameter_ptr.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/cuda_launcher.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/cuda_launcher.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/triple_chevron_launcher.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/triple_chevron_launcher.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/runtime_introspection.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/runtime_introspection.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/cuda_launch_config.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/cuda_launch_config.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/runtime_introspection.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/runtime_introspection.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/async.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/async.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/tuple_transform.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/tuple_transform.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/pointer_traits.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/pointer_traits.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/apply_from_tuple.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/apply_from_tuple.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/is_contiguous_iterator.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/is_contiguous_iterator.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/iterator.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/iterator.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/choose_sizes.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/choose_sizes.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/copy.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/copy.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/merge.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/merge.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/accumulate.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/accumulate.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/scan.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/scan.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/detail/stable_merge_sort.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/detail/stable_merge_sort.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/gather.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/gather.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/sort.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/sort.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/reduce.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/reduce.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/scatter.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/scatter.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/adjacent_difference.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/adjacent_difference.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/reduce_by_key.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/reduce_by_key.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/for_each.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/for_each.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/bulk.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/bulk.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/execution_policy.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/execution_policy.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/iterator/strided_iterator.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/iterator/strided_iterator.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/uninitialized.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/uninitialized.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/async.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/async.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/future.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/future.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/guarded_driver_types.h" "$(@D)/cuda/include/thrust/system/cuda/detail/guarded_driver_types.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/get_value.h" "$(@D)/cuda/include/thrust/system/cuda/detail/get_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/cuda/detail/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/logical.h" "$(@D)/cuda/include/thrust/system/cuda/detail/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/cuda/detail/iter_swap.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/merge.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/inclusive_scan.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/inclusive_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/merge.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/block/merge.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/merging_sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/merging_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/exclusive_scan.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/exclusive_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/reduce.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/copy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/odd_even_sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/odd_even_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/par.h" "$(@D)/cuda/include/thrust/system/cuda/detail/par.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy_cross_system.h" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_cross_system.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce_intervals.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce_intervals.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/cuda/detail/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/fill.h" "$(@D)/cuda/include/thrust/system/cuda/detail/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/set_difference.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/set_difference.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/transform.h" "$(@D)/cuda/include/thrust/system/cuda/detail/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/experimental/pinned_allocator.h" "$(@D)/cuda/include/thrust/system/cuda/experimental/pinned_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/memory.h" "$(@D)/cuda/include/thrust/system/cuda/memory.h" && cp "/usr/local/cuda-8.0/include/thrust/remove.h" "$(@D)/cuda/include/thrust/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/tabulate.h" "$(@D)/cuda/include/thrust/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/for_each.h" "$(@D)/cuda/include/thrust/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/distance.h" "$(@D)/cuda/include/thrust/distance.h" && cp "/usr/local/cuda-8.0/include/thrust/reduce.h" "$(@D)/cuda/include/thrust/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/equal.h" "$(@D)/cuda/include/thrust/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/complex.h" "$(@D)/cuda/include/thrust/complex.h" && cp "/usr/local/cuda-8.0/include/thrust/device_allocator.h" "$(@D)/cuda/include/thrust/device_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/copy.h" "$(@D)/cuda/include/thrust/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/uninitialized_copy.h" "$(@D)/cuda/include/thrust/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/device_reference.h" "$(@D)/cuda/include/thrust/device_reference.h" && cp "/usr/local/cuda-8.0/include/thrust/binary_search.h" "$(@D)/cuda/include/thrust/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/set_operations.h" "$(@D)/cuda/include/thrust/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/swap.h" "$(@D)/cuda/include/thrust/swap.h" && cp "/usr/local/cuda-8.0/include/thrust/mismatch.h" "$(@D)/cuda/include/thrust/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/extrema.h" "$(@D)/cuda/include/thrust/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/count.h" "$(@D)/cuda/include/thrust/count.h" && cp "/usr/local/cuda-8.0/include/thrust/device_free.h" "$(@D)/cuda/include/thrust/device_free.h" && cp "/usr/local/cuda-8.0/include/thrust/random/discard_block_engine.h" "$(@D)/cuda/include/thrust/random/discard_block_engine.h" && cp "/usr/local/cuda-8.0/include/thrust/random/normal_distribution.h" "$(@D)/cuda/include/thrust/random/normal_distribution.h" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/linear_feedback_shift_engine_wordmask.h" "$(@D)/cuda/include/thrust/random/detail/linear_feedback_shift_engine_wordmask.h" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/subtract_with_carry_engine.inl" "$(@D)/cuda/include/thrust/random/detail/subtract_with_carry_engine.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/xor_combine_engine_max.h" "$(@D)/cuda/include/thrust/random/detail/xor_combine_engine_max.h" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/linear_congruential_engine_discard.h" "$(@D)/cuda/include/thrust/random/detail/linear_congruential_engine_discard.h" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/uniform_int_distribution.inl" "$(@D)/cuda/include/thrust/random/detail/uniform_int_distribution.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/discard_block_engine.inl" "$(@D)/cuda/include/thrust/random/detail/discard_block_engine.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/uniform_real_distribution.inl" "$(@D)/cuda/include/thrust/random/detail/uniform_real_distribution.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/random_core_access.h" "$(@D)/cuda/include/thrust/random/detail/random_core_access.h" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/mod.h" "$(@D)/cuda/include/thrust/random/detail/mod.h" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/linear_feedback_shift_engine.inl" "$(@D)/cuda/include/thrust/random/detail/linear_feedback_shift_engine.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/linear_congruential_engine.inl" "$(@D)/cuda/include/thrust/random/detail/linear_congruential_engine.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/xor_combine_engine.inl" "$(@D)/cuda/include/thrust/random/detail/xor_combine_engine.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/normal_distribution.inl" "$(@D)/cuda/include/thrust/random/detail/normal_distribution.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/normal_distribution_base.h" "$(@D)/cuda/include/thrust/random/detail/normal_distribution_base.h" && cp "/usr/local/cuda-8.0/include/thrust/random/uniform_int_distribution.h" "$(@D)/cuda/include/thrust/random/uniform_int_distribution.h" && cp "/usr/local/cuda-8.0/include/thrust/random/linear_feedback_shift_engine.h" "$(@D)/cuda/include/thrust/random/linear_feedback_shift_engine.h" && cp "/usr/local/cuda-8.0/include/thrust/random/xor_combine_engine.h" "$(@D)/cuda/include/thrust/random/xor_combine_engine.h" && cp "/usr/local/cuda-8.0/include/thrust/random/subtract_with_carry_engine.h" "$(@D)/cuda/include/thrust/random/subtract_with_carry_engine.h" && cp "/usr/local/cuda-8.0/include/thrust/random/linear_congruential_engine.h" "$(@D)/cuda/include/thrust/random/linear_congruential_engine.h" && cp "/usr/local/cuda-8.0/include/thrust/random/uniform_real_distribution.h" "$(@D)/cuda/include/thrust/random/uniform_real_distribution.h" && cp "/usr/local/cuda-8.0/include/thrust/functional.h" "$(@D)/cuda/include/thrust/functional.h" && cp "/usr/local/cuda-8.0/include/thrust/replace.h" "$(@D)/cuda/include/thrust/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/device_new_allocator.h" "$(@D)/cuda/include/thrust/device_new_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/host_vector.h" "$(@D)/cuda/include/thrust/host_vector.h" && cp "/usr/local/cuda-8.0/include/thrust/version.h" "$(@D)/cuda/include/thrust/version.h" && cp "/usr/local/cuda-8.0/include/thrust/inner_product.h" "$(@D)/cuda/include/thrust/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/iterator_traits.h" "$(@D)/cuda/include/thrust/iterator/iterator_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/discard_iterator.h" "$(@D)/cuda/include/thrust/iterator/discard_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/retag.h" "$(@D)/cuda/include/thrust/iterator/retag.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/permutation_iterator.h" "$(@D)/cuda/include/thrust/iterator/permutation_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/transform_iterator.h" "$(@D)/cuda/include/thrust/iterator/transform_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/reverse_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/reverse_iterator.inl" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/zip_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/zip_iterator.inl" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/counting_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/counting_iterator.inl" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/distance_from_result.h" "$(@D)/cuda/include/thrust/iterator/detail/distance_from_result.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/host_system_tag.h" "$(@D)/cuda/include/thrust/iterator/detail/host_system_tag.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_traversal_tags.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_traversal_tags.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/retag.h" "$(@D)/cuda/include/thrust/iterator/detail/retag.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/tagged_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/tagged_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_traits.inl" "$(@D)/cuda/include/thrust/iterator/detail/iterator_traits.inl" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/minimum_category.h" "$(@D)/cuda/include/thrust/iterator/detail/minimum_category.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/discard_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/discard_iterator_base.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_category_to_traversal.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_category_to_traversal.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/zip_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/zip_iterator_base.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/normal_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/normal_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/join_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/join_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/device_system_tag.h" "$(@D)/cuda/include/thrust/iterator/detail/device_system_tag.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/universal_categories.h" "$(@D)/cuda/include/thrust/iterator/detail/universal_categories.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/reverse_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/reverse_iterator_base.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/minimum_system.h" "$(@D)/cuda/include/thrust/iterator/detail/minimum_system.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/tuple_of_iterator_references.h" "$(@D)/cuda/include/thrust/iterator/detail/tuple_of_iterator_references.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/is_iterator_category.h" "$(@D)/cuda/include/thrust/iterator/detail/is_iterator_category.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/permutation_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/permutation_iterator_base.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/any_assign.h" "$(@D)/cuda/include/thrust/iterator/detail/any_assign.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/any_system_tag.h" "$(@D)/cuda/include/thrust/iterator/detail/any_system_tag.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/is_trivial_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/is_trivial_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_category_to_system.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_category_to_system.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_adaptor_base.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_adaptor_base.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/constant_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/constant_iterator_base.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/transform_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/transform_iterator.inl" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_facade_category.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_facade_category.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_category_with_system_and_traversal.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_category_with_system_and_traversal.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/constant_iterator.h" "$(@D)/cuda/include/thrust/iterator/constant_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/counting_iterator.h" "$(@D)/cuda/include/thrust/iterator/counting_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/iterator_adaptor.h" "$(@D)/cuda/include/thrust/iterator/iterator_adaptor.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/iterator_facade.h" "$(@D)/cuda/include/thrust/iterator/iterator_facade.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/iterator_categories.h" "$(@D)/cuda/include/thrust/iterator/iterator_categories.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/reverse_iterator.h" "$(@D)/cuda/include/thrust/iterator/reverse_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/zip_iterator.h" "$(@D)/cuda/include/thrust/iterator/zip_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/logical.h" "$(@D)/cuda/include/thrust/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/tuple.h" "$(@D)/cuda/include/thrust/tuple.h" && cp "/usr/local/cuda-8.0/include/thrust/memory.h" "$(@D)/cuda/include/thrust/memory.h" && cp "/usr/local/cuda-8.0/include/thrust/random.h" "$(@D)/cuda/include/thrust/random.h" && cp "/usr/local/cuda-8.0/include/thrust/fill.h" "$(@D)/cuda/include/thrust/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/transform.h" "$(@D)/cuda/include/thrust/transform.h" && cp "/usr/local/cuda-8.0/include/texture_types.h" "$(@D)/cuda/include/texture_types.h" && cp "/usr/local/cuda-8.0/include/nppversion.h" "$(@D)/cuda/include/nppversion.h" && cp "/usr/local/cuda-8.0/include/cuda_texture_types.h" "$(@D)/cuda/include/cuda_texture_types.h" && cp "/usr/local/cuda-8.0/include/fatbinary.h" "$(@D)/cuda/include/fatbinary.h" && cp "/usr/local/cuda-8.0/include/cublasXt.h" "$(@D)/cuda/include/cublasXt.h" && cp "/usr/local/cuda-8.0/include/cuda_fp16.h" "$(@D)/cuda/include/cuda_fp16.h" && cp "/usr/local/cuda-8.0/include/vector_functions.h" "$(@D)/cuda/include/vector_functions.h" && cp "/usr/local/cuda-8.0/include/cusparse.h" "$(@D)/cuda/include/cusparse.h" && cp "/usr/local/cuda-8.0/include/nppi_filtering_functions.h" "$(@D)/cuda/include/nppi_filtering_functions.h" && cp "/usr/local/cuda-8.0/include/nppi_morphological_operations.h" "$(@D)/cuda/include/nppi_morphological_operations.h" && cp "/usr/local/cuda-8.0/include/sobol_direction_vectors.h" "$(@D)/cuda/include/sobol_direction_vectors.h" && cp "/usr/local/cuda-8.0/include/nvblas.h" "$(@D)/cuda/include/nvblas.h" && cp "/usr/local/cuda-8.0/include/curand_mtgp32dc_p_11213.h" "$(@D)/cuda/include/curand_mtgp32dc_p_11213.h" && cp "/usr/local/cuda-8.0/include/nvcuvid.h" "$(@D)/cuda/include/nvcuvid.h" && cp "/usr/local/cuda-8.0/include/cuda_runtime_api.h" "$(@D)/cuda/include/cuda_runtime_api.h" && cp "/usr/local/cuda-8.0/include/curand_mtgp32_kernel.h" "$(@D)/cuda/include/curand_mtgp32_kernel.h" && cp "/usr/local/cuda-8.0/include/cublas_v2.h" "$(@D)/cuda/include/cublas_v2.h" && cp "/usr/local/cuda-8.0/include/builtin_types.h" "$(@D)/cuda/include/builtin_types.h" && cp "/usr/local/cuda-8.0/include/nppi_geometry_transforms.h" "$(@D)/cuda/include/nppi_geometry_transforms.h" && cp "/usr/local/cuda-8.0/include/npps_support_functions.h" "$(@D)/cuda/include/npps_support_functions.h" && cp "/usr/local/cuda-8.0/include/cufftw.h" "$(@D)/cuda/include/cufftw.h" && cp "/usr/local/cuda-8.0/include/cuda_device_runtime_api.h" "$(@D)/cuda/include/cuda_device_runtime_api.h" && cp "/usr/local/cuda-8.0/include/sm_30_intrinsics.hpp" "$(@D)/cuda/include/sm_30_intrinsics.hpp" && cp "/usr/local/cuda-8.0/include/vector_types.h" "$(@D)/cuda/include/vector_types.h" && cp "/usr/local/cuda-8.0/include/sm_35_atomic_functions.h" "$(@D)/cuda/include/sm_35_atomic_functions.h" && cp "/usr/local/cuda-8.0/include/sm_20_intrinsics.h" "$(@D)/cuda/include/sm_20_intrinsics.h" && cp "/usr/local/cuda-8.0/include/driver_types.h" "$(@D)/cuda/include/driver_types.h" && cp "/usr/local/cuda-8.0/include/nvToolsExtCudaRt.h" "$(@D)/cuda/include/nvToolsExtCudaRt.h" && cp "/usr/local/cuda-8.0/include/curand_globals.h" "$(@D)/cuda/include/curand_globals.h" && cp "/usr/local/cuda-8.0/include/device_atomic_functions.h" "$(@D)/cuda/include/device_atomic_functions.h" && cp "/usr/local/cuda-8.0/include/surface_types.h" "$(@D)/cuda/include/surface_types.h" && cp "/usr/local/cuda-8.0/include/nvrtc.h" "$(@D)/cuda/include/nvrtc.h" && cp "/usr/local/cuda-8.0/include/nppdefs.h" "$(@D)/cuda/include/nppdefs.h" && cp "/usr/local/cuda-8.0/include/sm_60_atomic_functions.h" "$(@D)/cuda/include/sm_60_atomic_functions.h" && cp "/usr/local/cuda-8.0/include/driver_functions.h" "$(@D)/cuda/include/driver_functions.h" && cp "/usr/local/cuda-8.0/include/cusolver_common.h" "$(@D)/cuda/include/cusolver_common.h" && cp "/usr/local/cuda-8.0/include/cublas.h" "$(@D)/cuda/include/cublas.h" && cp "/usr/local/cuda-8.0/include/curand_lognormal.h" "$(@D)/cuda/include/curand_lognormal.h" && cp "/usr/local/cuda-8.0/include/device_atomic_functions.hpp" "$(@D)/cuda/include/device_atomic_functions.hpp" && cp "/usr/local/cuda-8.0/include/crt/device_runtime.h" "$(@D)/cuda/include/crt/device_runtime.h" && cp "/usr/local/cuda-8.0/include/crt/storage_class.h" "$(@D)/cuda/include/crt/storage_class.h" && cp "/usr/local/cuda-8.0/include/crt/func_macro.h" "$(@D)/cuda/include/crt/func_macro.h" && cp "/usr/local/cuda-8.0/include/crt/host_runtime.h" "$(@D)/cuda/include/crt/host_runtime.h" && cp "/usr/local/cuda-8.0/include/nppi_arithmetic_and_logical_operations.h" "$(@D)/cuda/include/nppi_arithmetic_and_logical_operations.h" && cp "/usr/local/cuda-8.0/include/npps_arithmetic_and_logical_operations.h" "$(@D)/cuda/include/npps_arithmetic_and_logical_operations.h" && cp "/usr/local/cuda-8.0/include/nppi_computer_vision.h" "$(@D)/cuda/include/nppi_computer_vision.h" && cp "/usr/local/cuda-8.0/include/surface_functions.hpp" "$(@D)/cuda/include/surface_functions.hpp" && cp "/usr/local/cuda-8.0/include/surface_functions.h" "$(@D)/cuda/include/surface_functions.h" && cp "/usr/local/cuda-8.0/include/curand_normal_static.h" "$(@D)/cuda/include/curand_normal_static.h" && cp "/usr/local/cuda-8.0/include/curand.h" "$(@D)/cuda/include/curand.h" && cp "/usr/local/cuda-8.0/include/math_functions_dbl_ptx3.h" "$(@D)/cuda/include/math_functions_dbl_ptx3.h" && cp "/usr/local/cuda-8.0/include/curand_philox4x32_x.h" "$(@D)/cuda/include/curand_philox4x32_x.h" && cp "/usr/local/cuda-8.0/include/nppi_threshold_and_compare_operations.h" "$(@D)/cuda/include/nppi_threshold_and_compare_operations.h" && cp "/usr/local/cuda-8.0/include/nvml.h" "$(@D)/cuda/include/nvml.h" && cp "/usr/local/cuda-8.0/include/npps.h" "$(@D)/cuda/include/npps.h" && cp "/usr/local/cuda-8.0/include/cuda_vdpau_interop.h" "$(@D)/cuda/include/cuda_vdpau_interop.h" && cp "/usr/local/cuda-8.0/include/sm_61_intrinsics.hpp" "$(@D)/cuda/include/sm_61_intrinsics.hpp" && cp "/usr/local/cuda-8.0/include/cublas_api.h" "$(@D)/cuda/include/cublas_api.h" && cp "/usr/local/cuda-8.0/include/nppi_color_conversion.h" "$(@D)/cuda/include/nppi_color_conversion.h" && cp "/usr/local/cuda-8.0/include/math_functions_dbl_ptx3.hpp" "$(@D)/cuda/include/math_functions_dbl_ptx3.hpp" && cp "/usr/local/cuda-8.0/include/nppcore.h" "$(@D)/cuda/include/nppcore.h" && cp "/usr/local/cuda-8.0/include/cudaGL.h" "$(@D)/cuda/include/cudaGL.h" && cp "/usr/local/cuda-8.0/include/fatBinaryCtl.h" "$(@D)/cuda/include/fatBinaryCtl.h" && cp "/usr/local/cuda-8.0/include/npps_statistics_functions.h" "$(@D)/cuda/include/npps_statistics_functions.h" && cp "/usr/local/cuda-8.0/include/cudaVDPAU.h" "$(@D)/cuda/include/cudaVDPAU.h" && cp "/usr/local/cuda-8.0/include/curand_poisson.h" "$(@D)/cuda/include/curand_poisson.h" && cp "/usr/local/cuda-8.0/include/cusolverDn.h" "$(@D)/cuda/include/cusolverDn.h" && cp "/usr/local/cuda-8.0/include/cuda_profiler_api.h" "$(@D)/cuda/include/cuda_profiler_api.h" && cp "/usr/local/cuda-8.0/include/sm_20_atomic_functions.h" "$(@D)/cuda/include/sm_20_atomic_functions.h" && cp "/usr/local/cuda-8.0/include/nvfunctional" "$(@D)/cuda/include/nvfunctional" +if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/include/CL/cl.h" "$(@D)/cuda/include/CL/cl.h" && cp "/usr/local/cuda-9.0/include/CL/cl.hpp" "$(@D)/cuda/include/CL/cl.hpp" && cp "/usr/local/cuda-9.0/include/CL/cl_egl.h" "$(@D)/cuda/include/CL/cl_egl.h" && cp "/usr/local/cuda-9.0/include/CL/cl_ext.h" "$(@D)/cuda/include/CL/cl_ext.h" && cp "/usr/local/cuda-9.0/include/CL/cl_gl.h" "$(@D)/cuda/include/CL/cl_gl.h" && cp "/usr/local/cuda-9.0/include/CL/cl_gl_ext.h" "$(@D)/cuda/include/CL/cl_gl_ext.h" && cp "/usr/local/cuda-9.0/include/CL/cl_platform.h" "$(@D)/cuda/include/CL/cl_platform.h" && cp "/usr/local/cuda-9.0/include/CL/opencl.h" "$(@D)/cuda/include/CL/opencl.h" && cp "/usr/local/cuda-9.0/include/builtin_types.h" "$(@D)/cuda/include/builtin_types.h" && cp "/usr/local/cuda-9.0/include/channel_descriptor.h" "$(@D)/cuda/include/channel_descriptor.h" && cp "/usr/local/cuda-9.0/include/common_functions.h" "$(@D)/cuda/include/common_functions.h" && cp "/usr/local/cuda-9.0/include/cooperative_groups.h" "$(@D)/cuda/include/cooperative_groups.h" && cp "/usr/local/cuda-9.0/include/cooperative_groups_helpers.h" "$(@D)/cuda/include/cooperative_groups_helpers.h" && cp "/usr/local/cuda-9.0/include/crt/common_functions.h" "$(@D)/cuda/include/crt/common_functions.h" && cp "/usr/local/cuda-9.0/include/crt/device_double_functions.h" "$(@D)/cuda/include/crt/device_double_functions.h" && cp "/usr/local/cuda-9.0/include/crt/device_double_functions.hpp" "$(@D)/cuda/include/crt/device_double_functions.hpp" && cp "/usr/local/cuda-9.0/include/crt/device_functions.h" "$(@D)/cuda/include/crt/device_functions.h" && cp "/usr/local/cuda-9.0/include/crt/device_functions.hpp" "$(@D)/cuda/include/crt/device_functions.hpp" && cp "/usr/local/cuda-9.0/include/crt/func_macro.h" "$(@D)/cuda/include/crt/func_macro.h" && cp "/usr/local/cuda-9.0/include/crt/host_config.h" "$(@D)/cuda/include/crt/host_config.h" && cp "/usr/local/cuda-9.0/include/crt/host_defines.h" "$(@D)/cuda/include/crt/host_defines.h" && cp "/usr/local/cuda-9.0/include/crt/host_runtime.h" "$(@D)/cuda/include/crt/host_runtime.h" && cp "/usr/local/cuda-9.0/include/crt/math_functions.h" "$(@D)/cuda/include/crt/math_functions.h" && cp "/usr/local/cuda-9.0/include/crt/math_functions.hpp" "$(@D)/cuda/include/crt/math_functions.hpp" && cp "/usr/local/cuda-9.0/include/crt/mma.h" "$(@D)/cuda/include/crt/mma.h" && cp "/usr/local/cuda-9.0/include/crt/mma.hpp" "$(@D)/cuda/include/crt/mma.hpp" && cp "/usr/local/cuda-9.0/include/crt/nvfunctional" "$(@D)/cuda/include/crt/nvfunctional" && cp "/usr/local/cuda-9.0/include/crt/sm_70_rt.h" "$(@D)/cuda/include/crt/sm_70_rt.h" && cp "/usr/local/cuda-9.0/include/crt/sm_70_rt.hpp" "$(@D)/cuda/include/crt/sm_70_rt.hpp" && cp "/usr/local/cuda-9.0/include/crt/storage_class.h" "$(@D)/cuda/include/crt/storage_class.h" && cp "/usr/local/cuda-9.0/include/cuComplex.h" "$(@D)/cuda/include/cuComplex.h" && cp "/usr/local/cuda-9.0/include/cublas.h" "$(@D)/cuda/include/cublas.h" && cp "/usr/local/cuda-9.0/include/cublasXt.h" "$(@D)/cuda/include/cublasXt.h" && cp "/usr/local/cuda-9.0/include/cublas_api.h" "$(@D)/cuda/include/cublas_api.h" && cp "/usr/local/cuda-9.0/include/cublas_v2.h" "$(@D)/cuda/include/cublas_v2.h" && cp "/usr/local/cuda-9.0/include/cuda.h" "$(@D)/cuda/include/cuda.h" && cp "/usr/local/cuda-9.0/include/cudaEGL.h" "$(@D)/cuda/include/cudaEGL.h" && cp "/usr/local/cuda-9.0/include/cudaGL.h" "$(@D)/cuda/include/cudaGL.h" && cp "/usr/local/cuda-9.0/include/cudaProfiler.h" "$(@D)/cuda/include/cudaProfiler.h" && cp "/usr/local/cuda-9.0/include/cudaVDPAU.h" "$(@D)/cuda/include/cudaVDPAU.h" && cp "/usr/local/cuda-9.0/include/cuda_device_runtime_api.h" "$(@D)/cuda/include/cuda_device_runtime_api.h" && cp "/usr/local/cuda-9.0/include/cuda_fp16.h" "$(@D)/cuda/include/cuda_fp16.h" && cp "/usr/local/cuda-9.0/include/cuda_fp16.hpp" "$(@D)/cuda/include/cuda_fp16.hpp" && cp "/usr/local/cuda-9.0/include/cuda_gl_interop.h" "$(@D)/cuda/include/cuda_gl_interop.h" && cp "/usr/local/cuda-9.0/include/cuda_occupancy.h" "$(@D)/cuda/include/cuda_occupancy.h" && cp "/usr/local/cuda-9.0/include/cuda_profiler_api.h" "$(@D)/cuda/include/cuda_profiler_api.h" && cp "/usr/local/cuda-9.0/include/cuda_runtime.h" "$(@D)/cuda/include/cuda_runtime.h" && cp "/usr/local/cuda-9.0/include/cuda_runtime_api.h" "$(@D)/cuda/include/cuda_runtime_api.h" && cp "/usr/local/cuda-9.0/include/cuda_surface_types.h" "$(@D)/cuda/include/cuda_surface_types.h" && cp "/usr/local/cuda-9.0/include/cuda_texture_types.h" "$(@D)/cuda/include/cuda_texture_types.h" && cp "/usr/local/cuda-9.0/include/cuda_vdpau_interop.h" "$(@D)/cuda/include/cuda_vdpau_interop.h" && cp "/usr/local/cuda-9.0/include/cudalibxt.h" "$(@D)/cuda/include/cudalibxt.h" && cp "/usr/local/cuda-9.0/include/cudnn.h" "$(@D)/cuda/include/cudnn.h" && cp "/usr/local/cuda-9.0/include/cufft.h" "$(@D)/cuda/include/cufft.h" && cp "/usr/local/cuda-9.0/include/cufftXt.h" "$(@D)/cuda/include/cufftXt.h" && cp "/usr/local/cuda-9.0/include/cufftw.h" "$(@D)/cuda/include/cufftw.h" && cp "/usr/local/cuda-9.0/include/curand.h" "$(@D)/cuda/include/curand.h" && cp "/usr/local/cuda-9.0/include/curand_discrete.h" "$(@D)/cuda/include/curand_discrete.h" && cp "/usr/local/cuda-9.0/include/curand_discrete2.h" "$(@D)/cuda/include/curand_discrete2.h" && cp "/usr/local/cuda-9.0/include/curand_globals.h" "$(@D)/cuda/include/curand_globals.h" && cp "/usr/local/cuda-9.0/include/curand_kernel.h" "$(@D)/cuda/include/curand_kernel.h" && cp "/usr/local/cuda-9.0/include/curand_lognormal.h" "$(@D)/cuda/include/curand_lognormal.h" && cp "/usr/local/cuda-9.0/include/curand_mrg32k3a.h" "$(@D)/cuda/include/curand_mrg32k3a.h" && cp "/usr/local/cuda-9.0/include/curand_mtgp32.h" "$(@D)/cuda/include/curand_mtgp32.h" && cp "/usr/local/cuda-9.0/include/curand_mtgp32_host.h" "$(@D)/cuda/include/curand_mtgp32_host.h" && cp "/usr/local/cuda-9.0/include/curand_mtgp32_kernel.h" "$(@D)/cuda/include/curand_mtgp32_kernel.h" && cp "/usr/local/cuda-9.0/include/curand_mtgp32dc_p_11213.h" "$(@D)/cuda/include/curand_mtgp32dc_p_11213.h" && cp "/usr/local/cuda-9.0/include/curand_normal.h" "$(@D)/cuda/include/curand_normal.h" && cp "/usr/local/cuda-9.0/include/curand_normal_static.h" "$(@D)/cuda/include/curand_normal_static.h" && cp "/usr/local/cuda-9.0/include/curand_philox4x32_x.h" "$(@D)/cuda/include/curand_philox4x32_x.h" && cp "/usr/local/cuda-9.0/include/curand_poisson.h" "$(@D)/cuda/include/curand_poisson.h" && cp "/usr/local/cuda-9.0/include/curand_precalc.h" "$(@D)/cuda/include/curand_precalc.h" && cp "/usr/local/cuda-9.0/include/curand_uniform.h" "$(@D)/cuda/include/curand_uniform.h" && cp "/usr/local/cuda-9.0/include/cusolverDn.h" "$(@D)/cuda/include/cusolverDn.h" && cp "/usr/local/cuda-9.0/include/cusolverRf.h" "$(@D)/cuda/include/cusolverRf.h" && cp "/usr/local/cuda-9.0/include/cusolverSp.h" "$(@D)/cuda/include/cusolverSp.h" && cp "/usr/local/cuda-9.0/include/cusolverSp_LOWLEVEL_PREVIEW.h" "$(@D)/cuda/include/cusolverSp_LOWLEVEL_PREVIEW.h" && cp "/usr/local/cuda-9.0/include/cusolver_common.h" "$(@D)/cuda/include/cusolver_common.h" && cp "/usr/local/cuda-9.0/include/cusparse.h" "$(@D)/cuda/include/cusparse.h" && cp "/usr/local/cuda-9.0/include/cusparse_v2.h" "$(@D)/cuda/include/cusparse_v2.h" && cp "/usr/local/cuda-9.0/include/device_atomic_functions.h" "$(@D)/cuda/include/device_atomic_functions.h" && cp "/usr/local/cuda-9.0/include/device_atomic_functions.hpp" "$(@D)/cuda/include/device_atomic_functions.hpp" && cp "/usr/local/cuda-9.0/include/device_double_functions.h" "$(@D)/cuda/include/device_double_functions.h" && cp "/usr/local/cuda-9.0/include/device_double_functions.hpp" "$(@D)/cuda/include/device_double_functions.hpp" && cp "/usr/local/cuda-9.0/include/device_functions.h" "$(@D)/cuda/include/device_functions.h" && cp "/usr/local/cuda-9.0/include/device_functions.hpp" "$(@D)/cuda/include/device_functions.hpp" && cp "/usr/local/cuda-9.0/include/device_functions_decls.h" "$(@D)/cuda/include/device_functions_decls.h" && cp "/usr/local/cuda-9.0/include/device_launch_parameters.h" "$(@D)/cuda/include/device_launch_parameters.h" && cp "/usr/local/cuda-9.0/include/device_types.h" "$(@D)/cuda/include/device_types.h" && cp "/usr/local/cuda-9.0/include/driver_functions.h" "$(@D)/cuda/include/driver_functions.h" && cp "/usr/local/cuda-9.0/include/driver_types.h" "$(@D)/cuda/include/driver_types.h" && cp "/usr/local/cuda-9.0/include/dynlink_cuda.h" "$(@D)/cuda/include/dynlink_cuda.h" && cp "/usr/local/cuda-9.0/include/dynlink_cuda_cuda.h" "$(@D)/cuda/include/dynlink_cuda_cuda.h" && cp "/usr/local/cuda-9.0/include/dynlink_cuviddec.h" "$(@D)/cuda/include/dynlink_cuviddec.h" && cp "/usr/local/cuda-9.0/include/dynlink_nvcuvid.h" "$(@D)/cuda/include/dynlink_nvcuvid.h" && cp "/usr/local/cuda-9.0/include/fatBinaryCtl.h" "$(@D)/cuda/include/fatBinaryCtl.h" && cp "/usr/local/cuda-9.0/include/fatbinary.h" "$(@D)/cuda/include/fatbinary.h" && cp "/usr/local/cuda-9.0/include/host_config.h" "$(@D)/cuda/include/host_config.h" && cp "/usr/local/cuda-9.0/include/host_defines.h" "$(@D)/cuda/include/host_defines.h" && cp "/usr/local/cuda-9.0/include/library_types.h" "$(@D)/cuda/include/library_types.h" && cp "/usr/local/cuda-9.0/include/math_constants.h" "$(@D)/cuda/include/math_constants.h" && cp "/usr/local/cuda-9.0/include/math_functions.h" "$(@D)/cuda/include/math_functions.h" && cp "/usr/local/cuda-9.0/include/math_functions.hpp" "$(@D)/cuda/include/math_functions.hpp" && cp "/usr/local/cuda-9.0/include/math_functions_dbl_ptx3.h" "$(@D)/cuda/include/math_functions_dbl_ptx3.h" && cp "/usr/local/cuda-9.0/include/math_functions_dbl_ptx3.hpp" "$(@D)/cuda/include/math_functions_dbl_ptx3.hpp" && cp "/usr/local/cuda-9.0/include/mma.h" "$(@D)/cuda/include/mma.h" && cp "/usr/local/cuda-9.0/include/npp.h" "$(@D)/cuda/include/npp.h" && cp "/usr/local/cuda-9.0/include/nppcore.h" "$(@D)/cuda/include/nppcore.h" && cp "/usr/local/cuda-9.0/include/nppdefs.h" "$(@D)/cuda/include/nppdefs.h" && cp "/usr/local/cuda-9.0/include/nppi.h" "$(@D)/cuda/include/nppi.h" && cp "/usr/local/cuda-9.0/include/nppi_arithmetic_and_logical_operations.h" "$(@D)/cuda/include/nppi_arithmetic_and_logical_operations.h" && cp "/usr/local/cuda-9.0/include/nppi_color_conversion.h" "$(@D)/cuda/include/nppi_color_conversion.h" && cp "/usr/local/cuda-9.0/include/nppi_compression_functions.h" "$(@D)/cuda/include/nppi_compression_functions.h" && cp "/usr/local/cuda-9.0/include/nppi_computer_vision.h" "$(@D)/cuda/include/nppi_computer_vision.h" && cp "/usr/local/cuda-9.0/include/nppi_data_exchange_and_initialization.h" "$(@D)/cuda/include/nppi_data_exchange_and_initialization.h" && cp "/usr/local/cuda-9.0/include/nppi_filtering_functions.h" "$(@D)/cuda/include/nppi_filtering_functions.h" && cp "/usr/local/cuda-9.0/include/nppi_geometry_transforms.h" "$(@D)/cuda/include/nppi_geometry_transforms.h" && cp "/usr/local/cuda-9.0/include/nppi_linear_transforms.h" "$(@D)/cuda/include/nppi_linear_transforms.h" && cp "/usr/local/cuda-9.0/include/nppi_morphological_operations.h" "$(@D)/cuda/include/nppi_morphological_operations.h" && cp "/usr/local/cuda-9.0/include/nppi_statistics_functions.h" "$(@D)/cuda/include/nppi_statistics_functions.h" && cp "/usr/local/cuda-9.0/include/nppi_support_functions.h" "$(@D)/cuda/include/nppi_support_functions.h" && cp "/usr/local/cuda-9.0/include/nppi_threshold_and_compare_operations.h" "$(@D)/cuda/include/nppi_threshold_and_compare_operations.h" && cp "/usr/local/cuda-9.0/include/npps.h" "$(@D)/cuda/include/npps.h" && cp "/usr/local/cuda-9.0/include/npps_arithmetic_and_logical_operations.h" "$(@D)/cuda/include/npps_arithmetic_and_logical_operations.h" && cp "/usr/local/cuda-9.0/include/npps_conversion_functions.h" "$(@D)/cuda/include/npps_conversion_functions.h" && cp "/usr/local/cuda-9.0/include/npps_filtering_functions.h" "$(@D)/cuda/include/npps_filtering_functions.h" && cp "/usr/local/cuda-9.0/include/npps_initialization.h" "$(@D)/cuda/include/npps_initialization.h" && cp "/usr/local/cuda-9.0/include/npps_statistics_functions.h" "$(@D)/cuda/include/npps_statistics_functions.h" && cp "/usr/local/cuda-9.0/include/npps_support_functions.h" "$(@D)/cuda/include/npps_support_functions.h" && cp "/usr/local/cuda-9.0/include/nppversion.h" "$(@D)/cuda/include/nppversion.h" && cp "/usr/local/cuda-9.0/include/nvToolsExt.h" "$(@D)/cuda/include/nvToolsExt.h" && cp "/usr/local/cuda-9.0/include/nvToolsExtCuda.h" "$(@D)/cuda/include/nvToolsExtCuda.h" && cp "/usr/local/cuda-9.0/include/nvToolsExtCudaRt.h" "$(@D)/cuda/include/nvToolsExtCudaRt.h" && cp "/usr/local/cuda-9.0/include/nvToolsExtMeta.h" "$(@D)/cuda/include/nvToolsExtMeta.h" && cp "/usr/local/cuda-9.0/include/nvToolsExtSync.h" "$(@D)/cuda/include/nvToolsExtSync.h" && cp "/usr/local/cuda-9.0/include/nvblas.h" "$(@D)/cuda/include/nvblas.h" && cp "/usr/local/cuda-9.0/include/nvfunctional" "$(@D)/cuda/include/nvfunctional" && cp "/usr/local/cuda-9.0/include/nvgraph.h" "$(@D)/cuda/include/nvgraph.h" && cp "/usr/local/cuda-9.0/include/nvml.h" "$(@D)/cuda/include/nvml.h" && cp "/usr/local/cuda-9.0/include/nvrtc.h" "$(@D)/cuda/include/nvrtc.h" && cp "/usr/local/cuda-9.0/include/sm_20_atomic_functions.h" "$(@D)/cuda/include/sm_20_atomic_functions.h" && cp "/usr/local/cuda-9.0/include/sm_20_atomic_functions.hpp" "$(@D)/cuda/include/sm_20_atomic_functions.hpp" && cp "/usr/local/cuda-9.0/include/sm_20_intrinsics.h" "$(@D)/cuda/include/sm_20_intrinsics.h" && cp "/usr/local/cuda-9.0/include/sm_20_intrinsics.hpp" "$(@D)/cuda/include/sm_20_intrinsics.hpp" && cp "/usr/local/cuda-9.0/include/sm_30_intrinsics.h" "$(@D)/cuda/include/sm_30_intrinsics.h" && cp "/usr/local/cuda-9.0/include/sm_30_intrinsics.hpp" "$(@D)/cuda/include/sm_30_intrinsics.hpp" && cp "/usr/local/cuda-9.0/include/sm_32_atomic_functions.h" "$(@D)/cuda/include/sm_32_atomic_functions.h" && cp "/usr/local/cuda-9.0/include/sm_32_atomic_functions.hpp" "$(@D)/cuda/include/sm_32_atomic_functions.hpp" && cp "/usr/local/cuda-9.0/include/sm_32_intrinsics.h" "$(@D)/cuda/include/sm_32_intrinsics.h" && cp "/usr/local/cuda-9.0/include/sm_32_intrinsics.hpp" "$(@D)/cuda/include/sm_32_intrinsics.hpp" && cp "/usr/local/cuda-9.0/include/sm_35_atomic_functions.h" "$(@D)/cuda/include/sm_35_atomic_functions.h" && cp "/usr/local/cuda-9.0/include/sm_35_intrinsics.h" "$(@D)/cuda/include/sm_35_intrinsics.h" && cp "/usr/local/cuda-9.0/include/sm_60_atomic_functions.h" "$(@D)/cuda/include/sm_60_atomic_functions.h" && cp "/usr/local/cuda-9.0/include/sm_60_atomic_functions.hpp" "$(@D)/cuda/include/sm_60_atomic_functions.hpp" && cp "/usr/local/cuda-9.0/include/sm_61_intrinsics.h" "$(@D)/cuda/include/sm_61_intrinsics.h" && cp "/usr/local/cuda-9.0/include/sm_61_intrinsics.hpp" "$(@D)/cuda/include/sm_61_intrinsics.hpp" && cp "/usr/local/cuda-9.0/include/sobol_direction_vectors.h" "$(@D)/cuda/include/sobol_direction_vectors.h" && cp "/usr/local/cuda-9.0/include/surface_functions.h" "$(@D)/cuda/include/surface_functions.h" && cp "/usr/local/cuda-9.0/include/surface_functions.hpp" "$(@D)/cuda/include/surface_functions.hpp" && cp "/usr/local/cuda-9.0/include/surface_indirect_functions.h" "$(@D)/cuda/include/surface_indirect_functions.h" && cp "/usr/local/cuda-9.0/include/surface_indirect_functions.hpp" "$(@D)/cuda/include/surface_indirect_functions.hpp" && cp "/usr/local/cuda-9.0/include/surface_types.h" "$(@D)/cuda/include/surface_types.h" && cp "/usr/local/cuda-9.0/include/texture_fetch_functions.h" "$(@D)/cuda/include/texture_fetch_functions.h" && cp "/usr/local/cuda-9.0/include/texture_fetch_functions.hpp" "$(@D)/cuda/include/texture_fetch_functions.hpp" && cp "/usr/local/cuda-9.0/include/texture_indirect_functions.h" "$(@D)/cuda/include/texture_indirect_functions.h" && cp "/usr/local/cuda-9.0/include/texture_indirect_functions.hpp" "$(@D)/cuda/include/texture_indirect_functions.hpp" && cp "/usr/local/cuda-9.0/include/texture_types.h" "$(@D)/cuda/include/texture_types.h" && cp "/usr/local/cuda-9.0/include/thrust/adjacent_difference.h" "$(@D)/cuda/include/thrust/adjacent_difference.h" && cp "/usr/local/cuda-9.0/include/thrust/advance.h" "$(@D)/cuda/include/thrust/advance.h" && cp "/usr/local/cuda-9.0/include/thrust/binary_search.h" "$(@D)/cuda/include/thrust/binary_search.h" && cp "/usr/local/cuda-9.0/include/thrust/complex.h" "$(@D)/cuda/include/thrust/complex.h" && cp "/usr/local/cuda-9.0/include/thrust/copy.h" "$(@D)/cuda/include/thrust/copy.h" && cp "/usr/local/cuda-9.0/include/thrust/count.h" "$(@D)/cuda/include/thrust/count.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/adjacent_difference.inl" "$(@D)/cuda/include/thrust/detail/adjacent_difference.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/advance.inl" "$(@D)/cuda/include/thrust/detail/advance.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/allocator_traits.h" "$(@D)/cuda/include/thrust/detail/allocator/allocator_traits.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/allocator_traits.inl" "$(@D)/cuda/include/thrust/detail/allocator/allocator_traits.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/copy_construct_range.h" "$(@D)/cuda/include/thrust/detail/allocator/copy_construct_range.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/copy_construct_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/copy_construct_range.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/default_construct_range.h" "$(@D)/cuda/include/thrust/detail/allocator/default_construct_range.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/default_construct_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/default_construct_range.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/destroy_range.h" "$(@D)/cuda/include/thrust/detail/allocator/destroy_range.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/destroy_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/destroy_range.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/fill_construct_range.h" "$(@D)/cuda/include/thrust/detail/allocator/fill_construct_range.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/fill_construct_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/fill_construct_range.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/malloc_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/malloc_allocator.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/malloc_allocator.inl" "$(@D)/cuda/include/thrust/detail/allocator/malloc_allocator.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/no_throw_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/no_throw_allocator.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/tagged_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/tagged_allocator.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/tagged_allocator.inl" "$(@D)/cuda/include/thrust/detail/allocator/tagged_allocator.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/temporary_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/temporary_allocator.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/allocator/temporary_allocator.inl" "$(@D)/cuda/include/thrust/detail/allocator/temporary_allocator.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/binary_search.inl" "$(@D)/cuda/include/thrust/detail/binary_search.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/arithmetic.h" "$(@D)/cuda/include/thrust/detail/complex/arithmetic.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/c99math.h" "$(@D)/cuda/include/thrust/detail/complex/c99math.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/catrig.h" "$(@D)/cuda/include/thrust/detail/complex/catrig.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/catrigf.h" "$(@D)/cuda/include/thrust/detail/complex/catrigf.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/ccosh.h" "$(@D)/cuda/include/thrust/detail/complex/ccosh.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/ccoshf.h" "$(@D)/cuda/include/thrust/detail/complex/ccoshf.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/cexp.h" "$(@D)/cuda/include/thrust/detail/complex/cexp.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/cexpf.h" "$(@D)/cuda/include/thrust/detail/complex/cexpf.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/clog.h" "$(@D)/cuda/include/thrust/detail/complex/clog.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/clogf.h" "$(@D)/cuda/include/thrust/detail/complex/clogf.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/complex.inl" "$(@D)/cuda/include/thrust/detail/complex/complex.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/cpow.h" "$(@D)/cuda/include/thrust/detail/complex/cpow.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/cpowf.h" "$(@D)/cuda/include/thrust/detail/complex/cpowf.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/cproj.h" "$(@D)/cuda/include/thrust/detail/complex/cproj.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/csinh.h" "$(@D)/cuda/include/thrust/detail/complex/csinh.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/csinhf.h" "$(@D)/cuda/include/thrust/detail/complex/csinhf.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/csqrt.h" "$(@D)/cuda/include/thrust/detail/complex/csqrt.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/csqrtf.h" "$(@D)/cuda/include/thrust/detail/complex/csqrtf.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/ctanh.h" "$(@D)/cuda/include/thrust/detail/complex/ctanh.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/ctanhf.h" "$(@D)/cuda/include/thrust/detail/complex/ctanhf.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/math_private.h" "$(@D)/cuda/include/thrust/detail/complex/math_private.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/complex/stream.h" "$(@D)/cuda/include/thrust/detail/complex/stream.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config.h" "$(@D)/cuda/include/thrust/detail/config.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/compiler.h" "$(@D)/cuda/include/thrust/detail/config/compiler.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/compiler_fence.h" "$(@D)/cuda/include/thrust/detail/config/compiler_fence.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/config.h" "$(@D)/cuda/include/thrust/detail/config/config.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/debug.h" "$(@D)/cuda/include/thrust/detail/config/debug.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/device_system.h" "$(@D)/cuda/include/thrust/detail/config/device_system.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/exec_check_disable.h" "$(@D)/cuda/include/thrust/detail/config/exec_check_disable.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/forceinline.h" "$(@D)/cuda/include/thrust/detail/config/forceinline.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/global_workarounds.h" "$(@D)/cuda/include/thrust/detail/config/global_workarounds.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/host_device.h" "$(@D)/cuda/include/thrust/detail/config/host_device.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/host_system.h" "$(@D)/cuda/include/thrust/detail/config/host_system.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/config/simple_defines.h" "$(@D)/cuda/include/thrust/detail/config/simple_defines.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/contiguous_storage.h" "$(@D)/cuda/include/thrust/detail/contiguous_storage.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/contiguous_storage.inl" "$(@D)/cuda/include/thrust/detail/contiguous_storage.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/copy.h" "$(@D)/cuda/include/thrust/detail/copy.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/copy.inl" "$(@D)/cuda/include/thrust/detail/copy.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/copy_if.h" "$(@D)/cuda/include/thrust/detail/copy_if.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/copy_if.inl" "$(@D)/cuda/include/thrust/detail/copy_if.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/count.inl" "$(@D)/cuda/include/thrust/detail/count.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/cstdint.h" "$(@D)/cuda/include/thrust/detail/cstdint.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/device_delete.inl" "$(@D)/cuda/include/thrust/detail/device_delete.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/device_free.inl" "$(@D)/cuda/include/thrust/detail/device_free.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/device_malloc.inl" "$(@D)/cuda/include/thrust/detail/device_malloc.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/device_new.inl" "$(@D)/cuda/include/thrust/detail/device_new.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/device_ptr.inl" "$(@D)/cuda/include/thrust/detail/device_ptr.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/device_reference.inl" "$(@D)/cuda/include/thrust/detail/device_reference.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/device_vector.inl" "$(@D)/cuda/include/thrust/detail/device_vector.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/dispatch/is_trivial_copy.h" "$(@D)/cuda/include/thrust/detail/dispatch/is_trivial_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/distance.inl" "$(@D)/cuda/include/thrust/detail/distance.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/equal.inl" "$(@D)/cuda/include/thrust/detail/equal.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/execute_with_allocator.h" "$(@D)/cuda/include/thrust/detail/execute_with_allocator.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/execution_policy.h" "$(@D)/cuda/include/thrust/detail/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/extrema.inl" "$(@D)/cuda/include/thrust/detail/extrema.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/fill.inl" "$(@D)/cuda/include/thrust/detail/fill.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/find.inl" "$(@D)/cuda/include/thrust/detail/find.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/for_each.inl" "$(@D)/cuda/include/thrust/detail/for_each.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/function.h" "$(@D)/cuda/include/thrust/detail/function.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional.inl" "$(@D)/cuda/include/thrust/detail/functional.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/actor.h" "$(@D)/cuda/include/thrust/detail/functional/actor.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/actor.inl" "$(@D)/cuda/include/thrust/detail/functional/actor.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/argument.h" "$(@D)/cuda/include/thrust/detail/functional/argument.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/composite.h" "$(@D)/cuda/include/thrust/detail/functional/composite.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/operators/arithmetic_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/arithmetic_operators.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/operators/assignment_operator.h" "$(@D)/cuda/include/thrust/detail/functional/operators/assignment_operator.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/operators/bitwise_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/bitwise_operators.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/operators/compound_assignment_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/compound_assignment_operators.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/operators/logical_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/logical_operators.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/operators/operator_adaptors.h" "$(@D)/cuda/include/thrust/detail/functional/operators/operator_adaptors.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/operators/relational_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/relational_operators.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/placeholder.h" "$(@D)/cuda/include/thrust/detail/functional/placeholder.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/functional/value.h" "$(@D)/cuda/include/thrust/detail/functional/value.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/gather.inl" "$(@D)/cuda/include/thrust/detail/gather.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/generate.inl" "$(@D)/cuda/include/thrust/detail/generate.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/get_iterator_value.h" "$(@D)/cuda/include/thrust/detail/get_iterator_value.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/host_vector.inl" "$(@D)/cuda/include/thrust/detail/host_vector.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/inner_product.inl" "$(@D)/cuda/include/thrust/detail/inner_product.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/integer_math.h" "$(@D)/cuda/include/thrust/detail/integer_math.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/integer_traits.h" "$(@D)/cuda/include/thrust/detail/integer_traits.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/internal_functional.h" "$(@D)/cuda/include/thrust/detail/internal_functional.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/logical.inl" "$(@D)/cuda/include/thrust/detail/logical.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/detail/malloc_and_free.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/merge.inl" "$(@D)/cuda/include/thrust/detail/merge.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/minmax.h" "$(@D)/cuda/include/thrust/detail/minmax.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/mismatch.inl" "$(@D)/cuda/include/thrust/detail/mismatch.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/mpl/math.h" "$(@D)/cuda/include/thrust/detail/mpl/math.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/numeric_traits.h" "$(@D)/cuda/include/thrust/detail/numeric_traits.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/overlapped_copy.h" "$(@D)/cuda/include/thrust/detail/overlapped_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/pair.inl" "$(@D)/cuda/include/thrust/detail/pair.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/partition.inl" "$(@D)/cuda/include/thrust/detail/partition.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/pointer.h" "$(@D)/cuda/include/thrust/detail/pointer.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/pointer.inl" "$(@D)/cuda/include/thrust/detail/pointer.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/range/head_flags.h" "$(@D)/cuda/include/thrust/detail/range/head_flags.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/range/tail_flags.h" "$(@D)/cuda/include/thrust/detail/range/tail_flags.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/raw_pointer_cast.h" "$(@D)/cuda/include/thrust/detail/raw_pointer_cast.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/raw_reference_cast.h" "$(@D)/cuda/include/thrust/detail/raw_reference_cast.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/reduce.inl" "$(@D)/cuda/include/thrust/detail/reduce.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/reference.h" "$(@D)/cuda/include/thrust/detail/reference.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/reference.inl" "$(@D)/cuda/include/thrust/detail/reference.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/reference_forward_declaration.h" "$(@D)/cuda/include/thrust/detail/reference_forward_declaration.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/remove.inl" "$(@D)/cuda/include/thrust/detail/remove.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/replace.inl" "$(@D)/cuda/include/thrust/detail/replace.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/reverse.inl" "$(@D)/cuda/include/thrust/detail/reverse.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/scan.inl" "$(@D)/cuda/include/thrust/detail/scan.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/scatter.inl" "$(@D)/cuda/include/thrust/detail/scatter.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/seq.h" "$(@D)/cuda/include/thrust/detail/seq.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/sequence.inl" "$(@D)/cuda/include/thrust/detail/sequence.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/set_operations.inl" "$(@D)/cuda/include/thrust/detail/set_operations.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/sort.inl" "$(@D)/cuda/include/thrust/detail/sort.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/static_assert.h" "$(@D)/cuda/include/thrust/detail/static_assert.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/static_map.h" "$(@D)/cuda/include/thrust/detail/static_map.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/swap.h" "$(@D)/cuda/include/thrust/detail/swap.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/swap.inl" "$(@D)/cuda/include/thrust/detail/swap.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/swap_ranges.inl" "$(@D)/cuda/include/thrust/detail/swap_ranges.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/tabulate.inl" "$(@D)/cuda/include/thrust/detail/tabulate.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/temporary_array.h" "$(@D)/cuda/include/thrust/detail/temporary_array.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/temporary_array.inl" "$(@D)/cuda/include/thrust/detail/temporary_array.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/detail/temporary_buffer.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/transform.inl" "$(@D)/cuda/include/thrust/detail/transform.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/transform_reduce.inl" "$(@D)/cuda/include/thrust/detail/transform_reduce.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/transform_scan.inl" "$(@D)/cuda/include/thrust/detail/transform_scan.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/trivial_sequence.h" "$(@D)/cuda/include/thrust/detail/trivial_sequence.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/tuple.inl" "$(@D)/cuda/include/thrust/detail/tuple.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/tuple_meta_transform.h" "$(@D)/cuda/include/thrust/detail/tuple_meta_transform.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/tuple_transform.h" "$(@D)/cuda/include/thrust/detail/tuple_transform.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits.h" "$(@D)/cuda/include/thrust/detail/type_traits.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h" "$(@D)/cuda/include/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/function_traits.h" "$(@D)/cuda/include/thrust/detail/type_traits/function_traits.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/has_member_function.h" "$(@D)/cuda/include/thrust/detail/type_traits/has_member_function.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/has_nested_type.h" "$(@D)/cuda/include/thrust/detail/type_traits/has_nested_type.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/has_trivial_assign.h" "$(@D)/cuda/include/thrust/detail/type_traits/has_trivial_assign.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/is_call_possible.h" "$(@D)/cuda/include/thrust/detail/type_traits/is_call_possible.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/is_metafunction_defined.h" "$(@D)/cuda/include/thrust/detail/type_traits/is_metafunction_defined.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/iterator/is_discard_iterator.h" "$(@D)/cuda/include/thrust/detail/type_traits/iterator/is_discard_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/iterator/is_output_iterator.h" "$(@D)/cuda/include/thrust/detail/type_traits/iterator/is_output_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/minimum_type.h" "$(@D)/cuda/include/thrust/detail/type_traits/minimum_type.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/pointer_traits.h" "$(@D)/cuda/include/thrust/detail/type_traits/pointer_traits.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/type_traits/result_of_adaptable_function.h" "$(@D)/cuda/include/thrust/detail/type_traits/result_of_adaptable_function.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/uninitialized_copy.inl" "$(@D)/cuda/include/thrust/detail/uninitialized_copy.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/uninitialized_fill.inl" "$(@D)/cuda/include/thrust/detail/uninitialized_fill.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/unique.inl" "$(@D)/cuda/include/thrust/detail/unique.inl" && cp "/usr/local/cuda-9.0/include/thrust/detail/use_default.h" "$(@D)/cuda/include/thrust/detail/use_default.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/util/align.h" "$(@D)/cuda/include/thrust/detail/util/align.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/util/blocking.h" "$(@D)/cuda/include/thrust/detail/util/blocking.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/vector_base.h" "$(@D)/cuda/include/thrust/detail/vector_base.h" && cp "/usr/local/cuda-9.0/include/thrust/detail/vector_base.inl" "$(@D)/cuda/include/thrust/detail/vector_base.inl" && cp "/usr/local/cuda-9.0/include/thrust/device_allocator.h" "$(@D)/cuda/include/thrust/device_allocator.h" && cp "/usr/local/cuda-9.0/include/thrust/device_delete.h" "$(@D)/cuda/include/thrust/device_delete.h" && cp "/usr/local/cuda-9.0/include/thrust/device_free.h" "$(@D)/cuda/include/thrust/device_free.h" && cp "/usr/local/cuda-9.0/include/thrust/device_malloc.h" "$(@D)/cuda/include/thrust/device_malloc.h" && cp "/usr/local/cuda-9.0/include/thrust/device_malloc_allocator.h" "$(@D)/cuda/include/thrust/device_malloc_allocator.h" && cp "/usr/local/cuda-9.0/include/thrust/device_new.h" "$(@D)/cuda/include/thrust/device_new.h" && cp "/usr/local/cuda-9.0/include/thrust/device_new_allocator.h" "$(@D)/cuda/include/thrust/device_new_allocator.h" && cp "/usr/local/cuda-9.0/include/thrust/device_ptr.h" "$(@D)/cuda/include/thrust/device_ptr.h" && cp "/usr/local/cuda-9.0/include/thrust/device_reference.h" "$(@D)/cuda/include/thrust/device_reference.h" && cp "/usr/local/cuda-9.0/include/thrust/device_vector.h" "$(@D)/cuda/include/thrust/device_vector.h" && cp "/usr/local/cuda-9.0/include/thrust/distance.h" "$(@D)/cuda/include/thrust/distance.h" && cp "/usr/local/cuda-9.0/include/thrust/equal.h" "$(@D)/cuda/include/thrust/equal.h" && cp "/usr/local/cuda-9.0/include/thrust/execution_policy.h" "$(@D)/cuda/include/thrust/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/extrema.h" "$(@D)/cuda/include/thrust/extrema.h" && cp "/usr/local/cuda-9.0/include/thrust/fill.h" "$(@D)/cuda/include/thrust/fill.h" && cp "/usr/local/cuda-9.0/include/thrust/find.h" "$(@D)/cuda/include/thrust/find.h" && cp "/usr/local/cuda-9.0/include/thrust/for_each.h" "$(@D)/cuda/include/thrust/for_each.h" && cp "/usr/local/cuda-9.0/include/thrust/functional.h" "$(@D)/cuda/include/thrust/functional.h" && cp "/usr/local/cuda-9.0/include/thrust/gather.h" "$(@D)/cuda/include/thrust/gather.h" && cp "/usr/local/cuda-9.0/include/thrust/generate.h" "$(@D)/cuda/include/thrust/generate.h" && cp "/usr/local/cuda-9.0/include/thrust/host_vector.h" "$(@D)/cuda/include/thrust/host_vector.h" && cp "/usr/local/cuda-9.0/include/thrust/inner_product.h" "$(@D)/cuda/include/thrust/inner_product.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/constant_iterator.h" "$(@D)/cuda/include/thrust/iterator/constant_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/counting_iterator.h" "$(@D)/cuda/include/thrust/iterator/counting_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/any_assign.h" "$(@D)/cuda/include/thrust/iterator/detail/any_assign.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/any_system_tag.h" "$(@D)/cuda/include/thrust/iterator/detail/any_system_tag.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/constant_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/constant_iterator_base.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/counting_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/counting_iterator.inl" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/device_system_tag.h" "$(@D)/cuda/include/thrust/iterator/detail/device_system_tag.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/discard_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/discard_iterator_base.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/distance_from_result.h" "$(@D)/cuda/include/thrust/iterator/detail/distance_from_result.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/host_system_tag.h" "$(@D)/cuda/include/thrust/iterator/detail/host_system_tag.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/is_iterator_category.h" "$(@D)/cuda/include/thrust/iterator/detail/is_iterator_category.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/is_trivial_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/is_trivial_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/iterator_adaptor_base.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_adaptor_base.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/iterator_category_to_system.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_category_to_system.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/iterator_category_to_traversal.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_category_to_traversal.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/iterator_category_with_system_and_traversal.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_category_with_system_and_traversal.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/iterator_facade_category.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_facade_category.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/iterator_traits.inl" "$(@D)/cuda/include/thrust/iterator/detail/iterator_traits.inl" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/iterator_traversal_tags.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_traversal_tags.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/join_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/join_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/minimum_category.h" "$(@D)/cuda/include/thrust/iterator/detail/minimum_category.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/minimum_system.h" "$(@D)/cuda/include/thrust/iterator/detail/minimum_system.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/normal_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/normal_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/permutation_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/permutation_iterator_base.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/retag.h" "$(@D)/cuda/include/thrust/iterator/detail/retag.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/reverse_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/reverse_iterator.inl" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/reverse_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/reverse_iterator_base.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/tagged_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/tagged_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/transform_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/transform_iterator.inl" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/transform_output_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/transform_output_iterator.inl" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/tuple_of_iterator_references.h" "$(@D)/cuda/include/thrust/iterator/detail/tuple_of_iterator_references.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/universal_categories.h" "$(@D)/cuda/include/thrust/iterator/detail/universal_categories.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/zip_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/zip_iterator.inl" && cp "/usr/local/cuda-9.0/include/thrust/iterator/detail/zip_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/zip_iterator_base.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/discard_iterator.h" "$(@D)/cuda/include/thrust/iterator/discard_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/iterator_adaptor.h" "$(@D)/cuda/include/thrust/iterator/iterator_adaptor.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/iterator_categories.h" "$(@D)/cuda/include/thrust/iterator/iterator_categories.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/iterator_facade.h" "$(@D)/cuda/include/thrust/iterator/iterator_facade.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/iterator_traits.h" "$(@D)/cuda/include/thrust/iterator/iterator_traits.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/permutation_iterator.h" "$(@D)/cuda/include/thrust/iterator/permutation_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/retag.h" "$(@D)/cuda/include/thrust/iterator/retag.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/reverse_iterator.h" "$(@D)/cuda/include/thrust/iterator/reverse_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/transform_iterator.h" "$(@D)/cuda/include/thrust/iterator/transform_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/transform_output_iterator.h" "$(@D)/cuda/include/thrust/iterator/transform_output_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/iterator/zip_iterator.h" "$(@D)/cuda/include/thrust/iterator/zip_iterator.h" && cp "/usr/local/cuda-9.0/include/thrust/logical.h" "$(@D)/cuda/include/thrust/logical.h" && cp "/usr/local/cuda-9.0/include/thrust/memory.h" "$(@D)/cuda/include/thrust/memory.h" && cp "/usr/local/cuda-9.0/include/thrust/merge.h" "$(@D)/cuda/include/thrust/merge.h" && cp "/usr/local/cuda-9.0/include/thrust/mismatch.h" "$(@D)/cuda/include/thrust/mismatch.h" && cp "/usr/local/cuda-9.0/include/thrust/pair.h" "$(@D)/cuda/include/thrust/pair.h" && cp "/usr/local/cuda-9.0/include/thrust/partition.h" "$(@D)/cuda/include/thrust/partition.h" && cp "/usr/local/cuda-9.0/include/thrust/random.h" "$(@D)/cuda/include/thrust/random.h" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/discard_block_engine.inl" "$(@D)/cuda/include/thrust/random/detail/discard_block_engine.inl" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/linear_congruential_engine.inl" "$(@D)/cuda/include/thrust/random/detail/linear_congruential_engine.inl" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/linear_congruential_engine_discard.h" "$(@D)/cuda/include/thrust/random/detail/linear_congruential_engine_discard.h" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/linear_feedback_shift_engine.inl" "$(@D)/cuda/include/thrust/random/detail/linear_feedback_shift_engine.inl" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/linear_feedback_shift_engine_wordmask.h" "$(@D)/cuda/include/thrust/random/detail/linear_feedback_shift_engine_wordmask.h" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/mod.h" "$(@D)/cuda/include/thrust/random/detail/mod.h" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/normal_distribution.inl" "$(@D)/cuda/include/thrust/random/detail/normal_distribution.inl" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/normal_distribution_base.h" "$(@D)/cuda/include/thrust/random/detail/normal_distribution_base.h" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/random_core_access.h" "$(@D)/cuda/include/thrust/random/detail/random_core_access.h" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/subtract_with_carry_engine.inl" "$(@D)/cuda/include/thrust/random/detail/subtract_with_carry_engine.inl" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/uniform_int_distribution.inl" "$(@D)/cuda/include/thrust/random/detail/uniform_int_distribution.inl" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/uniform_real_distribution.inl" "$(@D)/cuda/include/thrust/random/detail/uniform_real_distribution.inl" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/xor_combine_engine.inl" "$(@D)/cuda/include/thrust/random/detail/xor_combine_engine.inl" && cp "/usr/local/cuda-9.0/include/thrust/random/detail/xor_combine_engine_max.h" "$(@D)/cuda/include/thrust/random/detail/xor_combine_engine_max.h" && cp "/usr/local/cuda-9.0/include/thrust/random/discard_block_engine.h" "$(@D)/cuda/include/thrust/random/discard_block_engine.h" && cp "/usr/local/cuda-9.0/include/thrust/random/linear_congruential_engine.h" "$(@D)/cuda/include/thrust/random/linear_congruential_engine.h" && cp "/usr/local/cuda-9.0/include/thrust/random/linear_feedback_shift_engine.h" "$(@D)/cuda/include/thrust/random/linear_feedback_shift_engine.h" && cp "/usr/local/cuda-9.0/include/thrust/random/normal_distribution.h" "$(@D)/cuda/include/thrust/random/normal_distribution.h" && cp "/usr/local/cuda-9.0/include/thrust/random/subtract_with_carry_engine.h" "$(@D)/cuda/include/thrust/random/subtract_with_carry_engine.h" && cp "/usr/local/cuda-9.0/include/thrust/random/uniform_int_distribution.h" "$(@D)/cuda/include/thrust/random/uniform_int_distribution.h" && cp "/usr/local/cuda-9.0/include/thrust/random/uniform_real_distribution.h" "$(@D)/cuda/include/thrust/random/uniform_real_distribution.h" && cp "/usr/local/cuda-9.0/include/thrust/random/xor_combine_engine.h" "$(@D)/cuda/include/thrust/random/xor_combine_engine.h" && cp "/usr/local/cuda-9.0/include/thrust/reduce.h" "$(@D)/cuda/include/thrust/reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/remove.h" "$(@D)/cuda/include/thrust/remove.h" && cp "/usr/local/cuda-9.0/include/thrust/replace.h" "$(@D)/cuda/include/thrust/replace.h" && cp "/usr/local/cuda-9.0/include/thrust/reverse.h" "$(@D)/cuda/include/thrust/reverse.h" && cp "/usr/local/cuda-9.0/include/thrust/scan.h" "$(@D)/cuda/include/thrust/scan.h" && cp "/usr/local/cuda-9.0/include/thrust/scatter.h" "$(@D)/cuda/include/thrust/scatter.h" && cp "/usr/local/cuda-9.0/include/thrust/sequence.h" "$(@D)/cuda/include/thrust/sequence.h" && cp "/usr/local/cuda-9.0/include/thrust/set_operations.h" "$(@D)/cuda/include/thrust/set_operations.h" && cp "/usr/local/cuda-9.0/include/thrust/sort.h" "$(@D)/cuda/include/thrust/sort.h" && cp "/usr/local/cuda-9.0/include/thrust/swap.h" "$(@D)/cuda/include/thrust/swap.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/cpp/detail/adjacent_difference.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/cpp/detail/assign_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/cpp/detail/binary_search.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/copy.h" "$(@D)/cuda/include/thrust/system/cpp/detail/copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/cpp/detail/copy_if.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/count.h" "$(@D)/cuda/include/thrust/system/cpp/detail/count.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/equal.h" "$(@D)/cuda/include/thrust/system/cpp/detail/equal.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/cpp/detail/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/extrema.h" "$(@D)/cuda/include/thrust/system/cpp/detail/extrema.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/fill.h" "$(@D)/cuda/include/thrust/system/cpp/detail/fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/find.h" "$(@D)/cuda/include/thrust/system/cpp/detail/find.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/for_each.h" "$(@D)/cuda/include/thrust/system/cpp/detail/for_each.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/gather.h" "$(@D)/cuda/include/thrust/system/cpp/detail/gather.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/generate.h" "$(@D)/cuda/include/thrust/system/cpp/detail/generate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/get_value.h" "$(@D)/cuda/include/thrust/system/cpp/detail/get_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/cpp/detail/inner_product.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/cpp/detail/iter_swap.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/logical.h" "$(@D)/cuda/include/thrust/system/cpp/detail/logical.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/cpp/detail/malloc_and_free.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/memory.inl" "$(@D)/cuda/include/thrust/system/cpp/detail/memory.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/merge.h" "$(@D)/cuda/include/thrust/system/cpp/detail/merge.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/cpp/detail/mismatch.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/par.h" "$(@D)/cuda/include/thrust/system/cpp/detail/par.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/partition.h" "$(@D)/cuda/include/thrust/system/cpp/detail/partition.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/reduce.h" "$(@D)/cuda/include/thrust/system/cpp/detail/reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/cpp/detail/reduce_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/remove.h" "$(@D)/cuda/include/thrust/system/cpp/detail/remove.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/replace.h" "$(@D)/cuda/include/thrust/system/cpp/detail/replace.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/reverse.h" "$(@D)/cuda/include/thrust/system/cpp/detail/reverse.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/scan.h" "$(@D)/cuda/include/thrust/system/cpp/detail/scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/cpp/detail/scan_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/scatter.h" "$(@D)/cuda/include/thrust/system/cpp/detail/scatter.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/sequence.h" "$(@D)/cuda/include/thrust/system/cpp/detail/sequence.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/cpp/detail/set_operations.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/sort.h" "$(@D)/cuda/include/thrust/system/cpp/detail/sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/cpp/detail/swap_ranges.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/cpp/detail/tabulate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/cpp/detail/temporary_buffer.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/transform.h" "$(@D)/cuda/include/thrust/system/cpp/detail/transform.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/cpp/detail/transform_reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/cpp/detail/transform_scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/cpp/detail/uninitialized_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/cpp/detail/uninitialized_fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/unique.h" "$(@D)/cuda/include/thrust/system/cpp/detail/unique.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/cpp/detail/unique_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/detail/vector.inl" "$(@D)/cuda/include/thrust/system/cpp/detail/vector.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/execution_policy.h" "$(@D)/cuda/include/thrust/system/cpp/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/memory.h" "$(@D)/cuda/include/thrust/system/cpp/memory.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cpp/vector.h" "$(@D)/cuda/include/thrust/system/cpp/vector.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/config.h" "$(@D)/cuda/include/thrust/system/cuda/config.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/cuda/detail/adjacent_difference.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/cuda/detail/assign_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/cuda/detail/binary_search.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/copy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_if.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/core/agent_launcher.h" "$(@D)/cuda/include/thrust/system/cuda/detail/core/agent_launcher.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/core/alignment.h" "$(@D)/cuda/include/thrust/system/cuda/detail/core/alignment.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/core/triple_chevron_launch.h" "$(@D)/cuda/include/thrust/system/cuda/detail/core/triple_chevron_launch.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/core/util.h" "$(@D)/cuda/include/thrust/system/cuda/detail/core/util.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/count.h" "$(@D)/cuda/include/thrust/system/cuda/detail/count.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cross_system.h" "$(@D)/cuda/include/thrust/system/cuda/detail/cross_system.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_histogram.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_histogram.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_radix_sort_downsweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_radix_sort_downsweep.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_radix_sort_upsweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_radix_sort_upsweep.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_reduce.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_reduce_by_key.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_reduce_by_key.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_rle.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_rle.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_scan.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_segment_fixup.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_segment_fixup.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_select_if.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_select_if.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_spmv_csrt.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_spmv_csrt.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_spmv_orig.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_spmv_orig.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/agent_spmv_row_based.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_spmv_row_based.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/agent/single_pass_scan_operators.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/single_pass_scan_operators.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_adjacent_difference.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_adjacent_difference.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_discontinuity.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_discontinuity.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_exchange.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_exchange.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_histogram.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_histogram.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_load.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_load.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_radix_rank.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_radix_rank.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_radix_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_radix_sort.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_raking_layout.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_raking_layout.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_reduce.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_scan.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_shuffle.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_shuffle.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/block_store.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_store.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_atomic.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_atomic.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_sort.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking_commutative_only.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking_commutative_only.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_warp_reductions.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_warp_reductions.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_raking.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_raking.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans2.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans2.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans3.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans3.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/cub.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/cub.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_histogram.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_histogram.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_partition.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_partition.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_radix_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_radix_sort.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_reduce.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_run_length_encode.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_run_length_encode.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_scan.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_segmented_radix_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_segmented_radix_sort.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_segmented_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_segmented_reduce.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_select.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_select.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/device_spmv.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_spmv.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_histogram.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_histogram.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_radix_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_radix_sort.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_reduce.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_reduce_by_key.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_reduce_by_key.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_rle.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_rle.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_scan.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_select_if.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_select_if.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_csrt.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_csrt.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_orig.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_orig.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_row_based.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_row_based.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/grid/grid_barrier.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_barrier.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/grid/grid_even_share.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_even_share.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/grid/grid_mapping.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_mapping.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/grid/grid_queue.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_queue.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/host/mutex.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/host/mutex.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/iterator/arg_index_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/arg_index_input_iterator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/iterator/cache_modified_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_input_iterator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/iterator/cache_modified_output_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_output_iterator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/iterator/constant_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/constant_input_iterator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/iterator/discard_output_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/discard_output_iterator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/iterator/tex_obj_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/tex_obj_input_iterator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/iterator/tex_ref_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/tex_ref_input_iterator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/iterator/transform_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/transform_input_iterator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/thread/thread_load.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_load.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/thread/thread_operators.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_operators.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/thread/thread_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_reduce.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/thread/thread_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_scan.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/thread/thread_search.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_search.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/thread/thread_store.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_store.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/util_allocator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_allocator.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/util_arch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_arch.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/util_debug.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_debug.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/util_device.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_device.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/util_macro.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_macro.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/util_namespace.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_namespace.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/util_ptx.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_ptx.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/util_type.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_type.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_shfl.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_shfl.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_smem.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_smem.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_shfl.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_shfl.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_smem.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_smem.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/warp/warp_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/warp_reduce.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/cub/warp/warp_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/warp_scan.cuh" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/equal.h" "$(@D)/cuda/include/thrust/system/cuda/detail/equal.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/error.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/error.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/extrema.h" "$(@D)/cuda/include/thrust/system/cuda/detail/extrema.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/fill.h" "$(@D)/cuda/include/thrust/system/cuda/detail/fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/find.h" "$(@D)/cuda/include/thrust/system/cuda/detail/find.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/for_each.h" "$(@D)/cuda/include/thrust/system/cuda/detail/for_each.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/gather.h" "$(@D)/cuda/include/thrust/system/cuda/detail/gather.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/generate.h" "$(@D)/cuda/include/thrust/system/cuda/detail/generate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/get_value.h" "$(@D)/cuda/include/thrust/system/cuda/detail/get_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/guarded_cuda_runtime_api.h" "$(@D)/cuda/include/thrust/system/cuda/detail/guarded_cuda_runtime_api.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/guarded_driver_types.h" "$(@D)/cuda/include/thrust/system/cuda/detail/guarded_driver_types.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/cuda/detail/inner_product.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/internal/copy_cross_system.h" "$(@D)/cuda/include/thrust/system/cuda/detail/internal/copy_cross_system.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/internal/copy_device_to_device.h" "$(@D)/cuda/include/thrust/system/cuda/detail/internal/copy_device_to_device.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/cuda/detail/iter_swap.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/logical.h" "$(@D)/cuda/include/thrust/system/cuda/detail/logical.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/cuda/detail/malloc_and_free.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/memory.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/memory.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/memory_buffer.h" "$(@D)/cuda/include/thrust/system/cuda/detail/memory_buffer.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/merge.h" "$(@D)/cuda/include/thrust/system/cuda/detail/merge.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/cuda/detail/mismatch.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/par.h" "$(@D)/cuda/include/thrust/system/cuda/detail/par.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/par_to_seq.h" "$(@D)/cuda/include/thrust/system/cuda/detail/par_to_seq.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/parallel_for.h" "$(@D)/cuda/include/thrust/system/cuda/detail/parallel_for.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/partition.h" "$(@D)/cuda/include/thrust/system/cuda/detail/partition.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/reduce.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/remove.h" "$(@D)/cuda/include/thrust/system/cuda/detail/remove.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/replace.h" "$(@D)/cuda/include/thrust/system/cuda/detail/replace.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/reverse.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reverse.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/scan.h" "$(@D)/cuda/include/thrust/system/cuda/detail/scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/cuda/detail/scan_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/scatter.h" "$(@D)/cuda/include/thrust/system/cuda/detail/scatter.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/sequence.h" "$(@D)/cuda/include/thrust/system/cuda/detail/sequence.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/cuda/detail/set_operations.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/cuda/detail/swap_ranges.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/cuda/detail/tabulate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/cuda/detail/temporary_buffer.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/terminate.h" "$(@D)/cuda/include/thrust/system/cuda/detail/terminate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/transform.h" "$(@D)/cuda/include/thrust/system/cuda/detail/transform.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/cuda/detail/transform_reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/cuda/detail/transform_scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/uninitialized_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/cuda/detail/uninitialized_fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/unique.h" "$(@D)/cuda/include/thrust/system/cuda/detail/unique.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/cuda/detail/unique_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/util.h" "$(@D)/cuda/include/thrust/system/cuda/detail/util.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/detail/vector.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/vector.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/error.h" "$(@D)/cuda/include/thrust/system/cuda/error.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/execution_policy.h" "$(@D)/cuda/include/thrust/system/cuda/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/experimental/pinned_allocator.h" "$(@D)/cuda/include/thrust/system/cuda/experimental/pinned_allocator.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/memory.h" "$(@D)/cuda/include/thrust/system/cuda/memory.h" && cp "/usr/local/cuda-9.0/include/thrust/system/cuda/vector.h" "$(@D)/cuda/include/thrust/system/cuda/vector.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/detail/adl/adjacent_difference.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/assign_value.h" "$(@D)/cuda/include/thrust/system/detail/adl/assign_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/adl/binary_search.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/copy.h" "$(@D)/cuda/include/thrust/system/detail/adl/copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/copy_if.h" "$(@D)/cuda/include/thrust/system/detail/adl/copy_if.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/count.h" "$(@D)/cuda/include/thrust/system/detail/adl/count.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/equal.h" "$(@D)/cuda/include/thrust/system/detail/adl/equal.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/extrema.h" "$(@D)/cuda/include/thrust/system/detail/adl/extrema.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/fill.h" "$(@D)/cuda/include/thrust/system/detail/adl/fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/find.h" "$(@D)/cuda/include/thrust/system/detail/adl/find.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/for_each.h" "$(@D)/cuda/include/thrust/system/detail/adl/for_each.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/gather.h" "$(@D)/cuda/include/thrust/system/detail/adl/gather.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/generate.h" "$(@D)/cuda/include/thrust/system/detail/adl/generate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/get_value.h" "$(@D)/cuda/include/thrust/system/detail/adl/get_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/inner_product.h" "$(@D)/cuda/include/thrust/system/detail/adl/inner_product.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/iter_swap.h" "$(@D)/cuda/include/thrust/system/detail/adl/iter_swap.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/logical.h" "$(@D)/cuda/include/thrust/system/detail/adl/logical.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/detail/adl/malloc_and_free.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/merge.h" "$(@D)/cuda/include/thrust/system/detail/adl/merge.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/mismatch.h" "$(@D)/cuda/include/thrust/system/detail/adl/mismatch.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/partition.h" "$(@D)/cuda/include/thrust/system/detail/adl/partition.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/reduce.h" "$(@D)/cuda/include/thrust/system/detail/adl/reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/detail/adl/reduce_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/remove.h" "$(@D)/cuda/include/thrust/system/detail/adl/remove.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/replace.h" "$(@D)/cuda/include/thrust/system/detail/adl/replace.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/reverse.h" "$(@D)/cuda/include/thrust/system/detail/adl/reverse.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/scan.h" "$(@D)/cuda/include/thrust/system/detail/adl/scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/scan_by_key.h" "$(@D)/cuda/include/thrust/system/detail/adl/scan_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/scatter.h" "$(@D)/cuda/include/thrust/system/detail/adl/scatter.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/sequence.h" "$(@D)/cuda/include/thrust/system/detail/adl/sequence.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/set_operations.h" "$(@D)/cuda/include/thrust/system/detail/adl/set_operations.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/sort.h" "$(@D)/cuda/include/thrust/system/detail/adl/sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/swap_ranges.h" "$(@D)/cuda/include/thrust/system/detail/adl/swap_ranges.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/tabulate.h" "$(@D)/cuda/include/thrust/system/detail/adl/tabulate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/detail/adl/temporary_buffer.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/transform.h" "$(@D)/cuda/include/thrust/system/detail/adl/transform.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/transform_reduce.h" "$(@D)/cuda/include/thrust/system/detail/adl/transform_reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/transform_scan.h" "$(@D)/cuda/include/thrust/system/detail/adl/transform_scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/detail/adl/uninitialized_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/detail/adl/uninitialized_fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/unique.h" "$(@D)/cuda/include/thrust/system/detail/adl/unique.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/adl/unique_by_key.h" "$(@D)/cuda/include/thrust/system/detail/adl/unique_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/bad_alloc.h" "$(@D)/cuda/include/thrust/system/detail/bad_alloc.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/errno.h" "$(@D)/cuda/include/thrust/system/detail/errno.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/error_category.inl" "$(@D)/cuda/include/thrust/system/detail/error_category.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/error_code.inl" "$(@D)/cuda/include/thrust/system/detail/error_code.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/error_condition.inl" "$(@D)/cuda/include/thrust/system/detail/error_condition.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/detail/generic/adjacent_difference.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/adjacent_difference.inl" "$(@D)/cuda/include/thrust/system/detail/generic/adjacent_difference.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/advance.h" "$(@D)/cuda/include/thrust/system/detail/generic/advance.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/advance.inl" "$(@D)/cuda/include/thrust/system/detail/generic/advance.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/generic/binary_search.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/binary_search.inl" "$(@D)/cuda/include/thrust/system/detail/generic/binary_search.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/copy.h" "$(@D)/cuda/include/thrust/system/detail/generic/copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/copy.inl" "$(@D)/cuda/include/thrust/system/detail/generic/copy.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/copy_if.h" "$(@D)/cuda/include/thrust/system/detail/generic/copy_if.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/copy_if.inl" "$(@D)/cuda/include/thrust/system/detail/generic/copy_if.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/count.h" "$(@D)/cuda/include/thrust/system/detail/generic/count.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/count.inl" "$(@D)/cuda/include/thrust/system/detail/generic/count.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/distance.h" "$(@D)/cuda/include/thrust/system/detail/generic/distance.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/distance.inl" "$(@D)/cuda/include/thrust/system/detail/generic/distance.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/equal.h" "$(@D)/cuda/include/thrust/system/detail/generic/equal.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/equal.inl" "$(@D)/cuda/include/thrust/system/detail/generic/equal.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/extrema.h" "$(@D)/cuda/include/thrust/system/detail/generic/extrema.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/extrema.inl" "$(@D)/cuda/include/thrust/system/detail/generic/extrema.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/fill.h" "$(@D)/cuda/include/thrust/system/detail/generic/fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/find.h" "$(@D)/cuda/include/thrust/system/detail/generic/find.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/find.inl" "$(@D)/cuda/include/thrust/system/detail/generic/find.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/for_each.h" "$(@D)/cuda/include/thrust/system/detail/generic/for_each.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/gather.h" "$(@D)/cuda/include/thrust/system/detail/generic/gather.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/gather.inl" "$(@D)/cuda/include/thrust/system/detail/generic/gather.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/generate.h" "$(@D)/cuda/include/thrust/system/detail/generic/generate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/generate.inl" "$(@D)/cuda/include/thrust/system/detail/generic/generate.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/inner_product.h" "$(@D)/cuda/include/thrust/system/detail/generic/inner_product.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/inner_product.inl" "$(@D)/cuda/include/thrust/system/detail/generic/inner_product.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/logical.h" "$(@D)/cuda/include/thrust/system/detail/generic/logical.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/memory.h" "$(@D)/cuda/include/thrust/system/detail/generic/memory.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/memory.inl" "$(@D)/cuda/include/thrust/system/detail/generic/memory.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/merge.h" "$(@D)/cuda/include/thrust/system/detail/generic/merge.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/merge.inl" "$(@D)/cuda/include/thrust/system/detail/generic/merge.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/mismatch.h" "$(@D)/cuda/include/thrust/system/detail/generic/mismatch.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/mismatch.inl" "$(@D)/cuda/include/thrust/system/detail/generic/mismatch.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/partition.h" "$(@D)/cuda/include/thrust/system/detail/generic/partition.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/partition.inl" "$(@D)/cuda/include/thrust/system/detail/generic/partition.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/reduce.h" "$(@D)/cuda/include/thrust/system/detail/generic/reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/reduce.inl" "$(@D)/cuda/include/thrust/system/detail/generic/reduce.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/detail/generic/reduce_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/detail/generic/reduce_by_key.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/remove.h" "$(@D)/cuda/include/thrust/system/detail/generic/remove.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/remove.inl" "$(@D)/cuda/include/thrust/system/detail/generic/remove.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/replace.h" "$(@D)/cuda/include/thrust/system/detail/generic/replace.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/replace.inl" "$(@D)/cuda/include/thrust/system/detail/generic/replace.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/reverse.h" "$(@D)/cuda/include/thrust/system/detail/generic/reverse.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/reverse.inl" "$(@D)/cuda/include/thrust/system/detail/generic/reverse.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/scalar/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/generic/scalar/binary_search.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/scalar/binary_search.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scalar/binary_search.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/scan.h" "$(@D)/cuda/include/thrust/system/detail/generic/scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/scan.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scan.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/scan_by_key.h" "$(@D)/cuda/include/thrust/system/detail/generic/scan_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/scan_by_key.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scan_by_key.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/scatter.h" "$(@D)/cuda/include/thrust/system/detail/generic/scatter.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/scatter.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scatter.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/select_system.h" "$(@D)/cuda/include/thrust/system/detail/generic/select_system.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/sequence.h" "$(@D)/cuda/include/thrust/system/detail/generic/sequence.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/sequence.inl" "$(@D)/cuda/include/thrust/system/detail/generic/sequence.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/set_operations.h" "$(@D)/cuda/include/thrust/system/detail/generic/set_operations.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/set_operations.inl" "$(@D)/cuda/include/thrust/system/detail/generic/set_operations.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/sort.h" "$(@D)/cuda/include/thrust/system/detail/generic/sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/sort.inl" "$(@D)/cuda/include/thrust/system/detail/generic/sort.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/swap_ranges.h" "$(@D)/cuda/include/thrust/system/detail/generic/swap_ranges.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/swap_ranges.inl" "$(@D)/cuda/include/thrust/system/detail/generic/swap_ranges.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/tabulate.h" "$(@D)/cuda/include/thrust/system/detail/generic/tabulate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/tabulate.inl" "$(@D)/cuda/include/thrust/system/detail/generic/tabulate.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/tag.h" "$(@D)/cuda/include/thrust/system/detail/generic/tag.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/detail/generic/temporary_buffer.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/temporary_buffer.inl" "$(@D)/cuda/include/thrust/system/detail/generic/temporary_buffer.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/transform.h" "$(@D)/cuda/include/thrust/system/detail/generic/transform.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/transform.inl" "$(@D)/cuda/include/thrust/system/detail/generic/transform.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/transform_reduce.h" "$(@D)/cuda/include/thrust/system/detail/generic/transform_reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/transform_reduce.inl" "$(@D)/cuda/include/thrust/system/detail/generic/transform_reduce.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/transform_scan.h" "$(@D)/cuda/include/thrust/system/detail/generic/transform_scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/transform_scan.inl" "$(@D)/cuda/include/thrust/system/detail/generic/transform_scan.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/type_traits.h" "$(@D)/cuda/include/thrust/system/detail/generic/type_traits.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/uninitialized_copy.inl" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_copy.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/uninitialized_fill.inl" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_fill.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/unique.h" "$(@D)/cuda/include/thrust/system/detail/generic/unique.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/unique.inl" "$(@D)/cuda/include/thrust/system/detail/generic/unique.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/unique_by_key.h" "$(@D)/cuda/include/thrust/system/detail/generic/unique_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/generic/unique_by_key.inl" "$(@D)/cuda/include/thrust/system/detail/generic/unique_by_key.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/internal/decompose.h" "$(@D)/cuda/include/thrust/system/detail/internal/decompose.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/detail/sequential/adjacent_difference.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/assign_value.h" "$(@D)/cuda/include/thrust/system/detail/sequential/assign_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/sequential/binary_search.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/copy.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/copy.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/copy_backward.h" "$(@D)/cuda/include/thrust/system/detail/sequential/copy_backward.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/copy_if.h" "$(@D)/cuda/include/thrust/system/detail/sequential/copy_if.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/count.h" "$(@D)/cuda/include/thrust/system/detail/sequential/count.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/equal.h" "$(@D)/cuda/include/thrust/system/detail/sequential/equal.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/execution_policy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/extrema.h" "$(@D)/cuda/include/thrust/system/detail/sequential/extrema.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/fill.h" "$(@D)/cuda/include/thrust/system/detail/sequential/fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/find.h" "$(@D)/cuda/include/thrust/system/detail/sequential/find.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/for_each.h" "$(@D)/cuda/include/thrust/system/detail/sequential/for_each.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/gather.h" "$(@D)/cuda/include/thrust/system/detail/sequential/gather.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/general_copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/general_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/generate.h" "$(@D)/cuda/include/thrust/system/detail/sequential/generate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/get_value.h" "$(@D)/cuda/include/thrust/system/detail/sequential/get_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/inner_product.h" "$(@D)/cuda/include/thrust/system/detail/sequential/inner_product.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/insertion_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/insertion_sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/iter_swap.h" "$(@D)/cuda/include/thrust/system/detail/sequential/iter_swap.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/logical.h" "$(@D)/cuda/include/thrust/system/detail/sequential/logical.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/detail/sequential/malloc_and_free.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/merge.h" "$(@D)/cuda/include/thrust/system/detail/sequential/merge.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/merge.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/merge.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/mismatch.h" "$(@D)/cuda/include/thrust/system/detail/sequential/mismatch.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/partition.h" "$(@D)/cuda/include/thrust/system/detail/sequential/partition.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/reduce.h" "$(@D)/cuda/include/thrust/system/detail/sequential/reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/detail/sequential/reduce_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/remove.h" "$(@D)/cuda/include/thrust/system/detail/sequential/remove.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/replace.h" "$(@D)/cuda/include/thrust/system/detail/sequential/replace.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/reverse.h" "$(@D)/cuda/include/thrust/system/detail/sequential/reverse.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/scan.h" "$(@D)/cuda/include/thrust/system/detail/sequential/scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/scan_by_key.h" "$(@D)/cuda/include/thrust/system/detail/sequential/scan_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/scatter.h" "$(@D)/cuda/include/thrust/system/detail/sequential/scatter.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/sequence.h" "$(@D)/cuda/include/thrust/system/detail/sequential/sequence.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/set_operations.h" "$(@D)/cuda/include/thrust/system/detail/sequential/set_operations.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/sort.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/stable_merge_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_merge_sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/stable_merge_sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_merge_sort.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/stable_primitive_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_primitive_sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/stable_primitive_sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_primitive_sort.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/stable_radix_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_radix_sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/stable_radix_sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_radix_sort.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/swap_ranges.h" "$(@D)/cuda/include/thrust/system/detail/sequential/swap_ranges.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/tabulate.h" "$(@D)/cuda/include/thrust/system/detail/sequential/tabulate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/detail/sequential/temporary_buffer.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/transform.h" "$(@D)/cuda/include/thrust/system/detail/sequential/transform.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/transform_reduce.h" "$(@D)/cuda/include/thrust/system/detail/sequential/transform_reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/transform_scan.h" "$(@D)/cuda/include/thrust/system/detail/sequential/transform_scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/trivial_copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/trivial_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/uninitialized_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/detail/sequential/uninitialized_fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/unique.h" "$(@D)/cuda/include/thrust/system/detail/sequential/unique.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/sequential/unique_by_key.h" "$(@D)/cuda/include/thrust/system/detail/sequential/unique_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/detail/system_error.inl" "$(@D)/cuda/include/thrust/system/detail/system_error.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/error_code.h" "$(@D)/cuda/include/thrust/system/error_code.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/omp/detail/adjacent_difference.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/omp/detail/assign_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/omp/detail/binary_search.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/copy.h" "$(@D)/cuda/include/thrust/system/omp/detail/copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/copy.inl" "$(@D)/cuda/include/thrust/system/omp/detail/copy.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/omp/detail/copy_if.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/copy_if.inl" "$(@D)/cuda/include/thrust/system/omp/detail/copy_if.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/count.h" "$(@D)/cuda/include/thrust/system/omp/detail/count.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/default_decomposition.h" "$(@D)/cuda/include/thrust/system/omp/detail/default_decomposition.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/default_decomposition.inl" "$(@D)/cuda/include/thrust/system/omp/detail/default_decomposition.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/equal.h" "$(@D)/cuda/include/thrust/system/omp/detail/equal.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/omp/detail/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/extrema.h" "$(@D)/cuda/include/thrust/system/omp/detail/extrema.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/fill.h" "$(@D)/cuda/include/thrust/system/omp/detail/fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/find.h" "$(@D)/cuda/include/thrust/system/omp/detail/find.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/for_each.h" "$(@D)/cuda/include/thrust/system/omp/detail/for_each.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/for_each.inl" "$(@D)/cuda/include/thrust/system/omp/detail/for_each.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/gather.h" "$(@D)/cuda/include/thrust/system/omp/detail/gather.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/generate.h" "$(@D)/cuda/include/thrust/system/omp/detail/generate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/get_value.h" "$(@D)/cuda/include/thrust/system/omp/detail/get_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/omp/detail/inner_product.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/omp/detail/iter_swap.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/logical.h" "$(@D)/cuda/include/thrust/system/omp/detail/logical.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/omp/detail/malloc_and_free.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/memory.inl" "$(@D)/cuda/include/thrust/system/omp/detail/memory.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/merge.h" "$(@D)/cuda/include/thrust/system/omp/detail/merge.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/omp/detail/mismatch.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/par.h" "$(@D)/cuda/include/thrust/system/omp/detail/par.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/partition.h" "$(@D)/cuda/include/thrust/system/omp/detail/partition.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/partition.inl" "$(@D)/cuda/include/thrust/system/omp/detail/partition.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/reduce.h" "$(@D)/cuda/include/thrust/system/omp/detail/reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/reduce.inl" "$(@D)/cuda/include/thrust/system/omp/detail/reduce.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_by_key.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/reduce_intervals.h" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_intervals.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/reduce_intervals.inl" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_intervals.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/remove.h" "$(@D)/cuda/include/thrust/system/omp/detail/remove.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/remove.inl" "$(@D)/cuda/include/thrust/system/omp/detail/remove.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/replace.h" "$(@D)/cuda/include/thrust/system/omp/detail/replace.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/reverse.h" "$(@D)/cuda/include/thrust/system/omp/detail/reverse.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/scan.h" "$(@D)/cuda/include/thrust/system/omp/detail/scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/omp/detail/scan_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/scatter.h" "$(@D)/cuda/include/thrust/system/omp/detail/scatter.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/sequence.h" "$(@D)/cuda/include/thrust/system/omp/detail/sequence.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/omp/detail/set_operations.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/sort.h" "$(@D)/cuda/include/thrust/system/omp/detail/sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/sort.inl" "$(@D)/cuda/include/thrust/system/omp/detail/sort.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/omp/detail/swap_ranges.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/omp/detail/tabulate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/omp/detail/temporary_buffer.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/transform.h" "$(@D)/cuda/include/thrust/system/omp/detail/transform.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/omp/detail/transform_reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/omp/detail/transform_scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/omp/detail/uninitialized_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/omp/detail/uninitialized_fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/unique.h" "$(@D)/cuda/include/thrust/system/omp/detail/unique.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/unique.inl" "$(@D)/cuda/include/thrust/system/omp/detail/unique.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/omp/detail/unique_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/unique_by_key.inl" "$(@D)/cuda/include/thrust/system/omp/detail/unique_by_key.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/detail/vector.inl" "$(@D)/cuda/include/thrust/system/omp/detail/vector.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/execution_policy.h" "$(@D)/cuda/include/thrust/system/omp/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/memory.h" "$(@D)/cuda/include/thrust/system/omp/memory.h" && cp "/usr/local/cuda-9.0/include/thrust/system/omp/vector.h" "$(@D)/cuda/include/thrust/system/omp/vector.h" && cp "/usr/local/cuda-9.0/include/thrust/system/system_error.h" "$(@D)/cuda/include/thrust/system/system_error.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/tbb/detail/adjacent_difference.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/tbb/detail/assign_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/tbb/detail/binary_search.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/copy.h" "$(@D)/cuda/include/thrust/system/tbb/detail/copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/copy.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/copy.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/tbb/detail/copy_if.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/copy_if.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/copy_if.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/count.h" "$(@D)/cuda/include/thrust/system/tbb/detail/count.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/equal.h" "$(@D)/cuda/include/thrust/system/tbb/detail/equal.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/tbb/detail/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/extrema.h" "$(@D)/cuda/include/thrust/system/tbb/detail/extrema.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/fill.h" "$(@D)/cuda/include/thrust/system/tbb/detail/fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/find.h" "$(@D)/cuda/include/thrust/system/tbb/detail/find.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/for_each.h" "$(@D)/cuda/include/thrust/system/tbb/detail/for_each.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/for_each.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/for_each.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/gather.h" "$(@D)/cuda/include/thrust/system/tbb/detail/gather.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/generate.h" "$(@D)/cuda/include/thrust/system/tbb/detail/generate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/get_value.h" "$(@D)/cuda/include/thrust/system/tbb/detail/get_value.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/tbb/detail/inner_product.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/tbb/detail/iter_swap.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/logical.h" "$(@D)/cuda/include/thrust/system/tbb/detail/logical.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/tbb/detail/malloc_and_free.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/memory.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/memory.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/merge.h" "$(@D)/cuda/include/thrust/system/tbb/detail/merge.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/merge.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/merge.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/tbb/detail/mismatch.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/par.h" "$(@D)/cuda/include/thrust/system/tbb/detail/par.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/partition.h" "$(@D)/cuda/include/thrust/system/tbb/detail/partition.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/partition.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/partition.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/reduce.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/reduce.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce_by_key.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/reduce_intervals.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce_intervals.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/remove.h" "$(@D)/cuda/include/thrust/system/tbb/detail/remove.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/remove.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/remove.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/replace.h" "$(@D)/cuda/include/thrust/system/tbb/detail/replace.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/reverse.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reverse.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/scan.h" "$(@D)/cuda/include/thrust/system/tbb/detail/scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/scan.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/scan.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/tbb/detail/scan_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/scatter.h" "$(@D)/cuda/include/thrust/system/tbb/detail/scatter.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/sequence.h" "$(@D)/cuda/include/thrust/system/tbb/detail/sequence.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/tbb/detail/set_operations.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/sort.h" "$(@D)/cuda/include/thrust/system/tbb/detail/sort.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/sort.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/sort.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/tbb/detail/swap_ranges.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/tbb/detail/tabulate.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/tbb/detail/temporary_buffer.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/transform.h" "$(@D)/cuda/include/thrust/system/tbb/detail/transform.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/tbb/detail/transform_reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/tbb/detail/transform_scan.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/tbb/detail/uninitialized_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/tbb/detail/uninitialized_fill.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/unique.h" "$(@D)/cuda/include/thrust/system/tbb/detail/unique.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/unique.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/unique.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/tbb/detail/unique_by_key.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/unique_by_key.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/unique_by_key.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/detail/vector.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/vector.inl" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/execution_policy.h" "$(@D)/cuda/include/thrust/system/tbb/execution_policy.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/memory.h" "$(@D)/cuda/include/thrust/system/tbb/memory.h" && cp "/usr/local/cuda-9.0/include/thrust/system/tbb/vector.h" "$(@D)/cuda/include/thrust/system/tbb/vector.h" && cp "/usr/local/cuda-9.0/include/thrust/system_error.h" "$(@D)/cuda/include/thrust/system_error.h" && cp "/usr/local/cuda-9.0/include/thrust/tabulate.h" "$(@D)/cuda/include/thrust/tabulate.h" && cp "/usr/local/cuda-9.0/include/thrust/transform.h" "$(@D)/cuda/include/thrust/transform.h" && cp "/usr/local/cuda-9.0/include/thrust/transform_reduce.h" "$(@D)/cuda/include/thrust/transform_reduce.h" && cp "/usr/local/cuda-9.0/include/thrust/transform_scan.h" "$(@D)/cuda/include/thrust/transform_scan.h" && cp "/usr/local/cuda-9.0/include/thrust/tuple.h" "$(@D)/cuda/include/thrust/tuple.h" && cp "/usr/local/cuda-9.0/include/thrust/uninitialized_copy.h" "$(@D)/cuda/include/thrust/uninitialized_copy.h" && cp "/usr/local/cuda-9.0/include/thrust/uninitialized_fill.h" "$(@D)/cuda/include/thrust/uninitialized_fill.h" && cp "/usr/local/cuda-9.0/include/thrust/unique.h" "$(@D)/cuda/include/thrust/unique.h" && cp "/usr/local/cuda-9.0/include/thrust/version.h" "$(@D)/cuda/include/thrust/version.h" && cp "/usr/local/cuda-9.0/include/vector_functions.h" "$(@D)/cuda/include/vector_functions.h" && cp "/usr/local/cuda-9.0/include/vector_functions.hpp" "$(@D)/cuda/include/vector_functions.hpp" && cp "/usr/local/cuda-9.0/include/vector_types.h" "$(@D)/cuda/include/vector_types.h" """, ) @@ -1264,72 +1192,69 @@ genrule( name = "cuda-nvvm", outs = [ "cuda/nvvm/bin/cicc", - "cuda/nvvm/libdevice/libdevice.compute_50.10.bc", - "cuda/nvvm/libdevice/libdevice.compute_30.10.bc", - "cuda/nvvm/libdevice/libdevice.compute_20.10.bc", - "cuda/nvvm/libdevice/libdevice.compute_35.10.bc", - "cuda/nvvm/lib64/libnvvm.so.3", - "cuda/nvvm/lib64/libnvvm.so", - "cuda/nvvm/lib64/libnvvm.so.3.1.0", "cuda/nvvm/include/nvvm.h", - "cuda/nvvm/libnvvm-samples/ptxgen/README.txt", - "cuda/nvvm/libnvvm-samples/ptxgen/ptxgen.c", - "cuda/nvvm/libnvvm-samples/ptxgen/CMakeLists.txt", + "cuda/nvvm/lib64/libnvvm.so", + "cuda/nvvm/lib64/libnvvm.so.3", + "cuda/nvvm/lib64/libnvvm.so.3.2.0", + "cuda/nvvm/libdevice/libdevice.10.bc", + "cuda/nvvm/libnvvm-samples/CMakeLists.txt", + "cuda/nvvm/libnvvm-samples/README.txt", "cuda/nvvm/libnvvm-samples/build.bat", - "cuda/nvvm/libnvvm-samples/cuda-c-linking/README.txt", - "cuda/nvvm/libnvvm-samples/cuda-c-linking/math-funcs.cu", + "cuda/nvvm/libnvvm-samples/build.sh", + "cuda/nvvm/libnvvm-samples/common/include/DDSWriter.h", + "cuda/nvvm/libnvvm-samples/common/include/drvapi_error_string.h", "cuda/nvvm/libnvvm-samples/cuda-c-linking/CMakeLists.txt", + "cuda/nvvm/libnvvm-samples/cuda-c-linking/README.txt", "cuda/nvvm/libnvvm-samples/cuda-c-linking/cuda-c-linking.cpp", - "cuda/nvvm/libnvvm-samples/README.txt", - "cuda/nvvm/libnvvm-samples/simple/simple.c", - "cuda/nvvm/libnvvm-samples/simple/simple-gpu.ll", + "cuda/nvvm/libnvvm-samples/cuda-c-linking/math-funcs.cu", + "cuda/nvvm/libnvvm-samples/ptxgen/CMakeLists.txt", + "cuda/nvvm/libnvvm-samples/ptxgen/README.txt", + "cuda/nvvm/libnvvm-samples/ptxgen/ptxgen.c", + "cuda/nvvm/libnvvm-samples/simple/CMakeLists.txt", "cuda/nvvm/libnvvm-samples/simple/README.txt", + "cuda/nvvm/libnvvm-samples/simple/simple-gpu.ll", "cuda/nvvm/libnvvm-samples/simple/simple-gpu64.ll", - "cuda/nvvm/libnvvm-samples/simple/CMakeLists.txt", - "cuda/nvvm/libnvvm-samples/common/include/DDSWriter.h", - "cuda/nvvm/libnvvm-samples/common/include/drvapi_error_string.h", - "cuda/nvvm/libnvvm-samples/build.sh", - "cuda/nvvm/libnvvm-samples/CMakeLists.txt", + "cuda/nvvm/libnvvm-samples/simple/simple.c", ], cmd = """ -if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-8.0/nvvm/bin/cicc" "$(@D)/cuda/nvvm/bin/cicc" && cp "/usr/local/cuda-8.0/nvvm/libdevice/libdevice.compute_50.10.bc" "$(@D)/cuda/nvvm/libdevice/libdevice.compute_50.10.bc" && cp "/usr/local/cuda-8.0/nvvm/libdevice/libdevice.compute_30.10.bc" "$(@D)/cuda/nvvm/libdevice/libdevice.compute_30.10.bc" && cp "/usr/local/cuda-8.0/nvvm/libdevice/libdevice.compute_20.10.bc" "$(@D)/cuda/nvvm/libdevice/libdevice.compute_20.10.bc" && cp "/usr/local/cuda-8.0/nvvm/libdevice/libdevice.compute_35.10.bc" "$(@D)/cuda/nvvm/libdevice/libdevice.compute_35.10.bc" && cp "/usr/local/cuda-8.0/nvvm/lib64/libnvvm.so.3" "$(@D)/cuda/nvvm/lib64/libnvvm.so.3" && cp "/usr/local/cuda-8.0/nvvm/lib64/libnvvm.so" "$(@D)/cuda/nvvm/lib64/libnvvm.so" && cp "/usr/local/cuda-8.0/nvvm/lib64/libnvvm.so.3.1.0" "$(@D)/cuda/nvvm/lib64/libnvvm.so.3.1.0" && cp "/usr/local/cuda-8.0/nvvm/include/nvvm.h" "$(@D)/cuda/nvvm/include/nvvm.h" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/ptxgen/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/ptxgen/README.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/ptxgen/ptxgen.c" "$(@D)/cuda/nvvm/libnvvm-samples/ptxgen/ptxgen.c" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/ptxgen/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/ptxgen/CMakeLists.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/build.bat" "$(@D)/cuda/nvvm/libnvvm-samples/build.bat" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/cuda-c-linking/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/README.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/cuda-c-linking/math-funcs.cu" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/math-funcs.cu" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/cuda-c-linking/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/CMakeLists.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/cuda-c-linking/cuda-c-linking.cpp" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/cuda-c-linking.cpp" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/README.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/simple/simple.c" "$(@D)/cuda/nvvm/libnvvm-samples/simple/simple.c" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/simple/simple-gpu.ll" "$(@D)/cuda/nvvm/libnvvm-samples/simple/simple-gpu.ll" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/simple/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/simple/README.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/simple/simple-gpu64.ll" "$(@D)/cuda/nvvm/libnvvm-samples/simple/simple-gpu64.ll" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/simple/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/simple/CMakeLists.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/common/include/DDSWriter.h" "$(@D)/cuda/nvvm/libnvvm-samples/common/include/DDSWriter.h" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/common/include/drvapi_error_string.h" "$(@D)/cuda/nvvm/libnvvm-samples/common/include/drvapi_error_string.h" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/build.sh" "$(@D)/cuda/nvvm/libnvvm-samples/build.sh" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/CMakeLists.txt" +if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/nvvm/bin/cicc" "$(@D)/cuda/nvvm/bin/cicc" && cp "/usr/local/cuda-9.0/nvvm/include/nvvm.h" "$(@D)/cuda/nvvm/include/nvvm.h" && cp "/usr/local/cuda-9.0/nvvm/lib64/libnvvm.so" "$(@D)/cuda/nvvm/lib64/libnvvm.so" && cp "/usr/local/cuda-9.0/nvvm/lib64/libnvvm.so.3" "$(@D)/cuda/nvvm/lib64/libnvvm.so.3" && cp "/usr/local/cuda-9.0/nvvm/lib64/libnvvm.so.3.2.0" "$(@D)/cuda/nvvm/lib64/libnvvm.so.3.2.0" && cp "/usr/local/cuda-9.0/nvvm/libdevice/libdevice.10.bc" "$(@D)/cuda/nvvm/libdevice/libdevice.10.bc" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/CMakeLists.txt" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/README.txt" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/build.bat" "$(@D)/cuda/nvvm/libnvvm-samples/build.bat" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/build.sh" "$(@D)/cuda/nvvm/libnvvm-samples/build.sh" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/common/include/DDSWriter.h" "$(@D)/cuda/nvvm/libnvvm-samples/common/include/DDSWriter.h" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/common/include/drvapi_error_string.h" "$(@D)/cuda/nvvm/libnvvm-samples/common/include/drvapi_error_string.h" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/cuda-c-linking/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/CMakeLists.txt" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/cuda-c-linking/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/README.txt" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/cuda-c-linking/cuda-c-linking.cpp" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/cuda-c-linking.cpp" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/cuda-c-linking/math-funcs.cu" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/math-funcs.cu" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/ptxgen/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/ptxgen/CMakeLists.txt" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/ptxgen/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/ptxgen/README.txt" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/ptxgen/ptxgen.c" "$(@D)/cuda/nvvm/libnvvm-samples/ptxgen/ptxgen.c" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/simple/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/simple/CMakeLists.txt" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/simple/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/simple/README.txt" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/simple/simple-gpu.ll" "$(@D)/cuda/nvvm/libnvvm-samples/simple/simple-gpu.ll" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/simple/simple-gpu64.ll" "$(@D)/cuda/nvvm/libnvvm-samples/simple/simple-gpu64.ll" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/simple/simple.c" "$(@D)/cuda/nvvm/libnvvm-samples/simple/simple.c" """, ) genrule( name = "cuda-extras", outs = [ - "cuda/extras/CUPTI/include/cupti_result.h", + "cuda/extras/CUPTI/include/GL/gl.h", + "cuda/extras/CUPTI/include/GL/glew.h", + "cuda/extras/CUPTI/include/GL/glext.h", + "cuda/extras/CUPTI/include/GL/glu.h", + "cuda/extras/CUPTI/include/GL/glut.h", + "cuda/extras/CUPTI/include/GL/glx.h", + "cuda/extras/CUPTI/include/GL/glxext.h", + "cuda/extras/CUPTI/include/GL/wglew.h", + "cuda/extras/CUPTI/include/GL/wglext.h", + "cuda/extras/CUPTI/include/cuda_stdint.h", + "cuda/extras/CUPTI/include/cupti.h", + "cuda/extras/CUPTI/include/cupti_activity.h", + "cuda/extras/CUPTI/include/cupti_callbacks.h", + "cuda/extras/CUPTI/include/cupti_driver_cbid.h", "cuda/extras/CUPTI/include/cupti_events.h", - "cuda/extras/CUPTI/include/openacc/cupti_openacc.h", + "cuda/extras/CUPTI/include/cupti_metrics.h", + "cuda/extras/CUPTI/include/cupti_nvtx_cbid.h", + "cuda/extras/CUPTI/include/cupti_result.h", + "cuda/extras/CUPTI/include/cupti_runtime_cbid.h", "cuda/extras/CUPTI/include/cupti_version.h", - "cuda/extras/CUPTI/include/generated_cuda_gl_interop_meta.h", + "cuda/extras/CUPTI/include/generated_cudaGL_meta.h", "cuda/extras/CUPTI/include/generated_cudaVDPAU_meta.h", - "cuda/extras/CUPTI/include/cupti_activity.h", - "cuda/extras/CUPTI/include/generated_cuda_runtime_api_meta.h", + "cuda/extras/CUPTI/include/generated_cuda_gl_interop_meta.h", "cuda/extras/CUPTI/include/generated_cuda_meta.h", - "cuda/extras/CUPTI/include/cupti_nvtx_cbid.h", - "cuda/extras/CUPTI/include/cuda_stdint.h", - "cuda/extras/CUPTI/include/generated_cudaGL_meta.h", + "cuda/extras/CUPTI/include/generated_cuda_runtime_api_meta.h", "cuda/extras/CUPTI/include/generated_cuda_vdpau_interop_meta.h", - "cuda/extras/CUPTI/include/cupti_metrics.h", - "cuda/extras/CUPTI/include/cupti_callbacks.h", - "cuda/extras/CUPTI/include/cupti_runtime_cbid.h", - "cuda/extras/CUPTI/include/cupti.h", - "cuda/extras/CUPTI/include/GL/glut.h", - "cuda/extras/CUPTI/include/GL/glu.h", - "cuda/extras/CUPTI/include/GL/glxext.h", - "cuda/extras/CUPTI/include/GL/wglext.h", - "cuda/extras/CUPTI/include/GL/glx.h", - "cuda/extras/CUPTI/include/GL/glext.h", - "cuda/extras/CUPTI/include/GL/wglew.h", - "cuda/extras/CUPTI/include/GL/gl.h", - "cuda/extras/CUPTI/include/GL/glew.h", - "cuda/extras/CUPTI/include/cupti_driver_cbid.h", "cuda/extras/CUPTI/include/generated_nvtx_meta.h", + "cuda/extras/CUPTI/include/openacc/cupti_openacc.h", ], cmd = """ -if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_result.h" "$(@D)/cuda/extras/CUPTI/include/cupti_result.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_events.h" "$(@D)/cuda/extras/CUPTI/include/cupti_events.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/openacc/cupti_openacc.h" "$(@D)/cuda/extras/CUPTI/include/openacc/cupti_openacc.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_version.h" "$(@D)/cuda/extras/CUPTI/include/cupti_version.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_cuda_gl_interop_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_gl_interop_meta.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_cudaVDPAU_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cudaVDPAU_meta.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_activity.h" "$(@D)/cuda/extras/CUPTI/include/cupti_activity.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_cuda_runtime_api_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_runtime_api_meta.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_cuda_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_meta.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_nvtx_cbid.h" "$(@D)/cuda/extras/CUPTI/include/cupti_nvtx_cbid.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cuda_stdint.h" "$(@D)/cuda/extras/CUPTI/include/cuda_stdint.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_cudaGL_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cudaGL_meta.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_cuda_vdpau_interop_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_vdpau_interop_meta.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_metrics.h" "$(@D)/cuda/extras/CUPTI/include/cupti_metrics.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_callbacks.h" "$(@D)/cuda/extras/CUPTI/include/cupti_callbacks.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_runtime_cbid.h" "$(@D)/cuda/extras/CUPTI/include/cupti_runtime_cbid.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti.h" "$(@D)/cuda/extras/CUPTI/include/cupti.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/glut.h" "$(@D)/cuda/extras/CUPTI/include/GL/glut.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/glu.h" "$(@D)/cuda/extras/CUPTI/include/GL/glu.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/glxext.h" "$(@D)/cuda/extras/CUPTI/include/GL/glxext.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/wglext.h" "$(@D)/cuda/extras/CUPTI/include/GL/wglext.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/glx.h" "$(@D)/cuda/extras/CUPTI/include/GL/glx.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/glext.h" "$(@D)/cuda/extras/CUPTI/include/GL/glext.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/wglew.h" "$(@D)/cuda/extras/CUPTI/include/GL/wglew.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/gl.h" "$(@D)/cuda/extras/CUPTI/include/GL/gl.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/glew.h" "$(@D)/cuda/extras/CUPTI/include/GL/glew.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_driver_cbid.h" "$(@D)/cuda/extras/CUPTI/include/cupti_driver_cbid.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_nvtx_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_nvtx_meta.h" +if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/extras/CUPTI/include/GL/gl.h" "$(@D)/cuda/extras/CUPTI/include/GL/gl.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/GL/glew.h" "$(@D)/cuda/extras/CUPTI/include/GL/glew.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/GL/glext.h" "$(@D)/cuda/extras/CUPTI/include/GL/glext.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/GL/glu.h" "$(@D)/cuda/extras/CUPTI/include/GL/glu.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/GL/glut.h" "$(@D)/cuda/extras/CUPTI/include/GL/glut.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/GL/glx.h" "$(@D)/cuda/extras/CUPTI/include/GL/glx.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/GL/glxext.h" "$(@D)/cuda/extras/CUPTI/include/GL/glxext.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/GL/wglew.h" "$(@D)/cuda/extras/CUPTI/include/GL/wglew.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/GL/wglext.h" "$(@D)/cuda/extras/CUPTI/include/GL/wglext.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cuda_stdint.h" "$(@D)/cuda/extras/CUPTI/include/cuda_stdint.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti.h" "$(@D)/cuda/extras/CUPTI/include/cupti.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti_activity.h" "$(@D)/cuda/extras/CUPTI/include/cupti_activity.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti_callbacks.h" "$(@D)/cuda/extras/CUPTI/include/cupti_callbacks.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti_driver_cbid.h" "$(@D)/cuda/extras/CUPTI/include/cupti_driver_cbid.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti_events.h" "$(@D)/cuda/extras/CUPTI/include/cupti_events.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti_metrics.h" "$(@D)/cuda/extras/CUPTI/include/cupti_metrics.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti_nvtx_cbid.h" "$(@D)/cuda/extras/CUPTI/include/cupti_nvtx_cbid.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti_result.h" "$(@D)/cuda/extras/CUPTI/include/cupti_result.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti_runtime_cbid.h" "$(@D)/cuda/extras/CUPTI/include/cupti_runtime_cbid.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/cupti_version.h" "$(@D)/cuda/extras/CUPTI/include/cupti_version.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/generated_cudaGL_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cudaGL_meta.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/generated_cudaVDPAU_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cudaVDPAU_meta.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/generated_cuda_gl_interop_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_gl_interop_meta.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/generated_cuda_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_meta.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/generated_cuda_runtime_api_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_runtime_api_meta.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/generated_cuda_vdpau_interop_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_vdpau_interop_meta.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/generated_nvtx_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_nvtx_meta.h" && cp "/usr/local/cuda-9.0/extras/CUPTI/include/openacc/cupti_openacc.h" "$(@D)/cuda/extras/CUPTI/include/openacc/cupti_openacc.h" """, ) @@ -1337,26 +1262,21 @@ genrule( name = "cuda-lib", outs = [ "cuda/lib/libcuda.so", - "cuda/lib/libcudart.so.8.0", + "cuda/lib/libcudart.so.9.0", "cuda/lib/libcudart_static.a", - "cuda/lib/libcublas.so.8.0", - "cuda/lib/libcusolver.so.8.0", - "cuda/lib/libcurand.so.8.0", - "cuda/lib/libcufft.so.8.0", - "cuda/lib/libcudnn.so.6", - "cuda/lib/libcupti.so.8.0", + "cuda/lib/libcublas.so.9.0", + "cuda/lib/libcusolver.so.9.0", + "cuda/lib/libcurand.so.9.0", + "cuda/lib/libcufft.so.9.0", + "cuda/lib/libcudnn.so.7", + "cuda/lib/libcupti.so.9.0", ], cmd = """ -if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/stubs/libcuda.so" "$(@D)/cuda/lib/libcuda.so" && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudart.so.8.0.61" "$(@D)/cuda/lib/libcudart.so.8.0" && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudart_static.a" "$(@D)/cuda/lib/libcudart_static.a" && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcublas.so.8.0.88" "$(@D)/cuda/lib/libcublas.so.8.0" && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcusolver.so.8.0.61" "$(@D)/cuda/lib/libcusolver.so.8.0" && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcurand.so.8.0.61" "$(@D)/cuda/lib/libcurand.so.8.0" && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcufft.so.8.0.61" "$(@D)/cuda/lib/libcufft.so.8.0" && cp "/usr/lib/x86_64-linux-gnu/libcudnn.so.6.0.21" "$(@D)/cuda/lib/libcudnn.so.6" && cp "/usr/local/cuda-8.0/extras/CUPTI/lib64/libcupti.so.8.0.61" "$(@D)/cuda/lib/libcupti.so.8.0" +if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/stubs/libcuda.so" "$(@D)/cuda/lib/libcuda.so" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so.9.0" "$(@D)/cuda/lib/libcudart.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart_static.a" "$(@D)/cuda/lib/libcudart_static.a" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcublas.so.9.0" "$(@D)/cuda/lib/libcublas.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcusolver.so.9.0" "$(@D)/cuda/lib/libcusolver.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcurand.so.9.0" "$(@D)/cuda/lib/libcurand.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcufft.so.9.0" "$(@D)/cuda/lib/libcufft.so.9.0" && cp "/usr/lib/x86_64-linux-gnu/libcudnn.so.7" "$(@D)/cuda/lib/libcudnn.so.7" && cp "/usr/local/cuda-9.0/extras/CUPTI/lib64/libcupti.so.9.0" "$(@D)/cuda/lib/libcupti.so.9.0" """, ) -genrule( +filegroup( name = "cudnn-include", - outs = [ - "cuda/include/cudnn.h", - ], - cmd = """ -if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/include/cudnn.h" "$(@D)/cudnn.h" - """, + srcs = [], ) diff --git a/third_party/toolchains/gpus/py/BUILD b/third_party/toolchains/gpus/py/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..2d5ace93ff5054927cda61b0302af4edd8fe56c1 --- /dev/null +++ b/third_party/toolchains/gpus/py/BUILD @@ -0,0 +1,171 @@ +# A build file to configure python remote repository used with Bazel remote +# execution service +# DO NOT EDIT: automatically generated BUILD file + +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "python_headers", + hdrs = [":python_include"], + data = select({ + ":windows": [":python_import_lib"], + "//conditions:default": [], + }), + includes = ["python_include"], + linkopts = select({ + # TODO(pcloudy): Ideally, this should just go into deps after resolving + # https://github.com/bazelbuild/bazel/issues/3237, + ":windows": ["$(locations :python_import_lib)"], + "//conditions:default": [], + }), +) + +cc_library( + name = "numpy_headers", + hdrs = [":numpy_include"], + includes = ["numpy_include"], +) + +config_setting( + name = "windows", + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + +genrule( + name = "python_include", + outs = [ + "python_include/Python-ast.h", + "python_include/Python.h", + "python_include/abstract.h", + "python_include/asdl.h", + "python_include/ast.h", + "python_include/bitset.h", + "python_include/boolobject.h", + "python_include/bufferobject.h", + "python_include/bytearrayobject.h", + "python_include/bytes_methods.h", + "python_include/bytesobject.h", + "python_include/cStringIO.h", + "python_include/cellobject.h", + "python_include/ceval.h", + "python_include/classobject.h", + "python_include/cobject.h", + "python_include/code.h", + "python_include/codecs.h", + "python_include/compile.h", + "python_include/complexobject.h", + "python_include/datetime.h", + "python_include/descrobject.h", + "python_include/dictobject.h", + "python_include/dtoa.h", + "python_include/enumobject.h", + "python_include/errcode.h", + "python_include/eval.h", + "python_include/fileobject.h", + "python_include/floatobject.h", + "python_include/frameobject.h", + "python_include/funcobject.h", + "python_include/genobject.h", + "python_include/graminit.h", + "python_include/grammar.h", + "python_include/import.h", + "python_include/intobject.h", + "python_include/intrcheck.h", + "python_include/iterobject.h", + "python_include/listobject.h", + "python_include/longintrepr.h", + "python_include/longobject.h", + "python_include/marshal.h", + "python_include/memoryobject.h", + "python_include/metagrammar.h", + "python_include/methodobject.h", + "python_include/modsupport.h", + "python_include/moduleobject.h", + "python_include/node.h", + "python_include/object.h", + "python_include/objimpl.h", + "python_include/opcode.h", + "python_include/osdefs.h", + "python_include/parsetok.h", + "python_include/patchlevel.h", + "python_include/pgen.h", + "python_include/pgenheaders.h", + "python_include/py_curses.h", + "python_include/pyarena.h", + "python_include/pycapsule.h", + "python_include/pyconfig.h", + "python_include/pyctype.h", + "python_include/pydebug.h", + "python_include/pyerrors.h", + "python_include/pyexpat.h", + "python_include/pyfpe.h", + "python_include/pygetopt.h", + "python_include/pymacconfig.h", + "python_include/pymactoolbox.h", + "python_include/pymath.h", + "python_include/pymem.h", + "python_include/pyport.h", + "python_include/pystate.h", + "python_include/pystrcmp.h", + "python_include/pystrtod.h", + "python_include/pythonrun.h", + "python_include/pythread.h", + "python_include/rangeobject.h", + "python_include/setobject.h", + "python_include/sliceobject.h", + "python_include/stringobject.h", + "python_include/structmember.h", + "python_include/structseq.h", + "python_include/symtable.h", + "python_include/sysmodule.h", + "python_include/timefuncs.h", + "python_include/token.h", + "python_include/traceback.h", + "python_include/tupleobject.h", + "python_include/ucnhash.h", + "python_include/unicodeobject.h", + "python_include/warnings.h", + "python_include/weakrefobject.h", + ], + cmd = """ +cp "/usr/include/python2.7/Python-ast.h" "$(@D)/python_include/Python-ast.h" && cp "/usr/include/python2.7/Python.h" "$(@D)/python_include/Python.h" && cp "/usr/include/python2.7/abstract.h" "$(@D)/python_include/abstract.h" && cp "/usr/include/python2.7/asdl.h" "$(@D)/python_include/asdl.h" && cp "/usr/include/python2.7/ast.h" "$(@D)/python_include/ast.h" && cp "/usr/include/python2.7/bitset.h" "$(@D)/python_include/bitset.h" && cp "/usr/include/python2.7/boolobject.h" "$(@D)/python_include/boolobject.h" && cp "/usr/include/python2.7/bufferobject.h" "$(@D)/python_include/bufferobject.h" && cp "/usr/include/python2.7/bytearrayobject.h" "$(@D)/python_include/bytearrayobject.h" && cp "/usr/include/python2.7/bytes_methods.h" "$(@D)/python_include/bytes_methods.h" && cp "/usr/include/python2.7/bytesobject.h" "$(@D)/python_include/bytesobject.h" && cp "/usr/include/python2.7/cStringIO.h" "$(@D)/python_include/cStringIO.h" && cp "/usr/include/python2.7/cellobject.h" "$(@D)/python_include/cellobject.h" && cp "/usr/include/python2.7/ceval.h" "$(@D)/python_include/ceval.h" && cp "/usr/include/python2.7/classobject.h" "$(@D)/python_include/classobject.h" && cp "/usr/include/python2.7/cobject.h" "$(@D)/python_include/cobject.h" && cp "/usr/include/python2.7/code.h" "$(@D)/python_include/code.h" && cp "/usr/include/python2.7/codecs.h" "$(@D)/python_include/codecs.h" && cp "/usr/include/python2.7/compile.h" "$(@D)/python_include/compile.h" && cp "/usr/include/python2.7/complexobject.h" "$(@D)/python_include/complexobject.h" && cp "/usr/include/python2.7/datetime.h" "$(@D)/python_include/datetime.h" && cp "/usr/include/python2.7/descrobject.h" "$(@D)/python_include/descrobject.h" && cp "/usr/include/python2.7/dictobject.h" "$(@D)/python_include/dictobject.h" && cp "/usr/include/python2.7/dtoa.h" "$(@D)/python_include/dtoa.h" && cp "/usr/include/python2.7/enumobject.h" "$(@D)/python_include/enumobject.h" && cp "/usr/include/python2.7/errcode.h" "$(@D)/python_include/errcode.h" && cp "/usr/include/python2.7/eval.h" "$(@D)/python_include/eval.h" && cp "/usr/include/python2.7/fileobject.h" "$(@D)/python_include/fileobject.h" && cp "/usr/include/python2.7/floatobject.h" "$(@D)/python_include/floatobject.h" && cp "/usr/include/python2.7/frameobject.h" "$(@D)/python_include/frameobject.h" && cp "/usr/include/python2.7/funcobject.h" "$(@D)/python_include/funcobject.h" && cp "/usr/include/python2.7/genobject.h" "$(@D)/python_include/genobject.h" && cp "/usr/include/python2.7/graminit.h" "$(@D)/python_include/graminit.h" && cp "/usr/include/python2.7/grammar.h" "$(@D)/python_include/grammar.h" && cp "/usr/include/python2.7/import.h" "$(@D)/python_include/import.h" && cp "/usr/include/python2.7/intobject.h" "$(@D)/python_include/intobject.h" && cp "/usr/include/python2.7/intrcheck.h" "$(@D)/python_include/intrcheck.h" && cp "/usr/include/python2.7/iterobject.h" "$(@D)/python_include/iterobject.h" && cp "/usr/include/python2.7/listobject.h" "$(@D)/python_include/listobject.h" && cp "/usr/include/python2.7/longintrepr.h" "$(@D)/python_include/longintrepr.h" && cp "/usr/include/python2.7/longobject.h" "$(@D)/python_include/longobject.h" && cp "/usr/include/python2.7/marshal.h" "$(@D)/python_include/marshal.h" && cp "/usr/include/python2.7/memoryobject.h" "$(@D)/python_include/memoryobject.h" && cp "/usr/include/python2.7/metagrammar.h" "$(@D)/python_include/metagrammar.h" && cp "/usr/include/python2.7/methodobject.h" "$(@D)/python_include/methodobject.h" && cp "/usr/include/python2.7/modsupport.h" "$(@D)/python_include/modsupport.h" && cp "/usr/include/python2.7/moduleobject.h" "$(@D)/python_include/moduleobject.h" && cp "/usr/include/python2.7/node.h" "$(@D)/python_include/node.h" && cp "/usr/include/python2.7/object.h" "$(@D)/python_include/object.h" && cp "/usr/include/python2.7/objimpl.h" "$(@D)/python_include/objimpl.h" && cp "/usr/include/python2.7/opcode.h" "$(@D)/python_include/opcode.h" && cp "/usr/include/python2.7/osdefs.h" "$(@D)/python_include/osdefs.h" && cp "/usr/include/python2.7/parsetok.h" "$(@D)/python_include/parsetok.h" && cp "/usr/include/python2.7/patchlevel.h" "$(@D)/python_include/patchlevel.h" && cp "/usr/include/python2.7/pgen.h" "$(@D)/python_include/pgen.h" && cp "/usr/include/python2.7/pgenheaders.h" "$(@D)/python_include/pgenheaders.h" && cp "/usr/include/python2.7/py_curses.h" "$(@D)/python_include/py_curses.h" && cp "/usr/include/python2.7/pyarena.h" "$(@D)/python_include/pyarena.h" && cp "/usr/include/python2.7/pycapsule.h" "$(@D)/python_include/pycapsule.h" && cp "/usr/include/python2.7/pyconfig.h" "$(@D)/python_include/pyconfig.h" && cp "/usr/include/python2.7/pyctype.h" "$(@D)/python_include/pyctype.h" && cp "/usr/include/python2.7/pydebug.h" "$(@D)/python_include/pydebug.h" && cp "/usr/include/python2.7/pyerrors.h" "$(@D)/python_include/pyerrors.h" && cp "/usr/include/python2.7/pyexpat.h" "$(@D)/python_include/pyexpat.h" && cp "/usr/include/python2.7/pyfpe.h" "$(@D)/python_include/pyfpe.h" && cp "/usr/include/python2.7/pygetopt.h" "$(@D)/python_include/pygetopt.h" && cp "/usr/include/python2.7/pymacconfig.h" "$(@D)/python_include/pymacconfig.h" && cp "/usr/include/python2.7/pymactoolbox.h" "$(@D)/python_include/pymactoolbox.h" && cp "/usr/include/python2.7/pymath.h" "$(@D)/python_include/pymath.h" && cp "/usr/include/python2.7/pymem.h" "$(@D)/python_include/pymem.h" && cp "/usr/include/python2.7/pyport.h" "$(@D)/python_include/pyport.h" && cp "/usr/include/python2.7/pystate.h" "$(@D)/python_include/pystate.h" && cp "/usr/include/python2.7/pystrcmp.h" "$(@D)/python_include/pystrcmp.h" && cp "/usr/include/python2.7/pystrtod.h" "$(@D)/python_include/pystrtod.h" && cp "/usr/include/python2.7/pythonrun.h" "$(@D)/python_include/pythonrun.h" && cp "/usr/include/python2.7/pythread.h" "$(@D)/python_include/pythread.h" && cp "/usr/include/python2.7/rangeobject.h" "$(@D)/python_include/rangeobject.h" && cp "/usr/include/python2.7/setobject.h" "$(@D)/python_include/setobject.h" && cp "/usr/include/python2.7/sliceobject.h" "$(@D)/python_include/sliceobject.h" && cp "/usr/include/python2.7/stringobject.h" "$(@D)/python_include/stringobject.h" && cp "/usr/include/python2.7/structmember.h" "$(@D)/python_include/structmember.h" && cp "/usr/include/python2.7/structseq.h" "$(@D)/python_include/structseq.h" && cp "/usr/include/python2.7/symtable.h" "$(@D)/python_include/symtable.h" && cp "/usr/include/python2.7/sysmodule.h" "$(@D)/python_include/sysmodule.h" && cp "/usr/include/python2.7/timefuncs.h" "$(@D)/python_include/timefuncs.h" && cp "/usr/include/python2.7/token.h" "$(@D)/python_include/token.h" && cp "/usr/include/python2.7/traceback.h" "$(@D)/python_include/traceback.h" && cp "/usr/include/python2.7/tupleobject.h" "$(@D)/python_include/tupleobject.h" && cp "/usr/include/python2.7/ucnhash.h" "$(@D)/python_include/ucnhash.h" && cp "/usr/include/python2.7/unicodeobject.h" "$(@D)/python_include/unicodeobject.h" && cp "/usr/include/python2.7/warnings.h" "$(@D)/python_include/warnings.h" && cp "/usr/include/python2.7/weakrefobject.h" "$(@D)/python_include/weakrefobject.h" + """, +) + +genrule( + name = "numpy_include", + outs = [ + "numpy_include/numpy/__multiarray_api.h", + "numpy_include/numpy/__ufunc_api.h", + "numpy_include/numpy/_neighborhood_iterator_imp.h", + "numpy_include/numpy/_numpyconfig.h", + "numpy_include/numpy/arrayobject.h", + "numpy_include/numpy/arrayscalars.h", + "numpy_include/numpy/halffloat.h", + "numpy_include/numpy/multiarray_api.txt", + "numpy_include/numpy/ndarrayobject.h", + "numpy_include/numpy/ndarraytypes.h", + "numpy_include/numpy/noprefix.h", + "numpy_include/numpy/npy_1_7_deprecated_api.h", + "numpy_include/numpy/npy_3kcompat.h", + "numpy_include/numpy/npy_common.h", + "numpy_include/numpy/npy_cpu.h", + "numpy_include/numpy/npy_endian.h", + "numpy_include/numpy/npy_interrupt.h", + "numpy_include/numpy/npy_math.h", + "numpy_include/numpy/npy_no_deprecated_api.h", + "numpy_include/numpy/npy_os.h", + "numpy_include/numpy/numpyconfig.h", + "numpy_include/numpy/old_defines.h", + "numpy_include/numpy/oldnumeric.h", + "numpy_include/numpy/ufunc_api.txt", + "numpy_include/numpy/ufuncobject.h", + "numpy_include/numpy/utils.h", + ], + cmd = """ +cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/__multiarray_api.h" "$(@D)/numpy_include/numpy/__multiarray_api.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/__ufunc_api.h" "$(@D)/numpy_include/numpy/__ufunc_api.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/_neighborhood_iterator_imp.h" "$(@D)/numpy_include/numpy/_neighborhood_iterator_imp.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/_numpyconfig.h" "$(@D)/numpy_include/numpy/_numpyconfig.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/arrayobject.h" "$(@D)/numpy_include/numpy/arrayobject.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/arrayscalars.h" "$(@D)/numpy_include/numpy/arrayscalars.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/halffloat.h" "$(@D)/numpy_include/numpy/halffloat.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/multiarray_api.txt" "$(@D)/numpy_include/numpy/multiarray_api.txt" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/ndarrayobject.h" "$(@D)/numpy_include/numpy/ndarrayobject.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/ndarraytypes.h" "$(@D)/numpy_include/numpy/ndarraytypes.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/noprefix.h" "$(@D)/numpy_include/numpy/noprefix.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h" "$(@D)/numpy_include/numpy/npy_1_7_deprecated_api.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_3kcompat.h" "$(@D)/numpy_include/numpy/npy_3kcompat.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_common.h" "$(@D)/numpy_include/numpy/npy_common.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_cpu.h" "$(@D)/numpy_include/numpy/npy_cpu.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_endian.h" "$(@D)/numpy_include/numpy/npy_endian.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_interrupt.h" "$(@D)/numpy_include/numpy/npy_interrupt.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_math.h" "$(@D)/numpy_include/numpy/npy_math.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_no_deprecated_api.h" "$(@D)/numpy_include/numpy/npy_no_deprecated_api.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_os.h" "$(@D)/numpy_include/numpy/npy_os.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/numpyconfig.h" "$(@D)/numpy_include/numpy/numpyconfig.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/old_defines.h" "$(@D)/numpy_include/numpy/old_defines.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/oldnumeric.h" "$(@D)/numpy_include/numpy/oldnumeric.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/ufunc_api.txt" "$(@D)/numpy_include/numpy/ufunc_api.txt" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/ufuncobject.h" "$(@D)/numpy_include/numpy/ufuncobject.h" && cp "/usr/local/lib/python2.7/dist-packages/numpy/core/include/numpy/utils.h" "$(@D)/numpy_include/numpy/utils.h" + """, +) diff --git a/tools/bazel.rc b/tools/bazel.rc index 8b8c71756171387b7a4b834ea94015a00313492e..1c1e6afb65ab8da5b689d58ecaec6ac6c8a69bb8 100644 --- a/tools/bazel.rc +++ b/tools/bazel.rc @@ -27,11 +27,14 @@ build --define framework_shared_object=true build:mkl --define=using_mkl=true build:mkl -c opt +build:download_clang --crosstool_top=@local_config_download_clang//:toolchain +build:download_clang --define=using_clang=true + build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true build:cuda_clang --crosstool_top=@local_config_cuda//crosstool:toolchain -build:cuda_clang --define=using_cuda=true --define=using_cuda_clang=true +build:cuda_clang --define=using_cuda=true --define=using_cuda_clang=true --define=using_clang=true build:win-cuda --define=using_cuda=true --define=using_cuda_nvcc=true